diff --git a/application/cmd/cre_main.py b/application/cmd/cre_main.py index ead5a428..4c321ea5 100644 --- a/application/cmd/cre_main.py +++ b/application/cmd/cre_main.py @@ -61,6 +61,7 @@ def register_node(node: defs.Node, collection: db.Node_collection) -> db.Node: defs.Standard.__name__, defs.Code.__name__, defs.Tool.__name__, + defs.Attack.__name__, ]: # if a node links another node it is likely that a writer wants to reference something # in that case, find which of the two nodes has at least one CRE attached to it and link both to the parent CRE @@ -183,6 +184,7 @@ def parse_file( defs.Credoctypes.Standard.value, defs.Credoctypes.Code.value, defs.Credoctypes.Tool.value, + defs.Credoctypes.Attack.value, ): # document = defs.Standard(**contents) doctype = contents.get("doctype") @@ -192,7 +194,11 @@ def parse_file( else ( defs.Code if doctype == defs.Credoctypes.Code.value - else defs.Tool if doctype == defs.Credoctypes.Tool.value else None + else ( + defs.Tool + if doctype == defs.Credoctypes.Tool.value + else defs.Attack + ) ) ) document = from_dict( diff --git a/application/database/db.py b/application/database/db.py index 35cc444a..76405774 100644 --- a/application/database/db.py +++ b/application/database/db.py @@ -1494,7 +1494,11 @@ def add_node( logger.info( f"knew of node {entry.name}:{entry.section_id}:{entry.section}:{entry.link} ,updating" ) - if node.section and node.section != entry.section: + if ( + hasattr(node, "section") + and node.section + and node.section != entry.section + ): entry.section = node.section entry.link = node.hyperlink self.session.commit() @@ -1653,7 +1657,7 @@ def add_link( ) if entry: logger.debug( - f"knew of link {node.name}:{node.section}" + f"knew of link {node.name}:{getattr(node, 'section', 'None')}" f"=={cre.name} of type {entry.type}," f"updating type to {ltype.value}" ) @@ -1663,7 +1667,7 @@ def add_link( else: logger.debug( f"did not know of link {node.id})" - f"{node.name}:{node.section}=={cre.id}){cre.name}" + f"{node.name}:{getattr(node, 'section', 'None')}=={cre.id}){cre.name}" " ,adding" ) self.session.add(Links(type=ltype.value, cre=cre.id, node=node.id)) @@ -1950,6 +1954,8 @@ def dbNodeFromNode(doc: cre_defs.Node) -> Optional[Node]: return dbNodeFromCode(doc) elif doc.doctype == cre_defs.Credoctypes.Tool: return dbNodeFromTool(doc) + elif doc.doctype == cre_defs.Credoctypes.Attack: + return dbNodeFromAttack(doc) else: return None @@ -2041,6 +2047,13 @@ def nodeFromDB(dbnode: Node) -> cre_defs.Node: tags=tags, description=dbnode.description, ) + elif dbnode.ntype == cre_defs.Attack.__name__: + return cre_defs.Attack( + name=dbnode.name, + hyperlink=dbnode.link, + tags=tags, + description=dbnode.description, + ) else: raise ValueError( f"Db node {dbnode.name} has an unrecognised ntype {dbnode.ntype}" @@ -2140,3 +2153,17 @@ def gap_analysis( ) logger.info(f"stored gapa analysis for {'>>>'.join(node_names)}, successfully") return (node_names, grouped_paths, extra_paths_dict) + + +def dbNodeFromAttack(attack: cre_defs.Node) -> Node: + attack = cast(cre_defs.Attack, attack) + tags = "" + if attack.tags: + tags = ",".join(attack.tags) + return Node( + name=attack.name, + ntype=attack.doctype.value, + tags=tags, + description=attack.description, + link=attack.hyperlink, + ) diff --git a/application/defs/cre_defs.py b/application/defs/cre_defs.py index 5edf7121..10e78f96 100644 --- a/application/defs/cre_defs.py +++ b/application/defs/cre_defs.py @@ -179,6 +179,7 @@ class Credoctypes(str, Enum, metaclass=EnumMetaWithContains): Standard = "Standard" Tool = "Tool" Code = "Code" + Attack = "Attack" @staticmethod def from_str(typ: str) -> "Credoctypes": @@ -526,3 +527,8 @@ def __hash__(self) -> int: @dataclass(eq=False) class Code(Node): doctype: Credoctypes = Credoctypes.Code + + +@dataclass(eq=False) +class Attack(Node): + doctype: Credoctypes = Credoctypes.Attack diff --git a/application/manual_seed_attacks.py b/application/manual_seed_attacks.py new file mode 100644 index 00000000..55ad6c17 --- /dev/null +++ b/application/manual_seed_attacks.py @@ -0,0 +1,151 @@ +import os +import sys +import logging + +# Ensure application matches the import path +sys.path.append(os.getcwd()) + +from application.database import db +from application.defs import cre_defs as defs +from application.cmd.cre_main import db_connect +from application.config import CMDConfig +from application import create_app + +# Import our new utility +from application.utils.attack_mapper import link_attack_to_cre_by_cwe + +logging.basicConfig() +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def seed_cwe_structure(collection): + """Mocks the existing CWE->CRE structure if it doesn't exist.""" + print("\n--- Seeding Mock CWE Structure ---") + + # 1. Create a CRE "Input Validation" (Mock) + cre = defs.CRE( + name="Input Validation (Mock)", + id="999-999", + description="Mitigation for input attacks", + ) + db_cre = collection.add_cre(cre) + cre.id = db_cre.id # Use DB PK (UUID) for linking + print(f" Added/Found CRE: {cre.name}") + + # 2. Create CWE-22 (Standard) + cwe = defs.Standard( + name="CWE-22", + section="Path Traversal", + hyperlink="https://cwe.mitre.org/data/definitions/22.html", + ) + db_cwe = collection.add_node(cwe) + cwe.id = db_cwe.id # Use DB PK (UUID) for linking + print(f" Added/Found CWE: {cwe.name}") + + # 3. Link CWE -> CRE + # add_link(cre, node) + collection.add_link(cre=cre, node=cwe, ltype=defs.LinkTypes.Related) + print(f" Linked CWE-22 -> Input Validation") + + +def seed_attacks(): + # Database path + db_path = os.path.abspath("standards_cache.sqlite") + print(f"Connecting to DB at {db_path}...") + + # Setup context + conf = CMDConfig(db_uri=db_path) + app = create_app(conf=conf) + app_context = app.app_context() + app_context.push() + + collection = db.Node_collection() + + # Step 0: Ensure CWE infrastructure exists + seed_cwe_structure(collection) + + print("\n--- Seeding Attacks ---") + # Define Attacks with Descriptions containing CWEs + attacks = [ + defs.Attack( + name="Path Traversal", + hyperlink="https://owasp.org/www-community/attacks/Path_Traversal", + tags=["OWASP", "Attack"], + description="The Path Traversal attack technique allows... Related CWEs: CWE-22.", + ), + defs.Attack( + name="SQL Injection", + hyperlink="https://owasp.org/www-community/attacks/SQL_Injection", + tags=["OWASP", "Attack"], + description="SQL Injection attacks... Related CWEs: CWE-89.", + # Note: CWE-89 is not mocked above, so this should NOT link, testing negative case/robustness + ), + ] + + # Register Attacks + for attack in attacks: + # Idempotency Check + existing = collection.get_nodes(name=attack.name, ntype=defs.Attack.__name__) + if existing: + # Update description to ensure we test parsing + if existing[0].description != attack.description: + print(f" Updating description for {attack.name}") + # We can use add_node to update + collection.add_node(attack) + else: + print(f" Skipping existing: {attack.name}") + else: + db_node = collection.add_node(attack) + print(f" Added: {attack.name}") + + # Ensure attack.id is the DB UUID for linking + db_node_obj = ( + collection.session.query(db.Node) + .filter(db.Node.name == attack.name) + .first() + ) + if db_node_obj: + attack.id = db_node_obj.id + + # --- PHASE 2: AUTO LINKING --- + print(f" Running Auto-Linking for {attack.name}...") + linked_cres = link_attack_to_cre_by_cwe(attack, collection) + if linked_cres: + print(f" ✅ Linked to: {linked_cres}") + else: + print(f" (No links created)") + + # Verification + print("\n--- Final Verification ---") + # Verify Path Traversal is linked to "Input Validation (Mock)" + # We can check by fetching the CRE and listing its links + + cre_nodes = collection.get_CREs(name="Input Validation (Mock)") + if not cre_nodes: + print("❌ Mock CRE not found during verification!") + sys.exit(1) + + mock_cre = cre_nodes[0] + # Verify it links to "Path Traversal" + found_link = False + for link in mock_cre.links: + # link.document is the linked node + # We need to check if it's our attack + if ( + link.document.name == "Path Traversal" + and link.document.doctype == defs.Credoctypes.Attack + ): + found_link = True + print( + f"✅ Verified Link: CRE '{mock_cre.name}' <--> Attack '{link.document.name}'" + ) + break + + if not found_link: + print("❌ Verification Failed: Link between CRE and Attack NOT found.") + sys.exit(1) + + +if __name__ == "__main__": + seed_attacks() diff --git a/application/utils/attack_mapper.py b/application/utils/attack_mapper.py new file mode 100644 index 00000000..655cccce --- /dev/null +++ b/application/utils/attack_mapper.py @@ -0,0 +1,82 @@ +import re +import logging +from typing import List, Optional +from application.database import db +from application.defs import cre_defs as defs + +logger = logging.getLogger(__name__) + + +def extract_cwe_ids(text: str) -> List[str]: + """Finds all occurrences of 'CWE-' in the text.""" + if not text: + return [] + # Match CWE-123, CWE: 123, etc. strictly CWE-\d+ for now as per plan + matches = re.findall(r"CWE-(\d+)", text, re.IGNORECASE) + return [f"CWE-{m}" for m in matches] + + +def link_attack_to_cre_by_cwe( + attack: defs.Attack, collection: db.Node_collection +) -> List[str]: + """ + Links the Attack node to CREs that are already linked to the CWEs mentioned in the Attack's description. + Returns a list of linked CRE names/IDs. + """ + linked_cres = [] + cwe_ids = extract_cwe_ids(attack.description) + + if not cwe_ids: + return [] + + for cwe_name in set(cwe_ids): + # 1. Find the CWE Node + # CWEs are Standards. + cwe_nodes = collection.get_nodes(name=cwe_name) + if not cwe_nodes: + continue + + cwe_node = cwe_nodes[0] + + # 2. Find CREs linked to this CWE + # This requires querying the Links table. + # Node_collection doesn't expose get_links_for_node directly? + # We can access the session or use get_CREs with include_only? + # get_CREs doesn't filter by "linked to node X". + # We will iterate manually or use db.session if available. + # collection.session is likely the db.session + + links = ( + collection.session.query(db.Links) + .filter(db.Links.node == cwe_node.id) + .all() + ) + + for link in links: + cre_id = link.cre + + # Use get_CREs by internal_id first + cre_list = collection.get_CREs(internal_id=cre_id) + + if not cre_list: + # Fallback: Maybe the link stores external_id? + cre_list = collection.get_CREs(external_id=cre_id) + + if cre_list: + cre = cre_list[0] + # HACK: Retrieve the underlying DB UUID to ensure Links table uses PK + # get_CREs returns defs.CRE where .id is usually .external_id + # But Links table FK points to .id (UUID) + db_cre = ( + collection.session.query(db.CRE) + .filter(db.CRE.external_id == cre.id) + .first() + ) + if db_cre: + if db_cre: + cre.id = db_cre.id + + collection.add_link(cre=cre, node=attack, ltype=defs.LinkTypes.Related) + linked_cres.append(f"{cre.name} (via {cwe_name})") + + return linked_cres