Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- AI Chat: new installations default to opt-in context. Schema, current query, and query results no longer auto-include in every prompt; attach them via the `@` menu when you want them. Existing users keep their current Settings -> AI -> "Include schema/current query/query results" choices unchanged.
- AI Chat: the Ask / Edit / Agent mode picker now changes which tools the provider sees. Ask exposes read-only schema lookups. Edit adds `execute_query` for SELECT/INSERT/UPDATE/DELETE. Agent adds destructive DDL via `confirm_destructive_operation`. The connection's safe mode policy still gates execution.

### Fixed

Expand Down
38 changes: 38 additions & 0 deletions TablePro/Core/AI/Chat/ChatToolRegistry.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ final class ChatToolRegistry {

private static let logger = Logger(subsystem: "com.TablePro", category: "ChatToolRegistry")

private static let readOnlyToolNames: Set<String> = [
"list_connections",
"get_connection_status",
"list_databases",
"list_schemas",
"list_tables",
"describe_table",
"get_table_ddl"
]

private static let editModeToolNames: Set<String> = readOnlyToolNames.union([
"execute_query"
])

private var tools: [String: any ChatTool] = [:]

init() {}
Expand All @@ -33,6 +47,11 @@ final class ChatToolRegistry {
tools[name]
}

func tool(named name: String, in mode: AIChatMode) -> (any ChatTool)? {
guard Self.isToolAllowed(name: name, in: mode) else { return nil }
return tools[name]
}

var allTools: [any ChatTool] {
tools.values
.sorted { $0.name < $1.name }
Expand All @@ -41,4 +60,23 @@ final class ChatToolRegistry {
var allSpecs: [ChatToolSpec] {
allTools.map(\.spec)
}

func allTools(for mode: AIChatMode) -> [any ChatTool] {
allTools.filter { Self.isToolAllowed(name: $0.name, in: mode) }
}

func allSpecs(for mode: AIChatMode) -> [ChatToolSpec] {
allTools(for: mode).map(\.spec)
}

static func isToolAllowed(name: String, in mode: AIChatMode) -> Bool {
switch mode {
case .ask:
return readOnlyToolNames.contains(name)
case .edit:
return editModeToolNames.contains(name)
case .agent:
return true
}
}
}
22 changes: 22 additions & 0 deletions TablePro/Models/AI/AIModels.swift
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,28 @@ enum AIChatMode: String, Codable, CaseIterable, Identifiable, Sendable {
case .agent: return "infinity"
}
}

var helpText: String {
switch self {
case .ask:
return String(localized: "Ask: read-only schema lookups. AI can browse but not run queries.")
case .edit:
return String(localized: "Edit: read-only tools plus running queries. Destructive DDL stays blocked.")
case .agent:
return String(localized: "Agent: full tool access including destructive DDL. Safe mode still gates execution.")
}
}

var systemPromptNote: String {
switch self {
case .ask:
return "You are in Ask mode. Tools are read-only: schema lookups only. You cannot run queries or modify data."
case .edit:
return "You are in Edit mode. You can read schema and run SELECT/INSERT/UPDATE/DELETE via execute_query. Destructive DDL is blocked."
case .agent:
return "You are in Agent mode. All tools are available, including destructive DDL via confirm_destructive_operation. Safe mode policy still gates execution."
}
}
}

// MARK: - AI Settings
Expand Down
30 changes: 23 additions & 7 deletions TablePro/ViewModels/AIChatViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -828,9 +828,10 @@ final class AIChatViewModel {
assistantID: UUID,
settings: AISettings
) {
let chatMode = settings.chatMode
streamingTask = Task.detached(priority: .userInitiated) { [weak self] in
do {
let systemPrompt = Self.buildSystemPrompt(promptContext)
let systemPrompt = Self.buildSystemPrompt(promptContext, mode: chatMode)
guard let self else { return }
let preflightOK = await self.preflightCheck(
systemPrompt: systemPrompt,
Expand All @@ -839,7 +840,7 @@ final class AIChatViewModel {
)
guard preflightOK else { return }

let toolSpecs = await MainActor.run { ChatToolRegistry.shared.allSpecs }
let toolSpecs = await MainActor.run { ChatToolRegistry.shared.allSpecs(for: chatMode) }
var workingTurns = chatMessages
var currentAssistantID = assistantID
let flushInterval: ContinuousClock.Duration = .milliseconds(150)
Expand Down Expand Up @@ -934,7 +935,7 @@ final class AIChatViewModel {
authPolicy: ChatToolBootstrap.authPolicy
)
}
let toolResultBlocks = await Self.executeToolUses(toolUseBlocks, context: context)
let toolResultBlocks = await Self.executeToolUses(toolUseBlocks, mode: chatMode, context: context)
guard !Task.isCancelled else { return }

let continuation = await self.appendToolRoundtrip(
Expand Down Expand Up @@ -979,8 +980,8 @@ final class AIChatViewModel {
}
}

nonisolated private static func buildSystemPrompt(_ promptContext: PromptContext?) -> String? {
promptContext.map {
nonisolated private static func buildSystemPrompt(_ promptContext: PromptContext?, mode: AIChatMode) -> String? {
let schemaPrompt = promptContext.map {
AISchemaContext.buildSystemPrompt(
databaseType: $0.databaseType,
databaseName: $0.databaseName,
Expand All @@ -995,6 +996,9 @@ final class AIChatViewModel {
queryLanguageName: $0.queryLanguageName
)
}
let modeNote = mode.systemPromptNote
guard let schemaPrompt, !schemaPrompt.isEmpty else { return modeNote }
return "\(schemaPrompt)\n\n\(modeNote)"
}

private struct ToolRoundtripContinuation {
Expand Down Expand Up @@ -1110,13 +1114,14 @@ final class AIChatViewModel {
/// avoid polluting global state.
nonisolated static func executeToolUses(
_ blocks: [ToolUseBlock],
mode: AIChatMode,
context: ChatToolContext,
registry: ChatToolRegistry? = nil
) async -> [ToolResultBlock] {
await withTaskGroup(of: (Int, ToolResultBlock).self) { group in
for (index, block) in blocks.enumerated() {
group.addTask {
(index, await runToolUse(block, context: context, registry: registry))
(index, await runToolUse(block, mode: mode, context: context, registry: registry))
}
}
var indexed: [(Int, ToolResultBlock)] = []
Expand All @@ -1127,14 +1132,25 @@ final class AIChatViewModel {

nonisolated private static func runToolUse(
_ block: ToolUseBlock,
mode: AIChatMode,
context: ChatToolContext,
registry: ChatToolRegistry?
) async -> ToolResultBlock {
if Task.isCancelled {
return ToolResultBlock(toolUseId: block.id, content: "Cancelled", isError: true)
}
guard ChatToolRegistry.isToolAllowed(name: block.name, in: mode) else {
Self.logger.warning(
"Tool '\(block.name, privacy: .public)' blocked in \(mode.rawValue, privacy: .public) mode"
)
return ToolResultBlock(
toolUseId: block.id,
content: "Tool '\(block.name)' is not available in \(mode.displayName) mode",
isError: true
)
}
let tool = await MainActor.run {
(registry ?? ChatToolRegistry.shared).tool(named: block.name)
(registry ?? ChatToolRegistry.shared).tool(named: block.name, in: mode)
}
guard let tool else {
Self.logger.warning("Tool '\(block.name, privacy: .public)' not registered; returning error")
Expand Down
2 changes: 1 addition & 1 deletion TablePro/Views/AIChat/AIChatPanelView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ struct AIChatPanelView: View {
}
.menuStyle(.borderlessButton)
.fixedSize()
.help(String(localized: "Chat mode"))
.help(settingsManager.ai.chatMode.helpText)
}

@ViewBuilder
Expand Down
99 changes: 99 additions & 0 deletions TableProTests/Core/AI/ChatToolRegistryModeTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//
// ChatToolRegistryModeTests.swift
// TableProTests
//

import Foundation
@testable import TablePro
import Testing

@Suite("ChatToolRegistry mode gating")
@MainActor
struct ChatToolRegistryModeTests {
private struct StubTool: ChatTool {
let name: String
let description = ""
let inputSchema: JSONValue = .object(["type": .string("object")])

func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult {
ChatToolResult(content: "ok")
}
}

private static let readOnlyToolNames: [String] = [
"list_connections",
"get_connection_status",
"list_databases",
"list_schemas",
"list_tables",
"describe_table",
"get_table_ddl"
]

private static func makeRegistryWithAllTools() -> ChatToolRegistry {
let registry = ChatToolRegistry()
for name in readOnlyToolNames {
registry.register(StubTool(name: name))
}
registry.register(StubTool(name: "execute_query"))
registry.register(StubTool(name: "confirm_destructive_operation"))
return registry
}

@Test("Ask mode exposes only read-only tools")
func askModeReadOnly() {
let registry = Self.makeRegistryWithAllTools()
let names = Set(registry.allSpecs(for: .ask).map(\.name))
#expect(names == Set(Self.readOnlyToolNames))
#expect(!names.contains("execute_query"))
#expect(!names.contains("confirm_destructive_operation"))
}

@Test("Edit mode adds execute_query but blocks confirm_destructive_operation")
func editModeAddsExecuteQuery() {
let registry = Self.makeRegistryWithAllTools()
let names = Set(registry.allSpecs(for: .edit).map(\.name))
let expected = Set(Self.readOnlyToolNames + ["execute_query"])
#expect(names == expected)
#expect(names.contains("execute_query"))
#expect(!names.contains("confirm_destructive_operation"))
}

@Test("Agent mode exposes every registered tool including confirm_destructive_operation")
func agentModeExposesAll() {
let registry = Self.makeRegistryWithAllTools()
let names = Set(registry.allSpecs(for: .agent).map(\.name))
let expected = Set(Self.readOnlyToolNames + ["execute_query", "confirm_destructive_operation"])
#expect(names == expected)
#expect(names.contains("confirm_destructive_operation"))
}

@Test("isToolAllowed agrees with allSpecs for every mode and tool name")
func isToolAllowedMatchesSpecs() {
let registry = Self.makeRegistryWithAllTools()
for mode in AIChatMode.allCases {
let allowedFromSpecs = Set(registry.allSpecs(for: mode).map(\.name))
for tool in registry.allTools {
let allowed = ChatToolRegistry.isToolAllowed(name: tool.name, in: mode)
#expect(allowed == allowedFromSpecs.contains(tool.name))
}
}
}

@Test("tool(named:in:) returns nil for tools blocked by the mode")
func toolLookupRespectsMode() {
let registry = Self.makeRegistryWithAllTools()
#expect(registry.tool(named: "execute_query", in: .ask) == nil)
#expect(registry.tool(named: "execute_query", in: .edit)?.name == "execute_query")
#expect(registry.tool(named: "confirm_destructive_operation", in: .edit) == nil)
#expect(registry.tool(named: "confirm_destructive_operation", in: .agent)?.name == "confirm_destructive_operation")
#expect(registry.tool(named: "list_tables", in: .ask)?.name == "list_tables")
}

@Test("Unknown tool names are not allowed in any mode except agent")
func unknownToolsBlockedOutsideAgent() {
#expect(ChatToolRegistry.isToolAllowed(name: "future_tool", in: .ask) == false)
#expect(ChatToolRegistry.isToolAllowed(name: "future_tool", in: .edit) == false)
#expect(ChatToolRegistry.isToolAllowed(name: "future_tool", in: .agent) == true)
}
}
42 changes: 42 additions & 0 deletions TableProTests/Core/AI/ExecuteToolUsesTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ struct ExecuteToolUsesTests {
let blocks = [ToolUseBlock(id: "u1", name: "alpha", input: .object([:]))]
let results = await AIChatViewModel.executeToolUses(
blocks,
mode: .agent,
context: makeContext(),
registry: registry
)
Expand All @@ -83,6 +84,7 @@ struct ExecuteToolUsesTests {
]
let results = await AIChatViewModel.executeToolUses(
blocks,
mode: .agent,
context: makeContext(),
registry: registry
)
Expand All @@ -96,6 +98,7 @@ struct ExecuteToolUsesTests {
let blocks = [ToolUseBlock(id: "u1", name: "ghost", input: .object([:]))]
let results = await AIChatViewModel.executeToolUses(
blocks,
mode: .agent,
context: makeContext(),
registry: registry
)
Expand All @@ -111,6 +114,7 @@ struct ExecuteToolUsesTests {
let blocks = [ToolUseBlock(id: "u1", name: "boom", input: .object([:]))]
let results = await AIChatViewModel.executeToolUses(
blocks,
mode: .agent,
context: makeContext(),
registry: registry
)
Expand All @@ -126,6 +130,7 @@ struct ExecuteToolUsesTests {
let blocks = [ToolUseBlock(id: "u1", name: "warn", input: .object([:]))]
let results = await AIChatViewModel.executeToolUses(
blocks,
mode: .agent,
context: makeContext(),
registry: registry
)
Expand All @@ -143,6 +148,7 @@ struct ExecuteToolUsesTests {
]
let results = await AIChatViewModel.executeToolUses(
blocks,
mode: .agent,
context: makeContext(),
registry: registry
)
Expand All @@ -160,6 +166,7 @@ struct ExecuteToolUsesTests {
let input: JSONValue = .object(["query": .string("SELECT 1")])
_ = await AIChatViewModel.executeToolUses(
[ToolUseBlock(id: "u1", name: "alpha", input: input)],
mode: .agent,
context: makeContext(),
registry: registry
)
Expand All @@ -172,9 +179,44 @@ struct ExecuteToolUsesTests {
let registry = ChatToolRegistry()
let results = await AIChatViewModel.executeToolUses(
[],
mode: .agent,
context: makeContext(),
registry: registry
)
#expect(results.isEmpty)
}

@Test("execute_query blocked in Ask mode returns isError result without invoking tool")
func askModeBlocksExecuteQuery() async {
let registry = ChatToolRegistry()
let stub = StubTool(name: "execute_query", response: "should-not-run")
registry.register(stub)
let blocks = [ToolUseBlock(id: "u1", name: "execute_query", input: .object([:]))]
let results = await AIChatViewModel.executeToolUses(
blocks,
mode: .ask,
context: makeContext(),
registry: registry
)
#expect(results.count == 1)
#expect(results[0].isError == true)
#expect(stub.invocations.isEmpty)
}

@Test("confirm_destructive_operation blocked in Edit mode returns isError result")
func editModeBlocksDestructiveConfirm() async {
let registry = ChatToolRegistry()
let stub = StubTool(name: "confirm_destructive_operation", response: "should-not-run")
registry.register(stub)
let blocks = [ToolUseBlock(id: "u1", name: "confirm_destructive_operation", input: .object([:]))]
let results = await AIChatViewModel.executeToolUses(
blocks,
mode: .edit,
context: makeContext(),
registry: registry
)
#expect(results.count == 1)
#expect(results[0].isError == true)
#expect(stub.invocations.isEmpty)
}
}
Loading