diff --git a/agentex-ui/components/task-header/investigate-traces-button.tsx b/agentex-ui/components/task-header/investigate-traces-button.tsx index ffac1b3d..7663ab89 100644 --- a/agentex-ui/components/task-header/investigate-traces-button.tsx +++ b/agentex-ui/components/task-header/investigate-traces-button.tsx @@ -10,15 +10,15 @@ import { cn } from '@/lib/utils'; type InvestigateTracesButtonProps = { className?: string; disabled?: boolean; - taskId: string; + traceId: string; }; export const InvestigateTracesButton = forwardRef< HTMLAnchorElement, InvestigateTracesButtonProps ->(({ className, disabled = false, taskId, ...props }, ref) => { +>(({ className, disabled = false, traceId, ...props }, ref) => { const { sgpAppURL } = useAgentexClient(); - const sgpTracesURL = `${sgpAppURL}/beta/monitor?trace_id=${taskId}&tt-trace-id=${taskId}`; + const sgpTracesURL = `${sgpAppURL}/beta/monitor?trace_id=${traceId}&tt-trace-id=${traceId}`; if (!sgpAppURL) { return null; diff --git a/agentex-ui/components/task-header/task-header.tsx b/agentex-ui/components/task-header/task-header.tsx index 4a024de8..b8b3aa2b 100644 --- a/agentex-ui/components/task-header/task-header.tsx +++ b/agentex-ui/components/task-header/task-header.tsx @@ -13,6 +13,7 @@ import { SelectValue, } from '@/components/ui/select'; import { useSafeSearchParams } from '@/hooks/use-safe-search-params'; +import { useSpans } from '@/hooks/use-spans'; import type { Agent } from 'agentex/resources'; @@ -36,6 +37,8 @@ export function TaskHeader({ }: TaskHeaderProps) { const displayTaskId = taskId ? taskId.split('-')[0] : ''; const { agentName: selectedAgentName } = useSafeSearchParams(); + const { spans } = useSpans(taskId); + const traceId = spans[0]?.trace_id ?? taskId; const copyTaskId = async () => { if (taskId) { @@ -109,7 +112,7 @@ export function TaskHeader({ icon={Activity} /> )} - {taskId && } + {taskId && traceId && } diff --git a/agentex-ui/components/task-messages/task-message-reasoning-content.tsx b/agentex-ui/components/task-messages/task-message-reasoning-content.tsx index cdeed4db..1a0b8360 100644 --- a/agentex-ui/components/task-messages/task-message-reasoning-content.tsx +++ b/agentex-ui/components/task-messages/task-message-reasoning-content.tsx @@ -59,10 +59,9 @@ function TaskMessageReasoningImpl({ message }: TaskMessageReasoningProps) { if (message.content.type !== 'reasoning') { throw new Error('Message content is not a ReasoningContent'); } - return [ - ...(message.content.content ?? []), - ...(message.content.summary ?? []), - ].join('\n\n'); + const content = message.content.content ?? []; + const summary = message.content.summary ?? []; + return (content.length > 0 ? content : summary).join('\n\n'); }, [message.content]); const updateBlurEffects = () => { diff --git a/agentex-ui/components/task-messages/task-messages.tsx b/agentex-ui/components/task-messages/task-messages.tsx index 21075cca..98cdd172 100644 --- a/agentex-ui/components/task-messages/task-messages.tsx +++ b/agentex-ui/components/task-messages/task-messages.tsx @@ -151,17 +151,34 @@ function TaskMessagesImpl({ const shouldShowThinkingForLastPair = useMemo(() => { if (messagePairs.length === 0) return false; + if (rpcStatus !== 'pending' && rpcStatus !== 'success') return false; const lastPair = messagePairs[messagePairs.length - 1]!; - const hasNoAgentMessages = lastPair.agentMessages.length === 0; - const hasUserMessage = lastPair.userMessage !== null; - - return ( - hasUserMessage && - hasNoAgentMessages && - (rpcStatus === 'pending' || rpcStatus === 'success') - ); - }, [messagePairs, rpcStatus]); + + // No agent messages yet — waiting for first response + if (lastPair.agentMessages.length === 0) { + return lastPair.userMessage !== null; + } + + const lastAgentMessage = + lastPair.agentMessages[lastPair.agentMessages.length - 1]!; + const lastType = lastAgentMessage.content.type; + + // Already have text streaming or complete — not "thinking" + if (lastType === 'text') return false; + + // Tool or reasoning still in progress — show their own indicator, not "Thinking..." + if (lastAgentMessage.streaming_status === 'IN_PROGRESS') return false; + if ( + lastType === 'tool_request' && + pendingToolCallIds.has(lastAgentMessage.content.tool_call_id) + ) + return false; + + // Last message is a completed tool_request, tool_response, reasoning, or data + // with no following text — agent is thinking about the next step + return true; + }, [messagePairs, rpcStatus, pendingToolCallIds]); // Measure container height for last-pair min-height useEffect(() => { diff --git a/agentex-ui/hooks/use-spans.ts b/agentex-ui/hooks/use-spans.ts index 34539aa5..f2213186 100644 --- a/agentex-ui/hooks/use-spans.ts +++ b/agentex-ui/hooks/use-spans.ts @@ -8,8 +8,8 @@ import type { Span } from 'agentex/resources'; export const spansKeys = { all: ['spans'] as const, - byTraceId: (traceId: string | null) => - traceId ? ([...spansKeys.all, traceId] as const) : spansKeys.all, + byTaskId: (taskId: string | null) => + taskId ? ([...spansKeys.all, 'task', taskId] as const) : spansKeys.all, }; type UseSpansState = { @@ -21,24 +21,37 @@ type UseSpansState = { /** * Fetches execution spans for observability and debugging of task execution. * - * Spans are OpenTelemetry-style trace records that show the execution flow of an agent task. - * The query is automatically disabled when no traceId is provided. + * Queries by task_id first. Falls back to trace_id=taskId for backward + * compatibility with spans created before the task_id column was added. * - * @param traceId - string | null - The trace ID to fetch spans for, or null to disable the query + * @param taskId - string | null - The task ID to fetch spans for, or null to disable the query * @returns UseSpansState - Object containing the spans array, loading state, and any error message */ -export function useSpans(traceId: string | null): UseSpansState { +export function useSpans(taskId: string | null): UseSpansState { const { agentexClient } = useAgentexClient(); const { data, isLoading, error } = useQuery({ - queryKey: spansKeys.byTraceId(traceId), + queryKey: spansKeys.byTaskId(taskId), queryFn: async ({ signal }) => { - if (!traceId) { + if (!taskId) { return []; } - return await agentexClient.spans.list({ trace_id: traceId }, { signal }); + + // task_id is not yet in the SDK types (SDK update pending), but the + // server already accepts it — cast until the SDK is regenerated. + const spansByTaskId = await agentexClient.spans.list( + { task_id: taskId } as Parameters[0], + { signal } + ); + + if (spansByTaskId.length > 0) { + return spansByTaskId; + } + + // Fallback: query by trace_id=taskId for backward compat with old spans + return await agentexClient.spans.list({ trace_id: taskId }, { signal }); }, - enabled: traceId !== null, + enabled: taskId !== null, }); return { diff --git a/agentex/database/migrations/alembic/versions/2026_04_14_1126_add_task_id_to_spans_57c5ed4f59ae.py b/agentex/database/migrations/alembic/versions/2026_04_14_1126_add_task_id_to_spans_57c5ed4f59ae.py new file mode 100644 index 00000000..06ebcebc --- /dev/null +++ b/agentex/database/migrations/alembic/versions/2026_04_14_1126_add_task_id_to_spans_57c5ed4f59ae.py @@ -0,0 +1,52 @@ +"""add_task_id_to_spans + +Revision ID: 57c5ed4f59ae +Revises: 4a9b7787ccd7 +Create Date: 2026-04-14 11:26:45.193515 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '57c5ed4f59ae' +down_revision: Union[str, None] = '4a9b7787ccd7' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add nullable task_id column first (no FK yet, so backfill can run freely) + op.add_column('spans', sa.Column('task_id', sa.String(), nullable=True)) + + # Backfill task_id from trace_id where trace_id is a valid task ID. + # Uses a JOIN instead of a subquery for efficient matching. + op.execute(""" + UPDATE spans + SET task_id = spans.trace_id + FROM tasks + WHERE spans.trace_id = tasks.id + AND spans.task_id IS NULL + """) + + # Add FK constraint after backfill (NULL values are allowed by FK) + op.create_foreign_key( + 'fk_spans_task_id_tasks', + 'spans', + 'tasks', + ['task_id'], + ['id'], + ondelete='SET NULL', + ) + + # Add index for querying spans by task_id + op.create_index('ix_spans_task_id', 'spans', ['task_id']) + + +def downgrade() -> None: + op.drop_index('ix_spans_task_id', table_name='spans') + op.drop_constraint('fk_spans_task_id_tasks', 'spans', type_='foreignkey') + op.drop_column('spans', 'task_id') diff --git a/agentex/src/adapters/orm.py b/agentex/src/adapters/orm.py index 8fbd6d08..ac5ee39a 100644 --- a/agentex/src/adapters/orm.py +++ b/agentex/src/adapters/orm.py @@ -150,6 +150,7 @@ class SpanORM(BaseORM): __tablename__ = "spans" id = Column(String, primary_key=True, default=orm_id) # Using UUIDs for IDs trace_id = Column(String, nullable=False) + task_id = Column(String, ForeignKey("tasks.id", ondelete="SET NULL"), nullable=True) parent_id = Column(String, nullable=True) name = Column(String, nullable=False) start_time = Column(DateTime(timezone=True), nullable=False) @@ -166,6 +167,8 @@ class SpanORM(BaseORM): Index("ix_spans_trace_id_start_time", "trace_id", "start_time"), # Index for traversing span hierarchy Index("ix_spans_parent_id", "parent_id"), + # Index for filtering spans by task_id + Index("ix_spans_task_id", "task_id"), ) diff --git a/agentex/src/api/routes/spans.py b/agentex/src/api/routes/spans.py index 65ef1d11..888b71ac 100644 --- a/agentex/src/api/routes/spans.py +++ b/agentex/src/api/routes/spans.py @@ -23,6 +23,7 @@ async def create_span( return await span_use_case.create( id=request.id, trace_id=request.trace_id, + task_id=request.task_id, name=request.name, parent_id=request.parent_id, start_time=request.start_time, @@ -48,6 +49,7 @@ async def partial_update_span( return await span_use_case.partial_update( id=span_id, trace_id=request.trace_id, + task_id=request.task_id, name=request.name, parent_id=request.parent_id, start_time=request.start_time, @@ -80,17 +82,19 @@ async def get_span( async def list_spans( span_use_case: DSpanUseCase, trace_id: str | None = None, + task_id: str | None = None, limit: int = Query(default=50, ge=1, le=1000), page_number: int = Query(default=1, ge=1), order_by: str | None = None, order_direction: str = "desc", ) -> list[Span]: """ - List all spans for a given trace ID + List spans, optionally filtered by trace_id and/or task_id """ - logger.info(f"Listing spans for trace ID: {trace_id}") + logger.info(f"Listing spans for trace_id={trace_id}, task_id={task_id}") spans = await span_use_case.list( trace_id=trace_id, + task_id=task_id, limit=limit, page_number=page_number, order_by=order_by, diff --git a/agentex/src/api/schemas/spans.py b/agentex/src/api/schemas/spans.py index 8826b8b4..6f493e7b 100644 --- a/agentex/src/api/schemas/spans.py +++ b/agentex/src/api/schemas/spans.py @@ -17,6 +17,11 @@ class CreateSpanRequest(BaseModel): title="The trace ID for this span", description="Unique identifier for the trace this span belongs to", ) + task_id: str | None = Field( + None, + title="The task ID this span is associated with", + description="ID of the task this span belongs to", + ) parent_id: str | None = Field( None, title="The parent span ID if this is a child span", @@ -56,6 +61,11 @@ class UpdateSpanRequest(BaseModel): title="The trace ID for this span", description="Unique identifier for the trace this span belongs to", ) + task_id: str | None = Field( + None, + title="The task ID this span is associated with", + description="ID of the task this span belongs to", + ) parent_id: str | None = Field( None, title="The parent span ID if this is a child span", diff --git a/agentex/src/domain/entities/spans.py b/agentex/src/domain/entities/spans.py index 002f0d2b..76243704 100644 --- a/agentex/src/domain/entities/spans.py +++ b/agentex/src/domain/entities/spans.py @@ -15,6 +15,10 @@ class SpanEntity(BaseModel): ..., title="The trace ID for this span", ) + task_id: str | None = Field( + None, + title="The task ID this span is associated with", + ) parent_id: str | None = Field( None, title="The parent span ID if this is a child span", diff --git a/agentex/src/domain/use_cases/spans_use_case.py b/agentex/src/domain/use_cases/spans_use_case.py index 432068af..8faff82c 100644 --- a/agentex/src/domain/use_cases/spans_use_case.py +++ b/agentex/src/domain/use_cases/spans_use_case.py @@ -21,6 +21,7 @@ async def create( name: str, trace_id: str, id: str | None = None, + task_id: str | None = None, parent_id: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, @@ -38,6 +39,7 @@ async def create( span = SpanEntity( id=id, trace_id=trace_id, + task_id=task_id, parent_id=parent_id, name=name, start_time=start_time, @@ -52,6 +54,7 @@ async def partial_update( self, id: str, trace_id: str | None = None, + task_id: str | None = None, name: str | None = None, parent_id: str | None = None, start_time: datetime | None = None, @@ -70,6 +73,9 @@ async def partial_update( if trace_id is not None: span.trace_id = trace_id + if task_id is not None: + span.task_id = task_id + if name is not None: span.name = name @@ -108,19 +114,20 @@ async def list( limit: int, page_number: int, trace_id: str | None = None, + task_id: str | None = None, order_by: str | None = None, order_direction: str = "desc", ) -> list[SpanEntity]: """ - List all spans for a given trace ID + List spans, optionally filtered by trace_id and/or task_id """ - # Note: This would require custom implementation in the repository - # or filtering after fetching all spans - - if trace_id: - filters = {"trace_id": trace_id} - else: - filters = None + filters: dict[str, str] | None = None + if trace_id or task_id: + filters = {} + if trace_id: + filters["trace_id"] = trace_id + if task_id: + filters["task_id"] = task_id return await self.span_repo.list( filters=filters, limit=limit, diff --git a/agentex/tests/integration/api/spans/test_spans_api.py b/agentex/tests/integration/api/spans/test_spans_api.py index 82b0244e..d606c756 100644 --- a/agentex/tests/integration/api/spans/test_spans_api.py +++ b/agentex/tests/integration/api/spans/test_spans_api.py @@ -7,7 +7,9 @@ import pytest import pytest_asyncio +from src.domain.entities.agents import ACPType, AgentEntity from src.domain.entities.spans import SpanEntity +from src.domain.entities.tasks import TaskEntity from src.utils.ids import orm_id @@ -15,9 +17,43 @@ class TestSpansAPIIntegration: """Integration tests for span endpoints using API-first validation""" + @pytest_asyncio.fixture + async def test_agent(self, isolated_repositories): + """Create a test agent for task creation.""" + agent_repo = isolated_repositories["agent_repository"] + return await agent_repo.create( + AgentEntity( + id=orm_id(), + name="spans-test-agent", + description="Agent for span integration tests", + acp_url="http://test:8000", + acp_type=ACPType.SYNC, + ) + ) + + @pytest_asyncio.fixture + async def test_tasks(self, isolated_repositories, test_agent): + """Create test tasks that can be referenced by spans via FK.""" + task_repo = isolated_repositories["task_repository"] + tasks = {} + for name in [ + "task-a", + "task-b", + "task-x", + "task-y", + "task-create", + "task-update", + ]: + task = await task_repo.create( + agent_id=test_agent.id, + task=TaskEntity(id=orm_id(), name=name), + ) + tasks[name] = task + return tasks + @pytest_asyncio.fixture async def test_pagination_spans(self, isolated_repositories): - """Create a test task for message creation""" + """Create spans for pagination tests""" span_repo = isolated_repositories["span_repository"] spans = [] for i in range(60): @@ -30,11 +66,16 @@ async def test_pagination_spans(self, isolated_repositories): spans.append(await span_repo.create(span)) return spans - async def test_create_and_retrieve_span_consistency(self, isolated_client): + async def test_create_and_retrieve_span_consistency( + self, isolated_client, test_tasks + ): """Test span creation and validate POST → GET consistency (API-first)""" + task_id = test_tasks["task-create"].id + # Given - Span creation data span_data = { "trace_id": "test-trace-123", + "task_id": task_id, "name": "test-operation", "start_time": "2024-01-01T10:00:00Z", "end_time": "2024-01-01T10:00:05Z", @@ -53,6 +94,7 @@ async def test_create_and_retrieve_span_consistency(self, isolated_client): # Validate response has required fields assert "id" in created_span assert created_span["trace_id"] == span_data["trace_id"] + assert created_span["task_id"] == task_id assert created_span["name"] == span_data["name"] span_id = created_span["id"] @@ -64,12 +106,28 @@ async def test_create_and_retrieve_span_consistency(self, isolated_client): # Validate POST/GET consistency assert retrieved_span["id"] == span_id assert retrieved_span["trace_id"] == span_data["trace_id"] + assert retrieved_span["task_id"] == task_id assert retrieved_span["name"] == span_data["name"] assert retrieved_span["input"] == span_data["input"] assert retrieved_span["output"] == span_data["output"] - async def test_update_span_and_validate_changes(self, isolated_client): + async def test_create_span_without_task_id(self, isolated_client): + """Test span creation without task_id (should default to null)""" + span_data = { + "trace_id": "test-trace-no-task", + "name": "test-no-task", + "start_time": "2024-01-01T10:00:00Z", + } + + create_response = await isolated_client.post("/spans", json=span_data) + assert create_response.status_code == 200 + created_span = create_response.json() + assert created_span["task_id"] is None + + async def test_update_span_and_validate_changes(self, isolated_client, test_tasks): """Test span update and validate PATCH → GET consistency""" + task_id = test_tasks["task-update"].id + # Given - Create a span first initial_data = { "trace_id": "update-trace-456", @@ -80,9 +138,10 @@ async def test_update_span_and_validate_changes(self, isolated_client): assert create_response.status_code == 200 span_id = create_response.json()["id"] - # When - Update the span + # When - Update the span including task_id update_data = { "name": "updated-name", + "task_id": task_id, "parent_id": "parent-id", "start_time": "2024-01-01T10:10:00Z", "end_time": "2024-01-01T10:10:05Z", @@ -104,6 +163,7 @@ async def test_update_span_and_validate_changes(self, isolated_client): # Validate changes were applied assert updated_span["name"] == "updated-name" + assert updated_span["task_id"] == task_id assert updated_span["output"]["status"] == "completed" assert updated_span["parent_id"] == "parent-id" assert updated_span["start_time"] == "2024-01-01T10:10:00Z" @@ -123,6 +183,7 @@ async def test_update_span_and_validate_changes(self, isolated_client): assert patch_response.status_code == 200 updated_span = patch_response.json() assert updated_span["name"] == "updated-name" + assert updated_span["task_id"] == task_id # Still set from prior update assert updated_span["output"]["status"] == "completed" assert updated_span["parent_id"] == "parent-id" assert updated_span["start_time"] == "2024-01-01T10:10:00Z" @@ -131,7 +192,7 @@ async def test_update_span_and_validate_changes(self, isolated_client): assert updated_span["trace_id"] == "updated-trace-789" assert updated_span["data"] == {"test": True, "version": "2.0.0"} - async def test_list_spans_with_filtering(self, isolated_client): + async def test_list_spans_with_trace_id_filtering(self, isolated_client): """Test list spans endpoint with trace_id filtering""" # Given - Create spans with different trace_ids trace_id_1 = "list-trace-001" @@ -173,6 +234,114 @@ async def test_list_spans_with_filtering(self, isolated_client): for span in spans: assert span["trace_id"] == trace_id_1 + async def test_list_spans_with_task_id_filtering(self, isolated_client, test_tasks): + """Test list spans endpoint with task_id filtering""" + task_id_a = test_tasks["task-a"].id + task_id_b = test_tasks["task-b"].id + + for i in range(3): + resp = await isolated_client.post( + "/spans", + json={ + "trace_id": f"trace-task-filter-{i}", + "task_id": task_id_a, + "name": f"span-task-a-{i}", + "start_time": "2024-01-01T10:00:00Z", + }, + ) + assert resp.status_code == 200 + + for i in range(2): + resp = await isolated_client.post( + "/spans", + json={ + "trace_id": f"trace-task-filter-b-{i}", + "task_id": task_id_b, + "name": f"span-task-b-{i}", + "start_time": "2024-01-01T10:00:00Z", + }, + ) + assert resp.status_code == 200 + + # One span with no task_id + resp = await isolated_client.post( + "/spans", + json={ + "trace_id": "trace-no-task", + "name": "span-no-task", + "start_time": "2024-01-01T10:00:00Z", + }, + ) + assert resp.status_code == 200 + + # When - Filter by task_id_a + response = await isolated_client.get(f"/spans?task_id={task_id_a}") + assert response.status_code == 200 + spans = response.json() + assert len(spans) == 3 + for span in spans: + assert span["task_id"] == task_id_a + + # When - Filter by task_id_b + response = await isolated_client.get(f"/spans?task_id={task_id_b}") + assert response.status_code == 200 + spans = response.json() + assert len(spans) == 2 + for span in spans: + assert span["task_id"] == task_id_b + + # When - No filter returns all 6 + response = await isolated_client.get("/spans") + assert response.status_code == 200 + assert len(response.json()) == 6 + + async def test_list_spans_with_combined_trace_and_task_filtering( + self, isolated_client, test_tasks + ): + """Test list spans with both trace_id and task_id filters""" + shared_trace = "combined-trace" + task_id_x = test_tasks["task-x"].id + task_id_y = test_tasks["task-y"].id + + await isolated_client.post( + "/spans", + json={ + "trace_id": shared_trace, + "task_id": task_id_x, + "name": "span-match", + "start_time": "2024-01-01T10:00:00Z", + }, + ) + await isolated_client.post( + "/spans", + json={ + "trace_id": shared_trace, + "task_id": task_id_y, + "name": "span-same-trace-diff-task", + "start_time": "2024-01-01T10:00:00Z", + }, + ) + await isolated_client.post( + "/spans", + json={ + "trace_id": "other-trace", + "task_id": task_id_x, + "name": "span-diff-trace-same-task", + "start_time": "2024-01-01T10:00:00Z", + }, + ) + + # When - Filter by both trace_id and task_id + response = await isolated_client.get( + f"/spans?trace_id={shared_trace}&task_id={task_id_x}" + ) + assert response.status_code == 200 + spans = response.json() + assert len(spans) == 1 + assert spans[0]["name"] == "span-match" + assert spans[0]["trace_id"] == shared_trace + assert spans[0]["task_id"] == task_id_x + async def test_get_span_non_existent(self, isolated_client): """Test getting a non-existent span returns 404""" # When - Get a non-existent span diff --git a/agentex/tests/unit/repositories/test_span_repository.py b/agentex/tests/unit/repositories/test_span_repository.py index e18ca541..806a969a 100644 --- a/agentex/tests/unit/repositories/test_span_repository.py +++ b/agentex/tests/unit/repositories/test_span_repository.py @@ -11,7 +11,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..", "src")) -from adapters.orm import BaseORM +from adapters.orm import BaseORM, TaskORM from domain.entities.spans import SpanEntity from domain.repositories.span_repository import SpanRepository from utils.ids import orm_id @@ -48,6 +48,12 @@ async def test_span_repository_crud_operations(postgres_url): async_session_maker = async_sessionmaker(engine, expire_on_commit=False) span_repo = SpanRepository(async_session_maker, async_session_maker) + # Create a task row to satisfy the FK constraint on spans.task_id + task_id = orm_id() + async with async_session_maker() as session: + session.add(TaskORM(id=task_id, name="test-task")) + await session.commit() + # Test CREATE operation with JSON fields now = datetime.now(UTC) span_id = orm_id() @@ -56,6 +62,7 @@ async def test_span_repository_crud_operations(postgres_url): span = SpanEntity( id=span_id, trace_id=trace_id, + task_id=task_id, parent_id=None, name="test-span-operation", start_time=now, @@ -68,6 +75,7 @@ async def test_span_repository_crud_operations(postgres_url): created_span = await span_repo.create(span) assert created_span.id == span_id assert created_span.trace_id == trace_id + assert created_span.task_id == task_id assert created_span.name == "test-span-operation" assert created_span.input["operation"] == "test" assert created_span.data["metadata"]["version"] == "1.0" @@ -78,6 +86,7 @@ async def test_span_repository_crud_operations(postgres_url): updated_span = SpanEntity( id=span_id, trace_id=trace_id, + task_id=task_id, parent_id=None, name="test-span-operation", start_time=now, @@ -108,6 +117,7 @@ async def test_span_repository_crud_operations(postgres_url): child_span = SpanEntity( id=child_span_id, trace_id=trace_id, + task_id=task_id, parent_id=span_id, # Child of the first span name="child-span-operation", start_time=child_start_time, @@ -145,3 +155,49 @@ async def test_span_repository_crud_operations(postgres_url): print("✅ Test isolation provided by session-scoped PostgreSQL container") print("🎉 ALL SPAN REPOSITORY TESTS PASSED!") + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_span_task_id_set_null_on_task_delete(postgres_url): + """Deleting a referenced task should null out spans.task_id, not fail with FK violation.""" + + sqlalchemy_asyncpg_url = postgres_url.replace( + "postgresql+psycopg2://", "postgresql+asyncpg://" + ) + + engine = create_async_engine(sqlalchemy_asyncpg_url, echo=False) + async with engine.begin() as conn: + await conn.run_sync(BaseORM.metadata.create_all) + + async_session_maker = async_sessionmaker(engine, expire_on_commit=False) + span_repo = SpanRepository(async_session_maker, async_session_maker) + + # Seed a task and a span referencing it + task_id = orm_id() + span_id = orm_id() + async with async_session_maker() as session: + session.add(TaskORM(id=task_id, name="task-to-delete")) + await session.commit() + + await span_repo.create( + SpanEntity( + id=span_id, + trace_id=orm_id(), + task_id=task_id, + parent_id=None, + name="span-with-task-fk", + start_time=datetime.now(UTC), + ) + ) + + # Delete the task — should succeed, not raise a FK violation + async with async_session_maker() as session: + task = await session.get(TaskORM, task_id) + await session.delete(task) + await session.commit() + + # Span should survive with task_id set to NULL + retrieved = await span_repo.get(id=span_id) + assert retrieved is not None + assert retrieved.task_id is None diff --git a/agentex/tests/unit/repositories/test_task_repository.py b/agentex/tests/unit/repositories/test_task_repository.py index b3ff59e5..eab1b0f1 100644 --- a/agentex/tests/unit/repositories/test_task_repository.py +++ b/agentex/tests/unit/repositories/test_task_repository.py @@ -63,10 +63,12 @@ async def test_task_repository_crud_operations(postgres_url): agent_repo = AgentRepository(async_session_maker, async_session_maker) # First, create an agent (required for task creation) + # Use unique names to avoid collisions with other tests sharing the same session-scoped DB agent_id = orm_id() + unique_suffix = agent_id[:8] agent = AgentEntity( id=agent_id, - name="test-agent-for-tasks", + name=f"test-agent-for-tasks-{unique_suffix}", description="Test agent for task repository testing", docker_image="test/agent:latest", status=AgentStatus.READY, @@ -81,15 +83,16 @@ async def test_task_repository_crud_operations(postgres_url): task_id = orm_id() task = TaskEntity( id=task_id, - name="test-task", + name=f"test-task-{unique_suffix}", status=TaskStatus.RUNNING, status_reason="Task is running for testing", ) # Test CREATE operation + task_name = f"test-task-{unique_suffix}" created_task = await task_repo.create(agent_id, task) assert created_task.id == task_id - assert created_task.name == "test-task" + assert created_task.name == task_name assert created_task.status == TaskStatus.RUNNING assert created_task.status_reason == "Task is running for testing" assert created_task.created_at is not None @@ -104,9 +107,9 @@ async def test_task_repository_crud_operations(postgres_url): print("✅ GET by ID operation successful") # Test GET operation by name - retrieved_task_by_name = await task_repo.get(name="test-task") + retrieved_task_by_name = await task_repo.get(name=task_name) assert retrieved_task_by_name.id == created_task.id - assert retrieved_task_by_name.name == "test-task" + assert retrieved_task_by_name.name == task_name print("✅ GET by name operation successful") # Test GET agent by task ID @@ -118,7 +121,7 @@ async def test_task_repository_crud_operations(postgres_url): # Test UPDATE operation updated_task = TaskEntity( id=task_id, - name="test-task", # Keep same name + name=task_name, # Keep same name status=TaskStatus.COMPLETED, status_reason="Task completed successfully", ) @@ -140,7 +143,7 @@ async def test_task_repository_crud_operations(postgres_url): task_id_2 = orm_id() task_2 = TaskEntity( id=task_id_2, - name="test-task-2", + name=f"test-task-2-{unique_suffix}", status=TaskStatus.FAILED, status_reason="Second test task", ) @@ -211,10 +214,12 @@ async def test_task_repository_params_support(postgres_url): agent_repo = AgentRepository(async_session_maker, async_session_maker) # First, create an agent (required for task creation) + # Use unique names to avoid collisions with other tests sharing the same session-scoped DB agent_id = orm_id() + unique_suffix = agent_id[:8] agent = AgentEntity( id=agent_id, - name="test-agent-params", + name=f"test-agent-params-{unique_suffix}", description="Test agent for params testing", docker_image="test/agent:latest", status=AgentStatus.READY, @@ -233,9 +238,10 @@ async def test_task_repository_params_support(postgres_url): "max_tokens": 1000, "nested": {"key": "value", "number": 42}, } + task_name = f"test-task-with-params-{unique_suffix}" task = TaskEntity( id=task_id, - name="test-task-with-params", + name=task_name, status=TaskStatus.RUNNING, status_reason="Task with params for testing", params=task_params, @@ -244,7 +250,7 @@ async def test_task_repository_params_support(postgres_url): # Create task with params created_task = await task_repo.create(agent_id, task) assert created_task.id == task_id - assert created_task.name == "test-task-with-params" + assert created_task.name == task_name assert created_task.params == task_params print("✅ CREATE operation with params successful") @@ -256,7 +262,7 @@ async def test_task_repository_params_support(postgres_url): print("✅ GET by ID operation preserves params") # Test GET operation by name preserves params - retrieved_task_by_name = await task_repo.get(name="test-task-with-params") + retrieved_task_by_name = await task_repo.get(name=task_name) assert retrieved_task_by_name.id == created_task.id assert retrieved_task_by_name.params == task_params print("✅ GET by name operation preserves params") @@ -270,7 +276,7 @@ async def test_task_repository_params_support(postgres_url): } updated_task = TaskEntity( id=task_id, - name="test-task-with-params", + name=task_name, status=TaskStatus.COMPLETED, status_reason="Task completed with updated params", params=updated_params, @@ -293,7 +299,7 @@ async def test_task_repository_params_support(postgres_url): task_id_null = orm_id() task_null_params = TaskEntity( id=task_id_null, - name="test-task-null-params", + name=f"test-task-null-params-{unique_suffix}", status=TaskStatus.RUNNING, status_reason="Task with null params", params=None, @@ -349,10 +355,12 @@ async def test_task_repository_task_metadata_support(postgres_url): agent_repo = AgentRepository(async_session_maker, async_session_maker) # First, create an agent (required for task creation) + # Use unique names to avoid collisions with other tests sharing the same session-scoped DB agent_id = orm_id() + unique_suffix = agent_id[:8] agent = AgentEntity( id=agent_id, - name="test-agent-metadata", + name=f"test-agent-metadata-{unique_suffix}", description="Test agent for task_metadata testing", docker_image="test/agent:latest", status=AgentStatus.READY, @@ -397,9 +405,10 @@ async def test_task_repository_task_metadata_support(postgres_url): "numeric_precision": 123.456789, }, } + task_name = f"test-task-with-metadata-{unique_suffix}" task = TaskEntity( id=task_id, - name="test-task-with-metadata", + name=task_name, status=TaskStatus.RUNNING, status_reason="Task with task_metadata for testing", task_metadata=task_metadata, @@ -408,7 +417,7 @@ async def test_task_repository_task_metadata_support(postgres_url): # Create task with task_metadata created_task = await task_repo.create(agent_id, task) assert created_task.id == task_id - assert created_task.name == "test-task-with-metadata" + assert created_task.name == task_name assert created_task.task_metadata == task_metadata print("✅ CREATE operation with task_metadata successful") @@ -420,7 +429,7 @@ async def test_task_repository_task_metadata_support(postgres_url): print("✅ GET by ID operation preserves task_metadata") # Test GET operation by name preserves task_metadata - retrieved_task_by_name = await task_repo.get(name="test-task-with-metadata") + retrieved_task_by_name = await task_repo.get(name=task_name) assert retrieved_task_by_name.id == created_task.id assert retrieved_task_by_name.task_metadata == task_metadata print("✅ GET by name operation preserves task_metadata") @@ -461,7 +470,7 @@ async def test_task_repository_task_metadata_support(postgres_url): } updated_task = TaskEntity( id=task_id, - name="test-task-with-metadata", + name=task_name, status=TaskStatus.COMPLETED, status_reason="Task completed with updated task_metadata", task_metadata=updated_metadata, @@ -484,7 +493,7 @@ async def test_task_repository_task_metadata_support(postgres_url): task_id_null = orm_id() task_null_metadata = TaskEntity( id=task_id_null, - name="test-task-null-metadata", + name=f"test-task-null-metadata-{unique_suffix}", status=TaskStatus.RUNNING, status_reason="Task with null task_metadata", task_metadata=None, @@ -540,10 +549,12 @@ async def test_task_repository_null_task_metadata_handling(postgres_url): agent_repo = AgentRepository(async_session_maker, async_session_maker) # First, create an agent (required for task creation) + # Use unique names to avoid collisions with other tests sharing the same session-scoped DB agent_id = orm_id() + unique_suffix = agent_id[:8] agent = AgentEntity( id=agent_id, - name="test-agent-null-metadata", + name=f"test-agent-null-metadata-{unique_suffix}", description="Test agent for null task_metadata testing", docker_image="test/agent:latest", status=AgentStatus.READY, @@ -556,9 +567,10 @@ async def test_task_repository_null_task_metadata_handling(postgres_url): # Test CREATE with task_metadata=None task_id_null = orm_id() + task_name = f"test-task-null-metadata-handling-{unique_suffix}" task_null = TaskEntity( id=task_id_null, - name="test-task-null-metadata-handling", + name=task_name, status=TaskStatus.RUNNING, status_reason="Task with null task_metadata", task_metadata=None, @@ -575,9 +587,7 @@ async def test_task_repository_null_task_metadata_handling(postgres_url): assert retrieved_null_task.task_metadata is None print("✅ Retrieval preserves null task_metadata") - retrieved_null_by_name = await task_repo.get( - name="test-task-null-metadata-handling" - ) + retrieved_null_by_name = await task_repo.get(name=task_name) assert retrieved_null_by_name.id == task_id_null assert retrieved_null_by_name.task_metadata is None print("✅ Retrieval by name preserves null task_metadata") @@ -594,7 +604,7 @@ async def test_task_repository_null_task_metadata_handling(postgres_url): } updated_task = TaskEntity( id=task_id_null, - name="test-task-null-metadata-handling", + name=task_name, status=TaskStatus.RUNNING, status_reason="Task updated with populated task_metadata", task_metadata=populated_metadata, @@ -613,7 +623,7 @@ async def test_task_repository_null_task_metadata_handling(postgres_url): # Test UPDATE from populated back to null task_metadata updated_back_to_null = TaskEntity( id=task_id_null, - name="test-task-null-metadata-handling", + name=task_name, status=TaskStatus.COMPLETED, status_reason="Task updated back to null task_metadata", task_metadata=None, @@ -686,10 +696,13 @@ async def test_list_with_join_includes_task_metadata(postgres_url): task_repo = TaskRepository(async_session_maker, async_session_maker) agent_repo = AgentRepository(async_session_maker, async_session_maker) + # Use unique names to avoid collisions with other tests sharing the same session-scoped DB + unique_suffix = orm_id()[:8] + # Create test agents agent_1 = AgentEntity( id=orm_id(), - name="agent-with-metadata-tasks", + name=f"agent-with-metadata-tasks-{unique_suffix}", description="Test agent for task metadata join testing", docker_image="test/agent:latest", status=AgentStatus.READY, @@ -700,7 +713,7 @@ async def test_list_with_join_includes_task_metadata(postgres_url): agent_2 = AgentEntity( id=orm_id(), - name="agent-with-null-metadata-tasks", + name=f"agent-with-null-metadata-tasks-{unique_suffix}", description="Test agent for null task metadata join testing", docker_image="test/agent:latest", status=AgentStatus.READY, @@ -712,7 +725,7 @@ async def test_list_with_join_includes_task_metadata(postgres_url): # Create tasks with task_metadata task_with_metadata_1 = TaskEntity( id=orm_id(), - name="task-with-metadata-1", + name=f"task-with-metadata-1-{unique_suffix}", status=TaskStatus.RUNNING, status_reason="Task with metadata for join testing", task_metadata={ @@ -726,7 +739,7 @@ async def test_list_with_join_includes_task_metadata(postgres_url): task_with_metadata_2 = TaskEntity( id=orm_id(), - name="task-with-metadata-2", + name=f"task-with-metadata-2-{unique_suffix}", status=TaskStatus.FAILED, status_reason="Another task with metadata", task_metadata={ @@ -741,7 +754,7 @@ async def test_list_with_join_includes_task_metadata(postgres_url): # Create tasks without task_metadata (null) task_without_metadata_1 = TaskEntity( id=orm_id(), - name="task-without-metadata-1", + name=f"task-without-metadata-1-{unique_suffix}", status=TaskStatus.RUNNING, status_reason="Task without metadata", task_metadata=None, @@ -750,7 +763,7 @@ async def test_list_with_join_includes_task_metadata(postgres_url): task_without_metadata_2 = TaskEntity( id=orm_id(), - name="task-without-metadata-2", + name=f"task-without-metadata-2-{unique_suffix}", status=TaskStatus.COMPLETED, status_reason="Another task without metadata", task_metadata=None, @@ -765,25 +778,25 @@ async def test_list_with_join_includes_task_metadata(postgres_url): tasks_by_name = {task.name: task for task in all_tasks} # Verify tasks with metadata - assert "task-with-metadata-1" in tasks_by_name - metadata_task_1 = tasks_by_name["task-with-metadata-1"] + assert task_with_metadata_1.name in tasks_by_name + metadata_task_1 = tasks_by_name[task_with_metadata_1.name] assert metadata_task_1.task_metadata is not None assert metadata_task_1.task_metadata["priority"] == "high" assert metadata_task_1.task_metadata["category"] == "testing" - assert "task-with-metadata-2" in tasks_by_name - metadata_task_2 = tasks_by_name["task-with-metadata-2"] + assert task_with_metadata_2.name in tasks_by_name + metadata_task_2 = tasks_by_name[task_with_metadata_2.name] assert metadata_task_2.task_metadata is not None assert metadata_task_2.task_metadata["priority"] == "medium" assert metadata_task_2.task_metadata["category"] == "integration" # Verify tasks without metadata (null) - assert "task-without-metadata-1" in tasks_by_name - null_task_1 = tasks_by_name["task-without-metadata-1"] + assert task_without_metadata_1.name in tasks_by_name + null_task_1 = tasks_by_name[task_without_metadata_1.name] assert null_task_1.task_metadata is None - assert "task-without-metadata-2" in tasks_by_name - null_task_2 = tasks_by_name["task-without-metadata-2"] + assert task_without_metadata_2.name in tasks_by_name + null_task_2 = tasks_by_name[task_without_metadata_2.name] assert null_task_2.task_metadata is None print("✅ list_with_join returns task_metadata for all tasks") @@ -801,7 +814,7 @@ async def test_list_with_join_includes_task_metadata(postgres_url): # Test filtering by agent_name agent_2_tasks = await task_repo.list_with_join( - agent_name="agent-with-null-metadata-tasks", order_direction="asc" + agent_name=agent_2.name, order_direction="asc" ) assert len(agent_2_tasks) == 2 for task in agent_2_tasks: @@ -828,9 +841,9 @@ async def test_list_with_join_includes_task_metadata(postgres_url): ) assert len(ordered_by_name) == 4 # Verify ordering is correct and task_metadata is preserved - assert ordered_by_name[0].name == "task-with-metadata-1" + assert ordered_by_name[0].name == task_with_metadata_1.name assert ordered_by_name[0].task_metadata is not None - assert ordered_by_name[3].name == "task-without-metadata-2" + assert ordered_by_name[3].name == task_without_metadata_2.name assert ordered_by_name[3].task_metadata is None print("✅ Ordering works correctly with task_metadata present") @@ -890,9 +903,12 @@ async def test_list_with_join(postgres_url): task_repo = TaskRepository(async_session_maker, async_session_maker) agent_repo = AgentRepository(async_session_maker, async_session_maker) + # Use unique names to avoid collisions with other tests sharing the same session-scoped DB + unique_suffix = orm_id()[:8] + agent_1 = AgentEntity( id=orm_id(), - name="agent-1", + name=f"agent-1-{unique_suffix}", description="Test agent for task repository testing", docker_image="test/agent:latest", status=AgentStatus.READY, @@ -903,7 +919,7 @@ async def test_list_with_join(postgres_url): agent_2 = AgentEntity( id=orm_id(), - name="agent-2", + name=f"agent-2-{unique_suffix}", description="Test agent for task repository testing", docker_image="test/agent:latest", status=AgentStatus.READY, @@ -914,7 +930,7 @@ async def test_list_with_join(postgres_url): task_1_1 = TaskEntity( id=orm_id(), - name="agent-1-task-1", + name=f"agent-1-task-1-{unique_suffix}", status=TaskStatus.RUNNING, status_reason="status reason b", ) @@ -922,7 +938,7 @@ async def test_list_with_join(postgres_url): task_1_2 = TaskEntity( id=orm_id(), - name="agent-1-task-2", + name=f"agent-1-task-2-{unique_suffix}", status=TaskStatus.FAILED, status_reason="status reason a", ) @@ -930,7 +946,7 @@ async def test_list_with_join(postgres_url): task_2_1 = TaskEntity( id=orm_id(), - name="agent-2-task-1", + name=f"agent-2-task-1-{unique_suffix}", status=TaskStatus.RUNNING, status_reason="status reason a", ) @@ -976,7 +992,7 @@ async def test_list_with_join(postgres_url): assert_task_lists_by_name( expected=[task_1_1, task_1_2], received=await task_repo.list_with_join( - agent_name="agent-1", + agent_name=agent_1.name, order_direction="asc", ), ) @@ -1086,7 +1102,7 @@ async def test_list_with_join(postgres_url): assert_task_lists_by_name( expected=[task_1_2], received=await task_repo.list_with_join( - agent_name="agent-1", + agent_name=agent_1.name, task_filters={"status": TaskStatus.FAILED}, order_direction="asc", ),