diff --git a/TablePro/Core/AI/AnthropicProvider.swift b/TablePro/Core/AI/AnthropicProvider.swift index d1f88c445..571f19510 100644 --- a/TablePro/Core/AI/AnthropicProvider.swift +++ b/TablePro/Core/AI/AnthropicProvider.swift @@ -43,74 +43,15 @@ final class AnthropicProvider: ChatTransport { ) } - var inputTokens = 0 - var outputTokens = 0 - var toolUseIdsByIndex: [Int: String] = [:] - + var state = AnthropicStreamState() for try await line in bytes.lines { if Task.isCancelled { break } - - guard line.hasPrefix("data: ") else { continue } - let jsonString = String(line.dropFirst(6)) - guard jsonString != "[DONE]", - let data = jsonString.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], - let type = json["type"] as? String - else { continue } - - switch type { - case "content_block_start": - if let index = json["index"] as? Int, - let block = json["content_block"] as? [String: Any], - (block["type"] as? String) == "tool_use", - let blockId = block["id"] as? String, - let blockName = block["name"] as? String { - toolUseIdsByIndex[index] = blockId - continuation.yield(.toolUseStart(id: blockId, name: blockName)) - } - case "content_block_delta": - guard let delta = json["delta"] as? [String: Any] else { break } - let deltaType = delta["type"] as? String - if deltaType == "input_json_delta" { - if let index = json["index"] as? Int, - let id = toolUseIdsByIndex[index], - let partial = delta["partial_json"] as? String { - continuation.yield(.toolUseDelta(id: id, inputJSONDelta: partial)) - } - } else if let text = delta["text"] as? String { - continuation.yield(.textDelta(text)) - } - case "content_block_stop": - if let index = json["index"] as? Int, - let id = toolUseIdsByIndex.removeValue(forKey: index) { - continuation.yield(.toolUseEnd(id: id)) - } - case "message_start": - if let message = json["message"] as? [String: Any], - let usage = message["usage"] as? [String: Any], - let tokens = usage["input_tokens"] as? Int { - inputTokens = tokens - } - case "message_delta": - if let usage = json["usage"] as? [String: Any], - let tokens = usage["output_tokens"] as? Int { - outputTokens = tokens - } - case "error": - if let errorObj = json["error"] as? [String: Any], - let message = errorObj["message"] as? String { - throw AIProviderError.streamingFailed(message) - } - default: - break - } + guard let json = Self.decodeStreamLine(line) else { continue } + let events = try Self.parseChunk(json, state: &state) + for event in events { continuation.yield(event) } } - - if inputTokens > 0 || outputTokens > 0 { - continuation.yield(.usage(AITokenUsage( - inputTokens: inputTokens, - outputTokens: outputTokens - ))) + if let usage = state.finalUsageEvent() { + continuation.yield(usage) } continuation.finish() @@ -220,6 +161,81 @@ final class AnthropicProvider: ChatTransport { return request } + /// Decodes one SSE line of the form `data: {...}` to a JSON object. + /// Returns `nil` for non-data lines, the `[DONE]` sentinel, and unparsable + /// payloads. Keeping this separate from `parseChunk` lets tests skip the + /// SSE framing and feed JSON dictionaries directly. + static func decodeStreamLine(_ line: String) -> [String: Any]? { + guard line.hasPrefix("data: ") else { return nil } + let jsonString = String(line.dropFirst(6)) + guard jsonString != "[DONE]", + let data = jsonString.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { return nil } + return json + } + + /// Translate a single Anthropic SSE event JSON into zero or more + /// `ChatStreamEvent`s. Mutates `state` to carry index→id mappings and + /// token counters across calls. Throws `AIProviderError.streamingFailed` + /// on `error` events. + static func parseChunk( + _ json: [String: Any], + state: inout AnthropicStreamState + ) throws -> [ChatStreamEvent] { + guard let type = json["type"] as? String else { return [] } + switch type { + case "content_block_start": + guard let index = json["index"] as? Int, + let block = json["content_block"] as? [String: Any], + (block["type"] as? String) == "tool_use", + let blockId = block["id"] as? String, + let blockName = block["name"] as? String + else { return [] } + state.toolUseIdsByIndex[index] = blockId + return [.toolUseStart(id: blockId, name: blockName)] + case "content_block_delta": + guard let delta = json["delta"] as? [String: Any] else { return [] } + if (delta["type"] as? String) == "input_json_delta" { + guard let index = json["index"] as? Int, + let id = state.toolUseIdsByIndex[index], + let partial = delta["partial_json"] as? String + else { return [] } + return [.toolUseDelta(id: id, inputJSONDelta: partial)] + } + if let text = delta["text"] as? String { + return [.textDelta(text)] + } + return [] + case "content_block_stop": + guard let index = json["index"] as? Int, + let id = state.toolUseIdsByIndex.removeValue(forKey: index) + else { return [] } + return [.toolUseEnd(id: id)] + case "message_start": + if let message = json["message"] as? [String: Any], + let usage = message["usage"] as? [String: Any], + let tokens = usage["input_tokens"] as? Int { + state.inputTokens = tokens + } + return [] + case "message_delta": + if let usage = json["usage"] as? [String: Any], + let tokens = usage["output_tokens"] as? Int { + state.outputTokens = tokens + } + return [] + case "error": + if let errorObj = json["error"] as? [String: Any], + let message = errorObj["message"] as? String { + throw AIProviderError.streamingFailed(message) + } + return [] + default: + return [] + } + } + static func encodeToolSpec(_ spec: ChatToolSpec) throws -> [String: Any] { [ "name": spec.name, @@ -282,3 +298,15 @@ final class AnthropicProvider: ChatTransport { return try JSONSerialization.jsonObject(with: data, options: [.fragmentsAllowed]) } } + +/// Mutable state carried across `AnthropicProvider.parseChunk` calls. +struct AnthropicStreamState { + var inputTokens: Int = 0 + var outputTokens: Int = 0 + var toolUseIdsByIndex: [Int: String] = [:] + + func finalUsageEvent() -> ChatStreamEvent? { + guard inputTokens > 0 || outputTokens > 0 else { return nil } + return .usage(AITokenUsage(inputTokens: inputTokens, outputTokens: outputTokens)) + } +} diff --git a/TablePro/Core/AI/GeminiProvider.swift b/TablePro/Core/AI/GeminiProvider.swift index 04eca318a..ebe7f3980 100644 --- a/TablePro/Core/AI/GeminiProvider.swift +++ b/TablePro/Core/AI/GeminiProvider.swift @@ -43,54 +43,19 @@ final class GeminiProvider: ChatTransport { ) } - var inputTokens = 0 - var outputTokens = 0 - + var state = GeminiStreamState() for try await line in bytes.lines { if Task.isCancelled { break } - - guard line.hasPrefix("data: ") else { continue } - let jsonString = String(line.dropFirst(6)) - - guard let data = jsonString.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] - else { continue } - - if let candidates = json["candidates"] as? [[String: Any]], - let firstCandidate = candidates.first, - let content = firstCandidate["content"] as? [String: Any], - let parts = content["parts"] as? [[String: Any]] { - for part in parts { - if let text = part["text"] as? String, !text.isEmpty { - continuation.yield(.textDelta(text)) - } - if let functionCall = part["functionCall"] as? [String: Any], - let name = functionCall["name"] as? String { - let id = UUID().uuidString - let argsObject = functionCall["args"] ?? [String: Any]() - let argsString = encodeArgsToJSONString(argsObject) - continuation.yield(.toolUseStart(id: id, name: name)) - continuation.yield(.toolUseDelta(id: id, inputJSONDelta: argsString)) - continuation.yield(.toolUseEnd(id: id)) - } - } - } - - if let usageMetadata = json["usageMetadata"] as? [String: Any] { - if let prompt = usageMetadata["promptTokenCount"] as? Int { - inputTokens = prompt - } - if let candidates = usageMetadata["candidatesTokenCount"] as? Int { - outputTokens = candidates - } - } + guard let json = Self.decodeStreamLine(line) else { continue } + let events = Self.parseChunk( + json, + state: &state, + idGenerator: { UUID().uuidString } + ) + for event in events { continuation.yield(event) } } - - if inputTokens > 0 || outputTokens > 0 { - continuation.yield(.usage(AITokenUsage( - inputTokens: inputTokens, - outputTokens: outputTokens - ))) + if let usage = state.finalUsageEvent() { + continuation.yield(usage) } continuation.finish() @@ -317,7 +282,69 @@ final class GeminiProvider: ChatTransport { } } - private func encodeArgsToJSONString(_ args: Any) -> String { + + func mapHTTPError(statusCode: Int, body: String) -> AIProviderError { + if statusCode == 403 { + let message = AIProviderError.parseErrorMessage(from: body) ?? body + return .authenticationFailed(message) + } + return AIProviderError.mapHTTPError(statusCode: statusCode, body: body) + } + + /// Decodes one Gemini SSE line. Returns nil for non-data lines. + static func decodeStreamLine(_ line: String) -> [String: Any]? { + guard line.hasPrefix("data: ") else { return nil } + let jsonString = String(line.dropFirst(6)) + guard let data = jsonString.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { return nil } + return json + } + + /// Translate a single Gemini chunk to events. + /// + /// Gemini does not provide tool-call ids on `functionCall` parts, so we + /// synthesize one per call. `idGenerator` is injected so tests can pin the + /// synthetic id to a stable value; production passes `{ UUID().uuidString }`. + /// Each call to `idGenerator()` returns a fresh id, so multiple + /// `functionCall` parts in one chunk get distinct ids in production. + static func parseChunk( + _ json: [String: Any], + state: inout GeminiStreamState, + idGenerator: () -> String + ) -> [ChatStreamEvent] { + var events: [ChatStreamEvent] = [] + if let candidates = json["candidates"] as? [[String: Any]], + let firstCandidate = candidates.first, + let content = firstCandidate["content"] as? [String: Any], + let parts = content["parts"] as? [[String: Any]] { + for part in parts { + if let text = part["text"] as? String, !text.isEmpty { + events.append(.textDelta(text)) + } + if let functionCall = part["functionCall"] as? [String: Any], + let name = functionCall["name"] as? String { + let id = idGenerator() + let argsObject = functionCall["args"] ?? [String: Any]() + let argsString = encodeArgsToJSONString(argsObject) + events.append(.toolUseStart(id: id, name: name)) + events.append(.toolUseDelta(id: id, inputJSONDelta: argsString)) + events.append(.toolUseEnd(id: id)) + } + } + } + if let usageMetadata = json["usageMetadata"] as? [String: Any] { + if let prompt = usageMetadata["promptTokenCount"] as? Int { + state.inputTokens = prompt + } + if let candidates = usageMetadata["candidatesTokenCount"] as? Int { + state.outputTokens = candidates + } + } + return events + } + + static func encodeArgsToJSONString(_ args: Any) -> String { guard JSONSerialization.isValidJSONObject(args) else { Self.logger.warning("Gemini functionCall args was not a valid JSON object; falling back to empty input") return "{}" @@ -330,12 +357,15 @@ final class GeminiProvider: ChatTransport { return "{}" } } +} - func mapHTTPError(statusCode: Int, body: String) -> AIProviderError { - if statusCode == 403 { - let message = AIProviderError.parseErrorMessage(from: body) ?? body - return .authenticationFailed(message) - } - return AIProviderError.mapHTTPError(statusCode: statusCode, body: body) +/// Mutable state carried across `GeminiProvider.parseChunk` calls. +struct GeminiStreamState { + var inputTokens: Int = 0 + var outputTokens: Int = 0 + + func finalUsageEvent() -> ChatStreamEvent? { + guard inputTokens > 0 || outputTokens > 0 else { return nil } + return .usage(AITokenUsage(inputTokens: inputTokens, outputTokens: outputTokens)) } } diff --git a/TablePro/Core/AI/OpenAICompatibleProvider.swift b/TablePro/Core/AI/OpenAICompatibleProvider.swift index f5611b921..e2389d173 100644 --- a/TablePro/Core/AI/OpenAICompatibleProvider.swift +++ b/TablePro/Core/AI/OpenAICompatibleProvider.swift @@ -60,98 +60,19 @@ final class OpenAICompatibleProvider: ChatTransport { ) } - var inputTokens = 0 - var outputTokens = 0 - var toolCallIndexToId: [Int: String] = [:] - var toolCallOrder: [Int] = [] - + var state = OpenAIStreamState() for try await line in bytes.lines { if Task.isCancelled { break } - - let jsonString: String - if self.providerType == .ollama { - guard !line.isEmpty else { continue } - jsonString = line - } else { - guard line.hasPrefix("data: ") else { continue } - let payload = String(line.dropFirst(6)) - guard payload != "[DONE]" else { break } - jsonString = payload - } - - guard let data = jsonString.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] - else { continue } - - let choices = json["choices"] as? [[String: Any]] - let firstChoice = choices?.first - let delta = firstChoice?["delta"] as? [String: Any] - - if let delta, let content = delta["content"] as? String, !content.isEmpty { - continuation.yield(.textDelta(content)) - } else if let message = json["message"] as? [String: Any], - let content = message["content"] as? String, - !content.isEmpty { - continuation.yield(.textDelta(content)) - } - - if let delta, let toolCalls = delta["tool_calls"] as? [[String: Any]] { - handleToolCallDeltas( - toolCalls, - indexToId: &toolCallIndexToId, - order: &toolCallOrder, - continuation: continuation - ) - } else if let message = json["message"] as? [String: Any], - let toolCalls = message["tool_calls"] as? [[String: Any]] { - handleOllamaToolCalls( - toolCalls, - indexToId: &toolCallIndexToId, - order: &toolCallOrder, - continuation: continuation - ) - } - - if let finishReason = firstChoice?["finish_reason"] as? String, - finishReason == "tool_calls" { - for index in toolCallOrder { - if let id = toolCallIndexToId[index] { - continuation.yield(.toolUseEnd(id: id)) - } - } - toolCallIndexToId.removeAll() - toolCallOrder.removeAll() - } - - if let usage = json["usage"] as? [String: Any], - let promptTokens = usage["prompt_tokens"] as? Int, - let completionTokens = usage["completion_tokens"] as? Int { - inputTokens = promptTokens - outputTokens = completionTokens - } else if let done = json["done"] as? Bool, done, - let promptEval = json["prompt_eval_count"] as? Int, - let evalCount = json["eval_count"] as? Int { - inputTokens = promptEval - outputTokens = evalCount - } - - if json["done"] as? Bool == true { - for index in toolCallOrder { - if let id = toolCallIndexToId[index] { - continuation.yield(.toolUseEnd(id: id)) - } - } - toolCallIndexToId.removeAll() - toolCallOrder.removeAll() - break + guard let json = Self.decodeStreamLine(line, providerType: self.providerType) else { + if line == "data: [DONE]" { break } + continue } + let result = Self.parseChunk(json, state: &state) + for event in result.events { continuation.yield(event) } + if result.shouldBreak { break } } - - if inputTokens > 0 || outputTokens > 0 { - continuation.yield(.usage(AITokenUsage( - inputTokens: inputTokens, - outputTokens: outputTokens - ))) + if let usage = state.finalUsageEvent() { + continuation.yield(usage) } continuation.finish() @@ -166,53 +87,122 @@ final class OpenAICompatibleProvider: ChatTransport { } } - private func handleToolCallDeltas( + /// Decodes one streaming line. OpenAI/OpenRouter/Custom use SSE framing + /// (`data: {...}`); Ollama emits NDJSON (one JSON object per line). The + /// `[DONE]` sentinel returns nil; the caller should break on it. + static func decodeStreamLine(_ line: String, providerType: AIProviderType) -> [String: Any]? { + let jsonString: String + if providerType == .ollama { + guard !line.isEmpty else { return nil } + jsonString = line + } else { + guard line.hasPrefix("data: ") else { return nil } + let payload = String(line.dropFirst(6)) + guard payload != "[DONE]" else { return nil } + jsonString = payload + } + guard let data = jsonString.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { return nil } + return json + } + + /// Translate one chunk JSON to events. Mutates state to thread tool-call + /// index→id mapping, ordering, and token counters across chunks. + /// Returns `(events, shouldBreak)` so the caller can stop the stream when + /// Ollama emits `done: true`. + static func parseChunk( + _ json: [String: Any], + state: inout OpenAIStreamState + ) -> (events: [ChatStreamEvent], shouldBreak: Bool) { + var events: [ChatStreamEvent] = [] + let choices = json["choices"] as? [[String: Any]] + let firstChoice = choices?.first + let delta = firstChoice?["delta"] as? [String: Any] + + if let delta, let content = delta["content"] as? String, !content.isEmpty { + events.append(.textDelta(content)) + } else if let message = json["message"] as? [String: Any], + let content = message["content"] as? String, + !content.isEmpty { + events.append(.textDelta(content)) + } + + if let delta, let toolCalls = delta["tool_calls"] as? [[String: Any]] { + events.append(contentsOf: handleToolCallDeltas(toolCalls, state: &state)) + } else if let message = json["message"] as? [String: Any], + let toolCalls = message["tool_calls"] as? [[String: Any]] { + events.append(contentsOf: handleOllamaToolCalls(toolCalls, state: &state)) + } + + if let finishReason = firstChoice?["finish_reason"] as? String, + finishReason == "tool_calls" { + events.append(contentsOf: state.flushToolUseEnds()) + } + + if let usage = json["usage"] as? [String: Any], + let promptTokens = usage["prompt_tokens"] as? Int, + let completionTokens = usage["completion_tokens"] as? Int { + state.inputTokens = promptTokens + state.outputTokens = completionTokens + } else if let done = json["done"] as? Bool, done, + let promptEval = json["prompt_eval_count"] as? Int, + let evalCount = json["eval_count"] as? Int { + state.inputTokens = promptEval + state.outputTokens = evalCount + } + + // Ollama signals stream-end via `done: true`. We flush again here only + // when finish_reason didn't already drain the tool-call map (which + // typically isn't set on Ollama responses). + let shouldBreak = (json["done"] as? Bool) == true + if shouldBreak, !state.toolCallIndexToId.isEmpty { + events.append(contentsOf: state.flushToolUseEnds()) + } + return (events, shouldBreak) + } + + private static func handleToolCallDeltas( _ toolCalls: [[String: Any]], - indexToId: inout [Int: String], - order: inout [Int], - continuation: AsyncThrowingStream.Continuation - ) { + state: inout OpenAIStreamState + ) -> [ChatStreamEvent] { + var events: [ChatStreamEvent] = [] for toolCall in toolCalls { guard let index = toolCall["index"] as? Int else { continue } let function = toolCall["function"] as? [String: Any] - - if indexToId[index] == nil { + if state.toolCallIndexToId[index] == nil { let id = (toolCall["id"] as? String) ?? "call_\(index)_\(UUID().uuidString.prefix(8))" let name = (function?["name"] as? String) ?? "" - indexToId[index] = id - order.append(index) - continuation.yield(.toolUseStart(id: id, name: name)) + state.toolCallIndexToId[index] = id + state.toolCallOrder.append(index) + events.append(.toolUseStart(id: id, name: name)) } - - if let id = indexToId[index], + if let id = state.toolCallIndexToId[index], let arguments = function?["arguments"] as? String, !arguments.isEmpty { - continuation.yield(.toolUseDelta(id: id, inputJSONDelta: arguments)) + events.append(.toolUseDelta(id: id, inputJSONDelta: arguments)) } } + return events } - private func handleOllamaToolCalls( + private static func handleOllamaToolCalls( _ toolCalls: [[String: Any]], - indexToId: inout [Int: String], - order: inout [Int], - continuation: AsyncThrowingStream.Continuation - ) { + state: inout OpenAIStreamState + ) -> [ChatStreamEvent] { + var events: [ChatStreamEvent] = [] for (offset, toolCall) in toolCalls.enumerated() { guard let function = toolCall["function"] as? [String: Any], let name = function["name"] as? String else { continue } - let index = (toolCall["index"] as? Int) ?? offset let id = (toolCall["id"] as? String) ?? "call_\(index)_\(UUID().uuidString.prefix(8))" - - if indexToId[index] == nil { - indexToId[index] = id - order.append(index) - continuation.yield(.toolUseStart(id: id, name: name)) + if state.toolCallIndexToId[index] == nil { + state.toolCallIndexToId[index] = id + state.toolCallOrder.append(index) + events.append(.toolUseStart(id: id, name: name)) } - let argumentsString: String if let stringArgs = function["arguments"] as? String { argumentsString = stringArgs @@ -223,11 +213,11 @@ final class OpenAICompatibleProvider: ChatTransport { } else { argumentsString = "" } - - if !argumentsString.isEmpty, let resolvedId = indexToId[index] { - continuation.yield(.toolUseDelta(id: resolvedId, inputJSONDelta: argumentsString)) + if !argumentsString.isEmpty, let resolvedId = state.toolCallIndexToId[index] { + events.append(.toolUseDelta(id: resolvedId, inputJSONDelta: argumentsString)) } } + return events } func fetchAvailableModels() async throws -> [String] { @@ -500,3 +490,29 @@ final class OpenAICompatibleProvider: ChatTransport { return models.compactMap { $0["name"] as? String }.sorted() } } + +/// Mutable state carried across `OpenAICompatibleProvider.parseChunk` calls. +struct OpenAIStreamState { + var inputTokens: Int = 0 + var outputTokens: Int = 0 + var toolCallIndexToId: [Int: String] = [:] + var toolCallOrder: [Int] = [] + + /// Yield `.toolUseEnd` for every tracked tool call and clear the map. + /// Called when the provider signals tool-call completion (`finish_reason` + /// or Ollama `done: true`). + mutating func flushToolUseEnds() -> [ChatStreamEvent] { + let events: [ChatStreamEvent] = toolCallOrder.compactMap { index in + guard let id = toolCallIndexToId[index] else { return nil } + return .toolUseEnd(id: id) + } + toolCallIndexToId.removeAll() + toolCallOrder.removeAll() + return events + } + + func finalUsageEvent() -> ChatStreamEvent? { + guard inputTokens > 0 || outputTokens > 0 else { return nil } + return .usage(AITokenUsage(inputTokens: inputTokens, outputTokens: outputTokens)) + } +} diff --git a/TableProTests/Core/AI/AnthropicProviderParserTests.swift b/TableProTests/Core/AI/AnthropicProviderParserTests.swift new file mode 100644 index 000000000..061054640 --- /dev/null +++ b/TableProTests/Core/AI/AnthropicProviderParserTests.swift @@ -0,0 +1,203 @@ +// +// AnthropicProviderParserTests.swift +// TableProTests +// + +import Foundation +@testable import TablePro +import Testing + +@Suite("AnthropicProvider stream parser") +struct AnthropicProviderParserTests { + private func parse(_ json: [String: Any], state: inout AnthropicStreamState) throws -> [ChatStreamEvent] { + try AnthropicProvider.parseChunk(json, state: &state) + } + + @Test("text_delta yields textDelta") + func textDelta() throws { + var state = AnthropicStreamState() + let events = try parse([ + "type": "content_block_delta", + "delta": ["type": "text_delta", "text": "hello"] + ], state: &state) + guard case .textDelta(let text) = events.first else { + Issue.record("expected textDelta; got \(events)") + return + } + #expect(text == "hello") + } + + @Test("content_block_start with tool_use yields toolUseStart and remembers index→id") + func toolUseStart() throws { + var state = AnthropicStreamState() + let events = try parse([ + "type": "content_block_start", + "index": 1, + "content_block": [ + "type": "tool_use", + "id": "toolu_abc", + "name": "list_tables" + ] + ], state: &state) + #expect(events.count == 1) + if case .toolUseStart(let id, let name) = events.first { + #expect(id == "toolu_abc") + #expect(name == "list_tables") + } else { + Issue.record("expected toolUseStart; got \(events)") + } + #expect(state.toolUseIdsByIndex[1] == "toolu_abc") + } + + @Test("input_json_delta resolves index back to id from state") + func inputJSONDelta() throws { + var state = AnthropicStreamState() + state.toolUseIdsByIndex[1] = "toolu_abc" + let events = try parse([ + "type": "content_block_delta", + "index": 1, + "delta": ["type": "input_json_delta", "partial_json": #"{"foo":"#] + ], state: &state) + if case .toolUseDelta(let id, let delta) = events.first { + #expect(id == "toolu_abc") + #expect(delta == #"{"foo":"#) + } else { + Issue.record("expected toolUseDelta; got \(events)") + } + } + + @Test("input_json_delta for unknown index yields nothing") + func inputJSONDeltaUnknownIndex() throws { + var state = AnthropicStreamState() + let events = try parse([ + "type": "content_block_delta", + "index": 99, + "delta": ["type": "input_json_delta", "partial_json": "x"] + ], state: &state) + #expect(events.isEmpty) + } + + @Test("content_block_stop yields toolUseEnd and clears the index mapping") + func toolUseEnd() throws { + var state = AnthropicStreamState() + state.toolUseIdsByIndex[1] = "toolu_abc" + let events = try parse([ + "type": "content_block_stop", + "index": 1 + ], state: &state) + if case .toolUseEnd(let id) = events.first { + #expect(id == "toolu_abc") + } else { + Issue.record("expected toolUseEnd; got \(events)") + } + #expect(state.toolUseIdsByIndex[1] == nil) + } + + @Test("Fragmented input_json_delta concatenates correctly via state") + func fragmentedDelta() throws { + var state = AnthropicStreamState() + _ = try parse([ + "type": "content_block_start", + "index": 0, + "content_block": ["type": "tool_use", "id": "tid", "name": "list_tables"] + ], state: &state) + let chunk1 = try parse([ + "type": "content_block_delta", + "index": 0, + "delta": ["type": "input_json_delta", "partial_json": #"{"con"#] + ], state: &state) + let chunk2 = try parse([ + "type": "content_block_delta", + "index": 0, + "delta": ["type": "input_json_delta", "partial_json": #"nection_id":"x"}"#] + ], state: &state) + let stop = try parse([ + "type": "content_block_stop", + "index": 0 + ], state: &state) + let combined = chunk1 + chunk2 + stop + // textDelta accumulator on the consumer side reassembles fragments. + let deltas = combined.compactMap { event -> String? in + if case .toolUseDelta(_, let d) = event { return d } + return nil + } + #expect(deltas.joined() == #"{"connection_id":"x"}"#) + } + + @Test("message_start tracks input tokens") + func messageStart() throws { + var state = AnthropicStreamState() + _ = try parse([ + "type": "message_start", + "message": ["usage": ["input_tokens": 42]] + ], state: &state) + #expect(state.inputTokens == 42) + } + + @Test("message_delta tracks output tokens") + func messageDelta() throws { + var state = AnthropicStreamState() + _ = try parse([ + "type": "message_delta", + "usage": ["output_tokens": 100] + ], state: &state) + #expect(state.outputTokens == 100) + } + + @Test("finalUsageEvent emits .usage when tokens were observed") + func finalUsage() throws { + var state = AnthropicStreamState() + state.inputTokens = 42 + state.outputTokens = 100 + guard case .usage(let usage) = state.finalUsageEvent() else { + Issue.record("expected usage event") + return + } + #expect(usage.inputTokens == 42) + #expect(usage.outputTokens == 100) + } + + @Test("finalUsageEvent returns nil when no tokens observed") + func noUsageNoEvent() { + let state = AnthropicStreamState() + #expect(state.finalUsageEvent() == nil) + } + + @Test("error event throws streamingFailed") + func errorEvent() { + var state = AnthropicStreamState() + #expect(throws: AIProviderError.self) { + _ = try AnthropicProvider.parseChunk( + ["type": "error", "error": ["message": "rate limited"]], + state: &state + ) + } + } + + @Test("decodeStreamLine returns nil for non-data lines and [DONE]") + func framingDecode() { + #expect(AnthropicProvider.decodeStreamLine("event: message_stop") == nil) + #expect(AnthropicProvider.decodeStreamLine("data: [DONE]") == nil) + #expect(AnthropicProvider.decodeStreamLine("data: {\"type\":\"x\"}") != nil) + } + + @Test("Unknown event types yield no events and don't throw") + func unknownEventType() throws { + var state = AnthropicStreamState() + let events = try AnthropicProvider.parseChunk( + ["type": "ping", "extra": "ignored"], + state: &state + ) + #expect(events.isEmpty) + // State should be unchanged. + #expect(state.toolUseIdsByIndex.isEmpty) + #expect(state.inputTokens == 0) + } + + @Test("Chunk with no type field yields no events and doesn't throw") + func chunkWithoutType() throws { + var state = AnthropicStreamState() + let events = try AnthropicProvider.parseChunk(["random": "data"], state: &state) + #expect(events.isEmpty) + } +} diff --git a/TableProTests/Core/AI/GeminiProviderParserTests.swift b/TableProTests/Core/AI/GeminiProviderParserTests.swift new file mode 100644 index 000000000..db3435d36 --- /dev/null +++ b/TableProTests/Core/AI/GeminiProviderParserTests.swift @@ -0,0 +1,163 @@ +// +// GeminiProviderParserTests.swift +// TableProTests +// + +import Foundation +@testable import TablePro +import Testing + +@Suite("GeminiProvider stream parser") +struct GeminiProviderParserTests { + private let stableID = "stable-id" + + private func parse(_ json: [String: Any], state: inout GeminiStreamState) -> [ChatStreamEvent] { + GeminiProvider.parseChunk(json, state: &state, idGenerator: { self.stableID }) + } + + @Test("Text part yields textDelta") + func textPart() { + var state = GeminiStreamState() + let events = parse([ + "candidates": [[ + "content": [ + "parts": [["text": "hello"]] + ] + ]] + ], state: &state) + guard case .textDelta(let text) = events.first else { + Issue.record("expected textDelta") + return + } + #expect(text == "hello") + } + + @Test("functionCall part yields the start/delta/end trio in one chunk") + func functionCallTrio() { + var state = GeminiStreamState() + let events = parse([ + "candidates": [[ + "content": [ + "parts": [[ + "functionCall": [ + "name": "list_tables", + "args": ["connection_id": "abc"] + ] + ]] + ] + ]] + ], state: &state) + #expect(events.count == 3) + if case .toolUseStart(let id, let name) = events[0] { + #expect(id == stableID) + #expect(name == "list_tables") + } else { + Issue.record("expected toolUseStart at index 0") + } + if case .toolUseDelta(let id, let delta) = events[1] { + #expect(id == stableID) + #expect(delta.contains("connection_id")) + #expect(delta.contains("abc")) + } else { + Issue.record("expected toolUseDelta at index 1") + } + if case .toolUseEnd(let id) = events[2] { + #expect(id == stableID) + } else { + Issue.record("expected toolUseEnd at index 2") + } + } + + @Test("Mixed text + functionCall parts yield events in part order") + func mixedParts() { + var state = GeminiStreamState() + let events = parse([ + "candidates": [[ + "content": [ + "parts": [ + ["text": "I'll check"], + ["functionCall": ["name": "list_tables", "args": [String: Any]()]] + ] + ] + ]] + ], state: &state) + // Order: textDelta, toolUseStart, toolUseDelta, toolUseEnd + #expect(events.count == 4) + guard case .textDelta = events[0] else { + Issue.record("expected textDelta first") + return + } + guard case .toolUseStart = events[1] else { + Issue.record("expected toolUseStart second") + return + } + } + + @Test("Empty parts array yields no events") + func emptyParts() { + var state = GeminiStreamState() + let events = parse([ + "candidates": [[ + "content": ["parts": [[String: Any]]()] + ]] + ], state: &state) + #expect(events.isEmpty) + } + + @Test("usageMetadata populates state token counters") + func usageTokens() { + var state = GeminiStreamState() + _ = parse([ + "usageMetadata": [ + "promptTokenCount": 100, + "candidatesTokenCount": 50 + ] + ], state: &state) + #expect(state.inputTokens == 100) + #expect(state.outputTokens == 50) + } + + @Test("encodeArgsToJSONString returns {} on invalid input") + func argsFallback() { + let invalid: Any = NSObject() // not JSON-serializable + #expect(GeminiProvider.encodeArgsToJSONString(invalid) == "{}") + } + + @Test("encodeArgsToJSONString round-trips a valid object") + func argsRoundTrip() { + let result = GeminiProvider.encodeArgsToJSONString(["a": 1, "b": "x"]) + #expect(result.contains("\"a\"")) + #expect(result.contains("\"b\"")) + } + + @Test("Chunk without candidates yields no events") + func chunkWithoutCandidates() { + var state = GeminiStreamState() + let events = parse(["unrelated": "data"], state: &state) + #expect(events.isEmpty) + } + + @Test("Multiple functionCall parts in one chunk get distinct ids") + func multipleFunctionCallsGetDistinctIds() { + var state = GeminiStreamState() + var counter = 0 + let events = GeminiProvider.parseChunk([ + "candidates": [[ + "content": [ + "parts": [ + ["functionCall": ["name": "list_tables", "args": [String: Any]()]], + ["functionCall": ["name": "describe_table", "args": [String: Any]()]] + ] + ] + ]] + ], state: &state, idGenerator: { + defer { counter += 1 } + return "id-\(counter)" + }) + let starts = events.compactMap { event -> String? in + if case .toolUseStart(let id, _) = event { return id } + return nil + } + #expect(starts == ["id-0", "id-1"]) + } +} diff --git a/TableProTests/Core/AI/OpenAICompatibleProviderParserTests.swift b/TableProTests/Core/AI/OpenAICompatibleProviderParserTests.swift new file mode 100644 index 000000000..614cd6bab --- /dev/null +++ b/TableProTests/Core/AI/OpenAICompatibleProviderParserTests.swift @@ -0,0 +1,210 @@ +// +// OpenAICompatibleProviderParserTests.swift +// TableProTests +// + +import Foundation +@testable import TablePro +import Testing + +@Suite("OpenAICompatibleProvider stream parser") +struct OpenAICompatibleProviderParserTests { + @Test("delta.content yields textDelta") + func textDelta() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk([ + "choices": [[ + "delta": ["content": "hello"] + ]] + ], state: &state) + #expect(result.shouldBreak == false) + guard case .textDelta(let text) = result.events.first else { + Issue.record("expected textDelta; got \(result.events)") + return + } + #expect(text == "hello") + } + + @Test("First tool_calls chunk emits toolUseStart with id and name") + func toolUseStart() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk([ + "choices": [[ + "delta": [ + "tool_calls": [[ + "index": 0, + "id": "call_abc", + "type": "function", + "function": ["name": "list_tables", "arguments": ""] + ]] + ] + ]] + ], state: &state) + #expect(result.events.count == 1) + if case .toolUseStart(let id, let name) = result.events.first { + #expect(id == "call_abc") + #expect(name == "list_tables") + } else { + Issue.record("expected toolUseStart; got \(result.events)") + } + #expect(state.toolCallIndexToId[0] == "call_abc") + } + + @Test("Subsequent tool_calls chunks emit toolUseDelta only") + func toolUseDelta() { + var state = OpenAIStreamState() + state.toolCallIndexToId[0] = "call_abc" + state.toolCallOrder = [0] + let result = OpenAICompatibleProvider.parseChunk([ + "choices": [[ + "delta": [ + "tool_calls": [[ + "index": 0, + "function": ["arguments": #"{"foo":"#] + ]] + ] + ]] + ], state: &state) + #expect(result.events.count == 1) + if case .toolUseDelta(let id, let delta) = result.events.first { + #expect(id == "call_abc") + #expect(delta == #"{"foo":"#) + } else { + Issue.record("expected toolUseDelta; got \(result.events)") + } + } + + @Test("finish_reason: tool_calls flushes toolUseEnds for all tracked calls") + func finishReasonTriggersFlush() { + var state = OpenAIStreamState() + state.toolCallIndexToId = [0: "call_a", 1: "call_b"] + state.toolCallOrder = [0, 1] + let result = OpenAICompatibleProvider.parseChunk([ + "choices": [["finish_reason": "tool_calls"]] + ], state: &state) + let endIds = result.events.compactMap { event -> String? in + if case .toolUseEnd(let id) = event { return id } + return nil + } + #expect(endIds == ["call_a", "call_b"]) + #expect(state.toolCallIndexToId.isEmpty) + #expect(state.toolCallOrder.isEmpty) + } + + @Test("Ollama message.tool_calls with arguments-as-object encodes to JSON string") + func ollamaArgumentsAsObject() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk([ + "message": [ + "tool_calls": [[ + "function": [ + "name": "list_tables", + "arguments": ["connection_id": "abc"] // object, not string + ] + ]] + ] + ], state: &state) + let deltaPayload = result.events.compactMap { event -> String? in + if case .toolUseDelta(_, let s) = event { return s } + return nil + }.first + #expect(deltaPayload?.contains("connection_id") == true) + #expect(deltaPayload?.contains("abc") == true) + } + + @Test("Ollama message.tool_calls with arguments-as-string passes through verbatim") + func ollamaArgumentsAsString() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk([ + "message": [ + "tool_calls": [[ + "function": [ + "name": "list_tables", + "arguments": #"{"connection_id":"abc"}"# + ] + ]] + ] + ], state: &state) + let delta = result.events.compactMap { event -> String? in + if case .toolUseDelta(_, let s) = event { return s } + return nil + }.first + #expect(delta == #"{"connection_id":"abc"}"#) + } + + @Test("Ollama done: true sets shouldBreak and flushes pending tool ends") + func ollamaDoneFlushesAndBreaks() { + var state = OpenAIStreamState() + state.toolCallIndexToId[0] = "call_a" + state.toolCallOrder = [0] + let result = OpenAICompatibleProvider.parseChunk([ + "done": true, + "prompt_eval_count": 50, + "eval_count": 200 + ], state: &state) + #expect(result.shouldBreak == true) + #expect(result.events.contains(where: { event in + if case .toolUseEnd(let id) = event { return id == "call_a" } + return false + })) + #expect(state.inputTokens == 50) + #expect(state.outputTokens == 200) + } + + @Test("usage object populates state token counters") + func usageTokens() { + var state = OpenAIStreamState() + _ = OpenAICompatibleProvider.parseChunk([ + "usage": ["prompt_tokens": 30, "completion_tokens": 90] + ], state: &state) + #expect(state.inputTokens == 30) + #expect(state.outputTokens == 90) + } + + @Test("message.content path yields textDelta (Ollama non-stream + final-message OpenAI)") + func messageContentPathYieldsTextDelta() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk([ + "message": ["content": "hi"] + ], state: &state) + guard case .textDelta(let text) = result.events.first else { + Issue.record("expected textDelta from message.content; got \(result.events)") + return + } + #expect(text == "hi") + } + + @Test("Empty chunk yields no events and doesn't break") + func emptyChunk() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk([:], state: &state) + #expect(result.events.isEmpty) + #expect(result.shouldBreak == false) + } + + @Test("done: true with no pending tool calls breaks without emitting") + func doneWithNoPendingTools() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk(["done": true], state: &state) + #expect(result.shouldBreak == true) + #expect(result.events.isEmpty) + } + + @Test("decodeStreamLine respects providerType (SSE vs NDJSON)") + func decodeStreamLineFraming() { + let openAIParsed = OpenAICompatibleProvider.decodeStreamLine( + #"data: {"choices":[]}"#, + providerType: .openAI + ) + #expect(openAIParsed != nil) + let openAIDone = OpenAICompatibleProvider.decodeStreamLine("data: [DONE]", providerType: .openAI) + #expect(openAIDone == nil) + let ollamaParsed = OpenAICompatibleProvider.decodeStreamLine( + #"{"done":true}"#, + providerType: .ollama + ) + #expect(ollamaParsed != nil) + let ollamaEmpty = OpenAICompatibleProvider.decodeStreamLine("", providerType: .ollama) + #expect(ollamaEmpty == nil) + } +}