Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 183 additions & 96 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,8 @@ async def connect_mcp(
payload: ConnectMCPRequest,
current_user: UserParam,
):
import asyncio

from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import (
Expand All @@ -1307,7 +1309,7 @@ async def connect_mcp(
StdioMcpConnection,
validate_mcp_command,
)
from chainlit.session import WebsocketSession
from chainlit.session import McpSession, WebsocketSession

session = WebsocketSession.get_by_id(payload.sessionId)
context = init_ws_context(session)
Expand All @@ -1323,113 +1325,203 @@ async def connect_mcp(
)

mcp_enabled = config.features.mcp.enabled
if mcp_enabled:
if payload.name in session.mcp_sessions:
old_client_session, old_exit_stack = session.mcp_sessions[payload.name]
if on_mcp_disconnect := config.code.on_mcp_disconnect:
await on_mcp_disconnect(payload.name, old_client_session)
if not mcp_enabled:
raise HTTPException(
status_code=400,
detail="This app does not support MCP.",
)

# Disconnect previous session for this name (reconnection)
if payload.name in session.mcp_sessions:
old_mcp = session.mcp_sessions.pop(payload.name)
if on_mcp_disconnect := config.code.on_mcp_disconnect:
try:
await old_exit_stack.aclose()
await on_mcp_disconnect(payload.name, old_mcp.client)
except Exception:
pass

logger.debug(
"Error in on_mcp_disconnect callback for %s",
payload.name,
exc_info=True,
)
try:
exit_stack = AsyncExitStack()
mcp_connection: McpConnection

if payload.clientType == "sse":
if not config.features.mcp.sse.enabled:
raise HTTPException(
status_code=400,
detail="SSE MCP is not enabled",
)
await old_mcp.close()
except Exception:
logger.debug(
"Error closing old MCP session %s", payload.name, exc_info=True
)

mcp_connection = SseMcpConnection(
url=payload.url,
name=payload.name,
headers=getattr(payload, "headers", None),
)
# ── Validate config before launching the background task ──
mcp_connection: McpConnection

if payload.clientType == "sse":
if not config.features.mcp.sse.enabled:
raise HTTPException(
status_code=400,
detail="SSE MCP is not enabled",
)
mcp_connection = SseMcpConnection(
url=payload.url,
name=payload.name,
headers=getattr(payload, "headers", None),
)
elif payload.clientType == "stdio":
if not config.features.mcp.stdio.enabled:
raise HTTPException(
status_code=400,
detail="Stdio MCP is not enabled",
)
env_from_cmd, command, args = validate_mcp_command(payload.fullCommand)
mcp_connection = StdioMcpConnection(
command=command, args=args, name=payload.name
)
elif payload.clientType == "streamable-http":
if not config.features.mcp.streamable_http.enabled:
raise HTTPException(
status_code=400,
detail="HTTP MCP is not enabled",
)
mcp_connection = HttpMcpConnection(
url=payload.url,
name=payload.name,
headers=getattr(payload, "headers", None),
)
else:
raise HTTPException(
status_code=400,
detail=f"Unknown MCP client type: {payload.clientType}",
)

transport = await exit_stack.enter_async_context(
sse_client(
url=mcp_connection.url,
headers=mcp_connection.headers,
# ── Launch the MCP connection in its own background task ──
#
# The background task owns the AsyncExitStack: it enters all context
# managers, calls initialize(), signals ``ready_event``, and then
# blocks on ``stop_event.wait()``. When the stop event fires the
# task wakes up and closes the exit stack *in the same task* that
# opened it — avoiding the cross-task cancel-scope corruption from
# https://github.com/Chainlit/chainlit/issues/2182.

ready_event: asyncio.Event = asyncio.Event()
stop_event: asyncio.Event = asyncio.Event()
# Mutable container to pass the ClientSession back from the bg task.
result_holder: dict[str, object] = {}

async def _mcp_session_runner() -> None:
exit_stack = AsyncExitStack()
try:
try:
if isinstance(mcp_connection, SseMcpConnection):
transport = await exit_stack.enter_async_context(
sse_client(
url=mcp_connection.url,
headers=mcp_connection.headers,
)
)
)
elif payload.clientType == "stdio":
if not config.features.mcp.stdio.enabled:
raise HTTPException(
status_code=400,
detail="Stdio MCP is not enabled",
elif isinstance(mcp_connection, StdioMcpConnection):
env = get_default_environment()
env.update(env_from_cmd)
server_params = StdioServerParameters(
command=command, args=args, env=env
)
transport = await exit_stack.enter_async_context(
stdio_client(server_params)
)
elif isinstance(mcp_connection, HttpMcpConnection):
transport = await exit_stack.enter_async_context(
streamablehttp_client(
url=mcp_connection.url,
headers=mcp_connection.headers,
)
)
else:
raise ValueError(f"Unknown client type: {payload.clientType}")

env_from_cmd, command, args = validate_mcp_command(payload.fullCommand)
mcp_connection = StdioMcpConnection(
command=command, args=args, name=payload.name
)

env = get_default_environment()
env.update(env_from_cmd)
# Create the server parameters
server_params = StdioServerParameters(
command=command, args=args, env=env
)

transport = await exit_stack.enter_async_context(
stdio_client(server_params)
)
read, write = transport[:2]

elif payload.clientType == "streamable-http":
if not config.features.mcp.streamable_http.enabled:
raise HTTPException(
status_code=400,
detail="HTTP MCP is not enabled",
)
mcp_connection = HttpMcpConnection(
url=payload.url,
name=payload.name,
headers=getattr(payload, "headers", None),
)
transport = await exit_stack.enter_async_context(
streamablehttp_client(
url=mcp_connection.url,
headers=mcp_connection.headers,
mcp_client: ClientSession = await exit_stack.enter_async_context(
ClientSession(
read_stream=read,
write_stream=write,
sampling_callback=None,
)
)

# The transport can return (read, write) for stdio, sse
# Or (read, write, get_session_id) for streamable-http
# We are only interested in the read and write streams here.
read, write = transport[:2]
await mcp_client.initialize()
result_holder["client"] = mcp_client

mcp_session: ClientSession = await exit_stack.enter_async_context(
ClientSession(
read_stream=read, write_stream=write, sampling_callback=None
except BaseException as exc:
result_holder["error"] = exc
return # outer finally closes exit_stack
finally:
# Always signal the caller so it doesn't wait forever.
ready_event.set()

# ── Keep the task (and the exit stack) alive ──
try:
await stop_event.wait()
except asyncio.CancelledError:
logger.debug("MCP background task for %r cancelled", payload.name)
finally:
# Close exit_stack in ALL paths (error, normal shutdown,
# cancellation) — always in the same task that opened it.
logger.debug("Closing MCP exit stack for %r (same-task)", payload.name)
try:
await exit_stack.aclose()
except BaseException:
logger.debug(
"Error closing MCP exit stack for %r",
payload.name,
exc_info=True,
)
)

# Initialize the session
await mcp_session.initialize()
task = asyncio.create_task(
_mcp_session_runner(), name=f"mcp-session-{payload.name}"
)

# Store the session
session.mcp_sessions[mcp_connection.name] = (mcp_session, exit_stack)
# Wait for the background task to finish initialisation.
await ready_event.wait()

# Call the callback
if config.code.on_mcp_connect:
await config.code.on_mcp_connect(mcp_connection, mcp_session)
if "error" in result_holder:
# The task already exited and cleaned up its exit stack.
# Make sure the task itself is fully done.
try:
await task
except BaseException:
pass
return JSONResponse(
status_code=400,
content={
"detail": f"Could not connect to the MCP: {result_holder['error']!s}"
},
)

mcp_client_session = cast("ClientSession", result_holder["client"])

# Call the user callback
if config.code.on_mcp_connect:
try:
await config.code.on_mcp_connect(mcp_connection, mcp_client_session)
except Exception as e:
raise HTTPException(
# Callback failed — tear down the connection.
stop_event.set()
try:
await task
except BaseException:
pass
return JSONResponse(
status_code=400,
detail=f"Could not connect to the MCP: {e!s}",
content={"detail": f"Could not connect to the MCP: {e!s}"},
)
else:
raise HTTPException(
status_code=400,
detail="This app does not support MCP.",
)

tool_list = await mcp_session.list_tools()
# Store the session
mcp_session_obj = McpSession(
name=mcp_connection.name,
client=mcp_client_session,
task=task,
stop_event=stop_event,
)
session.mcp_sessions[mcp_connection.name] = mcp_session_obj

tool_list = await mcp_client_session.list_tools()

return JSONResponse(
content={
Expand Down Expand Up @@ -1475,22 +1567,17 @@ async def disconnect_mcp(

callback = config.code.on_mcp_disconnect
if payload.name in session.mcp_sessions:
mcp_session_obj = session.mcp_sessions.pop(payload.name)
try:
client_session, exit_stack = session.mcp_sessions[payload.name]
if callback:
await callback(payload.name, client_session)

try:
await exit_stack.aclose()
except Exception:
pass
del session.mcp_sessions[payload.name]

await callback(payload.name, mcp_session_obj.client)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Could not disconnect to the MCP: {e!s}",
detail=f"Could not disconnect from the MCP: {e!s}",
)
finally:
await mcp_session_obj.close()

return JSONResponse(content={"success": True})

Expand Down
Loading
Loading