diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index f9393e5e3b..81717d03e4 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -1290,6 +1290,8 @@ async def connect_mcp( payload: ConnectMCPRequest, current_user: UserParam, ): + import asyncio + from mcp import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import ( @@ -1307,7 +1309,7 @@ async def connect_mcp( StdioMcpConnection, validate_mcp_command, ) - from chainlit.session import WebsocketSession + from chainlit.session import McpSession, WebsocketSession session = WebsocketSession.get_by_id(payload.sessionId) context = init_ws_context(session) @@ -1323,113 +1325,203 @@ async def connect_mcp( ) mcp_enabled = config.features.mcp.enabled - if mcp_enabled: - if payload.name in session.mcp_sessions: - old_client_session, old_exit_stack = session.mcp_sessions[payload.name] - if on_mcp_disconnect := config.code.on_mcp_disconnect: - await on_mcp_disconnect(payload.name, old_client_session) + if not mcp_enabled: + raise HTTPException( + status_code=400, + detail="This app does not support MCP.", + ) + + # Disconnect previous session for this name (reconnection) + if payload.name in session.mcp_sessions: + old_mcp = session.mcp_sessions.pop(payload.name) + if on_mcp_disconnect := config.code.on_mcp_disconnect: try: - await old_exit_stack.aclose() + await on_mcp_disconnect(payload.name, old_mcp.client) except Exception: - pass - + logger.debug( + "Error in on_mcp_disconnect callback for %s", + payload.name, + exc_info=True, + ) try: - exit_stack = AsyncExitStack() - mcp_connection: McpConnection - - if payload.clientType == "sse": - if not config.features.mcp.sse.enabled: - raise HTTPException( - status_code=400, - detail="SSE MCP is not enabled", - ) + await old_mcp.close() + except Exception: + logger.debug( + "Error closing old MCP session %s", payload.name, exc_info=True + ) - mcp_connection = SseMcpConnection( - url=payload.url, - name=payload.name, - headers=getattr(payload, "headers", None), - ) + # ── Validate config before launching the background task ── + mcp_connection: McpConnection + + if payload.clientType == "sse": + if not config.features.mcp.sse.enabled: + raise HTTPException( + status_code=400, + detail="SSE MCP is not enabled", + ) + mcp_connection = SseMcpConnection( + url=payload.url, + name=payload.name, + headers=getattr(payload, "headers", None), + ) + elif payload.clientType == "stdio": + if not config.features.mcp.stdio.enabled: + raise HTTPException( + status_code=400, + detail="Stdio MCP is not enabled", + ) + env_from_cmd, command, args = validate_mcp_command(payload.fullCommand) + mcp_connection = StdioMcpConnection( + command=command, args=args, name=payload.name + ) + elif payload.clientType == "streamable-http": + if not config.features.mcp.streamable_http.enabled: + raise HTTPException( + status_code=400, + detail="HTTP MCP is not enabled", + ) + mcp_connection = HttpMcpConnection( + url=payload.url, + name=payload.name, + headers=getattr(payload, "headers", None), + ) + else: + raise HTTPException( + status_code=400, + detail=f"Unknown MCP client type: {payload.clientType}", + ) - transport = await exit_stack.enter_async_context( - sse_client( - url=mcp_connection.url, - headers=mcp_connection.headers, + # ── Launch the MCP connection in its own background task ── + # + # The background task owns the AsyncExitStack: it enters all context + # managers, calls initialize(), signals ``ready_event``, and then + # blocks on ``stop_event.wait()``. When the stop event fires the + # task wakes up and closes the exit stack *in the same task* that + # opened it — avoiding the cross-task cancel-scope corruption from + # https://github.com/Chainlit/chainlit/issues/2182. + + ready_event: asyncio.Event = asyncio.Event() + stop_event: asyncio.Event = asyncio.Event() + # Mutable container to pass the ClientSession back from the bg task. + result_holder: dict[str, object] = {} + + async def _mcp_session_runner() -> None: + exit_stack = AsyncExitStack() + try: + try: + if isinstance(mcp_connection, SseMcpConnection): + transport = await exit_stack.enter_async_context( + sse_client( + url=mcp_connection.url, + headers=mcp_connection.headers, + ) ) - ) - elif payload.clientType == "stdio": - if not config.features.mcp.stdio.enabled: - raise HTTPException( - status_code=400, - detail="Stdio MCP is not enabled", + elif isinstance(mcp_connection, StdioMcpConnection): + env = get_default_environment() + env.update(env_from_cmd) + server_params = StdioServerParameters( + command=command, args=args, env=env ) + transport = await exit_stack.enter_async_context( + stdio_client(server_params) + ) + elif isinstance(mcp_connection, HttpMcpConnection): + transport = await exit_stack.enter_async_context( + streamablehttp_client( + url=mcp_connection.url, + headers=mcp_connection.headers, + ) + ) + else: + raise ValueError(f"Unknown client type: {payload.clientType}") - env_from_cmd, command, args = validate_mcp_command(payload.fullCommand) - mcp_connection = StdioMcpConnection( - command=command, args=args, name=payload.name - ) - - env = get_default_environment() - env.update(env_from_cmd) - # Create the server parameters - server_params = StdioServerParameters( - command=command, args=args, env=env - ) - - transport = await exit_stack.enter_async_context( - stdio_client(server_params) - ) + read, write = transport[:2] - elif payload.clientType == "streamable-http": - if not config.features.mcp.streamable_http.enabled: - raise HTTPException( - status_code=400, - detail="HTTP MCP is not enabled", - ) - mcp_connection = HttpMcpConnection( - url=payload.url, - name=payload.name, - headers=getattr(payload, "headers", None), - ) - transport = await exit_stack.enter_async_context( - streamablehttp_client( - url=mcp_connection.url, - headers=mcp_connection.headers, + mcp_client: ClientSession = await exit_stack.enter_async_context( + ClientSession( + read_stream=read, + write_stream=write, + sampling_callback=None, ) ) - # The transport can return (read, write) for stdio, sse - # Or (read, write, get_session_id) for streamable-http - # We are only interested in the read and write streams here. - read, write = transport[:2] + await mcp_client.initialize() + result_holder["client"] = mcp_client - mcp_session: ClientSession = await exit_stack.enter_async_context( - ClientSession( - read_stream=read, write_stream=write, sampling_callback=None + except BaseException as exc: + result_holder["error"] = exc + return # outer finally closes exit_stack + finally: + # Always signal the caller so it doesn't wait forever. + ready_event.set() + + # ── Keep the task (and the exit stack) alive ── + try: + await stop_event.wait() + except asyncio.CancelledError: + logger.debug("MCP background task for %r cancelled", payload.name) + finally: + # Close exit_stack in ALL paths (error, normal shutdown, + # cancellation) — always in the same task that opened it. + logger.debug("Closing MCP exit stack for %r (same-task)", payload.name) + try: + await exit_stack.aclose() + except BaseException: + logger.debug( + "Error closing MCP exit stack for %r", + payload.name, + exc_info=True, ) - ) - # Initialize the session - await mcp_session.initialize() + task = asyncio.create_task( + _mcp_session_runner(), name=f"mcp-session-{payload.name}" + ) - # Store the session - session.mcp_sessions[mcp_connection.name] = (mcp_session, exit_stack) + # Wait for the background task to finish initialisation. + await ready_event.wait() - # Call the callback - if config.code.on_mcp_connect: - await config.code.on_mcp_connect(mcp_connection, mcp_session) + if "error" in result_holder: + # The task already exited and cleaned up its exit stack. + # Make sure the task itself is fully done. + try: + await task + except BaseException: + pass + return JSONResponse( + status_code=400, + content={ + "detail": f"Could not connect to the MCP: {result_holder['error']!s}" + }, + ) + mcp_client_session = cast("ClientSession", result_holder["client"]) + + # Call the user callback + if config.code.on_mcp_connect: + try: + await config.code.on_mcp_connect(mcp_connection, mcp_client_session) except Exception as e: - raise HTTPException( + # Callback failed — tear down the connection. + stop_event.set() + try: + await task + except BaseException: + pass + return JSONResponse( status_code=400, - detail=f"Could not connect to the MCP: {e!s}", + content={"detail": f"Could not connect to the MCP: {e!s}"}, ) - else: - raise HTTPException( - status_code=400, - detail="This app does not support MCP.", - ) - tool_list = await mcp_session.list_tools() + # Store the session + mcp_session_obj = McpSession( + name=mcp_connection.name, + client=mcp_client_session, + task=task, + stop_event=stop_event, + ) + session.mcp_sessions[mcp_connection.name] = mcp_session_obj + + tool_list = await mcp_client_session.list_tools() return JSONResponse( content={ @@ -1475,22 +1567,17 @@ async def disconnect_mcp( callback = config.code.on_mcp_disconnect if payload.name in session.mcp_sessions: + mcp_session_obj = session.mcp_sessions.pop(payload.name) try: - client_session, exit_stack = session.mcp_sessions[payload.name] if callback: - await callback(payload.name, client_session) - - try: - await exit_stack.aclose() - except Exception: - pass - del session.mcp_sessions[payload.name] - + await callback(payload.name, mcp_session_obj.client) except Exception as e: raise HTTPException( status_code=400, - detail=f"Could not disconnect to the MCP: {e!s}", + detail=f"Could not disconnect from the MCP: {e!s}", ) + finally: + await mcp_session_obj.close() return JSONResponse(content={"success": True}) diff --git a/backend/chainlit/session.py b/backend/chainlit/session.py index d6bd3f6214..2910f58140 100644 --- a/backend/chainlit/session.py +++ b/backend/chainlit/session.py @@ -4,7 +4,7 @@ import re import shutil import uuid -from contextlib import AsyncExitStack +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Literal, Optional, Union import aiofiles @@ -19,6 +19,60 @@ from chainlit.types import FileDict from chainlit.user import PersistedUser, User +_CLOSE_TIMEOUT = 10.0 # seconds to wait for a background MCP task to finish + + +@dataclass +class McpSession: + """Lifecycle wrapper for a single MCP connection. + + Each MCP connection is run inside its own ``asyncio.Task``. That task + creates the ``AsyncExitStack``, enters all context managers (transport, + ``ClientSession``), calls ``initialize()``, and then blocks on + ``stop_event.wait()``. When the event is set the task wakes up and + closes the exit stack **in the same task** that opened it, avoiding + the cross-task cancel-scope corruption described in + https://github.com/Chainlit/chainlit/issues/2182. + + Original solution by @nigiva: + https://github.com/Chainlit/chainlit/issues/2182#issuecomment-2840283194 + """ + + name: str + client: "ClientSession" + task: asyncio.Task + stop_event: asyncio.Event = field(default_factory=asyncio.Event) + + async def close(self) -> None: + """Signal the background task to shut down and wait for it.""" + self.stop_event.set() + try: + await asyncio.wait_for(self.task, timeout=_CLOSE_TIMEOUT) + except asyncio.TimeoutError: + logger.warning( + "MCP session %r did not shut down within %.1fs — cancelling", + self.name, + _CLOSE_TIMEOUT, + ) + self.task.cancel() + try: + await self.task + except BaseException: + pass + except asyncio.CancelledError: + pass + except BaseException: + logger.debug("Error while closing MCP session %r", self.name, exc_info=True) + + # Backward-compatible tuple unpacking. + # The original Chainlit format is ``(ClientSession, AsyncExitStack)``. + # Code that does ``client, _ = mcp_sessions[name]`` will get the + # ``ClientSession`` and a safe sentinel (not the real exit stack, + # which must only be closed by the owning background task). + def __iter__(self): + return iter((self.client, self)) + + ClientType = Literal["webapp", "copilot", "teams", "slack", "discord"] @@ -214,7 +268,7 @@ class WebsocketSession(BaseSession): to_clear: bool = False - mcp_sessions: dict[str, tuple["ClientSession", AsyncExitStack]] + mcp_sessions: dict[str, McpSession] def __init__( self, @@ -321,11 +375,16 @@ async def delete(self): ws_sessions_sid.pop(self.socket_id, None) ws_sessions_id.pop(self.id, None) - for _, exit_stack in self.mcp_sessions.values(): + for mcp_session in list(self.mcp_sessions.values()): try: - await exit_stack.aclose() + await mcp_session.close() except Exception: - pass + logger.debug( + "Error closing MCP session %r during session delete", + mcp_session.name, + exc_info=True, + ) + self.mcp_sessions.clear() async def flush_method_queue(self): for method_name, queue in self.thread_queues.items(): diff --git a/backend/tests/test_session.py b/backend/tests/test_session.py index e98b7a0994..f1a26656f3 100644 --- a/backend/tests/test_session.py +++ b/backend/tests/test_session.py @@ -1,3 +1,4 @@ +import builtins import json import tempfile import uuid @@ -10,11 +11,19 @@ BaseSession, HTTPSession, JSONEncoderIgnoreNonSerializable, + McpSession, WebsocketSession, clean_metadata, ) +def make_exception_group(message: str, exceptions: list[BaseException]): + base_exception_group = getattr(builtins, "BaseExceptionGroup", None) + if base_exception_group is None: + pytest.skip("BaseExceptionGroup is unavailable on this Python version") + return base_exception_group(message, exceptions) # type: ignore[misc] + + class TestJSONEncoderIgnoreNonSerializable: """Test suite for JSONEncoderIgnoreNonSerializable.""" @@ -613,10 +622,144 @@ async def test_websocket_session_delete_with_mcp_sessions(self): client_type="webapp", ) - # Mock MCP session with exit stack - mock_exit_stack = AsyncMock() - session.mcp_sessions["mcp1"] = (Mock(), mock_exit_stack) + # Create a real McpSession with a completed task + import asyncio + + stop = asyncio.Event() + stop.set() # already stopped + + async def _noop(): + pass + + task = asyncio.create_task(_noop()) + await task # let it finish + + mcp = McpSession( + name="mcp1", + client=Mock(), + task=task, + stop_event=stop, + ) + session.mcp_sessions["mcp1"] = mcp + + await session.delete() + + assert "mcp1" not in session.mcp_sessions + + @pytest.mark.asyncio + async def test_websocket_session_delete_with_hanging_mcp(self): + """Test that session delete handles a slow MCP session gracefully.""" + import asyncio + + with tempfile.TemporaryDirectory() as tmpdir: + with patch("chainlit.config.FILES_DIRECTORY", Path(tmpdir)): + session = WebsocketSession( + id="ws_id", + socket_id="socket_123", + emit=Mock(), + emit_call=Mock(), + user_env={}, + client_type="webapp", + ) + + stop = asyncio.Event() + + async def _hang(): + await stop.wait() + + task = asyncio.create_task(_hang()) + + mcp = McpSession( + name="mcp1", + client=Mock(), + task=task, + stop_event=stop, + ) + session.mcp_sessions["mcp1"] = mcp + # delete() should close the session cleanly await session.delete() - mock_exit_stack.aclose.assert_called_once() + assert task.done() + assert "mcp1" not in session.mcp_sessions + + +class TestMcpSession: + """Test suite for the McpSession dataclass.""" + + @pytest.mark.asyncio + async def test_close_signals_stop_and_awaits_task(self): + """close() sets the stop event and waits for the task.""" + import asyncio + + stop = asyncio.Event() + + async def _runner(): + await stop.wait() + + task = asyncio.create_task(_runner()) + mcp = McpSession( + name="test", + client=Mock(), + task=task, + stop_event=stop, + ) + + await mcp.close() + + assert stop.is_set() + assert task.done() + + @pytest.mark.asyncio + async def test_close_cancels_on_timeout(self): + """close() cancels a task that doesn't respond to stop_event.""" + import asyncio + + stop = asyncio.Event() + + async def _stuck(): + # Ignore stop_event entirely + await asyncio.sleep(3600) + + task = asyncio.create_task(_stuck()) + mcp = McpSession( + name="stuck", + client=Mock(), + task=task, + stop_event=stop, + ) + + # Temporarily reduce timeout for this test + import chainlit.session as session_mod + + original_timeout = session_mod._CLOSE_TIMEOUT + session_mod._CLOSE_TIMEOUT = 0.1 + try: + await mcp.close() + finally: + session_mod._CLOSE_TIMEOUT = original_timeout + + assert task.done() + + @pytest.mark.asyncio + async def test_close_idempotent(self): + """Calling close() twice does not raise.""" + import asyncio + + stop = asyncio.Event() + + async def _runner(): + await stop.wait() + + task = asyncio.create_task(_runner()) + mcp = McpSession( + name="test", + client=Mock(), + task=task, + stop_event=stop, + ) + + await mcp.close() + await mcp.close() # second call should be safe + + assert task.done()