refactor(api): better batch import

This commit is contained in:
dextmorgn
2025-12-05 15:47:43 +01:00
parent 79089f7a2d
commit 5bb7fae4bd

View File

@@ -9,6 +9,7 @@ from fastapi import (
Form,
BackgroundTasks,
)
from flowsint_types import TYPE_REGISTRY
from pydantic import BaseModel, Field
from typing import Literal, List, Optional, Dict, Any
from datetime import datetime, timezone
@@ -29,12 +30,10 @@ router = APIRouter()
class NodeData(BaseModel):
label: str = Field(default="Node", description="Label/name of the node")
color: str = Field(default="Node", description="Color of the node")
type: str = Field(default="Node", description="Type of the node")
# Add any other specific data fields that might be common across nodes
class Config:
extra = "allow" # Accept any additional fields
extra = "allow"
class NodeInput(BaseModel):
@@ -47,10 +46,6 @@ class NodeInput(BaseModel):
extra = "allow" # Accept any additional fields
def dict_to_cypher_props(props: dict, prefix: str = "") -> str:
return ", ".join(f"{key}: ${prefix}{key}" for key in props)
class NodeDeleteInput(BaseModel):
nodeIds: List[str]
@@ -224,6 +219,29 @@ async def get_sketch_nodes(
return {"nds": nodes, "rls": rels}
def clean_empty_values(data: dict) -> dict:
"""Remove empty string values from dict to avoid Pydantic validation errors."""
cleaned = {}
for key, value in data.items():
if value == "" or value is None:
continue
if isinstance(value, dict):
cleaned_nested = clean_empty_values(value)
if cleaned_nested:
cleaned[key] = cleaned_nested
elif isinstance(value, list):
cleaned_list = [
clean_empty_values(item) if isinstance(item, dict) else item
for item in value
if item != "" and item is not None
]
if cleaned_list:
cleaned[key] = cleaned_list
else:
cleaned[key] = value
return cleaned
@router.post("/{sketch_id}/nodes/add")
@update_sketch_timestamp
def add_node(
@@ -241,60 +259,41 @@ def add_node(
)
node_data = node.data.model_dump()
node_type = node_data["type"]
properties = {
"type": node_type.lower(),
"sketch_id": sketch_id,
"caption": node_data["label"],
"label": node_data["label"],
}
DetectedType = TYPE_REGISTRY.get_lowercase(node_type)
if not DetectedType:
raise HTTPException(status_code=400, detail=f"Unknown type: {node_type}")
if node_data:
flattened_data = flatten(node_data)
properties.update(flattened_data)
cypher_props = dict_to_cypher_props(properties)
# Add created_at to parameters
properties_with_timestamp = {
**properties,
"created_at": datetime.now(timezone.utc).isoformat(),
}
create_query = f"""
MERGE (d:`{node_type}` {{ {cypher_props} }})
ON CREATE SET d.created_at = $created_at
RETURN d as node, elementId(d) as id
"""
cleaned_data = clean_empty_values(node_data)
try:
create_result = neo4j_connection.query(create_query, properties_with_timestamp)
pydantic_obj = DetectedType(**cleaned_data)
except Exception as e:
print(f"Pydantic validation error: {e}")
raise HTTPException(
status_code=400, detail=f"Invalid data for type {node_type}: {str(e)}"
)
try:
graph_repo = GraphRepository(neo4j_connection)
element_id = graph_repo.create_node(pydantic_obj, sketch_id=sketch_id)
except Exception as e:
print(f"Query execution error: {e}")
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
if not create_result:
raise HTTPException(
status_code=400, detail="Node creation failed - no result returned"
)
if not element_id:
raise HTTPException(status_code=400, detail="Node creation failed")
try:
new_node = create_result[0]["node"]
new_node["id"] = create_result[0]["id"]
except (IndexError, KeyError) as e:
print(f"Error extracting node_id: {e}, result: {create_result}")
raise HTTPException(
status_code=500, detail="Failed to extract node data from response"
)
new_node["data"] = node_data
new_node["data"]["id"] = new_node["id"]
pydantic_data = pydantic_obj.model_dump(mode="json")
pydantic_data["id"] = element_id
pydantic_data["type"] = node_type
return {
"status": "node added",
"node": new_node,
"node": {
"id": element_id,
"data": pydantic_data,
},
}
@@ -363,39 +362,48 @@ def edit_node(
node_data = node_edit.data.model_dump()
node_type = node_data.get("type", "Node")
# Prepare properties to update
properties = {
"type": node_type.lower(),
"caption": node_data.get("label", "Node"),
"label": node_data.get("label", "Node"),
}
# Get the Pydantic type from registry
DetectedType = TYPE_REGISTRY.get_lowercase(node_type)
if not DetectedType:
raise HTTPException(status_code=400, detail=f"Unknown type: {node_type}")
# Add any additional data from the flattened node_data
if node_data:
flattened_data = flatten(node_data)
properties.update(flattened_data)
# Clean empty values
cleaned_data = clean_empty_values(node_data)
# Remove sketch_id from properties to avoid conflict (it's passed separately for security)
properties.pop("sketch_id", None)
# Convert to Pydantic object
try:
pydantic_obj = DetectedType(**cleaned_data)
except Exception as e:
print(f"Pydantic validation error: {e}")
raise HTTPException(
status_code=400, detail=f"Invalid data for type {node_type}: {str(e)}"
)
# Update node using GraphRepository
try:
graph_repo = GraphRepository(neo4j_connection)
updated_node = graph_repo.update_node_by_element_id(
element_id=node_edit.nodeId, sketch_id=sketch_id, **properties
updated_element_id = graph_repo.update_node(
element_id=node_edit.nodeId,
node_obj=pydantic_obj,
sketch_id=sketch_id,
)
except Exception as e:
print(f"Node update error: {e}")
raise HTTPException(status_code=500, detail="Failed to update node")
if not updated_node:
if not updated_element_id:
raise HTTPException(status_code=404, detail="Node not found or not accessible")
updated_node["data"] = node_data
# Return updated node with its data
pydantic_data = pydantic_obj.model_dump(mode="json")
pydantic_data["id"] = updated_element_id
return {
"status": "node updated",
"node": updated_node,
"node": {
"id": updated_element_id,
"data": pydantic_data,
},
}
@@ -520,7 +528,6 @@ def merge_nodes(
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
# 1. Verify the sketch exists
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
@@ -531,125 +538,35 @@ def merge_nodes(
if not oldNodes or len(oldNodes) == 0:
raise HTTPException(status_code=400, detail="oldNodes cannot be empty")
oldNodeIds = oldNodes
# 2. Prepare the merged node data
node_data = newNode.data.model_dump() if newNode.data else {}
node_type = node_data.get("type", "Node")
# Build properties for the new merged node
properties = {
"type": node_type.lower(),
"sketch_id": sketch_id,
"label": node_data.get("label", "Merged Node"),
"caption": node_data.get("label", "Merged Node"),
}
# Add all other data from the node
flattened_data = flatten(node_data)
properties.update(flattened_data)
# 3. Check if the newNode.id is one of the old nodes (reusing existing node)
# or if we need to create a brand new node
is_reusing_node = newNode.id in oldNodeIds
if is_reusing_node:
# Update the existing node that we're keeping
set_clause = ", ".join(f"n.{key} = ${key}" for key in properties.keys())
create_query = f"""
MATCH (n)
WHERE elementId(n) = $nodeId AND n.sketch_id = $sketch_id
SET {set_clause}
RETURN elementId(n) as newElementId
"""
params = {"nodeId": newNode.id, "sketch_id": sketch_id, **properties}
else:
# Create a completely new node with created_at timestamp
properties["created_at"] = datetime.now(timezone.utc).isoformat()
create_query = f"""
CREATE (n:`{node_type}`)
SET n = $properties
RETURN elementId(n) as newElementId
"""
params = {"properties": properties}
try:
result = neo4j_connection.query(create_query, params)
if not result:
raise HTTPException(
status_code=500, detail="Failed to create/update merged node"
)
new_node_element_id = result[0]["newElementId"]
except Exception as e:
print(f"Error creating/updating merged node: {e}")
raise HTTPException(
status_code=500, detail=f"Failed to create merged node: {str(e)}"
)
# 4. Copy all relationships from old nodes to the new node
# This handles both incoming and outgoing relationships while preserving types and properties
copy_relationships_query = """
MATCH (new) WHERE elementId(new) = $newElementId
UNWIND $oldNodeIds AS oldNodeId
MATCH (old) WHERE elementId(old) = oldNodeId AND old.sketch_id = $sketch_id
// Copy incoming relationships - get all unique combinations
WITH new, collect(old) as oldNodes
UNWIND oldNodes as old
MATCH (src)-[r]->(old)
WHERE elementId(src) NOT IN $oldNodeIds AND elementId(src) <> $newElementId
WITH new, src, type(r) as relType, properties(r) as relProps, r
MERGE (src)-[newRel:RELATED_TO {sketch_id: $sketch_id}]->(new)
SET newRel = relProps
WITH new, $oldNodeIds as oldNodeIds
UNWIND oldNodeIds AS oldNodeId
MATCH (old) WHERE elementId(old) = oldNodeId AND old.sketch_id = $sketch_id
// Copy outgoing relationships
MATCH (old)-[r]->(dst)
WHERE elementId(dst) NOT IN oldNodeIds AND elementId(dst) <> $newElementId
WITH new, dst, type(r) as relType, properties(r) as relProps
MERGE (new)-[newRel:RELATED_TO {sketch_id: $sketch_id}]->(dst)
SET newRel = relProps
"""
try:
neo4j_connection.query(
copy_relationships_query,
{
"newElementId": new_node_element_id,
"oldNodeIds": oldNodeIds,
"sketch_id": sketch_id,
},
graph_repo = GraphRepository(neo4j_connection)
new_node_element_id = graph_repo.merge_nodes(
old_node_ids=oldNodes,
new_node_data=properties,
new_node_id=newNode.id,
sketch_id=sketch_id,
)
except Exception as e:
print(f"Error copying relationships: {e}")
# Don't fail if relationship copying has issues, continue to deletion
print(f"Node merge error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to merge nodes: {str(e)}")
# 5. Delete the old nodes (except if we're reusing one)
nodes_to_delete = [nid for nid in oldNodeIds if nid != new_node_element_id]
if nodes_to_delete:
delete_query = """
UNWIND $nodeIds AS nodeId
MATCH (old)
WHERE elementId(old) = nodeId AND old.sketch_id = $sketch_id
DETACH DELETE old
"""
try:
neo4j_connection.query(
delete_query, {"nodeIds": nodes_to_delete, "sketch_id": sketch_id}
)
except Exception as e:
print(f"Error deleting old nodes: {e}")
raise HTTPException(status_code=500, detail="Failed to delete old nodes")
if not new_node_element_id:
raise HTTPException(status_code=500, detail="Failed to merge nodes")
return {
"status": "nodes merged",
"count": len(oldNodeIds),
"count": len(oldNodes),
"new_node_id": new_node_element_id,
}
@@ -661,7 +578,6 @@ def get_related_nodes(
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
# First verify the sketch exists and belongs to the user
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
@@ -669,134 +585,17 @@ def get_related_nodes(
current_user.id, sketch.investigation_id, actions=["read"], db=db
)
# Query to get all direct relationships and connected nodes
# First, let's get the center node
center_query = """
MATCH (n)
WHERE elementId(n) = $node_id AND n.sketch_id = $sketch_id
RETURN elementId(n) as id, labels(n) as labels, properties(n) as data
"""
try:
center_result = neo4j_connection.query(
center_query, {"sketch_id": sketch_id, "node_id": node_id}
)
except Exception as e:
print(f"Center node query error: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve center node")
if not center_result:
raise HTTPException(status_code=404, detail="Node not found")
# Now get all relationships and connected nodes
relationships_query = """
MATCH (n)
WHERE elementId(n) = $node_id AND n.sketch_id = $sketch_id
OPTIONAL MATCH (n)-[r]->(other)
WHERE other.sketch_id = $sketch_id
OPTIONAL MATCH (other)-[r2]->(n)
WHERE other.sketch_id = $sketch_id
RETURN
elementId(r) as rel_id,
type(r) as rel_type,
properties(r) as rel_data,
elementId(other) as other_node_id,
labels(other) as other_node_labels,
properties(other) as other_node_data,
'outgoing' as direction
UNION
MATCH (n)
WHERE elementId(n) = $node_id AND n.sketch_id = $sketch_id
OPTIONAL MATCH (other)-[r]->(n)
WHERE other.sketch_id = $sketch_id
RETURN
elementId(r) as rel_id,
type(r) as rel_type,
properties(r) as rel_data,
elementId(other) as other_node_id,
labels(other) as other_node_labels,
properties(other) as other_node_data,
'incoming' as direction
"""
try:
result = neo4j_connection.query(
relationships_query, {"sketch_id": sketch_id, "node_id": node_id}
)
graph_repo = GraphRepository(neo4j_connection)
result = graph_repo.get_related_nodes(node_id=node_id, sketch_id=sketch_id)
except Exception as e:
print(f"Related nodes query error: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve related nodes")
# Extract center node info
center_record = center_result[0]
center_node = {
"id": center_record["id"],
"labels": center_record["labels"],
"data": center_record["data"],
"label": center_record["data"].get("label", "Node"),
"type": "custom",
"caption": center_record["data"].get("label", "Node"),
}
if not result["nds"]:
raise HTTPException(status_code=404, detail="Node not found")
# Collect all related nodes and relationships
related_nodes = []
relationships = []
seen_nodes = set()
seen_relationships = set()
for record in result:
# Skip if no relationship found
if not record["rel_id"]:
continue
# Add relationship if not seen
if record["rel_id"] not in seen_relationships:
if record["direction"] == "outgoing":
relationships.append(
{
"id": record["rel_id"],
"type": "straight",
"source": center_node["id"],
"target": record["other_node_id"],
"data": record["rel_data"],
"caption": record["rel_type"],
}
)
else: # incoming
relationships.append(
{
"id": record["rel_id"],
"type": "straight",
"source": record["other_node_id"],
"target": center_node["id"],
"data": record["rel_data"],
"caption": record["rel_type"],
}
)
seen_relationships.add(record["rel_id"])
# Add related node if not seen
if (
record["other_node_id"]
and record["other_node_id"] not in seen_nodes
and record["other_node_id"] != center_node["id"]
):
related_nodes.append(
{
"id": record["other_node_id"],
"labels": record["other_node_labels"],
"data": record["other_node_data"],
"label": record["other_node_data"].get("label", "Node"),
"type": "custom",
"caption": record["other_node_data"].get("label", "Node"),
}
)
seen_nodes.add(record["other_node_id"])
# Combine center node with related nodes
all_nodes = [center_node] + related_nodes
return {"nds": all_nodes, "rls": relationships}
return result
@router.post("/{sketch_id}/import/analyze", response_model=FileParseResult)
@@ -902,47 +701,47 @@ async def execute_import(
# Filter only entities marked for inclusion
entities_to_import = [m for m in entity_mappings if m.include]
# Import entities using GraphRepository
graph_repo = GraphRepository(neo4j_connection)
nodes_created = 0
nodes_skipped = 0
errors = []
# Convert entity mappings to Pydantic objects
pydantic_nodes = []
conversion_errors = []
for idx, mapping in enumerate(entities_to_import):
# Use data from mapping directly
entity_type = mapping.entity_type
label = mapping.label
entity_data = mapping.data
# Flatten entity data for storage
flattened_data = flatten(entity_data)
# Get the Pydantic type from registry
DetectedType = TYPE_REGISTRY.get_lowercase(entity_type)
if not DetectedType:
conversion_errors.append(f"Entity {idx + 1} ({label}): Unknown type {entity_type}")
continue
# Remove fields that are passed as explicit parameters to avoid conflicts
flattened_data.pop('label', None)
flattened_data.pop('type', None)
flattened_data.pop('sketch_id', None)
# Clean empty values and add label
cleaned_data = clean_empty_values(entity_data)
cleaned_data["label"] = label
# Create node using GraphRepository
# Convert to Pydantic object
try:
node_id = graph_repo.create_node_from_import(
node_type=entity_type,
label=label,
sketch_id=sketch_id,
**flattened_data,
)
if node_id:
nodes_created += 1
else:
nodes_skipped += 1
errors.append(f"Entity {idx + 1} ({label}): Failed to create node")
pydantic_obj = DetectedType(**cleaned_data)
pydantic_nodes.append(pydantic_obj)
except Exception as e:
error_msg = f"Entity {idx + 1} ({label}): {str(e)}"
errors.append(error_msg)
nodes_skipped += 1
conversion_errors.append(f"Entity {idx + 1} ({label}): {str(e)}")
# Batch create all nodes
graph_repo = GraphRepository(neo4j_connection)
try:
result = graph_repo.batch_create_nodes(nodes=pydantic_nodes, sketch_id=sketch_id)
nodes_created = result["nodes_created"]
batch_errors = result.get("errors", [])
all_errors = conversion_errors + batch_errors
except Exception as e:
raise HTTPException(status_code=500, detail=f"Batch import failed: {str(e)}")
nodes_skipped = len(entities_to_import) - nodes_created
return ImportExecuteResponse(
status="completed" if not errors else "completed_with_errors",
status="completed" if not all_errors else "completed_with_errors",
nodes_created=nodes_created,
nodes_skipped=nodes_skipped,
errors=errors[:50], # Limit to first 50 errors
errors=all_errors[:50], # Limit to first 50 errors
)