diff --git a/examples/map_failure_modes.py b/examples/map_failure_modes.py new file mode 100644 index 00000000..4fd11d10 --- /dev/null +++ b/examples/map_failure_modes.py @@ -0,0 +1,62 @@ +import asyncio +import sys +from itertools import count +from streamz import Stream + + +async def flaky_async(x, from_where): + return flaky_sync(x, from_where) + + +def flaky_sync(x, from_where): + if x % 5 == 4: + raise ValueError(f"I flaked out on {x} for {from_where}") + return x + + +def make_counter(name): + return Stream.from_iterable(count(), asynchronous=True, stream_name=name) + + +async def main(run_flags): + async_non_stop_source = make_counter("async not stopping") + s_async = async_non_stop_source.rate_limit("500ms").map_async(flaky_async, async_non_stop_source) + s_async.sink(print, async_non_stop_source.name) + + sync_source = make_counter("sync") + s_sync = sync_source.rate_limit("500ms").map(flaky_sync, sync_source) + s_sync.sink(print, sync_source.name) + + async_stopping_source = make_counter("async stopping") + s_async_stop = async_stopping_source.rate_limit("500ms").map_async(flaky_async, async_stopping_source, stop_on_exception=True) + s_async_stop.sink(print, async_stopping_source.name) + + if run_flags[0]: + async_non_stop_source.start() + if run_flags[1]: + sync_source.start() + if run_flags[2]: + async_stopping_source.start() + + print(f"{async_non_stop_source.started=}, {sync_source.started=}, {async_stopping_source.started=}") + await asyncio.sleep(3) + print(f"{async_non_stop_source.stopped=}, {sync_source.stopped=}, {async_stopping_source.stopped=}") + + if run_flags[2]: + print() + print(f"Restarting {async_stopping_source}") + async_stopping_source.start() + print() + await asyncio.sleep(2) + print(f"{async_non_stop_source.stopped=}, {sync_source.stopped=}, {async_stopping_source.stopped=}") + + +if __name__ == "__main__": + try: + if len(sys.argv) > 1: + flags = [char == "T" for char in sys.argv[1]] + else: + flags = [True, True, True] + asyncio.run(main(flags)) + except KeyboardInterrupt: + pass diff --git a/streamz/core.py b/streamz/core.py index 00b5ed4c..8fea2813 100644 --- a/streamz/core.py +++ b/streamz/core.py @@ -1,4 +1,5 @@ import asyncio +import concurrent.futures from collections import deque, defaultdict from datetime import timedelta from itertools import chain @@ -6,7 +7,7 @@ import logging import threading from time import time -from typing import Any, Callable, Hashable, Union +from typing import Any, Callable, Coroutine, Hashable, Tuple, Union, overload import weakref import toolz @@ -730,6 +731,8 @@ class map_async(Stream): The arguments to pass to the function. parallelism: The maximum number of parallel Tasks for evaluating func, default value is 1 + stop_on_exception: + If the mapped func raises an exception, should the stream stop or not. Default value is False. **kwargs: Keyword arguments to pass to func @@ -749,38 +752,72 @@ class map_async(Stream): 6 8 """ - def __init__(self, upstream, func, *args, parallelism=1, **kwargs): + def __init__(self, upstream, func, *args, parallelism=1, stop_on_exception=False, **kwargs): self.func = func stream_name = kwargs.pop('stream_name', None) self.kwargs = kwargs self.args = args + self.stop_on_exception = stop_on_exception self.work_queue = asyncio.Queue(maxsize=parallelism) Stream.__init__(self, upstream, stream_name=stream_name, ensure_io_loop=True) - self.work_task = self._create_task(self.work_callback()) + self.work_task = None + + def _create_work_task(self) -> Tuple[asyncio.Event, asyncio.Task[None]]: + stop_work = asyncio.Event() + work_task = self._create_task(self.work_callback(stop_work)) + return stop_work, work_task + + def start(self): + if self.work_task: + stop_work, _ = self.work_task + stop_work.set() + self.work_task = self._create_work_task() + super().start() + + def stop(self): + stop_work, _ = self.work_task + stop_work.set() + self.work_task = None + super().stop() def update(self, x, who=None, metadata=None): + if not self.work_task: + self.work_task = self._create_work_task() return self._create_task(self._insert_job(x, metadata)) + @overload + def _create_task(self, coro: asyncio.Future) -> asyncio.Future: + ... + + @overload + def _create_task(self, coro: concurrent.futures.Future) -> concurrent.futures.Future: + ... + + @overload + def _create_task(self, coro: Coroutine) -> asyncio.Task: + ... + def _create_task(self, coro): if gen.is_future(coro): return coro return self.loop.asyncio_loop.create_task(coro) - async def work_callback(self): - while True: + async def work_callback(self, stop_work: asyncio.Event): + while not stop_work.is_set(): + task, metadata = await self.work_queue.get() + self.work_queue.task_done() try: - task, metadata = await self.work_queue.get() - self.work_queue.task_done() result = await task except Exception as e: logger.exception(e) - raise + if self.stop_on_exception: + self.stop() else: results = self._emit(result, metadata=metadata) if results: await asyncio.gather(*results) - self._release_refs(metadata) + self._release_refs(metadata) async def _wait_for_work_slot(self): while self.work_queue.full(): diff --git a/streamz/sources.py b/streamz/sources.py index 777f9181..940a2e9f 100644 --- a/streamz/sources.py +++ b/streamz/sources.py @@ -786,6 +786,8 @@ async def run(self): if self.stopped: break await asyncio.gather(*self._emit(x)) + if self.stopped: + break self.stopped = True diff --git a/streamz/tests/test_core.py b/streamz/tests/test_core.py index 9245f2e6..2dcc786c 100644 --- a/streamz/tests/test_core.py +++ b/streamz/tests/test_core.py @@ -151,6 +151,28 @@ def fail_func(): assert (time() - start) == pytest.approx(0.1, abs=4e-3) +@pytest.mark.asyncio +async def test_map_async_restart(): + async def flake_out(x): + if x == 2: + raise RuntimeError("I fail on 2.") + if x > 4: + raise RuntimeError("I fail on > 4.") + return x + + source = Stream.from_iterable(itertools.count()) + mapped = source.map_async(flake_out, stop_on_exception=True) + results = mapped.sink_to_list() + source.start() + + await await_for(lambda: results == [0, 1], 1) + await await_for(lambda: not mapped.work_task, 1) + + source.start() + + await await_for(lambda: results == [0, 1, 3, 4], 1) + + @pytest.mark.asyncio async def test_map_async(): @gen.coroutine