feat(api): dedicated sketch event route for refresh

This commit is contained in:
dextmorgn
2025-11-27 22:20:28 +01:00
parent 81a779cd73
commit 6412b1afe8
2 changed files with 71 additions and 36 deletions

View File

@@ -148,6 +148,59 @@ def delete_scan_logs(
raise HTTPException(status_code=500, detail=f"Failed to delete logs: {str(e)}")
@router.get("/sketch/{sketch_id}/status/stream")
async def stream_sketch_status(
request: Request,
sketch_id: str,
db: Session = Depends(get_db),
):
"""Stream COMPLETED events for a specific sketch (for graph refresh)"""
# Check if sketch exists
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(
status_code=404, detail=f"Sketch with id {sketch_id} not found"
)
async def status_generator():
channel = f"{sketch_id}_status"
await event_emitter.subscribe(channel)
try:
# Initial connection message
yield json.dumps({"event": "connected", "data": "Connected to status stream"})
while True:
if await request.is_disconnected():
break
data = await event_emitter.get_message(channel)
if data is None:
await asyncio.sleep(0.1)
continue
# Send status event
yield json.dumps({"event": "status", "data": data})
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print(f"[EventEmitter] Client disconnected from status stream for sketch_id: {sketch_id}")
except Exception as e:
print(f"[EventEmitter] Error in stream_sketch_status: {str(e)}")
finally:
await event_emitter.unsubscribe(channel)
return EventSourceResponse(
status_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get("/status/scan/{scan_id}/stream")
async def stream_status(request: Request, scan_id: str, db: Session = Depends(get_db)):
"""Stream status updates for a specific scan in real-time"""

View File

@@ -43,6 +43,9 @@ class NodeInput(BaseModel):
default_factory=NodeData, description="Additional data for the node"
)
class Config:
extra = "allow" # Accept any additional fields
def dict_to_cypher_props(props: dict, prefix: str = "") -> str:
return ", ".join(f"{key}: ${prefix}{key}" for key in props)
@@ -117,11 +120,7 @@ def update_sketch(
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
sketch = (
db.query(Sketch)
.filter(Sketch.id == id)
.first()
)
sketch = db.query(Sketch).filter(Sketch.id == id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
@@ -140,11 +139,7 @@ def delete_sketch(
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
sketch = (
db.query(Sketch)
.filter(Sketch.id == id)
.first()
)
sketch = db.query(Sketch).filter(Sketch.id == id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
@@ -169,7 +164,7 @@ async def get_sketch_nodes(
id: str,
format: str = None,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
current_user: Profile = Depends(get_current_user),
):
"""
Get the nodes and relationships for a sketch.
@@ -184,11 +179,7 @@ async def get_sketch_nodes(
rls: []
Or if format=inline: List of inline relationship strings
"""
sketch = (
db.query(Sketch)
.filter(Sketch.id == id)
.first()
)
sketch = db.query(Sketch).filter(Sketch.id == id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Graph not found")
check_investigation_permission(
@@ -269,7 +260,7 @@ def add_node(
# Add created_at to parameters
properties_with_timestamp = {
**properties,
"created_at": datetime.now(timezone.utc).isoformat()
"created_at": datetime.now(timezone.utc).isoformat(),
}
create_query = f"""
@@ -337,7 +328,7 @@ def add_edge(
from_element_id=relation.source,
to_element_id=relation.target,
rel_type=relation.label,
sketch_id=sketch_id
sketch_id=sketch_id,
)
except Exception as e:
print(f"Edge creation error: {e}")
@@ -391,9 +382,7 @@ def edit_node(
try:
graph_repo = GraphRepository(neo4j_connection)
updated_node = graph_repo.update_node_by_element_id(
element_id=node_edit.nodeId,
sketch_id=sketch_id,
**properties
element_id=node_edit.nodeId, sketch_id=sketch_id, **properties
)
except Exception as e:
print(f"Node update error: {e}")
@@ -451,8 +440,7 @@ def update_node_positions(
try:
graph_repo = GraphRepository(neo4j_connection)
updated_count = graph_repo.update_nodes_positions(
positions=positions,
sketch_id=sketch_id
positions=positions, sketch_id=sketch_id
)
except Exception as e:
print(f"Position update error: {e}")
@@ -512,7 +500,9 @@ def delete_relationships(
# Delete relationships using GraphRepository
try:
graph_repo = GraphRepository(neo4j_connection)
deleted_count = graph_repo.delete_relationships(relationships.relationshipIds, sketch_id)
deleted_count = graph_repo.delete_relationships(
relationships.relationshipIds, sketch_id
)
except Exception as e:
print(f"Relationship deletion error: {e}")
raise HTTPException(status_code=500, detail="Failed to delete relationships")
@@ -669,7 +659,7 @@ def get_related_nodes(
sketch_id: str,
node_id: str,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
current_user: Profile = Depends(get_current_user),
):
# First verify the sketch exists and belongs to the user
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
@@ -840,11 +830,7 @@ async def analyze_import_file(
Each row represents one entity. Detects entity types and provides preview.
"""
# Verify sketch exists and user has access
sketch = (
db.query(Sketch)
.filter(Sketch.id == sketch_id)
.first()
)
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
@@ -925,11 +911,7 @@ async def execute_import(
import json
# Verify sketch exists and user has access
sketch = (
db.query(Sketch)
.filter(Sketch.id == sketch_id)
.first()
)
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
@@ -990,7 +972,7 @@ async def execute_import(
node_type=entity_type,
label=label,
sketch_id=sketch_id,
**flattened_data
**flattened_data,
)
if node_id:
nodes_created += 1