feat(api): add permission checks for analysis and sketches

This commit is contained in:
dextmorgn
2025-11-11 20:13:13 +01:00
parent 4b7f1da797
commit 3294065ee3
2 changed files with 73 additions and 17 deletions

View File

@@ -1,4 +1,5 @@
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from app.security.permissions import check_investigation_permission
from fastapi import APIRouter, HTTPException, Depends, status from fastapi import APIRouter, HTTPException, Depends, status
from typing import List from typing import List
from datetime import datetime from datetime import datetime
@@ -29,6 +30,9 @@ def create_analysis(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user), current_user: Profile = Depends(get_current_user),
): ):
check_investigation_permission(
current_user.id, payload.investigation_id, actions=["create"], db=db
)
new_analysis = Analysis( new_analysis = Analysis(
id=uuid4(), id=uuid4(),
title=payload.title, title=payload.title,
@@ -54,11 +58,14 @@ def get_analysis_by_id(
): ):
analysis = ( analysis = (
db.query(Analysis) db.query(Analysis)
.filter(Analysis.id == analysis_id, Analysis.owner_id == current_user.id) .filter(Analysis.id == analysis_id)
.first() .first()
) )
if not analysis: if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found") raise HTTPException(status_code=404, detail="Analysis not found")
check_investigation_permission(
current_user.id, analysis.investigation_id, actions=["read"], db=db
)
return analysis return analysis
@@ -69,12 +76,12 @@ def get_analyses_by_investigation(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user), current_user: Profile = Depends(get_current_user),
): ):
check_investigation_permission(
current_user.id, investigation_id, actions=["read"], db=db
)
analyses = ( analyses = (
db.query(Analysis) db.query(Analysis)
.filter( .filter(Analysis.investigation_id == investigation_id)
Analysis.investigation_id == investigation_id,
Analysis.owner_id == current_user.id,
)
.all() .all()
) )
return analyses return analyses
@@ -90,11 +97,14 @@ def update_analysis(
): ):
analysis = ( analysis = (
db.query(Analysis) db.query(Analysis)
.filter(Analysis.id == analysis_id, Analysis.owner_id == current_user.id) .filter(Analysis.id == analysis_id)
.first() .first()
) )
if not analysis: if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found") raise HTTPException(status_code=404, detail="Analysis not found")
check_investigation_permission(
current_user.id, analysis.investigation_id, actions=["update"], db=db
)
if payload.title is not None: if payload.title is not None:
analysis.title = payload.title analysis.title = payload.title
if payload.description is not None: if payload.description is not None:
@@ -102,6 +112,10 @@ def update_analysis(
if payload.content is not None: if payload.content is not None:
analysis.content = payload.content analysis.content = payload.content
if payload.investigation_id is not None: if payload.investigation_id is not None:
# Check permission for the new investigation as well
check_investigation_permission(
current_user.id, payload.investigation_id, actions=["update"], db=db
)
analysis.investigation_id = payload.investigation_id analysis.investigation_id = payload.investigation_id
analysis.last_updated_at = datetime.utcnow() analysis.last_updated_at = datetime.utcnow()
db.commit() db.commit()
@@ -118,11 +132,14 @@ def delete_analysis(
): ):
analysis = ( analysis = (
db.query(Analysis) db.query(Analysis)
.filter(Analysis.id == analysis_id, Analysis.owner_id == current_user.id) .filter(Analysis.id == analysis_id)
.first() .first()
) )
if not analysis: if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found") raise HTTPException(status_code=404, detail="Analysis not found")
check_investigation_permission(
current_user.id, analysis.investigation_id, actions=["delete"], db=db
)
db.delete(analysis) db.delete(analysis)
db.commit() db.commit()
return None return None

View File

@@ -98,6 +98,9 @@ def get_sketch_by_id(
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch: if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found") raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["read"], db=db
)
return sketch return sketch
@@ -110,12 +113,14 @@ def update_sketch(
): ):
sketch = ( sketch = (
db.query(Sketch) db.query(Sketch)
.filter(Sketch.owner_id == current_user.id)
.filter(Sketch.id == id) .filter(Sketch.id == id)
.first() .first()
) )
if not sketch: if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found") raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["update"], db=db
)
for key, value in payload.model_dump(exclude_unset=True).items(): for key, value in payload.model_dump(exclude_unset=True).items():
setattr(sketch, key, value) setattr(sketch, key, value)
db.commit() db.commit()
@@ -131,11 +136,14 @@ def delete_sketch(
): ):
sketch = ( sketch = (
db.query(Sketch) db.query(Sketch)
.filter(Sketch.id == id, Sketch.owner_id == current_user.id) .filter(Sketch.id == id)
.first() .first()
) )
if not sketch: if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found") raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["delete"], db=db
)
# Delete all nodes and relationships in Neo4j first # Delete all nodes and relationships in Neo4j first
neo4j_query = """ neo4j_query = """
@@ -158,7 +166,7 @@ async def get_sketch_nodes(
id: str, id: str,
format: str = None, format: str = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
# current_user: Profile = Depends(get_current_user) current_user: Profile = Depends(get_current_user)
): ):
""" """
Get the nodes and relationships for a sketch. Get the nodes and relationships for a sketch.
@@ -175,14 +183,14 @@ async def get_sketch_nodes(
""" """
sketch = ( sketch = (
db.query(Sketch) db.query(Sketch)
.filter( .filter(Sketch.id == id)
Sketch.id == id,
# Sketch.owner_id == current_user.id
)
.first() .first()
) )
if not sketch: if not sketch:
raise HTTPException(status_code=404, detail="Graph not found") raise HTTPException(status_code=404, detail="Graph not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["read"], db=db
)
import random import random
nodes_query = """ nodes_query = """
@@ -241,6 +249,13 @@ def add_node(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user), current_user: Profile = Depends(get_current_user),
): ):
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["update"], db=db
)
node_data = node.data.model_dump() node_data = node.data.model_dump()
node_type = node_data["type"] node_type = node_data["type"]
@@ -308,6 +323,12 @@ def add_edge(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user), current_user: Profile = Depends(get_current_user),
): ):
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["update"], db=db
)
query = f""" query = f"""
MATCH (a) WHERE elementId(a) = $from_id MATCH (a) WHERE elementId(a) = $from_id
@@ -350,6 +371,9 @@ def edit_node(
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch: if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found") raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["update"], db=db
)
node_data = node_edit.data.model_dump() node_data = node_edit.data.model_dump()
node_type = node_data.get("type", "Node") node_type = node_data.get("type", "Node")
@@ -409,6 +433,9 @@ def delete_nodes(
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch: if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found") raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["update"], db=db
)
# Delete nodes and their relationships # Delete nodes and their relationships
query = """ query = """
@@ -443,6 +470,9 @@ def merge_nodes(
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first() sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch: if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found") raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["update"], db=db
)
if not oldNodes or len(oldNodes) == 0: if not oldNodes or len(oldNodes) == 0:
raise HTTPException(status_code=400, detail="oldNodes cannot be empty") raise HTTPException(status_code=400, detail="oldNodes cannot be empty")
@@ -576,9 +606,12 @@ def get_related_nodes(
current_user: Profile = Depends(get_current_user) current_user: Profile = Depends(get_current_user)
): ):
# First verify the sketch exists and belongs to the user # First verify the sketch exists and belongs to the user
sketch = db.query(Sketch).filter(Sketch.id == sketch_id, Sketch.owner_id == current_user.id).first() sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch: if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found") raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["read"], db=db
)
# Query to get all direct relationships and connected nodes # Query to get all direct relationships and connected nodes
# First, let's get the center node # First, let's get the center node
@@ -743,11 +776,14 @@ async def analyze_import_file(
# Verify sketch exists and user has access # Verify sketch exists and user has access
sketch = ( sketch = (
db.query(Sketch) db.query(Sketch)
.filter(Sketch.id == sketch_id, Sketch.owner_id == current_user.id) .filter(Sketch.id == sketch_id)
.first() .first()
) )
if not sketch: if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found") raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["read"], db=db
)
# Read file content # Read file content
try: try:
@@ -825,11 +861,14 @@ async def execute_import(
# Verify sketch exists and user has access # Verify sketch exists and user has access
sketch = ( sketch = (
db.query(Sketch) db.query(Sketch)
.filter(Sketch.id == sketch_id, Sketch.owner_id == current_user.id) .filter(Sketch.id == sketch_id)
.first() .first()
) )
if not sketch: if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found") raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["update"], db=db
)
# Parse entity mappings # Parse entity mappings
try: try: