From d6350ef0c6b3037e6286c776181cb1fd514a314d Mon Sep 17 00:00:00 2001 From: "Josh Grossman (Bounce Security)" <97975715+joshbouncesecurity@users.noreply.github.com> Date: Sun, 22 Mar 2026 16:40:41 +0200 Subject: [PATCH 1/7] feat: DI-aware call resolution for TypeScript/NestJS codebases (#20) * feat: DI-aware call resolution for TypeScript/NestJS codebases The parser couldn't resolve dependency-injected service calls like `this.callService.getById()` because it didn't know that `callService` is an instance of `CallService`. This caused the agentic enhancer to miss critical authorization checks in service layers, producing false positive vulnerability findings. Changes: - typescript_analyzer.js: Extract constructor parameter types as `constructorDeps` metadata on class methods using ts-morph AST - dependency_resolver.js: Use constructorDeps for DI-aware resolution in _resolveMethodCall, with prefix matching for versioned implementations (e.g., CallService -> CallServiceV1) - Agentic enhancer: Add forward-tracing instructions to the prompt so the agent traces into called functions for auth/validation checks - Agentic enhancer: Add get_static_dependencies tool to surface parsed call graph data to the exploration agent - Agentic enhancer: Pass static deps to tool executor before analysis Co-Authored-By: Claude Opus 4.6 (1M context) * test: add tests for DI-aware call resolution and enhancer tools - test_di_resolution.py: Tests constructor deps extraction from TypeScript AST and DI-aware method resolution in call graphs, including versioned implementations and false positive prevention - test_enhancer_tools.py: Tests resolve_dependencies and the get_static_dependencies tool via ToolExecutor Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- .../parsers/javascript/dependency_resolver.js | 40 ++- .../parsers/javascript/typescript_analyzer.js | 32 ++ libs/openant-core/tests/test_di_resolution.py | 310 ++++++++++++++++++ .../openant-core/tests/test_enhancer_tools.py | 127 +++++++ .../utilities/agentic_enhancer/agent.py | 3 + .../utilities/agentic_enhancer/prompts.py | 41 ++- .../agentic_enhancer/repository_index.py | 48 +++ .../utilities/agentic_enhancer/tools.py | 36 ++ 8 files changed, 623 insertions(+), 14 deletions(-) create mode 100644 libs/openant-core/tests/test_di_resolution.py create mode 100644 libs/openant-core/tests/test_enhancer_tools.py diff --git a/libs/openant-core/parsers/javascript/dependency_resolver.js b/libs/openant-core/parsers/javascript/dependency_resolver.js index 52d130e..84769fd 100644 --- a/libs/openant-core/parsers/javascript/dependency_resolver.js +++ b/libs/openant-core/parsers/javascript/dependency_resolver.js @@ -134,7 +134,7 @@ class DependencyResolver { // Skip 'this' (handled above) and common built-ins if (objectName === 'this' || this._isBuiltIn(objectName)) continue; - const resolved = this._resolveMethodCall(objectName, methodName, callerFile); + const resolved = this._resolveMethodCall(objectName, methodName, callerFile, callerFuncId); if (resolved && !seenCalls.has(resolved)) { seenCalls.add(resolved); calls.push(resolved); @@ -240,16 +240,20 @@ class DependencyResolver { /** * Resolve an object.method call + * + * Supports two resolution strategies: + * 1. Direct class name match: objectName === className + * 2. DI-aware resolution: objectName is a constructor-injected parameter, + * use its type annotation to find the target class */ - _resolveMethodCall(objectName, methodName, callerFile) { - // Check if objectName matches a class name - const qualifiedName = `${objectName}.${methodName}`; + _resolveMethodCall(objectName, methodName, callerFile, callerFuncId = null) { const candidates = this.functionsByName[methodName]; if (!candidates || !Array.isArray(candidates)) { return null; } + // 1. Exact class name match (existing behavior) for (const funcId of candidates) { const funcData = this.functions[funcId]; if (funcData && funcData.className === objectName) { @@ -257,6 +261,34 @@ class DependencyResolver { } } + // 2. DI-aware resolution: look up objectName in caller's constructorDeps + // e.g., this.callService.getById() -> constructorDeps says callService: CallService + // -> resolve to CallService.getById + if (callerFuncId) { + const callerFunc = this.functions[callerFuncId]; + if (callerFunc && callerFunc.constructorDeps) { + const typeName = callerFunc.constructorDeps[objectName]; + if (typeName) { + // 2a. Exact type match + for (const funcId of candidates) { + const funcData = this.functions[funcId]; + if (funcData && funcData.className === typeName) { + return funcId; + } + } + + // 2b. Implementation class match: type is often an interface/abstract class + // and the implementation has a suffix (e.g., CallService -> CallServiceV1, CallServiceImpl) + for (const funcId of candidates) { + const funcData = this.functions[funcId]; + if (funcData && funcData.className && funcData.className.startsWith(typeName)) { + return funcId; + } + } + } + } + } + return null; } diff --git a/libs/openant-core/parsers/javascript/typescript_analyzer.js b/libs/openant-core/parsers/javascript/typescript_analyzer.js index a41a80d..08e3128 100644 --- a/libs/openant-core/parsers/javascript/typescript_analyzer.js +++ b/libs/openant-core/parsers/javascript/typescript_analyzer.js @@ -227,6 +227,38 @@ class TypeScriptAnalyzer { className: className, }; } + + // Extract constructor DI metadata for this class + // In NestJS/Angular, constructor parameters with type annotations + // declare injected services: constructor(private callService: CallService) + const constructors = classDecl.getConstructors(); + if (constructors.length > 0) { + const ctor = constructors[0]; + const injections = {}; // paramName -> typeName + + for (const param of ctor.getParameters()) { + const paramName = param.getName(); + const typeNode = param.getTypeNode(); + if (typeNode) { + const typeName = typeNode.getText(); + // Only store simple PascalCase type names (skip union types, generics, primitives) + if (/^[A-Z][a-zA-Z0-9_$]*$/.test(typeName)) { + injections[paramName] = typeName; + } + } + } + + if (Object.keys(injections).length > 0) { + // Store DI metadata on each method of this class + for (const method of classDecl.getMethods()) { + const methodName = method.getName(); + const functionId = `${relativePath}:${className}.${methodName}`; + if (this.functions[functionId]) { + this.functions[functionId].constructorDeps = injections; + } + } + } + } } // Extract methods from object literals in export default diff --git a/libs/openant-core/tests/test_di_resolution.py b/libs/openant-core/tests/test_di_resolution.py new file mode 100644 index 0000000..309b1f0 --- /dev/null +++ b/libs/openant-core/tests/test_di_resolution.py @@ -0,0 +1,310 @@ +"""Tests for dependency injection-aware call resolution. + +Tests that the TypeScript analyzer extracts constructor parameter types +and the dependency resolver uses them to resolve DI-injected service calls. + +Requires Node.js and npm dependencies installed: + cd parsers/javascript && npm install +""" +import json +import shutil +from pathlib import Path + +from utilities.file_io import run_utf8 + +import pytest + +PARSERS_JS_DIR = Path(__file__).parent.parent / "parsers" / "javascript" +NODE_MODULES = PARSERS_JS_DIR / "node_modules" + +pytestmark = pytest.mark.skipif( + not shutil.which("node") or not NODE_MODULES.exists(), + reason="Node.js or JS parser npm dependencies not available", +) + + +def run_node(script_name, *args): + """Run a Node.js script from the JS parsers directory.""" + cmd = ["node", str(PARSERS_JS_DIR / script_name)] + list(args) + return run_utf8(cmd, capture_output=True, text=True, timeout=30) + + +# -- Fixture: NestJS-style DI codebase -- + +RESOLVER_TS = """\ +import { Injectable } from '@nestjs/common'; +import { CallService } from './call.service'; +import { AuthService } from './auth.service'; + +@Injectable() +export class CallResolver { + constructor( + private callService: CallService, + private authService: AuthService, + ) {} + + async getCall(id: string) { + return await this.callService.getById(id); + } + + async deleteCall(id: string) { + return await this.callService.remove(id); + } +} +""" + +SERVICE_TS = """\ +import { Injectable } from '@nestjs/common'; + +@Injectable() +export class CallService { + async getById(id: string) { + const call = await this.repository.findOne(id); + await this.authService.can('read', call); + return call; + } + + async remove(id: string) { + return await this.repository.delete(id); + } +} +""" + +AUTH_SERVICE_TS = """\ +import { Injectable } from '@nestjs/common'; + +@Injectable() +export class AuthService { + async can(action: string, resource: any) { + // authorization check + return true; + } +} +""" + +# Versioned implementation (interface CallService, impl CallServiceV2) +VERSIONED_SERVICE_TS = """\ +import { Injectable } from '@nestjs/common'; + +@Injectable() +export class CallServiceV2 { + async getById(id: string) { + return { id }; + } + + async remove(id: string) { + return true; + } +} +""" + + +@pytest.fixture +def nestjs_repo(tmp_path): + """Create a minimal NestJS-style repo with DI patterns.""" + src = tmp_path / "src" + src.mkdir() + (src / "call.resolver.ts").write_text(RESOLVER_TS) + (src / "call.service.ts").write_text(SERVICE_TS) + (src / "auth.service.ts").write_text(AUTH_SERVICE_TS) + return tmp_path + + +@pytest.fixture +def nestjs_repo_versioned(tmp_path): + """Create a repo where the DI type doesn't exactly match the class name.""" + src = tmp_path / "src" + src.mkdir() + (src / "call.resolver.ts").write_text(RESOLVER_TS) + (src / "call.service.ts").write_text(VERSIONED_SERVICE_TS) + return tmp_path + + +def analyze_and_resolve(repo_path, files): + """Run analyzer + resolver on given files and return resolved data.""" + analyzer_out = repo_path / "analyzer_output.json" + resolved_out = repo_path / "resolved.json" + + file_paths = [str(f) for f in files] + result = run_node( + "typescript_analyzer.js", str(repo_path), + *file_paths, + "--output", str(analyzer_out), + ) + assert result.returncode == 0, f"Analyzer failed: {result.stderr}" + + result = run_node( + "dependency_resolver.js", str(analyzer_out), + "--output", str(resolved_out), + ) + assert result.returncode == 0, f"Resolver failed: {result.stderr}" + + return json.loads(resolved_out.read_text()) + + +class TestConstructorDepsExtraction: + """Test that the analyzer extracts constructorDeps from class constructors.""" + + def test_extracts_constructor_deps(self, nestjs_repo): + analyzer_out = nestjs_repo / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(nestjs_repo), + "src/call.resolver.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + + data = json.loads(analyzer_out.read_text()) + functions = data["functions"] + + # Find a CallResolver method + resolver_methods = { + fid: f for fid, f in functions.items() + if "CallResolver" in fid + } + assert len(resolver_methods) > 0, "No CallResolver methods found" + + # Each method should have constructorDeps + for fid, func in resolver_methods.items(): + assert "constructorDeps" in func, f"{fid} missing constructorDeps" + deps = func["constructorDeps"] + assert deps.get("callService") == "CallService" + assert deps.get("authService") == "AuthService" + + def test_skips_primitive_types(self, tmp_path): + """Constructor params with primitive types should not be included.""" + src = tmp_path / "src" + src.mkdir() + (src / "example.ts").write_text("""\ +export class Example { + constructor( + private name: string, + private count: number, + private service: MyService, + ) {} + + doWork() { + return this.service.run(); + } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/example.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + + data = json.loads(analyzer_out.read_text()) + func = next( + f for f in data["functions"].values() + if f.get("className") == "Example" + ) + deps = func.get("constructorDeps", {}) + # Only MyService should be captured (PascalCase), not string/number + assert "service" in deps + assert deps["service"] == "MyService" + assert "name" not in deps + assert "count" not in deps + + +class TestDIAwareCallResolution: + """Test that the dependency resolver uses constructorDeps for DI resolution.""" + + def test_resolves_exact_type_match(self, nestjs_repo): + """this.callService.getById() resolves to CallService.getById.""" + data = analyze_and_resolve(nestjs_repo, [ + "src/call.resolver.ts", + "src/call.service.ts", + ]) + + call_graph = data["callGraph"] + + # Find CallResolver.getCall's call graph + resolver_calls = None + for fid, calls in call_graph.items(): + if "CallResolver.getCall" in fid: + resolver_calls = calls + break + + assert resolver_calls is not None, "CallResolver.getCall not in call graph" + assert any( + "CallService.getById" in c for c in resolver_calls + ), f"Expected CallService.getById in calls, got: {resolver_calls}" + + def test_resolves_versioned_implementation(self, nestjs_repo_versioned): + """this.callService.getById() resolves to CallServiceV2.getById via prefix match.""" + data = analyze_and_resolve(nestjs_repo_versioned, [ + "src/call.resolver.ts", + "src/call.service.ts", + ]) + + call_graph = data["callGraph"] + resolver_calls = None + for fid, calls in call_graph.items(): + if "CallResolver.getCall" in fid: + resolver_calls = calls + break + + assert resolver_calls is not None + assert any( + "CallServiceV2.getById" in c for c in resolver_calls + ), f"Expected CallServiceV2.getById in calls, got: {resolver_calls}" + + def test_resolves_multiple_di_methods(self, nestjs_repo): + """Both getById and remove should resolve to CallService methods.""" + data = analyze_and_resolve(nestjs_repo, [ + "src/call.resolver.ts", + "src/call.service.ts", + ]) + + call_graph = data["callGraph"] + + # deleteCall should resolve to CallService.remove + delete_calls = None + for fid, calls in call_graph.items(): + if "CallResolver.deleteCall" in fid: + delete_calls = calls + break + + assert delete_calls is not None + assert any( + "CallService.remove" in c for c in delete_calls + ), f"Expected CallService.remove in calls, got: {delete_calls}" + + def test_no_false_positives_without_di(self, tmp_path): + """Methods without constructor deps should not spuriously resolve.""" + src = tmp_path / "src" + src.mkdir() + (src / "plain.ts").write_text("""\ +export class PlainService { + doWork() { + return this.unknownService.process(); + } +} +""") + (src / "other.ts").write_text("""\ +export class UnknownService { + process() { + return 42; + } +} +""") + data = analyze_and_resolve(tmp_path, [ + "src/plain.ts", + "src/other.ts", + ]) + + call_graph = data["callGraph"] + plain_calls = None + for fid, calls in call_graph.items(): + if "PlainService.doWork" in fid: + plain_calls = calls + break + + # Without constructor deps, unknownService.process() should NOT resolve + assert plain_calls is not None + assert not any( + "UnknownService.process" in c for c in plain_calls + ), f"Should not resolve without DI metadata, got: {plain_calls}" diff --git a/libs/openant-core/tests/test_enhancer_tools.py b/libs/openant-core/tests/test_enhancer_tools.py new file mode 100644 index 0000000..a862f05 --- /dev/null +++ b/libs/openant-core/tests/test_enhancer_tools.py @@ -0,0 +1,127 @@ +"""Tests for the agentic enhancer tools, specifically the get_static_dependencies tool.""" +import pytest + +from utilities.agentic_enhancer.repository_index import RepositoryIndex +from utilities.agentic_enhancer.tools import ToolExecutor + + +def _make_index(functions: dict) -> RepositoryIndex: + """Create a RepositoryIndex from a minimal functions dict.""" + return RepositoryIndex({"functions": functions}) + + +SAMPLE_FUNCTIONS = { + "src/user.controller.ts:UserController.getUser": { + "name": "UserController.getUser", + "code": "async getUser(id) { return this.userService.findById(id); }", + "className": "UserController", + "unitType": "class_method", + "startLine": 10, + "endLine": 12, + }, + "src/user.service.ts:UserService.findById": { + "name": "UserService.findById", + "code": "async findById(id) { return this.repo.findOne(id); }", + "className": "UserService", + "unitType": "class_method", + "startLine": 5, + "endLine": 7, + }, + "src/auth.guard.ts:AuthGuard.canActivate": { + "name": "AuthGuard.canActivate", + "code": "canActivate(context) { return this.validate(context); }", + "className": "AuthGuard", + "unitType": "class_method", + "startLine": 3, + "endLine": 5, + }, +} + + +class TestResolveDependencies: + """Test RepositoryIndex.resolve_dependencies.""" + + def test_resolves_by_function_id(self): + index = _make_index(SAMPLE_FUNCTIONS) + result = index.resolve_dependencies([ + "src/user.service.ts:UserService.findById" + ]) + assert len(result) == 1 + assert result[0]["id"] == "src/user.service.ts:UserService.findById" + assert result[0]["className"] == "UserService" + + def test_resolves_by_qualified_name(self): + """Resolve using Class.method format when full ID is unknown.""" + index = _make_index(SAMPLE_FUNCTIONS) + result = index.resolve_dependencies(["AuthGuard.canActivate"]) + assert len(result) == 1 + assert "AuthGuard.canActivate" in result[0]["id"] + + def test_returns_empty_for_unknown(self): + index = _make_index(SAMPLE_FUNCTIONS) + result = index.resolve_dependencies(["nonExistentFunction"]) + assert result == [] + + def test_deduplicates_results(self): + index = _make_index(SAMPLE_FUNCTIONS) + result = index.resolve_dependencies([ + "src/user.service.ts:UserService.findById", + "src/user.service.ts:UserService.findById", + ]) + assert len(result) == 1 + + +class TestGetStaticDependenciesTool: + """Test the get_static_dependencies tool via ToolExecutor.""" + + def test_returns_resolved_deps(self): + index = _make_index(SAMPLE_FUNCTIONS) + executor = ToolExecutor(index) + executor.set_unit_context( + static_deps=["src/user.service.ts:UserService.findById"], + static_callers=[], + ) + + result = executor.execute("get_static_dependencies", {}) + assert result["dependencies"]["count"] == 1 + assert len(result["dependencies"]["resolved"]) == 1 + assert result["dependencies"]["resolved"][0]["className"] == "UserService" + assert result["callers"]["count"] == 0 + + def test_returns_resolved_callers(self): + index = _make_index(SAMPLE_FUNCTIONS) + executor = ToolExecutor(index) + executor.set_unit_context( + static_deps=[], + static_callers=["src/user.controller.ts:UserController.getUser"], + ) + + result = executor.execute("get_static_dependencies", {}) + assert result["callers"]["count"] == 1 + assert result["callers"]["resolved"][0]["className"] == "UserController" + + def test_empty_context(self): + index = _make_index(SAMPLE_FUNCTIONS) + executor = ToolExecutor(index) + executor.set_unit_context([], []) + + result = executor.execute("get_static_dependencies", {}) + assert result["dependencies"]["count"] == 0 + assert result["callers"]["count"] == 0 + + def test_context_resets_between_units(self): + index = _make_index(SAMPLE_FUNCTIONS) + executor = ToolExecutor(index) + + # First unit + executor.set_unit_context( + static_deps=["src/user.service.ts:UserService.findById"], + static_callers=[], + ) + result1 = executor.execute("get_static_dependencies", {}) + assert result1["dependencies"]["count"] == 1 + + # Second unit - different context + executor.set_unit_context(static_deps=[], static_callers=[]) + result2 = executor.execute("get_static_dependencies", {}) + assert result2["dependencies"]["count"] == 0 diff --git a/libs/openant-core/utilities/agentic_enhancer/agent.py b/libs/openant-core/utilities/agentic_enhancer/agent.py index 62061b7..513f728 100644 --- a/libs/openant-core/utilities/agentic_enhancer/agent.py +++ b/libs/openant-core/utilities/agentic_enhancer/agent.py @@ -161,6 +161,9 @@ def analyze_unit( entry_point_path = self.reachability.get_entry_point_path(unit_id) reaching_entry_point = self.reachability.get_reaching_entry_point(unit_id) + # Set static deps on tool executor for get_static_dependencies tool + self.tool_executor.set_unit_context(static_deps, static_callers) + # Build initial prompt with reachability info user_prompt = get_user_prompt( unit_id=unit_id, diff --git a/libs/openant-core/utilities/agentic_enhancer/prompts.py b/libs/openant-core/utilities/agentic_enhancer/prompts.py index dd9ca83..0594bbc 100644 --- a/libs/openant-core/utilities/agentic_enhancer/prompts.py +++ b/libs/openant-core/utilities/agentic_enhancer/prompts.py @@ -39,25 +39,40 @@ ## Your Analysis Process -1. **Identify Dangerous Operations** +1. **Get Static Dependencies First** + Call `get_static_dependencies` to see what functions this code calls and what calls it. + Then use `read_function` to examine key dependencies — especially service methods + that may contain authorization, validation, or sanitization. + +2. **Identify Dangerous Operations** Look for: eval, exec, SQL queries, file I/O, deserialization, command execution, innerHTML -2. **Trace User Input Reachability** +3. **Trace User Input Reachability (Backward)** If dangerous operations exist, trace BACKWARDS: - Who calls this function? - Who calls those callers? - Does the chain lead to an entry point (route handler, CLI parser, stdin)? -3. **Apply Classification Logic** +4. **Trace Forward Into Called Functions** + Check what the function CALLS — especially service/repository methods: + - Use `search_definitions` to find implementations of called methods + - Look for authorization checks (auth, permission, guard, can, allow, authorize) + - Look for validation/sanitization in called code + - A function may delegate security to its callees (e.g., service-layer auth) + - For `this.someService.method()` patterns, search for the method name definition + +5. **Apply Classification Logic** ``` Has dangerous sink? ├─ No → NEUTRAL or SECURITY_CONTROL └─ Yes → Is reachable from entry point? - ├─ Yes → EXPLOITABLE + ├─ Yes → Are there security controls in called functions? + │ ├─ Yes → May be SECURITY_CONTROL or lower severity + │ └─ No → EXPLOITABLE └─ No → VULNERABLE_INTERNAL ``` -4. **Complete with finish tool** +6. **Complete with finish tool** Provide classification, reasoning, and confidence level. ## Entry Point Examples @@ -150,19 +165,25 @@ def get_user_prompt( ## Your Task -1. **Analyze for dangerous operations**: eval, exec, SQL, file I/O, deserialization, etc. +1. **Start with `get_static_dependencies`** to see resolved callees and callers. + Then use `read_function` to examine called service/repository methods. -2. **Consider reachability**: Can user input reach any dangerous operations? +2. **Analyze for dangerous operations**: eval, exec, SQL, file I/O, deserialization, etc. + +3. **Consider reachability**: Can user input reach any dangerous operations? - If this is an entry point or reachable from one: vulnerabilities are EXPLOITABLE - If not reachable: vulnerabilities are VULNERABLE_INTERNAL -3. **Classify the code**: - - **EXPLOITABLE**: Dangerous ops + user input can reach them +4. **Trace forward**: Check called functions for authorization, validation, or security controls. + A function may delegate security to its service layer. + +5. **Classify the code**: + - **EXPLOITABLE**: Dangerous ops + user input can reach them + no security controls in callees - **VULNERABLE_INTERNAL**: Dangerous ops but no user input path - **SECURITY_CONTROL**: Defensive code (validators, sanitizers) - **NEUTRAL**: No security relevance -4. Call the `finish` tool with your classification and reasoning. +6. Call the `finish` tool with your classification and reasoning. Begin your analysis.""" diff --git a/libs/openant-core/utilities/agentic_enhancer/repository_index.py b/libs/openant-core/utilities/agentic_enhancer/repository_index.py index 06ef199..e027335 100644 --- a/libs/openant-core/utilities/agentic_enhancer/repository_index.py +++ b/libs/openant-core/utilities/agentic_enhancer/repository_index.py @@ -246,6 +246,54 @@ def read_file_section(self, file_path: str, start_line: int, end_line: int) -> O except Exception: return None + def resolve_dependencies(self, dep_names: list[str]) -> list[dict]: + """ + Resolve dependency names from static analysis to function entries. + + Handles both full function IDs (file:Class.method) and simple names. + + Args: + dep_names: List of function IDs or names from static analysis + + Returns: + List of {name, id, file, className} for each resolved dependency + """ + results = [] + seen_ids = set() + + for name in dep_names: + # First try as a direct function ID + func = self.functions.get(name) + if func and name not in seen_ids: + seen_ids.add(name) + results.append({ + "name": name, + "id": name, + "file": name.rsplit(":", 1)[0] if ":" in name else "", + "className": func.get("className") + }) + continue + + # Try exact name match + matches = self.search_by_name(name, exact=True) + if not matches: + # Try just the method part (e.g., "Class.method" -> "method") + parts = name.rsplit(".", 1) + if len(parts) == 2: + matches = self.search_by_name(parts[1], exact=True) + + for m in matches: + if m["id"] not in seen_ids: + seen_ids.add(m["id"]) + results.append({ + "name": name, + "id": m["id"], + "file": m["id"].rsplit(":", 1)[0] if ":" in m["id"] else "", + "className": m.get("className") + }) + + return results + def get_all_function_ids(self) -> list[str]: """ Get list of all function IDs. diff --git a/libs/openant-core/utilities/agentic_enhancer/tools.py b/libs/openant-core/utilities/agentic_enhancer/tools.py index b380c2c..8cf0947 100644 --- a/libs/openant-core/utilities/agentic_enhancer/tools.py +++ b/libs/openant-core/utilities/agentic_enhancer/tools.py @@ -102,6 +102,15 @@ "required": ["file_path", "start_line", "end_line"] } }, + { + "name": "get_static_dependencies", + "description": "Get the statically-analyzed dependencies (functions called) and callers for the unit being analyzed. Returns resolved function IDs that can be read with read_function. Use this first to understand what the code calls and to trace into service methods for auth/validation checks.", + "input_schema": { + "type": "object", + "properties": {}, + "required": [] + } + }, { "name": "finish", "description": "Complete the analysis and return the final result. Call this when you have gathered enough context to understand the code's intent and security implications.", @@ -165,6 +174,13 @@ def __init__(self, index: RepositoryIndex): index: RepositoryIndex instance for searching """ self.index = index + self._unit_static_deps: list[str] = [] + self._unit_static_callers: list[str] = [] + + def set_unit_context(self, static_deps: list[str], static_callers: list[str]): + """Set static dependency data for the current unit being analyzed.""" + self._unit_static_deps = static_deps or [] + self._unit_static_callers = static_callers or [] def execute(self, tool_name: str, tool_input: dict) -> dict: """ @@ -188,6 +204,8 @@ def execute(self, tool_name: str, tool_input: dict) -> dict: return self._list_functions(tool_input) elif tool_name == "read_file_section": return self._read_file_section(tool_input) + elif tool_name == "get_static_dependencies": + return self._get_static_dependencies(tool_input) elif tool_name == "finish": return self._finish(tool_input) else: @@ -315,6 +333,24 @@ def _read_file_section(self, input: dict) -> dict: "content": content } + def _get_static_dependencies(self, input: dict) -> dict: + """Get resolved static dependencies and callers for the current unit.""" + resolved_deps = self.index.resolve_dependencies(self._unit_static_deps) + resolved_callers = self.index.resolve_dependencies(self._unit_static_callers) + + return { + "dependencies": { + "raw": self._unit_static_deps[:20], + "resolved": resolved_deps[:20], + "count": len(self._unit_static_deps) + }, + "callers": { + "raw": self._unit_static_callers[:20], + "resolved": resolved_callers[:20], + "count": len(self._unit_static_callers) + } + } + def _finish(self, input: dict) -> dict: """Process finish tool - just validate and return the input.""" required = ["include_functions", "usage_context", "security_classification", "classification_reasoning", "confidence"] From 22aa52f9e8bec80498ff22bb083898bc2cf1cf84 Mon Sep 17 00:00:00 2001 From: joshbouncesecurity Date: Mon, 4 May 2026 22:31:22 +0300 Subject: [PATCH 2/7] @ fix(tests): replace missing run_utf8 import with subprocess.run test_di_resolution.py imported `run_utf8` from `utilities.file_io`, which does not exist in this repo. The import made the test module unimportable and broke pytest collection for the file (and any wider collection that included it). Mirror the helper used in test_js_parser.py and call subprocess.run directly. Co-Authored-By: Claude Opus 4.7 (1M context) @ --- libs/openant-core/tests/test_di_resolution.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/libs/openant-core/tests/test_di_resolution.py b/libs/openant-core/tests/test_di_resolution.py index 309b1f0..2dcecbe 100644 --- a/libs/openant-core/tests/test_di_resolution.py +++ b/libs/openant-core/tests/test_di_resolution.py @@ -7,11 +7,10 @@ cd parsers/javascript && npm install """ import json +import subprocess import shutil from pathlib import Path -from utilities.file_io import run_utf8 - import pytest PARSERS_JS_DIR = Path(__file__).parent.parent / "parsers" / "javascript" @@ -26,7 +25,7 @@ def run_node(script_name, *args): """Run a Node.js script from the JS parsers directory.""" cmd = ["node", str(PARSERS_JS_DIR / script_name)] + list(args) - return run_utf8(cmd, capture_output=True, text=True, timeout=30) + return subprocess.run(cmd, capture_output=True, text=True, timeout=30) # -- Fixture: NestJS-style DI codebase -- From d0e2214028344ecd48d423fbfab4fe14fe3bf1ff Mon Sep 17 00:00:00 2001 From: joshbouncesecurity Date: Tue, 12 May 2026 11:53:55 +0300 Subject: [PATCH 3/7] refactor: remove agentic enhancer changes (split to separate PR) Agentic enhancer changes (get_static_dependencies tool, prompt rewrite, resolve_dependencies) moved to feat/issue16-07-ts-di-enhancer for independent review and eval. This branch now contains only the DI-aware parser changes as requested in the PR #39 review. Co-Authored-By: Claude Sonnet 4.6 --- .../openant-core/tests/test_enhancer_tools.py | 127 ------------------ .../utilities/agentic_enhancer/agent.py | 3 - .../utilities/agentic_enhancer/prompts.py | 41 ++---- .../agentic_enhancer/repository_index.py | 54 +------- .../utilities/agentic_enhancer/tools.py | 36 ----- 5 files changed, 13 insertions(+), 248 deletions(-) delete mode 100644 libs/openant-core/tests/test_enhancer_tools.py diff --git a/libs/openant-core/tests/test_enhancer_tools.py b/libs/openant-core/tests/test_enhancer_tools.py deleted file mode 100644 index a862f05..0000000 --- a/libs/openant-core/tests/test_enhancer_tools.py +++ /dev/null @@ -1,127 +0,0 @@ -"""Tests for the agentic enhancer tools, specifically the get_static_dependencies tool.""" -import pytest - -from utilities.agentic_enhancer.repository_index import RepositoryIndex -from utilities.agentic_enhancer.tools import ToolExecutor - - -def _make_index(functions: dict) -> RepositoryIndex: - """Create a RepositoryIndex from a minimal functions dict.""" - return RepositoryIndex({"functions": functions}) - - -SAMPLE_FUNCTIONS = { - "src/user.controller.ts:UserController.getUser": { - "name": "UserController.getUser", - "code": "async getUser(id) { return this.userService.findById(id); }", - "className": "UserController", - "unitType": "class_method", - "startLine": 10, - "endLine": 12, - }, - "src/user.service.ts:UserService.findById": { - "name": "UserService.findById", - "code": "async findById(id) { return this.repo.findOne(id); }", - "className": "UserService", - "unitType": "class_method", - "startLine": 5, - "endLine": 7, - }, - "src/auth.guard.ts:AuthGuard.canActivate": { - "name": "AuthGuard.canActivate", - "code": "canActivate(context) { return this.validate(context); }", - "className": "AuthGuard", - "unitType": "class_method", - "startLine": 3, - "endLine": 5, - }, -} - - -class TestResolveDependencies: - """Test RepositoryIndex.resolve_dependencies.""" - - def test_resolves_by_function_id(self): - index = _make_index(SAMPLE_FUNCTIONS) - result = index.resolve_dependencies([ - "src/user.service.ts:UserService.findById" - ]) - assert len(result) == 1 - assert result[0]["id"] == "src/user.service.ts:UserService.findById" - assert result[0]["className"] == "UserService" - - def test_resolves_by_qualified_name(self): - """Resolve using Class.method format when full ID is unknown.""" - index = _make_index(SAMPLE_FUNCTIONS) - result = index.resolve_dependencies(["AuthGuard.canActivate"]) - assert len(result) == 1 - assert "AuthGuard.canActivate" in result[0]["id"] - - def test_returns_empty_for_unknown(self): - index = _make_index(SAMPLE_FUNCTIONS) - result = index.resolve_dependencies(["nonExistentFunction"]) - assert result == [] - - def test_deduplicates_results(self): - index = _make_index(SAMPLE_FUNCTIONS) - result = index.resolve_dependencies([ - "src/user.service.ts:UserService.findById", - "src/user.service.ts:UserService.findById", - ]) - assert len(result) == 1 - - -class TestGetStaticDependenciesTool: - """Test the get_static_dependencies tool via ToolExecutor.""" - - def test_returns_resolved_deps(self): - index = _make_index(SAMPLE_FUNCTIONS) - executor = ToolExecutor(index) - executor.set_unit_context( - static_deps=["src/user.service.ts:UserService.findById"], - static_callers=[], - ) - - result = executor.execute("get_static_dependencies", {}) - assert result["dependencies"]["count"] == 1 - assert len(result["dependencies"]["resolved"]) == 1 - assert result["dependencies"]["resolved"][0]["className"] == "UserService" - assert result["callers"]["count"] == 0 - - def test_returns_resolved_callers(self): - index = _make_index(SAMPLE_FUNCTIONS) - executor = ToolExecutor(index) - executor.set_unit_context( - static_deps=[], - static_callers=["src/user.controller.ts:UserController.getUser"], - ) - - result = executor.execute("get_static_dependencies", {}) - assert result["callers"]["count"] == 1 - assert result["callers"]["resolved"][0]["className"] == "UserController" - - def test_empty_context(self): - index = _make_index(SAMPLE_FUNCTIONS) - executor = ToolExecutor(index) - executor.set_unit_context([], []) - - result = executor.execute("get_static_dependencies", {}) - assert result["dependencies"]["count"] == 0 - assert result["callers"]["count"] == 0 - - def test_context_resets_between_units(self): - index = _make_index(SAMPLE_FUNCTIONS) - executor = ToolExecutor(index) - - # First unit - executor.set_unit_context( - static_deps=["src/user.service.ts:UserService.findById"], - static_callers=[], - ) - result1 = executor.execute("get_static_dependencies", {}) - assert result1["dependencies"]["count"] == 1 - - # Second unit - different context - executor.set_unit_context(static_deps=[], static_callers=[]) - result2 = executor.execute("get_static_dependencies", {}) - assert result2["dependencies"]["count"] == 0 diff --git a/libs/openant-core/utilities/agentic_enhancer/agent.py b/libs/openant-core/utilities/agentic_enhancer/agent.py index 513f728..62061b7 100644 --- a/libs/openant-core/utilities/agentic_enhancer/agent.py +++ b/libs/openant-core/utilities/agentic_enhancer/agent.py @@ -161,9 +161,6 @@ def analyze_unit( entry_point_path = self.reachability.get_entry_point_path(unit_id) reaching_entry_point = self.reachability.get_reaching_entry_point(unit_id) - # Set static deps on tool executor for get_static_dependencies tool - self.tool_executor.set_unit_context(static_deps, static_callers) - # Build initial prompt with reachability info user_prompt = get_user_prompt( unit_id=unit_id, diff --git a/libs/openant-core/utilities/agentic_enhancer/prompts.py b/libs/openant-core/utilities/agentic_enhancer/prompts.py index 0594bbc..dd9ca83 100644 --- a/libs/openant-core/utilities/agentic_enhancer/prompts.py +++ b/libs/openant-core/utilities/agentic_enhancer/prompts.py @@ -39,40 +39,25 @@ ## Your Analysis Process -1. **Get Static Dependencies First** - Call `get_static_dependencies` to see what functions this code calls and what calls it. - Then use `read_function` to examine key dependencies — especially service methods - that may contain authorization, validation, or sanitization. - -2. **Identify Dangerous Operations** +1. **Identify Dangerous Operations** Look for: eval, exec, SQL queries, file I/O, deserialization, command execution, innerHTML -3. **Trace User Input Reachability (Backward)** +2. **Trace User Input Reachability** If dangerous operations exist, trace BACKWARDS: - Who calls this function? - Who calls those callers? - Does the chain lead to an entry point (route handler, CLI parser, stdin)? -4. **Trace Forward Into Called Functions** - Check what the function CALLS — especially service/repository methods: - - Use `search_definitions` to find implementations of called methods - - Look for authorization checks (auth, permission, guard, can, allow, authorize) - - Look for validation/sanitization in called code - - A function may delegate security to its callees (e.g., service-layer auth) - - For `this.someService.method()` patterns, search for the method name definition - -5. **Apply Classification Logic** +3. **Apply Classification Logic** ``` Has dangerous sink? ├─ No → NEUTRAL or SECURITY_CONTROL └─ Yes → Is reachable from entry point? - ├─ Yes → Are there security controls in called functions? - │ ├─ Yes → May be SECURITY_CONTROL or lower severity - │ └─ No → EXPLOITABLE + ├─ Yes → EXPLOITABLE └─ No → VULNERABLE_INTERNAL ``` -6. **Complete with finish tool** +4. **Complete with finish tool** Provide classification, reasoning, and confidence level. ## Entry Point Examples @@ -165,25 +150,19 @@ def get_user_prompt( ## Your Task -1. **Start with `get_static_dependencies`** to see resolved callees and callers. - Then use `read_function` to examine called service/repository methods. +1. **Analyze for dangerous operations**: eval, exec, SQL, file I/O, deserialization, etc. -2. **Analyze for dangerous operations**: eval, exec, SQL, file I/O, deserialization, etc. - -3. **Consider reachability**: Can user input reach any dangerous operations? +2. **Consider reachability**: Can user input reach any dangerous operations? - If this is an entry point or reachable from one: vulnerabilities are EXPLOITABLE - If not reachable: vulnerabilities are VULNERABLE_INTERNAL -4. **Trace forward**: Check called functions for authorization, validation, or security controls. - A function may delegate security to its service layer. - -5. **Classify the code**: - - **EXPLOITABLE**: Dangerous ops + user input can reach them + no security controls in callees +3. **Classify the code**: + - **EXPLOITABLE**: Dangerous ops + user input can reach them - **VULNERABLE_INTERNAL**: Dangerous ops but no user input path - **SECURITY_CONTROL**: Defensive code (validators, sanitizers) - **NEUTRAL**: No security relevance -6. Call the `finish` tool with your classification and reasoning. +4. Call the `finish` tool with your classification and reasoning. Begin your analysis.""" diff --git a/libs/openant-core/utilities/agentic_enhancer/repository_index.py b/libs/openant-core/utilities/agentic_enhancer/repository_index.py index e027335..5af649c 100644 --- a/libs/openant-core/utilities/agentic_enhancer/repository_index.py +++ b/libs/openant-core/utilities/agentic_enhancer/repository_index.py @@ -14,11 +14,12 @@ load_index_from_file: Load index from analyzer_output.json file """ -import json import re from pathlib import Path from typing import Optional +from utilities.file_io import read_json + class RepositoryIndex: """ @@ -246,54 +247,6 @@ def read_file_section(self, file_path: str, start_line: int, end_line: int) -> O except Exception: return None - def resolve_dependencies(self, dep_names: list[str]) -> list[dict]: - """ - Resolve dependency names from static analysis to function entries. - - Handles both full function IDs (file:Class.method) and simple names. - - Args: - dep_names: List of function IDs or names from static analysis - - Returns: - List of {name, id, file, className} for each resolved dependency - """ - results = [] - seen_ids = set() - - for name in dep_names: - # First try as a direct function ID - func = self.functions.get(name) - if func and name not in seen_ids: - seen_ids.add(name) - results.append({ - "name": name, - "id": name, - "file": name.rsplit(":", 1)[0] if ":" in name else "", - "className": func.get("className") - }) - continue - - # Try exact name match - matches = self.search_by_name(name, exact=True) - if not matches: - # Try just the method part (e.g., "Class.method" -> "method") - parts = name.rsplit(".", 1) - if len(parts) == 2: - matches = self.search_by_name(parts[1], exact=True) - - for m in matches: - if m["id"] not in seen_ids: - seen_ids.add(m["id"]) - results.append({ - "name": name, - "id": m["id"], - "file": m["id"].rsplit(":", 1)[0] if ":" in m["id"] else "", - "className": m.get("className") - }) - - return results - def get_all_function_ids(self) -> list[str]: """ Get list of all function IDs. @@ -331,7 +284,6 @@ def load_index_from_file(analyzer_output_path: str, repo_path: str = None) -> Re Returns: RepositoryIndex instance """ - with open(analyzer_output_path, 'r') as f: - analyzer_output = json.load(f) + analyzer_output = read_json(analyzer_output_path) return RepositoryIndex(analyzer_output, repo_path) diff --git a/libs/openant-core/utilities/agentic_enhancer/tools.py b/libs/openant-core/utilities/agentic_enhancer/tools.py index 8cf0947..b380c2c 100644 --- a/libs/openant-core/utilities/agentic_enhancer/tools.py +++ b/libs/openant-core/utilities/agentic_enhancer/tools.py @@ -102,15 +102,6 @@ "required": ["file_path", "start_line", "end_line"] } }, - { - "name": "get_static_dependencies", - "description": "Get the statically-analyzed dependencies (functions called) and callers for the unit being analyzed. Returns resolved function IDs that can be read with read_function. Use this first to understand what the code calls and to trace into service methods for auth/validation checks.", - "input_schema": { - "type": "object", - "properties": {}, - "required": [] - } - }, { "name": "finish", "description": "Complete the analysis and return the final result. Call this when you have gathered enough context to understand the code's intent and security implications.", @@ -174,13 +165,6 @@ def __init__(self, index: RepositoryIndex): index: RepositoryIndex instance for searching """ self.index = index - self._unit_static_deps: list[str] = [] - self._unit_static_callers: list[str] = [] - - def set_unit_context(self, static_deps: list[str], static_callers: list[str]): - """Set static dependency data for the current unit being analyzed.""" - self._unit_static_deps = static_deps or [] - self._unit_static_callers = static_callers or [] def execute(self, tool_name: str, tool_input: dict) -> dict: """ @@ -204,8 +188,6 @@ def execute(self, tool_name: str, tool_input: dict) -> dict: return self._list_functions(tool_input) elif tool_name == "read_file_section": return self._read_file_section(tool_input) - elif tool_name == "get_static_dependencies": - return self._get_static_dependencies(tool_input) elif tool_name == "finish": return self._finish(tool_input) else: @@ -333,24 +315,6 @@ def _read_file_section(self, input: dict) -> dict: "content": content } - def _get_static_dependencies(self, input: dict) -> dict: - """Get resolved static dependencies and callers for the current unit.""" - resolved_deps = self.index.resolve_dependencies(self._unit_static_deps) - resolved_callers = self.index.resolve_dependencies(self._unit_static_callers) - - return { - "dependencies": { - "raw": self._unit_static_deps[:20], - "resolved": resolved_deps[:20], - "count": len(self._unit_static_deps) - }, - "callers": { - "raw": self._unit_static_callers[:20], - "resolved": resolved_callers[:20], - "count": len(self._unit_static_callers) - } - } - def _finish(self, input: dict) -> dict: """Process finish tool - just validate and return the input.""" required = ["include_functions", "usage_context", "security_classification", "classification_reasoning", "confidence"] From 1bc1331d4908a6f68f50c8a96766e125d6bc3c10 Mon Sep 17 00:00:00 2001 From: joshbouncesecurity Date: Tue, 12 May 2026 14:28:29 +0300 Subject: [PATCH 4/7] fix(di-resolution): address PR review feedback - dependency_resolver.js: prefix matching now collects all candidates before returning; skips resolution when multiple classes share the prefix (e.g. CallService matches both CallServiceV1 and CallServiceMock) to preserve the resolver's no-false-positive property - typescript_analyzer.js: strip generic parameters before PascalCase test so Repository resolves as Repository (covers NestJS TypeORM) - typescript_analyzer.js: document the single-constructor assumption - test_di_resolution.py: add test for ambiguous prefix case Co-Authored-By: Claude Sonnet 4.6 --- .../parsers/javascript/dependency_resolver.js | 7 +++- .../parsers/javascript/typescript_analyzer.js | 6 ++- libs/openant-core/tests/test_di_resolution.py | 42 +++++++++++++++++++ 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/libs/openant-core/parsers/javascript/dependency_resolver.js b/libs/openant-core/parsers/javascript/dependency_resolver.js index 84769fd..fb82682 100644 --- a/libs/openant-core/parsers/javascript/dependency_resolver.js +++ b/libs/openant-core/parsers/javascript/dependency_resolver.js @@ -279,12 +279,17 @@ class DependencyResolver { // 2b. Implementation class match: type is often an interface/abstract class // and the implementation has a suffix (e.g., CallService -> CallServiceV1, CallServiceImpl) + // Collect all prefix matches; if more than one, skip to avoid non-deterministic resolution. + const prefixMatches = []; for (const funcId of candidates) { const funcData = this.functions[funcId]; if (funcData && funcData.className && funcData.className.startsWith(typeName)) { - return funcId; + prefixMatches.push(funcId); } } + if (prefixMatches.length === 1) { + return prefixMatches[0]; + } } } } diff --git a/libs/openant-core/parsers/javascript/typescript_analyzer.js b/libs/openant-core/parsers/javascript/typescript_analyzer.js index 08e3128..5aa002f 100644 --- a/libs/openant-core/parsers/javascript/typescript_analyzer.js +++ b/libs/openant-core/parsers/javascript/typescript_analyzer.js @@ -232,6 +232,7 @@ class TypeScriptAnalyzer { // In NestJS/Angular, constructor parameters with type annotations // declare injected services: constructor(private callService: CallService) const constructors = classDecl.getConstructors(); + // DI classes have a single primary constructor; overloads are unusual in NestJS/Angular. if (constructors.length > 0) { const ctor = constructors[0]; const injections = {}; // paramName -> typeName @@ -240,8 +241,9 @@ class TypeScriptAnalyzer { const paramName = param.getName(); const typeNode = param.getTypeNode(); if (typeNode) { - const typeName = typeNode.getText(); - // Only store simple PascalCase type names (skip union types, generics, primitives) + // Strip generic parameters so Repository resolves as Repository + const typeName = typeNode.getText().replace(/<.*$/, ''); + // Only store simple PascalCase type names (skip union types, primitives) if (/^[A-Z][a-zA-Z0-9_$]*$/.test(typeName)) { injections[paramName] = typeName; } diff --git a/libs/openant-core/tests/test_di_resolution.py b/libs/openant-core/tests/test_di_resolution.py index 2dcecbe..98f93ff 100644 --- a/libs/openant-core/tests/test_di_resolution.py +++ b/libs/openant-core/tests/test_di_resolution.py @@ -272,6 +272,48 @@ def test_resolves_multiple_di_methods(self, nestjs_repo): "CallService.remove" in c for c in delete_calls ), f"Expected CallService.remove in calls, got: {delete_calls}" + def test_ambiguous_prefix_skips_resolution(self, tmp_path): + """When multiple classes share a type-name prefix, resolution is skipped.""" + src = tmp_path / "src" + src.mkdir() + (src / "resolver.ts").write_text("""\ +export class MyResolver { + constructor(private callService: CallService) {} + getCall(id: string) { + return this.callService.getById(id); + } +} +""") + (src / "call_service.ts").write_text("""\ +export class CallServiceV1 { + getById(id: string) { return 'v1'; } +} +""") + (src / "call_service_mock.ts").write_text("""\ +export class CallServiceMock { + getById(id: string) { return 'mock'; } +} +""") + data = analyze_and_resolve(tmp_path, [ + "src/resolver.ts", + "src/call_service.ts", + "src/call_service_mock.ts", + ]) + + call_graph = data["callGraph"] + resolver_calls = None + for fid, calls in call_graph.items(): + if "MyResolver.getCall" in fid: + resolver_calls = calls + break + + # Two classes match the CallService prefix — should not resolve to either + assert resolver_calls is not None + assert not any( + "CallServiceV1.getById" in c or "CallServiceMock.getById" in c + for c in resolver_calls + ), f"Should not resolve ambiguous prefix match, got: {resolver_calls}" + def test_no_false_positives_without_di(self, tmp_path): """Methods without constructor deps should not spuriously resolve.""" src = tmp_path / "src" From f1d6060899949485c500cfddfc7acf1d2259a05e Mon Sep 17 00:00:00 2001 From: joshbouncesecurity Date: Tue, 12 May 2026 14:50:20 +0300 Subject: [PATCH 5/7] feat(di-resolution): class-level metadata table and nominal type matching Move constructorDeps from per-method entries to a class-level `classes` table (className -> { constructorDeps, baseTypes }), matching the output schema of the Python, PHP, Ruby, and Zig parsers. This eliminates N redundant copies per class and provides a single source of truth. Also extract implements/extends clauses as baseTypes and use them in the resolver for nominal type matching. Resolution priority is now: 1. Exact class name match (existing) 2. Exact injected type match (existing) 3. Nominal: unique class that implements/extends the type 4. Prefix: unambiguous prefix match (existing, last resort) Steps 3 and 4 both return null when multiple candidates match to preserve the no-false-positive property. Co-Authored-By: Claude Sonnet 4.6 --- .../parsers/javascript/dependency_resolver.js | 41 ++-- .../parsers/javascript/typescript_analyzer.js | 40 ++-- libs/openant-core/tests/test_di_resolution.py | 198 ++++++++++++++++-- 3 files changed, 230 insertions(+), 49 deletions(-) diff --git a/libs/openant-core/parsers/javascript/dependency_resolver.js b/libs/openant-core/parsers/javascript/dependency_resolver.js index fb82682..90679cc 100644 --- a/libs/openant-core/parsers/javascript/dependency_resolver.js +++ b/libs/openant-core/parsers/javascript/dependency_resolver.js @@ -20,6 +20,7 @@ const path = require('path'); class DependencyResolver { constructor(analyzerOutput, options = {}) { this.functions = analyzerOutput.functions || {}; + this.classes = analyzerOutput.classes || {}; // className -> { constructorDeps, baseTypes } this.callGraph = {}; // functionId -> [calledFunctionIds] this.reverseCallGraph = {}; // functionId -> [callerFunctionIds] this.maxDepth = options.maxDepth || 3; @@ -29,6 +30,7 @@ class DependencyResolver { this.functionsByName = Object.create(null); // simpleName -> [functionIds] this.functionsByFile = Object.create(null); // filePath -> [functionIds] this.imports = Object.create(null); // filePath -> { importedName -> { source, originalName } } + this.classesByBaseType = Object.create(null); // baseTypeName -> [classNames] this._buildIndexes(); } @@ -52,6 +54,13 @@ class DependencyResolver { } this.functionsByFile[filePath].push(funcId); } + + for (const [className, classData] of Object.entries(this.classes)) { + for (const baseType of (classData.baseTypes || [])) { + if (!this.classesByBaseType[baseType]) this.classesByBaseType[baseType] = []; + this.classesByBaseType[baseType].push(className); + } + } } /** @@ -266,8 +275,9 @@ class DependencyResolver { // -> resolve to CallService.getById if (callerFuncId) { const callerFunc = this.functions[callerFuncId]; - if (callerFunc && callerFunc.constructorDeps) { - const typeName = callerFunc.constructorDeps[objectName]; + const classEntry = callerFunc && this.classes[callerFunc.className]; + if (classEntry && classEntry.constructorDeps) { + const typeName = classEntry.constructorDeps[objectName]; if (typeName) { // 2a. Exact type match for (const funcId of candidates) { @@ -277,19 +287,22 @@ class DependencyResolver { } } - // 2b. Implementation class match: type is often an interface/abstract class - // and the implementation has a suffix (e.g., CallService -> CallServiceV1, CallServiceImpl) - // Collect all prefix matches; if more than one, skip to avoid non-deterministic resolution. - const prefixMatches = []; - for (const funcId of candidates) { + // 2b. Nominal type match: prefer candidates whose class implements or extends typeName. + // If exactly one such candidate exists, the resolution is unambiguous. + const nominalClassNames = this.classesByBaseType[typeName] || []; + const nominalMatches = candidates.filter(funcId => { const funcData = this.functions[funcId]; - if (funcData && funcData.className && funcData.className.startsWith(typeName)) { - prefixMatches.push(funcId); - } - } - if (prefixMatches.length === 1) { - return prefixMatches[0]; - } + return funcData && nominalClassNames.includes(funcData.className); + }); + if (nominalMatches.length === 1) return nominalMatches[0]; + + // 2c. Prefix match: last resort for versioned names (e.g., CallService -> CallServiceV1). + // Skip if multiple candidates match to preserve no-false-positive property. + const prefixMatches = candidates.filter(funcId => { + const funcData = this.functions[funcId]; + return funcData && funcData.className && funcData.className.startsWith(typeName); + }); + if (prefixMatches.length === 1) return prefixMatches[0]; } } } diff --git a/libs/openant-core/parsers/javascript/typescript_analyzer.js b/libs/openant-core/parsers/javascript/typescript_analyzer.js index 5aa002f..5b49bc1 100644 --- a/libs/openant-core/parsers/javascript/typescript_analyzer.js +++ b/libs/openant-core/parsers/javascript/typescript_analyzer.js @@ -55,6 +55,7 @@ class TypeScriptAnalyzer { compilerOptions: PERMISSIVE_COMPILER_OPTIONS, }); this.functions = {}; // functionId -> function metadata + this.classes = {}; // className -> { constructorDeps, baseTypes } this.callGraph = {}; // callerId -> array of call info } @@ -155,6 +156,7 @@ class TypeScriptAnalyzer { return { functions: this.functions, + classes: this.classes, callGraph: this.callGraph, }; } @@ -228,11 +230,26 @@ class TypeScriptAnalyzer { }; } - // Extract constructor DI metadata for this class - // In NestJS/Angular, constructor parameters with type annotations - // declare injected services: constructor(private callService: CallService) - const constructors = classDecl.getConstructors(); + // Build class-level metadata: constructorDeps and baseTypes + const classEntry = {}; + + // Extract base types (implements + extends) for nominal DI resolution. + // Strips generics: implements Repository -> Repository + const baseTypes = []; + const extendsExpr = classDecl.getExtends(); + if (extendsExpr) { + const name = extendsExpr.getExpression().getText().replace(/<.*$/, ''); + if (/^[A-Z][a-zA-Z0-9_$]*$/.test(name)) baseTypes.push(name); + } + for (const impl of classDecl.getImplements()) { + const name = impl.getExpression().getText().replace(/<.*$/, ''); + if (/^[A-Z][a-zA-Z0-9_$]*$/.test(name)) baseTypes.push(name); + } + if (baseTypes.length > 0) classEntry.baseTypes = baseTypes; + + // Extract constructor DI metadata. // DI classes have a single primary constructor; overloads are unusual in NestJS/Angular. + const constructors = classDecl.getConstructors(); if (constructors.length > 0) { const ctor = constructors[0]; const injections = {}; // paramName -> typeName @@ -250,16 +267,11 @@ class TypeScriptAnalyzer { } } - if (Object.keys(injections).length > 0) { - // Store DI metadata on each method of this class - for (const method of classDecl.getMethods()) { - const methodName = method.getName(); - const functionId = `${relativePath}:${className}.${methodName}`; - if (this.functions[functionId]) { - this.functions[functionId].constructorDeps = injections; - } - } - } + if (Object.keys(injections).length > 0) classEntry.constructorDeps = injections; + } + + if (Object.keys(classEntry).length > 0) { + this.classes[className] = classEntry; } } diff --git a/libs/openant-core/tests/test_di_resolution.py b/libs/openant-core/tests/test_di_resolution.py index 98f93ff..54ad5a8 100644 --- a/libs/openant-core/tests/test_di_resolution.py +++ b/libs/openant-core/tests/test_di_resolution.py @@ -97,6 +97,39 @@ def run_node(script_name, *args): } """ +# Interface + implementing class for nominal type tests +ICALL_SERVICE_TS = """\ +export interface ICallService { + getById(id: string): Promise; +} +""" + +IMPL_CALL_SERVICE_TS = """\ +import { Injectable } from '@nestjs/common'; +import { ICallService } from './icall.service'; + +@Injectable() +export class CallServiceImpl implements ICallService { + async getById(id: string) { + return { id }; + } +} +""" + +NOMINAL_RESOLVER_TS = """\ +import { Injectable } from '@nestjs/common'; +import { ICallService } from './icall.service'; + +@Injectable() +export class CallResolver { + constructor(private callService: ICallService) {} + + async getCall(id: string) { + return this.callService.getById(id); + } +} +""" + @pytest.fixture def nestjs_repo(tmp_path): @@ -119,6 +152,17 @@ def nestjs_repo_versioned(tmp_path): return tmp_path +@pytest.fixture +def nestjs_repo_nominal(tmp_path): + """Create a repo where injection is via interface and impl uses implements.""" + src = tmp_path / "src" + src.mkdir() + (src / "icall.service.ts").write_text(ICALL_SERVICE_TS) + (src / "call.service.impl.ts").write_text(IMPL_CALL_SERVICE_TS) + (src / "call.resolver.ts").write_text(NOMINAL_RESOLVER_TS) + return tmp_path + + def analyze_and_resolve(repo_path, files): """Run analyzer + resolver on given files and return resolved data.""" analyzer_out = repo_path / "analyzer_output.json" @@ -142,7 +186,7 @@ def analyze_and_resolve(repo_path, files): class TestConstructorDepsExtraction: - """Test that the analyzer extracts constructorDeps from class constructors.""" + """Test that the analyzer extracts constructorDeps into the classes table.""" def test_extracts_constructor_deps(self, nestjs_repo): analyzer_out = nestjs_repo / "analyzer_output.json" @@ -154,21 +198,17 @@ def test_extracts_constructor_deps(self, nestjs_repo): assert result.returncode == 0 data = json.loads(analyzer_out.read_text()) - functions = data["functions"] - - # Find a CallResolver method - resolver_methods = { - fid: f for fid, f in functions.items() - if "CallResolver" in fid - } - assert len(resolver_methods) > 0, "No CallResolver methods found" - - # Each method should have constructorDeps - for fid, func in resolver_methods.items(): - assert "constructorDeps" in func, f"{fid} missing constructorDeps" - deps = func["constructorDeps"] - assert deps.get("callService") == "CallService" - assert deps.get("authService") == "AuthService" + classes = data["classes"] + + assert "CallResolver" in classes, "CallResolver not in classes table" + deps = classes["CallResolver"].get("constructorDeps", {}) + assert deps.get("callService") == "CallService" + assert deps.get("authService") == "AuthService" + + # Methods themselves should NOT carry constructorDeps (stored in classes table instead) + for fid, func in data["functions"].items(): + if "CallResolver" in fid: + assert "constructorDeps" not in func, f"{fid} should not have constructorDeps" def test_skips_primitive_types(self, tmp_path): """Constructor params with primitive types should not be included.""" @@ -196,11 +236,7 @@ def test_skips_primitive_types(self, tmp_path): assert result.returncode == 0 data = json.loads(analyzer_out.read_text()) - func = next( - f for f in data["functions"].values() - if f.get("className") == "Example" - ) - deps = func.get("constructorDeps", {}) + deps = data["classes"].get("Example", {}).get("constructorDeps", {}) # Only MyService should be captured (PascalCase), not string/number assert "service" in deps assert deps["service"] == "MyService" @@ -208,6 +244,126 @@ def test_skips_primitive_types(self, tmp_path): assert "count" not in deps +class TestBaseTypesExtraction: + """Test that the analyzer extracts implements/extends into baseTypes.""" + + def test_extracts_implements(self, nestjs_repo_nominal): + analyzer_out = nestjs_repo_nominal / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(nestjs_repo_nominal), + "src/call.service.impl.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + + data = json.loads(analyzer_out.read_text()) + base_types = data["classes"].get("CallServiceImpl", {}).get("baseTypes", []) + assert "ICallService" in base_types + + def test_generic_implements_stripped(self, tmp_path): + """implements Repository should store as Repository.""" + src = tmp_path / "src" + src.mkdir() + (src / "impl.ts").write_text("""\ +export class UserRepo implements Repository { + findOne(id: string) { return null; } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/impl.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + + data = json.loads(analyzer_out.read_text()) + base_types = data["classes"].get("UserRepo", {}).get("baseTypes", []) + assert "Repository" in base_types + assert not any("<" in t for t in base_types) + + def test_extracts_extends(self, tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "impl.ts").write_text("""\ +export class ConcreteService extends BaseService { + run() { return true; } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/impl.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + + data = json.loads(analyzer_out.read_text()) + base_types = data["classes"].get("ConcreteService", {}).get("baseTypes", []) + assert "BaseService" in base_types + + +class TestNominalTypeResolution: + """Test that implements/extends clauses are used for DI resolution.""" + + def test_resolves_via_implements(self, nestjs_repo_nominal): + """this.callService.getById() resolves to CallServiceImpl.getById via implements.""" + data = analyze_and_resolve(nestjs_repo_nominal, [ + "src/call.resolver.ts", + "src/call.service.impl.ts", + ]) + + call_graph = data["callGraph"] + resolver_calls = None + for fid, calls in call_graph.items(): + if "CallResolver.getCall" in fid: + resolver_calls = calls + break + + assert resolver_calls is not None, "CallResolver.getCall not in call graph" + assert any( + "CallServiceImpl.getById" in c for c in resolver_calls + ), f"Expected CallServiceImpl.getById via implements, got: {resolver_calls}" + + def test_nominal_ambiguity_skips_resolution(self, tmp_path): + """Two classes implementing same interface → no resolution (ambiguous).""" + src = tmp_path / "src" + src.mkdir() + (src / "resolver.ts").write_text("""\ +export class MyResolver { + constructor(private svc: IMyService) {} + work() { return this.svc.run(); } +} +""") + (src / "impl_a.ts").write_text("""\ +export class ImplA implements IMyService { + run() { return 'a'; } +} +""") + (src / "impl_b.ts").write_text("""\ +export class ImplB implements IMyService { + run() { return 'b'; } +} +""") + data = analyze_and_resolve(tmp_path, [ + "src/resolver.ts", + "src/impl_a.ts", + "src/impl_b.ts", + ]) + + call_graph = data["callGraph"] + resolver_calls = None + for fid, calls in call_graph.items(): + if "MyResolver.work" in fid: + resolver_calls = calls + break + + assert resolver_calls is not None + assert not any( + "ImplA.run" in c or "ImplB.run" in c for c in resolver_calls + ), f"Should not resolve ambiguous implements, got: {resolver_calls}" + + class TestDIAwareCallResolution: """Test that the dependency resolver uses constructorDeps for DI resolution.""" From 9e55a22ebd8e2e07ce938d907a7dd6ad58dd3a95 Mon Sep 17 00:00:00 2001 From: joshbouncesecurity Date: Tue, 12 May 2026 17:56:13 +0300 Subject: [PATCH 6/7] fix(di-resolution): file-qualify classes table key to prevent same-name collisions Keying this.classes by className alone was last-write-wins for NestJS multi-module apps where two files declare a class with the same name. The caller lookup in _resolveMethodCall then fetched the wrong class entry, silently dropping DI edges for the non-last-parsed class. Change both the analyzer (typescript_analyzer.js) and the resolver (dependency_resolver.js) to key by "filePath:className", matching the convention used by function IDs throughout the codebase. The classesByBaseType index now stores qualified keys; the nominal-match filter constructs the same key from each candidate's funcId prefix. Adds a regression test that asserts both entries are present after parsing two same-named classes, and that the resolver produces the correct call edges for each (not the other's service). Co-Authored-By: Claude Sonnet 4.6 --- .../parsers/javascript/dependency_resolver.js | 17 ++-- .../parsers/javascript/typescript_analyzer.js | 4 +- libs/openant-core/tests/test_di_resolution.py | 92 +++++++++++++++++-- 3 files changed, 98 insertions(+), 15 deletions(-) diff --git a/libs/openant-core/parsers/javascript/dependency_resolver.js b/libs/openant-core/parsers/javascript/dependency_resolver.js index 90679cc..f977a0c 100644 --- a/libs/openant-core/parsers/javascript/dependency_resolver.js +++ b/libs/openant-core/parsers/javascript/dependency_resolver.js @@ -20,7 +20,7 @@ const path = require('path'); class DependencyResolver { constructor(analyzerOutput, options = {}) { this.functions = analyzerOutput.functions || {}; - this.classes = analyzerOutput.classes || {}; // className -> { constructorDeps, baseTypes } + this.classes = analyzerOutput.classes || {}; // "filePath:className" -> { constructorDeps, baseTypes } this.callGraph = {}; // functionId -> [calledFunctionIds] this.reverseCallGraph = {}; // functionId -> [callerFunctionIds] this.maxDepth = options.maxDepth || 3; @@ -30,7 +30,7 @@ class DependencyResolver { this.functionsByName = Object.create(null); // simpleName -> [functionIds] this.functionsByFile = Object.create(null); // filePath -> [functionIds] this.imports = Object.create(null); // filePath -> { importedName -> { source, originalName } } - this.classesByBaseType = Object.create(null); // baseTypeName -> [classNames] + this.classesByBaseType = Object.create(null); // baseTypeName -> ["filePath:className", ...] this._buildIndexes(); } @@ -55,10 +55,10 @@ class DependencyResolver { this.functionsByFile[filePath].push(funcId); } - for (const [className, classData] of Object.entries(this.classes)) { + for (const [classKey, classData] of Object.entries(this.classes)) { for (const baseType of (classData.baseTypes || [])) { if (!this.classesByBaseType[baseType]) this.classesByBaseType[baseType] = []; - this.classesByBaseType[baseType].push(className); + this.classesByBaseType[baseType].push(classKey); } } } @@ -275,7 +275,8 @@ class DependencyResolver { // -> resolve to CallService.getById if (callerFuncId) { const callerFunc = this.functions[callerFuncId]; - const classEntry = callerFunc && this.classes[callerFunc.className]; + const classEntry = callerFunc && callerFunc.className && + this.classes[callerFile + ':' + callerFunc.className]; if (classEntry && classEntry.constructorDeps) { const typeName = classEntry.constructorDeps[objectName]; if (typeName) { @@ -289,10 +290,12 @@ class DependencyResolver { // 2b. Nominal type match: prefer candidates whose class implements or extends typeName. // If exactly one such candidate exists, the resolution is unambiguous. - const nominalClassNames = this.classesByBaseType[typeName] || []; + const nominalClassKeys = this.classesByBaseType[typeName] || []; const nominalMatches = candidates.filter(funcId => { const funcData = this.functions[funcId]; - return funcData && nominalClassNames.includes(funcData.className); + if (!funcData || !funcData.className) return false; + const funcClassKey = funcId.split(':')[0] + ':' + funcData.className; + return nominalClassKeys.includes(funcClassKey); }); if (nominalMatches.length === 1) return nominalMatches[0]; diff --git a/libs/openant-core/parsers/javascript/typescript_analyzer.js b/libs/openant-core/parsers/javascript/typescript_analyzer.js index 5b49bc1..bab57af 100644 --- a/libs/openant-core/parsers/javascript/typescript_analyzer.js +++ b/libs/openant-core/parsers/javascript/typescript_analyzer.js @@ -55,7 +55,7 @@ class TypeScriptAnalyzer { compilerOptions: PERMISSIVE_COMPILER_OPTIONS, }); this.functions = {}; // functionId -> function metadata - this.classes = {}; // className -> { constructorDeps, baseTypes } + this.classes = {}; // "filePath:className" -> { constructorDeps, baseTypes } this.callGraph = {}; // callerId -> array of call info } @@ -271,7 +271,7 @@ class TypeScriptAnalyzer { } if (Object.keys(classEntry).length > 0) { - this.classes[className] = classEntry; + this.classes[`${relativePath}:${className}`] = classEntry; } } diff --git a/libs/openant-core/tests/test_di_resolution.py b/libs/openant-core/tests/test_di_resolution.py index 54ad5a8..06fc8e1 100644 --- a/libs/openant-core/tests/test_di_resolution.py +++ b/libs/openant-core/tests/test_di_resolution.py @@ -163,6 +163,14 @@ def nestjs_repo_nominal(tmp_path): return tmp_path +def find_class(classes, class_name): + """Find a class entry in the file-qualified classes dict (key is "filePath:ClassName").""" + for key, val in classes.items(): + if key.endswith(':' + class_name): + return val + return None + + def analyze_and_resolve(repo_path, files): """Run analyzer + resolver on given files and return resolved data.""" analyzer_out = repo_path / "analyzer_output.json" @@ -200,8 +208,9 @@ def test_extracts_constructor_deps(self, nestjs_repo): data = json.loads(analyzer_out.read_text()) classes = data["classes"] - assert "CallResolver" in classes, "CallResolver not in classes table" - deps = classes["CallResolver"].get("constructorDeps", {}) + call_resolver = find_class(classes, "CallResolver") + assert call_resolver is not None, "CallResolver not in classes table" + deps = call_resolver.get("constructorDeps", {}) assert deps.get("callService") == "CallService" assert deps.get("authService") == "AuthService" @@ -236,7 +245,7 @@ def test_skips_primitive_types(self, tmp_path): assert result.returncode == 0 data = json.loads(analyzer_out.read_text()) - deps = data["classes"].get("Example", {}).get("constructorDeps", {}) + deps = (find_class(data["classes"], "Example") or {}).get("constructorDeps", {}) # Only MyService should be captured (PascalCase), not string/number assert "service" in deps assert deps["service"] == "MyService" @@ -244,6 +253,77 @@ def test_skips_primitive_types(self, tmp_path): assert "count" not in deps + def test_same_name_different_file_no_collision(self, tmp_path): + """Two classes with the same name in different files must not collide. + + Pre-fix: this.classes["UserController"] is last-write-wins, so the first + class's constructorDeps are silently overwritten and its DI calls miss. + Post-fix: both entries are keyed by "filePath:ClassName". + """ + (tmp_path / "admin").mkdir() + (tmp_path / "v2").mkdir() + (tmp_path / "admin" / "user_controller.ts").write_text("""\ +export class UserController { + constructor(private fooService: FooService) {} + getFoo() { return this.fooService.get(); } +} +""") + (tmp_path / "v2" / "user_controller.ts").write_text("""\ +export class UserController { + constructor(private barService: BarService) {} + getBar() { return this.barService.get(); } +} +""") + (tmp_path / "foo_service.ts").write_text("""\ +export class FooService { + get() { return 'foo'; } +} +""") + (tmp_path / "bar_service.ts").write_text("""\ +export class BarService { + get() { return 'bar'; } +} +""") + + # 1. Analyzer: both class entries present (no last-write-wins collision) + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "admin/user_controller.ts", + "v2/user_controller.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + data = json.loads(analyzer_out.read_text()) + classes = data["classes"] + + admin_entry = next((v for k, v in classes.items() if "admin" in k and k.endswith(":UserController")), None) + v2_entry = next((v for k, v in classes.items() if "v2" in k and k.endswith(":UserController")), None) + assert admin_entry is not None, "admin/UserController missing from classes table" + assert v2_entry is not None, "v2/UserController missing from classes table" + assert admin_entry.get("constructorDeps", {}).get("fooService") == "FooService" + assert v2_entry.get("constructorDeps", {}).get("barService") == "BarService" + + # 2. Resolver: each class resolves calls to the right service (not the other's) + data = analyze_and_resolve(tmp_path, [ + "admin/user_controller.ts", + "v2/user_controller.ts", + "foo_service.ts", + "bar_service.ts", + ]) + call_graph = data["callGraph"] + + admin_calls = next((calls for fid, calls in call_graph.items() if "admin" in fid and "UserController.getFoo" in fid), None) + v2_calls = next((calls for fid, calls in call_graph.items() if "v2" in fid and "UserController.getBar" in fid), None) + + assert admin_calls is not None, "admin/UserController.getFoo not in call graph" + assert v2_calls is not None, "v2/UserController.getBar not in call graph" + assert any("FooService.get" in c for c in admin_calls), \ + f"admin/UserController.getFoo should resolve to FooService.get, got: {admin_calls}" + assert any("BarService.get" in c for c in v2_calls), \ + f"v2/UserController.getBar should resolve to BarService.get, got: {v2_calls}" + + class TestBaseTypesExtraction: """Test that the analyzer extracts implements/extends into baseTypes.""" @@ -257,7 +337,7 @@ def test_extracts_implements(self, nestjs_repo_nominal): assert result.returncode == 0 data = json.loads(analyzer_out.read_text()) - base_types = data["classes"].get("CallServiceImpl", {}).get("baseTypes", []) + base_types = (find_class(data["classes"], "CallServiceImpl") or {}).get("baseTypes", []) assert "ICallService" in base_types def test_generic_implements_stripped(self, tmp_path): @@ -278,7 +358,7 @@ def test_generic_implements_stripped(self, tmp_path): assert result.returncode == 0 data = json.loads(analyzer_out.read_text()) - base_types = data["classes"].get("UserRepo", {}).get("baseTypes", []) + base_types = (find_class(data["classes"], "UserRepo") or {}).get("baseTypes", []) assert "Repository" in base_types assert not any("<" in t for t in base_types) @@ -299,7 +379,7 @@ def test_extracts_extends(self, tmp_path): assert result.returncode == 0 data = json.loads(analyzer_out.read_text()) - base_types = data["classes"].get("ConcreteService", {}).get("baseTypes", []) + base_types = (find_class(data["classes"], "ConcreteService") or {}).get("baseTypes", []) assert "BaseService" in base_types From 6103e39d039ec95a6037f4b768741828e1366f50 Mon Sep 17 00:00:00 2001 From: joshbouncesecurity Date: Tue, 12 May 2026 18:06:28 +0300 Subject: [PATCH 7/7] feat(di-resolution): add field/property injection support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends DI extraction to cover three additional patterns: - @Inject / @InjectRepository / any @Inject* decorator — type comes from the TypeScript property type annotation (generics stripped as before) - inject(SvcType) functional injection (Angular's inject() API) — type taken from the first call argument Extracted into classEntry.fieldDeps ("filePath:className" -> {propName: typeName}), kept separate from constructorDeps so callers can distinguish the injection style. The resolver now checks both constructorDeps and fieldDeps when resolving this.prop.method() calls. The guard is relaxed to enter the DI path for classes that have only fieldDeps (no constructor injection). Tests: @Inject decorator, @InjectRepository decorator (generic type stripped to Repository), inject() functional form, non-@Inject decorators ignored, full resolver round-trip for field-injected calls. Co-Authored-By: Claude Sonnet 4.6 --- .../parsers/javascript/dependency_resolver.js | 7 +- .../parsers/javascript/typescript_analyzer.js | 39 +++++- libs/openant-core/tests/test_di_resolution.py | 132 ++++++++++++++++++ 3 files changed, 174 insertions(+), 4 deletions(-) diff --git a/libs/openant-core/parsers/javascript/dependency_resolver.js b/libs/openant-core/parsers/javascript/dependency_resolver.js index f977a0c..c7c697e 100644 --- a/libs/openant-core/parsers/javascript/dependency_resolver.js +++ b/libs/openant-core/parsers/javascript/dependency_resolver.js @@ -20,7 +20,7 @@ const path = require('path'); class DependencyResolver { constructor(analyzerOutput, options = {}) { this.functions = analyzerOutput.functions || {}; - this.classes = analyzerOutput.classes || {}; // "filePath:className" -> { constructorDeps, baseTypes } + this.classes = analyzerOutput.classes || {}; // "filePath:className" -> { constructorDeps, fieldDeps, baseTypes } this.callGraph = {}; // functionId -> [calledFunctionIds] this.reverseCallGraph = {}; // functionId -> [callerFunctionIds] this.maxDepth = options.maxDepth || 3; @@ -277,8 +277,9 @@ class DependencyResolver { const callerFunc = this.functions[callerFuncId]; const classEntry = callerFunc && callerFunc.className && this.classes[callerFile + ':' + callerFunc.className]; - if (classEntry && classEntry.constructorDeps) { - const typeName = classEntry.constructorDeps[objectName]; + if (classEntry && (classEntry.constructorDeps || classEntry.fieldDeps)) { + const typeName = (classEntry.constructorDeps || {})[objectName] + ?? (classEntry.fieldDeps || {})[objectName]; if (typeName) { // 2a. Exact type match for (const funcId of candidates) { diff --git a/libs/openant-core/parsers/javascript/typescript_analyzer.js b/libs/openant-core/parsers/javascript/typescript_analyzer.js index bab57af..66a09e3 100644 --- a/libs/openant-core/parsers/javascript/typescript_analyzer.js +++ b/libs/openant-core/parsers/javascript/typescript_analyzer.js @@ -55,7 +55,7 @@ class TypeScriptAnalyzer { compilerOptions: PERMISSIVE_COMPILER_OPTIONS, }); this.functions = {}; // functionId -> function metadata - this.classes = {}; // "filePath:className" -> { constructorDeps, baseTypes } + this.classes = {}; // "filePath:className" -> { constructorDeps, fieldDeps, baseTypes } this.callGraph = {}; // callerId -> array of call info } @@ -270,6 +270,43 @@ class TypeScriptAnalyzer { if (Object.keys(injections).length > 0) classEntry.constructorDeps = injections; } + // Extract field/property injection metadata. + // Covers decorator-based (@Inject, @InjectRepository, etc.) and Angular's inject() function. + const fieldDeps = {}; + for (const prop of classDecl.getProperties()) { + const propName = prop.getName(); + let typeName = null; + + // Decorator-based: any @Inject* decorator signals an injection point; + // the injected type comes from the TypeScript type annotation. + const hasInjectDecorator = prop.getDecorators().some(d => /^Inject/.test(d.getName())); + if (hasInjectDecorator) { + const typeNode = prop.getTypeNode(); + if (typeNode) { + const t = typeNode.getText().replace(/<.*$/, ''); + if (/^[A-Z][a-zA-Z0-9_$]*$/.test(t)) typeName = t; + } + } + + // Functional: private svc = inject(SvcType) (Angular inject() API) + if (!typeName) { + const init = prop.getInitializer(); + if (init && init.getKindName() === 'CallExpression') { + const expr = init.getExpression(); + if (expr && expr.getText() === 'inject') { + const args = init.getArguments(); + if (args.length > 0) { + const t = args[0].getText().replace(/<.*$/, ''); + if (/^[A-Z][a-zA-Z0-9_$]*$/.test(t)) typeName = t; + } + } + } + } + + if (typeName) fieldDeps[propName] = typeName; + } + if (Object.keys(fieldDeps).length > 0) classEntry.fieldDeps = fieldDeps; + if (Object.keys(classEntry).length > 0) { this.classes[`${relativePath}:${className}`] = classEntry; } diff --git a/libs/openant-core/tests/test_di_resolution.py b/libs/openant-core/tests/test_di_resolution.py index 06fc8e1..70fa4ea 100644 --- a/libs/openant-core/tests/test_di_resolution.py +++ b/libs/openant-core/tests/test_di_resolution.py @@ -444,6 +444,138 @@ def test_nominal_ambiguity_skips_resolution(self, tmp_path): ), f"Should not resolve ambiguous implements, got: {resolver_calls}" +class TestFieldDepsExtraction: + """Test that @Inject* decorators and inject() function are captured as fieldDeps.""" + + def test_extracts_inject_decorator(self, tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "service.ts").write_text("""\ +import { Injectable, Inject } from '@nestjs/common'; + +@Injectable() +export class MyService { + @Inject('TOKEN') + private depService: DepService; + + run() { return this.depService.execute(); } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/service.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + data = json.loads(analyzer_out.read_text()) + field_deps = (find_class(data["classes"], "MyService") or {}).get("fieldDeps", {}) + assert field_deps.get("depService") == "DepService" + + def test_extracts_inject_repository_decorator(self, tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "service.ts").write_text("""\ +import { Injectable } from '@nestjs/common'; +import { InjectRepository } from '@nestjs/typeorm'; + +@Injectable() +export class UserService { + @InjectRepository(User) + private userRepo: Repository; + + findAll() { return this.userRepo.find(); } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/service.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + data = json.loads(analyzer_out.read_text()) + field_deps = (find_class(data["classes"], "UserService") or {}).get("fieldDeps", {}) + assert field_deps.get("userRepo") == "Repository" + + def test_extracts_functional_inject(self, tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "component.ts").write_text("""\ +import { inject } from '@angular/core'; + +export class MyComponent { + private authService = inject(AuthService); + + login() { return this.authService.signIn(); } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/component.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + data = json.loads(analyzer_out.read_text()) + field_deps = (find_class(data["classes"], "MyComponent") or {}).get("fieldDeps", {}) + assert field_deps.get("authService") == "AuthService" + + def test_ignores_non_inject_decorator(self, tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "service.ts").write_text("""\ +export class MyService { + @Column() + private name: string; + + getName() { return this.name; } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/service.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + data = json.loads(analyzer_out.read_text()) + field_deps = (find_class(data["classes"], "MyService") or {}).get("fieldDeps", {}) + assert "name" not in field_deps + + def test_resolves_field_injection_calls(self, tmp_path): + """Calls via @Inject field deps resolve correctly through the full pipeline.""" + src = tmp_path / "src" + src.mkdir() + (src / "service.ts").write_text("""\ +import { Injectable, Inject } from '@nestjs/common'; + +@Injectable() +export class MyService { + @Inject('TOKEN') + private depService: DepService; + + run() { return this.depService.execute(); } +} +""") + (src / "dep_service.ts").write_text("""\ +export class DepService { + execute() { return 'done'; } +} +""") + data = analyze_and_resolve(tmp_path, [ + "src/service.ts", + "src/dep_service.ts", + ]) + call_graph = data["callGraph"] + service_calls = next( + (calls for fid, calls in call_graph.items() if "MyService.run" in fid), None + ) + assert service_calls is not None, "MyService.run not in call graph" + assert any("DepService.execute" in c for c in service_calls), \ + f"Expected DepService.execute via field injection, got: {service_calls}" + + class TestDIAwareCallResolution: """Test that the dependency resolver uses constructorDeps for DI resolution."""