diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 75b1f65468..8062acfbe0 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -705,7 +705,6 @@ def __init__( else: self.stop = stop - self.set_callbacks(callbacks or []) self.set_env_callbacks() @staticmethod @@ -1786,18 +1785,23 @@ def call( if not self._invoke_before_llm_call_hooks(messages, from_agent): raise ValueError("LLM call blocked by before_llm_call hook") + effective_callbacks = callbacks if callbacks is not None else self.callbacks + # --- 5) Set up callbacks if provided with suppress_warnings(): - if callbacks and len(callbacks) > 0: - self.set_callbacks(callbacks) try: # --- 6) Prepare parameters for the completion call params = self._prepare_completion_params(messages, tools) + if effective_callbacks and len(effective_callbacks) > 0: + # Avoid mutating LiteLLM global callback lists. Pass callbacks per request + # so concurrent LLM instances don't race on shared global state. + params["callbacks"] = effective_callbacks + # --- 7) Make the completion call and handle response if self.stream: result = self._handle_streaming_response( params=params, - callbacks=callbacks, + callbacks=effective_callbacks, available_functions=available_functions, from_task=from_task, from_agent=from_agent, @@ -1806,7 +1810,7 @@ def call( else: result = self._handle_non_streaming_response( params=params, - callbacks=callbacks, + callbacks=effective_callbacks, available_functions=available_functions, from_task=from_task, from_agent=from_agent, @@ -1932,18 +1936,22 @@ async def acall( msg_role: Literal["assistant"] = "assistant" message["role"] = msg_role + effective_callbacks = callbacks if callbacks is not None else self.callbacks + with suppress_warnings(): - if callbacks and len(callbacks) > 0: - self.set_callbacks(callbacks) try: params = self._prepare_completion_params( messages, tools, skip_file_processing=True ) + if effective_callbacks and len(effective_callbacks) > 0: + # Avoid mutating LiteLLM global callback lists. Pass callbacks per request + # so concurrent LLM instances don't race on shared global state. + params["callbacks"] = effective_callbacks if self.stream: return await self._ahandle_streaming_response( params=params, - callbacks=callbacks, + callbacks=effective_callbacks, available_functions=available_functions, from_task=from_task, from_agent=from_agent, @@ -1952,7 +1960,7 @@ async def acall( return await self._ahandle_non_streaming_response( params=params, - callbacks=callbacks, + callbacks=effective_callbacks, available_functions=available_functions, from_task=from_task, from_agent=from_agent, @@ -2002,6 +2010,18 @@ async def acall( ), ) raise + ) + + crewai_event_bus.emit( + self, + event=LLMCallFailedEvent( + error=str(e), + from_task=from_task, + from_agent=from_agent, + call_id=get_current_call_id(), + ), + ) + raise def _handle_emit_call_events( self, @@ -2303,32 +2323,6 @@ def get_context_window_size(self) -> int: self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO) return self.context_window_size - @staticmethod - def set_callbacks(callbacks: list[Any]) -> None: - """ - Attempt to keep a single set of callbacks in litellm by removing old - duplicates and adding new ones. - - Note: This only affects the litellm fallback path. Native providers - don't use litellm callbacks - they emit events via base_llm.py. - """ - if not LITELLM_AVAILABLE: - # When litellm is not available, callbacks are still stored - # but not registered with litellm globals - return - - with suppress_warnings(): - callback_types = [type(callback) for callback in callbacks] - for callback in litellm.success_callback[:]: - if type(callback) in callback_types: - litellm.success_callback.remove(callback) - - for callback in litellm._async_success_callback[:]: - if type(callback) in callback_types: - litellm._async_success_callback.remove(callback) - - litellm.callbacks = callbacks - @staticmethod def set_env_callbacks() -> None: """Sets the success and failure callbacks for the LiteLLM library from environment variables. diff --git a/lib/crewai/tests/llms/test_concurrency.py b/lib/crewai/tests/llms/test_concurrency.py new file mode 100644 index 0000000000..f10d59fdbb --- /dev/null +++ b/lib/crewai/tests/llms/test_concurrency.py @@ -0,0 +1,107 @@ +import threading +import time +from unittest.mock import MagicMock, patch +import pytest +from crewai.llm import LLM +from crewai.utilities.token_counter_callback import TokenCalcHandler +from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess + +def test_concurrent_llm_calls_isolation(): + """ + Test that concurrent LLM calls with different callbacks do not interfere with each other. + """ + + # We patch globally so it applies to all threads + # We use crewai.llm.litellm to be safe + with patch("crewai.llm.litellm.completion") as mock_completion: + + # Setup mock to return a valid response structure + def side_effect(*args, **kwargs): + messages = kwargs.get("messages", []) + content = messages[0]["content"] if messages else "" + thread_id = content.split("thread ")[-1] + + mock_message = MagicMock() + mock_message.content = f"Response for thread {thread_id}" + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + # Use unique usage stats based on thread ID + tid_int = int(thread_id) + + # Create a usage object with attributes (not a dict) + usage_obj = MagicMock() + usage_obj.prompt_tokens = 10 + tid_int + usage_obj.completion_tokens = 5 + tid_int + usage_obj.total_tokens = 15 + 2 * tid_int + # Mock prompt_tokens_details to be None or have cached_tokens=0 + usage_obj.prompt_tokens_details = None + + mock_response.usage = usage_obj + + # Simulate slight delay + time.sleep(0.1) + return mock_response + + mock_completion.side_effect = side_effect + + # Define the workload + def run_llm_request(thread_id, result_container): + token_process = TokenProcess() + handler = TokenCalcHandler(token_cost_process=token_process) + + # Store handler so we can verify it later + result_container[thread_id] = { + "handler": handler, + "summary": None + } + + llm = LLM(model="gpt-4o-mini", is_litellm=True) + + llm.call( + messages=[{"role": "user", "content": f"Hello from thread {thread_id}"}], + callbacks=[handler] + ) + + result_container[thread_id]["summary"] = token_process.get_summary() + + results = {} + threads = [] + + # Start threads + for i in [1, 2]: + t = threading.Thread(target=run_llm_request, args=(i, results)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # Verification + assert mock_completion.call_count == 2 + + # Check each call arguments + for call_args in mock_completion.call_args_list: + kwargs = call_args.kwargs + messages = kwargs.get("messages", []) + content = messages[0]["content"] + thread_id = int(content.split("thread ")[-1]) + + expected_handler = results[thread_id]["handler"] + + # CRITICAL CHECK: Verify ONLY the expected handler was passed + assert "callbacks" in kwargs + callbacks = kwargs["callbacks"] + assert len(callbacks) == 1 + assert callbacks[0] == expected_handler, f"Callback mismatch for thread {thread_id}" + + # Verify token usage isolation + summary1 = results[1]["summary"] + assert summary1.prompt_tokens == 11 + assert summary1.completion_tokens == 6 + + summary2 = results[2]["summary"] + assert summary2.prompt_tokens == 12 + assert summary2.completion_tokens == 7 diff --git a/lib/crewai/tests/test_llm.py b/lib/crewai/tests/test_llm.py index 1ed2171664..1336109ae0 100644 --- a/lib/crewai/tests/test_llm.py +++ b/lib/crewai/tests/test_llm.py @@ -1,6 +1,5 @@ import logging import os -from time import sleep from unittest.mock import MagicMock, patch from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess @@ -18,27 +17,54 @@ import pytest -# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date -@pytest.mark.vcr() def test_llm_callback_replacement(): + """Callbacks passed to different LLM instances should not interfere with each other. + + Historically this was flaky because CrewAI mutated LiteLLM's global callback lists, + so callbacks could be removed/overwritten while a different request was still in-flight. + """ llm1 = LLM(model="gpt-4o-mini", is_litellm=True) llm2 = LLM(model="gpt-4o-mini", is_litellm=True) calc_handler_1 = TokenCalcHandler(token_cost_process=TokenProcess()) calc_handler_2 = TokenCalcHandler(token_cost_process=TokenProcess()) - llm1.call( - messages=[{"role": "user", "content": "Hello, world!"}], - callbacks=[calc_handler_1], - ) - usage_metrics_1 = calc_handler_1.token_cost_process.get_summary() + with patch("litellm.completion") as mock_completion: + # Return a minimal response object with .choices[0].message.content and .usage + def _mock_response(content: str): + mock_message = MagicMock() + mock_message.content = content + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = { + "prompt_tokens": 10, + "completion_tokens": 10, + "total_tokens": 20, + } + return mock_response + + mock_completion.side_effect = [ + _mock_response("Hello from call 1"), + _mock_response("Hello from call 2"), + ] + + llm1.call( + messages=[{"role": "user", "content": "Hello, world!"}], + callbacks=[calc_handler_1], + ) + usage_metrics_1 = calc_handler_1.token_cost_process.get_summary() - llm2.call( - messages=[{"role": "user", "content": "Hello, world from another agent!"}], - callbacks=[calc_handler_2], - ) - sleep(5) - usage_metrics_2 = calc_handler_2.token_cost_process.get_summary() + llm2.call( + messages=[{"role": "user", "content": "Hello, world from another agent!"}], + callbacks=[calc_handler_2], + ) + usage_metrics_2 = calc_handler_2.token_cost_process.get_summary() + + # Ensure callbacks are passed per-request (no reliance on global LiteLLM callback lists) + assert mock_completion.call_args_list[0].kwargs["callbacks"] == [calc_handler_1] + assert mock_completion.call_args_list[1].kwargs["callbacks"] == [calc_handler_2] # The first handler should not have been updated assert usage_metrics_1.successful_requests == 1