feat(api): add json import support

This commit is contained in:
dextmorgn
2026-01-10 18:07:14 +01:00
parent c1fe5d2307
commit 27c9e91c92

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal
from typing import Any, Dict, List, Literal, Optional
from uuid import UUID
from fastapi import (
@@ -15,12 +15,10 @@ from flowsint_core.core.graph_db import neo4j_connection
from flowsint_core.core.graph_repository import GraphRepository
from flowsint_core.core.models import Profile, Sketch
from flowsint_core.core.postgre_db import get_db
from flowsint_core.core.types import Role
from flowsint_core.imports import FileParseResult, parse_import_file
from flowsint_core.utils import flatten
from flowsint_types import TYPE_REGISTRY
from pydantic import BaseModel, Field
from sqlalchemy import or_
from sqlalchemy.orm import Session
from app.api.deps import get_current_user
@@ -63,14 +61,12 @@ class NodeEditInput(BaseModel):
default_factory=NodeData, description="Updated data for the node"
)
class RelationshipEditInput(BaseModel):
relationshipId: str
data: Dict[str, Any] = Field(
default_factory=dict, description="Updated data for the relationship"
)
class NodeMergeInput(BaseModel):
id: str
data: NodeData = Field(
@@ -84,9 +80,12 @@ def create_sketch(
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
sketch_data = data.dict()
sketch_data = data.model_dump()
investigation_id = sketch_data.get("investigation_id")
if not investigation_id:
raise HTTPException(status_code=404, detail="Investigation not found")
check_investigation_permission(
current_user.id, sketch_data.get("investigation_id"), actions=["create"], db=db
current_user.id, investigation_id, actions=["create"], db=db
)
sketch_data["owner_id"] = current_user.id
sketch = Sketch(**sketch_data)
@@ -100,25 +99,7 @@ def create_sketch(
def list_sketches(
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
):
from flowsint_core.core.models import InvestigationUserRole
# Get all investigations where user has at least VIEWER role
allowed_roles_for_read = [Role.OWNER, Role.EDITOR, Role.VIEWER]
query = db.query(Sketch).join(
InvestigationUserRole,
InvestigationUserRole.investigation_id == Sketch.investigation_id,
)
query = query.filter(InvestigationUserRole.user_id == current_user.id)
# Filter by allowed roles
conditions = [
InvestigationUserRole.roles.any(role) for role in allowed_roles_for_read
]
query = query.filter(or_(*conditions))
return query.distinct().all()
return db.query(Sketch).filter(Sketch.owner_id == current_user.id).all()
@router.get("/{sketch_id}")
@@ -185,7 +166,7 @@ def delete_sketch(
@router.get("/{id}/graph")
async def get_sketch_nodes(
id: str,
format: str = None,
format: str | None = None,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
@@ -247,80 +228,6 @@ async def get_sketch_nodes(
return {"nds": nodes, "rls": rels}
@router.get("/{id}/export")
async def export_sketch(
id: str,
format: str = "json",
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
"""
Export the sketch in the specified format.
Args:
id: The ID of the sketch
format: Export format - "json" or "png" (default: "json")
db: The database session
current_user: The current user
Returns:
The sketch data in the requested format
"""
sketch = db.query(Sketch).filter(Sketch.id == id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["read"], db=db
)
# Get all nodes and relationships using GraphRepository
graph_repo = GraphRepository(neo4j_connection)
graph_data = graph_repo.get_sketch_graph(id, limit=100000)
nodes_result = graph_data["nodes"]
rels_result = graph_data["relationships"]
nodes = [
{
"id": str(record["id"]),
"data": record["data"],
"label": record["data"].get("label", "Node"),
"idx": idx,
**({"x": record["data"]["x"]} if "x" in record["data"] else {}),
**({"y": record["data"]["y"]} if "y" in record["data"] else {}),
}
for idx, record in enumerate(nodes_result)
]
rels = [
{
"id": str(record["id"]),
"source": str(record["source"]),
"target": str(record["target"]),
"data": record["data"],
"label": record["type"],
}
for record in rels_result
]
if format == "json":
from fastapi.responses import JSONResponse
return JSONResponse(
content={
"sketch": {
"id": str(sketch.id),
"title": sketch.title,
"description": sketch.description,
},
"graph": {"nodes": nodes, "edges": rels},
}
)
elif format == "png":
# TODO: Implement PNG export
raise HTTPException(status_code=501, detail="PNG export not yet implemented")
else:
raise HTTPException(status_code=400, detail=f"Unsupported format: {format}")
def clean_empty_values(data: dict) -> dict:
"""Remove empty string values from dict to avoid Pydantic validation errors."""
cleaned = {}
@@ -413,7 +320,7 @@ class RelationInput(BaseModel):
source: str
target: str
type: Literal["one-way", "two-way"]
label: str = "RELATED_TO" # Optionnel : nom de la relation
label: str = "RELATED_TO"
@router.post("/{sketch_id}/relations/add")
@@ -519,50 +426,6 @@ def edit_node(
}
@router.put("/{sketch_id}/relationships/edit")
@update_sketch_timestamp
def edit_relationship(
sketch_id: str,
relationship_edit: RelationshipEditInput,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
# First verify the sketch exists and belongs to the user
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["update"], db=db
)
# Update edge using GraphRepository
try:
graph_repo = GraphRepository(neo4j_connection)
result = graph_repo.update_relationship(
element_id=relationship_edit.relationshipId,
properties=relationship_edit.data,
sketch_id=sketch_id,
)
except Exception as e:
print(f"Relationship update error: {e}")
raise HTTPException(status_code=500, detail="Failed to update relationship")
if not result:
raise HTTPException(
status_code=404, detail="Relationship not found or not accessible"
)
return {
"status": "relationship updated",
"relationship": {
"id": result["id"],
"label": result["type"],
"data": result["data"],
},
}
class NodePosition(BaseModel):
nodeId: str
x: float
@@ -673,6 +536,49 @@ def delete_relationships(
return {"status": "relationships deleted", "count": deleted_count}
@router.put("/{sketch_id}/relationships/edit")
@update_sketch_timestamp
def edit_relationship(
sketch_id: str,
relationship_edit: RelationshipEditInput,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
# First verify the sketch exists and belongs to the user
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["update"], db=db
)
# Update edge using GraphRepository
try:
graph_repo = GraphRepository(neo4j_connection)
result = graph_repo.update_relationship(
element_id=relationship_edit.relationshipId,
properties=relationship_edit.data,
sketch_id=sketch_id,
)
except Exception as e:
print(f"Relationship update error: {e}")
raise HTTPException(status_code=500, detail="Failed to update relationship")
if not result:
raise HTTPException(
status_code=404, detail="Relationship not found or not accessible"
)
return {
"status": "relationship updated",
"relationship": {
"id": result["id"],
"label": result["type"],
"data": result["data"],
},
}
@router.post("/{sketch_id}/nodes/merge")
@update_sketch_timestamp
@@ -762,9 +668,10 @@ async def analyze_import_file(
current_user: Profile = Depends(get_current_user),
):
"""
Analyze an uploaded TXT file for import.
Analyze an uploaded TXT or XML file for import.
Each line represents one entity. Detects entity types and provides preview.
"""
# Verify sketch exists and user has access
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
@@ -774,10 +681,10 @@ async def analyze_import_file(
)
# Validate file extension
if not file.filename or not file.filename.lower().endswith(".txt"):
if not file.filename or not file.filename.lower().endswith((".txt", ".json")):
raise HTTPException(
status_code=400,
detail="Only .txt files are supported. Please upload a text file with one value per line.",
detail="Only .txt and .json files are supported. Please upload a correct format.",
)
# Read file content
@@ -790,7 +697,7 @@ async def analyze_import_file(
try:
result = parse_import_file(
file_content=content,
filename=file.filename or "unknown.csv",
filename=file.filename or "unknown.txt",
max_preview_rows=10000000,
)
except ValueError as e:
@@ -808,6 +715,7 @@ class EntityMapping(BaseModel):
entity_type: str
include: bool = True
label: str
node_id: Optional[str] = None
data: Dict[str, Any] # Entity data from frontend
@@ -834,7 +742,6 @@ async def execute_import(
Uses the entity mappings provided by the frontend (no file re-parsing needed).
"""
import json
# Verify sketch exists and user has access
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
@@ -845,8 +752,10 @@ async def execute_import(
# Parse entity mappings
try:
mappings_data = json.loads(entity_mappings_json)
entity_mappings = [EntityMapping(**m) for m in mappings_data]
mappings = json.loads(entity_mappings_json)
nodes = mappings.get("nodes")
edges = mappings.get("edges")
entity_mappings = [EntityMapping(**m) for m in nodes]
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid entity_mappings JSON")
except Exception as e:
@@ -860,11 +769,13 @@ async def execute_import(
# Convert entity mappings to Pydantic objects
pydantic_nodes = []
conversion_errors = []
# Mapping dict to store frontend node_id -> index in pydantic_nodes
nodes_mapping_indices = {}
for idx, mapping in enumerate(entities_to_import):
entity_type = mapping.entity_type
label = mapping.label
entity_data = mapping.data
node_id = mapping.node_id
# Get the Pydantic type from registry
DetectedType = TYPE_REGISTRY.get_lowercase(entity_type)
@@ -882,18 +793,66 @@ async def execute_import(
try:
pydantic_obj = DetectedType(**cleaned_data)
pydantic_nodes.append(pydantic_obj)
if node_id:
# Map frontend node_id to index in pydantic_nodes list
nodes_mapping_indices[node_id] = len(pydantic_nodes) - 1
except Exception as e:
conversion_errors.append(f"Entity {idx + 1} ({label}): {str(e)}")
# Batch create all nodes
# Batch create all nodes first
graph_repo = GraphRepository(neo4j_connection)
try:
result = graph_repo.batch_create_nodes(
nodes_result = graph_repo.batch_create_nodes(
nodes=pydantic_nodes, sketch_id=sketch_id
)
nodes_created = result["nodes_created"]
batch_errors = result.get("errors", [])
all_errors = conversion_errors + batch_errors
nodes_created = nodes_result["nodes_created"]
node_element_ids = nodes_result.get("node_ids", [])
# Now create edges using element IDs
edges_to_insert = []
edge_errors = []
if edges and nodes_mapping_indices and node_element_ids:
for idx, edge in enumerate(edges):
from_id = edge.get("from_id")
to_id = edge.get("to_id")
# Get indices in pydantic_nodes list
from_idx = nodes_mapping_indices.get(from_id)
to_idx = nodes_mapping_indices.get(to_id)
if from_idx is None or to_idx is None:
edge_errors.append(
f"Edge {idx}: Missing source or target node (from: {from_id}, to: {to_id})"
)
continue
# Get corresponding element IDs from created nodes
if from_idx >= len(node_element_ids) or to_idx >= len(node_element_ids):
edge_errors.append(
f"Edge {idx}: Node index out of range (from_idx={from_idx}, to_idx={to_idx}, len={len(node_element_ids)})"
)
continue
from_element_id = node_element_ids[from_idx]
to_element_id = node_element_ids[to_idx]
edge_to_insert = {
"from_element_id": from_element_id,
"to_element_id": to_element_id,
"rel_type": edge.get("label", "RELATED_TO"),
}
edges_to_insert.append(edge_to_insert)
# Create edges using element IDs
if len(edges_to_insert) > 0:
edges_result = graph_repo.batch_create_edges_by_element_id(
edges=edges_to_insert, sketch_id=sketch_id
)
edge_errors.extend(edges_result.get("errors", []))
batch_errors = nodes_result.get("errors", [])
all_errors = conversion_errors + batch_errors + edge_errors
except Exception as e:
raise HTTPException(status_code=500, detail=f"Batch import failed: {str(e)}")
@@ -905,3 +864,75 @@ async def execute_import(
nodes_skipped=nodes_skipped,
errors=all_errors[:50], # Limit to first 50 errors
)
@router.get("/{id}/export")
async def export_sketch(
id: str,
format: str = "json",
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
"""
Export the sketch in the specified format.
Args:
id: The ID of the sketch
format: Export format - "json" or "png" (default: "json")
db: The database session
current_user: The current user
Returns:
The sketch data in the requested format
"""
sketch = db.query(Sketch).filter(Sketch.id == id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["read"], db=db
)
# Get all nodes and relationships using GraphRepository
graph_repo = GraphRepository(neo4j_connection)
graph_data = graph_repo.get_sketch_graph(id, limit=100000)
nodes_result = graph_data["nodes"]
rels_result = graph_data["relationships"]
nodes = [
{
"id": str(record["id"]),
"data": record["data"],
"label": record["data"].get("label", "Node"),
"idx": idx,
**({"x": record["data"]["x"]} if "x" in record["data"] else {}),
**({"y": record["data"]["y"]} if "y" in record["data"] else {}),
}
for idx, record in enumerate(nodes_result)
]
rels = [
{
"id": str(record["id"]),
"source": str(record["source"]),
"target": str(record["target"]),
"data": record["data"],
"label": record["type"],
}
for record in rels_result
]
if format == "json":
from fastapi.responses import JSONResponse
return JSONResponse(
content={
"sketch": {
"id": str(sketch.id),
"title": sketch.title,
"description": sketch.description,
},
"graph": {"nodes": nodes, "edges": rels},
}
)
else:
raise HTTPException(status_code=400, detail=f"Unsupported format: {format}")