Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions src/core/record/l1-extractor.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { extractL1Memories } from "./l1-extractor.js";
import type { LLMRunner } from "../types.js";

describe("extractL1Memories rule pre-extraction", () => {
let tmpDir: string;

beforeEach(async () => {
tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), "tdai-l1-pre-extract-"));
});

afterEach(async () => {
await fs.rm(tmpDir, { recursive: true, force: true });
});

it("stores explicit persona statements without calling the LLM", async () => {
const llmRunner: LLMRunner = {
run: vi.fn(async () => {
throw new Error("LLM should not be called for high-confidence persona statements");
}),
};

const result = await extractL1Memories({
messages: [
{
id: "msg-1",
role: "user",
content: "我是 Python 工程师",
timestamp: Date.parse("2026-06-01T10:00:00Z"),
},
],
sessionKey: "session-pre-extract",
baseDir: tmpDir,
config: {},
options: {
enableDedup: false,
llmRunner,
},
});

expect(llmRunner.run).not.toHaveBeenCalled();
expect(result.success).toBe(true);
expect(result.extractedCount).toBe(1);
expect(result.storedCount).toBe(1);
expect(result.records[0]).toMatchObject({
content: "用户是 Python 工程师。",
type: "persona",
priority: 80,
source_message_ids: ["msg-1"],
});
});

it("stores explicit reply instructions without calling the LLM", async () => {
const llmRunner: LLMRunner = {
run: vi.fn(async () => {
throw new Error("LLM should not be called for high-confidence instructions");
}),
};

const result = await extractL1Memories({
messages: [
{
id: "msg-1",
role: "user",
content: "以后请用中文回复",
timestamp: Date.parse("2026-06-01T10:00:00Z"),
},
],
sessionKey: "session-pre-extract",
baseDir: tmpDir,
config: {},
options: {
enableDedup: false,
llmRunner,
},
});

expect(llmRunner.run).not.toHaveBeenCalled();
expect(result.success).toBe(true);
expect(result.extractedCount).toBe(1);
expect(result.storedCount).toBe(1);
expect(result.records[0]).toMatchObject({
content: "用户要求 AI 以后用中文回复。",
type: "instruction",
priority: 80,
source_message_ids: ["msg-1"],
});
});

it("keeps using the LLM when any user message is not covered by direct rules", async () => {
const llmRunner: LLMRunner = {
run: vi.fn(async () => JSON.stringify([
{
scene_name: "用户讨论复杂计划",
message_ids: ["msg-1", "msg-2"],
memories: [
{
content: "用户正在评估一项复杂迁移计划。",
type: "episodic",
priority: 70,
source_message_ids: ["msg-2"],
metadata: {},
},
],
},
])),
};

const result = await extractL1Memories({
messages: [
{
id: "msg-1",
role: "user",
content: "我是 Python 工程师",
timestamp: Date.parse("2026-06-01T10:00:00Z"),
},
{
id: "msg-2",
role: "user",
content: "我们今天需要讨论数据库迁移的灰度计划和失败回滚策略。",
timestamp: Date.parse("2026-06-01T10:01:00Z"),
},
],
sessionKey: "session-pre-extract",
baseDir: tmpDir,
config: {},
options: {
enableDedup: false,
llmRunner,
},
});

expect(llmRunner.run).toHaveBeenCalledTimes(1);
expect(result.success).toBe(true);
expect(result.records).toHaveLength(1);
expect(result.records[0]).toMatchObject({
content: "用户正在评估一项复杂迁移计划。",
type: "episodic",
source_message_ids: ["msg-2"],
});
});
});
80 changes: 47 additions & 33 deletions src/core/record/l1-extractor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { EXTRACT_MEMORIES_SYSTEM_PROMPT, formatExtractionPrompt } from "../promp
import { batchDedup } from "./l1-dedup.js";
import { writeMemory, generateMemoryId } from "./l1-writer.js";
import type { ExtractedMemory, MemoryRecord, MemoryType, DedupDecision } from "./l1-writer.js";
import { preExtractHighConfidenceMemories } from "./pre-extractor.js";
import { CleanContextRunner } from "../../utils/clean-context-runner.js";
import { sanitizeJsonForParse, shouldExtractL1 } from "../../utils/sanitize.js";
import type { IMemoryStore } from "../store/types.js";
Expand Down Expand Up @@ -148,44 +149,57 @@ export async function extractL1Memories(params: {

logger?.debug?.(`${TAG} Extracting from ${newMessages.length} new messages (+ ${backgroundMessages.length} background) [${qualifiedMessages.length} qualified from ${messages.length} input]`);

// Step 1: LLM extraction (scene segmentation + memory extraction)
let scenes: SceneSegment[];
try {
scenes = await callLlmExtraction({
newMessages,
backgroundMessages,
previousSceneName: options.previousSceneName,
config,
logger,
model: options.model,
llmRunner: options.llmRunner,
});
logger?.debug?.(`${TAG} LLM detected ${scenes.length} scene(s)`);
} catch (err) {
logger?.error(`${TAG} LLM extraction failed: ${err instanceof Error ? err.message : String(err)}`);
return { success: false, extractedCount: 0, storedCount: 0, records: [], sceneNames: [] };
}

// Flatten all memories across scenes
const allExtracted: ExtractedMemory[] = [];
const sceneNames: string[] = [];

for (const scene of scenes) {
sceneNames.push(scene.scene_name);
for (const mem of scene.memories) {
const memType = normalizeType(mem.type);
if (!memType) {
logger?.warn?.(`${TAG} Skipping memory with invalid type "${mem.type}"`);
continue;
const preExtraction = preExtractHighConfidenceMemories(newMessages);
if (preExtraction.canBypassLlm) {
for (const memory of preExtraction.memories) {
allExtracted.push(memory);
if (!sceneNames.includes(memory.scene_name)) {
sceneNames.push(memory.scene_name);
}
allExtracted.push({
content: mem.content,
type: memType,
priority: typeof mem.priority === "number" ? mem.priority : 50,
source_message_ids: Array.isArray(mem.source_message_ids) ? mem.source_message_ids : [],
metadata: mem.metadata ?? {},
scene_name: scene.scene_name,
}
logger?.debug?.(
`${TAG} Rule pre-extraction bypassed LLM: ${allExtracted.length} high-confidence memor${allExtracted.length === 1 ? "y" : "ies"}`,
);
} else {
// Step 1: LLM extraction (scene segmentation + memory extraction)
let scenes: SceneSegment[];
try {
scenes = await callLlmExtraction({
newMessages,
backgroundMessages,
previousSceneName: options.previousSceneName,
config,
logger,
model: options.model,
llmRunner: options.llmRunner,
});
logger?.debug?.(`${TAG} LLM detected ${scenes.length} scene(s)`);
} catch (err) {
logger?.error(`${TAG} LLM extraction failed: ${err instanceof Error ? err.message : String(err)}`);
return { success: false, extractedCount: 0, storedCount: 0, records: [], sceneNames: [] };
}

// Flatten all memories across scenes
for (const scene of scenes) {
sceneNames.push(scene.scene_name);
for (const mem of scene.memories) {
const memType = normalizeType(mem.type);
if (!memType) {
logger?.warn?.(`${TAG} Skipping memory with invalid type "${mem.type}"`);
continue;
}
allExtracted.push({
content: mem.content,
type: memType,
priority: typeof mem.priority === "number" ? mem.priority : 50,
source_message_ids: Array.isArray(mem.source_message_ids) ? mem.source_message_ids : [],
metadata: mem.metadata ?? {},
scene_name: scene.scene_name,
});
}
}
}

Expand Down
122 changes: 122 additions & 0 deletions src/core/record/pre-extractor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import type { ConversationMessage } from "../conversation/l0-recorder.js";
import type { ExtractedMemory } from "./l1-writer.js";

export interface PreExtractionResult {
memories: ExtractedMemory[];
canBypassLlm: boolean;
}

type RuleMatch = Omit<ExtractedMemory, "source_message_ids">;

const MAX_DIRECT_MESSAGE_CHARS = 80;

export function preExtractHighConfidenceMemories(messages: ConversationMessage[]): PreExtractionResult {
const userMessages = messages.filter((message) => message.role === "user");
const memories: ExtractedMemory[] = [];
const coveredUserIds = new Set<string>();

for (const message of userMessages) {
const match = matchHighConfidenceRule(message.content);
if (!match) continue;
memories.push({
...match,
source_message_ids: [message.id],
});
coveredUserIds.add(message.id);
}

return {
memories,
canBypassLlm: memories.length > 0 && userMessages.every((message) => coveredUserIds.has(message.id)),
};
}

function matchHighConfidenceRule(raw: string): RuleMatch | undefined {
const text = normalizeDirectText(raw);
if (!text || text.length > MAX_DIRECT_MESSAGE_CHARS || /[\n\r]/.test(raw)) {
return undefined;
}

return matchPersonaRule(text) ?? matchInstructionRule(text);
}

function matchPersonaRule(text: string): RuleMatch | undefined {
const identity = text.match(/^我(?:是一名|是一个|是一位|是)\s*(.{2,40})$/u);
if (identity) {
return {
content: `用户是 ${identity[1].trim()}。`,
type: "persona",
priority: 80,
scene_name: "用户介绍个人身份信息",
metadata: {},
};
}

const occupation = text.match(/^我的(?:职业|工作|岗位)(?:是|为)\s*(.{2,40})$/u);
if (occupation) {
return {
content: `用户的职业是 ${occupation[1].trim()}。`,
type: "persona",
priority: 80,
scene_name: "用户介绍个人身份信息",
metadata: {},
};
}

const preference = text.match(/^我(喜欢|偏好|擅长)\s*(.{2,40})$/u);
if (preference) {
return {
content: `用户${preference[1]} ${preference[2].trim()}。`,
type: "persona",
priority: 70,
scene_name: "用户介绍个人偏好",
metadata: {},
};
}

return undefined;
}

function matchInstructionRule(text: string): RuleMatch | undefined {
const language = text.match(/^以后(?:请|都|要)?\s*(?:用|使用)\s*(.{1,20}?)(?:回复|回答)$/u);
if (language) {
return {
content: `用户要求 AI 以后用${language[1].trim()}回复。`,
type: "instruction",
priority: 80,
scene_name: "用户设置 AI 回复偏好",
metadata: {},
};
}

const futureDirective = text.match(/^以后(?:请|都|要)?\s*(不要|别|保持|尽量|必须|直接)\s*(.{2,40})$/u);
if (futureDirective) {
return {
content: `用户要求 AI 以后${futureDirective[1]}${futureDirective[2].trim()}。`,
type: "instruction",
priority: 80,
scene_name: "用户设置 AI 回复偏好",
metadata: {},
};
}

const fromNow = text.match(/^从现在开始(?:请|都|要)?\s*(.{2,50})$/u);
if (fromNow) {
return {
content: `用户要求 AI 从现在开始${fromNow[1].trim()}。`,
type: "instruction",
priority: 80,
scene_name: "用户设置 AI 回复偏好",
metadata: {},
};
}

return undefined;
}

function normalizeDirectText(text: string): string {
return text
.trim()
.replace(/[。.!!]+$/u, "")
.replace(/\s+/g, " ");
}
Loading