Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/map_failure_modes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import asyncio
import sys
from itertools import count
from streamz import Stream


async def flaky_async(x, from_where):
return flaky_sync(x, from_where)


def flaky_sync(x, from_where):
if x % 5 == 4:
raise ValueError(f"I flaked out on {x} for {from_where}")
return x


def make_counter(name):
return Stream.from_iterable(count(), asynchronous=True, stream_name=name)


async def main(run_flags):
async_non_stop_source = make_counter("async not stopping")
s_async = async_non_stop_source.rate_limit("500ms").map_async(flaky_async, async_non_stop_source)
s_async.sink(print, async_non_stop_source.name)

sync_source = make_counter("sync")
s_sync = sync_source.rate_limit("500ms").map(flaky_sync, sync_source)
s_sync.sink(print, sync_source.name)

async_stopping_source = make_counter("async stopping")
s_async_stop = async_stopping_source.rate_limit("500ms").map_async(flaky_async, async_stopping_source, stop_on_exception=True)
s_async_stop.sink(print, async_stopping_source.name)

if run_flags[0]:
async_non_stop_source.start()
if run_flags[1]:
sync_source.start()
if run_flags[2]:
async_stopping_source.start()

print(f"{async_non_stop_source.started=}, {sync_source.started=}, {async_stopping_source.started=}")
await asyncio.sleep(3)
print(f"{async_non_stop_source.stopped=}, {sync_source.stopped=}, {async_stopping_source.stopped=}")

if run_flags[2]:
print()
print(f"Restarting {async_stopping_source}")
async_stopping_source.start()
print()
await asyncio.sleep(2)
print(f"{async_non_stop_source.stopped=}, {sync_source.stopped=}, {async_stopping_source.stopped=}")


if __name__ == "__main__":
try:
if len(sys.argv) > 1:
flags = [char == "T" for char in sys.argv[1]]
else:
flags = [True, True, True]
asyncio.run(main(flags))
except KeyboardInterrupt:
pass
55 changes: 46 additions & 9 deletions streamz/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
import concurrent.futures
from collections import deque, defaultdict
from datetime import timedelta
from itertools import chain
import functools
import logging
import threading
from time import time
from typing import Any, Callable, Hashable, Union
from typing import Any, Callable, Coroutine, Hashable, Tuple, Union, overload
import weakref

import toolz
Expand Down Expand Up @@ -730,6 +731,8 @@ class map_async(Stream):
The arguments to pass to the function.
parallelism:
The maximum number of parallel Tasks for evaluating func, default value is 1
stop_on_exception:
If the mapped func raises an exception, should the stream stop or not. Default value is False.
**kwargs:
Keyword arguments to pass to func

Expand All @@ -749,38 +752,72 @@ class map_async(Stream):
6
8
"""
def __init__(self, upstream, func, *args, parallelism=1, **kwargs):
def __init__(self, upstream, func, *args, parallelism=1, stop_on_exception=False, **kwargs):
self.func = func
stream_name = kwargs.pop('stream_name', None)
self.kwargs = kwargs
self.args = args
self.stop_on_exception = stop_on_exception
self.work_queue = asyncio.Queue(maxsize=parallelism)

Stream.__init__(self, upstream, stream_name=stream_name, ensure_io_loop=True)
self.work_task = self._create_task(self.work_callback())
self.work_task = None

def _create_work_task(self) -> Tuple[asyncio.Event, asyncio.Task[None]]:
stop_work = asyncio.Event()
work_task = self._create_task(self.work_callback(stop_work))
return stop_work, work_task

def start(self):
if self.work_task:
stop_work, _ = self.work_task
stop_work.set()
self.work_task = self._create_work_task()
super().start()

def stop(self):
stop_work, _ = self.work_task
stop_work.set()
self.work_task = None
super().stop()

def update(self, x, who=None, metadata=None):
if not self.work_task:
self.work_task = self._create_work_task()
return self._create_task(self._insert_job(x, metadata))

@overload
def _create_task(self, coro: asyncio.Future) -> asyncio.Future:
...

@overload
def _create_task(self, coro: concurrent.futures.Future) -> concurrent.futures.Future:
...

@overload
def _create_task(self, coro: Coroutine) -> asyncio.Task:
...

def _create_task(self, coro):
if gen.is_future(coro):
return coro
return self.loop.asyncio_loop.create_task(coro)

async def work_callback(self):
while True:
async def work_callback(self, stop_work: asyncio.Event):
while not stop_work.is_set():
task, metadata = await self.work_queue.get()
self.work_queue.task_done()
try:
task, metadata = await self.work_queue.get()
self.work_queue.task_done()
result = await task
except Exception as e:
logger.exception(e)
raise
if self.stop_on_exception:
self.stop()
else:
results = self._emit(result, metadata=metadata)
if results:
await asyncio.gather(*results)
self._release_refs(metadata)
self._release_refs(metadata)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map_async calls _retain_refs during the insert into the work queue so making sure that we call _release_refs even during an exception seems better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct; probably the assumption is that the exception simply stops the whole pipeline, but we can do better. Nodes that filter in/out on exceptions would be reasonable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually had this idea for the next improvement. It would be better for map/starmap/map_async to flow down Exceptions (probably paired with the offending input) so that the graph can fork the success one way and the failure to a logging/recovery flow.


async def _wait_for_work_slot(self):
while self.work_queue.full():
Expand Down
2 changes: 2 additions & 0 deletions streamz/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,8 @@ async def run(self):
if self.stopped:
break
await asyncio.gather(*self._emit(x))
if self.stopped:
break
Comment on lines +789 to +790
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By not checking self.stopped after coming back from the gather, the source over-consumes the underlying iterable and loses an element.

self.stopped = True


Expand Down
22 changes: 22 additions & 0 deletions streamz/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,28 @@ def fail_func():
assert (time() - start) == pytest.approx(0.1, abs=4e-3)


@pytest.mark.asyncio
async def test_map_async_restart():
async def flake_out(x):
if x == 2:
raise RuntimeError("I fail on 2.")
if x > 4:
raise RuntimeError("I fail on > 4.")
return x

source = Stream.from_iterable(itertools.count())
mapped = source.map_async(flake_out, stop_on_exception=True)
results = mapped.sink_to_list()
source.start()

await await_for(lambda: results == [0, 1], 1)
await await_for(lambda: not mapped.work_task, 1)

source.start()

await await_for(lambda: results == [0, 1, 3, 4], 1)


@pytest.mark.asyncio
async def test_map_async():
@gen.coroutine
Expand Down
Loading