diff --git a/src/core/record/l1-extractor.test.ts b/src/core/record/l1-extractor.test.ts new file mode 100644 index 00000000..33525a16 --- /dev/null +++ b/src/core/record/l1-extractor.test.ts @@ -0,0 +1,145 @@ +import fs from "node:fs/promises"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { extractL1Memories } from "./l1-extractor.js"; +import type { LLMRunner } from "../types.js"; + +describe("extractL1Memories rule pre-extraction", () => { + let tmpDir: string; + + beforeEach(async () => { + tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), "tdai-l1-pre-extract-")); + }); + + afterEach(async () => { + await fs.rm(tmpDir, { recursive: true, force: true }); + }); + + it("stores explicit persona statements without calling the LLM", async () => { + const llmRunner: LLMRunner = { + run: vi.fn(async () => { + throw new Error("LLM should not be called for high-confidence persona statements"); + }), + }; + + const result = await extractL1Memories({ + messages: [ + { + id: "msg-1", + role: "user", + content: "我是 Python 工程师", + timestamp: Date.parse("2026-06-01T10:00:00Z"), + }, + ], + sessionKey: "session-pre-extract", + baseDir: tmpDir, + config: {}, + options: { + enableDedup: false, + llmRunner, + }, + }); + + expect(llmRunner.run).not.toHaveBeenCalled(); + expect(result.success).toBe(true); + expect(result.extractedCount).toBe(1); + expect(result.storedCount).toBe(1); + expect(result.records[0]).toMatchObject({ + content: "用户是 Python 工程师。", + type: "persona", + priority: 80, + source_message_ids: ["msg-1"], + }); + }); + + it("stores explicit reply instructions without calling the LLM", async () => { + const llmRunner: LLMRunner = { + run: vi.fn(async () => { + throw new Error("LLM should not be called for high-confidence instructions"); + }), + }; + + const result = await extractL1Memories({ + messages: [ + { + id: "msg-1", + role: "user", + content: "以后请用中文回复", + timestamp: Date.parse("2026-06-01T10:00:00Z"), + }, + ], + sessionKey: "session-pre-extract", + baseDir: tmpDir, + config: {}, + options: { + enableDedup: false, + llmRunner, + }, + }); + + expect(llmRunner.run).not.toHaveBeenCalled(); + expect(result.success).toBe(true); + expect(result.extractedCount).toBe(1); + expect(result.storedCount).toBe(1); + expect(result.records[0]).toMatchObject({ + content: "用户要求 AI 以后用中文回复。", + type: "instruction", + priority: 80, + source_message_ids: ["msg-1"], + }); + }); + + it("keeps using the LLM when any user message is not covered by direct rules", async () => { + const llmRunner: LLMRunner = { + run: vi.fn(async () => JSON.stringify([ + { + scene_name: "用户讨论复杂计划", + message_ids: ["msg-1", "msg-2"], + memories: [ + { + content: "用户正在评估一项复杂迁移计划。", + type: "episodic", + priority: 70, + source_message_ids: ["msg-2"], + metadata: {}, + }, + ], + }, + ])), + }; + + const result = await extractL1Memories({ + messages: [ + { + id: "msg-1", + role: "user", + content: "我是 Python 工程师", + timestamp: Date.parse("2026-06-01T10:00:00Z"), + }, + { + id: "msg-2", + role: "user", + content: "我们今天需要讨论数据库迁移的灰度计划和失败回滚策略。", + timestamp: Date.parse("2026-06-01T10:01:00Z"), + }, + ], + sessionKey: "session-pre-extract", + baseDir: tmpDir, + config: {}, + options: { + enableDedup: false, + llmRunner, + }, + }); + + expect(llmRunner.run).toHaveBeenCalledTimes(1); + expect(result.success).toBe(true); + expect(result.records).toHaveLength(1); + expect(result.records[0]).toMatchObject({ + content: "用户正在评估一项复杂迁移计划。", + type: "episodic", + source_message_ids: ["msg-2"], + }); + }); +}); diff --git a/src/core/record/l1-extractor.ts b/src/core/record/l1-extractor.ts index d38fd4b9..fa624d82 100644 --- a/src/core/record/l1-extractor.ts +++ b/src/core/record/l1-extractor.ts @@ -17,6 +17,7 @@ import { EXTRACT_MEMORIES_SYSTEM_PROMPT, formatExtractionPrompt } from "../promp import { batchDedup } from "./l1-dedup.js"; import { writeMemory, generateMemoryId } from "./l1-writer.js"; import type { ExtractedMemory, MemoryRecord, MemoryType, DedupDecision } from "./l1-writer.js"; +import { preExtractHighConfidenceMemories } from "./pre-extractor.js"; import { CleanContextRunner } from "../../utils/clean-context-runner.js"; import { sanitizeJsonForParse, shouldExtractL1 } from "../../utils/sanitize.js"; import type { IMemoryStore } from "../store/types.js"; @@ -148,44 +149,57 @@ export async function extractL1Memories(params: { logger?.debug?.(`${TAG} Extracting from ${newMessages.length} new messages (+ ${backgroundMessages.length} background) [${qualifiedMessages.length} qualified from ${messages.length} input]`); - // Step 1: LLM extraction (scene segmentation + memory extraction) - let scenes: SceneSegment[]; - try { - scenes = await callLlmExtraction({ - newMessages, - backgroundMessages, - previousSceneName: options.previousSceneName, - config, - logger, - model: options.model, - llmRunner: options.llmRunner, - }); - logger?.debug?.(`${TAG} LLM detected ${scenes.length} scene(s)`); - } catch (err) { - logger?.error(`${TAG} LLM extraction failed: ${err instanceof Error ? err.message : String(err)}`); - return { success: false, extractedCount: 0, storedCount: 0, records: [], sceneNames: [] }; - } - - // Flatten all memories across scenes const allExtracted: ExtractedMemory[] = []; const sceneNames: string[] = []; - for (const scene of scenes) { - sceneNames.push(scene.scene_name); - for (const mem of scene.memories) { - const memType = normalizeType(mem.type); - if (!memType) { - logger?.warn?.(`${TAG} Skipping memory with invalid type "${mem.type}"`); - continue; + const preExtraction = preExtractHighConfidenceMemories(newMessages); + if (preExtraction.canBypassLlm) { + for (const memory of preExtraction.memories) { + allExtracted.push(memory); + if (!sceneNames.includes(memory.scene_name)) { + sceneNames.push(memory.scene_name); } - allExtracted.push({ - content: mem.content, - type: memType, - priority: typeof mem.priority === "number" ? mem.priority : 50, - source_message_ids: Array.isArray(mem.source_message_ids) ? mem.source_message_ids : [], - metadata: mem.metadata ?? {}, - scene_name: scene.scene_name, + } + logger?.debug?.( + `${TAG} Rule pre-extraction bypassed LLM: ${allExtracted.length} high-confidence memor${allExtracted.length === 1 ? "y" : "ies"}`, + ); + } else { + // Step 1: LLM extraction (scene segmentation + memory extraction) + let scenes: SceneSegment[]; + try { + scenes = await callLlmExtraction({ + newMessages, + backgroundMessages, + previousSceneName: options.previousSceneName, + config, + logger, + model: options.model, + llmRunner: options.llmRunner, }); + logger?.debug?.(`${TAG} LLM detected ${scenes.length} scene(s)`); + } catch (err) { + logger?.error(`${TAG} LLM extraction failed: ${err instanceof Error ? err.message : String(err)}`); + return { success: false, extractedCount: 0, storedCount: 0, records: [], sceneNames: [] }; + } + + // Flatten all memories across scenes + for (const scene of scenes) { + sceneNames.push(scene.scene_name); + for (const mem of scene.memories) { + const memType = normalizeType(mem.type); + if (!memType) { + logger?.warn?.(`${TAG} Skipping memory with invalid type "${mem.type}"`); + continue; + } + allExtracted.push({ + content: mem.content, + type: memType, + priority: typeof mem.priority === "number" ? mem.priority : 50, + source_message_ids: Array.isArray(mem.source_message_ids) ? mem.source_message_ids : [], + metadata: mem.metadata ?? {}, + scene_name: scene.scene_name, + }); + } } } diff --git a/src/core/record/pre-extractor.ts b/src/core/record/pre-extractor.ts new file mode 100644 index 00000000..853f2227 --- /dev/null +++ b/src/core/record/pre-extractor.ts @@ -0,0 +1,122 @@ +import type { ConversationMessage } from "../conversation/l0-recorder.js"; +import type { ExtractedMemory } from "./l1-writer.js"; + +export interface PreExtractionResult { + memories: ExtractedMemory[]; + canBypassLlm: boolean; +} + +type RuleMatch = Omit; + +const MAX_DIRECT_MESSAGE_CHARS = 80; + +export function preExtractHighConfidenceMemories(messages: ConversationMessage[]): PreExtractionResult { + const userMessages = messages.filter((message) => message.role === "user"); + const memories: ExtractedMemory[] = []; + const coveredUserIds = new Set(); + + for (const message of userMessages) { + const match = matchHighConfidenceRule(message.content); + if (!match) continue; + memories.push({ + ...match, + source_message_ids: [message.id], + }); + coveredUserIds.add(message.id); + } + + return { + memories, + canBypassLlm: memories.length > 0 && userMessages.every((message) => coveredUserIds.has(message.id)), + }; +} + +function matchHighConfidenceRule(raw: string): RuleMatch | undefined { + const text = normalizeDirectText(raw); + if (!text || text.length > MAX_DIRECT_MESSAGE_CHARS || /[\n\r]/.test(raw)) { + return undefined; + } + + return matchPersonaRule(text) ?? matchInstructionRule(text); +} + +function matchPersonaRule(text: string): RuleMatch | undefined { + const identity = text.match(/^我(?:是一名|是一个|是一位|是)\s*(.{2,40})$/u); + if (identity) { + return { + content: `用户是 ${identity[1].trim()}。`, + type: "persona", + priority: 80, + scene_name: "用户介绍个人身份信息", + metadata: {}, + }; + } + + const occupation = text.match(/^我的(?:职业|工作|岗位)(?:是|为)\s*(.{2,40})$/u); + if (occupation) { + return { + content: `用户的职业是 ${occupation[1].trim()}。`, + type: "persona", + priority: 80, + scene_name: "用户介绍个人身份信息", + metadata: {}, + }; + } + + const preference = text.match(/^我(喜欢|偏好|擅长)\s*(.{2,40})$/u); + if (preference) { + return { + content: `用户${preference[1]} ${preference[2].trim()}。`, + type: "persona", + priority: 70, + scene_name: "用户介绍个人偏好", + metadata: {}, + }; + } + + return undefined; +} + +function matchInstructionRule(text: string): RuleMatch | undefined { + const language = text.match(/^以后(?:请|都|要)?\s*(?:用|使用)\s*(.{1,20}?)(?:回复|回答)$/u); + if (language) { + return { + content: `用户要求 AI 以后用${language[1].trim()}回复。`, + type: "instruction", + priority: 80, + scene_name: "用户设置 AI 回复偏好", + metadata: {}, + }; + } + + const futureDirective = text.match(/^以后(?:请|都|要)?\s*(不要|别|保持|尽量|必须|直接)\s*(.{2,40})$/u); + if (futureDirective) { + return { + content: `用户要求 AI 以后${futureDirective[1]}${futureDirective[2].trim()}。`, + type: "instruction", + priority: 80, + scene_name: "用户设置 AI 回复偏好", + metadata: {}, + }; + } + + const fromNow = text.match(/^从现在开始(?:请|都|要)?\s*(.{2,50})$/u); + if (fromNow) { + return { + content: `用户要求 AI 从现在开始${fromNow[1].trim()}。`, + type: "instruction", + priority: 80, + scene_name: "用户设置 AI 回复偏好", + metadata: {}, + }; + } + + return undefined; +} + +function normalizeDirectText(text: string): string { + return text + .trim() + .replace(/[。.!!]+$/u, "") + .replace(/\s+/g, " "); +}