Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 93 additions & 65 deletions TablePro/Core/AI/AnthropicProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,74 +43,15 @@ final class AnthropicProvider: ChatTransport {
)
}

var inputTokens = 0
var outputTokens = 0
var toolUseIdsByIndex: [Int: String] = [:]

var state = AnthropicStreamState()
for try await line in bytes.lines {
if Task.isCancelled { break }

guard line.hasPrefix("data: ") else { continue }
let jsonString = String(line.dropFirst(6))
guard jsonString != "[DONE]",
let data = jsonString.data(using: .utf8),
let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
let type = json["type"] as? String
else { continue }

switch type {
case "content_block_start":
if let index = json["index"] as? Int,
let block = json["content_block"] as? [String: Any],
(block["type"] as? String) == "tool_use",
let blockId = block["id"] as? String,
let blockName = block["name"] as? String {
toolUseIdsByIndex[index] = blockId
continuation.yield(.toolUseStart(id: blockId, name: blockName))
}
case "content_block_delta":
guard let delta = json["delta"] as? [String: Any] else { break }
let deltaType = delta["type"] as? String
if deltaType == "input_json_delta" {
if let index = json["index"] as? Int,
let id = toolUseIdsByIndex[index],
let partial = delta["partial_json"] as? String {
continuation.yield(.toolUseDelta(id: id, inputJSONDelta: partial))
}
} else if let text = delta["text"] as? String {
continuation.yield(.textDelta(text))
}
case "content_block_stop":
if let index = json["index"] as? Int,
let id = toolUseIdsByIndex.removeValue(forKey: index) {
continuation.yield(.toolUseEnd(id: id))
}
case "message_start":
if let message = json["message"] as? [String: Any],
let usage = message["usage"] as? [String: Any],
let tokens = usage["input_tokens"] as? Int {
inputTokens = tokens
}
case "message_delta":
if let usage = json["usage"] as? [String: Any],
let tokens = usage["output_tokens"] as? Int {
outputTokens = tokens
}
case "error":
if let errorObj = json["error"] as? [String: Any],
let message = errorObj["message"] as? String {
throw AIProviderError.streamingFailed(message)
}
default:
break
}
guard let json = Self.decodeStreamLine(line) else { continue }
let events = try Self.parseChunk(json, state: &state)
for event in events { continuation.yield(event) }
}

if inputTokens > 0 || outputTokens > 0 {
continuation.yield(.usage(AITokenUsage(
inputTokens: inputTokens,
outputTokens: outputTokens
)))
if let usage = state.finalUsageEvent() {
continuation.yield(usage)
}

continuation.finish()
Expand Down Expand Up @@ -220,6 +161,81 @@ final class AnthropicProvider: ChatTransport {
return request
}

/// Decodes one SSE line of the form `data: {...}` to a JSON object.
/// Returns `nil` for non-data lines, the `[DONE]` sentinel, and unparsable
/// payloads. Keeping this separate from `parseChunk` lets tests skip the
/// SSE framing and feed JSON dictionaries directly.
static func decodeStreamLine(_ line: String) -> [String: Any]? {
guard line.hasPrefix("data: ") else { return nil }
let jsonString = String(line.dropFirst(6))
guard jsonString != "[DONE]",
let data = jsonString.data(using: .utf8),
let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any]
else { return nil }
return json
}

/// Translate a single Anthropic SSE event JSON into zero or more
/// `ChatStreamEvent`s. Mutates `state` to carry index→id mappings and
/// token counters across calls. Throws `AIProviderError.streamingFailed`
/// on `error` events.
static func parseChunk(
_ json: [String: Any],
state: inout AnthropicStreamState
) throws -> [ChatStreamEvent] {
guard let type = json["type"] as? String else { return [] }
switch type {
case "content_block_start":
guard let index = json["index"] as? Int,
let block = json["content_block"] as? [String: Any],
(block["type"] as? String) == "tool_use",
let blockId = block["id"] as? String,
let blockName = block["name"] as? String
else { return [] }
state.toolUseIdsByIndex[index] = blockId
return [.toolUseStart(id: blockId, name: blockName)]
case "content_block_delta":
guard let delta = json["delta"] as? [String: Any] else { return [] }
if (delta["type"] as? String) == "input_json_delta" {
guard let index = json["index"] as? Int,
let id = state.toolUseIdsByIndex[index],
let partial = delta["partial_json"] as? String
else { return [] }
return [.toolUseDelta(id: id, inputJSONDelta: partial)]
}
if let text = delta["text"] as? String {
return [.textDelta(text)]
}
return []
case "content_block_stop":
guard let index = json["index"] as? Int,
let id = state.toolUseIdsByIndex.removeValue(forKey: index)
else { return [] }
return [.toolUseEnd(id: id)]
case "message_start":
if let message = json["message"] as? [String: Any],
let usage = message["usage"] as? [String: Any],
let tokens = usage["input_tokens"] as? Int {
state.inputTokens = tokens
}
return []
case "message_delta":
if let usage = json["usage"] as? [String: Any],
let tokens = usage["output_tokens"] as? Int {
state.outputTokens = tokens
}
return []
case "error":
if let errorObj = json["error"] as? [String: Any],
let message = errorObj["message"] as? String {
throw AIProviderError.streamingFailed(message)
}
return []
default:
return []
}
}

static func encodeToolSpec(_ spec: ChatToolSpec) throws -> [String: Any] {
[
"name": spec.name,
Expand Down Expand Up @@ -282,3 +298,15 @@ final class AnthropicProvider: ChatTransport {
return try JSONSerialization.jsonObject(with: data, options: [.fragmentsAllowed])
}
}

/// Mutable state carried across `AnthropicProvider.parseChunk` calls.
struct AnthropicStreamState {
var inputTokens: Int = 0
var outputTokens: Int = 0
var toolUseIdsByIndex: [Int: String] = [:]

func finalUsageEvent() -> ChatStreamEvent? {
guard inputTokens > 0 || outputTokens > 0 else { return nil }
return .usage(AITokenUsage(inputTokens: inputTokens, outputTokens: outputTokens))
}
}
134 changes: 82 additions & 52 deletions TablePro/Core/AI/GeminiProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,54 +43,19 @@ final class GeminiProvider: ChatTransport {
)
}

var inputTokens = 0
var outputTokens = 0

var state = GeminiStreamState()
for try await line in bytes.lines {
if Task.isCancelled { break }

guard line.hasPrefix("data: ") else { continue }
let jsonString = String(line.dropFirst(6))

guard let data = jsonString.data(using: .utf8),
let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any]
else { continue }

if let candidates = json["candidates"] as? [[String: Any]],
let firstCandidate = candidates.first,
let content = firstCandidate["content"] as? [String: Any],
let parts = content["parts"] as? [[String: Any]] {
for part in parts {
if let text = part["text"] as? String, !text.isEmpty {
continuation.yield(.textDelta(text))
}
if let functionCall = part["functionCall"] as? [String: Any],
let name = functionCall["name"] as? String {
let id = UUID().uuidString
let argsObject = functionCall["args"] ?? [String: Any]()
let argsString = encodeArgsToJSONString(argsObject)
continuation.yield(.toolUseStart(id: id, name: name))
continuation.yield(.toolUseDelta(id: id, inputJSONDelta: argsString))
continuation.yield(.toolUseEnd(id: id))
}
}
}

if let usageMetadata = json["usageMetadata"] as? [String: Any] {
if let prompt = usageMetadata["promptTokenCount"] as? Int {
inputTokens = prompt
}
if let candidates = usageMetadata["candidatesTokenCount"] as? Int {
outputTokens = candidates
}
}
guard let json = Self.decodeStreamLine(line) else { continue }
let events = Self.parseChunk(
json,
state: &state,
idGenerator: { UUID().uuidString }
)
for event in events { continuation.yield(event) }
}

if inputTokens > 0 || outputTokens > 0 {
continuation.yield(.usage(AITokenUsage(
inputTokens: inputTokens,
outputTokens: outputTokens
)))
if let usage = state.finalUsageEvent() {
continuation.yield(usage)
}

continuation.finish()
Expand Down Expand Up @@ -317,7 +282,69 @@ final class GeminiProvider: ChatTransport {
}
}

private func encodeArgsToJSONString(_ args: Any) -> String {

func mapHTTPError(statusCode: Int, body: String) -> AIProviderError {
if statusCode == 403 {
let message = AIProviderError.parseErrorMessage(from: body) ?? body
return .authenticationFailed(message)
}
return AIProviderError.mapHTTPError(statusCode: statusCode, body: body)
}

/// Decodes one Gemini SSE line. Returns nil for non-data lines.
static func decodeStreamLine(_ line: String) -> [String: Any]? {
guard line.hasPrefix("data: ") else { return nil }
let jsonString = String(line.dropFirst(6))
guard let data = jsonString.data(using: .utf8),
let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any]
else { return nil }
return json
}

/// Translate a single Gemini chunk to events.
///
/// Gemini does not provide tool-call ids on `functionCall` parts, so we
/// synthesize one per call. `idGenerator` is injected so tests can pin the
/// synthetic id to a stable value; production passes `{ UUID().uuidString }`.
/// Each call to `idGenerator()` returns a fresh id, so multiple
/// `functionCall` parts in one chunk get distinct ids in production.
static func parseChunk(
_ json: [String: Any],
state: inout GeminiStreamState,
idGenerator: () -> String
) -> [ChatStreamEvent] {
var events: [ChatStreamEvent] = []
if let candidates = json["candidates"] as? [[String: Any]],
let firstCandidate = candidates.first,
let content = firstCandidate["content"] as? [String: Any],
let parts = content["parts"] as? [[String: Any]] {
for part in parts {
if let text = part["text"] as? String, !text.isEmpty {
events.append(.textDelta(text))
}
if let functionCall = part["functionCall"] as? [String: Any],
let name = functionCall["name"] as? String {
let id = idGenerator()
let argsObject = functionCall["args"] ?? [String: Any]()
let argsString = encodeArgsToJSONString(argsObject)
events.append(.toolUseStart(id: id, name: name))
events.append(.toolUseDelta(id: id, inputJSONDelta: argsString))
events.append(.toolUseEnd(id: id))
}
}
}
if let usageMetadata = json["usageMetadata"] as? [String: Any] {
if let prompt = usageMetadata["promptTokenCount"] as? Int {
state.inputTokens = prompt
}
if let candidates = usageMetadata["candidatesTokenCount"] as? Int {
state.outputTokens = candidates
}
}
return events
}

static func encodeArgsToJSONString(_ args: Any) -> String {
guard JSONSerialization.isValidJSONObject(args) else {
Self.logger.warning("Gemini functionCall args was not a valid JSON object; falling back to empty input")
return "{}"
Expand All @@ -330,12 +357,15 @@ final class GeminiProvider: ChatTransport {
return "{}"
}
}
}

func mapHTTPError(statusCode: Int, body: String) -> AIProviderError {
if statusCode == 403 {
let message = AIProviderError.parseErrorMessage(from: body) ?? body
return .authenticationFailed(message)
}
return AIProviderError.mapHTTPError(statusCode: statusCode, body: body)
/// Mutable state carried across `GeminiProvider.parseChunk` calls.
struct GeminiStreamState {
var inputTokens: Int = 0
var outputTokens: Int = 0

func finalUsageEvent() -> ChatStreamEvent? {
guard inputTokens > 0 || outputTokens > 0 else { return nil }
return .usage(AITokenUsage(inputTokens: inputTokens, outputTokens: outputTokens))
}
}
Loading
Loading