diff --git a/flowsint-api/app/api/routes/sketches.py b/flowsint-api/app/api/routes/sketches.py index 87e3d57..2ecaf81 100644 --- a/flowsint-api/app/api/routes/sketches.py +++ b/flowsint-api/app/api/routes/sketches.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional from uuid import UUID from fastapi import ( @@ -15,12 +15,10 @@ from flowsint_core.core.graph_db import neo4j_connection from flowsint_core.core.graph_repository import GraphRepository from flowsint_core.core.models import Profile, Sketch from flowsint_core.core.postgre_db import get_db -from flowsint_core.core.types import Role from flowsint_core.imports import FileParseResult, parse_import_file from flowsint_core.utils import flatten from flowsint_types import TYPE_REGISTRY from pydantic import BaseModel, Field -from sqlalchemy import or_ from sqlalchemy.orm import Session from app.api.deps import get_current_user @@ -63,14 +61,12 @@ class NodeEditInput(BaseModel): default_factory=NodeData, description="Updated data for the node" ) - class RelationshipEditInput(BaseModel): relationshipId: str data: Dict[str, Any] = Field( default_factory=dict, description="Updated data for the relationship" ) - class NodeMergeInput(BaseModel): id: str data: NodeData = Field( @@ -84,9 +80,12 @@ def create_sketch( db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user), ): - sketch_data = data.dict() + sketch_data = data.model_dump() + investigation_id = sketch_data.get("investigation_id") + if not investigation_id: + raise HTTPException(status_code=404, detail="Investigation not found") check_investigation_permission( - current_user.id, sketch_data.get("investigation_id"), actions=["create"], db=db + current_user.id, investigation_id, actions=["create"], db=db ) sketch_data["owner_id"] = current_user.id sketch = Sketch(**sketch_data) @@ -100,25 +99,7 @@ def create_sketch( def list_sketches( db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user) ): - from flowsint_core.core.models import InvestigationUserRole - - # Get all investigations where user has at least VIEWER role - allowed_roles_for_read = [Role.OWNER, Role.EDITOR, Role.VIEWER] - - query = db.query(Sketch).join( - InvestigationUserRole, - InvestigationUserRole.investigation_id == Sketch.investigation_id, - ) - - query = query.filter(InvestigationUserRole.user_id == current_user.id) - - # Filter by allowed roles - conditions = [ - InvestigationUserRole.roles.any(role) for role in allowed_roles_for_read - ] - query = query.filter(or_(*conditions)) - - return query.distinct().all() + return db.query(Sketch).filter(Sketch.owner_id == current_user.id).all() @router.get("/{sketch_id}") @@ -185,7 +166,7 @@ def delete_sketch( @router.get("/{id}/graph") async def get_sketch_nodes( id: str, - format: str = None, + format: str | None = None, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user), ): @@ -247,80 +228,6 @@ async def get_sketch_nodes( return {"nds": nodes, "rls": rels} -@router.get("/{id}/export") -async def export_sketch( - id: str, - format: str = "json", - db: Session = Depends(get_db), - current_user: Profile = Depends(get_current_user), -): - """ - Export the sketch in the specified format. - Args: - id: The ID of the sketch - format: Export format - "json" or "png" (default: "json") - db: The database session - current_user: The current user - Returns: - The sketch data in the requested format - """ - sketch = db.query(Sketch).filter(Sketch.id == id).first() - if not sketch: - raise HTTPException(status_code=404, detail="Sketch not found") - check_investigation_permission( - current_user.id, sketch.investigation_id, actions=["read"], db=db - ) - - # Get all nodes and relationships using GraphRepository - graph_repo = GraphRepository(neo4j_connection) - graph_data = graph_repo.get_sketch_graph(id, limit=100000) - - nodes_result = graph_data["nodes"] - rels_result = graph_data["relationships"] - - nodes = [ - { - "id": str(record["id"]), - "data": record["data"], - "label": record["data"].get("label", "Node"), - "idx": idx, - **({"x": record["data"]["x"]} if "x" in record["data"] else {}), - **({"y": record["data"]["y"]} if "y" in record["data"] else {}), - } - for idx, record in enumerate(nodes_result) - ] - - rels = [ - { - "id": str(record["id"]), - "source": str(record["source"]), - "target": str(record["target"]), - "data": record["data"], - "label": record["type"], - } - for record in rels_result - ] - - if format == "json": - from fastapi.responses import JSONResponse - - return JSONResponse( - content={ - "sketch": { - "id": str(sketch.id), - "title": sketch.title, - "description": sketch.description, - }, - "graph": {"nodes": nodes, "edges": rels}, - } - ) - elif format == "png": - # TODO: Implement PNG export - raise HTTPException(status_code=501, detail="PNG export not yet implemented") - else: - raise HTTPException(status_code=400, detail=f"Unsupported format: {format}") - - def clean_empty_values(data: dict) -> dict: """Remove empty string values from dict to avoid Pydantic validation errors.""" cleaned = {} @@ -413,7 +320,7 @@ class RelationInput(BaseModel): source: str target: str type: Literal["one-way", "two-way"] - label: str = "RELATED_TO" # Optionnel : nom de la relation + label: str = "RELATED_TO" @router.post("/{sketch_id}/relations/add") @@ -519,50 +426,6 @@ def edit_node( } -@router.put("/{sketch_id}/relationships/edit") -@update_sketch_timestamp -def edit_relationship( - sketch_id: str, - relationship_edit: RelationshipEditInput, - background_tasks: BackgroundTasks, - db: Session = Depends(get_db), - current_user: Profile = Depends(get_current_user), -): - # First verify the sketch exists and belongs to the user - sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() - if not sketch: - raise HTTPException(status_code=404, detail="Sketch not found") - check_investigation_permission( - current_user.id, sketch.investigation_id, actions=["update"], db=db - ) - - # Update edge using GraphRepository - try: - graph_repo = GraphRepository(neo4j_connection) - result = graph_repo.update_relationship( - element_id=relationship_edit.relationshipId, - properties=relationship_edit.data, - sketch_id=sketch_id, - ) - except Exception as e: - print(f"Relationship update error: {e}") - raise HTTPException(status_code=500, detail="Failed to update relationship") - - if not result: - raise HTTPException( - status_code=404, detail="Relationship not found or not accessible" - ) - - return { - "status": "relationship updated", - "relationship": { - "id": result["id"], - "label": result["type"], - "data": result["data"], - }, - } - - class NodePosition(BaseModel): nodeId: str x: float @@ -673,6 +536,49 @@ def delete_relationships( return {"status": "relationships deleted", "count": deleted_count} +@router.put("/{sketch_id}/relationships/edit") +@update_sketch_timestamp +def edit_relationship( + sketch_id: str, + relationship_edit: RelationshipEditInput, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db), + current_user: Profile = Depends(get_current_user), +): + # First verify the sketch exists and belongs to the user + sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() + if not sketch: + raise HTTPException(status_code=404, detail="Sketch not found") + check_investigation_permission( + current_user.id, sketch.investigation_id, actions=["update"], db=db + ) + + # Update edge using GraphRepository + try: + graph_repo = GraphRepository(neo4j_connection) + result = graph_repo.update_relationship( + element_id=relationship_edit.relationshipId, + properties=relationship_edit.data, + sketch_id=sketch_id, + ) + except Exception as e: + print(f"Relationship update error: {e}") + raise HTTPException(status_code=500, detail="Failed to update relationship") + + if not result: + raise HTTPException( + status_code=404, detail="Relationship not found or not accessible" + ) + + return { + "status": "relationship updated", + "relationship": { + "id": result["id"], + "label": result["type"], + "data": result["data"], + }, + } + @router.post("/{sketch_id}/nodes/merge") @update_sketch_timestamp @@ -762,9 +668,10 @@ async def analyze_import_file( current_user: Profile = Depends(get_current_user), ): """ - Analyze an uploaded TXT file for import. + Analyze an uploaded TXT or XML file for import. Each line represents one entity. Detects entity types and provides preview. """ + # Verify sketch exists and user has access sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() if not sketch: @@ -774,10 +681,10 @@ async def analyze_import_file( ) # Validate file extension - if not file.filename or not file.filename.lower().endswith(".txt"): + if not file.filename or not file.filename.lower().endswith((".txt", ".json")): raise HTTPException( status_code=400, - detail="Only .txt files are supported. Please upload a text file with one value per line.", + detail="Only .txt and .json files are supported. Please upload a correct format.", ) # Read file content @@ -790,7 +697,7 @@ async def analyze_import_file( try: result = parse_import_file( file_content=content, - filename=file.filename or "unknown.csv", + filename=file.filename or "unknown.txt", max_preview_rows=10000000, ) except ValueError as e: @@ -808,6 +715,7 @@ class EntityMapping(BaseModel): entity_type: str include: bool = True label: str + node_id: Optional[str] = None data: Dict[str, Any] # Entity data from frontend @@ -834,7 +742,6 @@ async def execute_import( Uses the entity mappings provided by the frontend (no file re-parsing needed). """ import json - # Verify sketch exists and user has access sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() if not sketch: @@ -845,8 +752,10 @@ async def execute_import( # Parse entity mappings try: - mappings_data = json.loads(entity_mappings_json) - entity_mappings = [EntityMapping(**m) for m in mappings_data] + mappings = json.loads(entity_mappings_json) + nodes = mappings.get("nodes") + edges = mappings.get("edges") + entity_mappings = [EntityMapping(**m) for m in nodes] except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid entity_mappings JSON") except Exception as e: @@ -860,11 +769,13 @@ async def execute_import( # Convert entity mappings to Pydantic objects pydantic_nodes = [] conversion_errors = [] - + # Mapping dict to store frontend node_id -> index in pydantic_nodes + nodes_mapping_indices = {} for idx, mapping in enumerate(entities_to_import): entity_type = mapping.entity_type label = mapping.label entity_data = mapping.data + node_id = mapping.node_id # Get the Pydantic type from registry DetectedType = TYPE_REGISTRY.get_lowercase(entity_type) @@ -882,18 +793,66 @@ async def execute_import( try: pydantic_obj = DetectedType(**cleaned_data) pydantic_nodes.append(pydantic_obj) + if node_id: + # Map frontend node_id to index in pydantic_nodes list + nodes_mapping_indices[node_id] = len(pydantic_nodes) - 1 + except Exception as e: conversion_errors.append(f"Entity {idx + 1} ({label}): {str(e)}") - # Batch create all nodes + # Batch create all nodes first graph_repo = GraphRepository(neo4j_connection) + try: - result = graph_repo.batch_create_nodes( + nodes_result = graph_repo.batch_create_nodes( nodes=pydantic_nodes, sketch_id=sketch_id ) - nodes_created = result["nodes_created"] - batch_errors = result.get("errors", []) - all_errors = conversion_errors + batch_errors + nodes_created = nodes_result["nodes_created"] + node_element_ids = nodes_result.get("node_ids", []) + # Now create edges using element IDs + edges_to_insert = [] + edge_errors = [] + if edges and nodes_mapping_indices and node_element_ids: + for idx, edge in enumerate(edges): + from_id = edge.get("from_id") + to_id = edge.get("to_id") + + # Get indices in pydantic_nodes list + from_idx = nodes_mapping_indices.get(from_id) + to_idx = nodes_mapping_indices.get(to_id) + + if from_idx is None or to_idx is None: + edge_errors.append( + f"Edge {idx}: Missing source or target node (from: {from_id}, to: {to_id})" + ) + continue + + # Get corresponding element IDs from created nodes + if from_idx >= len(node_element_ids) or to_idx >= len(node_element_ids): + edge_errors.append( + f"Edge {idx}: Node index out of range (from_idx={from_idx}, to_idx={to_idx}, len={len(node_element_ids)})" + ) + continue + + from_element_id = node_element_ids[from_idx] + to_element_id = node_element_ids[to_idx] + + edge_to_insert = { + "from_element_id": from_element_id, + "to_element_id": to_element_id, + "rel_type": edge.get("label", "RELATED_TO"), + } + edges_to_insert.append(edge_to_insert) + + # Create edges using element IDs + if len(edges_to_insert) > 0: + edges_result = graph_repo.batch_create_edges_by_element_id( + edges=edges_to_insert, sketch_id=sketch_id + ) + edge_errors.extend(edges_result.get("errors", [])) + + batch_errors = nodes_result.get("errors", []) + all_errors = conversion_errors + batch_errors + edge_errors except Exception as e: raise HTTPException(status_code=500, detail=f"Batch import failed: {str(e)}") @@ -905,3 +864,75 @@ async def execute_import( nodes_skipped=nodes_skipped, errors=all_errors[:50], # Limit to first 50 errors ) + + +@router.get("/{id}/export") +async def export_sketch( + id: str, + format: str = "json", + db: Session = Depends(get_db), + current_user: Profile = Depends(get_current_user), +): + """ + Export the sketch in the specified format. + Args: + id: The ID of the sketch + format: Export format - "json" or "png" (default: "json") + db: The database session + current_user: The current user + Returns: + The sketch data in the requested format + """ + sketch = db.query(Sketch).filter(Sketch.id == id).first() + if not sketch: + raise HTTPException(status_code=404, detail="Sketch not found") + check_investigation_permission( + current_user.id, sketch.investigation_id, actions=["read"], db=db + ) + + # Get all nodes and relationships using GraphRepository + graph_repo = GraphRepository(neo4j_connection) + graph_data = graph_repo.get_sketch_graph(id, limit=100000) + + nodes_result = graph_data["nodes"] + rels_result = graph_data["relationships"] + + nodes = [ + { + "id": str(record["id"]), + "data": record["data"], + "label": record["data"].get("label", "Node"), + "idx": idx, + **({"x": record["data"]["x"]} if "x" in record["data"] else {}), + **({"y": record["data"]["y"]} if "y" in record["data"] else {}), + } + for idx, record in enumerate(nodes_result) + ] + + rels = [ + { + "id": str(record["id"]), + "source": str(record["source"]), + "target": str(record["target"]), + "data": record["data"], + "label": record["type"], + } + for record in rels_result + ] + + if format == "json": + from fastapi.responses import JSONResponse + + return JSONResponse( + content={ + "sketch": { + "id": str(sketch.id), + "title": sketch.title, + "description": sketch.description, + }, + "graph": {"nodes": nodes, "edges": rels}, + } + ) + + else: + raise HTTPException(status_code=400, detail=f"Unsupported format: {format}")