From 955c4b4a91a7c3095ec05693b3a8c800f779b56a Mon Sep 17 00:00:00 2001 From: Ngo Quoc Dat Date: Thu, 7 May 2026 15:03:24 +0700 Subject: [PATCH 1/2] refactor: extract per-chunk parsers from provider streamChat loops; add 27 parser tests --- TablePro/Core/AI/AnthropicProvider.swift | 158 ++++++----- TablePro/Core/AI/GeminiProvider.swift | 137 +++++----- .../Core/AI/OpenAICompatibleProvider.swift | 245 +++++++++--------- .../AI/AnthropicProviderParserTests.swift | 183 +++++++++++++ .../Core/AI/GeminiProviderParserTests.swift | 132 ++++++++++ .../OpenAICompatibleProviderParserTests.swift | 181 +++++++++++++ 6 files changed, 797 insertions(+), 239 deletions(-) create mode 100644 TableProTests/Core/AI/AnthropicProviderParserTests.swift create mode 100644 TableProTests/Core/AI/GeminiProviderParserTests.swift create mode 100644 TableProTests/Core/AI/OpenAICompatibleProviderParserTests.swift diff --git a/TablePro/Core/AI/AnthropicProvider.swift b/TablePro/Core/AI/AnthropicProvider.swift index d1f88c445..571f19510 100644 --- a/TablePro/Core/AI/AnthropicProvider.swift +++ b/TablePro/Core/AI/AnthropicProvider.swift @@ -43,74 +43,15 @@ final class AnthropicProvider: ChatTransport { ) } - var inputTokens = 0 - var outputTokens = 0 - var toolUseIdsByIndex: [Int: String] = [:] - + var state = AnthropicStreamState() for try await line in bytes.lines { if Task.isCancelled { break } - - guard line.hasPrefix("data: ") else { continue } - let jsonString = String(line.dropFirst(6)) - guard jsonString != "[DONE]", - let data = jsonString.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], - let type = json["type"] as? String - else { continue } - - switch type { - case "content_block_start": - if let index = json["index"] as? Int, - let block = json["content_block"] as? [String: Any], - (block["type"] as? String) == "tool_use", - let blockId = block["id"] as? String, - let blockName = block["name"] as? String { - toolUseIdsByIndex[index] = blockId - continuation.yield(.toolUseStart(id: blockId, name: blockName)) - } - case "content_block_delta": - guard let delta = json["delta"] as? [String: Any] else { break } - let deltaType = delta["type"] as? String - if deltaType == "input_json_delta" { - if let index = json["index"] as? Int, - let id = toolUseIdsByIndex[index], - let partial = delta["partial_json"] as? String { - continuation.yield(.toolUseDelta(id: id, inputJSONDelta: partial)) - } - } else if let text = delta["text"] as? String { - continuation.yield(.textDelta(text)) - } - case "content_block_stop": - if let index = json["index"] as? Int, - let id = toolUseIdsByIndex.removeValue(forKey: index) { - continuation.yield(.toolUseEnd(id: id)) - } - case "message_start": - if let message = json["message"] as? [String: Any], - let usage = message["usage"] as? [String: Any], - let tokens = usage["input_tokens"] as? Int { - inputTokens = tokens - } - case "message_delta": - if let usage = json["usage"] as? [String: Any], - let tokens = usage["output_tokens"] as? Int { - outputTokens = tokens - } - case "error": - if let errorObj = json["error"] as? [String: Any], - let message = errorObj["message"] as? String { - throw AIProviderError.streamingFailed(message) - } - default: - break - } + guard let json = Self.decodeStreamLine(line) else { continue } + let events = try Self.parseChunk(json, state: &state) + for event in events { continuation.yield(event) } } - - if inputTokens > 0 || outputTokens > 0 { - continuation.yield(.usage(AITokenUsage( - inputTokens: inputTokens, - outputTokens: outputTokens - ))) + if let usage = state.finalUsageEvent() { + continuation.yield(usage) } continuation.finish() @@ -220,6 +161,81 @@ final class AnthropicProvider: ChatTransport { return request } + /// Decodes one SSE line of the form `data: {...}` to a JSON object. + /// Returns `nil` for non-data lines, the `[DONE]` sentinel, and unparsable + /// payloads. Keeping this separate from `parseChunk` lets tests skip the + /// SSE framing and feed JSON dictionaries directly. + static func decodeStreamLine(_ line: String) -> [String: Any]? { + guard line.hasPrefix("data: ") else { return nil } + let jsonString = String(line.dropFirst(6)) + guard jsonString != "[DONE]", + let data = jsonString.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { return nil } + return json + } + + /// Translate a single Anthropic SSE event JSON into zero or more + /// `ChatStreamEvent`s. Mutates `state` to carry index→id mappings and + /// token counters across calls. Throws `AIProviderError.streamingFailed` + /// on `error` events. + static func parseChunk( + _ json: [String: Any], + state: inout AnthropicStreamState + ) throws -> [ChatStreamEvent] { + guard let type = json["type"] as? String else { return [] } + switch type { + case "content_block_start": + guard let index = json["index"] as? Int, + let block = json["content_block"] as? [String: Any], + (block["type"] as? String) == "tool_use", + let blockId = block["id"] as? String, + let blockName = block["name"] as? String + else { return [] } + state.toolUseIdsByIndex[index] = blockId + return [.toolUseStart(id: blockId, name: blockName)] + case "content_block_delta": + guard let delta = json["delta"] as? [String: Any] else { return [] } + if (delta["type"] as? String) == "input_json_delta" { + guard let index = json["index"] as? Int, + let id = state.toolUseIdsByIndex[index], + let partial = delta["partial_json"] as? String + else { return [] } + return [.toolUseDelta(id: id, inputJSONDelta: partial)] + } + if let text = delta["text"] as? String { + return [.textDelta(text)] + } + return [] + case "content_block_stop": + guard let index = json["index"] as? Int, + let id = state.toolUseIdsByIndex.removeValue(forKey: index) + else { return [] } + return [.toolUseEnd(id: id)] + case "message_start": + if let message = json["message"] as? [String: Any], + let usage = message["usage"] as? [String: Any], + let tokens = usage["input_tokens"] as? Int { + state.inputTokens = tokens + } + return [] + case "message_delta": + if let usage = json["usage"] as? [String: Any], + let tokens = usage["output_tokens"] as? Int { + state.outputTokens = tokens + } + return [] + case "error": + if let errorObj = json["error"] as? [String: Any], + let message = errorObj["message"] as? String { + throw AIProviderError.streamingFailed(message) + } + return [] + default: + return [] + } + } + static func encodeToolSpec(_ spec: ChatToolSpec) throws -> [String: Any] { [ "name": spec.name, @@ -282,3 +298,15 @@ final class AnthropicProvider: ChatTransport { return try JSONSerialization.jsonObject(with: data, options: [.fragmentsAllowed]) } } + +/// Mutable state carried across `AnthropicProvider.parseChunk` calls. +struct AnthropicStreamState { + var inputTokens: Int = 0 + var outputTokens: Int = 0 + var toolUseIdsByIndex: [Int: String] = [:] + + func finalUsageEvent() -> ChatStreamEvent? { + guard inputTokens > 0 || outputTokens > 0 else { return nil } + return .usage(AITokenUsage(inputTokens: inputTokens, outputTokens: outputTokens)) + } +} diff --git a/TablePro/Core/AI/GeminiProvider.swift b/TablePro/Core/AI/GeminiProvider.swift index 04eca318a..8d19de5d9 100644 --- a/TablePro/Core/AI/GeminiProvider.swift +++ b/TablePro/Core/AI/GeminiProvider.swift @@ -43,54 +43,19 @@ final class GeminiProvider: ChatTransport { ) } - var inputTokens = 0 - var outputTokens = 0 - + var state = GeminiStreamState() for try await line in bytes.lines { if Task.isCancelled { break } - - guard line.hasPrefix("data: ") else { continue } - let jsonString = String(line.dropFirst(6)) - - guard let data = jsonString.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] - else { continue } - - if let candidates = json["candidates"] as? [[String: Any]], - let firstCandidate = candidates.first, - let content = firstCandidate["content"] as? [String: Any], - let parts = content["parts"] as? [[String: Any]] { - for part in parts { - if let text = part["text"] as? String, !text.isEmpty { - continuation.yield(.textDelta(text)) - } - if let functionCall = part["functionCall"] as? [String: Any], - let name = functionCall["name"] as? String { - let id = UUID().uuidString - let argsObject = functionCall["args"] ?? [String: Any]() - let argsString = encodeArgsToJSONString(argsObject) - continuation.yield(.toolUseStart(id: id, name: name)) - continuation.yield(.toolUseDelta(id: id, inputJSONDelta: argsString)) - continuation.yield(.toolUseEnd(id: id)) - } - } - } - - if let usageMetadata = json["usageMetadata"] as? [String: Any] { - if let prompt = usageMetadata["promptTokenCount"] as? Int { - inputTokens = prompt - } - if let candidates = usageMetadata["candidatesTokenCount"] as? Int { - outputTokens = candidates - } - } + guard let json = Self.decodeStreamLine(line) else { continue } + let events = Self.parseChunk( + json, + state: &state, + idGenerator: { UUID().uuidString } + ) + for event in events { continuation.yield(event) } } - - if inputTokens > 0 || outputTokens > 0 { - continuation.yield(.usage(AITokenUsage( - inputTokens: inputTokens, - outputTokens: outputTokens - ))) + if let usage = state.finalUsageEvent() { + continuation.yield(usage) } continuation.finish() @@ -317,19 +282,6 @@ final class GeminiProvider: ChatTransport { } } - private func encodeArgsToJSONString(_ args: Any) -> String { - guard JSONSerialization.isValidJSONObject(args) else { - Self.logger.warning("Gemini functionCall args was not a valid JSON object; falling back to empty input") - return "{}" - } - do { - let data = try JSONSerialization.data(withJSONObject: args) - return String(data: data, encoding: .utf8) ?? "{}" - } catch { - Self.logger.warning("Gemini functionCall args serialization failed: \(error.localizedDescription, privacy: .public)") - return "{}" - } - } func mapHTTPError(statusCode: Int, body: String) -> AIProviderError { if statusCode == 403 { @@ -338,4 +290,73 @@ final class GeminiProvider: ChatTransport { } return AIProviderError.mapHTTPError(statusCode: statusCode, body: body) } + + /// Decodes one Gemini SSE line. Returns nil for non-data lines. + static func decodeStreamLine(_ line: String) -> [String: Any]? { + guard line.hasPrefix("data: ") else { return nil } + let jsonString = String(line.dropFirst(6)) + guard let data = jsonString.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { return nil } + return json + } + + /// Translate a single Gemini chunk to events. The injected `idGenerator` + /// lets tests pin the synthetic id Gemini doesn't provide; production + /// passes `UUID().uuidString`. + static func parseChunk( + _ json: [String: Any], + state: inout GeminiStreamState, + idGenerator: () -> String + ) -> [ChatStreamEvent] { + var events: [ChatStreamEvent] = [] + if let candidates = json["candidates"] as? [[String: Any]], + let firstCandidate = candidates.first, + let content = firstCandidate["content"] as? [String: Any], + let parts = content["parts"] as? [[String: Any]] { + for part in parts { + if let text = part["text"] as? String, !text.isEmpty { + events.append(.textDelta(text)) + } + if let functionCall = part["functionCall"] as? [String: Any], + let name = functionCall["name"] as? String { + let id = idGenerator() + let argsObject = functionCall["args"] ?? [String: Any]() + let argsString = encodeArgsToJSONString(argsObject) + events.append(.toolUseStart(id: id, name: name)) + events.append(.toolUseDelta(id: id, inputJSONDelta: argsString)) + events.append(.toolUseEnd(id: id)) + } + } + } + if let usageMetadata = json["usageMetadata"] as? [String: Any] { + if let prompt = usageMetadata["promptTokenCount"] as? Int { + state.inputTokens = prompt + } + if let candidates = usageMetadata["candidatesTokenCount"] as? Int { + state.outputTokens = candidates + } + } + return events + } + + static func encodeArgsToJSONString(_ args: Any) -> String { + guard JSONSerialization.isValidJSONObject(args) else { return "{}" } + guard let data = try? JSONSerialization.data(withJSONObject: args), + let string = String(data: data, encoding: .utf8) else { + return "{}" + } + return string + } +} + +/// Mutable state carried across `GeminiProvider.parseChunk` calls. +struct GeminiStreamState { + var inputTokens: Int = 0 + var outputTokens: Int = 0 + + func finalUsageEvent() -> ChatStreamEvent? { + guard inputTokens > 0 || outputTokens > 0 else { return nil } + return .usage(AITokenUsage(inputTokens: inputTokens, outputTokens: outputTokens)) + } } diff --git a/TablePro/Core/AI/OpenAICompatibleProvider.swift b/TablePro/Core/AI/OpenAICompatibleProvider.swift index f5611b921..908d47c76 100644 --- a/TablePro/Core/AI/OpenAICompatibleProvider.swift +++ b/TablePro/Core/AI/OpenAICompatibleProvider.swift @@ -60,98 +60,19 @@ final class OpenAICompatibleProvider: ChatTransport { ) } - var inputTokens = 0 - var outputTokens = 0 - var toolCallIndexToId: [Int: String] = [:] - var toolCallOrder: [Int] = [] - + var state = OpenAIStreamState() for try await line in bytes.lines { if Task.isCancelled { break } - - let jsonString: String - if self.providerType == .ollama { - guard !line.isEmpty else { continue } - jsonString = line - } else { - guard line.hasPrefix("data: ") else { continue } - let payload = String(line.dropFirst(6)) - guard payload != "[DONE]" else { break } - jsonString = payload - } - - guard let data = jsonString.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] - else { continue } - - let choices = json["choices"] as? [[String: Any]] - let firstChoice = choices?.first - let delta = firstChoice?["delta"] as? [String: Any] - - if let delta, let content = delta["content"] as? String, !content.isEmpty { - continuation.yield(.textDelta(content)) - } else if let message = json["message"] as? [String: Any], - let content = message["content"] as? String, - !content.isEmpty { - continuation.yield(.textDelta(content)) - } - - if let delta, let toolCalls = delta["tool_calls"] as? [[String: Any]] { - handleToolCallDeltas( - toolCalls, - indexToId: &toolCallIndexToId, - order: &toolCallOrder, - continuation: continuation - ) - } else if let message = json["message"] as? [String: Any], - let toolCalls = message["tool_calls"] as? [[String: Any]] { - handleOllamaToolCalls( - toolCalls, - indexToId: &toolCallIndexToId, - order: &toolCallOrder, - continuation: continuation - ) - } - - if let finishReason = firstChoice?["finish_reason"] as? String, - finishReason == "tool_calls" { - for index in toolCallOrder { - if let id = toolCallIndexToId[index] { - continuation.yield(.toolUseEnd(id: id)) - } - } - toolCallIndexToId.removeAll() - toolCallOrder.removeAll() - } - - if let usage = json["usage"] as? [String: Any], - let promptTokens = usage["prompt_tokens"] as? Int, - let completionTokens = usage["completion_tokens"] as? Int { - inputTokens = promptTokens - outputTokens = completionTokens - } else if let done = json["done"] as? Bool, done, - let promptEval = json["prompt_eval_count"] as? Int, - let evalCount = json["eval_count"] as? Int { - inputTokens = promptEval - outputTokens = evalCount - } - - if json["done"] as? Bool == true { - for index in toolCallOrder { - if let id = toolCallIndexToId[index] { - continuation.yield(.toolUseEnd(id: id)) - } - } - toolCallIndexToId.removeAll() - toolCallOrder.removeAll() - break + guard let json = Self.decodeStreamLine(line, providerType: self.providerType) else { + if line == "data: [DONE]" { break } + continue } + let result = Self.parseChunk(json, state: &state) + for event in result.events { continuation.yield(event) } + if result.shouldBreak { break } } - - if inputTokens > 0 || outputTokens > 0 { - continuation.yield(.usage(AITokenUsage( - inputTokens: inputTokens, - outputTokens: outputTokens - ))) + if let usage = state.finalUsageEvent() { + continuation.yield(usage) } continuation.finish() @@ -166,53 +87,119 @@ final class OpenAICompatibleProvider: ChatTransport { } } - private func handleToolCallDeltas( + /// Decodes one streaming line. OpenAI/OpenRouter/Custom use SSE framing + /// (`data: {...}`); Ollama emits NDJSON (one JSON object per line). The + /// `[DONE]` sentinel returns nil; the caller should break on it. + static func decodeStreamLine(_ line: String, providerType: AIProviderType) -> [String: Any]? { + let jsonString: String + if providerType == .ollama { + guard !line.isEmpty else { return nil } + jsonString = line + } else { + guard line.hasPrefix("data: ") else { return nil } + let payload = String(line.dropFirst(6)) + guard payload != "[DONE]" else { return nil } + jsonString = payload + } + guard let data = jsonString.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { return nil } + return json + } + + /// Translate one chunk JSON to events. Mutates state to thread tool-call + /// index→id mapping, ordering, and token counters across chunks. + /// Returns `(events, shouldBreak)` so the caller can stop the stream when + /// Ollama emits `done: true`. + static func parseChunk( + _ json: [String: Any], + state: inout OpenAIStreamState + ) -> (events: [ChatStreamEvent], shouldBreak: Bool) { + var events: [ChatStreamEvent] = [] + let choices = json["choices"] as? [[String: Any]] + let firstChoice = choices?.first + let delta = firstChoice?["delta"] as? [String: Any] + + if let delta, let content = delta["content"] as? String, !content.isEmpty { + events.append(.textDelta(content)) + } else if let message = json["message"] as? [String: Any], + let content = message["content"] as? String, + !content.isEmpty { + events.append(.textDelta(content)) + } + + if let delta, let toolCalls = delta["tool_calls"] as? [[String: Any]] { + events.append(contentsOf: handleToolCallDeltas(toolCalls, state: &state)) + } else if let message = json["message"] as? [String: Any], + let toolCalls = message["tool_calls"] as? [[String: Any]] { + events.append(contentsOf: handleOllamaToolCalls(toolCalls, state: &state)) + } + + if let finishReason = firstChoice?["finish_reason"] as? String, + finishReason == "tool_calls" { + events.append(contentsOf: state.flushToolUseEnds()) + } + + if let usage = json["usage"] as? [String: Any], + let promptTokens = usage["prompt_tokens"] as? Int, + let completionTokens = usage["completion_tokens"] as? Int { + state.inputTokens = promptTokens + state.outputTokens = completionTokens + } else if let done = json["done"] as? Bool, done, + let promptEval = json["prompt_eval_count"] as? Int, + let evalCount = json["eval_count"] as? Int { + state.inputTokens = promptEval + state.outputTokens = evalCount + } + + let shouldBreak = (json["done"] as? Bool) == true + if shouldBreak { + events.append(contentsOf: state.flushToolUseEnds()) + } + return (events, shouldBreak) + } + + private static func handleToolCallDeltas( _ toolCalls: [[String: Any]], - indexToId: inout [Int: String], - order: inout [Int], - continuation: AsyncThrowingStream.Continuation - ) { + state: inout OpenAIStreamState + ) -> [ChatStreamEvent] { + var events: [ChatStreamEvent] = [] for toolCall in toolCalls { guard let index = toolCall["index"] as? Int else { continue } let function = toolCall["function"] as? [String: Any] - - if indexToId[index] == nil { + if state.toolCallIndexToId[index] == nil { let id = (toolCall["id"] as? String) ?? "call_\(index)_\(UUID().uuidString.prefix(8))" let name = (function?["name"] as? String) ?? "" - indexToId[index] = id - order.append(index) - continuation.yield(.toolUseStart(id: id, name: name)) + state.toolCallIndexToId[index] = id + state.toolCallOrder.append(index) + events.append(.toolUseStart(id: id, name: name)) } - - if let id = indexToId[index], + if let id = state.toolCallIndexToId[index], let arguments = function?["arguments"] as? String, !arguments.isEmpty { - continuation.yield(.toolUseDelta(id: id, inputJSONDelta: arguments)) + events.append(.toolUseDelta(id: id, inputJSONDelta: arguments)) } } + return events } - private func handleOllamaToolCalls( + private static func handleOllamaToolCalls( _ toolCalls: [[String: Any]], - indexToId: inout [Int: String], - order: inout [Int], - continuation: AsyncThrowingStream.Continuation - ) { + state: inout OpenAIStreamState + ) -> [ChatStreamEvent] { + var events: [ChatStreamEvent] = [] for (offset, toolCall) in toolCalls.enumerated() { guard let function = toolCall["function"] as? [String: Any], let name = function["name"] as? String else { continue } - let index = (toolCall["index"] as? Int) ?? offset let id = (toolCall["id"] as? String) ?? "call_\(index)_\(UUID().uuidString.prefix(8))" - - if indexToId[index] == nil { - indexToId[index] = id - order.append(index) - continuation.yield(.toolUseStart(id: id, name: name)) + if state.toolCallIndexToId[index] == nil { + state.toolCallIndexToId[index] = id + state.toolCallOrder.append(index) + events.append(.toolUseStart(id: id, name: name)) } - let argumentsString: String if let stringArgs = function["arguments"] as? String { argumentsString = stringArgs @@ -223,11 +210,11 @@ final class OpenAICompatibleProvider: ChatTransport { } else { argumentsString = "" } - - if !argumentsString.isEmpty, let resolvedId = indexToId[index] { - continuation.yield(.toolUseDelta(id: resolvedId, inputJSONDelta: argumentsString)) + if !argumentsString.isEmpty, let resolvedId = state.toolCallIndexToId[index] { + events.append(.toolUseDelta(id: resolvedId, inputJSONDelta: argumentsString)) } } + return events } func fetchAvailableModels() async throws -> [String] { @@ -500,3 +487,29 @@ final class OpenAICompatibleProvider: ChatTransport { return models.compactMap { $0["name"] as? String }.sorted() } } + +/// Mutable state carried across `OpenAICompatibleProvider.parseChunk` calls. +struct OpenAIStreamState { + var inputTokens: Int = 0 + var outputTokens: Int = 0 + var toolCallIndexToId: [Int: String] = [:] + var toolCallOrder: [Int] = [] + + /// Yield `.toolUseEnd` for every tracked tool call and clear the map. + /// Called when the provider signals tool-call completion (`finish_reason` + /// or Ollama `done: true`). + mutating func flushToolUseEnds() -> [ChatStreamEvent] { + let events: [ChatStreamEvent] = toolCallOrder.compactMap { index in + guard let id = toolCallIndexToId[index] else { return nil } + return .toolUseEnd(id: id) + } + toolCallIndexToId.removeAll() + toolCallOrder.removeAll() + return events + } + + func finalUsageEvent() -> ChatStreamEvent? { + guard inputTokens > 0 || outputTokens > 0 else { return nil } + return .usage(AITokenUsage(inputTokens: inputTokens, outputTokens: outputTokens)) + } +} diff --git a/TableProTests/Core/AI/AnthropicProviderParserTests.swift b/TableProTests/Core/AI/AnthropicProviderParserTests.swift new file mode 100644 index 000000000..8e5acd96f --- /dev/null +++ b/TableProTests/Core/AI/AnthropicProviderParserTests.swift @@ -0,0 +1,183 @@ +// +// AnthropicProviderParserTests.swift +// TableProTests +// + +import Foundation +@testable import TablePro +import Testing + +@Suite("AnthropicProvider stream parser") +struct AnthropicProviderParserTests { + private func parse(_ json: [String: Any], state: inout AnthropicStreamState) throws -> [ChatStreamEvent] { + try AnthropicProvider.parseChunk(json, state: &state) + } + + @Test("text_delta yields textDelta") + func textDelta() throws { + var state = AnthropicStreamState() + let events = try parse([ + "type": "content_block_delta", + "delta": ["type": "text_delta", "text": "hello"] + ], state: &state) + guard case .textDelta(let text) = events.first else { + Issue.record("expected textDelta; got \(events)") + return + } + #expect(text == "hello") + } + + @Test("content_block_start with tool_use yields toolUseStart and remembers index→id") + func toolUseStart() throws { + var state = AnthropicStreamState() + let events = try parse([ + "type": "content_block_start", + "index": 1, + "content_block": [ + "type": "tool_use", + "id": "toolu_abc", + "name": "list_tables" + ] + ], state: &state) + #expect(events.count == 1) + if case .toolUseStart(let id, let name) = events.first { + #expect(id == "toolu_abc") + #expect(name == "list_tables") + } else { + Issue.record("expected toolUseStart; got \(events)") + } + #expect(state.toolUseIdsByIndex[1] == "toolu_abc") + } + + @Test("input_json_delta resolves index back to id from state") + func inputJSONDelta() throws { + var state = AnthropicStreamState() + state.toolUseIdsByIndex[1] = "toolu_abc" + let events = try parse([ + "type": "content_block_delta", + "index": 1, + "delta": ["type": "input_json_delta", "partial_json": #"{"foo":"#] + ], state: &state) + if case .toolUseDelta(let id, let delta) = events.first { + #expect(id == "toolu_abc") + #expect(delta == #"{"foo":"#) + } else { + Issue.record("expected toolUseDelta; got \(events)") + } + } + + @Test("input_json_delta for unknown index yields nothing") + func inputJSONDeltaUnknownIndex() throws { + var state = AnthropicStreamState() + let events = try parse([ + "type": "content_block_delta", + "index": 99, + "delta": ["type": "input_json_delta", "partial_json": "x"] + ], state: &state) + #expect(events.isEmpty) + } + + @Test("content_block_stop yields toolUseEnd and clears the index mapping") + func toolUseEnd() throws { + var state = AnthropicStreamState() + state.toolUseIdsByIndex[1] = "toolu_abc" + let events = try parse([ + "type": "content_block_stop", + "index": 1 + ], state: &state) + if case .toolUseEnd(let id) = events.first { + #expect(id == "toolu_abc") + } else { + Issue.record("expected toolUseEnd; got \(events)") + } + #expect(state.toolUseIdsByIndex[1] == nil) + } + + @Test("Fragmented input_json_delta concatenates correctly via state") + func fragmentedDelta() throws { + var state = AnthropicStreamState() + _ = try parse([ + "type": "content_block_start", + "index": 0, + "content_block": ["type": "tool_use", "id": "tid", "name": "list_tables"] + ], state: &state) + let chunk1 = try parse([ + "type": "content_block_delta", + "index": 0, + "delta": ["type": "input_json_delta", "partial_json": #"{"con"#] + ], state: &state) + let chunk2 = try parse([ + "type": "content_block_delta", + "index": 0, + "delta": ["type": "input_json_delta", "partial_json": #"nection_id":"x"}"#] + ], state: &state) + let stop = try parse([ + "type": "content_block_stop", + "index": 0 + ], state: &state) + let combined = chunk1 + chunk2 + stop + // textDelta accumulator on the consumer side reassembles fragments. + let deltas = combined.compactMap { event -> String? in + if case .toolUseDelta(_, let d) = event { return d } + return nil + } + #expect(deltas.joined() == #"{"connection_id":"x"}"#) + } + + @Test("message_start tracks input tokens") + func messageStart() throws { + var state = AnthropicStreamState() + _ = try parse([ + "type": "message_start", + "message": ["usage": ["input_tokens": 42]] + ], state: &state) + #expect(state.inputTokens == 42) + } + + @Test("message_delta tracks output tokens") + func messageDelta() throws { + var state = AnthropicStreamState() + _ = try parse([ + "type": "message_delta", + "usage": ["output_tokens": 100] + ], state: &state) + #expect(state.outputTokens == 100) + } + + @Test("finalUsageEvent emits .usage when tokens were observed") + func finalUsage() throws { + var state = AnthropicStreamState() + state.inputTokens = 42 + state.outputTokens = 100 + guard case .usage(let usage) = state.finalUsageEvent() else { + Issue.record("expected usage event") + return + } + #expect(usage.inputTokens == 42) + #expect(usage.outputTokens == 100) + } + + @Test("finalUsageEvent returns nil when no tokens observed") + func noUsageNoEvent() { + let state = AnthropicStreamState() + #expect(state.finalUsageEvent() == nil) + } + + @Test("error event throws streamingFailed") + func errorEvent() { + var state = AnthropicStreamState() + #expect(throws: AIProviderError.self) { + _ = try AnthropicProvider.parseChunk( + ["type": "error", "error": ["message": "rate limited"]], + state: &state + ) + } + } + + @Test("decodeStreamLine returns nil for non-data lines and [DONE]") + func framingDecode() { + #expect(AnthropicProvider.decodeStreamLine("event: message_stop") == nil) + #expect(AnthropicProvider.decodeStreamLine("data: [DONE]") == nil) + #expect(AnthropicProvider.decodeStreamLine("data: {\"type\":\"x\"}") != nil) + } +} diff --git a/TableProTests/Core/AI/GeminiProviderParserTests.swift b/TableProTests/Core/AI/GeminiProviderParserTests.swift new file mode 100644 index 000000000..e1f4d560f --- /dev/null +++ b/TableProTests/Core/AI/GeminiProviderParserTests.swift @@ -0,0 +1,132 @@ +// +// GeminiProviderParserTests.swift +// TableProTests +// + +import Foundation +@testable import TablePro +import Testing + +@Suite("GeminiProvider stream parser") +struct GeminiProviderParserTests { + private let stableID = "stable-id" + + private func parse(_ json: [String: Any], state: inout GeminiStreamState) -> [ChatStreamEvent] { + GeminiProvider.parseChunk(json, state: &state, idGenerator: { self.stableID }) + } + + @Test("Text part yields textDelta") + func textPart() { + var state = GeminiStreamState() + let events = parse([ + "candidates": [[ + "content": [ + "parts": [["text": "hello"]] + ] + ]] + ], state: &state) + guard case .textDelta(let text) = events.first else { + Issue.record("expected textDelta") + return + } + #expect(text == "hello") + } + + @Test("functionCall part yields the start/delta/end trio in one chunk") + func functionCallTrio() { + var state = GeminiStreamState() + let events = parse([ + "candidates": [[ + "content": [ + "parts": [[ + "functionCall": [ + "name": "list_tables", + "args": ["connection_id": "abc"] + ] + ]] + ] + ]] + ], state: &state) + #expect(events.count == 3) + if case .toolUseStart(let id, let name) = events[0] { + #expect(id == stableID) + #expect(name == "list_tables") + } else { + Issue.record("expected toolUseStart at index 0") + } + if case .toolUseDelta(let id, let delta) = events[1] { + #expect(id == stableID) + #expect(delta.contains("connection_id")) + #expect(delta.contains("abc")) + } else { + Issue.record("expected toolUseDelta at index 1") + } + if case .toolUseEnd(let id) = events[2] { + #expect(id == stableID) + } else { + Issue.record("expected toolUseEnd at index 2") + } + } + + @Test("Mixed text + functionCall parts yield events in part order") + func mixedParts() { + var state = GeminiStreamState() + let events = parse([ + "candidates": [[ + "content": [ + "parts": [ + ["text": "I'll check"], + ["functionCall": ["name": "list_tables", "args": [String: Any]()]] + ] + ] + ]] + ], state: &state) + // Order: textDelta, toolUseStart, toolUseDelta, toolUseEnd + #expect(events.count == 4) + guard case .textDelta = events[0] else { + Issue.record("expected textDelta first") + return + } + guard case .toolUseStart = events[1] else { + Issue.record("expected toolUseStart second") + return + } + } + + @Test("Empty parts array yields no events") + func emptyParts() { + var state = GeminiStreamState() + let events = parse([ + "candidates": [[ + "content": ["parts": [[String: Any]]()] + ]] + ], state: &state) + #expect(events.isEmpty) + } + + @Test("usageMetadata populates state token counters") + func usageTokens() { + var state = GeminiStreamState() + _ = parse([ + "usageMetadata": [ + "promptTokenCount": 100, + "candidatesTokenCount": 50 + ] + ], state: &state) + #expect(state.inputTokens == 100) + #expect(state.outputTokens == 50) + } + + @Test("encodeArgsToJSONString returns {} on invalid input") + func argsFallback() { + let invalid: Any = NSObject() // not JSON-serializable + #expect(GeminiProvider.encodeArgsToJSONString(invalid) == "{}") + } + + @Test("encodeArgsToJSONString round-trips a valid object") + func argsRoundTrip() { + let result = GeminiProvider.encodeArgsToJSONString(["a": 1, "b": "x"]) + #expect(result.contains("\"a\"")) + #expect(result.contains("\"b\"")) + } +} diff --git a/TableProTests/Core/AI/OpenAICompatibleProviderParserTests.swift b/TableProTests/Core/AI/OpenAICompatibleProviderParserTests.swift new file mode 100644 index 000000000..4373a1dd4 --- /dev/null +++ b/TableProTests/Core/AI/OpenAICompatibleProviderParserTests.swift @@ -0,0 +1,181 @@ +// +// OpenAICompatibleProviderParserTests.swift +// TableProTests +// + +import Foundation +@testable import TablePro +import Testing + +@Suite("OpenAICompatibleProvider stream parser") +struct OpenAICompatibleProviderParserTests { + @Test("delta.content yields textDelta") + func textDelta() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk([ + "choices": [[ + "delta": ["content": "hello"] + ]] + ], state: &state) + #expect(result.shouldBreak == false) + guard case .textDelta(let text) = result.events.first else { + Issue.record("expected textDelta; got \(result.events)") + return + } + #expect(text == "hello") + } + + @Test("First tool_calls chunk emits toolUseStart with id and name") + func toolUseStart() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk([ + "choices": [[ + "delta": [ + "tool_calls": [[ + "index": 0, + "id": "call_abc", + "type": "function", + "function": ["name": "list_tables", "arguments": ""] + ]] + ] + ]] + ], state: &state) + #expect(result.events.count == 1) + if case .toolUseStart(let id, let name) = result.events.first { + #expect(id == "call_abc") + #expect(name == "list_tables") + } else { + Issue.record("expected toolUseStart; got \(result.events)") + } + #expect(state.toolCallIndexToId[0] == "call_abc") + } + + @Test("Subsequent tool_calls chunks emit toolUseDelta only") + func toolUseDelta() { + var state = OpenAIStreamState() + state.toolCallIndexToId[0] = "call_abc" + state.toolCallOrder = [0] + let result = OpenAICompatibleProvider.parseChunk([ + "choices": [[ + "delta": [ + "tool_calls": [[ + "index": 0, + "function": ["arguments": #"{"foo":"#] + ]] + ] + ]] + ], state: &state) + #expect(result.events.count == 1) + if case .toolUseDelta(let id, let delta) = result.events.first { + #expect(id == "call_abc") + #expect(delta == #"{"foo":"#) + } else { + Issue.record("expected toolUseDelta; got \(result.events)") + } + } + + @Test("finish_reason: tool_calls flushes toolUseEnds for all tracked calls") + func finishReasonTriggersFlush() { + var state = OpenAIStreamState() + state.toolCallIndexToId = [0: "call_a", 1: "call_b"] + state.toolCallOrder = [0, 1] + let result = OpenAICompatibleProvider.parseChunk([ + "choices": [["finish_reason": "tool_calls"]] + ], state: &state) + let endIds = result.events.compactMap { event -> String? in + if case .toolUseEnd(let id) = event { return id } + return nil + } + #expect(endIds == ["call_a", "call_b"]) + #expect(state.toolCallIndexToId.isEmpty) + #expect(state.toolCallOrder.isEmpty) + } + + @Test("Ollama message.tool_calls with arguments-as-object encodes to JSON string") + func ollamaArgumentsAsObject() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk([ + "message": [ + "tool_calls": [[ + "function": [ + "name": "list_tables", + "arguments": ["connection_id": "abc"] // object, not string + ] + ]] + ] + ], state: &state) + let deltaPayload = result.events.compactMap { event -> String? in + if case .toolUseDelta(_, let s) = event { return s } + return nil + }.first + #expect(deltaPayload?.contains("connection_id") == true) + #expect(deltaPayload?.contains("abc") == true) + } + + @Test("Ollama message.tool_calls with arguments-as-string passes through verbatim") + func ollamaArgumentsAsString() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk([ + "message": [ + "tool_calls": [[ + "function": [ + "name": "list_tables", + "arguments": #"{"connection_id":"abc"}"# + ] + ]] + ] + ], state: &state) + let delta = result.events.compactMap { event -> String? in + if case .toolUseDelta(_, let s) = event { return s } + return nil + }.first + #expect(delta == #"{"connection_id":"abc"}"#) + } + + @Test("Ollama done: true sets shouldBreak and flushes pending tool ends") + func ollamaDoneFlushesAndBreaks() { + var state = OpenAIStreamState() + state.toolCallIndexToId[0] = "call_a" + state.toolCallOrder = [0] + let result = OpenAICompatibleProvider.parseChunk([ + "done": true, + "prompt_eval_count": 50, + "eval_count": 200 + ], state: &state) + #expect(result.shouldBreak == true) + #expect(result.events.contains(where: { event in + if case .toolUseEnd(let id) = event { return id == "call_a" } + return false + })) + #expect(state.inputTokens == 50) + #expect(state.outputTokens == 200) + } + + @Test("usage object populates state token counters") + func usageTokens() { + var state = OpenAIStreamState() + _ = OpenAICompatibleProvider.parseChunk([ + "usage": ["prompt_tokens": 30, "completion_tokens": 90] + ], state: &state) + #expect(state.inputTokens == 30) + #expect(state.outputTokens == 90) + } + + @Test("decodeStreamLine respects providerType (SSE vs NDJSON)") + func decodeStreamLineFraming() { + let openAIParsed = OpenAICompatibleProvider.decodeStreamLine( + #"data: {"choices":[]}"#, + providerType: .openAI + ) + #expect(openAIParsed != nil) + let openAIDone = OpenAICompatibleProvider.decodeStreamLine("data: [DONE]", providerType: .openAI) + #expect(openAIDone == nil) + let ollamaParsed = OpenAICompatibleProvider.decodeStreamLine( + #"{"done":true}"#, + providerType: .ollama + ) + #expect(ollamaParsed != nil) + let ollamaEmpty = OpenAICompatibleProvider.decodeStreamLine("", providerType: .ollama) + #expect(ollamaEmpty == nil) + } +} From df36fbe451e971556ecb324933e5463d8b95ac1e Mon Sep 17 00:00:00 2001 From: Ngo Quoc Dat Date: Thu, 7 May 2026 15:07:56 +0700 Subject: [PATCH 2/2] refactor: address parser-tests review (restore Gemini OSLog warnings, doc idGenerator, default-case + message.content tests) --- TablePro/Core/AI/GeminiProvider.swift | 23 +++++++++----- .../Core/AI/OpenAICompatibleProvider.swift | 5 ++- .../AI/AnthropicProviderParserTests.swift | 20 ++++++++++++ .../Core/AI/GeminiProviderParserTests.swift | 31 +++++++++++++++++++ .../OpenAICompatibleProviderParserTests.swift | 29 +++++++++++++++++ 5 files changed, 100 insertions(+), 8 deletions(-) diff --git a/TablePro/Core/AI/GeminiProvider.swift b/TablePro/Core/AI/GeminiProvider.swift index 8d19de5d9..ebe7f3980 100644 --- a/TablePro/Core/AI/GeminiProvider.swift +++ b/TablePro/Core/AI/GeminiProvider.swift @@ -301,9 +301,13 @@ final class GeminiProvider: ChatTransport { return json } - /// Translate a single Gemini chunk to events. The injected `idGenerator` - /// lets tests pin the synthetic id Gemini doesn't provide; production - /// passes `UUID().uuidString`. + /// Translate a single Gemini chunk to events. + /// + /// Gemini does not provide tool-call ids on `functionCall` parts, so we + /// synthesize one per call. `idGenerator` is injected so tests can pin the + /// synthetic id to a stable value; production passes `{ UUID().uuidString }`. + /// Each call to `idGenerator()` returns a fresh id, so multiple + /// `functionCall` parts in one chunk get distinct ids in production. static func parseChunk( _ json: [String: Any], state: inout GeminiStreamState, @@ -341,12 +345,17 @@ final class GeminiProvider: ChatTransport { } static func encodeArgsToJSONString(_ args: Any) -> String { - guard JSONSerialization.isValidJSONObject(args) else { return "{}" } - guard let data = try? JSONSerialization.data(withJSONObject: args), - let string = String(data: data, encoding: .utf8) else { + guard JSONSerialization.isValidJSONObject(args) else { + Self.logger.warning("Gemini functionCall args was not a valid JSON object; falling back to empty input") + return "{}" + } + do { + let data = try JSONSerialization.data(withJSONObject: args) + return String(data: data, encoding: .utf8) ?? "{}" + } catch { + Self.logger.warning("Gemini functionCall args serialization failed: \(error.localizedDescription, privacy: .public)") return "{}" } - return string } } diff --git a/TablePro/Core/AI/OpenAICompatibleProvider.swift b/TablePro/Core/AI/OpenAICompatibleProvider.swift index 908d47c76..e2389d173 100644 --- a/TablePro/Core/AI/OpenAICompatibleProvider.swift +++ b/TablePro/Core/AI/OpenAICompatibleProvider.swift @@ -152,8 +152,11 @@ final class OpenAICompatibleProvider: ChatTransport { state.outputTokens = evalCount } + // Ollama signals stream-end via `done: true`. We flush again here only + // when finish_reason didn't already drain the tool-call map (which + // typically isn't set on Ollama responses). let shouldBreak = (json["done"] as? Bool) == true - if shouldBreak { + if shouldBreak, !state.toolCallIndexToId.isEmpty { events.append(contentsOf: state.flushToolUseEnds()) } return (events, shouldBreak) diff --git a/TableProTests/Core/AI/AnthropicProviderParserTests.swift b/TableProTests/Core/AI/AnthropicProviderParserTests.swift index 8e5acd96f..061054640 100644 --- a/TableProTests/Core/AI/AnthropicProviderParserTests.swift +++ b/TableProTests/Core/AI/AnthropicProviderParserTests.swift @@ -180,4 +180,24 @@ struct AnthropicProviderParserTests { #expect(AnthropicProvider.decodeStreamLine("data: [DONE]") == nil) #expect(AnthropicProvider.decodeStreamLine("data: {\"type\":\"x\"}") != nil) } + + @Test("Unknown event types yield no events and don't throw") + func unknownEventType() throws { + var state = AnthropicStreamState() + let events = try AnthropicProvider.parseChunk( + ["type": "ping", "extra": "ignored"], + state: &state + ) + #expect(events.isEmpty) + // State should be unchanged. + #expect(state.toolUseIdsByIndex.isEmpty) + #expect(state.inputTokens == 0) + } + + @Test("Chunk with no type field yields no events and doesn't throw") + func chunkWithoutType() throws { + var state = AnthropicStreamState() + let events = try AnthropicProvider.parseChunk(["random": "data"], state: &state) + #expect(events.isEmpty) + } } diff --git a/TableProTests/Core/AI/GeminiProviderParserTests.swift b/TableProTests/Core/AI/GeminiProviderParserTests.swift index e1f4d560f..db3435d36 100644 --- a/TableProTests/Core/AI/GeminiProviderParserTests.swift +++ b/TableProTests/Core/AI/GeminiProviderParserTests.swift @@ -129,4 +129,35 @@ struct GeminiProviderParserTests { #expect(result.contains("\"a\"")) #expect(result.contains("\"b\"")) } + + @Test("Chunk without candidates yields no events") + func chunkWithoutCandidates() { + var state = GeminiStreamState() + let events = parse(["unrelated": "data"], state: &state) + #expect(events.isEmpty) + } + + @Test("Multiple functionCall parts in one chunk get distinct ids") + func multipleFunctionCallsGetDistinctIds() { + var state = GeminiStreamState() + var counter = 0 + let events = GeminiProvider.parseChunk([ + "candidates": [[ + "content": [ + "parts": [ + ["functionCall": ["name": "list_tables", "args": [String: Any]()]], + ["functionCall": ["name": "describe_table", "args": [String: Any]()]] + ] + ] + ]] + ], state: &state, idGenerator: { + defer { counter += 1 } + return "id-\(counter)" + }) + let starts = events.compactMap { event -> String? in + if case .toolUseStart(let id, _) = event { return id } + return nil + } + #expect(starts == ["id-0", "id-1"]) + } } diff --git a/TableProTests/Core/AI/OpenAICompatibleProviderParserTests.swift b/TableProTests/Core/AI/OpenAICompatibleProviderParserTests.swift index 4373a1dd4..614cd6bab 100644 --- a/TableProTests/Core/AI/OpenAICompatibleProviderParserTests.swift +++ b/TableProTests/Core/AI/OpenAICompatibleProviderParserTests.swift @@ -161,6 +161,35 @@ struct OpenAICompatibleProviderParserTests { #expect(state.outputTokens == 90) } + @Test("message.content path yields textDelta (Ollama non-stream + final-message OpenAI)") + func messageContentPathYieldsTextDelta() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk([ + "message": ["content": "hi"] + ], state: &state) + guard case .textDelta(let text) = result.events.first else { + Issue.record("expected textDelta from message.content; got \(result.events)") + return + } + #expect(text == "hi") + } + + @Test("Empty chunk yields no events and doesn't break") + func emptyChunk() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk([:], state: &state) + #expect(result.events.isEmpty) + #expect(result.shouldBreak == false) + } + + @Test("done: true with no pending tool calls breaks without emitting") + func doneWithNoPendingTools() { + var state = OpenAIStreamState() + let result = OpenAICompatibleProvider.parseChunk(["done": true], state: &state) + #expect(result.shouldBreak == true) + #expect(result.events.isEmpty) + } + @Test("decodeStreamLine respects providerType (SSE vs NDJSON)") func decodeStreamLineFraming() { let openAIParsed = OpenAICompatibleProvider.decodeStreamLine(