-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtflex_tpu_topology.py
74 lines (52 loc) · 2.38 KB
/
tflex_tpu_topology.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import json
import base64
from tensorflow.python.tpu import tpu as tpu_ops
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.tpu import device_assignment as device_assignment_lib
from tensorflow.python.tpu import topology as topology_lib
from tensorflow.contrib.cluster_resolver import TPUClusterResolver as BaseTPUClusterResolver
from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
_TOPOLOGY_CACHE_FILENAME = '.tpu_topology_cache.json'
class Context():
pass
if 'api' not in globals():
api = Context()
api.topology = None
api.topology_cache = {}
try:
with open(_TOPOLOGY_CACHE_FILENAME, 'r') as f:
api.topology_cache = json.load(f)
except FileNotFoundError:
pass
def cached_topology(name=None):
if name is None:
name = os.environ.get('TPU_NAME', '')
result = api.topology_cache.get(name, None)
if result is not None:
serialized = base64.b64decode(result)
return topology_lib.Topology(serialized=serialized)
def get_cluster_resolver(cluster_resolver=None):
if cluster_resolver is None:
cluster_resolver = BaseTPUClusterResolver(os.environ['TPU_NAME'])
return cluster_resolver
def get_topology(cluster_resolver=None):
api.topology = cached_topology()
if api.topology is None:
cluster_resolver = get_cluster_resolver(cluster_resolver)
api.topology = tpu_strategy_util.initialize_tpu_system(cluster_resolver)
api.topology_cache.update({os.environ['TPU_NAME']: base64.b64encode(api.topology.serialized()).decode('utf8')})
with open(_TOPOLOGY_CACHE_FILENAME, 'w') as f:
f.write(json.dumps(api.topology_cache))
return api.topology
def get_task_and_cores_to_replicas():
return device_assignment_lib._compute_task_and_cores_to_replicas(api.topology.device_coordinates, api.topology)
def get_core_assignment(*core_ids):
return device_assignment_lib.DeviceAssignment(get_topology(), [[get_topology().device_coordinates[0][i]] for i in core_ids])
def get_metadata(cluster_resolver=None):
cluster_resolver = get_cluster_resolver(cluster_resolver)
meta = tpu_system_metadata_lib._query_tpu_system_metadata(cluster_resolver.get_master(), cluster_def=cluster_resolver.cluster_spec().as_cluster_def(), query_topology=True)
return meta
api.topology = cached_topology()