diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f4830bc9..38c44cb9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - AI Chat: new installations default to opt-in context. Schema, current query, and query results no longer auto-include in every prompt; attach them via the `@` menu when you want them. Existing users keep their current Settings -> AI -> "Include schema/current query/query results" choices unchanged. +- AI Chat: the Ask / Edit / Agent mode picker now changes which tools the provider sees. Ask exposes read-only schema lookups. Edit adds `execute_query` for SELECT/INSERT/UPDATE/DELETE. Agent adds destructive DDL via `confirm_destructive_operation`. The connection's safe mode policy still gates execution. ### Fixed diff --git a/TablePro/Core/AI/Chat/ChatToolRegistry.swift b/TablePro/Core/AI/Chat/ChatToolRegistry.swift index 6af88eaf9..358120593 100644 --- a/TablePro/Core/AI/Chat/ChatToolRegistry.swift +++ b/TablePro/Core/AI/Chat/ChatToolRegistry.swift @@ -13,6 +13,20 @@ final class ChatToolRegistry { private static let logger = Logger(subsystem: "com.TablePro", category: "ChatToolRegistry") + private static let readOnlyToolNames: Set = [ + "list_connections", + "get_connection_status", + "list_databases", + "list_schemas", + "list_tables", + "describe_table", + "get_table_ddl" + ] + + private static let editModeToolNames: Set = readOnlyToolNames.union([ + "execute_query" + ]) + private var tools: [String: any ChatTool] = [:] init() {} @@ -33,6 +47,11 @@ final class ChatToolRegistry { tools[name] } + func tool(named name: String, in mode: AIChatMode) -> (any ChatTool)? { + guard Self.isToolAllowed(name: name, in: mode) else { return nil } + return tools[name] + } + var allTools: [any ChatTool] { tools.values .sorted { $0.name < $1.name } @@ -41,4 +60,23 @@ final class ChatToolRegistry { var allSpecs: [ChatToolSpec] { allTools.map(\.spec) } + + func allTools(for mode: AIChatMode) -> [any ChatTool] { + allTools.filter { Self.isToolAllowed(name: $0.name, in: mode) } + } + + func allSpecs(for mode: AIChatMode) -> [ChatToolSpec] { + allTools(for: mode).map(\.spec) + } + + static func isToolAllowed(name: String, in mode: AIChatMode) -> Bool { + switch mode { + case .ask: + return readOnlyToolNames.contains(name) + case .edit: + return editModeToolNames.contains(name) + case .agent: + return true + } + } } diff --git a/TablePro/Models/AI/AIModels.swift b/TablePro/Models/AI/AIModels.swift index 6e5885154..9ba51e5f8 100644 --- a/TablePro/Models/AI/AIModels.swift +++ b/TablePro/Models/AI/AIModels.swift @@ -155,6 +155,28 @@ enum AIChatMode: String, Codable, CaseIterable, Identifiable, Sendable { case .agent: return "infinity" } } + + var helpText: String { + switch self { + case .ask: + return String(localized: "Ask: read-only schema lookups. AI can browse but not run queries.") + case .edit: + return String(localized: "Edit: read-only tools plus running queries. Destructive DDL stays blocked.") + case .agent: + return String(localized: "Agent: full tool access including destructive DDL. Safe mode still gates execution.") + } + } + + var systemPromptNote: String { + switch self { + case .ask: + return "You are in Ask mode. Tools are read-only: schema lookups only. You cannot run queries or modify data." + case .edit: + return "You are in Edit mode. You can read schema and run SELECT/INSERT/UPDATE/DELETE via execute_query. Destructive DDL is blocked." + case .agent: + return "You are in Agent mode. All tools are available, including destructive DDL via confirm_destructive_operation. Safe mode policy still gates execution." + } + } } // MARK: - AI Settings diff --git a/TablePro/ViewModels/AIChatViewModel.swift b/TablePro/ViewModels/AIChatViewModel.swift index a685c1ebc..d7638de18 100644 --- a/TablePro/ViewModels/AIChatViewModel.swift +++ b/TablePro/ViewModels/AIChatViewModel.swift @@ -828,9 +828,10 @@ final class AIChatViewModel { assistantID: UUID, settings: AISettings ) { + let chatMode = settings.chatMode streamingTask = Task.detached(priority: .userInitiated) { [weak self] in do { - let systemPrompt = Self.buildSystemPrompt(promptContext) + let systemPrompt = Self.buildSystemPrompt(promptContext, mode: chatMode) guard let self else { return } let preflightOK = await self.preflightCheck( systemPrompt: systemPrompt, @@ -839,7 +840,7 @@ final class AIChatViewModel { ) guard preflightOK else { return } - let toolSpecs = await MainActor.run { ChatToolRegistry.shared.allSpecs } + let toolSpecs = await MainActor.run { ChatToolRegistry.shared.allSpecs(for: chatMode) } var workingTurns = chatMessages var currentAssistantID = assistantID let flushInterval: ContinuousClock.Duration = .milliseconds(150) @@ -934,7 +935,7 @@ final class AIChatViewModel { authPolicy: ChatToolBootstrap.authPolicy ) } - let toolResultBlocks = await Self.executeToolUses(toolUseBlocks, context: context) + let toolResultBlocks = await Self.executeToolUses(toolUseBlocks, mode: chatMode, context: context) guard !Task.isCancelled else { return } let continuation = await self.appendToolRoundtrip( @@ -979,8 +980,8 @@ final class AIChatViewModel { } } - nonisolated private static func buildSystemPrompt(_ promptContext: PromptContext?) -> String? { - promptContext.map { + nonisolated private static func buildSystemPrompt(_ promptContext: PromptContext?, mode: AIChatMode) -> String? { + let schemaPrompt = promptContext.map { AISchemaContext.buildSystemPrompt( databaseType: $0.databaseType, databaseName: $0.databaseName, @@ -995,6 +996,9 @@ final class AIChatViewModel { queryLanguageName: $0.queryLanguageName ) } + let modeNote = mode.systemPromptNote + guard let schemaPrompt, !schemaPrompt.isEmpty else { return modeNote } + return "\(schemaPrompt)\n\n\(modeNote)" } private struct ToolRoundtripContinuation { @@ -1110,13 +1114,14 @@ final class AIChatViewModel { /// avoid polluting global state. nonisolated static func executeToolUses( _ blocks: [ToolUseBlock], + mode: AIChatMode, context: ChatToolContext, registry: ChatToolRegistry? = nil ) async -> [ToolResultBlock] { await withTaskGroup(of: (Int, ToolResultBlock).self) { group in for (index, block) in blocks.enumerated() { group.addTask { - (index, await runToolUse(block, context: context, registry: registry)) + (index, await runToolUse(block, mode: mode, context: context, registry: registry)) } } var indexed: [(Int, ToolResultBlock)] = [] @@ -1127,14 +1132,25 @@ final class AIChatViewModel { nonisolated private static func runToolUse( _ block: ToolUseBlock, + mode: AIChatMode, context: ChatToolContext, registry: ChatToolRegistry? ) async -> ToolResultBlock { if Task.isCancelled { return ToolResultBlock(toolUseId: block.id, content: "Cancelled", isError: true) } + guard ChatToolRegistry.isToolAllowed(name: block.name, in: mode) else { + Self.logger.warning( + "Tool '\(block.name, privacy: .public)' blocked in \(mode.rawValue, privacy: .public) mode" + ) + return ToolResultBlock( + toolUseId: block.id, + content: "Tool '\(block.name)' is not available in \(mode.displayName) mode", + isError: true + ) + } let tool = await MainActor.run { - (registry ?? ChatToolRegistry.shared).tool(named: block.name) + (registry ?? ChatToolRegistry.shared).tool(named: block.name, in: mode) } guard let tool else { Self.logger.warning("Tool '\(block.name, privacy: .public)' not registered; returning error") diff --git a/TablePro/Views/AIChat/AIChatPanelView.swift b/TablePro/Views/AIChat/AIChatPanelView.swift index 43b0f95ba..e1c28610d 100644 --- a/TablePro/Views/AIChat/AIChatPanelView.swift +++ b/TablePro/Views/AIChat/AIChatPanelView.swift @@ -278,7 +278,7 @@ struct AIChatPanelView: View { } .menuStyle(.borderlessButton) .fixedSize() - .help(String(localized: "Chat mode")) + .help(settingsManager.ai.chatMode.helpText) } @ViewBuilder diff --git a/TableProTests/Core/AI/ChatToolRegistryModeTests.swift b/TableProTests/Core/AI/ChatToolRegistryModeTests.swift new file mode 100644 index 000000000..97663c19b --- /dev/null +++ b/TableProTests/Core/AI/ChatToolRegistryModeTests.swift @@ -0,0 +1,99 @@ +// +// ChatToolRegistryModeTests.swift +// TableProTests +// + +import Foundation +@testable import TablePro +import Testing + +@Suite("ChatToolRegistry mode gating") +@MainActor +struct ChatToolRegistryModeTests { + private struct StubTool: ChatTool { + let name: String + let description = "" + let inputSchema: JSONValue = .object(["type": .string("object")]) + + func execute(input: JSONValue, context: ChatToolContext) async throws -> ChatToolResult { + ChatToolResult(content: "ok") + } + } + + private static let readOnlyToolNames: [String] = [ + "list_connections", + "get_connection_status", + "list_databases", + "list_schemas", + "list_tables", + "describe_table", + "get_table_ddl" + ] + + private static func makeRegistryWithAllTools() -> ChatToolRegistry { + let registry = ChatToolRegistry() + for name in readOnlyToolNames { + registry.register(StubTool(name: name)) + } + registry.register(StubTool(name: "execute_query")) + registry.register(StubTool(name: "confirm_destructive_operation")) + return registry + } + + @Test("Ask mode exposes only read-only tools") + func askModeReadOnly() { + let registry = Self.makeRegistryWithAllTools() + let names = Set(registry.allSpecs(for: .ask).map(\.name)) + #expect(names == Set(Self.readOnlyToolNames)) + #expect(!names.contains("execute_query")) + #expect(!names.contains("confirm_destructive_operation")) + } + + @Test("Edit mode adds execute_query but blocks confirm_destructive_operation") + func editModeAddsExecuteQuery() { + let registry = Self.makeRegistryWithAllTools() + let names = Set(registry.allSpecs(for: .edit).map(\.name)) + let expected = Set(Self.readOnlyToolNames + ["execute_query"]) + #expect(names == expected) + #expect(names.contains("execute_query")) + #expect(!names.contains("confirm_destructive_operation")) + } + + @Test("Agent mode exposes every registered tool including confirm_destructive_operation") + func agentModeExposesAll() { + let registry = Self.makeRegistryWithAllTools() + let names = Set(registry.allSpecs(for: .agent).map(\.name)) + let expected = Set(Self.readOnlyToolNames + ["execute_query", "confirm_destructive_operation"]) + #expect(names == expected) + #expect(names.contains("confirm_destructive_operation")) + } + + @Test("isToolAllowed agrees with allSpecs for every mode and tool name") + func isToolAllowedMatchesSpecs() { + let registry = Self.makeRegistryWithAllTools() + for mode in AIChatMode.allCases { + let allowedFromSpecs = Set(registry.allSpecs(for: mode).map(\.name)) + for tool in registry.allTools { + let allowed = ChatToolRegistry.isToolAllowed(name: tool.name, in: mode) + #expect(allowed == allowedFromSpecs.contains(tool.name)) + } + } + } + + @Test("tool(named:in:) returns nil for tools blocked by the mode") + func toolLookupRespectsMode() { + let registry = Self.makeRegistryWithAllTools() + #expect(registry.tool(named: "execute_query", in: .ask) == nil) + #expect(registry.tool(named: "execute_query", in: .edit)?.name == "execute_query") + #expect(registry.tool(named: "confirm_destructive_operation", in: .edit) == nil) + #expect(registry.tool(named: "confirm_destructive_operation", in: .agent)?.name == "confirm_destructive_operation") + #expect(registry.tool(named: "list_tables", in: .ask)?.name == "list_tables") + } + + @Test("Unknown tool names are not allowed in any mode except agent") + func unknownToolsBlockedOutsideAgent() { + #expect(ChatToolRegistry.isToolAllowed(name: "future_tool", in: .ask) == false) + #expect(ChatToolRegistry.isToolAllowed(name: "future_tool", in: .edit) == false) + #expect(ChatToolRegistry.isToolAllowed(name: "future_tool", in: .agent) == true) + } +} diff --git a/TableProTests/Core/AI/ExecuteToolUsesTests.swift b/TableProTests/Core/AI/ExecuteToolUsesTests.swift index f82ba5273..95454c33d 100644 --- a/TableProTests/Core/AI/ExecuteToolUsesTests.swift +++ b/TableProTests/Core/AI/ExecuteToolUsesTests.swift @@ -61,6 +61,7 @@ struct ExecuteToolUsesTests { let blocks = [ToolUseBlock(id: "u1", name: "alpha", input: .object([:]))] let results = await AIChatViewModel.executeToolUses( blocks, + mode: .agent, context: makeContext(), registry: registry ) @@ -83,6 +84,7 @@ struct ExecuteToolUsesTests { ] let results = await AIChatViewModel.executeToolUses( blocks, + mode: .agent, context: makeContext(), registry: registry ) @@ -96,6 +98,7 @@ struct ExecuteToolUsesTests { let blocks = [ToolUseBlock(id: "u1", name: "ghost", input: .object([:]))] let results = await AIChatViewModel.executeToolUses( blocks, + mode: .agent, context: makeContext(), registry: registry ) @@ -111,6 +114,7 @@ struct ExecuteToolUsesTests { let blocks = [ToolUseBlock(id: "u1", name: "boom", input: .object([:]))] let results = await AIChatViewModel.executeToolUses( blocks, + mode: .agent, context: makeContext(), registry: registry ) @@ -126,6 +130,7 @@ struct ExecuteToolUsesTests { let blocks = [ToolUseBlock(id: "u1", name: "warn", input: .object([:]))] let results = await AIChatViewModel.executeToolUses( blocks, + mode: .agent, context: makeContext(), registry: registry ) @@ -143,6 +148,7 @@ struct ExecuteToolUsesTests { ] let results = await AIChatViewModel.executeToolUses( blocks, + mode: .agent, context: makeContext(), registry: registry ) @@ -160,6 +166,7 @@ struct ExecuteToolUsesTests { let input: JSONValue = .object(["query": .string("SELECT 1")]) _ = await AIChatViewModel.executeToolUses( [ToolUseBlock(id: "u1", name: "alpha", input: input)], + mode: .agent, context: makeContext(), registry: registry ) @@ -172,9 +179,44 @@ struct ExecuteToolUsesTests { let registry = ChatToolRegistry() let results = await AIChatViewModel.executeToolUses( [], + mode: .agent, context: makeContext(), registry: registry ) #expect(results.isEmpty) } + + @Test("execute_query blocked in Ask mode returns isError result without invoking tool") + func askModeBlocksExecuteQuery() async { + let registry = ChatToolRegistry() + let stub = StubTool(name: "execute_query", response: "should-not-run") + registry.register(stub) + let blocks = [ToolUseBlock(id: "u1", name: "execute_query", input: .object([:]))] + let results = await AIChatViewModel.executeToolUses( + blocks, + mode: .ask, + context: makeContext(), + registry: registry + ) + #expect(results.count == 1) + #expect(results[0].isError == true) + #expect(stub.invocations.isEmpty) + } + + @Test("confirm_destructive_operation blocked in Edit mode returns isError result") + func editModeBlocksDestructiveConfirm() async { + let registry = ChatToolRegistry() + let stub = StubTool(name: "confirm_destructive_operation", response: "should-not-run") + registry.register(stub) + let blocks = [ToolUseBlock(id: "u1", name: "confirm_destructive_operation", input: .object([:]))] + let results = await AIChatViewModel.executeToolUses( + blocks, + mode: .edit, + context: makeContext(), + registry: registry + ) + #expect(results.count == 1) + #expect(results[0].isError == true) + #expect(stub.invocations.isEmpty) + } }