-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathasyncio_buffered_pipeline.py
87 lines (74 loc) · 2.91 KB
/
asyncio_buffered_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import asyncio
import collections
def buffered_pipeline():
tasks = []
def queue(size):
# The regular asyncio.queue doesn't have a function to wait for space
# in the queue without also immediately putting an item into it, which
# would mean effective minimum buffer_size is 2: an item in the queue
# and in memory waiting to put into it. To allow a buffer_size of 1,
# we need to check there is space _before_ fetching the item from
# upstream. This requires a custom queue implementation.
#
# We also can guarantee there will be at most one getter and putter at
# any one time, and that _put won't be called until there is space in
# the queue, so we can have much simpler code than asyncio.Queue
_queue = collections.deque()
at_least_one_in_queue = asyncio.Event()
until_space = asyncio.Event()
until_space.set()
async def _space():
await until_space.wait()
def _has_items():
return bool(_queue)
async def _get():
nonlocal at_least_one_in_queue
await at_least_one_in_queue.wait()
value = _queue.popleft()
until_space.set()
if not _queue:
at_least_one_in_queue = asyncio.Event()
return value
def _put(item):
nonlocal until_space
_queue.append(item)
at_least_one_in_queue.set()
if len(_queue) >= size:
until_space = asyncio.Event()
return _space, _has_items, _get, _put
async def _buffer_iterable(iterable, buffer_size=1):
nonlocal tasks
queue_space, queue_has_items, queue_get, queue_put = queue(buffer_size)
iterator = iterable.__aiter__()
async def _iterate():
try:
while True:
await queue_space()
value = await iterator.__anext__()
queue_put((None, value))
value = None # So value can be garbage collected
except BaseException as exception:
queue_put((exception, None))
task = asyncio.create_task(_iterate())
tasks.append(task)
try:
while queue_has_items() or task:
exception, value = await queue_get()
if exception is not None:
raise exception from None
yield value
value = None # So value can be garbage collected
except StopAsyncIteration:
pass
except BaseException as exception:
for task in tasks:
task.cancel()
all_tasks = tasks
tasks = []
for task in all_tasks:
try:
await task
except asyncio.CancelledError:
pass
raise
return _buffer_iterable