Skip to content
3 changes: 3 additions & 0 deletions ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,8 @@ def warning(self, msg, *args, **kwargs):
def error(self, msg, *args, **kwargs):
self._logger.error(msg, *args, **kwargs)

def exception(self, msg, *args, **kwargs):
self._logger.exception(msg, *args, **kwargs)

def critical(self, msg, *args, **kwargs):
self._logger.critical(msg, *args, **kwargs)
110 changes: 99 additions & 11 deletions ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""

import inspect
import time
from functools import wraps
from typing import Optional, Sequence, TypeVar, Union

Expand Down Expand Up @@ -54,8 +55,11 @@ def __init__(
maximum_concurrent_activity_work_items: Optional[int] = None,
maximum_concurrent_orchestration_work_items: Optional[int] = None,
maximum_thread_pool_workers: Optional[int] = None,
worker_ready_timeout: Optional[float] = None,
):
self._logger = Logger('WorkflowRuntime', logger_options)
self._worker_ready_timeout = 30.0 if worker_ready_timeout is None else worker_ready_timeout

metadata = tuple()
if settings.DAPR_API_TOKEN:
metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),)
Expand Down Expand Up @@ -86,10 +90,20 @@ def register_workflow(self, fn: Workflow, *, name: Optional[str] = None):

def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None):
"""Responsible to call Workflow function in orchestrationWrapper"""
daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options())
if inp is None:
return fn(daprWfContext)
return fn(daprWfContext, inp)
instance_id = getattr(ctx, 'instance_id', 'unknown')

try:
daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options())
if inp is None:
result = fn(daprWfContext)
else:
result = fn(daprWfContext, inp)
return result
except Exception as e:
self._logger.exception(
f'Workflow execution failed - instance_id: {instance_id}, error: {e}'
)
raise

if hasattr(fn, '_workflow_registered'):
# whenever a workflow is registered, it has a _dapr_alternate_name attribute
Expand Down Expand Up @@ -152,10 +166,20 @@ def register_activity(self, fn: Activity, *, name: Optional[str] = None):

def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None):
"""Responsible to call Activity function in activityWrapper"""
wfActivityContext = WorkflowActivityContext(ctx)
if inp is None:
return fn(wfActivityContext)
return fn(wfActivityContext, inp)
activity_id = getattr(ctx, 'task_id', 'unknown')

try:
wfActivityContext = WorkflowActivityContext(ctx)
if inp is None:
result = fn(wfActivityContext)
else:
result = fn(wfActivityContext, inp)
return result
except Exception as e:
self._logger.exception(
f'Activity execution failed - task_id: {activity_id}, error: {e}'
)
raise

if hasattr(fn, '_activity_registered'):
# whenever an activity is registered, it has a _dapr_alternate_name attribute
Expand All @@ -174,13 +198,77 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None):
)
fn.__dict__['_activity_registered'] = True

def wait_for_worker_ready(self, timeout: float = 30.0) -> bool:
"""
Wait for the worker's gRPC stream to become ready to receive work items.
This method polls the worker's is_worker_ready() method until it returns True
or the timeout is reached.

Args:
timeout: Maximum time in seconds to wait for the worker to be ready.
Defaults to 30 seconds.

Returns:
True if the worker's gRPC stream is ready to receive work items, False if timeout.
"""
if not hasattr(self.__worker, 'is_worker_ready'):
return False

elapsed = 0.0
poll_interval = 0.1 # 100ms

while elapsed < timeout:
if self.__worker.is_worker_ready():
return True
time.sleep(poll_interval)
elapsed += poll_interval

self._logger.warning(
f'WorkflowRuntime worker readiness check timed out after {timeout} seconds'
)
return False

def start(self):
"""Starts the listening for work items on a background thread."""
self.__worker.start()
"""Starts the listening for work items on a background thread.
This method waits for the worker's gRPC stream to be fully initialized
before returning, ensuring that workflows can be scheduled immediately
after start() completes.
"""
try:
try:
self.__worker.start()
except Exception as start_error:
self._logger.exception(f'WorkflowRuntime worker did not start: {start_error}')
raise

# Verify the worker and its stream reader are ready
if hasattr(self.__worker, 'is_worker_ready'):
try:
is_ready = self.wait_for_worker_ready(timeout=self._worker_ready_timeout)
if not is_ready:
raise RuntimeError('WorkflowRuntime worker and its stream are not ready')
else:
self._logger.debug(
'WorkflowRuntime worker is ready and its stream can receive work items'
)
except Exception as ready_error:
self._logger.exception(
f'WorkflowRuntime wait_for_worker_ready() raised exception: {ready_error}'
)
raise ready_error
else:
self._logger.warning(
'Unable to verify stream readiness. Workflows scheduled immediately may not be received.'
)
except Exception:
raise

def shutdown(self):
"""Stops the listening for work items on a background thread."""
self.__worker.stop()
try:
self.__worker.stop()
except Exception:
raise

def versioned_workflow(
self,
Expand Down
127 changes: 127 additions & 0 deletions ext/dapr-ext-workflow/tests/test_workflow_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@


class FakeTaskHubGrpcWorker:
def __init__(self):
self._orchestrator_fns = {}
self._activity_fns = {}

def add_named_orchestrator(self, name: str, fn):
listOrchestrators.append(name)
self._orchestrator_fns[name] = fn

def add_named_activity(self, name: str, fn):
listActivities.append(name)
self._activity_fns[name] = fn


class WorkflowRuntimeTest(unittest.TestCase):
Expand Down Expand Up @@ -171,3 +177,124 @@ def test_decorator_register_optinal_name(self):
wanted_activity = ['test_act']
assert listActivities == wanted_activity
assert client_act._dapr_alternate_name == 'test_act'


class WorkflowRuntimeWorkerReadyTest(unittest.TestCase):
"""Tests for wait_for_worker_ready() and start() stream readiness."""

def setUp(self):
listActivities.clear()
listOrchestrators.clear()
mock.patch('durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()).start()
self.runtime = WorkflowRuntime()

def test_wait_for_worker_ready_returns_false_when_no_is_worker_ready(self):
mock_worker = mock.MagicMock(spec=['start', 'stop', '_registry'])
del mock_worker.is_worker_ready
self.runtime._WorkflowRuntime__worker = mock_worker
self.assertFalse(self.runtime.wait_for_worker_ready(timeout=0.1))

def test_wait_for_worker_ready_returns_true_when_ready(self):
mock_worker = mock.MagicMock()
mock_worker.is_worker_ready.return_value = True
self.runtime._WorkflowRuntime__worker = mock_worker
self.assertTrue(self.runtime.wait_for_worker_ready(timeout=1.0))
mock_worker.is_worker_ready.assert_called()

def test_wait_for_worker_ready_returns_true_after_poll(self):
"""Worker becomes ready on second poll."""
mock_worker = mock.MagicMock()
mock_worker.is_worker_ready.side_effect = [False, True]
self.runtime._WorkflowRuntime__worker = mock_worker
self.assertTrue(self.runtime.wait_for_worker_ready(timeout=1.0))
self.assertEqual(mock_worker.is_worker_ready.call_count, 2)

def test_wait_for_worker_ready_returns_false_on_timeout(self):
mock_worker = mock.MagicMock()
mock_worker.is_worker_ready.return_value = False
self.runtime._WorkflowRuntime__worker = mock_worker
self.assertFalse(self.runtime.wait_for_worker_ready(timeout=0.2))

def test_start_succeeds_when_worker_ready(self):
mock_worker = mock.MagicMock()
mock_worker.is_worker_ready.return_value = True
self.runtime._WorkflowRuntime__worker = mock_worker
self.runtime.start()
mock_worker.start.assert_called_once()
mock_worker.is_worker_ready.assert_called()

def test_start_logs_debug_when_worker_stream_ready(self):
"""start() logs at debug when worker and stream are ready."""
mock_worker = mock.MagicMock()
mock_worker.is_worker_ready.return_value = True
self.runtime._WorkflowRuntime__worker = mock_worker
with mock.patch.object(self.runtime._logger, 'debug') as mock_debug:
self.runtime.start()
mock_debug.assert_called_once()
call_args = mock_debug.call_args[0][0]
self.assertIn('ready', call_args)
self.assertIn('stream', call_args)

def test_start_logs_exception_when_worker_start_fails(self):
"""start() logs exception when worker.start() raises."""
mock_worker = mock.MagicMock()
mock_worker.start.side_effect = RuntimeError('start failed')
self.runtime._WorkflowRuntime__worker = mock_worker
with mock.patch.object(self.runtime._logger, 'exception') as mock_exception:
with self.assertRaises(RuntimeError):
self.runtime.start()
mock_exception.assert_called_once()
self.assertIn('did not start', mock_exception.call_args[0][0])

def test_start_raises_when_worker_not_ready(self):
listActivities.clear()
listOrchestrators.clear()
mock.patch('durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()).start()
runtime = WorkflowRuntime(worker_ready_timeout=0.2)
mock_worker = mock.MagicMock()
mock_worker.is_worker_ready.return_value = False
runtime._WorkflowRuntime__worker = mock_worker
with self.assertRaises(RuntimeError) as ctx:
runtime.start()
self.assertIn('not ready', str(ctx.exception))

def test_start_logs_warning_when_no_is_worker_ready(self):
mock_worker = mock.MagicMock(spec=['start', 'stop', '_registry'])
del mock_worker.is_worker_ready
self.runtime._WorkflowRuntime__worker = mock_worker
self.runtime.start()
mock_worker.start.assert_called_once()

def test_worker_ready_timeout_init(self):
listActivities.clear()
listOrchestrators.clear()
mock.patch('durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()).start()
rt = WorkflowRuntime(worker_ready_timeout=15.0)
self.assertEqual(rt._worker_ready_timeout, 15.0)

def test_start_raises_when_worker_start_fails(self):
mock_worker = mock.MagicMock()
mock_worker.is_worker_ready.return_value = True
mock_worker.start.side_effect = RuntimeError('start failed')
self.runtime._WorkflowRuntime__worker = mock_worker
with self.assertRaises(RuntimeError) as ctx:
self.runtime.start()
self.assertIn('start failed', str(ctx.exception))
mock_worker.start.assert_called_once()

def test_start_raises_when_wait_for_worker_ready_raises(self):
mock_worker = mock.MagicMock()
mock_worker.start.return_value = None
mock_worker.is_worker_ready.side_effect = ValueError('ready check failed')
self.runtime._WorkflowRuntime__worker = mock_worker
with self.assertRaises(ValueError) as ctx:
self.runtime.start()
self.assertIn('ready check failed', str(ctx.exception))

def test_shutdown_raises_when_worker_stop_fails(self):
mock_worker = mock.MagicMock()
mock_worker.stop.side_effect = RuntimeError('stop failed')
self.runtime._WorkflowRuntime__worker = mock_worker
with self.assertRaises(RuntimeError) as ctx:
self.runtime.shutdown()
self.assertIn('stop failed', str(ctx.exception))