Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 29 additions & 35 deletions lib/crewai/src/crewai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,6 @@ def __init__(
else:
self.stop = stop

self.set_callbacks(callbacks or [])
self.set_env_callbacks()

@staticmethod
Expand Down Expand Up @@ -1786,18 +1785,23 @@ def call(
if not self._invoke_before_llm_call_hooks(messages, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")

effective_callbacks = callbacks if callbacks is not None else self.callbacks

# --- 5) Set up callbacks if provided
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
if effective_callbacks and len(effective_callbacks) > 0:
# Avoid mutating LiteLLM global callback lists. Pass callbacks per request
# so concurrent LLM instances don't race on shared global state.
params["callbacks"] = effective_callbacks

# --- 7) Make the completion call and handle response
if self.stream:
result = self._handle_streaming_response(
params=params,
callbacks=callbacks,
callbacks=effective_callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
Expand All @@ -1806,7 +1810,7 @@ def call(
else:
result = self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
callbacks=effective_callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
Expand Down Expand Up @@ -1932,18 +1936,22 @@ async def acall(
msg_role: Literal["assistant"] = "assistant"
message["role"] = msg_role

effective_callbacks = callbacks if callbacks is not None else self.callbacks

with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
params = self._prepare_completion_params(
messages, tools, skip_file_processing=True
)
if effective_callbacks and len(effective_callbacks) > 0:
# Avoid mutating LiteLLM global callback lists. Pass callbacks per request
# so concurrent LLM instances don't race on shared global state.
params["callbacks"] = effective_callbacks

if self.stream:
return await self._ahandle_streaming_response(
params=params,
callbacks=callbacks,
callbacks=effective_callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
Expand All @@ -1952,7 +1960,7 @@ async def acall(

return await self._ahandle_non_streaming_response(
params=params,
callbacks=callbacks,
callbacks=effective_callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
Expand Down Expand Up @@ -2002,6 +2010,18 @@ async def acall(
),
)
raise
)

crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e),
from_task=from_task,
from_agent=from_agent,
call_id=get_current_call_id(),
),
)
raise

def _handle_emit_call_events(
self,
Expand Down Expand Up @@ -2303,32 +2323,6 @@ def get_context_window_size(self) -> int:
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size

@staticmethod
def set_callbacks(callbacks: list[Any]) -> None:
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.

Note: This only affects the litellm fallback path. Native providers
don't use litellm callbacks - they emit events via base_llm.py.
"""
if not LITELLM_AVAILABLE:
# When litellm is not available, callbacks are still stored
# but not registered with litellm globals
return

with suppress_warnings():
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)

for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)

litellm.callbacks = callbacks

@staticmethod
def set_env_callbacks() -> None:
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
Expand Down
107 changes: 107 additions & 0 deletions lib/crewai/tests/llms/test_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import threading
import time
from unittest.mock import MagicMock, patch
import pytest
from crewai.llm import LLM
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess

def test_concurrent_llm_calls_isolation():
"""
Test that concurrent LLM calls with different callbacks do not interfere with each other.
"""

# We patch globally so it applies to all threads
# We use crewai.llm.litellm to be safe
with patch("crewai.llm.litellm.completion") as mock_completion:

# Setup mock to return a valid response structure
def side_effect(*args, **kwargs):
messages = kwargs.get("messages", [])
content = messages[0]["content"] if messages else ""
thread_id = content.split("thread ")[-1]

mock_message = MagicMock()
mock_message.content = f"Response for thread {thread_id}"
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]

# Use unique usage stats based on thread ID
tid_int = int(thread_id)

# Create a usage object with attributes (not a dict)
usage_obj = MagicMock()
usage_obj.prompt_tokens = 10 + tid_int
usage_obj.completion_tokens = 5 + tid_int
usage_obj.total_tokens = 15 + 2 * tid_int
# Mock prompt_tokens_details to be None or have cached_tokens=0
usage_obj.prompt_tokens_details = None

mock_response.usage = usage_obj

# Simulate slight delay
time.sleep(0.1)
return mock_response

mock_completion.side_effect = side_effect

# Define the workload
def run_llm_request(thread_id, result_container):
token_process = TokenProcess()
handler = TokenCalcHandler(token_cost_process=token_process)

# Store handler so we can verify it later
result_container[thread_id] = {
"handler": handler,
"summary": None
}

llm = LLM(model="gpt-4o-mini", is_litellm=True)

llm.call(
messages=[{"role": "user", "content": f"Hello from thread {thread_id}"}],
callbacks=[handler]
)

result_container[thread_id]["summary"] = token_process.get_summary()

results = {}
threads = []

# Start threads
for i in [1, 2]:
t = threading.Thread(target=run_llm_request, args=(i, results))
threads.append(t)
t.start()

for t in threads:
t.join()

# Verification
assert mock_completion.call_count == 2

# Check each call arguments
for call_args in mock_completion.call_args_list:
kwargs = call_args.kwargs
messages = kwargs.get("messages", [])
content = messages[0]["content"]
thread_id = int(content.split("thread ")[-1])

expected_handler = results[thread_id]["handler"]

# CRITICAL CHECK: Verify ONLY the expected handler was passed
assert "callbacks" in kwargs
callbacks = kwargs["callbacks"]
assert len(callbacks) == 1
assert callbacks[0] == expected_handler, f"Callback mismatch for thread {thread_id}"

# Verify token usage isolation
summary1 = results[1]["summary"]
assert summary1.prompt_tokens == 11
assert summary1.completion_tokens == 6

summary2 = results[2]["summary"]
assert summary2.prompt_tokens == 12
assert summary2.completion_tokens == 7
54 changes: 40 additions & 14 deletions lib/crewai/tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
from time import sleep
from unittest.mock import MagicMock, patch

from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
Expand All @@ -18,27 +17,54 @@
import pytest


# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
@pytest.mark.vcr()
def test_llm_callback_replacement():
"""Callbacks passed to different LLM instances should not interfere with each other.

Historically this was flaky because CrewAI mutated LiteLLM's global callback lists,
so callbacks could be removed/overwritten while a different request was still in-flight.
"""
llm1 = LLM(model="gpt-4o-mini", is_litellm=True)
llm2 = LLM(model="gpt-4o-mini", is_litellm=True)

calc_handler_1 = TokenCalcHandler(token_cost_process=TokenProcess())
calc_handler_2 = TokenCalcHandler(token_cost_process=TokenProcess())

llm1.call(
messages=[{"role": "user", "content": "Hello, world!"}],
callbacks=[calc_handler_1],
)
usage_metrics_1 = calc_handler_1.token_cost_process.get_summary()
with patch("litellm.completion") as mock_completion:
# Return a minimal response object with .choices[0].message.content and .usage
def _mock_response(content: str):
mock_message = MagicMock()
mock_message.content = content
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_response.usage = {
"prompt_tokens": 10,
"completion_tokens": 10,
"total_tokens": 20,
}
return mock_response

mock_completion.side_effect = [
_mock_response("Hello from call 1"),
_mock_response("Hello from call 2"),
]

llm1.call(
messages=[{"role": "user", "content": "Hello, world!"}],
callbacks=[calc_handler_1],
)
usage_metrics_1 = calc_handler_1.token_cost_process.get_summary()

llm2.call(
messages=[{"role": "user", "content": "Hello, world from another agent!"}],
callbacks=[calc_handler_2],
)
sleep(5)
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
llm2.call(
messages=[{"role": "user", "content": "Hello, world from another agent!"}],
callbacks=[calc_handler_2],
)
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()

# Ensure callbacks are passed per-request (no reliance on global LiteLLM callback lists)
assert mock_completion.call_args_list[0].kwargs["callbacks"] == [calc_handler_1]
assert mock_completion.call_args_list[1].kwargs["callbacks"] == [calc_handler_2]

# The first handler should not have been updated
assert usage_metrics_1.successful_requests == 1
Expand Down