From 58bc7e6898e311a76f8687e093a23fa64844c9f9 Mon Sep 17 00:00:00 2001 From: Ngo Quoc Dat Date: Thu, 7 May 2026 19:32:59 +0700 Subject: [PATCH] feat: per-connection AI rules --- CHANGELOG.md | 1 + TablePro/Core/AI/AISchemaContext.swift | 8 +- TablePro/Core/Storage/ConnectionStorage.swift | 1 + .../Connection/DatabaseConnection.swift | 7 +- TablePro/ViewModels/AIChatViewModel.swift | 7 +- .../ConnectionFormCoordinator.swift | 6 + .../ConnectionForm/ConnectionFormPane.swift | 5 + .../ConnectionForm/ConnectionFormView.swift | 2 + .../Panes/AIRulesPaneView.swift | 91 ++++++++++++ .../ViewModels/AIRulesPaneViewModel.swift | 23 +++ .../DatabaseConnectionAIRulesTests.swift | 138 ++++++++++++++++++ docs/features/ai-assistant.mdx | 14 ++ 12 files changed, 299 insertions(+), 4 deletions(-) create mode 100644 TablePro/Views/ConnectionForm/Panes/AIRulesPaneView.swift create mode 100644 TablePro/Views/ConnectionForm/ViewModels/AIRulesPaneViewModel.swift create mode 100644 TableProTests/Models/DatabaseConnectionAIRulesTests.swift diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f4830bc9..7bee92efc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - AI Chat: attach a saved query as a chip via `@`. Type `@` and pick a saved SQL query to send its name and body to the AI alongside your message. - AI Chat: user-defined slash commands. Create your own commands in Settings -> AI -> Custom Slash Commands. Templates support `{{query}}`, `{{schema}}`, `{{database}}`, and `{{body}}` placeholders that get substituted at send time. - AI Chat: tool calling can now run write queries (`execute_query`) and destructive DDL (`confirm_destructive_operation` after the AI passes the verbatim phrase). The connection's safe mode policy still gates execution, so the user remains the final approver. +- AI Chat: per-connection rules. Add custom guidance (table conventions, PII columns, naming) in the connection's AI Rules tab; the AI sees it on every chat turn for that connection. ### Changed diff --git a/TablePro/Core/AI/AISchemaContext.swift b/TablePro/Core/AI/AISchemaContext.swift index fdfac1e6d..f921e450e 100644 --- a/TablePro/Core/AI/AISchemaContext.swift +++ b/TablePro/Core/AI/AISchemaContext.swift @@ -24,7 +24,8 @@ struct AISchemaContext { settings: AISettings, identifierQuote: String = "\"", editorLanguage: EditorLanguage, - queryLanguageName: String + queryLanguageName: String, + connectionRules: String? = nil ) -> String { var parts: [String] = [] @@ -67,6 +68,11 @@ struct AISchemaContext { parts.append("\n## Recent Query Results\n\(results)") } + if let rules = connectionRules?.trimmingCharacters(in: .whitespacesAndNewlines), + !rules.isEmpty { + parts.append("\n## Connection-Specific Rules\n\(rules)") + } + let langTag = editorLanguage.codeBlockTag switch editorLanguage { diff --git a/TablePro/Core/Storage/ConnectionStorage.swift b/TablePro/Core/Storage/ConnectionStorage.swift index 541a3b44a..45d57ec97 100644 --- a/TablePro/Core/Storage/ConnectionStorage.swift +++ b/TablePro/Core/Storage/ConnectionStorage.swift @@ -221,6 +221,7 @@ final class ConnectionStorage { sshTunnelMode: connection.sshTunnelMode, safeModeLevel: connection.safeModeLevel, aiPolicy: connection.aiPolicy, + aiRules: connection.aiRules, redisDatabase: connection.redisDatabase, startupCommands: connection.startupCommands, sortOrder: connection.sortOrder, diff --git a/TablePro/Models/Connection/DatabaseConnection.swift b/TablePro/Models/Connection/DatabaseConnection.swift index fcc759a1b..e03884ed4 100644 --- a/TablePro/Models/Connection/DatabaseConnection.swift +++ b/TablePro/Models/Connection/DatabaseConnection.swift @@ -276,6 +276,7 @@ struct DatabaseConnection: Identifiable, Hashable { var sshTunnelMode: SSHTunnelMode var safeModeLevel: SafeModeLevel var aiPolicy: AIConnectionPolicy? + var aiRules: String? var externalAccess: ExternalAccessLevel = .readOnly var additionalFields: [String: String] = [:] var redisDatabase: Int? @@ -356,6 +357,7 @@ struct DatabaseConnection: Identifiable, Hashable { sshTunnelMode: SSHTunnelMode = .disabled, safeModeLevel: SafeModeLevel = .silent, aiPolicy: AIConnectionPolicy? = nil, + aiRules: String? = nil, externalAccess: ExternalAccessLevel = .readOnly, mongoAuthSource: String? = nil, mongoReadPreference: String? = nil, @@ -402,6 +404,7 @@ struct DatabaseConnection: Identifiable, Hashable { self.sshTunnelMode = sshTunnelMode } self.aiPolicy = aiPolicy + self.aiRules = aiRules self.externalAccess = externalAccess self.redisDatabase = redisDatabase self.startupCommands = startupCommands @@ -454,7 +457,7 @@ extension DatabaseConnection: Codable { private enum CodingKeys: String, CodingKey { case id, name, host, port, database, username, type case sshConfig, sslConfig, color, tagId, groupId, sshProfileId - case sshTunnelMode, safeModeLevel, aiPolicy, externalAccess, additionalFields + case sshTunnelMode, safeModeLevel, aiPolicy, aiRules, externalAccess, additionalFields case redisDatabase, startupCommands, sortOrder, localOnly, isSample } @@ -475,6 +478,7 @@ extension DatabaseConnection: Codable { sshProfileId = try container.decodeIfPresent(UUID.self, forKey: .sshProfileId) safeModeLevel = try container.decodeIfPresent(SafeModeLevel.self, forKey: .safeModeLevel) ?? .silent aiPolicy = try container.decodeIfPresent(AIConnectionPolicy.self, forKey: .aiPolicy) + aiRules = try container.decodeIfPresent(String.self, forKey: .aiRules) externalAccess = try container.decodeIfPresent(ExternalAccessLevel.self, forKey: .externalAccess) ?? .readOnly additionalFields = try container.decodeIfPresent([String: String].self, forKey: .additionalFields) ?? [:] redisDatabase = try container.decodeIfPresent(Int.self, forKey: .redisDatabase) @@ -517,6 +521,7 @@ extension DatabaseConnection: Codable { try container.encode(sshTunnelMode, forKey: .sshTunnelMode) try container.encode(safeModeLevel, forKey: .safeModeLevel) try container.encodeIfPresent(aiPolicy, forKey: .aiPolicy) + try container.encodeIfPresent(aiRules, forKey: .aiRules) try container.encode(externalAccess, forKey: .externalAccess) try container.encode(additionalFields, forKey: .additionalFields) try container.encodeIfPresent(redisDatabase, forKey: .redisDatabase) diff --git a/TablePro/ViewModels/AIChatViewModel.swift b/TablePro/ViewModels/AIChatViewModel.swift index a685c1ebc..d1672224d 100644 --- a/TablePro/ViewModels/AIChatViewModel.swift +++ b/TablePro/ViewModels/AIChatViewModel.swift @@ -992,7 +992,8 @@ final class AIChatViewModel { settings: $0.settings, identifierQuote: $0.identifierQuote, editorLanguage: $0.editorLanguage, - queryLanguageName: $0.queryLanguageName + queryLanguageName: $0.queryLanguageName, + connectionRules: $0.connectionRules ) } } @@ -1189,6 +1190,7 @@ final class AIChatViewModel { let identifierQuote: String let editorLanguage: EditorLanguage let queryLanguageName: String + let connectionRules: String? } private func capturePromptContext(settings: AISettings) -> PromptContext? { @@ -1204,7 +1206,8 @@ final class AIChatViewModel { settings: settings, identifierQuote: PluginManager.shared.sqlDialect(for: connection.type)?.identifierQuote ?? "\"", editorLanguage: PluginManager.shared.editorLanguage(for: connection.type), - queryLanguageName: PluginManager.shared.queryLanguageName(for: connection.type) + queryLanguageName: PluginManager.shared.queryLanguageName(for: connection.type), + connectionRules: connection.aiRules ) } } diff --git a/TablePro/Views/ConnectionForm/ConnectionFormCoordinator.swift b/TablePro/Views/ConnectionForm/ConnectionFormCoordinator.swift index 4f8e2b72e..5f7db87f2 100644 --- a/TablePro/Views/ConnectionForm/ConnectionFormCoordinator.swift +++ b/TablePro/Views/ConnectionForm/ConnectionFormCoordinator.swift @@ -31,6 +31,7 @@ final class ConnectionFormCoordinator { var ssl: SSLPaneViewModel var customization: CustomizationPaneViewModel var advanced: AdvancedPaneViewModel + var aiRules: AIRulesPaneViewModel var selectedPane: ConnectionFormPane = .general var hasLoadedData: Bool = false @@ -67,6 +68,7 @@ final class ConnectionFormCoordinator { } panes.append(.customization) panes.append(.advanced) + panes.append(.aiRules) return panes } @@ -91,6 +93,7 @@ final class ConnectionFormCoordinator { self.ssl = SSLPaneViewModel() self.customization = CustomizationPaneViewModel() self.advanced = AdvancedPaneViewModel() + self.aiRules = AIRulesPaneViewModel() let ref = WeakCoordinatorRef(self) network.coordinator = ref @@ -99,6 +102,7 @@ final class ConnectionFormCoordinator { ssl.coordinator = ref customization.coordinator = ref advanced.coordinator = ref + aiRules.coordinator = ref let resolvedInitialType = initialParsedURL?.type ?? initialType if let resolvedInitialType { @@ -135,6 +139,7 @@ final class ConnectionFormCoordinator { ssl.load(from: existing) customization.load(from: existing) advanced.load(from: existing) + aiRules.load(from: existing) } hasLoadedData = true } @@ -250,6 +255,7 @@ final class ConnectionFormCoordinator { sshTunnelMode: sshTunnelMode, safeModeLevel: customization.safeModeLevel, aiPolicy: advanced.aiPolicy, + aiRules: aiRules.trimmedRules, externalAccess: advanced.externalAccess, redisDatabase: advanced.additionalFieldValues["redisDatabase"].map { Int($0) ?? 0 }, startupCommands: advanced.startupCommands.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty diff --git a/TablePro/Views/ConnectionForm/ConnectionFormPane.swift b/TablePro/Views/ConnectionForm/ConnectionFormPane.swift index 96fa4b1e7..7a2d0e37d 100644 --- a/TablePro/Views/ConnectionForm/ConnectionFormPane.swift +++ b/TablePro/Views/ConnectionForm/ConnectionFormPane.swift @@ -11,6 +11,7 @@ enum ConnectionFormPane: String, CaseIterable, Identifiable, Hashable { case ssl case customization case advanced + case aiRules var id: String { rawValue } @@ -21,6 +22,7 @@ enum ConnectionFormPane: String, CaseIterable, Identifiable, Hashable { case .ssl: return String(localized: "SSL/TLS") case .customization: return String(localized: "Customization") case .advanced: return String(localized: "Advanced") + case .aiRules: return String(localized: "AI Rules") } } @@ -31,6 +33,7 @@ enum ConnectionFormPane: String, CaseIterable, Identifiable, Hashable { case .ssl: return "lock.fill" case .customization: return "paintbrush" case .advanced: return "gearshape.2" + case .aiRules: return "sparkles" } } @@ -48,6 +51,8 @@ enum ConnectionFormPane: String, CaseIterable, Identifiable, Hashable { issues = coordinator.customization.validationIssues case .advanced: issues = coordinator.advanced.validationIssues + case .aiRules: + issues = [] } return issues.isEmpty ? nil : "exclamationmark.triangle.fill" } diff --git a/TablePro/Views/ConnectionForm/ConnectionFormView.swift b/TablePro/Views/ConnectionForm/ConnectionFormView.swift index ef807f21a..629c33579 100644 --- a/TablePro/Views/ConnectionForm/ConnectionFormView.swift +++ b/TablePro/Views/ConnectionForm/ConnectionFormView.swift @@ -90,6 +90,8 @@ private struct ConnectionFormDetail: View { CustomizationPaneView(coordinator: coordinator) case .advanced: AdvancedPaneView(coordinator: coordinator) + case .aiRules: + AIRulesPaneView(coordinator: coordinator) } } .navigationSplitViewColumnWidth(min: 480, ideal: 580) diff --git a/TablePro/Views/ConnectionForm/Panes/AIRulesPaneView.swift b/TablePro/Views/ConnectionForm/Panes/AIRulesPaneView.swift new file mode 100644 index 000000000..fc7fe7fb0 --- /dev/null +++ b/TablePro/Views/ConnectionForm/Panes/AIRulesPaneView.swift @@ -0,0 +1,91 @@ +// +// AIRulesPaneView.swift +// TablePro +// + +import AppKit +import SwiftUI + +struct AIRulesPaneView: View { + @Bindable var coordinator: ConnectionFormCoordinator + + var body: some View { + Form { + Section { + AIRulesEditor(text: $coordinator.aiRules.rules) + .frame(minHeight: 280) + } header: { + Text(String(localized: "Rules")) + } footer: { + VStack(alignment: .leading, spacing: 4) { + // swiftlint:disable:next line_length + Text("Custom guidance the AI sees on every chat turn for this connection. Use it for table conventions, naming, columns to avoid (PII, soft-deleted rows), join hints, or business rules the schema doesn't show.") + Text(String(localized: "Plain text. Markdown is preserved as written.")) + } + .font(.caption) + .foregroundStyle(.secondary) + } + + Section { + // swiftlint:disable:next line_length + Text(verbatim: "- Tables prefixed with `tmp_` are scratch and safe to ignore\n- `users.email_hash` is the join key, not `users.email`\n- Always filter `orders` by `deleted_at IS NULL`\n- Never select `users.ssn`") + .font(.system(.caption, design: .monospaced)) + .foregroundStyle(.secondary) + .frame(maxWidth: .infinity, alignment: .leading) + .textSelection(.enabled) + } header: { + Text(String(localized: "Examples")) + } + } + .formStyle(.grouped) + .scrollContentBackground(.hidden) + } +} + +private struct AIRulesEditor: NSViewRepresentable { + @Binding var text: String + + func makeNSView(context: Context) -> NSScrollView { + let scrollView = NSTextView.scrollableTextView() + guard let textView = scrollView.documentView as? NSTextView else { return scrollView } + + textView.font = .monospacedSystemFont(ofSize: NSFont.systemFontSize, weight: .regular) + textView.isAutomaticQuoteSubstitutionEnabled = false + textView.isAutomaticDashSubstitutionEnabled = false + textView.isAutomaticTextReplacementEnabled = false + textView.isAutomaticSpellingCorrectionEnabled = false + textView.isRichText = false + textView.string = text + textView.textContainerInset = NSSize(width: 4, height: 6) + textView.delegate = context.coordinator + + scrollView.borderType = .bezelBorder + scrollView.hasVerticalScroller = true + + return scrollView + } + + func updateNSView(_ scrollView: NSScrollView, context: Context) { + guard let textView = scrollView.documentView as? NSTextView else { return } + if textView.string != text { + textView.string = text + } + } + + func makeCoordinator() -> Coordinator { + Coordinator(text: $text) + } + + final class Coordinator: NSObject, NSTextViewDelegate { + private var text: Binding + + init(text: Binding) { + self.text = text + } + + func textDidChange(_ notification: Notification) { + guard let textView = notification.object as? NSTextView else { return } + text.wrappedValue = textView.string + } + } +} diff --git a/TablePro/Views/ConnectionForm/ViewModels/AIRulesPaneViewModel.swift b/TablePro/Views/ConnectionForm/ViewModels/AIRulesPaneViewModel.swift new file mode 100644 index 000000000..47c67b767 --- /dev/null +++ b/TablePro/Views/ConnectionForm/ViewModels/AIRulesPaneViewModel.swift @@ -0,0 +1,23 @@ +// +// AIRulesPaneViewModel.swift +// TablePro +// + +import Foundation + +@Observable +@MainActor +final class AIRulesPaneViewModel { + var rules: String = "" + + var coordinator: WeakCoordinatorRef? + + func load(from connection: DatabaseConnection) { + rules = connection.aiRules ?? "" + } + + var trimmedRules: String? { + let trimmed = rules.trimmingCharacters(in: .whitespacesAndNewlines) + return trimmed.isEmpty ? nil : rules + } +} diff --git a/TableProTests/Models/DatabaseConnectionAIRulesTests.swift b/TableProTests/Models/DatabaseConnectionAIRulesTests.swift new file mode 100644 index 000000000..cff9cf799 --- /dev/null +++ b/TableProTests/Models/DatabaseConnectionAIRulesTests.swift @@ -0,0 +1,138 @@ +// +// DatabaseConnectionAIRulesTests.swift +// TableProTests +// + +import Foundation +import Testing + +@testable import TablePro + +@Suite("DatabaseConnection.aiRules") +struct DatabaseConnectionAIRulesTests { + @Test("aiRules defaults to nil") + func defaultsToNil() { + let conn = TestFixtures.makeConnection() + #expect(conn.aiRules == nil) + } + + @Test("init populates aiRules") + func initPopulatesAIRules() { + let conn = DatabaseConnection( + name: "Test", + type: .mysql, + aiRules: "Always filter by tenant_id." + ) + #expect(conn.aiRules == "Always filter by tenant_id.") + } + + @Test("aiRules is mutable on var") + func aiRulesMutable() { + var conn = TestFixtures.makeConnection() + conn.aiRules = "Avoid users.ssn" + #expect(conn.aiRules == "Avoid users.ssn") + } + + @Test("Codable round-trip preserves aiRules") + func codableRoundTripWithRules() throws { + let original = DatabaseConnection( + name: "Prod", + type: .postgresql, + aiRules: "- Tables prefixed with `tmp_` are scratch.\n- Never select users.ssn." + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(DatabaseConnection.self, from: data) + #expect(decoded.aiRules == original.aiRules) + } + + @Test("Codable round-trip preserves nil aiRules") + func codableRoundTripNilRules() throws { + let original = TestFixtures.makeConnection() + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(DatabaseConnection.self, from: data) + #expect(decoded.aiRules == nil) + } + + @Test("Decode without aiRules key produces nil for forward compatibility") + func decodeLegacyJSONWithoutAIRulesKey() throws { + let id = UUID() + let legacyJSON = """ + { + "id": "\(id.uuidString)", + "name": "Legacy", + "type": "MySQL" + } + """ + let data = Data(legacyJSON.utf8) + let decoded = try JSONDecoder().decode(DatabaseConnection.self, from: data) + #expect(decoded.aiRules == nil) + #expect(decoded.name == "Legacy") + } + + @Test("Empty aiRules string round-trips as empty string") + func emptyStringRoundTrip() throws { + let original = DatabaseConnection( + name: "Empty", + type: .mysql, + aiRules: "" + ) + let data = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(DatabaseConnection.self, from: data) + #expect(decoded.aiRules == "") + } + + @Test("System prompt includes connection rules section when rules are non-empty") + func systemPromptIncludesRulesSection() { + let prompt = AISchemaContext.buildSystemPrompt( + databaseType: .postgresql, + databaseName: "shop", + tables: [], + columnsByTable: [:], + foreignKeys: [:], + currentQuery: nil, + queryResults: nil, + settings: AISettings.default, + editorLanguage: .sql, + queryLanguageName: "SQL", + connectionRules: "Filter orders by deleted_at IS NULL." + ) + #expect(prompt.contains("## Connection-Specific Rules")) + #expect(prompt.contains("Filter orders by deleted_at IS NULL.")) + } + + @Test("System prompt omits connection rules section when nil") + func systemPromptOmitsRulesWhenNil() { + let prompt = AISchemaContext.buildSystemPrompt( + databaseType: .postgresql, + databaseName: "shop", + tables: [], + columnsByTable: [:], + foreignKeys: [:], + currentQuery: nil, + queryResults: nil, + settings: AISettings.default, + editorLanguage: .sql, + queryLanguageName: "SQL", + connectionRules: nil + ) + #expect(!prompt.contains("## Connection-Specific Rules")) + } + + @Test("System prompt omits connection rules section when whitespace only") + func systemPromptOmitsRulesWhenWhitespace() { + let prompt = AISchemaContext.buildSystemPrompt( + databaseType: .postgresql, + databaseName: "shop", + tables: [], + columnsByTable: [:], + foreignKeys: [:], + currentQuery: nil, + queryResults: nil, + settings: AISettings.default, + editorLanguage: .sql, + queryLanguageName: "SQL", + connectionRules: " \n\t " + ) + #expect(!prompt.contains("## Connection-Specific Rules")) + } +} diff --git a/docs/features/ai-assistant.mdx b/docs/features/ai-assistant.mdx index 434212a3e..9b616b16c 100644 --- a/docs/features/ai-assistant.mdx +++ b/docs/features/ai-assistant.mdx @@ -216,6 +216,20 @@ Under **Settings** > **AI** > **Context**: - **Include query results** (default off) - **Max schema tables** (default 20) +## Connection rules + +Pin domain knowledge to a specific connection so the AI sees it on every chat turn. Open the connection's edit form and pick **AI Rules** in the sidebar. + +Rules are plain text. Use them for facts the schema doesn't show: + +- Table conventions: `Tables prefixed with tmp_ are scratch and safe to ignore.` +- Join keys: `users.email_hash is the join key, not users.email.` +- Soft deletes: `Always filter orders by deleted_at IS NULL.` +- PII to avoid: `Never select users.ssn or users.dob.` +- Business rules: `Active accounts have status = 'active' AND verified_at IS NOT NULL.` + +The text is appended to the system prompt under a `## Connection-Specific Rules` section, after the schema and any attached context. Rules persist per connection and follow that connection across app launches. + ## Privacy Set a per-connection AI policy in the connection form: **Use Default**, **Always Allow**, **Ask Each Time**, or **Never**. New connections default to **Ask Each Time**.