Files
flowsint/flowsint-api/app/api/routes/flows.py

611 lines
20 KiB
Python

from datetime import datetime
from typing import Any, Dict, List, Optional
from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, Query, status
# Auto-discover and register all enrichers
from flowsint_core.core.celery import celery
from flowsint_core.core.graph import create_graph_service
from flowsint_core.core.models import CustomType, Flow, Profile, Sketch
from flowsint_core.core.postgre_db import get_db
from flowsint_core.core.types import FlowBranch, FlowEdge, FlowNode, FlowStep
from flowsint_core.utils import extract_input_schema_flow
from flowsint_enrichers import ENRICHER_REGISTRY, load_all_enrichers
from flowsint_types import (
ASN,
CIDR,
CryptoNFT,
CryptoWallet,
CryptoWalletTransaction,
DNSRecord,
Domain,
Email,
Individual,
Ip,
Organization,
Phone,
Phrase,
Port,
SocialAccount,
Username,
Website,
)
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.api.deps import get_current_user
from app.api.schemas.flow import FlowCreate, FlowRead, FlowUpdate
from app.security.permissions import check_investigation_permission
load_all_enrichers()
class FlowComputationRequest(BaseModel):
nodes: List[FlowNode]
edges: List[FlowEdge]
inputType: Optional[str] = None
class FlowComputationResponse(BaseModel):
flowBranches: List[FlowBranch]
initialData: Any
class StepSimulationRequest(BaseModel):
flowBranches: List[FlowBranch]
currentStepIndex: int
class launchFlowPayload(BaseModel):
node_ids: List[str]
sketch_id: str
router = APIRouter()
@router.get("/", response_model=List[FlowRead])
def get_flows(
category: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
if not category or category.lower() == "undefined":
return db.query(Flow).order_by(Flow.last_updated_at.desc()).all()
custom_type = (
db.query(CustomType)
.filter(
CustomType.owner_id == current_user.id,
CustomType.status == "published",
func.lower(CustomType.name) == category.lower(),
)
.first()
)
if custom_type:
flows = db.query(Flow).order_by(Flow.last_updated_at.desc()).all()
return [
{
**(flow.to_dict() if hasattr(flow, "to_dict") else flow.__dict__),
"wobblyType": True,
}
for flow in flows
]
flows = db.query(Flow).order_by(Flow.last_updated_at.desc()).all()
return [
flow
for flow in flows
if any(cat.lower() == category.lower() for cat in flow.category)
]
# Returns the "raw_materials" for the flow editor
@router.get("/raw_materials")
async def get_material_list():
enrichers = ENRICHER_REGISTRY.list_by_categories()
enricher_categories = {
category: [
{
"class_name": enricher.get("class_name"),
"category": enricher.get("category"),
"name": enricher.get("name"),
"module": enricher.get("module"),
"documentation": enricher.get("documentation"),
"description": enricher.get("description"),
"inputs": enricher.get("inputs"),
"outputs": enricher.get("outputs"),
"type": "enricher",
"params": enricher.get("params"),
"params_schema": enricher.get("params_schema"),
"required_params": enricher.get("required_params"),
"icon": enricher.get("icon"),
}
for enricher in enricher_list
]
for category, enricher_list in enrichers.items()
}
object_inputs = [
extract_input_schema_flow(Phrase),
extract_input_schema_flow(Organization),
extract_input_schema_flow(Individual),
extract_input_schema_flow(Domain),
extract_input_schema_flow(Website),
extract_input_schema_flow(Ip),
extract_input_schema_flow(DNSRecord),
extract_input_schema_flow(Port),
extract_input_schema_flow(Phone),
extract_input_schema_flow(ASN),
extract_input_schema_flow(CIDR),
extract_input_schema_flow(Username),
extract_input_schema_flow(SocialAccount),
extract_input_schema_flow(Email),
extract_input_schema_flow(CryptoWallet),
extract_input_schema_flow(CryptoWalletTransaction),
extract_input_schema_flow(CryptoNFT),
]
# Put types first, then add all enricher categories
flattened_enrichers = {"types": object_inputs}
flattened_enrichers.update(enricher_categories)
return {"items": flattened_enrichers}
# Returns the "raw_materials" for the flow editor
@router.get("/input_type/{input_type}")
async def get_material_list(input_type: str):
enrichers = ENRICHER_REGISTRY.list_by_input_type(input_type)
return {"items": enrichers}
# Create a new flow
@router.post("/create", response_model=FlowRead, status_code=status.HTTP_201_CREATED)
def create_flow(
payload: FlowCreate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
new_flow = Flow(
id=uuid4(),
name=payload.name,
description=payload.description,
category=payload.category,
flow_schema=payload.flow_schema,
created_at=datetime.utcnow(),
last_updated_at=datetime.utcnow(),
)
db.add(new_flow)
db.commit()
db.refresh(new_flow)
return new_flow
# Get a flow by ID
@router.get("/{flow_id}", response_model=FlowRead)
def get_flow_by_id(
flow_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
flow = db.query(Flow).filter(Flow.id == flow_id).first()
if not flow:
raise HTTPException(status_code=404, detail="flow not found")
return flow
# Update a flow by ID
@router.put("/{flow_id}", response_model=FlowRead)
def update_flow(
flow_id: UUID,
payload: FlowUpdate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
flow = db.query(Flow).filter(Flow.id == flow_id).first()
if not flow:
raise HTTPException(status_code=404, detail="flow not found")
update_data = payload.model_dump(exclude_unset=True)
for key, value in update_data.items():
print(f"only update {key}")
if key == "category":
if "SocialAccount" in value:
value.append("Username")
setattr(flow, key, value)
flow.last_updated_at = datetime.utcnow()
db.commit()
db.refresh(flow)
return flow
# Delete a flow by ID
@router.delete("/{flow_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_flow(
flow_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
flow = db.query(Flow).filter(Flow.id == flow_id).first()
if not flow:
raise HTTPException(status_code=404, detail="flow not found")
db.delete(flow)
db.commit()
return None
@router.post("/{flow_id}/launch")
async def launch_flow(
flow_id: str,
payload: launchFlowPayload,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
try:
flow = db.query(Flow).filter(Flow.id == flow_id).first()
if flow is None:
raise HTTPException(status_code=404, detail="flow not found")
# Check investigation permission via sketch
sketch = db.query(Sketch).filter(Sketch.id == payload.sketch_id).first()
if not sketch:
raise HTTPException(status_code=404, detail="Sketch not found")
check_investigation_permission(
current_user.id, sketch.investigation_id, actions=["update"], db=db
)
# Retrieve entities from Neo4J by their element IDs
graph_service = create_graph_service(sketch_id=payload.sketch_id)
entities = graph_service.get_nodes_by_ids_for_task(payload.node_ids)
# Compute flow branches
nodes = [FlowNode(**node) for node in flow.flow_schema["nodes"]]
edges = [FlowEdge(**edge) for edge in flow.flow_schema["edges"]]
entities = [entity.model_dump(mode="json", serialize_as_any=True) for entity in entities]
# For flow computation, we still need a sample value
# Use the label from the first node data
sample_value = (
entities[0].get("nodeLabel", "sample_value") if len(entities) else "sample_value"
)
flow_branches = compute_flow_branches(sample_value, nodes, edges)
serializable_branches = [branch.model_dump() for branch in flow_branches]
task = celery.send_task(
"run_flow",
args=[
serializable_branches,
entities,
payload.sketch_id,
str(current_user.id),
],
)
return {"id": task.id}
except HTTPException:
raise
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=f"Error launching flow: {str(e)}")
@router.post("/{flow_id}/compute", response_model=FlowComputationResponse)
def compute_flows(
request: FlowComputationRequest, current_user: Profile = Depends(get_current_user)
):
initial_data = generate_sample_data(request.inputType or "string")
flow_branches = compute_flow_branches(initial_data, request.nodes, request.edges)
return FlowComputationResponse(flowBranches=flow_branches, initialData=initial_data)
def generate_sample_data(type_str: str) -> Any:
type_str = type_str.lower() if type_str else "string"
if type_str == "string":
return "sample_text"
elif type_str == "number":
return 42
elif type_str == "boolean":
return True
elif type_str == "array":
return [1, 2, 3]
elif type_str == "object":
return {"key": "value"}
elif type_str == "url":
return "https://example.com"
elif type_str == "email":
return "user@example.com"
elif type_str == "domain":
return "example.com"
elif type_str == "ip":
return "192.168.1.1"
else:
return f"sample_{type_str}"
def compute_flow_branches(
initial_value: Any, nodes: List[FlowNode], edges: List[FlowEdge]
) -> List[FlowBranch]:
"""Computes flow branches based on nodes and edges with proper DFS traversal"""
# Find input nodes (starting points)
input_nodes = [node for node in nodes if node.data.get("type") == "type"]
if not input_nodes:
return [
FlowBranch(
id="error",
name="Error",
steps=[
FlowStep(
nodeId="error",
inputs={},
type="error",
outputs={},
status="error",
branchId="error",
depth=0,
)
],
)
]
node_map = {node.id: node for node in nodes}
branches = []
branch_counter = 0
# Track enricher outputs across all branches
enricher_outputs = {}
def calculate_path_length(start_node: str, visited: set = None) -> int:
"""Calculate the shortest possible path length from a node to any leaf"""
if visited is None:
visited = set()
if start_node in visited:
return float("inf")
visited.add(start_node)
out_edges = [edge for edge in edges if edge.source == start_node]
if not out_edges:
return 1
min_length = float("inf")
for edge in out_edges:
length = calculate_path_length(edge.target, visited.copy())
min_length = min(min_length, length)
return 1 + min_length
def get_outgoing_edges(node_id: str) -> List[FlowEdge]:
"""Get outgoing edges sorted by the shortest possible path length"""
out_edges = [edge for edge in edges if edge.source == node_id]
# Sort edges by the length of the shortest possible path from their target
return sorted(out_edges, key=lambda e: calculate_path_length(e.target))
def create_step(
node_id: str,
branch_id: str,
depth: int,
input_data: Dict[str, Any],
is_input_node: bool,
outputs: Dict[str, Any],
node_params: Optional[Dict[str, Any]] = None,
) -> FlowStep:
return FlowStep(
nodeId=node_id,
params=node_params,
inputs={} if is_input_node else input_data,
outputs=outputs,
type="type" if is_input_node else "enricher",
status="pending",
branchId=branch_id,
depth=depth,
)
def explore_branch(
current_node_id: str,
branch_id: str,
branch_name: str,
depth: int,
input_data: Dict[str, Any],
path: List[str],
branch_visited: set,
steps: List[FlowStep],
parent_outputs: Dict[str, Any] = None,
) -> None:
nonlocal branch_counter
# Skip if node is already in current path (cycle detection)
if current_node_id in path:
return
current_node = node_map.get(current_node_id)
if not current_node:
return
# Process node outputs
is_input_node = current_node.data.get("type") == "type"
if is_input_node:
outputs_array = current_node.data["outputs"].get("properties", [])
first_output_name = (
outputs_array[0].get("name", "output") if outputs_array else "output"
)
current_outputs = {first_output_name: initial_value}
else:
# Check if we already have outputs for this enricher
if current_node_id in enricher_outputs:
current_outputs = enricher_outputs[current_node_id]
else:
current_outputs = process_node_data(current_node, input_data)
# Store the outputs for future use
enricher_outputs[current_node_id] = current_outputs
# Extract node parameters
node_params = current_node.data.get("params", {})
# Create and add current step
current_step = create_step(
current_node_id,
branch_id,
depth,
input_data,
is_input_node,
current_outputs,
node_params,
)
steps.append(current_step)
path.append(current_node_id)
branch_visited.add(current_node_id)
# Get all outgoing edges sorted by path length
out_edges = get_outgoing_edges(current_node_id)
if not out_edges:
# Leaf node reached, save the branch
branches.append(FlowBranch(id=branch_id, name=branch_name, steps=steps[:]))
else:
# Process each outgoing edge in order of shortest path
for i, edge in enumerate(out_edges):
if edge.target in path: # Skip if would create cycle
continue
# Prepare next node's input
output_key = edge.sourceHandle
if not output_key and current_outputs:
output_key = list(current_outputs.keys())[0]
output_value = current_outputs.get(output_key) if output_key else None
if output_value is None and parent_outputs:
output_value = (
parent_outputs.get(output_key) if output_key else None
)
next_input = {edge.targetHandle or "input": output_value}
if i == 0:
# Continue in same branch (will be shortest path)
explore_branch(
edge.target,
branch_id,
branch_name,
depth + 1,
next_input,
path,
branch_visited,
steps,
current_outputs,
)
else:
# Create new branch starting from current node
branch_counter += 1
new_branch_id = f"{branch_id}-{branch_counter}"
new_branch_name = f"{branch_name} (Branch {branch_counter})"
new_steps = steps[: len(steps)] # Copy steps up to current node
new_branch_visited = (
branch_visited.copy()
) # Create new visited set for the branch
explore_branch(
edge.target,
new_branch_id,
new_branch_name,
depth + 1,
next_input,
path[:], # Create new path copy for branch
new_branch_visited,
new_steps,
current_outputs,
)
# Backtrack: remove current node from path and remove its step
path.pop()
steps.pop()
# Start exploration from each input node
for index, input_node in enumerate(input_nodes):
branch_id = f"branch-{index}"
branch_name = f"Flow {index + 1}" if len(input_nodes) > 1 else "Main Flow"
explore_branch(
input_node.id,
branch_id,
branch_name,
0,
{},
[], # Use list for path to maintain order
set(), # Use set for visited to check membership
[],
None,
)
# Sort branches by length (number of steps)
branches.sort(key=lambda branch: len(branch.steps))
return branches
def process_node_data(node: FlowNode, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Traite les données de nœud en fonction du type de nœud et des entrées"""
outputs = {}
output_types = node.data["outputs"].get("properties", [])
for output in output_types:
output_name = output.get("name", "output")
class_name = node.data.get("class_name", "")
# For simulation purposes, we'll return a placeholder value based on the enricher type
if class_name in ["ReverseResolveEnricher", "ResolveEnricher"]:
# IP/Domain resolution enrichers
outputs[output_name] = (
"192.168.1.1" if "ip" in output_name.lower() else "example.com"
)
elif class_name == "SubdomainEnricher":
# Subdomain enricher
outputs[output_name] = f"sub.{inputs.get('input', 'example.com')}"
elif class_name == "WhoisEnricher":
# WHOIS enricher
outputs[output_name] = {
"domain": inputs.get("input", "example.com"),
"registrar": "Example Registrar",
"creation_date": "2020-01-01",
}
elif class_name == "IpToInfosEnricher":
# Geolocation enricher
outputs[output_name] = {
"country": "France",
"city": "Paris",
"coordinates": {"lat": 48.8566, "lon": 2.3522},
}
elif class_name == "MaigretEnricher":
# Social media enricher
outputs[output_name] = {
"username": inputs.get("input", "user123"),
"platforms": ["twitter", "github", "linkedin"],
}
elif class_name == "HoleheEnricher":
# Email verification enricher
outputs[output_name] = {
"email": inputs.get("input", "user@example.com"),
"exists": True,
"platforms": ["gmail", "github"],
}
elif class_name == "SireneEnricher":
# Organization enricher
outputs[output_name] = {
"name": inputs.get("input", "Example Corp"),
"siret": "12345678901234",
"address": "1 Example Street",
}
else:
# For unknown enrichers, pass through the input
outputs[output_name] = inputs.get("input") or f"flowed_{output_name}"
return outputs