mirror of
https://github.com/reconurge/flowsint.git
synced 2026-03-09 07:17:07 -05:00
feat(api): add json import support
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Literal
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import (
|
||||
@@ -15,12 +15,10 @@ from flowsint_core.core.graph_db import neo4j_connection
|
||||
from flowsint_core.core.graph_repository import GraphRepository
|
||||
from flowsint_core.core.models import Profile, Sketch
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.types import Role
|
||||
from flowsint_core.imports import FileParseResult, parse_import_file
|
||||
from flowsint_core.utils import flatten
|
||||
from flowsint_types import TYPE_REGISTRY
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
@@ -63,14 +61,12 @@ class NodeEditInput(BaseModel):
|
||||
default_factory=NodeData, description="Updated data for the node"
|
||||
)
|
||||
|
||||
|
||||
class RelationshipEditInput(BaseModel):
|
||||
relationshipId: str
|
||||
data: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Updated data for the relationship"
|
||||
)
|
||||
|
||||
|
||||
class NodeMergeInput(BaseModel):
|
||||
id: str
|
||||
data: NodeData = Field(
|
||||
@@ -84,9 +80,12 @@ def create_sketch(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
sketch_data = data.dict()
|
||||
sketch_data = data.model_dump()
|
||||
investigation_id = sketch_data.get("investigation_id")
|
||||
if not investigation_id:
|
||||
raise HTTPException(status_code=404, detail="Investigation not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch_data.get("investigation_id"), actions=["create"], db=db
|
||||
current_user.id, investigation_id, actions=["create"], db=db
|
||||
)
|
||||
sketch_data["owner_id"] = current_user.id
|
||||
sketch = Sketch(**sketch_data)
|
||||
@@ -100,25 +99,7 @@ def create_sketch(
|
||||
def list_sketches(
|
||||
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
|
||||
):
|
||||
from flowsint_core.core.models import InvestigationUserRole
|
||||
|
||||
# Get all investigations where user has at least VIEWER role
|
||||
allowed_roles_for_read = [Role.OWNER, Role.EDITOR, Role.VIEWER]
|
||||
|
||||
query = db.query(Sketch).join(
|
||||
InvestigationUserRole,
|
||||
InvestigationUserRole.investigation_id == Sketch.investigation_id,
|
||||
)
|
||||
|
||||
query = query.filter(InvestigationUserRole.user_id == current_user.id)
|
||||
|
||||
# Filter by allowed roles
|
||||
conditions = [
|
||||
InvestigationUserRole.roles.any(role) for role in allowed_roles_for_read
|
||||
]
|
||||
query = query.filter(or_(*conditions))
|
||||
|
||||
return query.distinct().all()
|
||||
return db.query(Sketch).filter(Sketch.owner_id == current_user.id).all()
|
||||
|
||||
|
||||
@router.get("/{sketch_id}")
|
||||
@@ -185,7 +166,7 @@ def delete_sketch(
|
||||
@router.get("/{id}/graph")
|
||||
async def get_sketch_nodes(
|
||||
id: str,
|
||||
format: str = None,
|
||||
format: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
@@ -247,80 +228,6 @@ async def get_sketch_nodes(
|
||||
return {"nds": nodes, "rls": rels}
|
||||
|
||||
|
||||
@router.get("/{id}/export")
|
||||
async def export_sketch(
|
||||
id: str,
|
||||
format: str = "json",
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Export the sketch in the specified format.
|
||||
Args:
|
||||
id: The ID of the sketch
|
||||
format: Export format - "json" or "png" (default: "json")
|
||||
db: The database session
|
||||
current_user: The current user
|
||||
Returns:
|
||||
The sketch data in the requested format
|
||||
"""
|
||||
sketch = db.query(Sketch).filter(Sketch.id == id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
|
||||
# Get all nodes and relationships using GraphRepository
|
||||
graph_repo = GraphRepository(neo4j_connection)
|
||||
graph_data = graph_repo.get_sketch_graph(id, limit=100000)
|
||||
|
||||
nodes_result = graph_data["nodes"]
|
||||
rels_result = graph_data["relationships"]
|
||||
|
||||
nodes = [
|
||||
{
|
||||
"id": str(record["id"]),
|
||||
"data": record["data"],
|
||||
"label": record["data"].get("label", "Node"),
|
||||
"idx": idx,
|
||||
**({"x": record["data"]["x"]} if "x" in record["data"] else {}),
|
||||
**({"y": record["data"]["y"]} if "y" in record["data"] else {}),
|
||||
}
|
||||
for idx, record in enumerate(nodes_result)
|
||||
]
|
||||
|
||||
rels = [
|
||||
{
|
||||
"id": str(record["id"]),
|
||||
"source": str(record["source"]),
|
||||
"target": str(record["target"]),
|
||||
"data": record["data"],
|
||||
"label": record["type"],
|
||||
}
|
||||
for record in rels_result
|
||||
]
|
||||
|
||||
if format == "json":
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"sketch": {
|
||||
"id": str(sketch.id),
|
||||
"title": sketch.title,
|
||||
"description": sketch.description,
|
||||
},
|
||||
"graph": {"nodes": nodes, "edges": rels},
|
||||
}
|
||||
)
|
||||
elif format == "png":
|
||||
# TODO: Implement PNG export
|
||||
raise HTTPException(status_code=501, detail="PNG export not yet implemented")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported format: {format}")
|
||||
|
||||
|
||||
def clean_empty_values(data: dict) -> dict:
|
||||
"""Remove empty string values from dict to avoid Pydantic validation errors."""
|
||||
cleaned = {}
|
||||
@@ -413,7 +320,7 @@ class RelationInput(BaseModel):
|
||||
source: str
|
||||
target: str
|
||||
type: Literal["one-way", "two-way"]
|
||||
label: str = "RELATED_TO" # Optionnel : nom de la relation
|
||||
label: str = "RELATED_TO"
|
||||
|
||||
|
||||
@router.post("/{sketch_id}/relations/add")
|
||||
@@ -519,50 +426,6 @@ def edit_node(
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{sketch_id}/relationships/edit")
|
||||
@update_sketch_timestamp
|
||||
def edit_relationship(
|
||||
sketch_id: str,
|
||||
relationship_edit: RelationshipEditInput,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
# First verify the sketch exists and belongs to the user
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
|
||||
# Update edge using GraphRepository
|
||||
try:
|
||||
graph_repo = GraphRepository(neo4j_connection)
|
||||
result = graph_repo.update_relationship(
|
||||
element_id=relationship_edit.relationshipId,
|
||||
properties=relationship_edit.data,
|
||||
sketch_id=sketch_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Relationship update error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update relationship")
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Relationship not found or not accessible"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "relationship updated",
|
||||
"relationship": {
|
||||
"id": result["id"],
|
||||
"label": result["type"],
|
||||
"data": result["data"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class NodePosition(BaseModel):
|
||||
nodeId: str
|
||||
x: float
|
||||
@@ -673,6 +536,49 @@ def delete_relationships(
|
||||
|
||||
return {"status": "relationships deleted", "count": deleted_count}
|
||||
|
||||
@router.put("/{sketch_id}/relationships/edit")
|
||||
@update_sketch_timestamp
|
||||
def edit_relationship(
|
||||
sketch_id: str,
|
||||
relationship_edit: RelationshipEditInput,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
# First verify the sketch exists and belongs to the user
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
|
||||
# Update edge using GraphRepository
|
||||
try:
|
||||
graph_repo = GraphRepository(neo4j_connection)
|
||||
result = graph_repo.update_relationship(
|
||||
element_id=relationship_edit.relationshipId,
|
||||
properties=relationship_edit.data,
|
||||
sketch_id=sketch_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Relationship update error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update relationship")
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Relationship not found or not accessible"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "relationship updated",
|
||||
"relationship": {
|
||||
"id": result["id"],
|
||||
"label": result["type"],
|
||||
"data": result["data"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{sketch_id}/nodes/merge")
|
||||
@update_sketch_timestamp
|
||||
@@ -762,9 +668,10 @@ async def analyze_import_file(
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Analyze an uploaded TXT file for import.
|
||||
Analyze an uploaded TXT or XML file for import.
|
||||
Each line represents one entity. Detects entity types and provides preview.
|
||||
"""
|
||||
|
||||
# Verify sketch exists and user has access
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
@@ -774,10 +681,10 @@ async def analyze_import_file(
|
||||
)
|
||||
|
||||
# Validate file extension
|
||||
if not file.filename or not file.filename.lower().endswith(".txt"):
|
||||
if not file.filename or not file.filename.lower().endswith((".txt", ".json")):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Only .txt files are supported. Please upload a text file with one value per line.",
|
||||
detail="Only .txt and .json files are supported. Please upload a correct format.",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
@@ -790,7 +697,7 @@ async def analyze_import_file(
|
||||
try:
|
||||
result = parse_import_file(
|
||||
file_content=content,
|
||||
filename=file.filename or "unknown.csv",
|
||||
filename=file.filename or "unknown.txt",
|
||||
max_preview_rows=10000000,
|
||||
)
|
||||
except ValueError as e:
|
||||
@@ -808,6 +715,7 @@ class EntityMapping(BaseModel):
|
||||
entity_type: str
|
||||
include: bool = True
|
||||
label: str
|
||||
node_id: Optional[str] = None
|
||||
data: Dict[str, Any] # Entity data from frontend
|
||||
|
||||
|
||||
@@ -834,7 +742,6 @@ async def execute_import(
|
||||
Uses the entity mappings provided by the frontend (no file re-parsing needed).
|
||||
"""
|
||||
import json
|
||||
|
||||
# Verify sketch exists and user has access
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
@@ -845,8 +752,10 @@ async def execute_import(
|
||||
|
||||
# Parse entity mappings
|
||||
try:
|
||||
mappings_data = json.loads(entity_mappings_json)
|
||||
entity_mappings = [EntityMapping(**m) for m in mappings_data]
|
||||
mappings = json.loads(entity_mappings_json)
|
||||
nodes = mappings.get("nodes")
|
||||
edges = mappings.get("edges")
|
||||
entity_mappings = [EntityMapping(**m) for m in nodes]
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid entity_mappings JSON")
|
||||
except Exception as e:
|
||||
@@ -860,11 +769,13 @@ async def execute_import(
|
||||
# Convert entity mappings to Pydantic objects
|
||||
pydantic_nodes = []
|
||||
conversion_errors = []
|
||||
|
||||
# Mapping dict to store frontend node_id -> index in pydantic_nodes
|
||||
nodes_mapping_indices = {}
|
||||
for idx, mapping in enumerate(entities_to_import):
|
||||
entity_type = mapping.entity_type
|
||||
label = mapping.label
|
||||
entity_data = mapping.data
|
||||
node_id = mapping.node_id
|
||||
|
||||
# Get the Pydantic type from registry
|
||||
DetectedType = TYPE_REGISTRY.get_lowercase(entity_type)
|
||||
@@ -882,18 +793,66 @@ async def execute_import(
|
||||
try:
|
||||
pydantic_obj = DetectedType(**cleaned_data)
|
||||
pydantic_nodes.append(pydantic_obj)
|
||||
if node_id:
|
||||
# Map frontend node_id to index in pydantic_nodes list
|
||||
nodes_mapping_indices[node_id] = len(pydantic_nodes) - 1
|
||||
|
||||
except Exception as e:
|
||||
conversion_errors.append(f"Entity {idx + 1} ({label}): {str(e)}")
|
||||
|
||||
# Batch create all nodes
|
||||
# Batch create all nodes first
|
||||
graph_repo = GraphRepository(neo4j_connection)
|
||||
|
||||
try:
|
||||
result = graph_repo.batch_create_nodes(
|
||||
nodes_result = graph_repo.batch_create_nodes(
|
||||
nodes=pydantic_nodes, sketch_id=sketch_id
|
||||
)
|
||||
nodes_created = result["nodes_created"]
|
||||
batch_errors = result.get("errors", [])
|
||||
all_errors = conversion_errors + batch_errors
|
||||
nodes_created = nodes_result["nodes_created"]
|
||||
node_element_ids = nodes_result.get("node_ids", [])
|
||||
# Now create edges using element IDs
|
||||
edges_to_insert = []
|
||||
edge_errors = []
|
||||
if edges and nodes_mapping_indices and node_element_ids:
|
||||
for idx, edge in enumerate(edges):
|
||||
from_id = edge.get("from_id")
|
||||
to_id = edge.get("to_id")
|
||||
|
||||
# Get indices in pydantic_nodes list
|
||||
from_idx = nodes_mapping_indices.get(from_id)
|
||||
to_idx = nodes_mapping_indices.get(to_id)
|
||||
|
||||
if from_idx is None or to_idx is None:
|
||||
edge_errors.append(
|
||||
f"Edge {idx}: Missing source or target node (from: {from_id}, to: {to_id})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Get corresponding element IDs from created nodes
|
||||
if from_idx >= len(node_element_ids) or to_idx >= len(node_element_ids):
|
||||
edge_errors.append(
|
||||
f"Edge {idx}: Node index out of range (from_idx={from_idx}, to_idx={to_idx}, len={len(node_element_ids)})"
|
||||
)
|
||||
continue
|
||||
|
||||
from_element_id = node_element_ids[from_idx]
|
||||
to_element_id = node_element_ids[to_idx]
|
||||
|
||||
edge_to_insert = {
|
||||
"from_element_id": from_element_id,
|
||||
"to_element_id": to_element_id,
|
||||
"rel_type": edge.get("label", "RELATED_TO"),
|
||||
}
|
||||
edges_to_insert.append(edge_to_insert)
|
||||
|
||||
# Create edges using element IDs
|
||||
if len(edges_to_insert) > 0:
|
||||
edges_result = graph_repo.batch_create_edges_by_element_id(
|
||||
edges=edges_to_insert, sketch_id=sketch_id
|
||||
)
|
||||
edge_errors.extend(edges_result.get("errors", []))
|
||||
|
||||
batch_errors = nodes_result.get("errors", [])
|
||||
all_errors = conversion_errors + batch_errors + edge_errors
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Batch import failed: {str(e)}")
|
||||
|
||||
@@ -905,3 +864,75 @@ async def execute_import(
|
||||
nodes_skipped=nodes_skipped,
|
||||
errors=all_errors[:50], # Limit to first 50 errors
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{id}/export")
|
||||
async def export_sketch(
|
||||
id: str,
|
||||
format: str = "json",
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Export the sketch in the specified format.
|
||||
Args:
|
||||
id: The ID of the sketch
|
||||
format: Export format - "json" or "png" (default: "json")
|
||||
db: The database session
|
||||
current_user: The current user
|
||||
Returns:
|
||||
The sketch data in the requested format
|
||||
"""
|
||||
sketch = db.query(Sketch).filter(Sketch.id == id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
|
||||
# Get all nodes and relationships using GraphRepository
|
||||
graph_repo = GraphRepository(neo4j_connection)
|
||||
graph_data = graph_repo.get_sketch_graph(id, limit=100000)
|
||||
|
||||
nodes_result = graph_data["nodes"]
|
||||
rels_result = graph_data["relationships"]
|
||||
|
||||
nodes = [
|
||||
{
|
||||
"id": str(record["id"]),
|
||||
"data": record["data"],
|
||||
"label": record["data"].get("label", "Node"),
|
||||
"idx": idx,
|
||||
**({"x": record["data"]["x"]} if "x" in record["data"] else {}),
|
||||
**({"y": record["data"]["y"]} if "y" in record["data"] else {}),
|
||||
}
|
||||
for idx, record in enumerate(nodes_result)
|
||||
]
|
||||
|
||||
rels = [
|
||||
{
|
||||
"id": str(record["id"]),
|
||||
"source": str(record["source"]),
|
||||
"target": str(record["target"]),
|
||||
"data": record["data"],
|
||||
"label": record["type"],
|
||||
}
|
||||
for record in rels_result
|
||||
]
|
||||
|
||||
if format == "json":
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"sketch": {
|
||||
"id": str(sketch.id),
|
||||
"title": sketch.title,
|
||||
"description": sketch.description,
|
||||
},
|
||||
"graph": {"nodes": nodes, "edges": rels},
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported format: {format}")
|
||||
|
||||
Reference in New Issue
Block a user