feat: python formatter finally working

This commit is contained in:
dextmorgn
2025-08-13 20:08:06 +02:00
parent 4d5f96bb8d
commit 7620a6d145
148 changed files with 6777 additions and 3269 deletions

View File

@@ -10,15 +10,25 @@ from app.api.schemas.analysis import AnalysisRead, AnalysisCreate, AnalysisUpdat
router = APIRouter()
# Get the list of all analyses for the current user
@router.get("", response_model=List[AnalysisRead])
def get_analyses(db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def get_analyses(
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
):
analyses = db.query(Analysis).filter(Analysis.owner_id == current_user.id).all()
return analyses
# Create a new analysis
@router.post("/create", response_model=AnalysisRead, status_code=status.HTTP_201_CREATED)
def create_analysis(payload: AnalysisCreate, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
@router.post(
"/create", response_model=AnalysisRead, status_code=status.HTTP_201_CREATED
)
def create_analysis(
payload: AnalysisCreate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
new_analysis = Analysis(
id=uuid4(),
title=payload.title,
@@ -34,31 +44,55 @@ def create_analysis(payload: AnalysisCreate, db: Session = Depends(get_db), curr
db.refresh(new_analysis)
return new_analysis
# Get an analysis by ID
@router.get("/{analysis_id}", response_model=AnalysisRead)
def get_analysis_by_id(analysis_id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
analysis = db.query(Analysis).filter(Analysis.id == analysis_id, Analysis.owner_id == current_user.id).first()
def get_analysis_by_id(
analysis_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
analysis = (
db.query(Analysis)
.filter(Analysis.id == analysis_id, Analysis.owner_id == current_user.id)
.first()
)
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
return analysis
# Get analyses by investigation ID
@router.get("/investigation/{investigation_id}", response_model=List[AnalysisRead])
def get_analyses_by_investigation(
investigation_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
current_user: Profile = Depends(get_current_user),
):
analyses = db.query(Analysis).filter(
Analysis.investigation_id == investigation_id,
Analysis.owner_id == current_user.id
).all()
analyses = (
db.query(Analysis)
.filter(
Analysis.investigation_id == investigation_id,
Analysis.owner_id == current_user.id,
)
.all()
)
return analyses
# Update an analysis by ID
@router.put("/{analysis_id}", response_model=AnalysisRead)
def update_analysis(analysis_id: UUID, payload: AnalysisUpdate, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
analysis = db.query(Analysis).filter(Analysis.id == analysis_id, Analysis.owner_id == current_user.id).first()
def update_analysis(
analysis_id: UUID,
payload: AnalysisUpdate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
analysis = (
db.query(Analysis)
.filter(Analysis.id == analysis_id, Analysis.owner_id == current_user.id)
.first()
)
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
if payload.title is not None:
@@ -74,12 +108,21 @@ def update_analysis(analysis_id: UUID, payload: AnalysisUpdate, db: Session = De
db.refresh(analysis)
return analysis
# Delete an analysis by ID
@router.delete("/{analysis_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_analysis(analysis_id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
analysis = db.query(Analysis).filter(Analysis.id == analysis_id, Analysis.owner_id == current_user.id).first()
def delete_analysis(
analysis_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
analysis = (
db.query(Analysis)
.filter(Analysis.id == analysis_id, Analysis.owner_id == current_user.id)
.first()
)
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
db.delete(analysis)
db.commit()
return None
return None

View File

@@ -10,14 +10,18 @@ from flowsint_core.core.postgre_db import get_db
router = APIRouter()
@router.post("/token")
def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)
):
user = db.query(Profile).filter(Profile.email == form_data.username).first()
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(status_code=400, detail="Incorrect email or password")
access_token = create_access_token(data={"sub": user.email})
return {"access_token": access_token, "user_id": user.id, "token_type": "bearer"}
@router.post("/register", status_code=201)
def register(user: ProfileCreate, db: Session = Depends(get_db)):
print(user)
@@ -30,6 +34,5 @@ def register(user: ProfileCreate, db: Session = Depends(get_db)):
db.add(new_user)
db.commit()
db.refresh(new_user)
return {"message": "User registered successfully", "email": new_user.email}

View File

@@ -14,6 +14,7 @@ from flowsint_core.core.postgre_db import get_db
from app.models.models import Chat, ChatMessage, Profile
from app.api.deps import get_current_user
from app.api.schemas.chat import ChatCreate, ChatRead
router = APIRouter()
@@ -25,22 +26,23 @@ def clean_context(context: List[Dict]) -> List[Dict]:
if isinstance(item, dict):
# Create a copy and remove unwanted keys
cleaned_item = item["data"].copy()
# Remove top-level keys
cleaned_item.pop('id', None)
cleaned_item.pop('sketch_id', None)
cleaned_item.pop("id", None)
cleaned_item.pop("sketch_id", None)
# Remove from data if it exists
if 'data' in cleaned_item and isinstance(cleaned_item['data'], dict):
cleaned_item['data'].pop('sketch_id', None)
if "data" in cleaned_item and isinstance(cleaned_item["data"], dict):
cleaned_item["data"].pop("sketch_id", None)
# Remove measured/dimensions
cleaned_item.pop('measured', None)
cleaned_item.pop("measured", None)
cleaned.append(cleaned_item)
print(cleaned)
return cleaned
class ChatRequest(BaseModel):
prompt: str
context: Optional[List[Dict]] = None
@@ -49,17 +51,14 @@ class ChatRequest(BaseModel):
# Get all chats
@router.get("/", response_model=List[ChatRead])
def get_chats(
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
):
chats = db.query(Chat).filter(
Chat.owner_id == current_user.id
).all()
chats = db.query(Chat).filter(Chat.owner_id == current_user.id).all()
# Sort messages for each chat by created_at in ascending order
for chat in chats:
chat.messages.sort(key=lambda x: x.created_at)
return chats
@@ -69,7 +68,7 @@ def get_chats(
# for chat in chats:
# db.delete(chat)
# db.commit()
# return {"result": "done"}
# return {"result": "done"}
# Get analyses by investigation ID
@@ -77,26 +76,39 @@ def get_chats(
def get_chats_by_investigation(
investigation_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
current_user: Profile = Depends(get_current_user),
):
chats = db.query(Chat).filter(
Chat.investigation_id == investigation_id,
Chat.owner_id == current_user.id
).all()
chats = (
db.query(Chat)
.filter(
Chat.investigation_id == investigation_id, Chat.owner_id == current_user.id
)
.all()
)
# Sort messages for each chat by created_at in ascending order
for chat in chats:
chat.messages.sort(key=lambda x: x.created_at)
return chats
@router.post("/stream/{chat_id}")
async def stream_chat(chat_id: UUID, payload: ChatRequest, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
async def stream_chat(
chat_id: UUID,
payload: ChatRequest,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
# Check if Chat exists
chat = db.query(Chat).filter(Chat.id == chat_id, Chat.owner_id == current_user.id).first()
chat = (
db.query(Chat)
.filter(Chat.id == chat_id, Chat.owner_id == current_user.id)
.first()
)
if not chat:
raise HTTPException(status_code=404, detail="Chat not found")
# Update chat's last_updated_at
chat.last_updated_at = datetime.utcnow()
db.commit()
@@ -112,19 +124,25 @@ async def stream_chat(chat_id: UUID, payload: ChatRequest, db: Session = Depends
db.add(user_message)
db.commit()
db.refresh(user_message)
try:
api_key = os.environ.get("MISTRAL_API_KEY")
if not api_key:
raise HTTPException(status_code=500, detail="Mistral API key not configured")
raise HTTPException(
status_code=500, detail="Mistral API key not configured"
)
client = Mistral(api_key=api_key)
model = "mistral-small-latest"
accumulated_content = []
context_message = None
# Convert database messages to Mistral format
messages = [SystemMessage(content="You are a CTI/OSINT investigator and you are trying to investigate on a variety of real life cases. Use your knowledge and analytics capabilities to analyse the context and answer the question the best you can. If you need to reference some items (an IP, a domain or something particular) please use the code brackets, like : `12.23.34.54` to reference it.")]
messages = [
SystemMessage(
content="You are a CTI/OSINT investigator and you are trying to investigate on a variety of real life cases. Use your knowledge and analytics capabilities to analyse the context and answer the question the best you can. If you need to reference some items (an IP, a domain or something particular) please use the code brackets, like : `12.23.34.54` to reference it."
)
]
# Add context as a single system message if provided
if payload.context:
try:
@@ -139,16 +157,22 @@ async def stream_chat(chat_id: UUID, payload: ChatRequest, db: Session = Depends
except Exception as e:
# If context processing fails, skip it
print(f"Context processing error: {e}")
# Sort messages by created_at in ascending order and get recent messages
sorted_messages = sorted(chat.messages, key=lambda x: x.created_at)
recent_messages = sorted_messages[-5:] if len(sorted_messages) > 5 else sorted_messages
recent_messages = (
sorted_messages[-5:] if len(sorted_messages) > 5 else sorted_messages
)
for message in recent_messages:
if message.is_bot:
messages.append(AssistantMessage(content=json.dumps(message.content, default=str)))
messages.append(
AssistantMessage(content=json.dumps(message.content, default=str))
)
else:
messages.append(UserMessage(content=json.dumps(message.content, default=str)))
messages.append(
UserMessage(content=json.dumps(message.content, default=str))
)
# Add the current context
if context_message:
messages.append(SystemMessage(content=context_message))
@@ -156,17 +180,14 @@ async def stream_chat(chat_id: UUID, payload: ChatRequest, db: Session = Depends
messages.append(UserMessage(content=payload.prompt))
async def generate():
response = await client.chat.stream_async(
model=model,
messages=messages
)
response = await client.chat.stream_async(model=model, messages=messages)
async for chunk in response:
if chunk.data.choices[0].delta.content is not None:
content_chunk = chunk.data.choices[0].delta.content
accumulated_content.append(content_chunk)
yield f"data: {json.dumps({'content': content_chunk})}\n\n"
# Create the bot message after all chunks have been processed
chat_message = ChatMessage(
id=uuid4(),
@@ -175,24 +196,25 @@ async def stream_chat(chat_id: UUID, payload: ChatRequest, db: Session = Depends
is_bot=True,
created_at=datetime.utcnow(),
)
db.add(chat_message)
db.commit()
db.refresh(chat_message)
yield "data: [DONE]\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)
return StreamingResponse(generate(), media_type="text/event-stream")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Create a new chat
@router.post("/create", response_model=ChatRead, status_code=status.HTTP_201_CREATED)
def create_chat(payload: ChatCreate, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def create_chat(
payload: ChatCreate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
new_chat = Chat(
id=uuid4(),
title=payload.title,
@@ -207,29 +229,42 @@ def create_chat(payload: ChatCreate, db: Session = Depends(get_db), current_user
db.refresh(new_chat)
return new_chat
# Get a chat by ID
@router.get("/{chat_id}", response_model=ChatRead)
def get_chat_by_id(chat_id: UUID, db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
):
chat = db.query(Chat).filter(Chat.id == chat_id,
Chat.owner_id == current_user.id
).first()
def get_chat_by_id(
chat_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
chat = (
db.query(Chat)
.filter(Chat.id == chat_id, Chat.owner_id == current_user.id)
.first()
)
if not chat:
raise HTTPException(status_code=404, detail="Chat not found")
# Sort messages by created_at in ascending order
chat.messages.sort(key=lambda x: x.created_at)
return chat
# Delete an chat by ID
@router.delete("/{chat_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_chat(chat_id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
chat = db.query(Chat).filter(Chat.id == chat_id, Chat.owner_id == current_user.id).first()
def delete_chat(
chat_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
chat = (
db.query(Chat)
.filter(Chat.id == chat_id, Chat.owner_id == current_user.id)
.first()
)
if not chat:
raise HTTPException(status_code=404, detail="Chat not found")
db.delete(chat)
db.commit()
return None
return None

View File

@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
router = APIRouter()
@router.get("/sketch/{sketch_id}/logs")
def get_logs_by_sketch(
sketch_id: str,
@@ -20,27 +21,33 @@ def get_logs_by_sketch(
since: datetime | None = None,
db: Session = Depends(get_db),
# current_user: Profile = Depends(get_current_user)
):
):
"""Get historical logs for a specific sketch with optional filtering"""
# Check if sketch exists
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail=f"Sketch with id {sketch_id} not found")
raise HTTPException(
status_code=404, detail=f"Sketch with id {sketch_id} not found"
)
print(
f"[EventEmitter] Fetching logs for sketch {sketch_id} (limit: {limit}, since: {since})"
)
query = (
db.query(Log).filter(Log.sketch_id == sketch_id).order_by(Log.created_at.desc())
)
print(f"[EventEmitter] Fetching logs for sketch {sketch_id} (limit: {limit}, since: {since})")
query = db.query(Log).filter(Log.sketch_id == sketch_id).order_by(Log.created_at.desc())
if since:
query = query.filter(Log.created_at > since)
else:
# Default to last 24 hours if no since parameter
query = query.filter(Log.created_at > datetime.utcnow() - timedelta(days=1))
logs = query.limit(limit).all()
# Reverse to show chronologically (oldest to newest)
logs = list(reversed(logs))
results = []
for log in logs:
# Ensure payload is always a dictionary
@@ -53,53 +60,58 @@ def get_logs_by_sketch(
else:
# Handle other types by converting to string and wrapping
payload = {"content": str(log.content)}
results.append(Event(
id=str(log.id),
sketch_id=str(log.sketch_id) if log.sketch_id else None,
type=log.type,
payload=payload
))
results.append(
Event(
id=str(log.id),
sketch_id=str(log.sketch_id) if log.sketch_id else None,
type=log.type,
payload=payload,
)
)
return results
@router.get("/sketch/{sketch_id}/stream")
async def stream_events(request: Request, sketch_id: str,
db: Session = Depends(get_db),
# current_user: Profile = Depends(get_current_user)
):
async def stream_events(
request: Request,
sketch_id: str,
db: Session = Depends(get_db),
# current_user: Profile = Depends(get_current_user)
):
"""Stream events for a specific scan in real-time"""
# Check if sketch exists
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail=f"Sketch with id {sketch_id} not found")
raise HTTPException(
status_code=404, detail=f"Sketch with id {sketch_id} not found"
)
async def event_generator():
channel = sketch_id
await event_emitter.subscribe(channel)
try:
# Initial connection message
yield "data: {\"event\": \"connected\", \"data\": \"Connected to log stream\"}\n\n"
yield 'data: {"event": "connected", "data": "Connected to log stream"}\n\n'
while True:
if await request.is_disconnected():
break
data = await event_emitter.get_message(channel)
if data is None:
await asyncio.sleep(.1) # avoid tight loop on None
await asyncio.sleep(0.1) # avoid tight loop on None
continue
# Handle different types of events
if isinstance(data, dict) and data.get('type') == 'scanner_complete':
if isinstance(data, dict) and data.get("type") == "scanner_complete":
# Send scanner completion event
yield json.dumps({'event': 'scanner_complete', 'data': data})
yield json.dumps({"event": "scanner_complete", "data": data})
else:
# Send regular log event
yield json.dumps({'event': 'log', 'data': data})
await asyncio.sleep(.1)
yield json.dumps({"event": "log", "data": data})
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print(f"[EventEmitter] Client disconnected from sketch_id: {sketch_id}")
@@ -114,16 +126,17 @@ async def stream_events(request: Request, sketch_id: str,
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
"X-Accel-Buffering": "no",
},
)
@router.delete("/sketch/{sketch_id}/logs")
def delete_scan_logs(
sketch_id:str,
sketch_id: str,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
):
current_user: Profile = Depends(get_current_user),
):
"""Delete all logs for a specific scan"""
try:
db.query(Log).filter(Log.sketch_id == sketch_id).delete()
@@ -133,9 +146,9 @@ def delete_scan_logs(
db.rollback()
raise HTTPException(status_code=500, detail=f"Failed to delete logs: {str(e)}")
@router.get("/status/scan/{scan_id}/stream")
async def stream_status(request: Request, scan_id: str,
db: Session = Depends(get_db)):
async def stream_status(request: Request, scan_id: str, db: Session = Depends(get_db)):
"""Stream status updates for a specific scan in real-time"""
async def status_generator():
@@ -143,7 +156,7 @@ async def stream_status(request: Request, scan_id: str,
await event_emitter.subscribe(f"scan_{scan_id}_status")
try:
# Initial connection message
yield "data: {\"event\": \"connected\", \"data\": \"Connected to status stream\"}\n\n"
yield 'data: {"event": "connected", "data": "Connected to status stream"}\n\n'
while True:
data = await event_emitter.get_message(f"scan_{scan_id}_status")
@@ -154,7 +167,9 @@ async def stream_status(request: Request, scan_id: str,
yield f"data: {data}\n\n"
except asyncio.CancelledError:
print(f"[EventEmitter] Client disconnected from status stream for scan_id: {scan_id}")
print(
f"[EventEmitter] Client disconnected from status stream for scan_id: {scan_id}"
)
finally:
await event_emitter.unsubscribe(f"scan_{scan_id}_status")
@@ -164,6 +179,6 @@ async def stream_status(request: Request, scan_id: str,
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
"X-Accel-Buffering": "no",
},
)

View File

@@ -6,21 +6,37 @@ from sqlalchemy.orm import Session, selectinload
from flowsint_core.core.postgre_db import get_db
from app.models.models import Analysis, Investigation, Profile, Sketch
from app.api.deps import get_current_user
from app.api.schemas.investigation import InvestigationRead, InvestigationCreate, InvestigationUpdate
from app.api.schemas.investigation import (
InvestigationRead,
InvestigationCreate,
InvestigationUpdate,
)
from app.api.schemas.sketch import SketchRead
from flowsint_core.core.graph_db import neo4j_connection
router = APIRouter()
# Get the list of all investigations
@router.get("", response_model=List[InvestigationRead])
def get_investigations(db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
investigations = db.query(Investigation).filter(Investigation.owner_id == current_user.id).all()
def get_investigations(
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
):
investigations = (
db.query(Investigation).filter(Investigation.owner_id == current_user.id).all()
)
return investigations
# Create a new investigation
@router.post("/create", response_model=InvestigationRead, status_code=status.HTTP_201_CREATED)
def create_investigation(payload: InvestigationCreate, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
@router.post(
"/create", response_model=InvestigationRead, status_code=status.HTTP_201_CREATED
)
def create_investigation(
payload: InvestigationCreate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
new_investigation = Investigation(
id=uuid4(),
name=payload.name,
@@ -35,33 +51,57 @@ def create_investigation(payload: InvestigationCreate, db: Session = Depends(get
db.refresh(new_investigation)
return new_investigation
# Get a investigation by ID
@router.get("/{investigation_id}", response_model=InvestigationRead)
def get_investigation_by_id(investigation_id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
investigation = db.query(Investigation).options(selectinload(Investigation.sketches)).filter(Investigation.id == investigation_id).filter(Investigation.owner_id == current_user.id).first()
def get_investigation_by_id(
investigation_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
investigation = (
db.query(Investigation)
.options(selectinload(Investigation.sketches))
.filter(Investigation.id == investigation_id)
.filter(Investigation.owner_id == current_user.id)
.first()
)
if not investigation:
raise HTTPException(status_code=404, detail="Investigation not found")
return investigation
# Get a investigation by ID
@router.get("/{investigation_id}/sketches", response_model=List[SketchRead])
def get_sketches_by_investigation(
investigation_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
current_user: Profile = Depends(get_current_user),
):
sketches = db.query(Sketch).filter(Sketch.investigation_id == investigation_id).all()
sketches = (
db.query(Sketch).filter(Sketch.investigation_id == investigation_id).all()
)
if not sketches:
raise HTTPException(status_code=404, detail="No sketches found for this investigation")
raise HTTPException(
status_code=404, detail="No sketches found for this investigation"
)
return sketches
# Update a investigation by ID
@router.put("/{investigation_id}", response_model=InvestigationRead)
def update_investigation(investigation_id: UUID, payload: InvestigationUpdate, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
investigation = db.query(Investigation).filter(Investigation.id == investigation_id).first()
def update_investigation(
investigation_id: UUID,
payload: InvestigationUpdate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
investigation = (
db.query(Investigation).filter(Investigation.id == investigation_id).first()
)
if not investigation:
raise HTTPException(status_code=404, detail="Investigation not found")
investigation.name = payload.name
investigation.description = payload.description
investigation.status = payload.status
@@ -71,18 +111,33 @@ def update_investigation(investigation_id: UUID, payload: InvestigationUpdate, d
db.refresh(investigation)
return investigation
# Delete a investigation by ID
@router.delete("/{investigation_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_investigation(investigation_id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
investigation = db.query(Investigation).filter(Investigation.id == investigation_id, Investigation.owner_id == current_user.id).first()
def delete_investigation(
investigation_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
investigation = (
db.query(Investigation)
.filter(
Investigation.id == investigation_id,
Investigation.owner_id == current_user.id,
)
.first()
)
if not investigation:
raise HTTPException(status_code=404, detail="Investigation not found")
# Get all sketches related to this investigation
sketches = db.query(Sketch).filter(Sketch.investigation_id == investigation_id).all()
analyses = db.query(Analysis).filter(Sketch.investigation_id == investigation_id).all()
# Get all sketches related to this investigation
sketches = (
db.query(Sketch).filter(Sketch.investigation_id == investigation_id).all()
)
analyses = (
db.query(Analysis).filter(Sketch.investigation_id == investigation_id).all()
)
# Delete all nodes and relationships for each sketch in Neo4j
for sketch in sketches:
neo4j_query = """
@@ -94,13 +149,13 @@ def delete_investigation(investigation_id: UUID, db: Session = Depends(get_db),
except Exception as e:
print(f"Neo4j cleanup error for sketch {sketch.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to clean up graph data")
# Delete all sketches from PostgreSQL
for sketch in sketches:
db.delete(sketch)
for analysis in analyses:
db.delete(analysis)
# Finally delete the investigation
db.delete(investigation)
db.commit()

View File

@@ -10,45 +10,62 @@ from datetime import datetime
router = APIRouter()
def obfuscate_key(key: str) -> str:
"""Obfuscate a key by showing only the last 4 characters, replacing others with asterisks."""
if len(key) <= 4:
return key
return "*" * (len(key) - 4) + key[-4:]
# Get the list of all keys for a user
@router.get("", response_model=List[KeyRead])
def get_keys(db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def get_keys(
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
):
keys = db.query(Key).filter(Key.owner_id == current_user.id).all()
response_data = [KeyRead(
id=key.id,
owner_id=key.owner_id,
encrypted_key=obfuscate_key(key.encrypted_key),
name=key.name,
created_at=key.created_at
) for key in keys]
response_data = [
KeyRead(
id=key.id,
owner_id=key.owner_id,
encrypted_key=obfuscate_key(key.encrypted_key),
name=key.name,
created_at=key.created_at,
)
for key in keys
]
return response_data
# Get a key by ID
@router.get("/{id}", response_model=KeyRead)
def get_key_by_id(id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def get_key_by_id(
id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
key = db.query(Key).filter(Key.id == id, Key.owner_id == current_user.id).first()
if not key:
raise HTTPException(status_code=404, detail="Key not found")
# Create a response with obfuscated key
response_data = KeyRead(
id=key.id,
owner_id=key.owner_id,
encrypted_key=obfuscate_key(key.encrypted_key),
name=key.name,
created_at=key.created_at
created_at=key.created_at,
)
return response_data
# Create a new key
@router.post("/create", response_model=KeyRead, status_code=status.HTTP_201_CREATED)
def create_key(payload: KeyCreate, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def create_key(
payload: KeyCreate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
new_key = Key(
id=uuid4(),
name=payload.name,
@@ -61,12 +78,17 @@ def create_key(payload: KeyCreate, db: Session = Depends(get_db), current_user:
db.refresh(new_key)
return new_key
# Delete a key by ID
@router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_key(id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def delete_key(
id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
key = db.query(Key).filter(Key.id == id, Key.owner_id == current_user.id).first()
if not key:
raise HTTPException(status_code=404, detail="Key not found")
db.delete(key)
db.commit()
return None
return None

View File

@@ -9,20 +9,32 @@ from app.api.schemas.scan import ScanRead
router = APIRouter()
# Get the list of all scans
@router.get("", response_model=List[ScanRead],)
def get_scans(db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
@router.get(
"",
response_model=List[ScanRead],
)
def get_scans(
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
):
scans = db.query(Scan).all()
return scans
# Get a scan by ID
@router.get("/{id}", response_model=ScanRead)
def get_scan_by_id(id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def get_scan_by_id(
id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
scan = db.query(Scan).filter(Scan.id == id).first()
if not scan:
raise HTTPException(status_code=404, detail="Transform not found")
return scan
# Delete a scan by ID
@router.delete("", status_code=status.HTTP_204_NO_CONTENT)
def delete_scan(db: Session = Depends(get_db)):
@@ -30,14 +42,21 @@ def delete_scan(db: Session = Depends(get_db)):
db.commit()
return None
# Delete a scan by ID
@router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_scan(id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
scan = db.query(Scan).filter(Scan.id == id, Scan.owner_id == current_user.id).first()
def delete_scan(
id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
scan = (
db.query(Scan).filter(Scan.id == id, Scan.owner_id == current_user.id).first()
)
if not scan:
raise HTTPException(status_code=404, detail="Scan not found")
scan = db.query(Scan).filter(Scan.id == id).all()
# Finally delete the scan
db.delete(scan)
db.commit()

View File

@@ -16,14 +16,15 @@ from app.api.deps import get_current_user
router = APIRouter()
@router.post("/create", response_model=SketchRead, status_code=status.HTTP_201_CREATED)
def create_sketch(
data: SketchCreate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
current_user: Profile = Depends(get_current_user),
):
sketch_data = data.dict()
sketch_data['owner_id'] = current_user.id
sketch_data["owner_id"] = current_user.id
sketch = Sketch(**sketch_data)
db.add(sketch)
db.commit()
@@ -32,18 +33,31 @@ def create_sketch(
@router.get("", response_model=List[SketchRead])
def list_sketches(db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def list_sketches(
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
):
return db.query(Sketch).filter(Sketch.owner_id == current_user.id).all()
@router.get("/{sketch_id}")
def get_sketch_by_id(sketch_id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def get_sketch_by_id(
sketch_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
return sketch
@router.put("/{id}", response_model=SketchRead)
def update_sketch(id: UUID, data: SketchUpdate, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def update_sketch(
id: UUID,
data: SketchUpdate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
sketch = db.query(Sketch).filter(Sketch.owner_id == current_user.id).get(id)
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
@@ -53,12 +67,21 @@ def update_sketch(id: UUID, data: SketchUpdate, db: Session = Depends(get_db), c
db.refresh(sketch)
return sketch
@router.delete("/{id}", status_code=204)
def delete_sketch(id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
sketch = db.query(Sketch).filter(Sketch.id == id, Sketch.owner_id == current_user.id).first()
def delete_sketch(
id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
sketch = (
db.query(Sketch)
.filter(Sketch.id == id, Sketch.owner_id == current_user.id)
.first()
)
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
# Delete all nodes and relationships in Neo4j first
neo4j_query = """
MATCH (n {sketch_id: $sketch_id})
@@ -69,16 +92,17 @@ def delete_sketch(id: UUID, db: Session = Depends(get_db), current_user: Profile
except Exception as e:
print(f"Neo4j cleanup error: {e}")
raise HTTPException(status_code=500, detail="Failed to clean up graph data")
# Then delete the sketch from PostgreSQL
db.delete(sketch)
db.commit()
@router.get("/{id}/graph")
async def get_sketch_nodes(
id: str,
id: str,
format: str = None,
db: Session = Depends(get_db),
db: Session = Depends(get_db),
# current_user: Profile = Depends(get_current_user)
):
"""
@@ -94,12 +118,18 @@ async def get_sketch_nodes(
rls: []
Or if format=inline: List of inline relationship strings
"""
sketch = db.query(Sketch).filter(Sketch.id == id,
# Sketch.owner_id == current_user.id
).first()
sketch = (
db.query(Sketch)
.filter(
Sketch.id == id,
# Sketch.owner_id == current_user.id
)
.first()
)
if not sketch:
raise HTTPException(status_code=404, detail="Graph not found")
import random
nodes_query = """
MATCH (n)
WHERE n.sketch_id = $sketch_id
@@ -126,11 +156,11 @@ async def get_sketch_nodes(
"label": record["data"].get("label", "Node"),
"type": "custom",
"caption": record["data"].get("label", "Node"),
"position":{
"position": {
"x": random.random() * 1000,
"y": random.random() * 1000,
},
"idx": idx
"idx": idx,
}
for idx, record in enumerate(nodes_result)
]
@@ -150,10 +180,12 @@ async def get_sketch_nodes(
if format == "inline":
from flowsint_core.utils import get_inline_relationships
return get_inline_relationships(nodes, rels)
return {"nds": nodes, "rls": rels}
class NodeData(BaseModel):
label: str = Field(default="Node", description="Label/name of the node")
color: str = Field(default="Node", description="Color of the node")
@@ -163,17 +195,24 @@ class NodeData(BaseModel):
class Config:
extra = "allow" # Accept any additional fields
class NodeInput(BaseModel):
type: str = Field(..., description="Type of the node")
data: NodeData = Field(default_factory=NodeData, description="Additional data for the node")
data: NodeData = Field(
default_factory=NodeData, description="Additional data for the node"
)
def dict_to_cypher_props(props: dict, prefix: str = "") -> str:
return ", ".join(f"{key}: ${prefix}{key}" for key in props)
@router.post("/{sketch_id}/nodes/add")
def add_node(sketch_id: str, node: NodeInput, current_user: Profile = Depends(get_current_user)):
def add_node(
sketch_id: str, node: NodeInput, current_user: Profile = Depends(get_current_user)
):
node_data = node.data.model_dump()
node_type = node_data["type"]
properties = {
@@ -186,7 +225,6 @@ def add_node(sketch_id: str, node: NodeInput, current_user: Profile = Depends(ge
if node_data:
flattened_data = flatten(node_data)
properties.update(flattened_data)
cypher_props = dict_to_cypher_props(properties)
@@ -202,15 +240,19 @@ def add_node(sketch_id: str, node: NodeInput, current_user: Profile = Depends(ge
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
if not create_result:
raise HTTPException(status_code=400, detail="Node creation failed - no result returned")
raise HTTPException(
status_code=400, detail="Node creation failed - no result returned"
)
try:
new_node = create_result[0]["node"]
new_node["id"] = create_result[0]["id"]
except (IndexError, KeyError) as e:
print(f"Error extracting node_id: {e}, result: {create_result}")
raise HTTPException(status_code=500, detail="Failed to extract node data from response")
raise HTTPException(
status_code=500, detail="Failed to extract node data from response"
)
new_node["data"] = node_data
new_node["data"]["id"] = new_node["id"]
@@ -219,14 +261,20 @@ def add_node(sketch_id: str, node: NodeInput, current_user: Profile = Depends(ge
"node": new_node,
}
class RelationInput(BaseModel):
source: Any
target: Any
type: Literal["one-way", "two-way"]
label: str = "RELATED_TO" # Optionnel : nom de la relation
@router.post("/{sketch_id}/relations/add")
def add_edge(sketch_id: str, relation: RelationInput, current_user: Profile = Depends(get_current_user)):
def add_edge(
sketch_id: str,
relation: RelationInput,
current_user: Profile = Depends(get_current_user),
):
query = f"""
MATCH (a) WHERE elementId(a) = $from_id
@@ -255,85 +303,87 @@ def add_edge(sketch_id: str, relation: RelationInput, current_user: Profile = De
"edge": result[0]["r"],
}
class NodeDeleteInput(BaseModel):
nodeIds: List[str]
class NodeEditInput(BaseModel):
nodeId: str
data: NodeData = Field(default_factory=NodeData, description="Updated data for the node")
data: NodeData = Field(
default_factory=NodeData, description="Updated data for the node"
)
@router.put("/{sketch_id}/nodes/edit")
def edit_node(
sketch_id: str,
node_edit: NodeEditInput,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
current_user: Profile = Depends(get_current_user),
):
# First verify the sketch exists and belongs to the user
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
node_data = node_edit.data.model_dump()
node_type = node_data.get("type", "Node")
# Prepare properties to update
properties = {
"type": node_type.lower(),
"caption": node_data.get("label", "Node"),
"label": node_data.get("label", "Node"),
}
# Add any additional data from the flattened node_data
if node_data:
flattened_data = flatten(node_data)
properties.update(flattened_data)
# Build the SET clause for the Cypher query
set_clause = ", ".join(f"n.{key} = ${key}" for key in properties.keys())
query = f"""
MATCH (n)
WHERE elementId(n) = $node_id AND n.sketch_id = $sketch_id
SET {set_clause}
RETURN n as node
"""
params = {
"node_id": node_edit.nodeId,
"sketch_id": sketch_id,
**properties
}
params = {"node_id": node_edit.nodeId, "sketch_id": sketch_id, **properties}
try:
result = neo4j_connection.query(query, params)
except Exception as e:
print(f"Node update error: {e}")
raise HTTPException(status_code=500, detail="Failed to update node")
if not result:
raise HTTPException(status_code=404, detail="Node not found or not accessible")
updated_node = result[0]["node"]
updated_node["data"] = node_data
return {
"status": "node updated",
"node": updated_node,
}
@router.delete("/{sketch_id}/nodes")
def delete_nodes(
sketch_id: str,
nodes: NodeDeleteInput,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
current_user: Profile = Depends(get_current_user),
):
# First verify the sketch exists and belongs to the user
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
# Delete nodes and their relationships
query = """
UNWIND $node_ids AS node_id
@@ -341,18 +391,18 @@ def delete_nodes(
WHERE elementId(n) = node_id AND n.sketch_id = $sketch_id
DETACH DELETE n
"""
try:
neo4j_connection.query(query, {
"node_ids": nodes.nodeIds,
"sketch_id": sketch_id
})
neo4j_connection.query(
query, {"node_ids": nodes.nodeIds, "sketch_id": sketch_id}
)
except Exception as e:
print(f"Node deletion error: {e}")
raise HTTPException(status_code=500, detail="Failed to delete nodes")
return {"status": "nodes deleted", "count": len(nodes.nodeIds)}
@router.get("/{sketch_id}/nodes/{node_id}")
def get_related_nodes(
sketch_id: str,
@@ -364,7 +414,7 @@ def get_related_nodes(
# sketch = db.query(Sketch).filter(Sketch.id == sketch_id, Sketch.owner_id == current_user.id).first()
# if not sketch:
# raise HTTPException(status_code=404, detail="Sketch not found")
# Query to get all direct relationships and connected nodes
# First, let's get the center node
center_query = """
@@ -372,19 +422,18 @@ def get_related_nodes(
WHERE elementId(n) = $node_id AND n.sketch_id = $sketch_id
RETURN elementId(n) as id, labels(n) as labels, properties(n) as data
"""
try:
center_result = neo4j_connection.query(center_query, {
"sketch_id": sketch_id,
"node_id": node_id
})
center_result = neo4j_connection.query(
center_query, {"sketch_id": sketch_id, "node_id": node_id}
)
except Exception as e:
print(f"Center node query error: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve center node")
if not center_result:
raise HTTPException(status_code=404, detail="Node not found")
# Now get all relationships and connected nodes
relationships_query = """
MATCH (n)
@@ -415,16 +464,15 @@ def get_related_nodes(
properties(other) as other_node_data,
'incoming' as direction
"""
try:
result = neo4j_connection.query(relationships_query, {
"sketch_id": sketch_id,
"node_id": node_id
})
result = neo4j_connection.query(
relationships_query, {"sketch_id": sketch_id, "node_id": node_id}
)
except Exception as e:
print(f"Related nodes query error: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve related nodes")
# Extract center node info
center_record = center_result[0]
center_node = {
@@ -433,59 +481,61 @@ def get_related_nodes(
"data": center_record["data"],
"label": center_record["data"].get("label", "Node"),
"type": "custom",
"caption": center_record["data"].get("label", "Node")
"caption": center_record["data"].get("label", "Node"),
}
# Collect all related nodes and relationships
related_nodes = []
relationships = []
seen_nodes = set()
seen_relationships = set()
for record in result:
# Skip if no relationship found
if not record["rel_id"]:
continue
# Add relationship if not seen
if record["rel_id"] not in seen_relationships:
if record["direction"] == "outgoing":
relationships.append({
"id": record["rel_id"],
"type": "straight",
"source": center_node["id"],
"target": record["other_node_id"],
"data": record["rel_data"],
"caption": record["rel_type"]
})
relationships.append(
{
"id": record["rel_id"],
"type": "straight",
"source": center_node["id"],
"target": record["other_node_id"],
"data": record["rel_data"],
"caption": record["rel_type"],
}
)
else: # incoming
relationships.append({
"id": record["rel_id"],
"type": "straight",
"source": record["other_node_id"],
"target": center_node["id"],
"data": record["rel_data"],
"caption": record["rel_type"]
})
relationships.append(
{
"id": record["rel_id"],
"type": "straight",
"source": record["other_node_id"],
"target": center_node["id"],
"data": record["rel_data"],
"caption": record["rel_type"],
}
)
seen_relationships.add(record["rel_id"])
# Add related node if not seen
if record["other_node_id"] and record["other_node_id"] not in seen_nodes:
related_nodes.append({
"id": record["other_node_id"],
"labels": record["other_node_labels"],
"data": record["other_node_data"],
"label": record["other_node_data"].get("label", "Node"),
"type": "custom",
"caption": record["other_node_data"].get("label", "Node")
})
related_nodes.append(
{
"id": record["other_node_id"],
"labels": record["other_node_labels"],
"data": record["other_node_data"],
"label": record["other_node_data"].get("label", "Node"),
"type": "custom",
"caption": record["other_node_data"].get("label", "Node"),
}
)
seen_nodes.add(record["other_node_id"])
# Combine center node with related nodes
all_nodes = [center_node] + related_nodes
return {
"nds": all_nodes,
"rls": relationships
}
return {"nds": all_nodes, "rls": relationships}

View File

@@ -13,46 +13,62 @@ from flowsint_core.core.postgre_db import get_db
from app.models.models import Transform, Profile
from app.api.deps import get_current_user
from app.api.schemas.transform import TransformRead, TransformCreate, TransformUpdate
from flowsint_types import ASN, CIDR, CryptoWallet, CryptoWalletTransaction, CryptoNFT, Website, Individual
from flowsint_types import (
ASN,
CIDR,
CryptoWallet,
CryptoWalletTransaction,
CryptoNFT,
Website,
Individual,
)
class FlowComputationRequest(BaseModel):
nodes: List[Node]
edges: List[Edge]
inputType: Optional[str] = None
class FlowComputationResponse(BaseModel):
transformBranches: List[FlowBranch]
initialData: Any
class StepSimulationRequest(BaseModel):
transformBranches: List[FlowBranch]
currentStepIndex: int
class LaunchTransformPayload(BaseModel):
values: List[str]
sketch_id: str
router = APIRouter()
# Get the list of all transforms
@router.get("", response_model=List[TransformRead])
def get_transforms(
category: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
current_user: Profile = Depends(get_current_user),
):
query = db.query(Transform)
if category is not None and category != "undefined":
# Case-insensitive filtering by checking if any category matches (case-insensitive)
transforms = query.all()
return [
transform for transform in transforms
transform
for transform in transforms
if any(cat.lower() == category.lower() for cat in transform.category)
]
return query.order_by(Transform.last_updated_at.desc()).all()
# Returns the "raw_materials" for the transform editor
@router.get("/raw_materials")
async def get_material_list():
@@ -72,7 +88,7 @@ async def get_material_list():
"params": scanner.get("params"),
"params_schema": scanner.get("params_schema"),
"required_params": scanner.get("required_params"),
"icon": scanner.get("icon")
"icon": scanner.get("icon"),
}
for scanner in scanner_list
]
@@ -92,27 +108,33 @@ async def get_material_list():
extract_input_schema_transform(Email),
extract_input_schema_transform(CryptoWallet),
extract_input_schema_transform(CryptoWalletTransaction),
extract_input_schema_transform(CryptoNFT)
extract_input_schema_transform(CryptoNFT),
]
# Put types first, then add all scanner categories
flattened_scanners = {"types": object_inputs}
flattened_scanners.update(scanner_categories)
return {"items": flattened_scanners}
# Returns the "raw_materials" for the transform editor
@router.get("/input_type/{input_type}")
async def get_material_list(input_type:str):
async def get_material_list(input_type: str):
scanners = ScannerRegistry.list_by_input_type(input_type)
return {"items": scanners}
# Create a new transform
@router.post("/create", response_model=TransformRead, status_code=status.HTTP_201_CREATED)
def create_transform(payload: TransformCreate, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
@router.post(
"/create", response_model=TransformRead, status_code=status.HTTP_201_CREATED
)
def create_transform(
payload: TransformCreate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
new_transform = Transform(
id=uuid4(),
name=payload.name,
@@ -127,28 +149,39 @@ def create_transform(payload: TransformCreate, db: Session = Depends(get_db), cu
db.refresh(new_transform)
return new_transform
# Get a transform by ID
@router.get("/{transform_id}", response_model=TransformRead)
def get_transform_by_id(transform_id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def get_transform_by_id(
transform_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
transform = db.query(Transform).filter(Transform.id == transform_id).first()
if not transform:
raise HTTPException(status_code=404, detail="Transform not found")
return transform
# Update a transform by ID
@router.put("/{transform_id}", response_model=TransformRead)
def update_transform(transform_id: UUID, payload: TransformUpdate, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def update_transform(
transform_id: UUID,
payload: TransformUpdate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
transform = db.query(Transform).filter(Transform.id == transform_id).first()
if not transform:
raise HTTPException(status_code=404, detail="Transform not found")
update_data = payload.dict(exclude_unset=True)
for key, value in update_data.items():
print(f'only update {key}')
print(f"only update {key}")
if key == "category":
if "SocialProfile" in value:
value.append("Username")
setattr(transform, key, value)
transform.last_updated_at = datetime.utcnow()
db.commit()
@@ -158,7 +191,11 @@ def update_transform(transform_id: UUID, payload: TransformUpdate, db: Session =
# Delete a transform by ID
@router.delete("/{transform_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_transform(transform_id: UUID, db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)):
def delete_transform(
transform_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
transform = db.query(Transform).filter(Transform.id == transform_id).first()
if not transform:
raise HTTPException(status_code=404, detail="Transform not found")
@@ -172,7 +209,7 @@ async def launch_transform(
transform_id: str,
payload: LaunchTransformPayload,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user)
current_user: Profile = Depends(get_current_user),
):
try:
transform = db.query(Transform).filter(Transform.id == transform_id).first()
@@ -180,32 +217,37 @@ async def launch_transform(
raise HTTPException(status_code=404, detail="Transform not found")
nodes = [Node(**node) for node in transform.transform_schema["nodes"]]
edges = [Edge(**edge) for edge in transform.transform_schema["edges"]]
transform_branches = compute_transform_branches(
payload.values,
nodes,
edges
)
transform_branches = compute_transform_branches(payload.values, nodes, edges)
serializable_branches = [branch.model_dump() for branch in transform_branches]
task = celery.send_task("run_transform", args=[serializable_branches, payload.values, payload.sketch_id, str(current_user.id)])
task = celery.send_task(
"run_transform",
args=[
serializable_branches,
payload.values,
payload.sketch_id,
str(current_user.id),
],
)
return {"id": task.id}
except Exception as e:
print(e)
raise HTTPException(status_code=404, detail="Transform not found")
@router.post("/{transform_id}/compute", response_model=FlowComputationResponse)
def compute_transforms(request: FlowComputationRequest, current_user: Profile = Depends(get_current_user)):
def compute_transforms(
request: FlowComputationRequest, current_user: Profile = Depends(get_current_user)
):
initial_data = generate_sample_data(request.inputType or "string")
transform_branches = compute_transform_branches(
initial_data,
request.nodes,
request.edges
)
return FlowComputationResponse(
transformBranches=transform_branches,
initialData=initial_data
initial_data, request.nodes, request.edges
)
return FlowComputationResponse(
transformBranches=transform_branches, initialData=initial_data
)
def generate_sample_data(type_str: str) -> Any:
type_str = type_str.lower() if type_str else "string"
if type_str == "string":
@@ -229,7 +271,10 @@ def generate_sample_data(type_str: str) -> Any:
else:
return f"sample_{type_str}"
def compute_transform_branches(initial_value: Any, nodes: List[Node], edges: List[Edge]) -> List[FlowBranch]:
def compute_transform_branches(
initial_value: Any, nodes: List[Node], edges: List[Edge]
) -> List[FlowBranch]:
"""Computes flow branches based on nodes and edges with proper DFS traversal"""
# Find input nodes (starting points)
input_nodes = [node for node in nodes if node.data.get("type") == "type"]
@@ -263,33 +308,38 @@ def compute_transform_branches(initial_value: Any, nodes: List[Node], edges: Lis
"""Calculate the shortest possible path length from a node to any leaf"""
if visited is None:
visited = set()
if start_node in visited:
return float('inf')
return float("inf")
visited.add(start_node)
out_edges = [edge for edge in edges if edge.source == start_node]
if not out_edges:
return 1
min_length = float('inf')
min_length = float("inf")
for edge in out_edges:
length = calculate_path_length(edge.target, visited.copy())
min_length = min(min_length, length)
return 1 + min_length
def get_outgoing_edges(node_id: str) -> List[Edge]:
"""Get outgoing edges sorted by the shortest possible path length"""
out_edges = [edge for edge in edges if edge.source == node_id]
# Sort edges by the length of the shortest possible path from their target
return sorted(
out_edges,
key=lambda e: calculate_path_length(e.target)
)
return sorted(out_edges, key=lambda e: calculate_path_length(e.target))
def create_step(node_id: str, branch_id: str, depth: int, input_data: Dict[str, Any], is_input_node: bool, outputs: Dict[str, Any], node_params: Optional[Dict[str, Any]] = None) -> FlowStep:
def create_step(
node_id: str,
branch_id: str,
depth: int,
input_data: Dict[str, Any],
is_input_node: bool,
outputs: Dict[str, Any],
node_params: Optional[Dict[str, Any]] = None,
) -> FlowStep:
return FlowStep(
nodeId=node_id,
params=node_params,
@@ -301,7 +351,17 @@ def compute_transform_branches(initial_value: Any, nodes: List[Node], edges: Lis
depth=depth,
)
def explore_branch(current_node_id: str, branch_id: str, branch_name: str, depth: int, input_data: Dict[str, Any], path: List[str], branch_visited: set, steps: List[FlowStep], parent_outputs: Dict[str, Any] = None) -> None:
def explore_branch(
current_node_id: str,
branch_id: str,
branch_name: str,
depth: int,
input_data: Dict[str, Any],
path: List[str],
branch_visited: set,
steps: List[FlowStep],
parent_outputs: Dict[str, Any] = None,
) -> None:
nonlocal branch_counter
# Skip if node is already in current path (cycle detection)
@@ -316,7 +376,9 @@ def compute_transform_branches(initial_value: Any, nodes: List[Node], edges: Lis
is_input_node = current_node.data.get("type") == "type"
if is_input_node:
outputs_array = current_node.data["outputs"].get("properties", [])
first_output_name = outputs_array[0].get("name", "output") if outputs_array else "output"
first_output_name = (
outputs_array[0].get("name", "output") if outputs_array else "output"
)
current_outputs = {first_output_name: initial_value}
else:
# Check if we already have outputs for this scanner
@@ -331,14 +393,22 @@ def compute_transform_branches(initial_value: Any, nodes: List[Node], edges: Lis
node_params = current_node.data.get("params", {})
# Create and add current step
current_step = create_step(current_node_id, branch_id, depth, input_data, is_input_node, current_outputs, node_params)
current_step = create_step(
current_node_id,
branch_id,
depth,
input_data,
is_input_node,
current_outputs,
node_params,
)
steps.append(current_step)
path.append(current_node_id)
branch_visited.add(current_node_id)
# Get all outgoing edges sorted by path length
out_edges = get_outgoing_edges(current_node_id)
if not out_edges:
# Leaf node reached, save the branch
branches.append(FlowBranch(id=branch_id, name=branch_name, steps=steps[:]))
@@ -347,15 +417,17 @@ def compute_transform_branches(initial_value: Any, nodes: List[Node], edges: Lis
for i, edge in enumerate(out_edges):
if edge.target in path: # Skip if would create cycle
continue
# Prepare next node's input
output_key = edge.sourceHandle
if not output_key and current_outputs:
output_key = list(current_outputs.keys())[0]
output_value = current_outputs.get(output_key) if output_key else None
if output_value is None and parent_outputs:
output_value = parent_outputs.get(output_key) if output_key else None
output_value = (
parent_outputs.get(output_key) if output_key else None
)
next_input = {edge.targetHandle or "input": output_value}
@@ -370,15 +442,17 @@ def compute_transform_branches(initial_value: Any, nodes: List[Node], edges: Lis
path,
branch_visited,
steps,
current_outputs
current_outputs,
)
else:
# Create new branch starting from current node
branch_counter += 1
new_branch_id = f"{branch_id}-{branch_counter}"
new_branch_name = f"{branch_name} (Branch {branch_counter})"
new_steps = steps[:len(steps)] # Copy steps up to current node
new_branch_visited = branch_visited.copy() # Create new visited set for the branch
new_steps = steps[: len(steps)] # Copy steps up to current node
new_branch_visited = (
branch_visited.copy()
) # Create new visited set for the branch
explore_branch(
edge.target,
new_branch_id,
@@ -388,7 +462,7 @@ def compute_transform_branches(initial_value: Any, nodes: List[Node], edges: Lis
path[:], # Create new path copy for branch
new_branch_visited,
new_steps,
current_outputs
current_outputs,
)
# Backtrack: remove current node from path and remove its step
@@ -408,70 +482,73 @@ def compute_transform_branches(initial_value: Any, nodes: List[Node], edges: Lis
[], # Use list for path to maintain order
set(), # Use set for visited to check membership
[],
None
None,
)
# Sort branches by length (number of steps)
branches.sort(key=lambda branch: len(branch.steps))
return branches
def process_node_data(node: Node, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Traite les données de nœud en fonction du type de nœud et des entrées"""
outputs = {}
output_types = node.data["outputs"].get("properties", [])
for output in output_types:
output_name = output.get("name", "output")
class_name = node.data.get("class_name", "")
# For simulation purposes, we'll return a placeholder value based on the scanner type
if class_name in ["ReverseResolveScanner", "ResolveScanner"]:
# IP/Domain resolution scanners
outputs[output_name] = "192.168.1.1" if "ip" in output_name.lower() else "example.com"
outputs[output_name] = (
"192.168.1.1" if "ip" in output_name.lower() else "example.com"
)
elif class_name == "SubdomainScanner":
# Subdomain scanner
outputs[output_name] = f"sub.{inputs.get('input', 'example.com')}"
elif class_name == "WhoisScanner":
# WHOIS scanner
outputs[output_name] = {
"domain": inputs.get("input", "example.com"),
"registrar": "Example Registrar",
"creation_date": "2020-01-01"
"creation_date": "2020-01-01",
}
elif class_name == "GeolocationScanner":
# Geolocation scanner
outputs[output_name] = {
"country": "France",
"city": "Paris",
"coordinates": {"lat": 48.8566, "lon": 2.3522}
"coordinates": {"lat": 48.8566, "lon": 2.3522},
}
elif class_name == "MaigretScanner":
# Social media scanner
outputs[output_name] = {
"username": inputs.get("input", "user123"),
"platforms": ["twitter", "github", "linkedin"]
"platforms": ["twitter", "github", "linkedin"],
}
elif class_name == "HoleheScanner":
# Email verification scanner
outputs[output_name] = {
"email": inputs.get("input", "user@example.com"),
"exists": True,
"platforms": ["gmail", "github"]
"platforms": ["gmail", "github"],
}
elif class_name == "SireneScanner":
# Organization scanner
outputs[output_name] = {
"name": inputs.get("input", "Example Corp"),
"siret": "12345678901234",
"address": "1 Example Street"
"address": "1 Example Street",
}
else:
# For unknown scanners, pass through the input
outputs[output_name] = inputs.get("input") or f"transformed_{output_name}"
return outputs
return outputs

View File

@@ -3,18 +3,44 @@ from uuid import UUID, uuid4
from fastapi import APIRouter
from pydantic import BaseModel, TypeAdapter
from flowsint_types import (
Domain, Ip, SocialProfile, Organization, Email, ASN, CIDR,
CryptoWallet, CryptoWalletTransaction, CryptoNFT, Website, Individual,
Phone, Leak, Username, Credential, Session,
DNSRecord, SSLCertificate, Device, Document, File, Message,
Malware, Weapon, BankAccount, CreditCard, WebTracker, Phrase
Domain,
Ip,
SocialProfile,
Organization,
Email,
ASN,
CIDR,
CryptoWallet,
CryptoWalletTransaction,
CryptoNFT,
Website,
Individual,
Phone,
Leak,
Username,
Credential,
Session,
DNSRecord,
SSLCertificate,
Device,
Document,
File,
Message,
Malware,
Weapon,
BankAccount,
CreditCard,
WebTracker,
Phrase,
)
# from flowsint_types.script import Script
# from flowsint_types.reputation_score import ReputationScore
# from flowsint_types.risk_profile import RiskProfile
router = APIRouter()
# Returns the "types" for the sketches
@router.get("/")
async def get_types_list():
@@ -39,9 +65,13 @@ async def get_types_list():
"fields": [],
"children": [
extract_input_schema(Individual, label_key="full_name"),
extract_input_schema(SocialProfile, label_key="username", icon="socialprofile"),
extract_input_schema(
SocialProfile, label_key="username", icon="socialprofile"
),
extract_input_schema(Organization, label_key="name"),
extract_input_schema(Username, label_key="username", icon="socialprofile"),
extract_input_schema(
Username, label_key="username", icon="socialprofile"
),
# extract_input_schema(Alias, label_key="alias", icon="alias"),
# extract_input_schema(Affiliation, label_key="organization", icon="affiliation"),
],
@@ -67,7 +97,9 @@ async def get_types_list():
"children": [
extract_input_schema(Phone, label_key="number"),
extract_input_schema(Email, label_key="email"),
extract_input_schema(SocialProfile, label_key="username", icon="socialprofile"),
extract_input_schema(
SocialProfile, label_key="username", icon="socialprofile"
),
extract_input_schema(Message, label_key="content", icon="message"),
],
},
@@ -97,7 +129,9 @@ async def get_types_list():
"label": "Security & Access",
"fields": [],
"children": [
extract_input_schema(Credential, label_key="username", icon="credential"),
extract_input_schema(
Credential, label_key="username", icon="credential"
),
extract_input_schema(Session, label_key="session_id", icon="session"),
extract_input_schema(Device, label_key="device_id", icon="device"),
extract_input_schema(Malware, label_key="name", icon="malware"),
@@ -127,8 +161,12 @@ async def get_types_list():
"label": "Financial Data",
"fields": [],
"children": [
extract_input_schema(BankAccount, label_key="account_number", icon="creditcard"),
extract_input_schema(CreditCard, label_key="card_number", icon="creditcard"),
extract_input_schema(
BankAccount, label_key="account_number", icon="creditcard"
),
extract_input_schema(
CreditCard, label_key="card_number", icon="creditcard"
),
],
},
{
@@ -150,41 +188,47 @@ async def get_types_list():
"label": "Crypto",
"fields": [],
"children": [
extract_input_schema(CryptoWallet, label_key="address", icon="cryptowallet"),
extract_input_schema(CryptoWalletTransaction, label_key="hash", icon="cryptowallet"),
extract_input_schema(
CryptoWallet, label_key="address", icon="cryptowallet"
),
extract_input_schema(
CryptoWalletTransaction, label_key="hash", icon="cryptowallet"
),
extract_input_schema(CryptoNFT, label_key="name", icon="cryptowallet"),
],
}
},
]
return types
def extract_input_schema(
model: Type[BaseModel], label_key: str, icon: Optional[str] = None
) -> Dict[str, Any]:
def extract_input_schema(model: Type[BaseModel], label_key:str, icon: Optional[str]=None) -> Dict[str, Any]:
adapter = TypeAdapter(model)
schema = adapter.json_schema()
# Use the main schema properties, not the $defs
type_name = model.__name__
details = schema
return {
"id": uuid4(),
"type": type_name,
"key": type_name.lower(),
"label_key": label_key,
"icon": icon or type_name.lower(),
"label": type_name,
"description": details.get("description",""),
"fields": [resolve_field(prop, details=info, schema=schema)
for prop, info in details.get("properties", {}).items()
]
"id": uuid4(),
"type": type_name,
"key": type_name.lower(),
"label_key": label_key,
"icon": icon or type_name.lower(),
"label": type_name,
"description": details.get("description", ""),
"fields": [
resolve_field(prop, details=info, schema=schema)
for prop, info in details.get("properties", {}).items()
],
}
def resolve_field(prop:str, details: dict, schema: dict = None) -> Dict:
def resolve_field(prop: str, details: dict, schema: dict = None) -> Dict:
"""_summary_
The fields can sometimes contain nested complex objects, like:
The fields can sometimes contain nested complex objects, like:
- Organization having Individual[] as dirigeants, so we want to skip those.
Args:
details (dict): _description_
@@ -193,28 +237,35 @@ def resolve_field(prop:str, details: dict, schema: dict = None) -> Dict:
Returns:
str: _description_
"""
field = { "name": prop,"label": details["title"], "description": details["description"], "type":"text"}
field = {
"name": prop,
"label": details["title"],
"description": details["description"],
"type": "text",
}
if has_enum(details):
field["type"] = "select"
field["options"]= [
{ "label": label, "value": label } for label in get_enum_values(details)
field["type"] = "select"
field["options"] = [
{"label": label, "value": label} for label in get_enum_values(details)
]
field["required"] = is_required(details)
return field
def has_enum(schema: dict) -> bool:
any_of = schema.get('anyOf', [])
return any(isinstance(entry, dict) and 'enum' in entry for entry in any_of)
any_of = schema.get("anyOf", [])
return any(isinstance(entry, dict) and "enum" in entry for entry in any_of)
def is_required(schema: dict) -> bool:
any_of = schema.get('anyOf', [])
return not any(entry == {'type': 'null'} for entry in any_of)
any_of = schema.get("anyOf", [])
return not any(entry == {"type": "null"} for entry in any_of)
def get_enum_values(schema: dict) -> list:
enum_values = []
for entry in schema.get('anyOf', []):
if isinstance(entry, dict) and 'enum' in entry:
enum_values.extend(entry['enum'])
return enum_values
for entry in schema.get("anyOf", []):
if isinstance(entry, dict) and "enum" in entry:
enum_values.extend(entry["enum"])
return enum_values