diff --git a/flowsint-api/app/api/routes/enrichers.py b/flowsint-api/app/api/routes/enrichers.py index 987bb4a..f96b461 100644 --- a/flowsint-api/app/api/routes/enrichers.py +++ b/flowsint-api/app/api/routes/enrichers.py @@ -1,24 +1,25 @@ -from fastapi import APIRouter, HTTPException, Depends, Query -from typing import List, Any, Optional -from pydantic import BaseModel -from flowsint_enrichers import ENRICHER_REGISTRY, load_all_enrichers +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query from flowsint_core.core.celery import celery -from flowsint_core.core.types import Node, Edge, FlowBranch +from flowsint_core.core.graph import create_graph_service, GraphEdge, GraphNode from flowsint_core.core.models import CustomType, Profile -from flowsint_core.core.graph_repository import GraphRepository -from flowsint_types import clean_neo4j_node_data -from app.api.deps import get_current_user from flowsint_core.core.postgre_db import get_db -from sqlalchemy.orm import Session +from flowsint_core.core.types import FlowBranch +from flowsint_enrichers import ENRICHER_REGISTRY, load_all_enrichers +from pydantic import BaseModel from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.api.deps import get_current_user # Auto-discover and register all enrichers load_all_enrichers() class FlowComputationRequest(BaseModel): - nodes: List[Node] - edges: List[Edge] + nodes: List[GraphNode] + edges: List[GraphEdge] inputType: Optional[str] = None @@ -74,21 +75,20 @@ async def launch_enricher( ): try: # Retrieve nodes from Neo4J by their element IDs - graph_repo = GraphRepository() - nodes_data = graph_repo.get_nodes_by_ids(payload.node_ids, payload.sketch_id) - - if not nodes_data: - raise HTTPException(status_code=404, detail="No nodes found with provided IDs") - - # Clean Neo4J-specific fields from node data - # The enricher's preprocess() will handle Pydantic validation - cleaned_nodes = [clean_neo4j_node_data(node_data) for node_data in nodes_data] + graph_service = create_graph_service(sketch_id=payload.sketch_id) + entities = graph_service.get_nodes_by_ids_for_task(payload.node_ids) + # send deserialized nodes + entities = [entity.model_dump(mode="json", serialize_as_any=True) for entity in entities] + if not entities: + raise HTTPException( + status_code=404, detail="No entities found with provided IDs" + ) task = celery.send_task( "run_enricher", args=[ enricher_name, - cleaned_nodes, + entities, payload.sketch_id, str(current_user.id), ], @@ -98,4 +98,7 @@ async def launch_enricher( except HTTPException: raise except Exception as e: - raise HTTPException(status_code=500, detail=f"Error launching enricher: {str(e)}") + print(e) + raise HTTPException( + status_code=500, detail=f"Error launching enricher: {str(e)}" + ) diff --git a/flowsint-api/app/api/routes/flows.py b/flowsint-api/app/api/routes/flows.py index ebed538..f2d5c04 100644 --- a/flowsint-api/app/api/routes/flows.py +++ b/flowsint-api/app/api/routes/flows.py @@ -1,50 +1,49 @@ -from uuid import UUID, uuid4 -from fastapi import APIRouter, HTTPException, Depends, status, Query -from typing import Dict, List, Any, Optional -from pydantic import BaseModel from datetime import datetime -from flowsint_core.utils import extract_input_schema_flow -from flowsint_enrichers import ENRICHER_REGISTRY, load_all_enrichers +from typing import Any, Dict, List, Optional +from uuid import UUID, uuid4 +from fastapi import APIRouter, Depends, HTTPException, Query, status # Auto-discover and register all enrichers -load_all_enrichers() from flowsint_core.core.celery import celery -from flowsint_core.core.graph_repository import GraphRepository -from flowsint_types import ( - Domain, - Phrase, - Ip, - SocialAccount, - Organization, - Email, - Phone, - Username, - clean_neo4j_node_data, -) -from flowsint_core.core.types import Node, Edge, FlowStep, FlowBranch -from sqlalchemy.orm import Session +from flowsint_core.core.graph import create_graph_service +from flowsint_core.core.models import CustomType, Flow, Profile, Sketch from flowsint_core.core.postgre_db import get_db -from flowsint_core.core.models import Flow, Profile, CustomType, Sketch -from app.api.deps import get_current_user -from sqlalchemy import func -from app.api.schemas.flow import FlowRead, FlowCreate, FlowUpdate -from app.security.permissions import check_investigation_permission +from flowsint_core.core.types import FlowBranch, FlowEdge, FlowNode, FlowStep +from flowsint_core.utils import extract_input_schema_flow +from flowsint_enrichers import ENRICHER_REGISTRY, load_all_enrichers from flowsint_types import ( ASN, CIDR, + CryptoNFT, CryptoWallet, CryptoWalletTransaction, - CryptoNFT, - Website, + DNSRecord, + Domain, + Email, Individual, + Ip, + Organization, + Phone, + Phrase, Port, - DNSRecord + SocialAccount, + Username, + Website, ) +from pydantic import BaseModel +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.api.deps import get_current_user +from app.api.schemas.flow import FlowCreate, FlowRead, FlowUpdate +from app.security.permissions import check_investigation_permission + +load_all_enrichers() class FlowComputationRequest(BaseModel): - nodes: List[Node] - edges: List[Edge] + nodes: List[FlowNode] + edges: List[FlowEdge] inputType: Optional[str] = None @@ -169,7 +168,6 @@ def create_flow( db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user), ): - new_flow = Flow( id=uuid4(), name=payload.name, @@ -260,24 +258,22 @@ async def launch_flow( current_user.id, sketch.investigation_id, actions=["update"], db=db ) - # Retrieve nodes from Neo4J by their element IDs - graph_repo = GraphRepository() - nodes_data = graph_repo.get_nodes_by_ids(payload.node_ids, payload.sketch_id) - - if not nodes_data: - raise HTTPException(status_code=404, detail="No nodes found with provided IDs") - - # Clean Neo4J-specific fields from node data - # The enricher's preprocess() will handle Pydantic validation - cleaned_nodes = [clean_neo4j_node_data(node_data) for node_data in nodes_data] + # Retrieve entities from Neo4J by their element IDs + graph_service = create_graph_service(sketch_id=payload.sketch_id) + entities = graph_service.get_nodes_by_ids_for_task(payload.node_ids) # Compute flow branches - nodes = [Node(**node) for node in flow.flow_schema["nodes"]] - edges = [Edge(**edge) for edge in flow.flow_schema["edges"]] + nodes = [FlowNode(**node) for node in flow.flow_schema["nodes"]] + edges = [FlowEdge(**edge) for edge in flow.flow_schema["edges"]] + + entities = [entity.model_dump(mode="json", serialize_as_any=True) for entity in entities] + # For flow computation, we still need a sample value # Use the label from the first node data - sample_value = nodes_data[0].get('label', 'sample_value') if nodes_data else 'sample_value' + sample_value = ( + entities[0].get("nodeLabel", "sample_value") if len(entities) else "sample_value" + ) flow_branches = compute_flow_branches(sample_value, nodes, edges) serializable_branches = [branch.model_dump() for branch in flow_branches] @@ -285,7 +281,7 @@ async def launch_flow( "run_flow", args=[ serializable_branches, - cleaned_nodes, + entities, payload.sketch_id, str(current_user.id), ], @@ -295,6 +291,7 @@ async def launch_flow( except HTTPException: raise except Exception as e: + print(e) raise HTTPException(status_code=500, detail=f"Error launching flow: {str(e)}") @@ -332,7 +329,7 @@ def generate_sample_data(type_str: str) -> Any: def compute_flow_branches( - initial_value: Any, nodes: List[Node], edges: List[Edge] + initial_value: Any, nodes: List[FlowNode], edges: List[FlowEdge] ) -> List[FlowBranch]: """Computes flow branches based on nodes and edges with proper DFS traversal""" # Find input nodes (starting points) @@ -384,7 +381,7 @@ def compute_flow_branches( return 1 + min_length - def get_outgoing_edges(node_id: str) -> List[Edge]: + def get_outgoing_edges(node_id: str) -> List[FlowEdge]: """Get outgoing edges sorted by the shortest possible path length""" out_edges = [edge for edge in edges if edge.source == node_id] # Sort edges by the length of the shortest possible path from their target @@ -549,7 +546,7 @@ def compute_flow_branches( return branches -def process_node_data(node: Node, inputs: Dict[str, Any]) -> Dict[str, Any]: +def process_node_data(node: FlowNode, inputs: Dict[str, Any]) -> Dict[str, Any]: """Traite les données de nœud en fonction du type de nœud et des entrées""" outputs = {} output_types = node.data["outputs"].get("properties", []) diff --git a/flowsint-api/app/api/routes/investigations.py b/flowsint-api/app/api/routes/investigations.py index 043b6e0..a01a95e 100644 --- a/flowsint-api/app/api/routes/investigations.py +++ b/flowsint-api/app/api/routes/investigations.py @@ -21,8 +21,7 @@ from app.api.schemas.investigation import ( InvestigationUpdate, ) from app.api.schemas.sketch import SketchRead -from flowsint_core.core.graph_db import neo4j_connection -from flowsint_core.core.graph_repository import GraphRepository +from flowsint_core.core.graph import create_graph_service router = APIRouter() @@ -203,11 +202,14 @@ def delete_investigation( db.query(Analysis).filter(Sketch.investigation_id == investigation_id).all() ) - # Delete all nodes and relationships for each sketch in Neo4j using GraphRepository - graph_repo = GraphRepository(neo4j_connection) + # Delete all nodes and relationships for each sketch in Neo4j using GraphService for sketch in sketches: try: - graph_repo.delete_all_sketch_nodes(str(sketch.id)) + graph_service = create_graph_service( + sketch_id=str(sketch.id), + enable_batching=False, + ) + graph_service.delete_all_sketch_nodes() except Exception as e: print(f"Neo4j cleanup error for sketch {sketch.id}: {e}") raise HTTPException(status_code=500, detail="Failed to clean up graph data") diff --git a/flowsint-api/app/api/routes/sketches.py b/flowsint-api/app/api/routes/sketches.py index fc3aa53..a987304 100644 --- a/flowsint-api/app/api/routes/sketches.py +++ b/flowsint-api/app/api/routes/sketches.py @@ -11,13 +11,16 @@ from fastapi import ( UploadFile, status, ) -from flowsint_core.core.graph_db import neo4j_connection -from flowsint_core.core.graph_repository import GraphRepository +from flowsint_core.core.graph import create_graph_service, GraphNode from flowsint_core.core.models import Profile, Sketch from flowsint_core.core.postgre_db import get_db -from flowsint_core.imports import FileParseResult, parse_import_file +from flowsint_core.imports import ( + EntityMapping, + ImportService, + create_import_service, + FileParseResult, +) from flowsint_core.utils import flatten -from flowsint_types import TYPE_REGISTRY from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -37,16 +40,6 @@ class NodeData(BaseModel): extra = "allow" -class NodeInput(BaseModel): - type: str = Field(..., description="Type of the node") - data: NodeData = Field( - default_factory=NodeData, description="Additional data for the node" - ) - - class Config: - extra = "allow" # Accept any additional fields - - class NodeDeleteInput(BaseModel): nodeIds: List[str] @@ -55,11 +48,10 @@ class RelationshipDeleteInput(BaseModel): relationshipIds: List[str] + class NodeEditInput(BaseModel): nodeId: str - data: NodeData = Field( - default_factory=NodeData, description="Updated data for the node" - ) + updates: Dict[str, Any] class RelationshipEditInput(BaseModel): @@ -152,10 +144,12 @@ def delete_sketch( current_user.id, sketch.investigation_id, actions=["delete"], db=db ) - # Delete all nodes and relationships in Neo4j first using GraphRepository + # Delete all nodes and relationships in Neo4j first using GraphService try: - graph_repo = GraphRepository(neo4j_connection) - graph_repo.delete_all_sketch_nodes(str(id)) + graph_service = create_graph_service( + sketch_id=str(id), enable_batching=False + ) + graph_service.delete_all_sketch_nodes() except Exception as e: print(f"Neo4j cleanup error: {e}") raise HTTPException(status_code=500, detail="Failed to clean up graph data") @@ -165,15 +159,15 @@ def delete_sketch( db.commit() -@router.get("/{id}/graph") +@router.get("/{sketch_id}/graph") async def get_sketch_nodes( - id: str, + sketch_id: str, format: str | None = None, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user), ): """ - Get the nodes and relationships for a sketch. + Get the nodes and edges for a sketch. Args: id: The ID of the sketch format: Optional format parameter. If "inline", returns inline relationships @@ -185,79 +179,32 @@ async def get_sketch_nodes( rls: [] Or if format=inline: List of inline relationship strings """ - sketch = db.query(Sketch).filter(Sketch.id == id).first() + sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() if not sketch: raise HTTPException(status_code=404, detail="Graph not found") check_investigation_permission( current_user.id, sketch.investigation_id, actions=["read"], db=db ) - # Get all nodes and relationships using GraphRepository - graph_repo = GraphRepository(neo4j_connection) - graph_data = graph_repo.get_sketch_graph(id, limit=100000) - - nodes_result = graph_data["nodes"] - rels_result = graph_data["relationships"] - - nodes = [ - { - "id": str(record["id"]), - "data": record["data"], - "label": record["data"].get("label", "Node"), - "idx": idx, - # Extract x and y positions if they exist - **({"x": record["data"]["x"]} if "x" in record["data"] else {}), - **({"y": record["data"]["y"]} if "y" in record["data"] else {}), - } - for idx, record in enumerate(nodes_result) - ] - - rels = [ - { - "id": str(record["id"]), - "source": str(record["source"]), - "target": str(record["target"]), - "data": record["data"], - "label": record["type"], - } - for record in rels_result - ] - + # Get all nodes and relationships using GraphService + graph_service = create_graph_service( + sketch_id=sketch_id, enable_batching=False + ) + graph_data = graph_service.get_sketch_graph() if format == "inline": from flowsint_core.utils import get_inline_relationships - return get_inline_relationships(nodes, rels) + return get_inline_relationships(graph_data.nodes, graph_data.edges) - return {"nds": nodes, "rls": rels} + graph = graph_data.model_dump(mode="json", serialize_as_any=True) - -def clean_empty_values(data: dict) -> dict: - """Remove empty string values from dict to avoid Pydantic validation errors.""" - cleaned = {} - for key, value in data.items(): - if value == "" or value is None: - continue - if isinstance(value, dict): - cleaned_nested = clean_empty_values(value) - if cleaned_nested: - cleaned[key] = cleaned_nested - elif isinstance(value, list): - cleaned_list = [ - clean_empty_values(item) if isinstance(item, dict) else item - for item in value - if item != "" and item is not None - ] - if cleaned_list: - cleaned[key] = cleaned_list - else: - cleaned[key] = value - return cleaned + return {"nds": graph["nodes"], "rls": graph["edges"]} @router.post("/{sketch_id}/nodes/add") @update_sketch_timestamp def add_node( sketch_id: str, - node: NodeInput, + node: GraphNode, background_tasks: BackgroundTasks, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user), @@ -268,53 +215,24 @@ def add_node( check_investigation_permission( current_user.id, sketch.investigation_id, actions=["update"], db=db ) - - node_data = node.data.model_dump() - node_type = node_data["type"] - - DetectedType = TYPE_REGISTRY.get_lowercase(node_type) - if not DetectedType: - raise HTTPException(status_code=400, detail=f"Unknown type: {node_type}") - - cleaned_data = clean_empty_values(node_data) - try: - pydantic_obj = DetectedType(**cleaned_data) - except Exception as e: - print(f"Pydantic validation error: {e}") - raise HTTPException( - status_code=400, detail=f"Invalid data for type {node_type}: {str(e)}" + graph_service = create_graph_service( + sketch_id=sketch_id, + enable_batching=False, ) - - try: - graph_repo = GraphRepository(neo4j_connection) - element_id = graph_repo.create_node(pydantic_obj, sketch_id=sketch_id) + node_id = graph_service.create_node(node) except Exception as e: + print(e) raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") - if not element_id: + if not node_id: raise HTTPException(status_code=400, detail="Node creation failed") - obj_dict = ( - pydantic_obj.model_dump(mode="json") - if hasattr(pydantic_obj, "model_dump") - else pydantic_obj.dict() - ) - # Extract only non-dict values, skip None keys - obj_properties = { - k: (v if v is not None else "") - for k, v in obj_dict.items() - if k is not None and not isinstance(v, dict) - } - obj_properties["id"] = element_id - obj_properties["type"] = node_type + node.id = node_id return { "status": "node added", - "node": { - "id": element_id, - "data": obj_properties, - }, + "node": node, } @@ -341,14 +259,16 @@ def add_edge( current_user.id, sketch.investigation_id, actions=["update"], db=db ) - # Create relationship using GraphRepository + # Create relationship using GraphService try: - graph_repo = GraphRepository(neo4j_connection) - result = graph_repo.create_relationship_by_element_id( + graph_service = create_graph_service( + sketch_id=sketch_id, + enable_batching=False, + ) + result = graph_service.create_relationship_by_element_id( from_element_id=relation.source, to_element_id=relation.target, - rel_type=relation.label, - sketch_id=sketch_id, + rel_label=relation.label, ) except Exception as e: print(f"Edge creation error: {e}") @@ -380,33 +300,15 @@ def edit_node( current_user.id, sketch.investigation_id, actions=["update"], db=db ) - node_data = node_edit.data.model_dump() - node_type = node_data.get("type", "Node") - - # Get the Pydantic type from registry - DetectedType = TYPE_REGISTRY.get_lowercase(node_type) - if not DetectedType: - raise HTTPException(status_code=400, detail=f"Unknown type: {node_type}") - - # Clean empty values - cleaned_data = clean_empty_values(node_data) - - # Convert to Pydantic object + updates = node_edit.updates try: - pydantic_obj = DetectedType(**cleaned_data) - except Exception as e: - print(f"Pydantic validation error: {e}") - raise HTTPException( - status_code=400, detail=f"Invalid data for type {node_type}: {str(e)}" - ) - - # Update node using GraphRepository - try: - graph_repo = GraphRepository(neo4j_connection) - updated_element_id = graph_repo.update_node( - element_id=node_edit.nodeId, - node_obj=pydantic_obj, + graph_service = create_graph_service( sketch_id=sketch_id, + enable_batching=False, + ) + updated_element_id = graph_service.update_node( + element_id=node_edit.nodeId, + updates=updates, ) except Exception as e: print(f"Node update error: {e}") @@ -415,15 +317,10 @@ def edit_node( if not updated_element_id: raise HTTPException(status_code=404, detail="Node not found or not accessible") - # Return updated node with its data - pydantic_data = pydantic_obj.model_dump(mode="json") - pydantic_data["id"] = updated_element_id - return { "status": "node updated", "node": { "id": updated_element_id, - "data": pydantic_data, }, } @@ -462,15 +359,16 @@ def update_node_positions( if not data.positions: return {"status": "no positions to update", "count": 0} - # Convert Pydantic models to dicts for GraphRepository + # Convert Pydantic models to dicts for GraphService positions = [pos.model_dump() for pos in data.positions] - # Update positions using GraphRepository + # Update positions using GraphService try: - graph_repo = GraphRepository(neo4j_connection) - updated_count = graph_repo.update_nodes_positions( - positions=positions, sketch_id=sketch_id + graph_service = create_graph_service( + sketch_id=sketch_id, + enable_batching=False, ) + updated_count = graph_service.update_nodes_positions(positions=positions) except Exception as e: print(f"Position update error: {e}") raise HTTPException(status_code=500, detail="Failed to update node positions") @@ -498,10 +396,13 @@ def delete_nodes( current_user.id, sketch.investigation_id, actions=["update"], db=db ) - # Delete nodes and their relationships using GraphRepository + # Delete nodes and their relationships using GraphService try: - graph_repo = GraphRepository(neo4j_connection) - deleted_count = graph_repo.delete_nodes(nodes.nodeIds, sketch_id) + graph_service = create_graph_service( + sketch_id=sketch_id, + enable_batching=False, + ) + deleted_count = graph_service.delete_nodes(nodes.nodeIds) except Exception as e: print(f"Node deletion error: {e}") raise HTTPException(status_code=500, detail="Failed to delete nodes") @@ -526,11 +427,14 @@ def delete_relationships( current_user.id, sketch.investigation_id, actions=["update"], db=db ) - # Delete relationships using GraphRepository + # Delete relationships using GraphService try: - graph_repo = GraphRepository(neo4j_connection) - deleted_count = graph_repo.delete_relationships( - relationships.relationshipIds, sketch_id + graph_service = create_graph_service( + sketch_id=sketch_id, + enable_batching=False, + ) + deleted_count = graph_service.delete_relationships( + relationships.relationshipIds ) except Exception as e: print(f"Relationship deletion error: {e}") @@ -556,13 +460,15 @@ def edit_relationship( current_user.id, sketch.investigation_id, actions=["update"], db=db ) - # Update edge using GraphRepository + # Update edge using GraphService try: - graph_repo = GraphRepository(neo4j_connection) - result = graph_repo.update_relationship( + graph_service = create_graph_service( + sketch_id=sketch_id, + enable_batching=False, + ) + result = graph_service.update_relationship( element_id=relationship_edit.relationshipId, properties=relationship_edit.data, - sketch_id=sketch_id, ) except Exception as e: print(f"Relationship update error: {e}") @@ -615,12 +521,14 @@ def merge_nodes( properties.update(flattened_data) try: - graph_repo = GraphRepository(neo4j_connection) - new_node_element_id = graph_repo.merge_nodes( + graph_service = create_graph_service( + sketch_id=sketch_id, + enable_batching=False, + ) + new_node_element_id = graph_service.merge_nodes( old_node_ids=oldNodes, new_node_data=properties, new_node_id=newNode.id, - sketch_id=sketch_id, ) except Exception as e: print(f"Node merge error: {e}") @@ -651,16 +559,17 @@ def get_related_nodes( ) try: - graph_repo = GraphRepository(neo4j_connection) - result = graph_repo.get_related_nodes(node_id=node_id, sketch_id=sketch_id) + graph_service = create_graph_service(sketch_id=sketch_id) + result = graph_service.get_neighbors(node_id) + except Exception as e: - print(f"Related nodes query error: {e}") + print(e) raise HTTPException(status_code=500, detail="Failed to retrieve related nodes") - if not result["nds"]: + if not result.nodes: raise HTTPException(status_code=404, detail="Node not found") - return result + return {"nds": result.nodes, "rls": result.edges} @router.post("/{sketch_id}/import/analyze", response_model=FileParseResult) @@ -671,10 +580,9 @@ async def analyze_import_file( current_user: Profile = Depends(get_current_user), ): """ - Analyze an uploaded TXT or XML file for import. + Analyze an uploaded TXT or JSON file for import. Each line represents one entity. Detects entity types and provides preview. """ - # Verify sketch exists and user has access sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() if not sketch: @@ -696,12 +604,11 @@ async def analyze_import_file( except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to read file: {str(e)}") - # Parse and analyze the file + # Analyze file using ImportService try: - result = parse_import_file( + result = ImportService.analyze_file( file_content=content, filename=file.filename or "unknown.txt", - max_preview_rows=10000000, ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -711,15 +618,15 @@ async def analyze_import_file( return result -class EntityMapping(BaseModel): - """Mapping configuration for an entity.""" +class EntityMappingInput(BaseModel): + """Pydantic model for parsing entity mapping input from frontend.""" - id: str # Unique identifier for this entity (generated by frontend) + id: str entity_type: str include: bool = True - label: str + nodeLabel: str node_id: Optional[str] = None - data: Dict[str, Any] # Entity data from frontend + data: Dict[str, Any] class ImportExecuteResponse(BaseModel): @@ -754,12 +661,13 @@ async def execute_import( current_user.id, sketch.investigation_id, actions=["update"], db=db ) - # Parse entity mappings + # Parse entity mappings JSON try: mappings = json.loads(entity_mappings_json) - nodes = mappings.get("nodes") - edges = mappings.get("edges") - entity_mappings = [EntityMapping(**m) for m in nodes] + nodes = mappings.get("nodes", []) + edges = mappings.get("edges", []) + print(nodes) + entity_mapping_inputs = [EntityMappingInput(**m) for m in nodes] except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid entity_mappings JSON") except Exception as e: @@ -767,106 +675,38 @@ async def execute_import( status_code=400, detail=f"Failed to parse entity_mappings: {str(e)}" ) - # Filter only entities marked for inclusion - entities_to_import = [m for m in entity_mappings if m.include] + # Convert Pydantic inputs to service dataclasses + entity_mappings = [ + EntityMapping( + id=m.id, + entity_type=m.entity_type, + nodeLabel=m.nodeLabel, + data=m.data, + include=m.include, + node_id=m.node_id, + ) + for m in entity_mapping_inputs + ] - # Convert entity mappings to Pydantic objects - pydantic_nodes = [] - conversion_errors = [] - # Mapping dict to store frontend node_id -> index in pydantic_nodes - nodes_mapping_indices = {} - for idx, mapping in enumerate(entities_to_import): - entity_type = mapping.entity_type - label = mapping.label - entity_data = mapping.data - node_id = mapping.node_id - - # Get the Pydantic type from registry - DetectedType = TYPE_REGISTRY.get_lowercase(entity_type) - if not DetectedType: - conversion_errors.append( - f"Entity {idx + 1} ({label}): Unknown type {entity_type}" - ) - continue - - # Clean empty values and add label - cleaned_data = clean_empty_values(entity_data) - cleaned_data["label"] = label - - # Convert to Pydantic object - try: - pydantic_obj = DetectedType(**cleaned_data) - pydantic_nodes.append(pydantic_obj) - if node_id: - # Map frontend node_id to index in pydantic_nodes list - nodes_mapping_indices[node_id] = len(pydantic_nodes) - 1 - - except Exception as e: - conversion_errors.append(f"Entity {idx + 1} ({label}): {str(e)}") - - # Batch create all nodes first - graph_repo = GraphRepository(neo4j_connection) + # Execute import using ImportService + graph_service = create_graph_service( + sketch_id=sketch_id, enable_batching=False + ) + import_service = create_import_service(graph_service) try: - nodes_result = graph_repo.batch_create_nodes( - nodes=pydantic_nodes, sketch_id=sketch_id + result = import_service.execute_import( + entity_mappings=entity_mappings, + edges=edges, ) - nodes_created = nodes_result["nodes_created"] - node_element_ids = nodes_result.get("node_ids", []) - # Now create edges using element IDs - edges_to_insert = [] - edge_errors = [] - if edges and nodes_mapping_indices and node_element_ids: - for idx, edge in enumerate(edges): - from_id = edge.get("from_id") - to_id = edge.get("to_id") - - # Get indices in pydantic_nodes list - from_idx = nodes_mapping_indices.get(from_id) - to_idx = nodes_mapping_indices.get(to_id) - - if from_idx is None or to_idx is None: - edge_errors.append( - f"Edge {idx}: Missing source or target node (from: {from_id}, to: {to_id})" - ) - continue - - # Get corresponding element IDs from created nodes - if from_idx >= len(node_element_ids) or to_idx >= len(node_element_ids): - edge_errors.append( - f"Edge {idx}: Node index out of range (from_idx={from_idx}, to_idx={to_idx}, len={len(node_element_ids)})" - ) - continue - - from_element_id = node_element_ids[from_idx] - to_element_id = node_element_ids[to_idx] - - edge_to_insert = { - "from_element_id": from_element_id, - "to_element_id": to_element_id, - "rel_type": edge.get("label", "RELATED_TO"), - } - edges_to_insert.append(edge_to_insert) - - # Create edges using element IDs - if len(edges_to_insert) > 0: - edges_result = graph_repo.batch_create_edges_by_element_id( - edges=edges_to_insert, sketch_id=sketch_id - ) - edge_errors.extend(edges_result.get("errors", [])) - - batch_errors = nodes_result.get("errors", []) - all_errors = conversion_errors + batch_errors + edge_errors except Exception as e: - raise HTTPException(status_code=500, detail=f"Batch import failed: {str(e)}") - - nodes_skipped = len(entities_to_import) - nodes_created + raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}") return ImportExecuteResponse( - status="completed" if not all_errors else "completed_with_errors", - nodes_created=nodes_created, - nodes_skipped=nodes_skipped, - errors=all_errors[:50], # Limit to first 50 errors + status=result.status, + nodes_created=result.nodes_created, + nodes_skipped=result.nodes_skipped, + errors=result.errors, ) @@ -881,7 +721,7 @@ async def export_sketch( Export the sketch in the specified format. Args: id: The ID of the sketch - format: Export format - "json" or "png" (default: "json") + format: Export format - "json" (only format for now) db: The database session current_user: The current user Returns: @@ -894,50 +734,22 @@ async def export_sketch( current_user.id, sketch.investigation_id, actions=["read"], db=db ) - # Get all nodes and relationships using GraphRepository - graph_repo = GraphRepository(neo4j_connection) - graph_data = graph_repo.get_sketch_graph(id, limit=100000) - - nodes_result = graph_data["nodes"] - rels_result = graph_data["relationships"] - - nodes = [ - { - "id": str(record["id"]), - "data": record["data"], - "label": record["data"].get("label", "Node"), - "idx": idx, - **({"x": record["data"]["x"]} if "x" in record["data"] else {}), - **({"y": record["data"]["y"]} if "y" in record["data"] else {}), - } - for idx, record in enumerate(nodes_result) - ] - - rels = [ - { - "id": str(record["id"]), - "source": str(record["source"]), - "target": str(record["target"]), - "data": record["data"], - "label": record["type"], - } - for record in rels_result - ] + # Get all nodes and relationships using GraphService + graph_service = create_graph_service( + sketch_id=id, enable_batching=False + ) + graph_data = graph_service.get_sketch_graph() if format == "json": - from fastapi.responses import JSONResponse - - return JSONResponse( - content={ - "sketch": { - "id": str(sketch.id), - "title": sketch.title, - "description": sketch.description, - }, - "nodes": nodes, - "edges": rels, - } - ) + return { + "sketch": { + "id": str(sketch.id), + "title": sketch.title, + "description": sketch.description, + }, + "nodes": [node.model_dump(mode="json") for node in graph_data.nodes], + "edges": [edge.model_dump(mode="json") for edge in graph_data.edges], + } else: raise HTTPException(status_code=400, detail=f"Unsupported format: {format}") diff --git a/flowsint-api/app/api/routes/types.py b/flowsint-api/app/api/routes/types.py index 91d86fe..931bd4b 100644 --- a/flowsint-api/app/api/routes/types.py +++ b/flowsint-api/app/api/routes/types.py @@ -1,12 +1,14 @@ from typing import Any, Dict, Optional, Type from uuid import uuid4 + from fastapi import APIRouter, Depends +from flowsint_core.core.models import CustomType, Profile +from flowsint_core.core.postgre_db import get_db +from flowsint_types.registry import get_type from pydantic import BaseModel, TypeAdapter from sqlalchemy.orm import Session -from flowsint_core.core.postgre_db import get_db -from flowsint_core.core.models import CustomType, Profile + from app.api.deps import get_current_user -from flowsint_types.registry import get_type router = APIRouter() @@ -203,7 +205,9 @@ async def get_types_list( label_key = ( required[0] if required - else list(properties.keys())[0] if properties else "value" + else list(properties.keys())[0] + if properties + else "value" ) custom_types_children.append( @@ -247,7 +251,6 @@ async def get_types_list( def extract_input_schema( model: Type[BaseModel], label_key: str, icon: Optional[str] = None ) -> Dict[str, Any]: - adapter = TypeAdapter(model) schema = adapter.json_schema() # Use the main schema properties, not the $defs @@ -264,8 +267,8 @@ def extract_input_schema( "fields": [ resolve_field(prop, details=info, schema=schema) for prop, info in details.get("properties", {}).items() - # exclude label from properties to fill - if prop != "label" + # exclude nodeLabel from properties to fill + if prop != "nodeLabel" ], } diff --git a/flowsint-api/app/main.py b/flowsint-api/app/main.py index fc9fd73..66b044b 100644 --- a/flowsint-api/app/main.py +++ b/flowsint-api/app/main.py @@ -1,7 +1,4 @@ from fastapi import FastAPI -from flowsint_core.core.graph_db import Neo4jConnection -import os -from dotenv import load_dotenv from fastapi.middleware.cors import CORSMiddleware # Routes to be included @@ -18,19 +15,12 @@ from app.api.routes import keys from app.api.routes import types from app.api.routes import custom_types -load_dotenv() - -URI = os.getenv("NEO4J_URI_BOLT") -USERNAME = os.getenv("NEO4J_USERNAME") -PASSWORD = os.getenv("NEO4J_PASSWORD") - origins = [ "*", ] app = FastAPI() -neo4j_connection = Neo4jConnection(URI, USERNAME, PASSWORD) app.add_middleware( CORSMiddleware, diff --git a/flowsint-api/app/security/permissions.py b/flowsint-api/app/security/permissions.py index 46b0a08..2650f8e 100644 --- a/flowsint-api/app/security/permissions.py +++ b/flowsint-api/app/security/permissions.py @@ -1,3 +1,5 @@ +from uuid import UUID + from fastapi import HTTPException from flowsint_core.core.models import InvestigationUserRole from flowsint_core.core.types import Role @@ -21,7 +23,9 @@ def can_user(roles: list[Role], actions: list[str]) -> bool: from fastapi import HTTPException -def check_investigation_permission(user_id: str, investigation_id: str, actions: list[str], db): +def check_investigation_permission( + user_id: UUID, investigation_id: str, actions: list[str], db +): role_entry = ( db.query(InvestigationUserRole) .filter_by(user_id=user_id, investigation_id=investigation_id)