-
Notifications
You must be signed in to change notification settings - Fork 155
/
Copy pathworker.py
292 lines (251 loc) · 11.1 KB
/
worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""
import queue
import random
from dataclasses import dataclass
from typing import Any, Dict, Optional, TypeVar, Union
import torch
from torch._utils import ExceptionWrapper
from torch.utils.data._utils import HAS_NUMPY, MP_STATUS_CHECK_INTERVAL, signal_handling
from torch.utils.data._utils.worker import (
_generate_state,
_IterableDatasetStopIteration,
_ResumeIteration,
ManagerWatchdog,
WorkerInfo,
)
from .incremental_state import (
_DATASET_ITER_STATE,
_DATASET_STATE,
_FETCHER_ENDED,
_FETCHER_STATE,
_IncrementalWorkerState,
_WORKER_ID,
)
from .stateful import Stateful
T = TypeVar("T")
def try_to_serialize(obj: Any) -> Union[dict, None]:
if isinstance(obj, Stateful):
obj_state = obj.state_dict()
else:
obj_state = None
return obj_state
def try_to_deserialize(obj: T, state_dict: dict) -> T:
if isinstance(obj, Stateful):
obj.load_state_dict(state_dict)
return obj # type: ignore[return-value]
return obj
@dataclass(frozen=True)
class _AckStartup:
"""Dummy class used to ack startup and return state at time 0"""
worker_id: int
initial_state: Optional[Union[Dict[str, Any], ExceptionWrapper]]
is_delta: bool = False
def _worker_loop(
dataset_kind,
dataset,
index_queue,
data_queue,
done_event,
auto_collation,
collate_fn,
drop_last,
base_seed,
init_fn,
worker_id,
num_workers,
persistent_workers,
shared_seed,
worker_state,
):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
try:
# Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal had already happened
# again.
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
signal_handling._set_worker_signal_handlers()
torch.set_num_threads(1)
seed = base_seed + worker_id
random.seed(seed)
torch.manual_seed(seed)
if HAS_NUMPY:
np_seed = _generate_state(base_seed, worker_id)
import numpy as np
np.random.seed(np_seed)
from torch.utils.data import IterDataPipe
from torch.utils.data.graph_settings import apply_random_seed
shared_rng = torch.Generator()
if isinstance(dataset, IterDataPipe):
assert shared_seed is not None
shared_rng.manual_seed(shared_seed)
dataset = apply_random_seed(dataset, shared_rng)
torch.utils.data._utils.worker._worker_info = WorkerInfo(
id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
)
from torch.utils.data import _DatasetKind
# See NOTE [ Incremental worker state ]
incremental_worker_state: _IncrementalWorkerState
init_exception = None
fetcher = None
initial_state = None
is_delta = False
try:
if init_fn is not None:
init_fn(worker_id)
if worker_state is None:
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
initial_state = _make_state_dict(worker_id, dataset_kind, fetcher, dataset)
incremental_worker_state = _IncrementalWorkerState(initial_state)
else:
# Always restore in this order:
# 1. try to restore dataset state
# 2. generate dataset iterator
# 3. try to restore iterator state
incremental_worker_state = _IncrementalWorkerState(worker_state)
if worker_state[_DATASET_STATE] is not None:
dataset = try_to_deserialize(dataset, worker_state[_DATASET_STATE])
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
if worker_state[_FETCHER_STATE] is not None:
if dataset_kind == _DatasetKind.Iterable:
if worker_state[_FETCHER_STATE][_DATASET_ITER_STATE] is not None:
dataset_iter = try_to_deserialize(
fetcher.dataset_iter,
worker_state[_FETCHER_STATE][_DATASET_ITER_STATE],
)
if dataset_iter is not None:
fetcher.dataset_iter = dataset_iter
# We always force fetcher to request at least one batch even if
# we know it will lead to immediate stop iteration
fetcher.ended = False
iteration_end = False
initial_state = incremental_worker_state.generate_delta(
_make_state_dict(worker_id, dataset_kind, fetcher, dataset)
)
is_delta = True
del worker_state
except Exception:
init_exception = ExceptionWrapper(where=f"in DataLoader worker process {worker_id}")
# When using Iterable mode, some worker can exit earlier than others due
# to the IterableDataset behaving differently for different workers.
# When such things happen, an `_IterableDatasetStopIteration` object is
# sent over to the main process with the ID of this worker, so that the
# main process won't send more tasks to this worker, and will send
# `None` to this worker to properly exit it.
#
# Note that we cannot set `done_event` from a worker as it is shared
# among all processes. Instead, we set the `iteration_end` flag to
# signify that the iterator is exhausted. When either `done_event` or
# `iteration_end` is set, we skip all processing step and just wait for
# `None`.
iteration_end = False
watchdog = ManagerWatchdog()
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if isinstance(r, _AckStartup):
# Send ack and initial state to the main process
data_queue.put(
(
r,
_AckStartup(
worker_id=worker_id, initial_state=init_exception or initial_state, is_delta=is_delta
),
)
)
del initial_state
del is_delta
continue
elif isinstance(r, _ResumeIteration):
iteration_end = False
if isinstance(dataset, IterDataPipe):
assert r.seed is not None
shared_rng.manual_seed(r.seed)
dataset = apply_random_seed(dataset, shared_rng)
try:
# Recreate the fetcher for worker-reuse policy
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
# see NOTE [ Incremental Worker State ]
initial_state = _make_state_dict(worker_id, dataset_kind, fetcher, dataset)
incremental_worker_state = _IncrementalWorkerState(initial_state)
except Exception:
init_exception = ExceptionWrapper(where=f"in DataLoader worker process {worker_id}")
# Acknowledge the main process
data_queue.put((r, _AckStartup(worker_id=worker_id, initial_state=init_exception or initial_state)))
del initial_state
continue
elif r is None:
# Received the final signal
assert done_event.is_set() or iteration_end
break
elif done_event.is_set() or iteration_end:
# `done_event` is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, (index, snapshot) = r
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
delta_state_dict = None
if init_exception is not None:
data = init_exception
init_exception = None
else:
try:
try:
data = fetcher.fetch(index) # type: ignore[union-attr]
except StopIteration:
if not dataset_kind == _DatasetKind.Iterable:
raise
data = _IterableDatasetStopIteration(worker_id)
# Set `iteration_end`
# (1) to save future `next(...)` calls, and
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
iteration_end = True
if snapshot or iteration_end:
# Generate incremental diff from prev_state_dict and current_state_dict
state_dict = _make_state_dict(worker_id, dataset_kind, fetcher, dataset)
delta_state_dict = incremental_worker_state.generate_delta(state_dict)
del state_dict
except Exception:
# It is important that we don't store exc_info in a variable.
# `ExceptionWrapper` does the correct thing.
# See NOTE [ Python Traceback Reference Cycle Problem ]
data = ExceptionWrapper(where=f"in DataLoader worker process {worker_id}")
data_queue.put((idx, (data, worker_id, delta_state_dict)))
del data, idx, index, r, delta_state_dict # save memory
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass
if done_event.is_set():
data_queue.cancel_join_thread()
data_queue.close()
def _make_state_dict(worker_id, dataset_kind, fetcher, dataset) -> Dict[str, Any]:
from torch.utils.data import _DatasetKind
if dataset_kind == _DatasetKind.Iterable:
fetcher_state = {
_DATASET_ITER_STATE: try_to_serialize(fetcher.dataset_iter),
_FETCHER_ENDED: fetcher.ended,
}
dataset_state = None
if fetcher.dataset_iter is not fetcher.dataset:
dataset_state = try_to_serialize(fetcher.dataset)
else:
fetcher_state = None
# Pick up any user-defined dataset state
dataset_state = try_to_serialize(dataset)
return {
_WORKER_ID: worker_id,
_FETCHER_STATE: fetcher_state,
_DATASET_STATE: dataset_state,
}