from app.security.permissions import check_investigation_permission from fastapi import APIRouter, HTTPException, Depends, status from pydantic import BaseModel, Field from typing import Literal, List from fastapi import HTTPException from pydantic import BaseModel, Field from flowsint_core.utils import flatten from sqlalchemy.orm import Session from app.api.schemas.sketch import SketchCreate, SketchRead, SketchUpdate from flowsint_core.core.models import Sketch, Profile from sqlalchemy.orm import Session from uuid import UUID from flowsint_core.core.graph_db import neo4j_connection from flowsint_core.core.postgre_db import get_db from app.api.deps import get_current_user router = APIRouter() class NodeData(BaseModel): label: str = Field(default="Node", description="Label/name of the node") color: str = Field(default="Node", description="Color of the node") type: str = Field(default="Node", description="Type of the node") # Add any other specific data fields that might be common across nodes class Config: extra = "allow" # Accept any additional fields class NodeInput(BaseModel): type: str = Field(..., description="Type of the node") data: NodeData = Field( default_factory=NodeData, description="Additional data for the node" ) def dict_to_cypher_props(props: dict, prefix: str = "") -> str: return ", ".join(f"{key}: ${prefix}{key}" for key in props) class NodeDeleteInput(BaseModel): nodeIds: List[str] class NodeEditInput(BaseModel): nodeId: str data: NodeData = Field( default_factory=NodeData, description="Updated data for the node" ) class NodeMergeInput(BaseModel): id: str data: NodeData = Field( default_factory=NodeData, description="Updated data for the node" ) @router.post("/create", response_model=SketchRead, status_code=status.HTTP_201_CREATED) def create_sketch( data: SketchCreate, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user), ): sketch_data = data.dict() check_investigation_permission( current_user.id, sketch_data.get("investigation_id"), actions=["create"], db=db ) sketch_data["owner_id"] = current_user.id sketch = Sketch(**sketch_data) db.add(sketch) db.commit() db.refresh(sketch) return sketch @router.get("", response_model=List[SketchRead]) def list_sketches( db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user) ): return db.query(Sketch).filter(Sketch.owner_id == current_user.id).all() @router.get("/{sketch_id}") def get_sketch_by_id( sketch_id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user), ): sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() if not sketch: raise HTTPException(status_code=404, detail="Sketch not found") return sketch @router.put("/{id}", response_model=SketchRead) def update_sketch( id: UUID, payload: SketchUpdate, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user), ): sketch = ( db.query(Sketch) .filter(Sketch.owner_id == current_user.id) .filter(Sketch.id == id) .first() ) if not sketch: raise HTTPException(status_code=404, detail="Sketch not found") for key, value in payload.model_dump(exclude_unset=True).items(): setattr(sketch, key, value) db.commit() db.refresh(sketch) return sketch @router.delete("/{id}", status_code=204) def delete_sketch( id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user), ): sketch = ( db.query(Sketch) .filter(Sketch.id == id, Sketch.owner_id == current_user.id) .first() ) if not sketch: raise HTTPException(status_code=404, detail="Sketch not found") # Delete all nodes and relationships in Neo4j first neo4j_query = """ MATCH (n {sketch_id: $sketch_id}) DETACH DELETE n """ try: neo4j_connection.query(neo4j_query, {"sketch_id": str(id)}) except Exception as e: print(f"Neo4j cleanup error: {e}") raise HTTPException(status_code=500, detail="Failed to clean up graph data") # Then delete the sketch from PostgreSQL db.delete(sketch) db.commit() @router.get("/{id}/graph") async def get_sketch_nodes( id: str, format: str = None, db: Session = Depends(get_db), # current_user: Profile = Depends(get_current_user) ): """ Get the nodes and relationships for a sketch. Args: id: The ID of the sketch format: Optional format parameter. If "inline", returns inline relationships db: The database session current_user: The current user Returns: A dictionary containing the nodes and relationships for the sketch nds: [] rls: [] Or if format=inline: List of inline relationship strings """ sketch = ( db.query(Sketch) .filter( Sketch.id == id, # Sketch.owner_id == current_user.id ) .first() ) if not sketch: raise HTTPException(status_code=404, detail="Graph not found") import random nodes_query = """ MATCH (n) WHERE n.sketch_id = $sketch_id RETURN elementId(n) as id, labels(n) as labels, properties(n) as data LIMIT 100000 """ nodes_result = neo4j_connection.query(nodes_query, parameters={"sketch_id": id}) node_ids = [record["id"] for record in nodes_result] rels_query = """ UNWIND $node_ids AS nid MATCH (a)-[r]->(b) WHERE elementId(a) = nid AND elementId(b) IN $node_ids RETURN elementId(r) as id, type(r) as type, elementId(a) as source, elementId(b) as target, properties(r) as data """ rels_result = neo4j_connection.query(rels_query, parameters={"node_ids": node_ids}) nodes = [ { "id": str(record["id"]), "data": record["data"], "label": record["data"].get("label", "Node"), "idx": idx, } for idx, record in enumerate(nodes_result) ] rels = [ { "id": str(record["id"]), "source": str(record["source"]), "target": str(record["target"]), "data": record["data"], "label": record["type"], } for record in rels_result ] if format == "inline": from flowsint_core.utils import get_inline_relationships return get_inline_relationships(nodes, rels) return {"nds": nodes, "rls": rels} @router.post("/{sketch_id}/nodes/add") def add_node( sketch_id: str, node: NodeInput, current_user: Profile = Depends(get_current_user) ): node_data = node.data.model_dump() node_type = node_data["type"] properties = { "type": node_type.lower(), "sketch_id": sketch_id, "caption": node_data["label"], "label": node_data["label"], } if node_data: flattened_data = flatten(node_data) properties.update(flattened_data) cypher_props = dict_to_cypher_props(properties) create_query = f""" MERGE (d:`{node_type}` {{ {cypher_props} }}) RETURN d as node, elementId(d) as id """ try: create_result = neo4j_connection.query(create_query, properties) except Exception as e: print(f"Query execution error: {e}") raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") if not create_result: raise HTTPException( status_code=400, detail="Node creation failed - no result returned" ) try: new_node = create_result[0]["node"] new_node["id"] = create_result[0]["id"] except (IndexError, KeyError) as e: print(f"Error extracting node_id: {e}, result: {create_result}") raise HTTPException( status_code=500, detail="Failed to extract node data from response" ) new_node["data"] = node_data new_node["data"]["id"] = new_node["id"] return { "status": "node added", "node": new_node, } class RelationInput(BaseModel): source: str target: str type: Literal["one-way", "two-way"] label: str = "RELATED_TO" # Optionnel : nom de la relation @router.post("/{sketch_id}/relations/add") def add_edge( sketch_id: str, relation: RelationInput, current_user: Profile = Depends(get_current_user), ): query = f""" MATCH (a) WHERE elementId(a) = $from_id MATCH (b) WHERE elementId(b) = $to_id MERGE (a)-[r:`{relation.label}` {{sketch_id: $sketch_id}}]->(b) RETURN r """ params = { "from_id": relation.source, "to_id": relation.target, "sketch_id": sketch_id, } try: result = neo4j_connection.query(query, params) except Exception as e: print(f"Edge creation error: {e}") raise HTTPException(status_code=500, detail="Failed to create edge") if not result: raise HTTPException(status_code=400, detail="Edge creation failed") return { "status": "edge added", "edge": result[0]["r"], } @router.put("/{sketch_id}/nodes/edit") def edit_node( sketch_id: str, node_edit: NodeEditInput, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user), ): # First verify the sketch exists and belongs to the user sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() if not sketch: raise HTTPException(status_code=404, detail="Sketch not found") node_data = node_edit.data.model_dump() node_type = node_data.get("type", "Node") # Prepare properties to update properties = { "type": node_type.lower(), "caption": node_data.get("label", "Node"), "label": node_data.get("label", "Node"), } # Add any additional data from the flattened node_data if node_data: flattened_data = flatten(node_data) properties.update(flattened_data) # Build the SET clause for the Cypher query set_clause = ", ".join(f"n.{key} = ${key}" for key in properties.keys()) query = f""" MATCH (n) WHERE elementId(n) = $node_id AND n.sketch_id = $sketch_id SET {set_clause} RETURN n as node """ params = {"node_id": node_edit.nodeId, "sketch_id": sketch_id, **properties} try: result = neo4j_connection.query(query, params) except Exception as e: print(f"Node update error: {e}") raise HTTPException(status_code=500, detail="Failed to update node") if not result: raise HTTPException(status_code=404, detail="Node not found or not accessible") updated_node = result[0]["node"] updated_node["data"] = node_data return { "status": "node updated", "node": updated_node, } @router.delete("/{sketch_id}/nodes") def delete_nodes( sketch_id: str, nodes: NodeDeleteInput, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user), ): # First verify the sketch exists and belongs to the user sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() if not sketch: raise HTTPException(status_code=404, detail="Sketch not found") # Delete nodes and their relationships query = """ UNWIND $node_ids AS node_id MATCH (n) WHERE elementId(n) = node_id AND n.sketch_id = $sketch_id DETACH DELETE n """ try: neo4j_connection.query( query, {"node_ids": nodes.nodeIds, "sketch_id": sketch_id} ) except Exception as e: print(f"Node deletion error: {e}") raise HTTPException(status_code=500, detail="Failed to delete nodes") return {"status": "nodes deleted", "count": len(nodes.nodeIds)} @router.post("/{sketch_id}/nodes/merge") def merge_nodes( sketch_id: str, oldNodes: List[str], newNode: NodeMergeInput, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user), ): # 1. Vérifier le sketch sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() if not sketch: raise HTTPException(status_code=404, detail="Sketch not found") oldNodeIds = [id for id in oldNodes] # 2. Préparer le node unique (utiliser nodeId) node_id = getattr(newNode, "id", None) if not node_id: raise HTTPException(status_code=400, detail="newNode.id is required") properties = {} if newNode.data: flattened_data = flatten(newNode.data.dict()) properties.update(flattened_data) cypher_props = dict_to_cypher_props(properties) node_type = getattr(newNode, "type", "Node") # 3. Créer ou merger le nouveau node create_query = f""" MERGE (new:`{node_type}` {{nodeId: $nodeId}}) SET new += $nodeData RETURN elementId(new) as newElementId """ try: result = neo4j_connection.query( create_query, {"nodeId": node_id, "nodeData": cypher_props} ) new_node_element_id = result[0]["newElementId"] except Exception as e: print(f"Error creating/merging new node: {e}") raise HTTPException(status_code=500, detail="Failed to create new node") # 4. Récupérer tous les types de relations des oldNodes rel_types_query = """ MATCH (old) WHERE elementId(old) IN $oldNodeIds AND old.sketch_id = $sketch_id MATCH (old)-[r]-() RETURN DISTINCT type(r) AS relType """ try: rel_types_result = neo4j_connection.query( rel_types_query, {"oldNodeIds": oldNodeIds, "sketch_id": sketch_id} ) rel_types = [row["relType"] for row in rel_types_result] or [] except Exception as e: print(f"Error fetching relation types: {e}") raise HTTPException(status_code=500, detail="Failed to fetch relation types") # 5. Construire la query pour copier les relations blocks = [] for rel_type in rel_types: block = f""" // Relations entrantes MATCH (new) WHERE elementId(new) = $newElementId MATCH (old) WHERE elementId(old) IN $oldNodeIds OPTIONAL MATCH (src)-[r:`{rel_type}`]->(old) WITH src, new, r WHERE src IS NOT NULL MERGE (src)-[newRel:`{rel_type}`]->(new) ON CREATE SET newRel = r ON MATCH SET newRel += r WITH DISTINCT new // Relations sortantes MATCH (new) WHERE elementId(new) = $newElementId MATCH (old) WHERE elementId(old) IN $oldNodeIds OPTIONAL MATCH (old)-[r:`{rel_type}`]->(dst) WITH dst, new, r WHERE dst IS NOT NULL MERGE (new)-[newRel2:`{rel_type}`]->(dst) ON CREATE SET newRel2 = r ON MATCH SET newRel2 += r WITH DISTINCT new """ blocks.append(block) # 6. Supprimer les anciens nodes delete_query = """ MATCH (old) WHERE elementId(old) IN $oldNodeIds DETACH DELETE old """ full_query = "\n".join(blocks) + delete_query # 7. Exécuter la query try: neo4j_connection.query( full_query, {"newElementId": new_node_element_id, "oldNodeIds": oldNodeIds} ) except Exception as e: print(f"Node merging error: {e}") raise HTTPException(status_code=500, detail="Failed to merge node relations") return {"status": "nodes merged", "count": len(oldNodeIds)} @router.get("/{sketch_id}/nodes/{node_id}") def get_related_nodes( sketch_id: str, node_id: str, db: Session = Depends(get_db), # current_user: Profile = Depends(get_current_user) ): # First verify the sketch exists and belongs to the user # sketch = db.query(Sketch).filter(Sketch.id == sketch_id, Sketch.owner_id == current_user.id).first() # if not sketch: # raise HTTPException(status_code=404, detail="Sketch not found") # Query to get all direct relationships and connected nodes # First, let's get the center node center_query = """ MATCH (n) WHERE elementId(n) = $node_id AND n.sketch_id = $sketch_id RETURN elementId(n) as id, labels(n) as labels, properties(n) as data """ try: center_result = neo4j_connection.query( center_query, {"sketch_id": sketch_id, "node_id": node_id} ) except Exception as e: print(f"Center node query error: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve center node") if not center_result: raise HTTPException(status_code=404, detail="Node not found") # Now get all relationships and connected nodes relationships_query = """ MATCH (n) WHERE elementId(n) = $node_id AND n.sketch_id = $sketch_id OPTIONAL MATCH (n)-[r]->(other) WHERE other.sketch_id = $sketch_id OPTIONAL MATCH (other)-[r2]->(n) WHERE other.sketch_id = $sketch_id RETURN elementId(r) as rel_id, type(r) as rel_type, properties(r) as rel_data, elementId(other) as other_node_id, labels(other) as other_node_labels, properties(other) as other_node_data, 'outgoing' as direction UNION MATCH (n) WHERE elementId(n) = $node_id AND n.sketch_id = $sketch_id OPTIONAL MATCH (other)-[r]->(n) WHERE other.sketch_id = $sketch_id RETURN elementId(r) as rel_id, type(r) as rel_type, properties(r) as rel_data, elementId(other) as other_node_id, labels(other) as other_node_labels, properties(other) as other_node_data, 'incoming' as direction """ try: result = neo4j_connection.query( relationships_query, {"sketch_id": sketch_id, "node_id": node_id} ) except Exception as e: print(f"Related nodes query error: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve related nodes") # Extract center node info center_record = center_result[0] center_node = { "id": center_record["id"], "labels": center_record["labels"], "data": center_record["data"], "label": center_record["data"].get("label", "Node"), "type": "custom", "caption": center_record["data"].get("label", "Node"), } # Collect all related nodes and relationships related_nodes = [] relationships = [] seen_nodes = set() seen_relationships = set() for record in result: # Skip if no relationship found if not record["rel_id"]: continue # Add relationship if not seen if record["rel_id"] not in seen_relationships: if record["direction"] == "outgoing": relationships.append( { "id": record["rel_id"], "type": "straight", "source": center_node["id"], "target": record["other_node_id"], "data": record["rel_data"], "caption": record["rel_type"], } ) else: # incoming relationships.append( { "id": record["rel_id"], "type": "straight", "source": record["other_node_id"], "target": center_node["id"], "data": record["rel_data"], "caption": record["rel_type"], } ) seen_relationships.add(record["rel_id"]) # Add related node if not seen if record["other_node_id"] and record["other_node_id"] not in seen_nodes: related_nodes.append( { "id": record["other_node_id"], "labels": record["other_node_labels"], "data": record["other_node_data"], "label": record["other_node_data"].get("label", "Node"), "type": "custom", "caption": record["other_node_data"].get("label", "Node"), } ) seen_nodes.add(record["other_node_id"]) # Combine center node with related nodes all_nodes = [center_node] + related_nodes return {"nds": all_nodes, "rls": relationships}