mirror of
https://github.com/reconurge/flowsint.git
synced 2026-05-07 20:28:48 -05:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ba3c0bd6c | ||
|
|
2fec5b82e9 | ||
|
|
1cc88d1033 | ||
|
|
2372678bda | ||
|
|
fd20468c79 | ||
|
|
427035fd41 | ||
|
|
152be112d0 | ||
|
|
01908e682b | ||
|
|
20a3d35909 | ||
|
|
16c9497268 | ||
|
|
12ef177127 | ||
|
|
14f79fecc5 |
8
.github/workflows/images.yml
vendored
8
.github/workflows/images.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
- name: Upload Trivy scan results
|
||||
uses: github/codeql-action/upload-sarif@v3
|
||||
uses: github/codeql-action/upload-sarif@v4
|
||||
if: always() && steps.trivy.outcome == 'success'
|
||||
with:
|
||||
sarif_file: "trivy-frontend.sarif"
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
@@ -152,7 +152,7 @@ jobs:
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
- name: Upload Trivy scan results
|
||||
uses: github/codeql-action/upload-sarif@v3
|
||||
uses: github/codeql-action/upload-sarif@v4
|
||||
if: always() && steps.trivy.outcome == 'success'
|
||||
with:
|
||||
sarif_file: "trivy-backend.sarif"
|
||||
|
||||
@@ -117,7 +117,7 @@ WORKDIR /app/flowsint-api
|
||||
RUN chmod +x entrypoint.sh
|
||||
|
||||
# Switch to non-root user
|
||||
USER flowsint
|
||||
# USER flowsint
|
||||
|
||||
EXPOSE 5001
|
||||
|
||||
|
||||
@@ -1,42 +1,30 @@
|
||||
from uuid import UUID, uuid4
|
||||
from app.security.permissions import check_investigation_permission
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.models import Analysis, Profile, InvestigationUserRole
|
||||
from flowsint_core.core.types import Role
|
||||
from flowsint_core.core.models import Profile
|
||||
from flowsint_core.core.services import (
|
||||
create_analysis_service,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
)
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.schemas.analysis import AnalysisRead, AnalysisCreate, AnalysisUpdate
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Get the list of all analyses for the current user
|
||||
@router.get("", response_model=List[AnalysisRead])
|
||||
def get_analyses(
|
||||
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
|
||||
):
|
||||
# Get all analyses from investigations where user has at least VIEWER role
|
||||
allowed_roles_for_read = [Role.OWNER, Role.EDITOR, Role.VIEWER]
|
||||
|
||||
query = db.query(Analysis).join(
|
||||
InvestigationUserRole,
|
||||
InvestigationUserRole.investigation_id == Analysis.investigation_id,
|
||||
)
|
||||
|
||||
query = query.filter(InvestigationUserRole.user_id == current_user.id)
|
||||
|
||||
# Filter by allowed roles
|
||||
conditions = [InvestigationUserRole.roles.any(role) for role in allowed_roles_for_read]
|
||||
query = query.filter(or_(*conditions))
|
||||
|
||||
return query.distinct().all()
|
||||
"""Get all analyses accessible to the current user."""
|
||||
service = create_analysis_service(db)
|
||||
return service.get_accessible_analyses(current_user.id)
|
||||
|
||||
|
||||
# Create a New analysis
|
||||
@router.post(
|
||||
"/create", response_model=AnalysisRead, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
@@ -45,64 +33,47 @@ def create_analysis(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
check_investigation_permission(
|
||||
current_user.id, payload.investigation_id, actions=["create"], db=db
|
||||
)
|
||||
new_analysis = Analysis(
|
||||
id=uuid4(),
|
||||
title=payload.title,
|
||||
description=payload.description,
|
||||
content=payload.content,
|
||||
owner_id=current_user.id,
|
||||
investigation_id=payload.investigation_id,
|
||||
created_at=datetime.utcnow(),
|
||||
last_updated_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(new_analysis)
|
||||
db.commit()
|
||||
db.refresh(new_analysis)
|
||||
return new_analysis
|
||||
service = create_analysis_service(db)
|
||||
try:
|
||||
return service.create(
|
||||
title=payload.title,
|
||||
description=payload.description,
|
||||
content=payload.content,
|
||||
investigation_id=payload.investigation_id,
|
||||
owner_id=current_user.id,
|
||||
)
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
# Get an analysis by ID
|
||||
@router.get("/{analysis_id}", response_model=AnalysisRead)
|
||||
def get_analysis_by_id(
|
||||
analysis_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
analysis = (
|
||||
db.query(Analysis)
|
||||
.filter(Analysis.id == analysis_id)
|
||||
.first()
|
||||
)
|
||||
if not analysis:
|
||||
service = create_analysis_service(db)
|
||||
try:
|
||||
return service.get_by_id(analysis_id, current_user.id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Analysis not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, analysis.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
return analysis
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
# Get analyses by investigation ID
|
||||
@router.get("/investigation/{investigation_id}", response_model=List[AnalysisRead])
|
||||
def get_analyses_by_investigation(
|
||||
investigation_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
check_investigation_permission(
|
||||
current_user.id, investigation_id, actions=["read"], db=db
|
||||
)
|
||||
analyses = (
|
||||
db.query(Analysis)
|
||||
.filter(Analysis.investigation_id == investigation_id)
|
||||
.all()
|
||||
)
|
||||
return analyses
|
||||
service = create_analysis_service(db)
|
||||
try:
|
||||
return service.get_by_investigation(investigation_id, current_user.id)
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
# Update an analysis by ID
|
||||
@router.put("/{analysis_id}", response_model=AnalysisRead)
|
||||
def update_analysis(
|
||||
analysis_id: UUID,
|
||||
@@ -110,51 +81,33 @@ def update_analysis(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
analysis = (
|
||||
db.query(Analysis)
|
||||
.filter(Analysis.id == analysis_id)
|
||||
.first()
|
||||
)
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=404, detail="Analysis not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, analysis.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
if payload.title is not None:
|
||||
analysis.title = payload.title
|
||||
if payload.description is not None:
|
||||
analysis.description = payload.description
|
||||
if payload.content is not None:
|
||||
analysis.content = payload.content
|
||||
if payload.investigation_id is not None:
|
||||
# Check permission for the new investigation as well
|
||||
check_investigation_permission(
|
||||
current_user.id, payload.investigation_id, actions=["update"], db=db
|
||||
service = create_analysis_service(db)
|
||||
try:
|
||||
return service.update(
|
||||
analysis_id=analysis_id,
|
||||
user_id=current_user.id,
|
||||
title=payload.title,
|
||||
description=payload.description,
|
||||
content=payload.content,
|
||||
investigation_id=payload.investigation_id,
|
||||
)
|
||||
analysis.investigation_id = payload.investigation_id
|
||||
analysis.last_updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(analysis)
|
||||
return analysis
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Analysis not found")
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
# Delete an analysis by ID
|
||||
@router.delete("/{analysis_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_analysis(
|
||||
analysis_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
analysis = (
|
||||
db.query(Analysis)
|
||||
.filter(Analysis.id == analysis_id)
|
||||
.first()
|
||||
)
|
||||
if not analysis:
|
||||
service = create_analysis_service(db)
|
||||
try:
|
||||
service.delete(analysis_id, current_user.id)
|
||||
return None
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Analysis not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, analysis.investigation_id, actions=["delete"], db=db
|
||||
)
|
||||
db.delete(analysis)
|
||||
db.commit()
|
||||
return None
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
|
||||
from flowsint_core.core.auth import (
|
||||
verify_password,
|
||||
create_access_token,
|
||||
get_password_hash,
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from flowsint_core.core.services import (
|
||||
create_auth_service,
|
||||
AuthenticationError,
|
||||
ConflictError,
|
||||
DatabaseError,
|
||||
)
|
||||
from app.api.schemas.profile import ProfileCreate
|
||||
from flowsint_core.core.models import Profile
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
|
||||
router = APIRouter()
|
||||
@@ -18,50 +19,23 @@ router = APIRouter()
|
||||
def login_for_access_token(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)
|
||||
):
|
||||
service = create_auth_service(db)
|
||||
try:
|
||||
user = db.query(Profile).filter(Profile.email == form_data.username).first()
|
||||
|
||||
if not user or not verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="Incorrect email or password")
|
||||
|
||||
access_token = create_access_token(data={"sub": user.email})
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"user_id": user.id,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
# Log optionnel
|
||||
return service.authenticate(form_data.username, form_data.password)
|
||||
except AuthenticationError:
|
||||
raise HTTPException(status_code=400, detail="Incorrect email or password")
|
||||
except (DatabaseError, SQLAlchemyError) as e:
|
||||
print(f"[ERROR] DB error during login: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/register", status_code=201)
|
||||
def register(user: ProfileCreate, db: Session = Depends(get_db)):
|
||||
service = create_auth_service(db)
|
||||
try:
|
||||
existing_user = db.query(Profile).filter(Profile.email == user.email).first()
|
||||
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
|
||||
hashed_password = get_password_hash(user.password)
|
||||
new_user = Profile(email=user.email, hashed_password=hashed_password)
|
||||
|
||||
db.add(new_user)
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
|
||||
return {
|
||||
"message": "User registered successfully",
|
||||
"email": new_user.email,
|
||||
}
|
||||
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
return service.register(user.email, user.password)
|
||||
except ConflictError:
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
except (DatabaseError, SQLAlchemyError) as e:
|
||||
print(f"[ERROR] DB error during registration: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import os
|
||||
@@ -6,90 +6,43 @@ from mistralai import Mistral
|
||||
from mistralai.models import UserMessage, AssistantMessage, SystemMessage
|
||||
import json
|
||||
from uuid import UUID, uuid4
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.models import Chat, ChatMessage, Profile, InvestigationUserRole
|
||||
from flowsint_core.core.types import Role
|
||||
from flowsint_core.core.models import ChatMessage, Profile
|
||||
from flowsint_core.core.services import (
|
||||
create_chat_service,
|
||||
NotFoundError,
|
||||
)
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.schemas.chat import ChatCreate, ChatRead
|
||||
from app.security.permissions import check_investigation_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def clean_context(context: List[Dict]) -> List[Dict]:
|
||||
print(context)
|
||||
"""Remove unnecessary keys from context data."""
|
||||
cleaned = []
|
||||
for item in context:
|
||||
if isinstance(item, dict):
|
||||
# Create a copy and remove unwanted keys
|
||||
cleaned_item = item["data"].copy()
|
||||
# Remove top-level keys
|
||||
cleaned_item.pop("id", None)
|
||||
cleaned_item.pop("sketch_id", None)
|
||||
# Remove from data if it exists
|
||||
if "data" in cleaned_item and isinstance(cleaned_item["data"], dict):
|
||||
cleaned_item["data"].pop("sketch_id", None)
|
||||
# Remove measured/dimensions
|
||||
cleaned_item.pop("measured", None)
|
||||
cleaned.append(cleaned_item)
|
||||
return cleaned
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
prompt: str
|
||||
context: Optional[List[Dict]] = None
|
||||
|
||||
|
||||
# Get all chats
|
||||
@router.get("/", response_model=List[ChatRead])
|
||||
def get_chats(
|
||||
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
|
||||
):
|
||||
chats = db.query(Chat).filter(Chat.owner_id == current_user.id).all()
|
||||
|
||||
# Sort messages for each chat by created_at in ascending order
|
||||
for chat in chats:
|
||||
chat.messages.sort(key=lambda x: x.created_at)
|
||||
|
||||
return chats
|
||||
service = create_chat_service(db)
|
||||
return service.get_chats_for_user(current_user.id)
|
||||
|
||||
|
||||
# @router.get("/delete-all", status_code=status.HTTP_204_NO_CONTENT)
|
||||
# def delete_all_chat(db: Session = Depends(get_db)):
|
||||
# chats = db.query(Chat).all()
|
||||
# for chat in chats:
|
||||
# db.delete(chat)
|
||||
# db.commit()
|
||||
# return {"result": "done"}
|
||||
|
||||
|
||||
# Get analyses by investigation ID
|
||||
@router.get("/investigation/{investigation_id}", response_model=List[ChatRead])
|
||||
def get_chats_by_investigation(
|
||||
investigation_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
chats = (
|
||||
db.query(Chat)
|
||||
.filter(
|
||||
Chat.investigation_id == investigation_id, Chat.owner_id == current_user.id
|
||||
)
|
||||
.order_by(Chat.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
# Sort messages for each chat by created_at in ascending order
|
||||
for chat in chats:
|
||||
chat.messages.sort(key=lambda x: x.created_at)
|
||||
|
||||
return chats
|
||||
service = create_chat_service(db)
|
||||
return service.get_by_investigation(investigation_id, current_user.id)
|
||||
|
||||
|
||||
@router.post("/stream/{chat_id}")
|
||||
@@ -99,30 +52,18 @@ async def stream_chat(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
# Check if Chat exists
|
||||
chat = (
|
||||
db.query(Chat)
|
||||
.filter(Chat.id == chat_id, Chat.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not chat:
|
||||
service = create_chat_service(db)
|
||||
|
||||
try:
|
||||
chat = service.get_by_id(chat_id, current_user.id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
# Update chat's last_updated_at
|
||||
chat.last_updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
# A new message is created
|
||||
user_message = ChatMessage(
|
||||
id=uuid4(),
|
||||
content=payload.prompt,
|
||||
context=payload.context,
|
||||
chat_id=chat_id,
|
||||
is_bot=False,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(user_message)
|
||||
db.commit()
|
||||
db.refresh(user_message)
|
||||
# Add user message
|
||||
service.add_user_message(chat_id, current_user.id, payload.prompt, payload.context)
|
||||
|
||||
# Prepare AI context
|
||||
ai_context = service.prepare_ai_context(chat, payload.prompt, payload.context)
|
||||
|
||||
try:
|
||||
api_key = os.environ.get("MISTRAL_API_KEY")
|
||||
@@ -134,35 +75,15 @@ async def stream_chat(
|
||||
client = Mistral(api_key=api_key)
|
||||
model = "mistral-small-latest"
|
||||
accumulated_content = []
|
||||
context_message = None
|
||||
# Convert database messages to Mistral format
|
||||
|
||||
# Build messages for Mistral
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a CTI/OSINT investigator and you are trying to investigate on a variety of real life cases. Use your knowledge and analytics capabilities to analyse the context and answer the question the best you can. If you need to reference some items (an IP, a domain or something particular) please use the code brackets, like : `12.23.34.54` to reference it."
|
||||
)
|
||||
]
|
||||
|
||||
# Add context as a single system message if provided
|
||||
if payload.context:
|
||||
try:
|
||||
# Clean context by removing unnecessary keys
|
||||
cleaned_context = clean_context(payload.context)
|
||||
if cleaned_context:
|
||||
context_str = json.dumps(cleaned_context, indent=2, default=str)
|
||||
context_message = f"Context: {context_str}"
|
||||
# Limit context message length to avoid token limits
|
||||
if len(context_message) > 2000:
|
||||
context_message = context_message[:2000] + "..."
|
||||
except Exception as e:
|
||||
# If context processing fails, skip it
|
||||
print(f"Context processing error: {e}")
|
||||
|
||||
# Sort messages by created_at in ascending order and get recent messages
|
||||
sorted_messages = sorted(chat.messages, key=lambda x: x.created_at)
|
||||
recent_messages = (
|
||||
sorted_messages[-5:] if len(sorted_messages) > 5 else sorted_messages
|
||||
)
|
||||
for message in recent_messages:
|
||||
for message in ai_context["recent_messages"]:
|
||||
if message.is_bot:
|
||||
messages.append(
|
||||
AssistantMessage(content=json.dumps(message.content, default=str))
|
||||
@@ -172,11 +93,10 @@ async def stream_chat(
|
||||
UserMessage(content=json.dumps(message.content, default=str))
|
||||
)
|
||||
|
||||
# Add the current context
|
||||
if context_message:
|
||||
messages.append(SystemMessage(content=context_message))
|
||||
# Add the current user message
|
||||
messages.append(UserMessage(content=payload.prompt))
|
||||
if ai_context["context_message"]:
|
||||
messages.append(SystemMessage(content=ai_context["context_message"]))
|
||||
|
||||
messages.append(UserMessage(content=ai_context["user_prompt"]))
|
||||
|
||||
async def generate():
|
||||
response = await client.chat.stream_async(model=model, messages=messages)
|
||||
@@ -187,18 +107,8 @@ async def stream_chat(
|
||||
accumulated_content.append(content_chunk)
|
||||
yield f"data: {json.dumps({'content': content_chunk})}\n\n"
|
||||
|
||||
# Create the bot message after all chunks have been processed
|
||||
chat_message = ChatMessage(
|
||||
id=uuid4(),
|
||||
content="".join(accumulated_content),
|
||||
chat_id=chat_id,
|
||||
is_bot=True,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(chat_message)
|
||||
db.commit()
|
||||
db.refresh(chat_message)
|
||||
# Save bot message after streaming completes
|
||||
service.add_bot_message(chat_id, "".join(accumulated_content))
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@@ -207,63 +117,43 @@ async def stream_chat(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Create a new chat
|
||||
@router.post("/create", response_model=ChatRead, status_code=status.HTTP_201_CREATED)
|
||||
def create_chat(
|
||||
payload: ChatCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
new_chat = Chat(
|
||||
id=uuid4(),
|
||||
service = create_chat_service(db)
|
||||
return service.create(
|
||||
title=payload.title,
|
||||
description=payload.description,
|
||||
owner_id=current_user.id,
|
||||
investigation_id=payload.investigation_id,
|
||||
created_at=datetime.utcnow(),
|
||||
last_updated_at=datetime.utcnow(),
|
||||
owner_id=current_user.id,
|
||||
)
|
||||
db.add(new_chat)
|
||||
db.commit()
|
||||
db.refresh(new_chat)
|
||||
return new_chat
|
||||
|
||||
|
||||
# Get a chat by ID
|
||||
@router.get("/{chat_id}", response_model=ChatRead)
|
||||
def get_chat_by_id(
|
||||
chat_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
chat = (
|
||||
db.query(Chat)
|
||||
.filter(Chat.id == chat_id, Chat.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not chat:
|
||||
service = create_chat_service(db)
|
||||
try:
|
||||
return service.get_by_id(chat_id, current_user.id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
# Sort messages by created_at in ascending order
|
||||
chat.messages.sort(key=lambda x: x.created_at)
|
||||
|
||||
return chat
|
||||
|
||||
|
||||
# Delete an chat by ID
|
||||
@router.delete("/{chat_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_chat(
|
||||
chat_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
chat = (
|
||||
db.query(Chat)
|
||||
.filter(Chat.id == chat_id, Chat.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not chat:
|
||||
service = create_chat_service(db)
|
||||
try:
|
||||
service.delete(chat_id, current_user.id)
|
||||
return None
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
db.delete(chat)
|
||||
db.commit()
|
||||
return None
|
||||
|
||||
@@ -3,8 +3,15 @@ from uuid import UUID
|
||||
from typing import List
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.models import CustomType, Profile
|
||||
from flowsint_core.core.models import Profile
|
||||
from flowsint_core.core.services import (
|
||||
create_custom_type_service,
|
||||
NotFoundError,
|
||||
ValidationError,
|
||||
ConflictError,
|
||||
)
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.schemas.custom_type import (
|
||||
CustomTypeCreate,
|
||||
@@ -29,47 +36,22 @@ def create_custom_type(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Create a new custom type.
|
||||
|
||||
Validates the JSON Schema and stores it in the database.
|
||||
"""
|
||||
# Validate the JSON Schema
|
||||
validate_json_schema(custom_type.json_schema)
|
||||
|
||||
# Calculate checksum
|
||||
checksum = calculate_schema_checksum(custom_type.json_schema)
|
||||
|
||||
# Check for duplicate name for this user
|
||||
existing = (
|
||||
db.query(CustomType)
|
||||
.filter(
|
||||
CustomType.owner_id == current_user.id,
|
||||
CustomType.name == custom_type.name
|
||||
"""Create a new custom type."""
|
||||
service = create_custom_type_service(db)
|
||||
try:
|
||||
return service.create(
|
||||
name=custom_type.name,
|
||||
json_schema=custom_type.json_schema,
|
||||
user_id=current_user.id,
|
||||
description=custom_type.description,
|
||||
status=custom_type.status,
|
||||
validate_schema_func=validate_json_schema,
|
||||
calculate_checksum_func=calculate_schema_checksum,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Custom type with name '{custom_type.name}' already exists"
|
||||
)
|
||||
|
||||
# Create the custom type
|
||||
db_custom_type = CustomType(
|
||||
name=custom_type.name,
|
||||
owner_id=current_user.id,
|
||||
schema=custom_type.json_schema,
|
||||
description=custom_type.description,
|
||||
status=custom_type.status,
|
||||
checksum=checksum,
|
||||
)
|
||||
|
||||
db.add(db_custom_type)
|
||||
db.commit()
|
||||
db.refresh(db_custom_type)
|
||||
|
||||
return db_custom_type
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except ConflictError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("", response_model=List[CustomTypeRead])
|
||||
@@ -78,23 +60,12 @@ def list_custom_types(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
List all custom types for the current user.
|
||||
|
||||
Can be filtered by status (draft, published, archived).
|
||||
"""
|
||||
query = db.query(CustomType).filter(CustomType.owner_id == current_user.id)
|
||||
|
||||
if status:
|
||||
if status not in ["draft", "published", "archived"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Status must be one of: draft, published, archived"
|
||||
)
|
||||
query = query.filter(CustomType.status == status)
|
||||
|
||||
custom_types = query.order_by(CustomType.created_at.desc()).all()
|
||||
return custom_types
|
||||
"""List all custom types for the current user."""
|
||||
service = create_custom_type_service(db)
|
||||
try:
|
||||
return service.list_custom_types(current_user.id, status)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=CustomTypeRead)
|
||||
@@ -104,19 +75,11 @@ def get_custom_type(
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""Get a specific custom type by ID."""
|
||||
custom_type = (
|
||||
db.query(CustomType)
|
||||
.filter(CustomType.id == id, CustomType.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not custom_type:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Custom type not found"
|
||||
)
|
||||
|
||||
return custom_type
|
||||
service = create_custom_type_service(db)
|
||||
try:
|
||||
return service.get_by_id(id, current_user.id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Custom type not found")
|
||||
|
||||
|
||||
@router.get("/{id}/schema")
|
||||
@@ -126,19 +89,11 @@ def get_custom_type_schema(
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""Get the raw JSON Schema for a custom type."""
|
||||
custom_type = (
|
||||
db.query(CustomType)
|
||||
.filter(CustomType.id == id, CustomType.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not custom_type:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Custom type not found"
|
||||
)
|
||||
|
||||
return custom_type.schema
|
||||
service = create_custom_type_service(db)
|
||||
try:
|
||||
return service.get_schema(id, current_user.id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Custom type not found")
|
||||
|
||||
|
||||
@router.put("/{id}", response_model=CustomTypeRead)
|
||||
@@ -148,58 +103,25 @@ def update_custom_type(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Update a custom type.
|
||||
|
||||
If the schema is changed, a new checksum is calculated.
|
||||
"""
|
||||
custom_type = (
|
||||
db.query(CustomType)
|
||||
.filter(CustomType.id == id, CustomType.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not custom_type:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Custom type not found"
|
||||
"""Update a custom type."""
|
||||
service = create_custom_type_service(db)
|
||||
try:
|
||||
return service.update(
|
||||
custom_type_id=id,
|
||||
user_id=current_user.id,
|
||||
name=update_data.name,
|
||||
json_schema=update_data.json_schema,
|
||||
description=update_data.description,
|
||||
status=update_data.status,
|
||||
validate_schema_func=validate_json_schema,
|
||||
calculate_checksum_func=calculate_schema_checksum,
|
||||
)
|
||||
|
||||
# Update fields
|
||||
if update_data.name is not None:
|
||||
# Check for duplicate name
|
||||
existing = (
|
||||
db.query(CustomType)
|
||||
.filter(
|
||||
CustomType.owner_id == current_user.id,
|
||||
CustomType.name == update_data.name,
|
||||
CustomType.id != id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Custom type with name '{update_data.name}' already exists"
|
||||
)
|
||||
custom_type.name = update_data.name
|
||||
|
||||
if update_data.json_schema is not None:
|
||||
# Validate the new schema
|
||||
validate_json_schema(update_data.json_schema)
|
||||
custom_type.schema = update_data.json_schema
|
||||
custom_type.checksum = calculate_schema_checksum(update_data.json_schema)
|
||||
|
||||
if update_data.description is not None:
|
||||
custom_type.description = update_data.description
|
||||
|
||||
if update_data.status is not None:
|
||||
custom_type.status = update_data.status
|
||||
|
||||
db.commit()
|
||||
db.refresh(custom_type)
|
||||
|
||||
return custom_type
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Custom type not found")
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except ConflictError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@@ -209,22 +131,12 @@ def delete_custom_type(
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a custom type."""
|
||||
custom_type = (
|
||||
db.query(CustomType)
|
||||
.filter(CustomType.id == id, CustomType.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not custom_type:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Custom type not found"
|
||||
)
|
||||
|
||||
db.delete(custom_type)
|
||||
db.commit()
|
||||
|
||||
return None
|
||||
service = create_custom_type_service(db)
|
||||
try:
|
||||
service.delete(id, current_user.id)
|
||||
return None
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Custom type not found")
|
||||
|
||||
|
||||
@router.post("/{id}/validate", response_model=CustomTypeValidateResponse)
|
||||
@@ -234,30 +146,18 @@ def validate_payload(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Validate a payload against a custom type's schema.
|
||||
|
||||
This is useful for testing before publishing a type.
|
||||
"""
|
||||
custom_type = (
|
||||
db.query(CustomType)
|
||||
.filter(CustomType.id == id, CustomType.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not custom_type:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Custom type not found"
|
||||
"""Validate a payload against a custom type's schema."""
|
||||
service = create_custom_type_service(db)
|
||||
try:
|
||||
is_valid, errors = service.validate_payload(
|
||||
id,
|
||||
current_user.id,
|
||||
payload_data.payload,
|
||||
validate_payload_func=validate_payload_against_schema,
|
||||
)
|
||||
|
||||
# Validate the payload against the schema
|
||||
is_valid, errors = validate_payload_against_schema(
|
||||
payload_data.payload,
|
||||
custom_type.schema
|
||||
)
|
||||
|
||||
return CustomTypeValidateResponse(
|
||||
valid=is_valid,
|
||||
errors=errors if not is_valid else None
|
||||
)
|
||||
return CustomTypeValidateResponse(
|
||||
valid=is_valid,
|
||||
errors=errors if not is_valid else None,
|
||||
)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Custom type not found")
|
||||
|
||||
@@ -1,38 +1,20 @@
|
||||
from typing import Any, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from flowsint_core.core.celery import celery
|
||||
from flowsint_core.core.graph import create_graph_service, GraphEdge, GraphNode
|
||||
from flowsint_core.core.models import CustomType, Profile
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.types import FlowBranch
|
||||
from flowsint_enrichers import ENRICHER_REGISTRY, load_all_enrichers
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from flowsint_core.core.celery import celery
|
||||
from flowsint_core.core.graph import create_graph_service
|
||||
from flowsint_core.core.models import Profile
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.services import create_enricher_service
|
||||
from flowsint_enrichers import ENRICHER_REGISTRY, load_all_enrichers
|
||||
from app.api.deps import get_current_user
|
||||
|
||||
# Auto-discover and register all enrichers
|
||||
load_all_enrichers()
|
||||
|
||||
|
||||
class FlowComputationRequest(BaseModel):
|
||||
nodes: List[GraphNode]
|
||||
edges: List[GraphEdge]
|
||||
inputType: Optional[str] = None
|
||||
|
||||
|
||||
class FlowComputationResponse(BaseModel):
|
||||
flowBranches: List[FlowBranch]
|
||||
initialData: Any
|
||||
|
||||
|
||||
class StepSimulationRequest(BaseModel):
|
||||
flowBranches: List[FlowBranch]
|
||||
currentStepIndex: int
|
||||
|
||||
|
||||
class launchEnricherPayload(BaseModel):
|
||||
node_ids: List[str]
|
||||
sketch_id: str
|
||||
@@ -41,30 +23,15 @@ class launchEnricherPayload(BaseModel):
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Get the list of all enrichers
|
||||
@router.get("/")
|
||||
def get_enrichers(
|
||||
category: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
if not category or category.lower() == "undefined":
|
||||
return ENRICHER_REGISTRY.list(exclude=["n8n_connector"])
|
||||
# Si catégorie custom
|
||||
custom_type = (
|
||||
db.query(CustomType)
|
||||
.filter(
|
||||
CustomType.owner_id == current_user.id,
|
||||
CustomType.status == "published",
|
||||
func.lower(CustomType.name) == category.lower(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if custom_type:
|
||||
return ENRICHER_REGISTRY.list(exclude=["n8n_connector"], wobbly_type=True)
|
||||
|
||||
return ENRICHER_REGISTRY.list_by_input_type(category, exclude=["n8n_connector"])
|
||||
"""Get all enrichers, optionally filtered by category."""
|
||||
service = create_enricher_service(db)
|
||||
return service.get_enrichers(category, current_user.id, ENRICHER_REGISTRY)
|
||||
|
||||
|
||||
@router.post("/{enricher_name}/launch")
|
||||
@@ -78,12 +45,13 @@ async def launch_enricher(
|
||||
graph_service = create_graph_service(sketch_id=payload.sketch_id)
|
||||
entities = graph_service.get_nodes_by_ids_for_task(payload.node_ids)
|
||||
|
||||
# send deserialized nodes
|
||||
# Send deserialized nodes
|
||||
entities = [entity.model_dump(mode="json", serialize_as_any=True) for entity in entities]
|
||||
if not entities:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No entities found with provided IDs"
|
||||
)
|
||||
|
||||
task = celery.send_task(
|
||||
"run_enricher",
|
||||
args=[
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.models import Log, Sketch, Scan
|
||||
from flowsint_core.core.events import event_emitter
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from flowsint_core.core.types import Event
|
||||
from app.api.deps import get_current_user, get_current_user_sse
|
||||
from flowsint_core.core.models import Profile, Sketch
|
||||
from app.security.permissions import check_investigation_permission
|
||||
import json
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.events import event_emitter
|
||||
from flowsint_core.core.models import Profile
|
||||
from flowsint_core.core.services import (
|
||||
create_log_service,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
DatabaseError,
|
||||
)
|
||||
from app.api.deps import get_current_user, get_current_user_sse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -21,62 +25,16 @@ def get_logs_by_sketch(
|
||||
limit: int = 100,
|
||||
since: datetime | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user)
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""Get historical logs for a specific sketch with optional filtering"""
|
||||
# Check if sketch exists
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Sketch with id {sketch_id} not found"
|
||||
)
|
||||
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
|
||||
print(
|
||||
f"[EventEmitter] Fetching logs for sketch {sketch_id} (limit: {limit}, since: {since})"
|
||||
)
|
||||
query = (
|
||||
db.query(Log).filter(Log.sketch_id == sketch_id).order_by(Log.created_at.desc())
|
||||
)
|
||||
|
||||
if since:
|
||||
query = query.filter(Log.created_at > since)
|
||||
else:
|
||||
# Default to last 24 hours if no since parameter
|
||||
query = query.filter(Log.created_at > datetime.utcnow() - timedelta(days=1))
|
||||
|
||||
logs = query.limit(limit).all()
|
||||
|
||||
# Reverse to show chronologically (oldest to newest)
|
||||
logs = list(reversed(logs))
|
||||
|
||||
results = []
|
||||
for log in logs:
|
||||
# Ensure payload is always a dictionary
|
||||
if isinstance(log.content, dict):
|
||||
payload = log.content
|
||||
elif isinstance(log.content, str):
|
||||
payload = {"message": log.content}
|
||||
elif log.content is None:
|
||||
payload = {}
|
||||
else:
|
||||
# Handle other types by converting to string and wrapping
|
||||
payload = {"content": str(log.content)}
|
||||
|
||||
results.append(
|
||||
Event(
|
||||
id=str(log.id),
|
||||
sketch_id=str(log.sketch_id) if log.sketch_id else None,
|
||||
type=log.type,
|
||||
payload=payload,
|
||||
created_at=log.created_at
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
"""Get historical logs for a specific sketch with optional filtering."""
|
||||
service = create_log_service(db)
|
||||
try:
|
||||
return service.get_logs_by_sketch(sketch_id, current_user.id, limit, since)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
@router.get("/sketch/{sketch_id}/stream")
|
||||
@@ -84,26 +42,22 @@ async def stream_events(
|
||||
request: Request,
|
||||
sketch_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user_sse)
|
||||
current_user: Profile = Depends(get_current_user_sse),
|
||||
):
|
||||
"""Stream events for a specific scan in real-time"""
|
||||
|
||||
# Check if sketch exists
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Sketch with id {sketch_id} not found"
|
||||
)
|
||||
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
"""Stream events for a specific sketch in real-time."""
|
||||
service = create_log_service(db)
|
||||
try:
|
||||
# Verify permission
|
||||
service._get_sketch_with_permission(sketch_id, current_user.id, ["read"])
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
async def event_generator():
|
||||
channel = sketch_id
|
||||
await event_emitter.subscribe(channel)
|
||||
try:
|
||||
# Initial connection message
|
||||
yield 'data: {"event": "connected", "data": "Connected to log stream"}\n\n'
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
@@ -111,15 +65,12 @@ async def stream_events(
|
||||
|
||||
data = await event_emitter.get_message(channel)
|
||||
if data is None:
|
||||
await asyncio.sleep(0.1) # avoid tight loop on None
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Handle different types of events
|
||||
if isinstance(data, dict) and data.get("type") == "enricher_complete":
|
||||
# Send enricher completion event
|
||||
yield json.dumps({"event": "enricher_complete", "data": data})
|
||||
else:
|
||||
# Send regular log event
|
||||
yield json.dumps({"event": "log", "data": data})
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@@ -147,25 +98,16 @@ def delete_scan_logs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""Delete all logs for a specific scan"""
|
||||
# Check if sketch exists and user has permission
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Sketch with id {sketch_id} not found"
|
||||
)
|
||||
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["delete"], db=db
|
||||
)
|
||||
|
||||
"""Delete all logs for a specific sketch."""
|
||||
service = create_log_service(db)
|
||||
try:
|
||||
db.query(Log).filter(Log.sketch_id == sketch_id).delete()
|
||||
db.commit()
|
||||
return {"message": f"All logs have been deleted successfully"}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete logs: {str(e)}")
|
||||
return service.delete_logs_by_sketch(sketch_id, current_user.id)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except DatabaseError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/sketch/{sketch_id}/status/stream")
|
||||
@@ -173,26 +115,21 @@ async def stream_sketch_status(
|
||||
request: Request,
|
||||
sketch_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user_sse)
|
||||
current_user: Profile = Depends(get_current_user_sse),
|
||||
):
|
||||
"""Stream COMPLETED events for a specific sketch (for graph refresh)"""
|
||||
|
||||
# Check if sketch exists
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Sketch with id {sketch_id} not found"
|
||||
)
|
||||
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
"""Stream COMPLETED events for a specific sketch (for graph refresh)."""
|
||||
service = create_log_service(db)
|
||||
try:
|
||||
service._get_sketch_with_permission(sketch_id, current_user.id, ["read"])
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
async def status_generator():
|
||||
channel = f"{sketch_id}_status"
|
||||
await event_emitter.subscribe(channel)
|
||||
try:
|
||||
# Initial connection message
|
||||
yield json.dumps({"event": "connected", "data": "Connected to status stream"})
|
||||
|
||||
while True:
|
||||
@@ -204,7 +141,6 @@ async def stream_sketch_status(
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Send status event
|
||||
yield json.dumps({"event": "status", "data": data})
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@@ -231,29 +167,21 @@ async def stream_status(
|
||||
request: Request,
|
||||
scan_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user_sse)
|
||||
current_user: Profile = Depends(get_current_user_sse),
|
||||
):
|
||||
"""Stream status updates for a specific scan in real-time"""
|
||||
|
||||
# Check if scan exists and user has permission
|
||||
scan = db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
if not scan:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Scan with id {scan_id} not found"
|
||||
)
|
||||
|
||||
# Check investigation permission via sketch
|
||||
sketch = db.query(Sketch).filter(Sketch.id == scan.sketch_id).first()
|
||||
if sketch:
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
"""Stream status updates for a specific scan in real-time."""
|
||||
service = create_log_service(db)
|
||||
try:
|
||||
service.get_scan_with_permission(scan_id, current_user.id)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
async def status_generator():
|
||||
print("[EventEmitter] Start status generator")
|
||||
await event_emitter.subscribe(f"scan_{scan_id}_status")
|
||||
try:
|
||||
# Initial connection message
|
||||
yield 'data: {"event": "connected", "data": "Connected to status stream"}\n\n'
|
||||
|
||||
while True:
|
||||
@@ -265,9 +193,7 @@ async def stream_status(
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
except asyncio.CancelledError:
|
||||
print(
|
||||
f"[EventEmitter] Client disconnected from status stream for scan_id: {scan_id}"
|
||||
)
|
||||
print(f"[EventEmitter] Client disconnected from status stream for scan_id: {scan_id}")
|
||||
finally:
|
||||
await event_emitter.unsubscribe(f"scan_{scan_id}_status")
|
||||
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from uuid import UUID
|
||||
|
||||
# Auto-discover and register all enrichers
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from flowsint_core.core.celery import celery
|
||||
from flowsint_core.core.graph import create_graph_service
|
||||
from flowsint_core.core.models import CustomType, Flow, Profile, Sketch
|
||||
from flowsint_core.core.models import Profile
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.services import (
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
create_flow_service,
|
||||
)
|
||||
from flowsint_core.core.types import FlowBranch, FlowEdge, FlowNode, FlowStep
|
||||
from flowsint_core.utils import extract_input_schema_flow
|
||||
from flowsint_enrichers import ENRICHER_REGISTRY, load_all_enrichers
|
||||
@@ -31,12 +34,10 @@ from flowsint_types import (
|
||||
Website,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.schemas.flow import FlowCreate, FlowRead, FlowUpdate
|
||||
from app.security.permissions import check_investigation_permission
|
||||
|
||||
load_all_enrichers()
|
||||
|
||||
@@ -71,37 +72,10 @@ def get_flows(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
if not category or category.lower() == "undefined":
|
||||
return db.query(Flow).order_by(Flow.last_updated_at.desc()).all()
|
||||
|
||||
custom_type = (
|
||||
db.query(CustomType)
|
||||
.filter(
|
||||
CustomType.owner_id == current_user.id,
|
||||
CustomType.status == "published",
|
||||
func.lower(CustomType.name) == category.lower(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if custom_type:
|
||||
flows = db.query(Flow).order_by(Flow.last_updated_at.desc()).all()
|
||||
return [
|
||||
{
|
||||
**(flow.to_dict() if hasattr(flow, "to_dict") else flow.__dict__),
|
||||
"wobblyType": True,
|
||||
}
|
||||
for flow in flows
|
||||
]
|
||||
|
||||
flows = db.query(Flow).order_by(Flow.last_updated_at.desc()).all()
|
||||
return [
|
||||
flow
|
||||
for flow in flows
|
||||
if any(cat.lower() == category.lower() for cat in flow.category)
|
||||
]
|
||||
service = create_flow_service(db)
|
||||
return service.get_all_flows(category, current_user.id)
|
||||
|
||||
|
||||
# Returns the "raw_materials" for the flow editor
|
||||
@router.get("/raw_materials")
|
||||
async def get_material_list():
|
||||
enrichers = ENRICHER_REGISTRY.list_by_categories()
|
||||
@@ -147,56 +121,46 @@ async def get_material_list():
|
||||
extract_input_schema_flow(CryptoNFT),
|
||||
]
|
||||
|
||||
# Put types first, then add all enricher categories
|
||||
flattened_enrichers = {"types": object_inputs}
|
||||
flattened_enrichers.update(enricher_categories)
|
||||
|
||||
return {"items": flattened_enrichers}
|
||||
|
||||
|
||||
# Returns the "raw_materials" for the flow editor
|
||||
@router.get("/input_type/{input_type}")
|
||||
async def get_material_list(input_type: str):
|
||||
async def get_material_by_input_type(input_type: str):
|
||||
enrichers = ENRICHER_REGISTRY.list_by_input_type(input_type)
|
||||
return {"items": enrichers}
|
||||
|
||||
|
||||
# Create a new flow
|
||||
@router.post("/create", response_model=FlowRead, status_code=status.HTTP_201_CREATED)
|
||||
def create_flow(
|
||||
payload: FlowCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
new_flow = Flow(
|
||||
id=uuid4(),
|
||||
service = create_flow_service(db)
|
||||
return service.create(
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
category=payload.category,
|
||||
flow_schema=payload.flow_schema,
|
||||
created_at=datetime.utcnow(),
|
||||
last_updated_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(new_flow)
|
||||
db.commit()
|
||||
db.refresh(new_flow)
|
||||
return new_flow
|
||||
|
||||
|
||||
# Get a flow by ID
|
||||
@router.get("/{flow_id}", response_model=FlowRead)
|
||||
def get_flow_by_id(
|
||||
flow_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
flow = db.query(Flow).filter(Flow.id == flow_id).first()
|
||||
if not flow:
|
||||
raise HTTPException(status_code=404, detail="flow not found")
|
||||
return flow
|
||||
service = create_flow_service(db)
|
||||
try:
|
||||
return service.get_by_id(flow_id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Flow not found")
|
||||
|
||||
|
||||
# Update a flow by ID
|
||||
@router.put("/{flow_id}", response_model=FlowRead)
|
||||
def update_flow(
|
||||
flow_id: UUID,
|
||||
@@ -204,37 +168,25 @@ def update_flow(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
flow = db.query(Flow).filter(Flow.id == flow_id).first()
|
||||
if not flow:
|
||||
raise HTTPException(status_code=404, detail="flow not found")
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
print(f"only update {key}")
|
||||
if key == "category":
|
||||
if "SocialAccount" in value:
|
||||
value.append("Username")
|
||||
setattr(flow, key, value)
|
||||
|
||||
flow.last_updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(flow)
|
||||
return flow
|
||||
service = create_flow_service(db)
|
||||
try:
|
||||
return service.update(flow_id, payload.model_dump(exclude_unset=True))
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Flow not found")
|
||||
|
||||
|
||||
# Delete a flow by ID
|
||||
@router.delete("/{flow_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_flow(
|
||||
flow_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
flow = db.query(Flow).filter(Flow.id == flow_id).first()
|
||||
if not flow:
|
||||
raise HTTPException(status_code=404, detail="flow not found")
|
||||
db.delete(flow)
|
||||
db.commit()
|
||||
return None
|
||||
service = create_flow_service(db)
|
||||
try:
|
||||
service.delete(flow_id)
|
||||
return None
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Flow not found")
|
||||
|
||||
|
||||
@router.post("/{flow_id}/launch")
|
||||
@@ -244,19 +196,10 @@ async def launch_flow(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
service = create_flow_service(db)
|
||||
try:
|
||||
flow = db.query(Flow).filter(Flow.id == flow_id).first()
|
||||
if flow is None:
|
||||
raise HTTPException(status_code=404, detail="flow not found")
|
||||
|
||||
# Check investigation permission via sketch
|
||||
sketch = db.query(Sketch).filter(Sketch.id == payload.sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
flow = service.get_by_id(UUID(flow_id))
|
||||
service.get_sketch_for_launch(payload.sketch_id, current_user.id)
|
||||
|
||||
# Retrieve entities from Neo4J by their element IDs
|
||||
graph_service = create_graph_service(sketch_id=payload.sketch_id)
|
||||
@@ -266,13 +209,14 @@ async def launch_flow(
|
||||
nodes = [FlowNode(**node) for node in flow.flow_schema["nodes"]]
|
||||
edges = [FlowEdge(**edge) for edge in flow.flow_schema["edges"]]
|
||||
|
||||
entities = [entity.model_dump(mode="json", serialize_as_any=True) for entity in entities]
|
||||
entities = [
|
||||
entity.model_dump(mode="json", serialize_as_any=True) for entity in entities
|
||||
]
|
||||
|
||||
|
||||
# For flow computation, we still need a sample value
|
||||
# Use the label from the first node data
|
||||
sample_value = (
|
||||
entities[0].get("nodeLabel", "sample_value") if len(entities) else "sample_value"
|
||||
entities[0].get("nodeLabel", "sample_value")
|
||||
if len(entities)
|
||||
else "sample_value"
|
||||
)
|
||||
flow_branches = compute_flow_branches(sample_value, nodes, edges)
|
||||
serializable_branches = [branch.model_dump() for branch in flow_branches]
|
||||
@@ -288,8 +232,10 @@ async def launch_flow(
|
||||
)
|
||||
return {"id": task.id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(status_code=500, detail=f"Error launching flow: {str(e)}")
|
||||
@@ -332,7 +278,6 @@ def compute_flow_branches(
|
||||
initial_value: Any, nodes: List[FlowNode], edges: List[FlowEdge]
|
||||
) -> List[FlowBranch]:
|
||||
"""Computes flow branches based on nodes and edges with proper DFS traversal"""
|
||||
# Find input nodes (starting points)
|
||||
input_nodes = [node for node in nodes if node.data.get("type") == "type"]
|
||||
|
||||
if not input_nodes:
|
||||
@@ -357,34 +302,25 @@ def compute_flow_branches(
|
||||
node_map = {node.id: node for node in nodes}
|
||||
branches = []
|
||||
branch_counter = 0
|
||||
# Track enricher outputs across all branches
|
||||
enricher_outputs = {}
|
||||
|
||||
def calculate_path_length(start_node: str, visited: set = None) -> int:
|
||||
"""Calculate the shortest possible path length from a node to any leaf"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
if start_node in visited:
|
||||
return float("inf")
|
||||
|
||||
visited.add(start_node)
|
||||
out_edges = [edge for edge in edges if edge.source == start_node]
|
||||
|
||||
if not out_edges:
|
||||
return 1
|
||||
|
||||
min_length = float("inf")
|
||||
for edge in out_edges:
|
||||
length = calculate_path_length(edge.target, visited.copy())
|
||||
min_length = min(min_length, length)
|
||||
|
||||
return 1 + min_length
|
||||
|
||||
def get_outgoing_edges(node_id: str) -> List[FlowEdge]:
|
||||
"""Get outgoing edges sorted by the shortest possible path length"""
|
||||
out_edges = [edge for edge in edges if edge.source == node_id]
|
||||
# Sort edges by the length of the shortest possible path from their target
|
||||
return sorted(out_edges, key=lambda e: calculate_path_length(e.target))
|
||||
|
||||
def create_step(
|
||||
@@ -420,7 +356,6 @@ def compute_flow_branches(
|
||||
) -> None:
|
||||
nonlocal branch_counter
|
||||
|
||||
# Skip if node is already in current path (cycle detection)
|
||||
if current_node_id in path:
|
||||
return
|
||||
|
||||
@@ -428,7 +363,6 @@ def compute_flow_branches(
|
||||
if not current_node:
|
||||
return
|
||||
|
||||
# Process node outputs
|
||||
is_input_node = current_node.data.get("type") == "type"
|
||||
if is_input_node:
|
||||
outputs_array = current_node.data["outputs"].get("properties", [])
|
||||
@@ -437,18 +371,14 @@ def compute_flow_branches(
|
||||
)
|
||||
current_outputs = {first_output_name: initial_value}
|
||||
else:
|
||||
# Check if we already have outputs for this enricher
|
||||
if current_node_id in enricher_outputs:
|
||||
current_outputs = enricher_outputs[current_node_id]
|
||||
else:
|
||||
current_outputs = process_node_data(current_node, input_data)
|
||||
# Store the outputs for future use
|
||||
enricher_outputs[current_node_id] = current_outputs
|
||||
|
||||
# Extract node parameters
|
||||
node_params = current_node.data.get("params", {})
|
||||
|
||||
# Create and add current step
|
||||
current_step = create_step(
|
||||
current_node_id,
|
||||
branch_id,
|
||||
@@ -462,19 +392,15 @@ def compute_flow_branches(
|
||||
path.append(current_node_id)
|
||||
branch_visited.add(current_node_id)
|
||||
|
||||
# Get all outgoing edges sorted by path length
|
||||
out_edges = get_outgoing_edges(current_node_id)
|
||||
|
||||
if not out_edges:
|
||||
# Leaf node reached, save the branch
|
||||
branches.append(FlowBranch(id=branch_id, name=branch_name, steps=steps[:]))
|
||||
else:
|
||||
# Process each outgoing edge in order of shortest path
|
||||
for i, edge in enumerate(out_edges):
|
||||
if edge.target in path: # Skip if would create cycle
|
||||
if edge.target in path:
|
||||
continue
|
||||
|
||||
# Prepare next node's input
|
||||
output_key = edge.sourceHandle
|
||||
if not output_key and current_outputs:
|
||||
output_key = list(current_outputs.keys())[0]
|
||||
@@ -488,7 +414,6 @@ def compute_flow_branches(
|
||||
next_input = {edge.targetHandle or "input": output_value}
|
||||
|
||||
if i == 0:
|
||||
# Continue in same branch (will be shortest path)
|
||||
explore_branch(
|
||||
edge.target,
|
||||
branch_id,
|
||||
@@ -501,31 +426,26 @@ def compute_flow_branches(
|
||||
current_outputs,
|
||||
)
|
||||
else:
|
||||
# Create new branch starting from current node
|
||||
branch_counter += 1
|
||||
new_branch_id = f"{branch_id}-{branch_counter}"
|
||||
new_branch_name = f"{branch_name} (Branch {branch_counter})"
|
||||
new_steps = steps[: len(steps)] # Copy steps up to current node
|
||||
new_branch_visited = (
|
||||
branch_visited.copy()
|
||||
) # Create new visited set for the branch
|
||||
new_steps = steps[: len(steps)]
|
||||
new_branch_visited = branch_visited.copy()
|
||||
explore_branch(
|
||||
edge.target,
|
||||
new_branch_id,
|
||||
new_branch_name,
|
||||
depth + 1,
|
||||
next_input,
|
||||
path[:], # Create new path copy for branch
|
||||
path[:],
|
||||
new_branch_visited,
|
||||
new_steps,
|
||||
current_outputs,
|
||||
)
|
||||
|
||||
# Backtrack: remove current node from path and remove its step
|
||||
path.pop()
|
||||
steps.pop()
|
||||
|
||||
# Start exploration from each input node
|
||||
for index, input_node in enumerate(input_nodes):
|
||||
branch_id = f"branch-{index}"
|
||||
branch_name = f"Flow {index + 1}" if len(input_nodes) > 1 else "Main Flow"
|
||||
@@ -535,76 +455,61 @@ def compute_flow_branches(
|
||||
branch_name,
|
||||
0,
|
||||
{},
|
||||
[], # Use list for path to maintain order
|
||||
set(), # Use set for visited to check membership
|
||||
[],
|
||||
set(),
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
# Sort branches by length (number of steps)
|
||||
branches.sort(key=lambda branch: len(branch.steps))
|
||||
return branches
|
||||
|
||||
|
||||
def process_node_data(node: FlowNode, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Traite les données de nœud en fonction du type de nœud et des entrées"""
|
||||
"""Process node data based on node type and inputs"""
|
||||
outputs = {}
|
||||
output_types = node.data["outputs"].get("properties", [])
|
||||
|
||||
for output in output_types:
|
||||
output_name = output.get("name", "output")
|
||||
class_name = node.data.get("class_name", "")
|
||||
# For simulation purposes, we'll return a placeholder value based on the enricher type
|
||||
|
||||
if class_name in ["ReverseResolveEnricher", "ResolveEnricher"]:
|
||||
# IP/Domain resolution enrichers
|
||||
outputs[output_name] = (
|
||||
"192.168.1.1" if "ip" in output_name.lower() else "example.com"
|
||||
)
|
||||
elif class_name == "SubdomainEnricher":
|
||||
# Subdomain enricher
|
||||
outputs[output_name] = f"sub.{inputs.get('input', 'example.com')}"
|
||||
|
||||
elif class_name == "WhoisEnricher":
|
||||
# WHOIS enricher
|
||||
outputs[output_name] = {
|
||||
"domain": inputs.get("input", "example.com"),
|
||||
"registrar": "Example Registrar",
|
||||
"creation_date": "2020-01-01",
|
||||
}
|
||||
|
||||
elif class_name == "IpToInfosEnricher":
|
||||
# Geolocation enricher
|
||||
outputs[output_name] = {
|
||||
"country": "France",
|
||||
"city": "Paris",
|
||||
"coordinates": {"lat": 48.8566, "lon": 2.3522},
|
||||
}
|
||||
|
||||
elif class_name == "MaigretEnricher":
|
||||
# Social media enricher
|
||||
outputs[output_name] = {
|
||||
"username": inputs.get("input", "user123"),
|
||||
"platforms": ["twitter", "github", "linkedin"],
|
||||
}
|
||||
|
||||
elif class_name == "HoleheEnricher":
|
||||
# Email verification enricher
|
||||
outputs[output_name] = {
|
||||
"email": inputs.get("input", "user@example.com"),
|
||||
"exists": True,
|
||||
"platforms": ["gmail", "github"],
|
||||
}
|
||||
|
||||
elif class_name == "SireneEnricher":
|
||||
# Organization enricher
|
||||
outputs[output_name] = {
|
||||
"name": inputs.get("input", "Example Corp"),
|
||||
"siret": "12345678901234",
|
||||
"address": "1 Example Street",
|
||||
}
|
||||
|
||||
else:
|
||||
# For unknown enrichers, pass through the input
|
||||
outputs[output_name] = inputs.get("input") or f"flowed_{output_name}"
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
from uuid import UUID, uuid4
|
||||
from app.security.permissions import check_investigation_permission
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from flowsint_core.core.types import Role
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.models import (
|
||||
Analysis,
|
||||
Investigation,
|
||||
InvestigationUserRole,
|
||||
Profile,
|
||||
Sketch,
|
||||
from flowsint_core.core.models import Profile
|
||||
from flowsint_core.core.services import (
|
||||
create_investigation_service,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
DatabaseError,
|
||||
)
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.schemas.investigation import (
|
||||
@@ -21,61 +19,23 @@ from app.api.schemas.investigation import (
|
||||
InvestigationUpdate,
|
||||
)
|
||||
from app.api.schemas.sketch import SketchRead
|
||||
from flowsint_core.core.graph import create_graph_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_user_accessible_investigations(
|
||||
user_id: str, db: Session, allowed_roles: list[Role] = None
|
||||
) -> list[Investigation]:
|
||||
"""
|
||||
Returns all investigations accessible to user depending on its roles
|
||||
"""
|
||||
query = db.query(Investigation).join(
|
||||
InvestigationUserRole,
|
||||
InvestigationUserRole.investigation_id == Investigation.id,
|
||||
)
|
||||
|
||||
query = query.filter(InvestigationUserRole.user_id == user_id)
|
||||
|
||||
if allowed_roles:
|
||||
# ARRAY(Role) contains any of allowed_roles
|
||||
conditions = [InvestigationUserRole.roles.any(role) for role in allowed_roles]
|
||||
# Inclut également le propriétaire de l’investigation
|
||||
query = query.filter(or_(*conditions, Investigation.owner_id == user_id))
|
||||
|
||||
return (
|
||||
query.options(
|
||||
selectinload(Investigation.sketches),
|
||||
selectinload(Investigation.analyses),
|
||||
selectinload(Investigation.owner),
|
||||
)
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
# Get the list of all investigations
|
||||
@router.get("", response_model=List[InvestigationRead])
|
||||
def get_investigations(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Récupère toutes les investigations accessibles à l'utilisateur
|
||||
selon ses rôles (OWNER, EDITOR, VIEWER).
|
||||
"""
|
||||
allowed_roles_for_read = [Role.OWNER, Role.EDITOR, Role.VIEWER]
|
||||
|
||||
investigations = get_user_accessible_investigations(
|
||||
user_id=current_user.id, db=db, allowed_roles=allowed_roles_for_read
|
||||
"""Get all investigations accessible to the user based on their roles."""
|
||||
service = create_investigation_service(db)
|
||||
allowed_roles = [Role.OWNER, Role.EDITOR, Role.VIEWER]
|
||||
return service.get_accessible_investigations(
|
||||
user_id=current_user.id, allowed_roles=allowed_roles
|
||||
)
|
||||
|
||||
return investigations
|
||||
|
||||
|
||||
# Create a new investigation
|
||||
@router.post(
|
||||
"/create", response_model=InvestigationRead, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
@@ -84,73 +44,46 @@ def create_investigation(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
new_investigation = Investigation(
|
||||
id=uuid4(),
|
||||
service = create_investigation_service(db)
|
||||
return service.create(
|
||||
name=payload.name,
|
||||
description=payload.description or payload.name,
|
||||
description=payload.description,
|
||||
owner_id=current_user.id,
|
||||
status="active",
|
||||
)
|
||||
db.add(new_investigation)
|
||||
|
||||
new_roles = InvestigationUserRole(
|
||||
id=uuid4(),
|
||||
user_id=current_user.id,
|
||||
investigation_id=new_investigation.id,
|
||||
roles=[Role.OWNER],
|
||||
)
|
||||
db.add(new_roles)
|
||||
|
||||
db.commit()
|
||||
db.refresh(new_investigation)
|
||||
db.refresh(new_roles)
|
||||
|
||||
return new_investigation
|
||||
|
||||
|
||||
# Get a investigation by ID
|
||||
@router.get("/{investigation_id}", response_model=InvestigationRead)
|
||||
def get_investigation_by_id(
|
||||
investigation_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
check_investigation_permission(current_user.id, investigation_id, actions=["read"], db=db)
|
||||
investigation = (
|
||||
db.query(Investigation)
|
||||
.options(
|
||||
selectinload(Investigation.sketches),
|
||||
selectinload(Investigation.analyses),
|
||||
selectinload(Investigation.owner),
|
||||
)
|
||||
.filter(Investigation.id == investigation_id)
|
||||
.filter(Investigation.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not investigation:
|
||||
service = create_investigation_service(db)
|
||||
try:
|
||||
return service.get_by_id(investigation_id, current_user.id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Investigation not found")
|
||||
return investigation
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
# Get a investigation by ID
|
||||
@router.get("/{investigation_id}/sketches", response_model=List[SketchRead])
|
||||
def get_sketches_by_investigation(
|
||||
investigation_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
check_investigation_permission(current_user.id, investigation_id, actions=["read"], db=db)
|
||||
sketches = (
|
||||
db.query(Sketch).filter(Sketch.investigation_id == investigation_id).all()
|
||||
)
|
||||
if not sketches:
|
||||
service = create_investigation_service(db)
|
||||
try:
|
||||
return service.get_sketches(investigation_id, current_user.id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No sketches found for this investigation"
|
||||
)
|
||||
return sketches
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
# Update a investigation by ID
|
||||
@router.put("/{investigation_id}", response_model=InvestigationRead)
|
||||
def update_investigation(
|
||||
investigation_id: UUID,
|
||||
@@ -158,69 +91,34 @@ def update_investigation(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
check_investigation_permission(current_user.id, investigation_id, actions=["write"], db=db)
|
||||
investigation = (
|
||||
db.query(Investigation).filter(Investigation.id == investigation_id).first()
|
||||
)
|
||||
if not investigation:
|
||||
service = create_investigation_service(db)
|
||||
try:
|
||||
return service.update(
|
||||
investigation_id=investigation_id,
|
||||
user_id=current_user.id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
status=payload.status,
|
||||
)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Investigation not found")
|
||||
|
||||
investigation.name = payload.name
|
||||
investigation.description = payload.description
|
||||
investigation.status = payload.status
|
||||
investigation.last_updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(investigation)
|
||||
return investigation
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
# Delete a investigation by ID
|
||||
@router.delete("/{investigation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_investigation(
|
||||
investigation_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
check_investigation_permission(current_user.id, investigation_id, actions=["delete"], db=db)
|
||||
investigation = (
|
||||
db.query(Investigation)
|
||||
.filter(
|
||||
Investigation.id == investigation_id,
|
||||
Investigation.owner_id == current_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not investigation:
|
||||
service = create_investigation_service(db)
|
||||
try:
|
||||
service.delete(investigation_id, current_user.id)
|
||||
return None
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Investigation not found")
|
||||
|
||||
# Get all sketches related to this investigation
|
||||
sketches = (
|
||||
db.query(Sketch).filter(Sketch.investigation_id == investigation_id).all()
|
||||
)
|
||||
analyses = (
|
||||
db.query(Analysis).filter(Sketch.investigation_id == investigation_id).all()
|
||||
)
|
||||
|
||||
# Delete all nodes and relationships for each sketch in Neo4j using GraphService
|
||||
for sketch in sketches:
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch.id),
|
||||
enable_batching=False,
|
||||
)
|
||||
graph_service.delete_all_sketch_nodes()
|
||||
except Exception as e:
|
||||
print(f"Neo4j cleanup error for sketch {sketch.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to clean up graph data")
|
||||
|
||||
# Delete all sketches from PostgreSQL
|
||||
for sketch in sketches:
|
||||
db.delete(sketch)
|
||||
for analysis in analyses:
|
||||
db.delete(analysis)
|
||||
|
||||
# Finally delete the investigation
|
||||
db.delete(investigation)
|
||||
db.commit()
|
||||
return None
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except DatabaseError:
|
||||
raise HTTPException(status_code=500, detail="Failed to clean up graph data")
|
||||
|
||||
@@ -1,23 +1,28 @@
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from typing import List
|
||||
from flowsint_core.core.vault import Vault
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from flowsint_core.core.services import (
|
||||
create_key_service,
|
||||
NotFoundError,
|
||||
DatabaseError,
|
||||
)
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.models import Profile, Key
|
||||
from flowsint_core.core.models import Profile
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.schemas.key import KeyRead, KeyCreate
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Get the list of all keys for a user, just the public method for viewing
|
||||
@router.get("", response_model=List[KeyRead])
|
||||
def get_keys(
|
||||
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
|
||||
):
|
||||
keys = db.query(Key).filter(Key.owner_id == current_user.id).all()
|
||||
response_data = [
|
||||
service = create_key_service(db)
|
||||
keys = service.get_keys_for_user(current_user.id)
|
||||
return [
|
||||
KeyRead(
|
||||
id=key.id,
|
||||
owner_id=key.owner_id,
|
||||
@@ -26,61 +31,57 @@ def get_keys(
|
||||
)
|
||||
for key in keys
|
||||
]
|
||||
return response_data
|
||||
|
||||
|
||||
# Get a key by ID, just the public method for viewing
|
||||
@router.get("/{id}", response_model=KeyRead)
|
||||
def get_key_by_id(
|
||||
id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
key = db.query(Key).filter(Key.id == id, Key.owner_id == current_user.id).first()
|
||||
if not key:
|
||||
service = create_key_service(db)
|
||||
try:
|
||||
key = service.get_key_by_id(id, current_user.id)
|
||||
return KeyRead(
|
||||
id=key.id,
|
||||
owner_id=key.owner_id,
|
||||
name=key.name,
|
||||
created_at=key.created_at,
|
||||
)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
# Create a response with obfuscated key
|
||||
response_data = KeyRead(
|
||||
id=key.id,
|
||||
owner_id=key.owner_id,
|
||||
name=key.name,
|
||||
created_at=key.created_at,
|
||||
)
|
||||
return response_data
|
||||
|
||||
|
||||
# Create a new key
|
||||
@router.post("/create", response_model=KeyRead, status_code=status.HTTP_201_CREATED)
|
||||
def create_key(
|
||||
payload: KeyCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
service = create_key_service(db)
|
||||
try:
|
||||
vault = Vault(db=db, owner_id=current_user.id)
|
||||
key = vault.set_secret(vault_ref=payload.name, plain_key=payload.key)
|
||||
if not key:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An error occured creating the key."
|
||||
)
|
||||
return key
|
||||
except Exception as e:
|
||||
key = service.create_key(payload.name, payload.key, current_user.id)
|
||||
return KeyRead(
|
||||
id=key.id,
|
||||
owner_id=key.owner_id,
|
||||
name=key.name,
|
||||
created_at=key.created_at,
|
||||
)
|
||||
except DatabaseError:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An error occured creating the key."
|
||||
status_code=500, detail="An error occurred creating the key."
|
||||
)
|
||||
|
||||
|
||||
# Delete a key by ID
|
||||
@router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_key(
|
||||
id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
key = db.query(Key).filter(Key.id == id, Key.owner_id == current_user.id).first()
|
||||
if not key:
|
||||
service = create_key_service(db)
|
||||
try:
|
||||
service.delete_key(id, current_user.id)
|
||||
return None
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
db.delete(key)
|
||||
db.commit()
|
||||
return None
|
||||
|
||||
@@ -1,84 +1,56 @@
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from typing import List
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.models import Scan, Profile, Sketch, InvestigationUserRole
|
||||
from flowsint_core.core.types import Role
|
||||
from flowsint_core.core.models import Profile
|
||||
from flowsint_core.core.services import (
|
||||
create_scan_service,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
)
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.schemas.scan import ScanRead
|
||||
from app.security.permissions import check_investigation_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Get the list of all scans
|
||||
@router.get(
|
||||
"",
|
||||
response_model=List[ScanRead],
|
||||
)
|
||||
@router.get("", response_model=List[ScanRead])
|
||||
def get_scans(
|
||||
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
|
||||
):
|
||||
# Get all scans from sketches in investigations where user has at least VIEWER role
|
||||
allowed_roles_for_read = [Role.OWNER, Role.EDITOR, Role.VIEWER]
|
||||
|
||||
query = db.query(Scan).join(
|
||||
Sketch, Sketch.id == Scan.sketch_id
|
||||
).join(
|
||||
InvestigationUserRole,
|
||||
InvestigationUserRole.investigation_id == Sketch.investigation_id,
|
||||
)
|
||||
|
||||
query = query.filter(InvestigationUserRole.user_id == current_user.id)
|
||||
|
||||
# Filter by allowed roles
|
||||
conditions = [InvestigationUserRole.roles.any(role) for role in allowed_roles_for_read]
|
||||
query = query.filter(or_(*conditions))
|
||||
|
||||
return query.distinct().all()
|
||||
"""Get all scans accessible to the current user."""
|
||||
service = create_scan_service(db)
|
||||
return service.get_accessible_scans(current_user.id)
|
||||
|
||||
|
||||
# Get a scan by ID
|
||||
@router.get("/{id}", response_model=ScanRead)
|
||||
def get_scan_by_id(
|
||||
id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
scan = db.query(Scan).filter(Scan.id == id).first()
|
||||
if not scan:
|
||||
service = create_scan_service(db)
|
||||
try:
|
||||
return service.get_by_id(id, current_user.id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Check investigation permission via sketch
|
||||
sketch = db.query(Sketch).filter(Sketch.id == scan.sketch_id).first()
|
||||
if sketch:
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
|
||||
return scan
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
# Delete a scan by ID
|
||||
@router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_scan_by_id(
|
||||
id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
scan = db.query(Scan).filter(Scan.id == id).first()
|
||||
if not scan:
|
||||
service = create_scan_service(db)
|
||||
try:
|
||||
service.delete(id, current_user.id)
|
||||
return None
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Check investigation permission via sketch
|
||||
sketch = db.query(Sketch).filter(Sketch.id == scan.sketch_id).first()
|
||||
if sketch:
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["delete"], db=db
|
||||
)
|
||||
|
||||
db.delete(scan)
|
||||
db.commit()
|
||||
return None
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
@@ -11,23 +11,29 @@ from fastapi import (
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from flowsint_core.core.graph import create_graph_service, GraphNode
|
||||
from flowsint_core.core.models import Profile, Sketch
|
||||
from flowsint_core.core.graph import GraphNode
|
||||
from flowsint_core.core.models import Profile
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.services import (
|
||||
create_sketch_service,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
ValidationError,
|
||||
DatabaseError,
|
||||
)
|
||||
from flowsint_core.imports import (
|
||||
EntityMapping,
|
||||
ImportService,
|
||||
create_import_service,
|
||||
FileParseResult,
|
||||
)
|
||||
from flowsint_core.utils import flatten
|
||||
from flowsint_core.core.graph import create_graph_service
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.schemas.sketch import SketchCreate, SketchRead, SketchUpdate
|
||||
from app.api.sketch_utils import update_sketch_timestamp
|
||||
from app.security.permissions import check_investigation_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -48,7 +54,6 @@ class RelationshipDeleteInput(BaseModel):
|
||||
relationshipIds: List[str]
|
||||
|
||||
|
||||
|
||||
class NodeEditInput(BaseModel):
|
||||
nodeId: str
|
||||
updates: Dict[str, Any]
|
||||
@@ -68,174 +73,6 @@ class NodeMergeInput(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create", response_model=SketchRead, status_code=status.HTTP_201_CREATED)
|
||||
def create_sketch(
|
||||
data: SketchCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
sketch_data = data.model_dump()
|
||||
investigation_id = sketch_data.get("investigation_id")
|
||||
if not investigation_id:
|
||||
raise HTTPException(status_code=404, detail="Investigation not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, investigation_id, actions=["create"], db=db
|
||||
)
|
||||
sketch_data["owner_id"] = current_user.id
|
||||
sketch = Sketch(**sketch_data)
|
||||
db.add(sketch)
|
||||
db.commit()
|
||||
db.refresh(sketch)
|
||||
return sketch
|
||||
|
||||
|
||||
@router.get("", response_model=List[SketchRead])
|
||||
def list_sketches(
|
||||
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
|
||||
):
|
||||
return db.query(Sketch).filter(Sketch.owner_id == current_user.id).all()
|
||||
|
||||
|
||||
@router.get("/{sketch_id}")
|
||||
def get_sketch_by_id(
|
||||
sketch_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
return sketch
|
||||
|
||||
|
||||
@router.put("/{id}", response_model=SketchRead)
|
||||
def update_sketch(
|
||||
id: UUID,
|
||||
payload: SketchUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
sketch = db.query(Sketch).filter(Sketch.id == id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
for key, value in payload.model_dump(exclude_unset=True).items():
|
||||
setattr(sketch, key, value)
|
||||
db.commit()
|
||||
db.refresh(sketch)
|
||||
return sketch
|
||||
|
||||
|
||||
@router.delete("/{id}", status_code=204)
|
||||
def delete_sketch(
|
||||
id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
sketch = db.query(Sketch).filter(Sketch.id == id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["delete"], db=db
|
||||
)
|
||||
|
||||
# Delete all nodes and relationships in Neo4j first using GraphService
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(id), enable_batching=False
|
||||
)
|
||||
graph_service.delete_all_sketch_nodes()
|
||||
except Exception as e:
|
||||
print(f"Neo4j cleanup error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to clean up graph data")
|
||||
|
||||
# Then delete the sketch from PostgreSQL
|
||||
db.delete(sketch)
|
||||
db.commit()
|
||||
|
||||
|
||||
@router.get("/{sketch_id}/graph")
|
||||
async def get_sketch_nodes(
|
||||
sketch_id: str,
|
||||
format: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Get the nodes and edges for a sketch.
|
||||
Args:
|
||||
id: The ID of the sketch
|
||||
format: Optional format parameter. If "inline", returns inline relationships
|
||||
db: The database session
|
||||
current_user: The current user
|
||||
Returns:
|
||||
A dictionary containing the nodes and relationships for the sketch
|
||||
nds: []
|
||||
rls: []
|
||||
Or if format=inline: List of inline relationship strings
|
||||
"""
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Graph not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
# Get all nodes and relationships using GraphService
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=sketch_id, enable_batching=False
|
||||
)
|
||||
graph_data = graph_service.get_sketch_graph()
|
||||
if format == "inline":
|
||||
from flowsint_core.utils import get_inline_relationships
|
||||
|
||||
return get_inline_relationships(graph_data.nodes, graph_data.edges)
|
||||
|
||||
graph = graph_data.model_dump(mode="json", serialize_as_any=True)
|
||||
|
||||
return {"nds": graph["nodes"], "rls": graph["edges"]}
|
||||
|
||||
|
||||
@router.post("/{sketch_id}/nodes/add")
|
||||
@update_sketch_timestamp
|
||||
def add_node(
|
||||
sketch_id: str,
|
||||
node: GraphNode,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=sketch_id,
|
||||
enable_batching=False,
|
||||
)
|
||||
node_id = graph_service.create_node(node)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
|
||||
|
||||
if not node_id:
|
||||
raise HTTPException(status_code=400, detail="Node creation failed")
|
||||
|
||||
node.id = node_id
|
||||
|
||||
return {
|
||||
"status": "node added",
|
||||
"node": node,
|
||||
}
|
||||
|
||||
|
||||
class RelationInput(BaseModel):
|
||||
source: str
|
||||
target: str
|
||||
@@ -243,88 +80,6 @@ class RelationInput(BaseModel):
|
||||
label: str = "RELATED_TO"
|
||||
|
||||
|
||||
@router.post("/{sketch_id}/relations/add")
|
||||
@update_sketch_timestamp
|
||||
def add_edge(
|
||||
sketch_id: str,
|
||||
relation: RelationInput,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
|
||||
# Create relationship using GraphService
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=sketch_id,
|
||||
enable_batching=False,
|
||||
)
|
||||
result = graph_service.create_relationship_by_element_id(
|
||||
from_element_id=relation.source,
|
||||
to_element_id=relation.target,
|
||||
rel_label=relation.label,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Edge creation error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create edge")
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=400, detail="Edge creation failed")
|
||||
|
||||
return {
|
||||
"status": "edge added",
|
||||
"edge": result,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{sketch_id}/nodes/edit")
|
||||
@update_sketch_timestamp
|
||||
def edit_node(
|
||||
sketch_id: str,
|
||||
node_edit: NodeEditInput,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
# First verify the sketch exists and belongs to the user
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
|
||||
updates = node_edit.updates
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=sketch_id,
|
||||
enable_batching=False,
|
||||
)
|
||||
updated_element_id = graph_service.update_node(
|
||||
element_id=node_edit.nodeId,
|
||||
updates=updates,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Node update error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update node")
|
||||
|
||||
if not updated_element_id:
|
||||
raise HTTPException(status_code=404, detail="Node not found or not accessible")
|
||||
|
||||
return {
|
||||
"status": "node updated",
|
||||
"node": {
|
||||
"id": updated_element_id,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class NodePosition(BaseModel):
|
||||
nodeId: str
|
||||
x: float
|
||||
@@ -335,6 +90,186 @@ class UpdatePositionsInput(BaseModel):
|
||||
positions: List[NodePosition]
|
||||
|
||||
|
||||
class EntityMappingInput(BaseModel):
|
||||
"""Pydantic model for parsing entity mapping input from frontend."""
|
||||
id: str
|
||||
entity_type: str
|
||||
include: bool = True
|
||||
nodeLabel: str
|
||||
node_id: Optional[str] = None
|
||||
data: Dict[str, Any]
|
||||
|
||||
|
||||
class ImportExecuteResponse(BaseModel):
|
||||
"""Response model for import execution."""
|
||||
status: str
|
||||
nodes_created: int
|
||||
nodes_skipped: int
|
||||
errors: List[str]
|
||||
|
||||
|
||||
@router.post("/create", response_model=SketchRead, status_code=status.HTTP_201_CREATED)
|
||||
def create_sketch(
|
||||
data: SketchCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
sketch_data = data.model_dump()
|
||||
return service.create(
|
||||
title=sketch_data.get("title"),
|
||||
description=sketch_data.get("description"),
|
||||
investigation_id=sketch_data.get("investigation_id"),
|
||||
owner_id=current_user.id,
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
@router.get("", response_model=List[SketchRead])
|
||||
def list_sketches(
|
||||
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
|
||||
):
|
||||
service = create_sketch_service(db)
|
||||
return service.list_sketches(current_user.id)
|
||||
|
||||
|
||||
@router.get("/{sketch_id}")
|
||||
def get_sketch_by_id(
|
||||
sketch_id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
return service.get_by_id(sketch_id, current_user.id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
@router.put("/{id}", response_model=SketchRead)
|
||||
def update_sketch(
|
||||
id: UUID,
|
||||
payload: SketchUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
return service.update(id, current_user.id, payload.model_dump(exclude_unset=True))
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
@router.delete("/{id}", status_code=204)
|
||||
def delete_sketch(
|
||||
id: UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
service.delete(id, current_user.id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except DatabaseError:
|
||||
raise HTTPException(status_code=500, detail="Failed to clean up graph data")
|
||||
|
||||
|
||||
@router.get("/{sketch_id}/graph")
|
||||
async def get_sketch_nodes(
|
||||
sketch_id: str,
|
||||
format: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""Get the nodes and edges for a sketch."""
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
return service.get_graph(UUID(sketch_id), current_user.id, format)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Graph not found")
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
|
||||
@router.post("/{sketch_id}/nodes/add")
|
||||
@update_sketch_timestamp
|
||||
def add_node(
|
||||
sketch_id: str,
|
||||
node: GraphNode,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
return service.add_node(UUID(sketch_id), current_user.id, node)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except ValidationError:
|
||||
raise HTTPException(status_code=400, detail="Node creation failed")
|
||||
except DatabaseError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{sketch_id}/relations/add")
|
||||
@update_sketch_timestamp
|
||||
def add_edge(
|
||||
sketch_id: str,
|
||||
relation: RelationInput,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
return service.add_relationship(
|
||||
UUID(sketch_id), current_user.id, relation.source, relation.target, relation.label
|
||||
)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except ValidationError:
|
||||
raise HTTPException(status_code=400, detail="Edge creation failed")
|
||||
except DatabaseError:
|
||||
raise HTTPException(status_code=500, detail="Failed to create edge")
|
||||
|
||||
|
||||
@router.put("/{sketch_id}/nodes/edit")
|
||||
@update_sketch_timestamp
|
||||
def edit_node(
|
||||
sketch_id: str,
|
||||
node_edit: NodeEditInput,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
return service.update_node(
|
||||
UUID(sketch_id), current_user.id, node_edit.nodeId, node_edit.updates
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except DatabaseError:
|
||||
raise HTTPException(status_code=500, detail="Failed to update node")
|
||||
|
||||
|
||||
@router.put("/{sketch_id}/nodes/positions")
|
||||
@update_sketch_timestamp
|
||||
def update_node_positions(
|
||||
@@ -344,40 +279,18 @@ def update_node_positions(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Update positions (x, y) for multiple nodes in batch.
|
||||
This is used to persist node positions after drag operations in the graph viewer.
|
||||
"""
|
||||
# Verify the sketch exists and user has access
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
|
||||
if not data.positions:
|
||||
return {"status": "no positions to update", "count": 0}
|
||||
|
||||
# Convert Pydantic models to dicts for GraphService
|
||||
positions = [pos.model_dump() for pos in data.positions]
|
||||
|
||||
# Update positions using GraphService
|
||||
"""Update positions (x, y) for multiple nodes in batch."""
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=sketch_id,
|
||||
enable_batching=False,
|
||||
)
|
||||
updated_count = graph_service.update_nodes_positions(positions=positions)
|
||||
except Exception as e:
|
||||
print(f"Position update error: {e}")
|
||||
positions = [pos.model_dump() for pos in data.positions]
|
||||
return service.update_node_positions(UUID(sketch_id), current_user.id, positions)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except DatabaseError:
|
||||
raise HTTPException(status_code=500, detail="Failed to update node positions")
|
||||
|
||||
return {
|
||||
"status": "positions updated",
|
||||
"count": updated_count,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{sketch_id}/nodes")
|
||||
@update_sketch_timestamp
|
||||
@@ -388,27 +301,16 @@ def delete_nodes(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
# First verify the sketch exists and belongs to the user
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
|
||||
# Delete nodes and their relationships using GraphService
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=sketch_id,
|
||||
enable_batching=False,
|
||||
)
|
||||
deleted_count = graph_service.delete_nodes(nodes.nodeIds)
|
||||
except Exception as e:
|
||||
print(f"Node deletion error: {e}")
|
||||
return service.delete_nodes(UUID(sketch_id), current_user.id, nodes.nodeIds)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except DatabaseError:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete nodes")
|
||||
|
||||
return {"status": "nodes deleted", "count": deleted_count}
|
||||
|
||||
|
||||
@router.delete("/{sketch_id}/relationships")
|
||||
@update_sketch_timestamp
|
||||
@@ -419,29 +321,18 @@ def delete_relationships(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
# First verify the sketch exists and belongs to the user
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
|
||||
# Delete relationships using GraphService
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=sketch_id,
|
||||
enable_batching=False,
|
||||
return service.delete_relationships(
|
||||
UUID(sketch_id), current_user.id, relationships.relationshipIds
|
||||
)
|
||||
deleted_count = graph_service.delete_relationships(
|
||||
relationships.relationshipIds
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Relationship deletion error: {e}")
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except DatabaseError:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete relationships")
|
||||
|
||||
return {"status": "relationships deleted", "count": deleted_count}
|
||||
|
||||
|
||||
@router.put("/{sketch_id}/relationships/edit")
|
||||
@update_sketch_timestamp
|
||||
@@ -452,42 +343,21 @@ def edit_relationship(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
# First verify the sketch exists and belongs to the user
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
|
||||
# Update edge using GraphService
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=sketch_id,
|
||||
enable_batching=False,
|
||||
return service.update_relationship(
|
||||
UUID(sketch_id),
|
||||
current_user.id,
|
||||
relationship_edit.relationshipId,
|
||||
relationship_edit.data,
|
||||
)
|
||||
result = graph_service.update_relationship(
|
||||
element_id=relationship_edit.relationshipId,
|
||||
properties=relationship_edit.data,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Relationship update error: {e}")
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except DatabaseError:
|
||||
raise HTTPException(status_code=500, detail="Failed to update relationship")
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Relationship not found or not accessible"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "relationship updated",
|
||||
"relationship": {
|
||||
"id": result["id"],
|
||||
"label": result["type"],
|
||||
"data": result["data"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{sketch_id}/nodes/merge")
|
||||
@update_sketch_timestamp
|
||||
@@ -499,49 +369,20 @@ def merge_nodes(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
|
||||
if not oldNodes or len(oldNodes) == 0:
|
||||
raise HTTPException(status_code=400, detail="oldNodes cannot be empty")
|
||||
|
||||
node_data = newNode.data.model_dump() if newNode.data else {}
|
||||
node_type = node_data.get("type", "Node")
|
||||
|
||||
properties = {
|
||||
"type": node_type.lower(),
|
||||
"label": node_data.get("label", "Merged Node"),
|
||||
}
|
||||
|
||||
flattened_data = flatten(node_data)
|
||||
properties.update(flattened_data)
|
||||
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=sketch_id,
|
||||
enable_batching=False,
|
||||
node_data = newNode.data.model_dump() if newNode.data else {}
|
||||
return service.merge_nodes(
|
||||
UUID(sketch_id), current_user.id, oldNodes, newNode.id, node_data
|
||||
)
|
||||
new_node_element_id = graph_service.merge_nodes(
|
||||
old_node_ids=oldNodes,
|
||||
new_node_data=properties,
|
||||
new_node_id=newNode.id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Node merge error: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to merge nodes: {str(e)}")
|
||||
|
||||
if not new_node_element_id:
|
||||
raise HTTPException(status_code=500, detail="Failed to merge nodes")
|
||||
|
||||
return {
|
||||
"status": "nodes merged",
|
||||
"count": len(oldNodes),
|
||||
"new_node_id": new_node_element_id,
|
||||
}
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except DatabaseError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{sketch_id}/nodes/{node_id}")
|
||||
@@ -551,26 +392,16 @@ def get_related_nodes(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
graph_service = create_graph_service(sketch_id=sketch_id)
|
||||
result = graph_service.get_neighbors(node_id)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return service.get_neighbors(UUID(sketch_id), current_user.id, node_id)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except DatabaseError:
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve related nodes")
|
||||
|
||||
if not result.nodes:
|
||||
raise HTTPException(status_code=404, detail="Node not found")
|
||||
|
||||
return {"nds": result.nodes, "rls": result.edges}
|
||||
|
||||
|
||||
@router.post("/{sketch_id}/import/analyze", response_model=FileParseResult)
|
||||
async def analyze_import_file(
|
||||
@@ -579,32 +410,26 @@ async def analyze_import_file(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Analyze an uploaded TXT or JSON file for import.
|
||||
Each line represents one entity. Detects entity types and provides preview.
|
||||
"""
|
||||
# Verify sketch exists and user has access
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
"""Analyze an uploaded TXT or JSON file for import."""
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
service.get_by_id(UUID(sketch_id), current_user.id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
# Validate file extension
|
||||
if not file.filename or not file.filename.lower().endswith((".txt", ".json")):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Only .txt and .json files are supported. Please upload a correct format.",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
content = await file.read()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to read file: {str(e)}")
|
||||
|
||||
# Analyze file using ImportService
|
||||
try:
|
||||
result = ImportService.analyze_file(
|
||||
file_content=content,
|
||||
@@ -618,26 +443,6 @@ async def analyze_import_file(
|
||||
return result
|
||||
|
||||
|
||||
class EntityMappingInput(BaseModel):
|
||||
"""Pydantic model for parsing entity mapping input from frontend."""
|
||||
|
||||
id: str
|
||||
entity_type: str
|
||||
include: bool = True
|
||||
nodeLabel: str
|
||||
node_id: Optional[str] = None
|
||||
data: Dict[str, Any]
|
||||
|
||||
|
||||
class ImportExecuteResponse(BaseModel):
|
||||
"""Response model for import execution."""
|
||||
|
||||
status: str
|
||||
nodes_created: int
|
||||
nodes_skipped: int
|
||||
errors: List[str]
|
||||
|
||||
|
||||
@router.post("/{sketch_id}/import/execute", response_model=ImportExecuteResponse)
|
||||
@update_sketch_timestamp
|
||||
async def execute_import(
|
||||
@@ -647,26 +452,21 @@ async def execute_import(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Execute the import of entities into the sketch.
|
||||
Uses the entity mappings provided by the frontend (no file re-parsing needed).
|
||||
"""
|
||||
"""Execute the import of entities into the sketch."""
|
||||
import json
|
||||
|
||||
# Verify sketch exists and user has access
|
||||
sketch = db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
service.get_by_id(UUID(sketch_id), current_user.id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["update"], db=db
|
||||
)
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
# Parse entity mappings JSON
|
||||
try:
|
||||
mappings = json.loads(entity_mappings_json)
|
||||
nodes = mappings.get("nodes", [])
|
||||
edges = mappings.get("edges", [])
|
||||
print(nodes)
|
||||
entity_mapping_inputs = [EntityMappingInput(**m) for m in nodes]
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid entity_mappings JSON")
|
||||
@@ -675,7 +475,6 @@ async def execute_import(
|
||||
status_code=400, detail=f"Failed to parse entity_mappings: {str(e)}"
|
||||
)
|
||||
|
||||
# Convert Pydantic inputs to service dataclasses
|
||||
entity_mappings = [
|
||||
EntityMapping(
|
||||
id=m.id,
|
||||
@@ -688,10 +487,7 @@ async def execute_import(
|
||||
for m in entity_mapping_inputs
|
||||
]
|
||||
|
||||
# Execute import using ImportService
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=sketch_id, enable_batching=False
|
||||
)
|
||||
graph_service = create_graph_service(sketch_id=sketch_id, enable_batching=False)
|
||||
import_service = create_import_service(graph_service)
|
||||
|
||||
try:
|
||||
@@ -717,39 +513,13 @@ async def export_sketch(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Profile = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Export the sketch in the specified format.
|
||||
Args:
|
||||
id: The ID of the sketch
|
||||
format: Export format - "json" (only format for now)
|
||||
db: The database session
|
||||
current_user: The current user
|
||||
Returns:
|
||||
The sketch data in the requested format
|
||||
"""
|
||||
sketch = db.query(Sketch).filter(Sketch.id == id).first()
|
||||
if not sketch:
|
||||
"""Export the sketch in the specified format."""
|
||||
service = create_sketch_service(db)
|
||||
try:
|
||||
return service.export_sketch(UUID(id), current_user.id, format)
|
||||
except NotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Sketch not found")
|
||||
check_investigation_permission(
|
||||
current_user.id, sketch.investigation_id, actions=["read"], db=db
|
||||
)
|
||||
|
||||
# Get all nodes and relationships using GraphService
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=id, enable_batching=False
|
||||
)
|
||||
graph_data = graph_service.get_sketch_graph()
|
||||
|
||||
if format == "json":
|
||||
return {
|
||||
"sketch": {
|
||||
"id": str(sketch.id),
|
||||
"title": sketch.title,
|
||||
"description": sketch.description,
|
||||
},
|
||||
"nodes": [node.model_dump(mode="json") for node in graph_data.nodes],
|
||||
"edges": [edge.model_dump(mode="json") for edge in graph_data.edges],
|
||||
}
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported format: {format}")
|
||||
except PermissionDeniedError:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@@ -1,318 +1,18 @@
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from flowsint_core.core.models import CustomType, Profile
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_types.registry import get_type
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from flowsint_core.core.models import Profile
|
||||
from flowsint_core.core.postgre_db import get_db
|
||||
from flowsint_core.core.services import create_type_registry_service
|
||||
from app.api.deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Helper function to get a type by name from the registry
|
||||
def get_type_from_registry(type_name: str) -> Optional[Type[BaseModel]]:
|
||||
"""Get a type from the TYPE_REGISTRY by name."""
|
||||
return get_type(type_name, case_sensitive=True)
|
||||
|
||||
|
||||
# Returns the "types" for the sketches
|
||||
@router.get("/")
|
||||
async def get_types_list(
|
||||
db: Session = Depends(get_db), current_user: Profile = Depends(get_current_user)
|
||||
):
|
||||
# Define categories with type names to look up in TYPE_REGISTRY
|
||||
# Format: (type_name, label_key, optional_icon)
|
||||
category_definitions = [
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "global",
|
||||
"key": "global_category",
|
||||
"icon": "phrase",
|
||||
"label": "Global",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Phrase", "text", None),
|
||||
("Location", "address", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "person",
|
||||
"key": "person_category",
|
||||
"icon": "individual",
|
||||
"label": "Identities & Entities",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Individual", "full_name", None),
|
||||
("Username", "value", "username"),
|
||||
("Organization", "name", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "organization",
|
||||
"key": "organization_category",
|
||||
"icon": "organization",
|
||||
"label": "Organization",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Organization", "name", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "contact_category",
|
||||
"key": "contact",
|
||||
"icon": "phone",
|
||||
"label": "Communication & Contact",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Phone", "number", None),
|
||||
("Email", "email", None),
|
||||
("Username", "value", None),
|
||||
("SocialAccount", "username", "socialaccount"),
|
||||
("Message", "content", "message"),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "network_category",
|
||||
"key": "network",
|
||||
"icon": "domain",
|
||||
"label": "Network",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("ASN", "number", None),
|
||||
("CIDR", "network", None),
|
||||
("Domain", "domain", None),
|
||||
("Website", "url", None),
|
||||
("Ip", "address", None),
|
||||
("Port", "number", None),
|
||||
("DNSRecord", "name", "dns"),
|
||||
("SSLCertificate", "subject", "ssl"),
|
||||
("WebTracker", "name", "webtracker"),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "security_category",
|
||||
"key": "security",
|
||||
"icon": "credential",
|
||||
"label": "Security & Access",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Credential", "username", "credential"),
|
||||
("Session", "session_id", "session"),
|
||||
("Device", "device_id", "device"),
|
||||
("Malware", "name", "malware"),
|
||||
("Weapon", "name", "weapon"),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "files_category",
|
||||
"key": "files",
|
||||
"icon": "file",
|
||||
"label": "Files & Documents",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Document", "title", "document"),
|
||||
("File", "filename", "file"),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "financial_category",
|
||||
"key": "financial",
|
||||
"icon": "creditcard",
|
||||
"label": "Financial Data",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("BankAccount", "account_number", "creditcard"),
|
||||
("CreditCard", "card_number", "creditcard"),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "leak_category",
|
||||
"key": "leaks",
|
||||
"icon": "breach",
|
||||
"label": "Leaks",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Leak", "name", "breach"),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "crypto_category",
|
||||
"key": "crypto",
|
||||
"icon": "cryptowallet",
|
||||
"label": "Crypto",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("CryptoWallet", "address", "cryptowallet"),
|
||||
("CryptoWalletTransaction", "hash", "cryptowallet"),
|
||||
("CryptoNFT", "name", "cryptowallet"),
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Build the types list by looking up each type in TYPE_REGISTRY
|
||||
types = []
|
||||
for category in category_definitions:
|
||||
category_copy = category.copy()
|
||||
children_schemas = []
|
||||
|
||||
for child_def in category["children"]:
|
||||
type_name, label_key, icon = child_def
|
||||
model = get_type_from_registry(type_name)
|
||||
|
||||
if model:
|
||||
children_schemas.append(
|
||||
extract_input_schema(model, label_key=label_key, icon=icon)
|
||||
)
|
||||
else:
|
||||
# Log warning but continue - type might not be available
|
||||
print(f"Warning: Type {type_name} not found in TYPE_REGISTRY")
|
||||
|
||||
category_copy["children"] = children_schemas
|
||||
types.append(category_copy)
|
||||
|
||||
# Add custom types
|
||||
custom_types = (
|
||||
db.query(CustomType)
|
||||
.filter(
|
||||
CustomType.owner_id == current_user.id,
|
||||
CustomType.status == "published", # Only show published custom types
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if custom_types:
|
||||
custom_types_children = []
|
||||
for custom_type in custom_types:
|
||||
# Extract the label_key from the schema (use first required field or first property)
|
||||
schema = custom_type.schema
|
||||
properties = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
# Try to use the first required field, or the first property
|
||||
label_key = (
|
||||
required[0]
|
||||
if required
|
||||
else list(properties.keys())[0]
|
||||
if properties
|
||||
else "value"
|
||||
)
|
||||
|
||||
custom_types_children.append(
|
||||
{
|
||||
"id": custom_type.id,
|
||||
"type": custom_type.name,
|
||||
"key": custom_type.name.lower(),
|
||||
"label_key": label_key,
|
||||
"icon": "custom",
|
||||
"label": custom_type.name,
|
||||
"description": custom_type.description or "",
|
||||
"fields": [
|
||||
{
|
||||
"name": prop,
|
||||
"label": info.get("title", prop),
|
||||
"description": info.get("description", ""),
|
||||
"type": "text",
|
||||
"required": prop in required,
|
||||
}
|
||||
for prop, info in properties.items()
|
||||
],
|
||||
"custom": True, # Mark as custom type
|
||||
}
|
||||
)
|
||||
|
||||
types.append(
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "custom_types_category",
|
||||
"key": "custom_types",
|
||||
"icon": "custom",
|
||||
"label": "Custom types",
|
||||
"fields": [],
|
||||
"children": custom_types_children,
|
||||
}
|
||||
)
|
||||
|
||||
return types
|
||||
|
||||
|
||||
def extract_input_schema(
|
||||
model: Type[BaseModel], label_key: str, icon: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
adapter = TypeAdapter(model)
|
||||
schema = adapter.json_schema()
|
||||
# Use the main schema properties, not the $defs
|
||||
type_name = model.__name__
|
||||
details = schema
|
||||
return {
|
||||
"id": uuid4(),
|
||||
"type": type_name,
|
||||
"key": type_name.lower(),
|
||||
"label_key": label_key,
|
||||
"icon": icon or type_name.lower(),
|
||||
"label": type_name,
|
||||
"description": details.get("description", ""),
|
||||
"fields": [
|
||||
resolve_field(prop, details=info, schema=schema)
|
||||
for prop, info in details.get("properties", {}).items()
|
||||
# exclude nodeLabel from properties to fill
|
||||
if prop != "nodeLabel"
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def resolve_field(prop: str, details: dict, schema: dict = None) -> Dict:
|
||||
"""_summary_
|
||||
The fields can sometimes contain nested complex objects, like:
|
||||
- Organization having Individual[] as dirigeants, so we want to skip those.
|
||||
Args:
|
||||
details (dict): _description_
|
||||
schema_context (dict, optional): _description_. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: _description_
|
||||
"""
|
||||
field = {
|
||||
"name": prop,
|
||||
"label": details.get("title", prop),
|
||||
"description": details.get("description", ""),
|
||||
"type": "text",
|
||||
}
|
||||
if has_enum(details):
|
||||
field["type"] = "select"
|
||||
field["options"] = [
|
||||
{"label": label, "value": label} for label in get_enum_values(details)
|
||||
]
|
||||
field["required"] = is_required(details)
|
||||
|
||||
return field
|
||||
|
||||
|
||||
def has_enum(schema: dict) -> bool:
|
||||
any_of = schema.get("anyOf", [])
|
||||
return any(isinstance(entry, dict) and "enum" in entry for entry in any_of)
|
||||
|
||||
|
||||
def is_required(schema: dict) -> bool:
|
||||
any_of = schema.get("anyOf", [])
|
||||
return not any(entry == {"type": "null"} for entry in any_of)
|
||||
|
||||
|
||||
def get_enum_values(schema: dict) -> list:
|
||||
enum_values = []
|
||||
for entry in schema.get("anyOf", []):
|
||||
if isinstance(entry, dict) and "enum" in entry:
|
||||
enum_values.extend(entry["enum"])
|
||||
return enum_values
|
||||
"""Get the complete types list for sketches."""
|
||||
service = create_type_registry_service(db)
|
||||
return service.get_types_list(current_user.id)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useTheme } from 'next-themes'
|
||||
import { Toaster as Sonner, ToasterProps } from 'sonner'
|
||||
import { useTheme } from '../theme-provider'
|
||||
|
||||
const Toaster = ({ ...props }: ToasterProps) => {
|
||||
const { theme = 'system' } = useTheme()
|
||||
|
||||
70
flowsint-core/src/flowsint_core/core/services/__init__.py
Normal file
70
flowsint-core/src/flowsint_core/core/services/__init__.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Service layer for flowsint-core.
|
||||
|
||||
This module provides business logic services that encapsulate database operations
|
||||
and domain logic, enabling cleaner route handlers and better testability.
|
||||
"""
|
||||
|
||||
from .exceptions import (
|
||||
ServiceError,
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
ValidationError,
|
||||
DatabaseError,
|
||||
AuthenticationError,
|
||||
ConflictError,
|
||||
)
|
||||
from .base import BaseService
|
||||
from .auth_service import AuthService, create_auth_service
|
||||
from .key_service import KeyService, create_key_service
|
||||
from .investigation_service import InvestigationService, create_investigation_service
|
||||
from .sketch_service import SketchService, create_sketch_service
|
||||
from .analysis_service import AnalysisService, create_analysis_service
|
||||
from .chat_service import ChatService, create_chat_service
|
||||
from .scan_service import ScanService, create_scan_service
|
||||
from .log_service import LogService, create_log_service
|
||||
from .flow_service import FlowService, create_flow_service
|
||||
from .custom_type_service import CustomTypeService, create_custom_type_service
|
||||
from .type_registry_service import TypeRegistryService, create_type_registry_service
|
||||
from .enricher_service import EnricherService, create_enricher_service
|
||||
|
||||
__all__ = [
|
||||
# Exceptions
|
||||
"ServiceError",
|
||||
"NotFoundError",
|
||||
"PermissionDeniedError",
|
||||
"ValidationError",
|
||||
"DatabaseError",
|
||||
"AuthenticationError",
|
||||
"ConflictError",
|
||||
# Base
|
||||
"BaseService",
|
||||
# Services - Phase 1
|
||||
"AuthService",
|
||||
"create_auth_service",
|
||||
"KeyService",
|
||||
"create_key_service",
|
||||
# Services - Phase 2
|
||||
"InvestigationService",
|
||||
"create_investigation_service",
|
||||
"SketchService",
|
||||
"create_sketch_service",
|
||||
# Services - Phase 3
|
||||
"AnalysisService",
|
||||
"create_analysis_service",
|
||||
"ChatService",
|
||||
"create_chat_service",
|
||||
"ScanService",
|
||||
"create_scan_service",
|
||||
"LogService",
|
||||
"create_log_service",
|
||||
# Services - Phase 4
|
||||
"FlowService",
|
||||
"create_flow_service",
|
||||
"CustomTypeService",
|
||||
"create_custom_type_service",
|
||||
"TypeRegistryService",
|
||||
"create_type_registry_service",
|
||||
"EnricherService",
|
||||
"create_enricher_service",
|
||||
]
|
||||
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Analysis service for managing analyses within investigations.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import Analysis, InvestigationUserRole
|
||||
from ..types import Role
|
||||
from .base import BaseService
|
||||
from .exceptions import NotFoundError, PermissionDeniedError
|
||||
|
||||
|
||||
class AnalysisService(BaseService):
|
||||
"""
|
||||
Service for analysis CRUD operations.
|
||||
"""
|
||||
|
||||
def get_accessible_analyses(self, user_id: UUID) -> List[Analysis]:
|
||||
"""
|
||||
Get all analyses accessible to a user based on their investigation roles.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
List of accessible analyses
|
||||
"""
|
||||
allowed_roles = [Role.OWNER, Role.EDITOR, Role.VIEWER]
|
||||
|
||||
query = self._db.query(Analysis).join(
|
||||
InvestigationUserRole,
|
||||
InvestigationUserRole.investigation_id == Analysis.investigation_id,
|
||||
)
|
||||
|
||||
query = query.filter(InvestigationUserRole.user_id == user_id)
|
||||
|
||||
conditions = [InvestigationUserRole.roles.any(role) for role in allowed_roles]
|
||||
query = query.filter(or_(*conditions))
|
||||
|
||||
return query.distinct().all()
|
||||
|
||||
def get_by_id(self, analysis_id: UUID, user_id: UUID) -> Analysis:
|
||||
"""
|
||||
Get an analysis by ID with permission check.
|
||||
|
||||
Args:
|
||||
analysis_id: The analysis ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
The analysis
|
||||
|
||||
Raises:
|
||||
NotFoundError: If analysis not found
|
||||
PermissionDeniedError: If user doesn't have permission
|
||||
"""
|
||||
analysis = self._db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
if not analysis:
|
||||
raise NotFoundError("Analysis not found")
|
||||
|
||||
self._check_permission(user_id, analysis.investigation_id, ["read"])
|
||||
return analysis
|
||||
|
||||
def get_by_investigation(
|
||||
self, investigation_id: UUID, user_id: UUID
|
||||
) -> List[Analysis]:
|
||||
"""
|
||||
Get all analyses for an investigation.
|
||||
|
||||
Args:
|
||||
investigation_id: The investigation ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
List of analyses
|
||||
|
||||
Raises:
|
||||
PermissionDeniedError: If user doesn't have permission
|
||||
"""
|
||||
self._check_permission(user_id, investigation_id, ["read"])
|
||||
|
||||
return (
|
||||
self._db.query(Analysis)
|
||||
.filter(Analysis.investigation_id == investigation_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
title: str,
|
||||
description: Optional[str],
|
||||
content: Optional[Dict[str, Any]],
|
||||
investigation_id: UUID,
|
||||
owner_id: UUID,
|
||||
) -> Analysis:
|
||||
"""
|
||||
Create a new analysis.
|
||||
|
||||
Args:
|
||||
title: Analysis title
|
||||
description: Analysis description
|
||||
content: Analysis content (JSON)
|
||||
investigation_id: Parent investigation ID
|
||||
owner_id: Owner user ID
|
||||
|
||||
Returns:
|
||||
The created analysis
|
||||
|
||||
Raises:
|
||||
PermissionDeniedError: If user can't create in this investigation
|
||||
"""
|
||||
self._check_permission(owner_id, investigation_id, ["create"])
|
||||
|
||||
new_analysis = Analysis(
|
||||
id=uuid4(),
|
||||
title=title,
|
||||
description=description,
|
||||
content=content,
|
||||
owner_id=owner_id,
|
||||
investigation_id=investigation_id,
|
||||
created_at=datetime.utcnow(),
|
||||
last_updated_at=datetime.utcnow(),
|
||||
)
|
||||
self._add(new_analysis)
|
||||
self._commit()
|
||||
self._refresh(new_analysis)
|
||||
return new_analysis
|
||||
|
||||
def update(
|
||||
self,
|
||||
analysis_id: UUID,
|
||||
user_id: UUID,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
content: Optional[Dict[str, Any]] = None,
|
||||
investigation_id: Optional[UUID] = None,
|
||||
) -> Analysis:
|
||||
"""
|
||||
Update an analysis.
|
||||
|
||||
Args:
|
||||
analysis_id: The analysis ID
|
||||
user_id: The user's ID
|
||||
title: New title (optional)
|
||||
description: New description (optional)
|
||||
content: New content (optional)
|
||||
investigation_id: New investigation ID (optional)
|
||||
|
||||
Returns:
|
||||
The updated analysis
|
||||
|
||||
Raises:
|
||||
NotFoundError: If analysis not found
|
||||
PermissionDeniedError: If user doesn't have permission
|
||||
"""
|
||||
analysis = self._db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
if not analysis:
|
||||
raise NotFoundError("Analysis not found")
|
||||
|
||||
self._check_permission(user_id, analysis.investigation_id, ["update"])
|
||||
|
||||
if title is not None:
|
||||
analysis.title = title
|
||||
if description is not None:
|
||||
analysis.description = description
|
||||
if content is not None:
|
||||
analysis.content = content
|
||||
if investigation_id is not None:
|
||||
# Check permission for the new investigation as well
|
||||
self._check_permission(user_id, investigation_id, ["update"])
|
||||
analysis.investigation_id = investigation_id
|
||||
|
||||
analysis.last_updated_at = datetime.utcnow()
|
||||
self._commit()
|
||||
self._refresh(analysis)
|
||||
return analysis
|
||||
|
||||
def delete(self, analysis_id: UUID, user_id: UUID) -> None:
|
||||
"""
|
||||
Delete an analysis.
|
||||
|
||||
Args:
|
||||
analysis_id: The analysis ID
|
||||
user_id: The user's ID
|
||||
|
||||
Raises:
|
||||
NotFoundError: If analysis not found
|
||||
PermissionDeniedError: If user doesn't have permission
|
||||
"""
|
||||
analysis = self._db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
if not analysis:
|
||||
raise NotFoundError("Analysis not found")
|
||||
|
||||
self._check_permission(user_id, analysis.investigation_id, ["delete"])
|
||||
|
||||
self._delete(analysis)
|
||||
self._commit()
|
||||
|
||||
|
||||
def create_analysis_service(db: Session) -> AnalysisService:
|
||||
"""
|
||||
Factory function to create an AnalysisService instance.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
Configured AnalysisService instance
|
||||
"""
|
||||
return AnalysisService(db=db)
|
||||
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Authentication service for user login and registration.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ..models import Profile
|
||||
from ..auth import verify_password, create_access_token, get_password_hash
|
||||
from .base import BaseService
|
||||
from .exceptions import AuthenticationError, ConflictError, DatabaseError
|
||||
|
||||
|
||||
class AuthService(BaseService):
|
||||
"""
|
||||
Service for user authentication and registration.
|
||||
"""
|
||||
|
||||
def authenticate(self, email: str, password: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Authenticate a user and return an access token.
|
||||
|
||||
Args:
|
||||
email: User's email
|
||||
password: User's password
|
||||
|
||||
Returns:
|
||||
Dictionary containing access_token, user_id, and token_type
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If credentials are invalid
|
||||
"""
|
||||
user = self._db.query(Profile).filter(Profile.email == email).first()
|
||||
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
raise AuthenticationError("Incorrect email or password")
|
||||
|
||||
access_token = create_access_token(data={"sub": user.email})
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"user_id": user.id,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
def register(self, email: str, password: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Register a new user.
|
||||
|
||||
Args:
|
||||
email: User's email
|
||||
password: User's password
|
||||
|
||||
Returns:
|
||||
Dictionary containing success message and email
|
||||
|
||||
Raises:
|
||||
ConflictError: If email is already registered
|
||||
DatabaseError: If database operation fails
|
||||
"""
|
||||
existing_user = self._db.query(Profile).filter(Profile.email == email).first()
|
||||
|
||||
if existing_user:
|
||||
raise ConflictError("Email already registered")
|
||||
|
||||
hashed_password = get_password_hash(password)
|
||||
new_user = Profile(email=email, hashed_password=hashed_password)
|
||||
|
||||
try:
|
||||
self._add(new_user)
|
||||
self._commit()
|
||||
self._refresh(new_user)
|
||||
|
||||
return {
|
||||
"message": "User registered successfully",
|
||||
"email": new_user.email,
|
||||
}
|
||||
except IntegrityError:
|
||||
self._rollback()
|
||||
raise ConflictError("Email already registered")
|
||||
|
||||
|
||||
def create_auth_service(db: Session) -> AuthService:
|
||||
"""
|
||||
Factory function to create an AuthService instance.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
Configured AuthService instance
|
||||
"""
|
||||
return AuthService(db=db)
|
||||
194
flowsint-core/src/flowsint_core/core/services/base.py
Normal file
194
flowsint-core/src/flowsint_core/core/services/base.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Base service class providing common functionality for all services.
|
||||
"""
|
||||
|
||||
from typing import Type, TypeVar, Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from .exceptions import DatabaseError, NotFoundError, PermissionDeniedError
|
||||
from ..models import InvestigationUserRole
|
||||
from ..types import Role
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseService:
|
||||
"""
|
||||
Base class for all services.
|
||||
|
||||
Provides common database operations and error handling patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""
|
||||
Initialize the service with a database session.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
"""
|
||||
self._db = db
|
||||
|
||||
@property
|
||||
def db(self) -> Session:
|
||||
"""Get the database session."""
|
||||
return self._db
|
||||
|
||||
def _commit(self) -> None:
|
||||
"""
|
||||
Commit the current transaction.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the commit fails
|
||||
"""
|
||||
try:
|
||||
self._db.commit()
|
||||
except SQLAlchemyError as e:
|
||||
self._db.rollback()
|
||||
raise DatabaseError(f"Database error: {e}")
|
||||
|
||||
def _rollback(self) -> None:
|
||||
"""Rollback the current transaction."""
|
||||
self._db.rollback()
|
||||
|
||||
def _flush(self) -> None:
|
||||
"""
|
||||
Flush pending changes to the database.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the flush fails
|
||||
"""
|
||||
try:
|
||||
self._db.flush()
|
||||
except SQLAlchemyError as e:
|
||||
self._db.rollback()
|
||||
raise DatabaseError(f"Database error: {e}")
|
||||
|
||||
def _refresh(self, entity: T) -> T:
|
||||
"""
|
||||
Refresh an entity from the database.
|
||||
|
||||
Args:
|
||||
entity: The entity to refresh
|
||||
|
||||
Returns:
|
||||
The refreshed entity
|
||||
"""
|
||||
self._db.refresh(entity)
|
||||
return entity
|
||||
|
||||
def _add(self, entity: T) -> T:
|
||||
"""
|
||||
Add an entity to the session.
|
||||
|
||||
Args:
|
||||
entity: The entity to add
|
||||
|
||||
Returns:
|
||||
The added entity
|
||||
"""
|
||||
self._db.add(entity)
|
||||
return entity
|
||||
|
||||
def _delete(self, entity: T) -> None:
|
||||
"""
|
||||
Mark an entity for deletion.
|
||||
|
||||
Args:
|
||||
entity: The entity to delete
|
||||
"""
|
||||
self._db.delete(entity)
|
||||
|
||||
def _get_or_404(self, model: Type[T], id: UUID) -> T:
|
||||
"""
|
||||
Get an entity by ID or raise NotFoundError.
|
||||
|
||||
Args:
|
||||
model: The SQLAlchemy model class
|
||||
id: The entity ID
|
||||
|
||||
Returns:
|
||||
The found entity
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the entity is not found
|
||||
"""
|
||||
entity = self._db.query(model).filter(model.id == id).first()
|
||||
if not entity:
|
||||
raise NotFoundError(f"{model.__name__} not found")
|
||||
return entity
|
||||
|
||||
def _get_by_id(self, model: Type[T], id: UUID) -> Optional[T]:
|
||||
"""
|
||||
Get an entity by ID, returning None if not found.
|
||||
|
||||
Args:
|
||||
model: The SQLAlchemy model class
|
||||
id: The entity ID
|
||||
|
||||
Returns:
|
||||
The found entity or None
|
||||
"""
|
||||
return self._db.query(model).filter(model.id == id).first()
|
||||
|
||||
def _get_all(self, model: Type[T]) -> List[T]:
|
||||
"""
|
||||
Get all entities of a given model.
|
||||
|
||||
Args:
|
||||
model: The SQLAlchemy model class
|
||||
|
||||
Returns:
|
||||
List of all entities
|
||||
"""
|
||||
return self._db.query(model).all()
|
||||
|
||||
def _can_user(self, roles: List[Role], actions: List[str]) -> bool:
|
||||
"""
|
||||
Check if at least one role in the list allows at least one action.
|
||||
|
||||
Args:
|
||||
roles: List of user roles
|
||||
actions: List of actions to check
|
||||
|
||||
Returns:
|
||||
True if any role allows any action
|
||||
"""
|
||||
for role in roles:
|
||||
for action in actions:
|
||||
if role == Role.OWNER:
|
||||
return True
|
||||
if role == Role.EDITOR and action in ["read", "create", "update"]:
|
||||
return True
|
||||
if role == Role.VIEWER and action == "read":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_permission(
|
||||
self, user_id: UUID, investigation_id: UUID, actions: List[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a user has permission to perform actions on an investigation.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
investigation_id: The investigation ID
|
||||
actions: List of actions to check
|
||||
|
||||
Returns:
|
||||
True if user has permission
|
||||
|
||||
Raises:
|
||||
PermissionDeniedError: If user doesn't have permission
|
||||
"""
|
||||
role_entry = (
|
||||
self._db.query(InvestigationUserRole)
|
||||
.filter_by(user_id=user_id, investigation_id=investigation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not role_entry or not self._can_user(role_entry.roles, actions):
|
||||
raise PermissionDeniedError("Forbidden")
|
||||
return True
|
||||
288
flowsint-core/src/flowsint_core/core/services/chat_service.py
Normal file
288
flowsint-core/src/flowsint_core/core/services/chat_service.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Chat service for managing chats and messages with AI integration.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
from datetime import datetime
|
||||
import os
|
||||
import json
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import Chat, ChatMessage
|
||||
from .base import BaseService
|
||||
from .exceptions import NotFoundError, PermissionDeniedError, DatabaseError
|
||||
|
||||
|
||||
def clean_context(context: List[Dict]) -> List[Dict]:
|
||||
"""Remove unnecessary keys from context data."""
|
||||
cleaned = []
|
||||
for item in context:
|
||||
if isinstance(item, dict):
|
||||
cleaned_item = item.get("data", item).copy()
|
||||
cleaned_item.pop("id", None)
|
||||
cleaned_item.pop("sketch_id", None)
|
||||
if "data" in cleaned_item and isinstance(cleaned_item["data"], dict):
|
||||
cleaned_item["data"].pop("sketch_id", None)
|
||||
cleaned_item.pop("measured", None)
|
||||
cleaned.append(cleaned_item)
|
||||
return cleaned
|
||||
|
||||
|
||||
class ChatService(BaseService):
|
||||
"""
|
||||
Service for chat CRUD operations and AI message streaming.
|
||||
"""
|
||||
|
||||
def get_chats_for_user(self, user_id: UUID) -> List[Chat]:
|
||||
"""
|
||||
Get all chats owned by a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
List of chats with sorted messages
|
||||
"""
|
||||
chats = self._db.query(Chat).filter(Chat.owner_id == user_id).all()
|
||||
|
||||
# Sort messages for each chat by created_at in ascending order
|
||||
for chat in chats:
|
||||
chat.messages.sort(key=lambda x: x.created_at)
|
||||
|
||||
return chats
|
||||
|
||||
def get_by_investigation(
|
||||
self, investigation_id: UUID, user_id: UUID
|
||||
) -> List[Chat]:
|
||||
"""
|
||||
Get all chats for an investigation.
|
||||
|
||||
Args:
|
||||
investigation_id: The investigation ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
List of chats with sorted messages
|
||||
"""
|
||||
chats = (
|
||||
self._db.query(Chat)
|
||||
.filter(Chat.investigation_id == investigation_id, Chat.owner_id == user_id)
|
||||
.order_by(Chat.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
for chat in chats:
|
||||
chat.messages.sort(key=lambda x: x.created_at)
|
||||
|
||||
return chats
|
||||
|
||||
def get_by_id(self, chat_id: UUID, user_id: UUID) -> Chat:
|
||||
"""
|
||||
Get a chat by ID.
|
||||
|
||||
Args:
|
||||
chat_id: The chat ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
The chat with sorted messages
|
||||
|
||||
Raises:
|
||||
NotFoundError: If chat not found or doesn't belong to user
|
||||
"""
|
||||
chat = (
|
||||
self._db.query(Chat)
|
||||
.filter(Chat.id == chat_id, Chat.owner_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not chat:
|
||||
raise NotFoundError("Chat not found")
|
||||
|
||||
chat.messages.sort(key=lambda x: x.created_at)
|
||||
return chat
|
||||
|
||||
def create(
|
||||
self,
|
||||
title: str,
|
||||
description: Optional[str],
|
||||
investigation_id: Optional[UUID],
|
||||
owner_id: UUID,
|
||||
) -> Chat:
|
||||
"""
|
||||
Create a new chat.
|
||||
|
||||
Args:
|
||||
title: Chat title
|
||||
description: Chat description
|
||||
investigation_id: Parent investigation ID (optional)
|
||||
owner_id: Owner user ID
|
||||
|
||||
Returns:
|
||||
The created chat
|
||||
"""
|
||||
new_chat = Chat(
|
||||
id=uuid4(),
|
||||
title=title,
|
||||
description=description,
|
||||
owner_id=owner_id,
|
||||
investigation_id=investigation_id,
|
||||
created_at=datetime.utcnow(),
|
||||
last_updated_at=datetime.utcnow(),
|
||||
)
|
||||
self._add(new_chat)
|
||||
self._commit()
|
||||
self._refresh(new_chat)
|
||||
return new_chat
|
||||
|
||||
def delete(self, chat_id: UUID, user_id: UUID) -> None:
|
||||
"""
|
||||
Delete a chat.
|
||||
|
||||
Args:
|
||||
chat_id: The chat ID
|
||||
user_id: The user's ID
|
||||
|
||||
Raises:
|
||||
NotFoundError: If chat not found or doesn't belong to user
|
||||
"""
|
||||
chat = (
|
||||
self._db.query(Chat)
|
||||
.filter(Chat.id == chat_id, Chat.owner_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not chat:
|
||||
raise NotFoundError("Chat not found")
|
||||
|
||||
self._delete(chat)
|
||||
self._commit()
|
||||
|
||||
def add_user_message(
|
||||
self, chat_id: UUID, user_id: UUID, content: str, context: Optional[List[Dict]] = None
|
||||
) -> ChatMessage:
|
||||
"""
|
||||
Add a user message to a chat.
|
||||
|
||||
Args:
|
||||
chat_id: The chat ID
|
||||
user_id: The user's ID
|
||||
content: Message content
|
||||
context: Optional context data
|
||||
|
||||
Returns:
|
||||
The created message
|
||||
|
||||
Raises:
|
||||
NotFoundError: If chat not found
|
||||
"""
|
||||
chat = (
|
||||
self._db.query(Chat)
|
||||
.filter(Chat.id == chat_id, Chat.owner_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not chat:
|
||||
raise NotFoundError("Chat not found")
|
||||
|
||||
# Update chat's last_updated_at
|
||||
chat.last_updated_at = datetime.utcnow()
|
||||
|
||||
user_message = ChatMessage(
|
||||
id=uuid4(),
|
||||
content=content,
|
||||
context=context,
|
||||
chat_id=chat_id,
|
||||
is_bot=False,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
self._add(user_message)
|
||||
self._commit()
|
||||
self._refresh(user_message)
|
||||
return user_message
|
||||
|
||||
def add_bot_message(self, chat_id: UUID, content: str) -> ChatMessage:
|
||||
"""
|
||||
Add a bot message to a chat.
|
||||
|
||||
Args:
|
||||
chat_id: The chat ID
|
||||
content: Message content
|
||||
|
||||
Returns:
|
||||
The created message
|
||||
"""
|
||||
chat_message = ChatMessage(
|
||||
id=uuid4(),
|
||||
content=content,
|
||||
chat_id=chat_id,
|
||||
is_bot=True,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
self._add(chat_message)
|
||||
self._commit()
|
||||
self._refresh(chat_message)
|
||||
return chat_message
|
||||
|
||||
def get_chat_with_context(self, chat_id: UUID, user_id: UUID) -> Chat:
|
||||
"""
|
||||
Get a chat with its messages for AI context building.
|
||||
|
||||
Args:
|
||||
chat_id: The chat ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
The chat
|
||||
|
||||
Raises:
|
||||
NotFoundError: If chat not found
|
||||
"""
|
||||
return self.get_by_id(chat_id, user_id)
|
||||
|
||||
def prepare_ai_context(
|
||||
self, chat: Chat, user_prompt: str, context: Optional[List[Dict]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare context for AI message generation.
|
||||
|
||||
Args:
|
||||
chat: The chat
|
||||
user_prompt: The user's prompt
|
||||
context: Optional additional context
|
||||
|
||||
Returns:
|
||||
Dictionary with prepared context for AI
|
||||
"""
|
||||
context_message = None
|
||||
if context:
|
||||
try:
|
||||
cleaned_context = clean_context(context)
|
||||
if cleaned_context:
|
||||
context_str = json.dumps(cleaned_context, indent=2, default=str)
|
||||
context_message = f"Context: {context_str}"
|
||||
if len(context_message) > 2000:
|
||||
context_message = context_message[:2000] + "..."
|
||||
except Exception as e:
|
||||
print(f"Context processing error: {e}")
|
||||
|
||||
sorted_messages = sorted(chat.messages, key=lambda x: x.created_at)
|
||||
recent_messages = sorted_messages[-5:] if len(sorted_messages) > 5 else sorted_messages
|
||||
|
||||
return {
|
||||
"recent_messages": recent_messages,
|
||||
"context_message": context_message,
|
||||
"user_prompt": user_prompt,
|
||||
}
|
||||
|
||||
|
||||
def create_chat_service(db: Session) -> ChatService:
|
||||
"""
|
||||
Factory function to create a ChatService instance.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
Configured ChatService instance
|
||||
"""
|
||||
return ChatService(db=db)
|
||||
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Custom type service for managing user-defined types.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import CustomType
|
||||
from .base import BaseService
|
||||
from .exceptions import NotFoundError, ValidationError, ConflictError
|
||||
|
||||
|
||||
class CustomTypeService(BaseService):
|
||||
"""
|
||||
Service for custom type CRUD operations and validation.
|
||||
"""
|
||||
|
||||
def list_custom_types(
|
||||
self, user_id: UUID, status: Optional[str] = None
|
||||
) -> List[CustomType]:
|
||||
"""
|
||||
List all custom types for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
status: Optional status filter (draft, published, archived)
|
||||
|
||||
Returns:
|
||||
List of custom types
|
||||
|
||||
Raises:
|
||||
ValidationError: If invalid status provided
|
||||
"""
|
||||
query = self._db.query(CustomType).filter(CustomType.owner_id == user_id)
|
||||
|
||||
if status:
|
||||
if status not in ["draft", "published", "archived"]:
|
||||
raise ValidationError(
|
||||
"Status must be one of: draft, published, archived"
|
||||
)
|
||||
query = query.filter(CustomType.status == status)
|
||||
|
||||
return query.order_by(CustomType.created_at.desc()).all()
|
||||
|
||||
def get_by_id(self, custom_type_id: UUID, user_id: UUID) -> CustomType:
|
||||
"""
|
||||
Get a custom type by ID.
|
||||
|
||||
Args:
|
||||
custom_type_id: The custom type ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
The custom type
|
||||
|
||||
Raises:
|
||||
NotFoundError: If custom type not found
|
||||
"""
|
||||
custom_type = (
|
||||
self._db.query(CustomType)
|
||||
.filter(CustomType.id == custom_type_id, CustomType.owner_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not custom_type:
|
||||
raise NotFoundError("Custom type not found")
|
||||
return custom_type
|
||||
|
||||
def get_schema(self, custom_type_id: UUID, user_id: UUID) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the raw JSON Schema for a custom type.
|
||||
|
||||
Args:
|
||||
custom_type_id: The custom type ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
The JSON schema
|
||||
|
||||
Raises:
|
||||
NotFoundError: If custom type not found
|
||||
"""
|
||||
custom_type = self.get_by_id(custom_type_id, user_id)
|
||||
return custom_type.schema
|
||||
|
||||
def create(
|
||||
self,
|
||||
name: str,
|
||||
json_schema: Dict[str, Any],
|
||||
user_id: UUID,
|
||||
description: Optional[str] = None,
|
||||
status: str = "draft",
|
||||
validate_schema_func=None,
|
||||
calculate_checksum_func=None,
|
||||
) -> CustomType:
|
||||
"""
|
||||
Create a new custom type.
|
||||
|
||||
Args:
|
||||
name: Type name
|
||||
json_schema: The JSON Schema
|
||||
user_id: The owner's ID
|
||||
description: Optional description
|
||||
status: Initial status (default: draft)
|
||||
validate_schema_func: Function to validate JSON schema
|
||||
calculate_checksum_func: Function to calculate schema checksum
|
||||
|
||||
Returns:
|
||||
The created custom type
|
||||
|
||||
Raises:
|
||||
ValidationError: If schema is invalid
|
||||
ConflictError: If name already exists
|
||||
"""
|
||||
# Validate the JSON Schema
|
||||
if validate_schema_func:
|
||||
validate_schema_func(json_schema)
|
||||
|
||||
# Calculate checksum
|
||||
checksum = calculate_checksum_func(json_schema) if calculate_checksum_func else None
|
||||
|
||||
# Check for duplicate name for this user
|
||||
existing = (
|
||||
self._db.query(CustomType)
|
||||
.filter(CustomType.owner_id == user_id, CustomType.name == name)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ConflictError(f"Custom type with name '{name}' already exists")
|
||||
|
||||
# Create the custom type
|
||||
db_custom_type = CustomType(
|
||||
name=name,
|
||||
owner_id=user_id,
|
||||
schema=json_schema,
|
||||
description=description,
|
||||
status=status,
|
||||
checksum=checksum,
|
||||
)
|
||||
|
||||
self._add(db_custom_type)
|
||||
self._commit()
|
||||
self._refresh(db_custom_type)
|
||||
|
||||
return db_custom_type
|
||||
|
||||
def update(
|
||||
self,
|
||||
custom_type_id: UUID,
|
||||
user_id: UUID,
|
||||
name: Optional[str] = None,
|
||||
json_schema: Optional[Dict[str, Any]] = None,
|
||||
description: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
validate_schema_func=None,
|
||||
calculate_checksum_func=None,
|
||||
) -> CustomType:
|
||||
"""
|
||||
Update a custom type.
|
||||
|
||||
Args:
|
||||
custom_type_id: The custom type ID
|
||||
user_id: The user's ID
|
||||
name: New name (optional)
|
||||
json_schema: New schema (optional)
|
||||
description: New description (optional)
|
||||
status: New status (optional)
|
||||
validate_schema_func: Function to validate JSON schema
|
||||
calculate_checksum_func: Function to calculate schema checksum
|
||||
|
||||
Returns:
|
||||
The updated custom type
|
||||
|
||||
Raises:
|
||||
NotFoundError: If custom type not found
|
||||
ValidationError: If schema is invalid
|
||||
ConflictError: If name already exists
|
||||
"""
|
||||
custom_type = self.get_by_id(custom_type_id, user_id)
|
||||
|
||||
if name is not None:
|
||||
# Check for duplicate name
|
||||
existing = (
|
||||
self._db.query(CustomType)
|
||||
.filter(
|
||||
CustomType.owner_id == user_id,
|
||||
CustomType.name == name,
|
||||
CustomType.id != custom_type_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ConflictError(f"Custom type with name '{name}' already exists")
|
||||
custom_type.name = name
|
||||
|
||||
if json_schema is not None:
|
||||
if validate_schema_func:
|
||||
validate_schema_func(json_schema)
|
||||
custom_type.schema = json_schema
|
||||
if calculate_checksum_func:
|
||||
custom_type.checksum = calculate_checksum_func(json_schema)
|
||||
|
||||
if description is not None:
|
||||
custom_type.description = description
|
||||
|
||||
if status is not None:
|
||||
custom_type.status = status
|
||||
|
||||
self._commit()
|
||||
self._refresh(custom_type)
|
||||
|
||||
return custom_type
|
||||
|
||||
def delete(self, custom_type_id: UUID, user_id: UUID) -> None:
|
||||
"""
|
||||
Delete a custom type.
|
||||
|
||||
Args:
|
||||
custom_type_id: The custom type ID
|
||||
user_id: The user's ID
|
||||
|
||||
Raises:
|
||||
NotFoundError: If custom type not found
|
||||
"""
|
||||
custom_type = self.get_by_id(custom_type_id, user_id)
|
||||
self._delete(custom_type)
|
||||
self._commit()
|
||||
|
||||
def validate_payload(
|
||||
self,
|
||||
custom_type_id: UUID,
|
||||
user_id: UUID,
|
||||
payload: Dict[str, Any],
|
||||
validate_payload_func=None,
|
||||
) -> Tuple[bool, Optional[List[str]]]:
|
||||
"""
|
||||
Validate a payload against a custom type's schema.
|
||||
|
||||
Args:
|
||||
custom_type_id: The custom type ID
|
||||
user_id: The user's ID
|
||||
payload: The payload to validate
|
||||
validate_payload_func: Function to validate payload against schema
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, errors)
|
||||
|
||||
Raises:
|
||||
NotFoundError: If custom type not found
|
||||
"""
|
||||
custom_type = self.get_by_id(custom_type_id, user_id)
|
||||
|
||||
if validate_payload_func:
|
||||
return validate_payload_func(payload, custom_type.schema)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def create_custom_type_service(db: Session) -> CustomTypeService:
|
||||
"""
|
||||
Factory function to create a CustomTypeService instance.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
Configured CustomTypeService instance
|
||||
"""
|
||||
return CustomTypeService(db=db)
|
||||
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Enricher service for managing enricher operations.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import CustomType
|
||||
from .base import BaseService
|
||||
|
||||
|
||||
class EnricherService(BaseService):
|
||||
"""
|
||||
Service for enricher operations and listing.
|
||||
"""
|
||||
|
||||
def get_enrichers(
|
||||
self, category: Optional[str], user_id: UUID, enricher_registry
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get enrichers, optionally filtered by category.
|
||||
|
||||
Args:
|
||||
category: Optional category filter
|
||||
user_id: The user's ID
|
||||
enricher_registry: The enricher registry instance
|
||||
|
||||
Returns:
|
||||
List of enrichers
|
||||
"""
|
||||
if not category or category.lower() == "undefined":
|
||||
return enricher_registry.list(exclude=["n8n_connector"])
|
||||
|
||||
# Check if category is a custom type
|
||||
custom_type = (
|
||||
self._db.query(CustomType)
|
||||
.filter(
|
||||
CustomType.owner_id == user_id,
|
||||
CustomType.status == "published",
|
||||
func.lower(CustomType.name) == category.lower(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if custom_type:
|
||||
return enricher_registry.list(exclude=["n8n_connector"], wobbly_type=True)
|
||||
|
||||
return enricher_registry.list_by_input_type(category, exclude=["n8n_connector"])
|
||||
|
||||
|
||||
def create_enricher_service(db: Session) -> EnricherService:
|
||||
"""
|
||||
Factory function to create an EnricherService instance.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
Configured EnricherService instance
|
||||
"""
|
||||
return EnricherService(db=db)
|
||||
56
flowsint-core/src/flowsint_core/core/services/exceptions.py
Normal file
56
flowsint-core/src/flowsint_core/core/services/exceptions.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Domain exceptions for the service layer.
|
||||
|
||||
These exceptions represent business logic errors that can be caught
|
||||
and converted to appropriate HTTP responses by route handlers.
|
||||
"""
|
||||
|
||||
|
||||
class ServiceError(Exception):
|
||||
"""Base exception for all service errors."""
|
||||
|
||||
def __init__(self, message: str = "A service error occurred"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class NotFoundError(ServiceError):
|
||||
"""Entity not found."""
|
||||
|
||||
def __init__(self, message: str = "Entity not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class PermissionDeniedError(ServiceError):
|
||||
"""User does not have permission to perform the action."""
|
||||
|
||||
def __init__(self, message: str = "Permission denied"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ValidationError(ServiceError):
|
||||
"""Input validation failed."""
|
||||
|
||||
def __init__(self, message: str = "Validation failed"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class DatabaseError(ServiceError):
|
||||
"""Database operation failed."""
|
||||
|
||||
def __init__(self, message: str = "Database operation failed"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AuthenticationError(ServiceError):
|
||||
"""Authentication failed."""
|
||||
|
||||
def __init__(self, message: str = "Authentication failed"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ConflictError(ServiceError):
|
||||
"""Resource conflict (e.g., duplicate entry)."""
|
||||
|
||||
def __init__(self, message: str = "Resource conflict"):
|
||||
super().__init__(message)
|
||||
200
flowsint-core/src/flowsint_core/core/services/flow_service.py
Normal file
200
flowsint-core/src/flowsint_core/core/services/flow_service.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Flow service for managing flows and flow computations.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import Flow, CustomType, Sketch
|
||||
from ..types import FlowBranch, FlowEdge, FlowNode, FlowStep
|
||||
from .base import BaseService
|
||||
from .exceptions import NotFoundError, PermissionDeniedError
|
||||
|
||||
|
||||
class FlowService(BaseService):
|
||||
"""
|
||||
Service for flow CRUD operations and flow computations.
|
||||
"""
|
||||
|
||||
def get_all_flows(
|
||||
self, category: Optional[str], user_id: UUID
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all flows, optionally filtered by category.
|
||||
|
||||
Args:
|
||||
category: Optional category filter
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
List of flows
|
||||
"""
|
||||
if not category or category.lower() == "undefined":
|
||||
return self._db.query(Flow).order_by(Flow.last_updated_at.desc()).all()
|
||||
|
||||
# Check if category is a custom type
|
||||
custom_type = (
|
||||
self._db.query(CustomType)
|
||||
.filter(
|
||||
CustomType.owner_id == user_id,
|
||||
CustomType.status == "published",
|
||||
func.lower(CustomType.name) == category.lower(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if custom_type:
|
||||
flows = self._db.query(Flow).order_by(Flow.last_updated_at.desc()).all()
|
||||
return [
|
||||
{
|
||||
**(flow.to_dict() if hasattr(flow, "to_dict") else flow.__dict__),
|
||||
"wobblyType": True,
|
||||
}
|
||||
for flow in flows
|
||||
]
|
||||
|
||||
# Filter by category
|
||||
flows = self._db.query(Flow).order_by(Flow.last_updated_at.desc()).all()
|
||||
return [
|
||||
flow
|
||||
for flow in flows
|
||||
if any(cat.lower() == category.lower() for cat in flow.category)
|
||||
]
|
||||
|
||||
def get_by_id(self, flow_id: UUID) -> Flow:
|
||||
"""
|
||||
Get a flow by ID.
|
||||
|
||||
Args:
|
||||
flow_id: The flow ID
|
||||
|
||||
Returns:
|
||||
The flow
|
||||
|
||||
Raises:
|
||||
NotFoundError: If flow not found
|
||||
"""
|
||||
flow = self._db.query(Flow).filter(Flow.id == flow_id).first()
|
||||
if not flow:
|
||||
raise NotFoundError("Flow not found")
|
||||
return flow
|
||||
|
||||
def create(
|
||||
self,
|
||||
name: str,
|
||||
description: Optional[str],
|
||||
category: List[str],
|
||||
flow_schema: Dict[str, Any],
|
||||
) -> Flow:
|
||||
"""
|
||||
Create a new flow.
|
||||
|
||||
Args:
|
||||
name: Flow name
|
||||
description: Flow description
|
||||
category: List of categories
|
||||
flow_schema: Flow schema (nodes and edges)
|
||||
|
||||
Returns:
|
||||
The created flow
|
||||
"""
|
||||
new_flow = Flow(
|
||||
id=uuid4(),
|
||||
name=name,
|
||||
description=description,
|
||||
category=category,
|
||||
flow_schema=flow_schema,
|
||||
created_at=datetime.utcnow(),
|
||||
last_updated_at=datetime.utcnow(),
|
||||
)
|
||||
self._add(new_flow)
|
||||
self._commit()
|
||||
self._refresh(new_flow)
|
||||
return new_flow
|
||||
|
||||
def update(
|
||||
self, flow_id: UUID, updates: Dict[str, Any]
|
||||
) -> Flow:
|
||||
"""
|
||||
Update a flow.
|
||||
|
||||
Args:
|
||||
flow_id: The flow ID
|
||||
updates: Dictionary of updates
|
||||
|
||||
Returns:
|
||||
The updated flow
|
||||
|
||||
Raises:
|
||||
NotFoundError: If flow not found
|
||||
"""
|
||||
flow = self._db.query(Flow).filter(Flow.id == flow_id).first()
|
||||
if not flow:
|
||||
raise NotFoundError("Flow not found")
|
||||
|
||||
for key, value in updates.items():
|
||||
if key == "category":
|
||||
if "SocialAccount" in value:
|
||||
value.append("Username")
|
||||
setattr(flow, key, value)
|
||||
|
||||
flow.last_updated_at = datetime.utcnow()
|
||||
self._commit()
|
||||
self._refresh(flow)
|
||||
return flow
|
||||
|
||||
def delete(self, flow_id: UUID) -> None:
|
||||
"""
|
||||
Delete a flow.
|
||||
|
||||
Args:
|
||||
flow_id: The flow ID
|
||||
|
||||
Raises:
|
||||
NotFoundError: If flow not found
|
||||
"""
|
||||
flow = self._db.query(Flow).filter(Flow.id == flow_id).first()
|
||||
if not flow:
|
||||
raise NotFoundError("Flow not found")
|
||||
|
||||
self._delete(flow)
|
||||
self._commit()
|
||||
|
||||
def get_sketch_for_launch(self, sketch_id: str, user_id: UUID) -> Sketch:
|
||||
"""
|
||||
Get sketch for flow launch with permission check.
|
||||
|
||||
Args:
|
||||
sketch_id: The sketch ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
The sketch
|
||||
|
||||
Raises:
|
||||
NotFoundError: If sketch not found
|
||||
PermissionDeniedError: If user doesn't have permission
|
||||
"""
|
||||
sketch = self._db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise NotFoundError("Sketch not found")
|
||||
|
||||
self._check_permission(user_id, sketch.investigation_id, ["update"])
|
||||
return sketch
|
||||
|
||||
|
||||
def create_flow_service(db: Session) -> FlowService:
|
||||
"""
|
||||
Factory function to create a FlowService instance.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
Configured FlowService instance
|
||||
"""
|
||||
return FlowService(db=db)
|
||||
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Investigation service for managing investigations and user roles.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID, uuid4
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from ..models import Investigation, InvestigationUserRole, Sketch, Analysis
|
||||
from ..types import Role
|
||||
from ..graph import create_graph_service
|
||||
from .base import BaseService
|
||||
from .exceptions import NotFoundError, PermissionDeniedError, DatabaseError
|
||||
|
||||
|
||||
class InvestigationService(BaseService):
|
||||
"""
|
||||
Service for investigation CRUD operations and role management.
|
||||
"""
|
||||
|
||||
def get_accessible_investigations(
|
||||
self, user_id: UUID, allowed_roles: Optional[List[Role]] = None
|
||||
) -> List[Investigation]:
|
||||
"""
|
||||
Get all investigations accessible to a user based on their roles.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
allowed_roles: Optional list of roles to filter by
|
||||
|
||||
Returns:
|
||||
List of accessible investigations
|
||||
"""
|
||||
query = self._db.query(Investigation).join(
|
||||
InvestigationUserRole,
|
||||
InvestigationUserRole.investigation_id == Investigation.id,
|
||||
)
|
||||
|
||||
query = query.filter(InvestigationUserRole.user_id == user_id)
|
||||
|
||||
if allowed_roles:
|
||||
conditions = [
|
||||
InvestigationUserRole.roles.any(role) for role in allowed_roles
|
||||
]
|
||||
query = query.filter(or_(*conditions, Investigation.owner_id == user_id))
|
||||
|
||||
return (
|
||||
query.options(
|
||||
selectinload(Investigation.sketches),
|
||||
selectinload(Investigation.analyses),
|
||||
selectinload(Investigation.owner),
|
||||
)
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_by_id(self, investigation_id: UUID, user_id: UUID) -> Investigation:
|
||||
"""
|
||||
Get an investigation by ID with permission check.
|
||||
|
||||
Args:
|
||||
investigation_id: The investigation ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
The investigation
|
||||
|
||||
Raises:
|
||||
PermissionDeniedError: If user doesn't have read permission
|
||||
NotFoundError: If investigation not found
|
||||
"""
|
||||
self._check_permission(user_id, investigation_id, actions=["read"])
|
||||
|
||||
investigation = (
|
||||
self._db.query(Investigation)
|
||||
.options(
|
||||
selectinload(Investigation.sketches),
|
||||
selectinload(Investigation.analyses),
|
||||
selectinload(Investigation.owner),
|
||||
)
|
||||
.filter(Investigation.id == investigation_id)
|
||||
.filter(Investigation.owner_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not investigation:
|
||||
raise NotFoundError("Investigation not found")
|
||||
return investigation
|
||||
|
||||
def get_sketches(self, investigation_id: UUID, user_id: UUID) -> List[Sketch]:
|
||||
"""
|
||||
Get all sketches for an investigation.
|
||||
|
||||
Args:
|
||||
investigation_id: The investigation ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
List of sketches
|
||||
|
||||
Raises:
|
||||
PermissionDeniedError: If user doesn't have read permission
|
||||
NotFoundError: If no sketches found
|
||||
"""
|
||||
self._check_permission(user_id, investigation_id, actions=["read"])
|
||||
|
||||
sketches = (
|
||||
self._db.query(Sketch)
|
||||
.filter(Sketch.investigation_id == investigation_id)
|
||||
.all()
|
||||
)
|
||||
if not sketches:
|
||||
raise NotFoundError("No sketches found for this investigation")
|
||||
return sketches
|
||||
|
||||
def create(
|
||||
self, name: str, description: Optional[str], owner_id: UUID
|
||||
) -> Investigation:
|
||||
"""
|
||||
Create a new investigation with owner role.
|
||||
|
||||
Args:
|
||||
name: Investigation name
|
||||
description: Investigation description
|
||||
owner_id: The owner's user ID
|
||||
|
||||
Returns:
|
||||
The created investigation
|
||||
"""
|
||||
new_investigation = Investigation(
|
||||
id=uuid4(),
|
||||
name=name,
|
||||
description=description or name,
|
||||
owner_id=owner_id,
|
||||
status="active",
|
||||
)
|
||||
self._add(new_investigation)
|
||||
|
||||
new_roles = InvestigationUserRole(
|
||||
id=uuid4(),
|
||||
user_id=owner_id,
|
||||
investigation_id=new_investigation.id,
|
||||
roles=[Role.OWNER],
|
||||
)
|
||||
self._add(new_roles)
|
||||
|
||||
self._commit()
|
||||
self._refresh(new_investigation)
|
||||
|
||||
return new_investigation
|
||||
|
||||
def update(
|
||||
self,
|
||||
investigation_id: UUID,
|
||||
user_id: UUID,
|
||||
name: str,
|
||||
description: str,
|
||||
status: str,
|
||||
) -> Investigation:
|
||||
"""
|
||||
Update an investigation.
|
||||
|
||||
Args:
|
||||
investigation_id: The investigation ID
|
||||
user_id: The user's ID
|
||||
name: New name
|
||||
description: New description
|
||||
status: New status
|
||||
|
||||
Returns:
|
||||
The updated investigation
|
||||
|
||||
Raises:
|
||||
PermissionDeniedError: If user doesn't have write permission
|
||||
NotFoundError: If investigation not found
|
||||
"""
|
||||
self._check_permission(user_id, investigation_id, actions=["write"])
|
||||
|
||||
investigation = (
|
||||
self._db.query(Investigation)
|
||||
.filter(Investigation.id == investigation_id)
|
||||
.first()
|
||||
)
|
||||
if not investigation:
|
||||
raise NotFoundError("Investigation not found")
|
||||
|
||||
investigation.name = name
|
||||
investigation.description = description
|
||||
investigation.status = status
|
||||
investigation.last_updated_at = datetime.utcnow()
|
||||
|
||||
self._commit()
|
||||
self._refresh(investigation)
|
||||
return investigation
|
||||
|
||||
def delete(self, investigation_id: UUID, user_id: UUID) -> None:
|
||||
"""
|
||||
Delete an investigation and all related data.
|
||||
|
||||
Args:
|
||||
investigation_id: The investigation ID
|
||||
user_id: The user's ID
|
||||
|
||||
Raises:
|
||||
PermissionDeniedError: If user doesn't have delete permission
|
||||
NotFoundError: If investigation not found
|
||||
DatabaseError: If graph cleanup fails
|
||||
"""
|
||||
self._check_permission(user_id, investigation_id, actions=["delete"])
|
||||
|
||||
investigation = (
|
||||
self._db.query(Investigation)
|
||||
.filter(
|
||||
Investigation.id == investigation_id,
|
||||
Investigation.owner_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not investigation:
|
||||
raise NotFoundError("Investigation not found")
|
||||
|
||||
# Get all sketches related to this investigation
|
||||
sketches = (
|
||||
self._db.query(Sketch)
|
||||
.filter(Sketch.investigation_id == investigation_id)
|
||||
.all()
|
||||
)
|
||||
analyses = (
|
||||
self._db.query(Analysis)
|
||||
.filter(Analysis.investigation_id == investigation_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Delete all nodes and relationships for each sketch in Neo4j
|
||||
for sketch in sketches:
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch.id),
|
||||
enable_batching=False,
|
||||
)
|
||||
graph_service.delete_all_sketch_nodes()
|
||||
except Exception as e:
|
||||
print(f"Neo4j cleanup error for sketch {sketch.id}: {e}")
|
||||
raise DatabaseError("Failed to clean up graph data")
|
||||
|
||||
# Delete all sketches and analyses from PostgreSQL
|
||||
for sketch in sketches:
|
||||
self._delete(sketch)
|
||||
for analysis in analyses:
|
||||
self._delete(analysis)
|
||||
|
||||
# Finally delete the investigation
|
||||
self._delete(investigation)
|
||||
self._commit()
|
||||
|
||||
|
||||
def create_investigation_service(db: Session) -> InvestigationService:
|
||||
"""
|
||||
Factory function to create an InvestigationService instance.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
Configured InvestigationService instance
|
||||
"""
|
||||
return InvestigationService(db=db)
|
||||
120
flowsint-core/src/flowsint_core/core/services/key_service.py
Normal file
120
flowsint-core/src/flowsint_core/core/services/key_service.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
API key management service with Vault integration.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import Key
|
||||
from ..vault import Vault
|
||||
from .base import BaseService
|
||||
from .exceptions import NotFoundError, DatabaseError
|
||||
|
||||
|
||||
class KeyService(BaseService):
|
||||
"""
|
||||
Service for API key management with encryption via Vault.
|
||||
"""
|
||||
|
||||
def get_keys_for_user(self, user_id: UUID) -> List[Key]:
|
||||
"""
|
||||
Get all keys for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
List of Key entities (without decrypted values)
|
||||
"""
|
||||
return self._db.query(Key).filter(Key.owner_id == user_id).all()
|
||||
|
||||
def get_key_by_id(self, key_id: UUID, user_id: UUID) -> Key:
|
||||
"""
|
||||
Get a specific key by ID for a user.
|
||||
|
||||
Args:
|
||||
key_id: The key's ID
|
||||
user_id: The user's ID (for ownership verification)
|
||||
|
||||
Returns:
|
||||
The Key entity
|
||||
|
||||
Raises:
|
||||
NotFoundError: If key not found or doesn't belong to user
|
||||
"""
|
||||
key = (
|
||||
self._db.query(Key)
|
||||
.filter(Key.id == key_id, Key.owner_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not key:
|
||||
raise NotFoundError("Key not found")
|
||||
return key
|
||||
|
||||
def create_key(self, name: str, key_value: str, user_id: UUID) -> Key:
|
||||
"""
|
||||
Create a new encrypted API key.
|
||||
|
||||
Args:
|
||||
name: Key name (e.g., "shodan", "whoxy")
|
||||
key_value: The plain text key value
|
||||
user_id: The owner's ID
|
||||
|
||||
Returns:
|
||||
The created Key entity
|
||||
|
||||
Raises:
|
||||
DatabaseError: If key creation fails
|
||||
"""
|
||||
try:
|
||||
vault = Vault(db=self._db, owner_id=user_id)
|
||||
key = vault.set_secret(vault_ref=name, plain_key=key_value)
|
||||
if not key:
|
||||
raise DatabaseError("An error occurred creating the key")
|
||||
return key
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"An error occurred creating the key: {e}")
|
||||
|
||||
def delete_key(self, key_id: UUID, user_id: UUID) -> None:
|
||||
"""
|
||||
Delete a key by ID.
|
||||
|
||||
Args:
|
||||
key_id: The key's ID
|
||||
user_id: The user's ID (for ownership verification)
|
||||
|
||||
Raises:
|
||||
NotFoundError: If key not found or doesn't belong to user
|
||||
"""
|
||||
key = self.get_key_by_id(key_id, user_id)
|
||||
self._delete(key)
|
||||
self._commit()
|
||||
|
||||
def get_decrypted_key(self, name_or_id: str, user_id: UUID) -> Optional[str]:
|
||||
"""
|
||||
Get a decrypted key value by name or ID.
|
||||
|
||||
Args:
|
||||
name_or_id: Either the key name or UUID
|
||||
user_id: The owner's ID
|
||||
|
||||
Returns:
|
||||
The decrypted key value or None if not found
|
||||
"""
|
||||
vault = Vault(db=self._db, owner_id=user_id)
|
||||
return vault.get_secret(name_or_id)
|
||||
|
||||
|
||||
def create_key_service(db: Session) -> KeyService:
|
||||
"""
|
||||
Factory function to create a KeyService instance.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
Configured KeyService instance
|
||||
"""
|
||||
return KeyService(db=db)
|
||||
161
flowsint-core/src/flowsint_core/core/services/log_service.py
Normal file
161
flowsint-core/src/flowsint_core/core/services/log_service.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Log service for managing event logs.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import Log, Sketch, Scan
|
||||
from ..types import Event
|
||||
from ..enums import EventLevel
|
||||
from .base import BaseService
|
||||
from .exceptions import NotFoundError, PermissionDeniedError, DatabaseError
|
||||
|
||||
|
||||
class LogService(BaseService):
|
||||
"""
|
||||
Service for log operations.
|
||||
"""
|
||||
|
||||
def _get_sketch_with_permission(
|
||||
self, sketch_id: str, user_id: UUID, actions: List[str]
|
||||
) -> Sketch:
|
||||
"""Get sketch and verify user has permission."""
|
||||
sketch = self._db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise NotFoundError(f"Sketch with id {sketch_id} not found")
|
||||
self._check_permission(user_id, sketch.investigation_id, actions)
|
||||
return sketch
|
||||
|
||||
def get_logs_by_sketch(
|
||||
self,
|
||||
sketch_id: str,
|
||||
user_id: UUID,
|
||||
limit: int = 100,
|
||||
since: Optional[datetime] = None,
|
||||
) -> List[Event]:
|
||||
"""
|
||||
Get historical logs for a specific sketch.
|
||||
|
||||
Args:
|
||||
sketch_id: The sketch ID
|
||||
user_id: The user's ID
|
||||
limit: Maximum number of logs to return
|
||||
since: Only return logs after this time
|
||||
|
||||
Returns:
|
||||
List of Event objects
|
||||
|
||||
Raises:
|
||||
NotFoundError: If sketch not found
|
||||
PermissionDeniedError: If user doesn't have permission
|
||||
"""
|
||||
self._get_sketch_with_permission(sketch_id, user_id, ["read"])
|
||||
|
||||
query = (
|
||||
self._db.query(Log)
|
||||
.filter(Log.sketch_id == sketch_id)
|
||||
.order_by(Log.created_at.desc())
|
||||
)
|
||||
|
||||
if since:
|
||||
query = query.filter(Log.created_at > since)
|
||||
else:
|
||||
# Default to last 24 hours if no since parameter
|
||||
query = query.filter(Log.created_at > datetime.utcnow() - timedelta(days=1))
|
||||
|
||||
logs = query.limit(limit).all()
|
||||
|
||||
# Reverse to show chronologically (oldest to newest)
|
||||
logs = list(reversed(logs))
|
||||
|
||||
results = []
|
||||
for log in logs:
|
||||
# Ensure payload is always a dictionary
|
||||
if isinstance(log.content, dict):
|
||||
payload = log.content
|
||||
elif isinstance(log.content, str):
|
||||
payload = {"message": log.content}
|
||||
elif log.content is None:
|
||||
payload = {}
|
||||
else:
|
||||
payload = {"content": str(log.content)}
|
||||
|
||||
results.append(
|
||||
Event(
|
||||
id=str(log.id),
|
||||
sketch_id=str(log.sketch_id) if log.sketch_id else None,
|
||||
type=log.type,
|
||||
payload=payload,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def delete_logs_by_sketch(self, sketch_id: str, user_id: UUID) -> dict:
|
||||
"""
|
||||
Delete all logs for a specific sketch.
|
||||
|
||||
Args:
|
||||
sketch_id: The sketch ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
|
||||
Raises:
|
||||
NotFoundError: If sketch not found
|
||||
PermissionDeniedError: If user doesn't have permission
|
||||
DatabaseError: If deletion fails
|
||||
"""
|
||||
self._get_sketch_with_permission(sketch_id, user_id, ["delete"])
|
||||
|
||||
try:
|
||||
self._db.query(Log).filter(Log.sketch_id == sketch_id).delete()
|
||||
self._commit()
|
||||
return {"message": "All logs have been deleted successfully"}
|
||||
except Exception as e:
|
||||
self._rollback()
|
||||
raise DatabaseError(f"Failed to delete logs: {str(e)}")
|
||||
|
||||
def get_scan_with_permission(self, scan_id: str, user_id: UUID) -> Scan:
|
||||
"""
|
||||
Get a scan and verify user has permission via sketch.
|
||||
|
||||
Args:
|
||||
scan_id: The scan ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
The scan
|
||||
|
||||
Raises:
|
||||
NotFoundError: If scan not found
|
||||
PermissionDeniedError: If user doesn't have permission
|
||||
"""
|
||||
scan = self._db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
if not scan:
|
||||
raise NotFoundError(f"Scan with id {scan_id} not found")
|
||||
|
||||
sketch = self._db.query(Sketch).filter(Sketch.id == scan.sketch_id).first()
|
||||
if sketch:
|
||||
self._check_permission(user_id, sketch.investigation_id, ["read"])
|
||||
|
||||
return scan
|
||||
|
||||
|
||||
def create_log_service(db: Session) -> LogService:
|
||||
"""
|
||||
Factory function to create a LogService instance.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
Configured LogService instance
|
||||
"""
|
||||
return LogService(db=db)
|
||||
111
flowsint-core/src/flowsint_core/core/services/scan_service.py
Normal file
111
flowsint-core/src/flowsint_core/core/services/scan_service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Scan service for managing scans.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import Scan, Sketch, InvestigationUserRole
|
||||
from ..types import Role
|
||||
from .base import BaseService
|
||||
from .exceptions import NotFoundError, PermissionDeniedError
|
||||
|
||||
|
||||
class ScanService(BaseService):
|
||||
"""
|
||||
Service for scan operations.
|
||||
"""
|
||||
|
||||
def get_accessible_scans(self, user_id: UUID) -> List[Scan]:
|
||||
"""
|
||||
Get all scans accessible to a user based on their investigation roles.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
List of accessible scans
|
||||
"""
|
||||
allowed_roles = [Role.OWNER, Role.EDITOR, Role.VIEWER]
|
||||
|
||||
query = (
|
||||
self._db.query(Scan)
|
||||
.join(Sketch, Sketch.id == Scan.sketch_id)
|
||||
.join(
|
||||
InvestigationUserRole,
|
||||
InvestigationUserRole.investigation_id == Sketch.investigation_id,
|
||||
)
|
||||
)
|
||||
|
||||
query = query.filter(InvestigationUserRole.user_id == user_id)
|
||||
|
||||
conditions = [InvestigationUserRole.roles.any(role) for role in allowed_roles]
|
||||
query = query.filter(or_(*conditions))
|
||||
|
||||
return query.distinct().all()
|
||||
|
||||
def get_by_id(self, scan_id: UUID, user_id: UUID) -> Scan:
|
||||
"""
|
||||
Get a scan by ID with permission check.
|
||||
|
||||
Args:
|
||||
scan_id: The scan ID
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
The scan
|
||||
|
||||
Raises:
|
||||
NotFoundError: If scan not found
|
||||
PermissionDeniedError: If user doesn't have permission
|
||||
"""
|
||||
scan = self._db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
if not scan:
|
||||
raise NotFoundError("Scan not found")
|
||||
|
||||
# Check investigation permission via sketch
|
||||
sketch = self._db.query(Sketch).filter(Sketch.id == scan.sketch_id).first()
|
||||
if sketch:
|
||||
self._check_permission(user_id, sketch.investigation_id, ["read"])
|
||||
|
||||
return scan
|
||||
|
||||
def delete(self, scan_id: UUID, user_id: UUID) -> None:
|
||||
"""
|
||||
Delete a scan.
|
||||
|
||||
Args:
|
||||
scan_id: The scan ID
|
||||
user_id: The user's ID
|
||||
|
||||
Raises:
|
||||
NotFoundError: If scan not found
|
||||
PermissionDeniedError: If user doesn't have permission
|
||||
"""
|
||||
scan = self._db.query(Scan).filter(Scan.id == scan_id).first()
|
||||
if not scan:
|
||||
raise NotFoundError("Scan not found")
|
||||
|
||||
# Check investigation permission via sketch
|
||||
sketch = self._db.query(Sketch).filter(Sketch.id == scan.sketch_id).first()
|
||||
if sketch:
|
||||
self._check_permission(user_id, sketch.investigation_id, ["delete"])
|
||||
|
||||
self._delete(scan)
|
||||
self._commit()
|
||||
|
||||
|
||||
def create_scan_service(db: Session) -> ScanService:
|
||||
"""
|
||||
Factory function to create a ScanService instance.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
Configured ScanService instance
|
||||
"""
|
||||
return ScanService(db=db)
|
||||
399
flowsint-core/src/flowsint_core/core/services/sketch_service.py
Normal file
399
flowsint-core/src/flowsint_core/core/services/sketch_service.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Sketch service for managing sketches and graph operations.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import Sketch
|
||||
from ..graph import create_graph_service, GraphNode
|
||||
from ..graph.types import GraphData
|
||||
from .base import BaseService
|
||||
from .exceptions import NotFoundError, PermissionDeniedError, ValidationError, DatabaseError
|
||||
|
||||
|
||||
class SketchService(BaseService):
|
||||
"""
|
||||
Service for sketch CRUD operations and graph interactions.
|
||||
"""
|
||||
|
||||
def _get_sketch_with_permission(
|
||||
self, sketch_id: UUID, user_id: UUID, actions: List[str]
|
||||
) -> Sketch:
|
||||
"""Get sketch and verify user has permission."""
|
||||
sketch = self._db.query(Sketch).filter(Sketch.id == sketch_id).first()
|
||||
if not sketch:
|
||||
raise NotFoundError("Sketch not found")
|
||||
self._check_permission(user_id, sketch.investigation_id, actions)
|
||||
return sketch
|
||||
|
||||
def list_sketches(self, user_id: UUID) -> List[Sketch]:
|
||||
"""Get all sketches owned by a user."""
|
||||
return self._db.query(Sketch).filter(Sketch.owner_id == user_id).all()
|
||||
|
||||
def get_by_id(self, sketch_id: UUID, user_id: UUID) -> Sketch:
|
||||
"""Get a sketch by ID with permission check."""
|
||||
return self._get_sketch_with_permission(sketch_id, user_id, ["read"])
|
||||
|
||||
def create(
|
||||
self,
|
||||
title: str,
|
||||
description: Optional[str],
|
||||
investigation_id: UUID,
|
||||
owner_id: UUID,
|
||||
) -> Sketch:
|
||||
"""
|
||||
Create a new sketch.
|
||||
|
||||
Args:
|
||||
title: Sketch title
|
||||
description: Sketch description
|
||||
investigation_id: Parent investigation ID
|
||||
owner_id: Owner user ID
|
||||
|
||||
Returns:
|
||||
The created sketch
|
||||
|
||||
Raises:
|
||||
ValidationError: If investigation_id is missing
|
||||
PermissionDeniedError: If user can't create in this investigation
|
||||
"""
|
||||
if not investigation_id:
|
||||
raise ValidationError("Investigation not found")
|
||||
|
||||
self._check_permission(owner_id, investigation_id, ["create"])
|
||||
|
||||
sketch = Sketch(
|
||||
title=title,
|
||||
description=description,
|
||||
investigation_id=investigation_id,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
self._add(sketch)
|
||||
self._commit()
|
||||
self._refresh(sketch)
|
||||
return sketch
|
||||
|
||||
def update(
|
||||
self, sketch_id: UUID, user_id: UUID, updates: Dict[str, Any]
|
||||
) -> Sketch:
|
||||
"""Update a sketch with permission check."""
|
||||
sketch = self._get_sketch_with_permission(sketch_id, user_id, ["update"])
|
||||
|
||||
for key, value in updates.items():
|
||||
if hasattr(sketch, key):
|
||||
setattr(sketch, key, value)
|
||||
|
||||
self._commit()
|
||||
self._refresh(sketch)
|
||||
return sketch
|
||||
|
||||
def delete(self, sketch_id: UUID, user_id: UUID) -> None:
|
||||
"""Delete a sketch and its graph data."""
|
||||
sketch = self._get_sketch_with_permission(sketch_id, user_id, ["delete"])
|
||||
|
||||
# Delete all nodes and relationships in Neo4j first
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch_id), enable_batching=False
|
||||
)
|
||||
graph_service.delete_all_sketch_nodes()
|
||||
except Exception as e:
|
||||
print(f"Neo4j cleanup error: {e}")
|
||||
raise DatabaseError("Failed to clean up graph data")
|
||||
|
||||
# Then delete the sketch from PostgreSQL
|
||||
self._delete(sketch)
|
||||
self._commit()
|
||||
|
||||
# --- Graph operations ---
|
||||
|
||||
def get_graph(
|
||||
self, sketch_id: UUID, user_id: UUID, format: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the graph data for a sketch.
|
||||
|
||||
Args:
|
||||
sketch_id: Sketch ID
|
||||
user_id: User ID for permission check
|
||||
format: Optional format ("inline" for inline relationships)
|
||||
|
||||
Returns:
|
||||
Graph data as dict with "nds" and "rls" keys, or inline format
|
||||
"""
|
||||
sketch = self._get_sketch_with_permission(sketch_id, user_id, ["read"])
|
||||
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch_id), enable_batching=False
|
||||
)
|
||||
graph_data = graph_service.get_sketch_graph()
|
||||
|
||||
if format == "inline":
|
||||
from flowsint_core.utils import get_inline_relationships
|
||||
return get_inline_relationships(graph_data.nodes, graph_data.edges)
|
||||
|
||||
graph = graph_data.model_dump(mode="json", serialize_as_any=True)
|
||||
return {"nds": graph["nodes"], "rls": graph["edges"]}
|
||||
|
||||
def add_node(
|
||||
self, sketch_id: UUID, user_id: UUID, node: GraphNode
|
||||
) -> Dict[str, Any]:
|
||||
"""Add a node to the sketch graph."""
|
||||
self._get_sketch_with_permission(sketch_id, user_id, ["update"])
|
||||
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch_id), enable_batching=False
|
||||
)
|
||||
node_id = graph_service.create_node(node)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise DatabaseError(f"Database error: {str(e)}")
|
||||
|
||||
if not node_id:
|
||||
raise ValidationError("Node creation failed")
|
||||
|
||||
node.id = node_id
|
||||
return {"status": "node added", "node": node}
|
||||
|
||||
def add_relationship(
|
||||
self,
|
||||
sketch_id: UUID,
|
||||
user_id: UUID,
|
||||
source: str,
|
||||
target: str,
|
||||
label: str = "RELATED_TO",
|
||||
) -> Dict[str, Any]:
|
||||
"""Add a relationship between nodes."""
|
||||
self._get_sketch_with_permission(sketch_id, user_id, ["update"])
|
||||
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch_id), enable_batching=False
|
||||
)
|
||||
result = graph_service.create_relationship_by_element_id(
|
||||
from_element_id=source,
|
||||
to_element_id=target,
|
||||
rel_label=label,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Edge creation error: {e}")
|
||||
raise DatabaseError("Failed to create edge")
|
||||
|
||||
if not result:
|
||||
raise ValidationError("Edge creation failed")
|
||||
|
||||
return {"status": "edge added", "edge": result}
|
||||
|
||||
def update_node(
|
||||
self, sketch_id: UUID, user_id: UUID, node_id: str, updates: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Update a node's properties."""
|
||||
self._get_sketch_with_permission(sketch_id, user_id, ["update"])
|
||||
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch_id), enable_batching=False
|
||||
)
|
||||
updated_element_id = graph_service.update_node(
|
||||
element_id=node_id, updates=updates
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Node update error: {e}")
|
||||
raise DatabaseError("Failed to update node")
|
||||
|
||||
if not updated_element_id:
|
||||
raise NotFoundError("Node not found or not accessible")
|
||||
|
||||
return {"status": "node updated", "node": {"id": updated_element_id}}
|
||||
|
||||
def update_node_positions(
|
||||
self, sketch_id: UUID, user_id: UUID, positions: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Update positions for multiple nodes."""
|
||||
self._get_sketch_with_permission(sketch_id, user_id, ["update"])
|
||||
|
||||
if not positions:
|
||||
return {"status": "no positions to update", "count": 0}
|
||||
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch_id), enable_batching=False
|
||||
)
|
||||
updated_count = graph_service.update_nodes_positions(positions=positions)
|
||||
except Exception as e:
|
||||
print(f"Position update error: {e}")
|
||||
raise DatabaseError("Failed to update node positions")
|
||||
|
||||
return {"status": "positions updated", "count": updated_count}
|
||||
|
||||
def delete_nodes(
|
||||
self, sketch_id: UUID, user_id: UUID, node_ids: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete nodes from the graph."""
|
||||
self._get_sketch_with_permission(sketch_id, user_id, ["update"])
|
||||
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch_id), enable_batching=False
|
||||
)
|
||||
deleted_count = graph_service.delete_nodes(node_ids)
|
||||
except Exception as e:
|
||||
print(f"Node deletion error: {e}")
|
||||
raise DatabaseError("Failed to delete nodes")
|
||||
|
||||
return {"status": "nodes deleted", "count": deleted_count}
|
||||
|
||||
def delete_relationships(
|
||||
self, sketch_id: UUID, user_id: UUID, relationship_ids: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete relationships from the graph."""
|
||||
self._get_sketch_with_permission(sketch_id, user_id, ["update"])
|
||||
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch_id), enable_batching=False
|
||||
)
|
||||
deleted_count = graph_service.delete_relationships(relationship_ids)
|
||||
except Exception as e:
|
||||
print(f"Relationship deletion error: {e}")
|
||||
raise DatabaseError("Failed to delete relationships")
|
||||
|
||||
return {"status": "relationships deleted", "count": deleted_count}
|
||||
|
||||
def update_relationship(
|
||||
self,
|
||||
sketch_id: UUID,
|
||||
user_id: UUID,
|
||||
relationship_id: str,
|
||||
data: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Update a relationship's properties."""
|
||||
self._get_sketch_with_permission(sketch_id, user_id, ["update"])
|
||||
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch_id), enable_batching=False
|
||||
)
|
||||
result = graph_service.update_relationship(
|
||||
element_id=relationship_id, properties=data
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Relationship update error: {e}")
|
||||
raise DatabaseError("Failed to update relationship")
|
||||
|
||||
if not result:
|
||||
raise NotFoundError("Relationship not found or not accessible")
|
||||
|
||||
return {
|
||||
"status": "relationship updated",
|
||||
"relationship": {
|
||||
"id": result["id"],
|
||||
"label": result["type"],
|
||||
"data": result["data"],
|
||||
},
|
||||
}
|
||||
|
||||
def merge_nodes(
|
||||
self,
|
||||
sketch_id: UUID,
|
||||
user_id: UUID,
|
||||
old_node_ids: List[str],
|
||||
new_node_id: str,
|
||||
node_data: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge multiple nodes into one."""
|
||||
from flowsint_core.utils import flatten
|
||||
|
||||
self._get_sketch_with_permission(sketch_id, user_id, ["update"])
|
||||
|
||||
if not old_node_ids:
|
||||
raise ValidationError("oldNodes cannot be empty")
|
||||
|
||||
node_type = node_data.get("type", "Node")
|
||||
properties = {
|
||||
"type": node_type.lower(),
|
||||
"label": node_data.get("label", "Merged Node"),
|
||||
}
|
||||
flattened_data = flatten(node_data)
|
||||
properties.update(flattened_data)
|
||||
|
||||
try:
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch_id), enable_batching=False
|
||||
)
|
||||
new_node_element_id = graph_service.merge_nodes(
|
||||
old_node_ids=old_node_ids,
|
||||
new_node_data=properties,
|
||||
new_node_id=new_node_id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Node merge error: {e}")
|
||||
raise DatabaseError(f"Failed to merge nodes: {str(e)}")
|
||||
|
||||
if not new_node_element_id:
|
||||
raise DatabaseError("Failed to merge nodes")
|
||||
|
||||
return {
|
||||
"status": "nodes merged",
|
||||
"count": len(old_node_ids),
|
||||
"new_node_id": new_node_element_id,
|
||||
}
|
||||
|
||||
def get_neighbors(
|
||||
self, sketch_id: UUID, user_id: UUID, node_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get neighboring nodes and edges for a node."""
|
||||
self._get_sketch_with_permission(sketch_id, user_id, ["read"])
|
||||
|
||||
try:
|
||||
graph_service = create_graph_service(sketch_id=str(sketch_id))
|
||||
result = graph_service.get_neighbors(node_id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise DatabaseError("Failed to retrieve related nodes")
|
||||
|
||||
if not result.nodes:
|
||||
raise NotFoundError("Node not found")
|
||||
|
||||
return {"nds": result.nodes, "rls": result.edges}
|
||||
|
||||
def export_sketch(
|
||||
self, sketch_id: UUID, user_id: UUID, format: str = "json"
|
||||
) -> Dict[str, Any]:
|
||||
"""Export sketch data."""
|
||||
sketch = self._get_sketch_with_permission(sketch_id, user_id, ["read"])
|
||||
|
||||
graph_service = create_graph_service(
|
||||
sketch_id=str(sketch_id), enable_batching=False
|
||||
)
|
||||
graph_data = graph_service.get_sketch_graph()
|
||||
|
||||
if format == "json":
|
||||
return {
|
||||
"sketch": {
|
||||
"id": str(sketch.id),
|
||||
"title": sketch.title,
|
||||
"description": sketch.description,
|
||||
},
|
||||
"nodes": [node.model_dump(mode="json") for node in graph_data.nodes],
|
||||
"edges": [edge.model_dump(mode="json") for edge in graph_data.edges],
|
||||
}
|
||||
else:
|
||||
raise ValidationError(f"Unsupported format: {format}")
|
||||
|
||||
|
||||
def create_sketch_service(db: Session) -> SketchService:
|
||||
"""
|
||||
Factory function to create a SketchService instance.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
Configured SketchService instance
|
||||
"""
|
||||
return SketchService(db=db)
|
||||
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
Type registry service for managing flowsint types.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from ..models import CustomType
|
||||
from .base import BaseService
|
||||
|
||||
|
||||
class TypeRegistryService(BaseService):
|
||||
"""
|
||||
Service for type registry operations and schema extraction.
|
||||
"""
|
||||
|
||||
def get_types_list(self, user_id: UUID) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get the complete types list for sketches.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
List of type categories with their children
|
||||
"""
|
||||
from flowsint_types.registry import get_type
|
||||
|
||||
# Define categories with type names
|
||||
category_definitions = self._get_category_definitions()
|
||||
|
||||
types = []
|
||||
for category in category_definitions:
|
||||
category_copy = category.copy()
|
||||
children_schemas = []
|
||||
|
||||
for child_def in category["children"]:
|
||||
type_name, label_key, icon = child_def
|
||||
model = get_type(type_name, case_sensitive=True)
|
||||
|
||||
if model:
|
||||
children_schemas.append(
|
||||
self._extract_input_schema(model, label_key=label_key, icon=icon)
|
||||
)
|
||||
else:
|
||||
print(f"Warning: Type {type_name} not found in TYPE_REGISTRY")
|
||||
|
||||
category_copy["children"] = children_schemas
|
||||
types.append(category_copy)
|
||||
|
||||
# Add custom types
|
||||
custom_types = (
|
||||
self._db.query(CustomType)
|
||||
.filter(
|
||||
CustomType.owner_id == user_id,
|
||||
CustomType.status == "published",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if custom_types:
|
||||
custom_types_children = []
|
||||
for custom_type in custom_types:
|
||||
schema = custom_type.schema
|
||||
properties = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
label_key = (
|
||||
required[0]
|
||||
if required
|
||||
else list(properties.keys())[0]
|
||||
if properties
|
||||
else "value"
|
||||
)
|
||||
|
||||
custom_types_children.append(
|
||||
{
|
||||
"id": custom_type.id,
|
||||
"type": custom_type.name,
|
||||
"key": custom_type.name.lower(),
|
||||
"label_key": label_key,
|
||||
"icon": "custom",
|
||||
"label": custom_type.name,
|
||||
"description": custom_type.description or "",
|
||||
"fields": [
|
||||
{
|
||||
"name": prop,
|
||||
"label": info.get("title", prop),
|
||||
"description": info.get("description", ""),
|
||||
"type": "text",
|
||||
"required": prop in required,
|
||||
}
|
||||
for prop, info in properties.items()
|
||||
],
|
||||
"custom": True,
|
||||
}
|
||||
)
|
||||
|
||||
types.append(
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "custom_types_category",
|
||||
"key": "custom_types",
|
||||
"icon": "custom",
|
||||
"label": "Custom types",
|
||||
"fields": [],
|
||||
"children": custom_types_children,
|
||||
}
|
||||
)
|
||||
|
||||
return types
|
||||
|
||||
def _get_category_definitions(self) -> List[Dict[str, Any]]:
|
||||
"""Get the category definitions for types."""
|
||||
return [
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "global",
|
||||
"key": "global_category",
|
||||
"icon": "phrase",
|
||||
"label": "Global",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Phrase", "text", None),
|
||||
("Location", "address", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "person",
|
||||
"key": "person_category",
|
||||
"icon": "individual",
|
||||
"label": "Identities & Entities",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Individual", "full_name", None),
|
||||
("Username", "value", "username"),
|
||||
("Organization", "name", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "organization",
|
||||
"key": "organization_category",
|
||||
"icon": "organization",
|
||||
"label": "Organization",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Organization", "name", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "contact_category",
|
||||
"key": "contact",
|
||||
"icon": "phone",
|
||||
"label": "Communication & Contact",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Phone", "number", None),
|
||||
("Email", "email", None),
|
||||
("Username", "value", None),
|
||||
("SocialAccount", "username", "socialaccount"),
|
||||
("Message", "content", "message"),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "network_category",
|
||||
"key": "network",
|
||||
"icon": "domain",
|
||||
"label": "Network",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("ASN", "number", None),
|
||||
("CIDR", "network", None),
|
||||
("Domain", "domain", None),
|
||||
("Website", "url", None),
|
||||
("Ip", "address", None),
|
||||
("Port", "number", None),
|
||||
("DNSRecord", "name", "dns"),
|
||||
("SSLCertificate", "subject", "ssl"),
|
||||
("WebTracker", "name", "webtracker"),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "security_category",
|
||||
"key": "security",
|
||||
"icon": "credential",
|
||||
"label": "Security & Access",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Credential", "username", "credential"),
|
||||
("Session", "session_id", "session"),
|
||||
("Device", "device_id", "device"),
|
||||
("Malware", "name", "malware"),
|
||||
("Weapon", "name", "weapon"),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "files_category",
|
||||
"key": "files",
|
||||
"icon": "file",
|
||||
"label": "Files & Documents",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Document", "title", "document"),
|
||||
("File", "filename", "file"),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "financial_category",
|
||||
"key": "financial",
|
||||
"icon": "creditcard",
|
||||
"label": "Financial Data",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("BankAccount", "account_number", "creditcard"),
|
||||
("CreditCard", "card_number", "creditcard"),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "leak_category",
|
||||
"key": "leaks",
|
||||
"icon": "breach",
|
||||
"label": "Leaks",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("Leak", "name", "breach"),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": uuid4(),
|
||||
"type": "crypto_category",
|
||||
"key": "crypto",
|
||||
"icon": "cryptowallet",
|
||||
"label": "Crypto",
|
||||
"fields": [],
|
||||
"children": [
|
||||
("CryptoWallet", "address", "cryptowallet"),
|
||||
("CryptoWalletTransaction", "hash", "cryptowallet"),
|
||||
("CryptoNFT", "name", "cryptowallet"),
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def _extract_input_schema(
|
||||
self, model: Type[BaseModel], label_key: str, icon: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Extract input schema from a Pydantic model."""
|
||||
adapter = TypeAdapter(model)
|
||||
schema = adapter.json_schema()
|
||||
type_name = model.__name__
|
||||
details = schema
|
||||
|
||||
return {
|
||||
"id": uuid4(),
|
||||
"type": type_name,
|
||||
"key": type_name.lower(),
|
||||
"label_key": label_key,
|
||||
"icon": icon or type_name.lower(),
|
||||
"label": type_name,
|
||||
"description": details.get("description", ""),
|
||||
"fields": [
|
||||
self._resolve_field(prop, details=info, schema=schema)
|
||||
for prop, info in details.get("properties", {}).items()
|
||||
if prop != "nodeLabel"
|
||||
],
|
||||
}
|
||||
|
||||
def _resolve_field(
|
||||
self, prop: str, details: dict, schema: dict = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Resolve a field definition from schema."""
|
||||
field = {
|
||||
"name": prop,
|
||||
"label": details.get("title", prop),
|
||||
"description": details.get("description", ""),
|
||||
"type": "text",
|
||||
}
|
||||
|
||||
if self._has_enum(details):
|
||||
field["type"] = "select"
|
||||
field["options"] = [
|
||||
{"label": label, "value": label}
|
||||
for label in self._get_enum_values(details)
|
||||
]
|
||||
|
||||
field["required"] = self._is_required(details)
|
||||
return field
|
||||
|
||||
def _has_enum(self, schema: dict) -> bool:
|
||||
"""Check if schema has enum values."""
|
||||
any_of = schema.get("anyOf", [])
|
||||
return any(isinstance(entry, dict) and "enum" in entry for entry in any_of)
|
||||
|
||||
def _is_required(self, schema: dict) -> bool:
|
||||
"""Check if field is required."""
|
||||
any_of = schema.get("anyOf", [])
|
||||
return not any(entry == {"type": "null"} for entry in any_of)
|
||||
|
||||
def _get_enum_values(self, schema: dict) -> list:
|
||||
"""Get enum values from schema."""
|
||||
enum_values = []
|
||||
for entry in schema.get("anyOf", []):
|
||||
if isinstance(entry, dict) and "enum" in entry:
|
||||
enum_values.extend(entry["enum"])
|
||||
return enum_values
|
||||
|
||||
|
||||
def create_type_registry_service(db: Session) -> TypeRegistryService:
|
||||
"""
|
||||
Factory function to create a TypeRegistryService instance.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
Configured TypeRegistryService instance
|
||||
"""
|
||||
return TypeRegistryService(db=db)
|
||||
@@ -1,22 +1,25 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from typing import Any, List, Union, Dict, Set, Optional
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from flowsint_core.core.enricher_base import Enricher
|
||||
from flowsint_enrichers.registry import flowsint_enricher
|
||||
from flowsint_core.core.logger import Logger
|
||||
from flowsint_core.utils import is_root_domain, is_valid_domain
|
||||
from flowsint_types.address import Location
|
||||
from flowsint_types.domain import Domain
|
||||
from flowsint_types.email import Email
|
||||
from flowsint_types.individual import Individual
|
||||
from flowsint_types.organization import Organization
|
||||
from flowsint_types.email import Email
|
||||
from flowsint_types.phone import Phone
|
||||
from flowsint_core.utils import is_valid_domain, is_root_domain
|
||||
from flowsint_types.address import Location
|
||||
from flowsint_core.core.logger import Logger
|
||||
|
||||
from flowsint_enrichers.registry import flowsint_enricher
|
||||
from tools.network.whoxy import WhoxyTool
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@flowsint_enricher
|
||||
class DomainToHistoryEnricher(Enricher):
|
||||
"""[WHOXY] Takes a domain and returns history infos about it (history, organization, owners, emails, etc.)."""
|
||||
@@ -185,7 +188,12 @@ class DomainToHistoryEnricher(Enricher):
|
||||
)
|
||||
|
||||
# Extract other non-redacted information (country, email, etc.)
|
||||
self.__extract_additional_info_from_contact(contact, contact_type, domain_name, extracted_info["original_domain"].domain)
|
||||
self.__extract_additional_info_from_contact(
|
||||
contact,
|
||||
contact_type,
|
||||
domain_name,
|
||||
extracted_info["original_domain"].domain,
|
||||
)
|
||||
else:
|
||||
Logger.info(
|
||||
self.sketch_id,
|
||||
@@ -220,8 +228,6 @@ class DomainToHistoryEnricher(Enricher):
|
||||
# A record is valid if it has a domain name - we'll filter contacts individually later
|
||||
return True
|
||||
|
||||
|
||||
|
||||
def __is_redacted(self, value: str) -> bool:
|
||||
"""Check if a value is redacted."""
|
||||
if not value:
|
||||
@@ -238,14 +244,16 @@ class DomainToHistoryEnricher(Enricher):
|
||||
if self.__is_redacted(full_name) or not full_name:
|
||||
Logger.info(
|
||||
self.sketch_id,
|
||||
{"message": f"[WHOXY] Skipping contact with redacted/empty name: {full_name}"},
|
||||
{
|
||||
"message": f"[WHOXY] Skipping contact with redacted/empty name: {full_name}"
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Parse full name into first and last name
|
||||
name_parts = full_name.strip().split()
|
||||
first_name = name_parts[0] if name_parts else ""
|
||||
last_name = " ".join(name_parts[1:]) if len(name_parts) > 1 else ""
|
||||
first_name = name_parts[0] if name_parts else "N/A"
|
||||
last_name = " ".join(name_parts[1:]) if len(name_parts) > 1 else "N/A"
|
||||
|
||||
# Extract email and phone
|
||||
email_raw = contact.get("email_address", "")
|
||||
@@ -322,9 +330,7 @@ class DomainToHistoryEnricher(Enricher):
|
||||
if not all([address, city, zip_code, country]):
|
||||
return None
|
||||
|
||||
return Location(
|
||||
address=address, city=city, zip=zip_code, country=country
|
||||
)
|
||||
return Location(address=address, city=city, zip=zip_code, country=country)
|
||||
|
||||
def __extract_organization_from_contact(
|
||||
self, contact: Dict[str, Any], contact_type: str
|
||||
@@ -349,13 +355,17 @@ class DomainToHistoryEnricher(Enricher):
|
||||
return organization
|
||||
|
||||
def __extract_additional_info_from_contact(
|
||||
self, contact: Dict[str, Any], contact_type: str, domain_name: str, original_domain: str
|
||||
self,
|
||||
contact: Dict[str, Any],
|
||||
contact_type: str,
|
||||
domain_name: str,
|
||||
original_domain: str,
|
||||
):
|
||||
"""Extract additional non-redacted information from contact data."""
|
||||
# Extract country information
|
||||
country_name = contact.get("country_name", "")
|
||||
country_code = contact.get("country_code", "")
|
||||
|
||||
|
||||
if country_name and not self.__is_redacted(country_name):
|
||||
Logger.info(
|
||||
self.sketch_id,
|
||||
@@ -363,7 +373,7 @@ class DomainToHistoryEnricher(Enricher):
|
||||
"message": f"[WHOXY] Found country: {country_name} ({contact_type}) for {domain_name}"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if country_code and not self.__is_redacted(country_code):
|
||||
Logger.info(
|
||||
self.sketch_id,
|
||||
@@ -380,7 +390,7 @@ class DomainToHistoryEnricher(Enricher):
|
||||
for email in email_list:
|
||||
if email and self.__is_valid_email(email):
|
||||
emails.append(email)
|
||||
|
||||
|
||||
if emails:
|
||||
Logger.info(
|
||||
self.sketch_id,
|
||||
@@ -389,7 +399,9 @@ class DomainToHistoryEnricher(Enricher):
|
||||
},
|
||||
)
|
||||
|
||||
def postprocess(self, results: List[OutputType], original_input: List[InputType]) -> List[OutputType]:
|
||||
def postprocess(
|
||||
self, results: List[OutputType], original_input: List[InputType]
|
||||
) -> List[OutputType]:
|
||||
"""Create Neo4j nodes and relationships from extracted data."""
|
||||
if not self._graph_service:
|
||||
Logger.info(
|
||||
@@ -440,7 +452,9 @@ class DomainToHistoryEnricher(Enricher):
|
||||
# Create relationship between original domain and found domain
|
||||
original_domain_obj = Domain(domain=original_domain_name)
|
||||
domain_obj_rel = Domain(domain=domain_name)
|
||||
self.create_relationship(original_domain_obj, domain_obj_rel, "HAS_RELATED_DOMAIN")
|
||||
self.create_relationship(
|
||||
original_domain_obj, domain_obj_rel, "HAS_RELATED_DOMAIN"
|
||||
)
|
||||
|
||||
# Create individual node if not already processed
|
||||
individual_id = (
|
||||
@@ -458,7 +472,9 @@ class DomainToHistoryEnricher(Enricher):
|
||||
|
||||
# Create relationship between individual and domain
|
||||
domain_obj_contact = Domain(domain=domain_name)
|
||||
self.create_relationship(individual, domain_obj_contact, f"IS_{contact_type.upper()}_CONTACT")
|
||||
self.create_relationship(
|
||||
individual, domain_obj_contact, f"IS_{contact_type.upper()}_CONTACT"
|
||||
)
|
||||
|
||||
# Process email addresses
|
||||
if individual.email_addresses:
|
||||
@@ -537,7 +553,9 @@ class DomainToHistoryEnricher(Enricher):
|
||||
# Create relationship between original domain and found domain
|
||||
original_domain_obj3 = Domain(domain=original_domain_name)
|
||||
domain_obj_rel3 = Domain(domain=domain_name)
|
||||
self.create_relationship(original_domain_obj3, domain_obj_rel3, "HAS_RELATED_DOMAIN")
|
||||
self.create_relationship(
|
||||
original_domain_obj3, domain_obj_rel3, "HAS_RELATED_DOMAIN"
|
||||
)
|
||||
|
||||
# Create organization node if not already processed
|
||||
if organization.name not in processed_organizations:
|
||||
@@ -552,7 +570,9 @@ class DomainToHistoryEnricher(Enricher):
|
||||
|
||||
# Create relationship between organization and domain
|
||||
domain_obj_org = Domain(domain=domain_name)
|
||||
self.create_relationship(organization, domain_obj_org, f"IS_{contact_type.upper()}_CONTACT")
|
||||
self.create_relationship(
|
||||
organization, domain_obj_org, f"IS_{contact_type.upper()}_CONTACT"
|
||||
)
|
||||
|
||||
self.log_graph_message(
|
||||
f"Processed organization {organization.name} ({contact_type}) for domain {domain_name}"
|
||||
|
||||
60
flowsint-enrichers/src/flowsint_enrichers/email/to_domain.py
Normal file
60
flowsint-enrichers/src/flowsint_enrichers/email/to_domain.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import List
|
||||
|
||||
from flowsint_core.core.enricher_base import Enricher
|
||||
from flowsint_types.domain import Domain
|
||||
from flowsint_types.email import Email
|
||||
|
||||
from flowsint_enrichers.registry import flowsint_enricher
|
||||
|
||||
|
||||
@flowsint_enricher
|
||||
class EmailToDomainEnricher(Enricher):
|
||||
"""From email to domain."""
|
||||
|
||||
InputType = Email
|
||||
OutputType = Domain
|
||||
|
||||
@classmethod
|
||||
def name(cls) -> str:
|
||||
return "email_to_domain"
|
||||
|
||||
@classmethod
|
||||
def category(cls) -> str:
|
||||
return "Email"
|
||||
|
||||
@classmethod
|
||||
def key(cls) -> str:
|
||||
return "email"
|
||||
|
||||
async def scan(self, data: List[InputType]) -> List[OutputType]:
|
||||
results: List[OutputType] = []
|
||||
|
||||
for email in data:
|
||||
splitted = email.email.split("@")
|
||||
domain = splitted[1]
|
||||
results.append(Domain(domain=domain))
|
||||
|
||||
return results
|
||||
|
||||
def postprocess(
|
||||
self, results: List[OutputType], original_input: List[InputType]
|
||||
) -> List[OutputType]:
|
||||
for email_obj, domain_obj in zip(original_input, results):
|
||||
if not self._graph_service:
|
||||
continue
|
||||
# Create email node
|
||||
self.create_node(email_obj)
|
||||
self.create_node(domain_obj)
|
||||
# Create relationship between email and gravatar
|
||||
self.create_relationship(email_obj, domain_obj, "HAS_DOMAIN")
|
||||
|
||||
self.log_graph_message(
|
||||
f"Exctracted domain for {email_obj.email} -> domain: {domain_obj.domain}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Make types available at module level for easy access
|
||||
InputType = EmailToDomainEnricher.InputType
|
||||
OutputType = EmailToDomainEnricher.OutputType
|
||||
@@ -1,16 +1,18 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List, Dict, Set, Optional
|
||||
from flowsint_core.core.enricher_base import Enricher
|
||||
from flowsint_enrichers.registry import flowsint_enricher
|
||||
from flowsint_types.domain import Domain
|
||||
from flowsint_types.individual import Individual
|
||||
from flowsint_types.email import Email
|
||||
from flowsint_types.phone import Phone
|
||||
from flowsint_types.address import Location
|
||||
from flowsint_core.core.logger import Logger
|
||||
from tools.network.whoxy import WhoxyTool
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from flowsint_core.core.enricher_base import Enricher
|
||||
from flowsint_core.core.logger import Logger
|
||||
from flowsint_types.address import Location
|
||||
from flowsint_types.domain import Domain
|
||||
from flowsint_types.email import Email
|
||||
from flowsint_types.individual import Individual
|
||||
from flowsint_types.phone import Phone
|
||||
|
||||
from flowsint_enrichers.registry import flowsint_enricher
|
||||
from tools.network.whoxy import WhoxyTool
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -142,8 +144,8 @@ class EmailToDomainsEnricher(Enricher):
|
||||
|
||||
# Parse full name into first and last name
|
||||
name_parts = full_name.strip().split()
|
||||
first_name = name_parts[0] if name_parts else ""
|
||||
last_name = " ".join(name_parts[1:]) if len(name_parts) > 1 else ""
|
||||
first_name = name_parts[0] if name_parts else "N/A"
|
||||
last_name = " ".join(name_parts[1:]) if len(name_parts) > 1 else "N/A"
|
||||
|
||||
# Extract email and phone
|
||||
email = contact.get("email_address", "")
|
||||
@@ -210,11 +212,11 @@ class EmailToDomainsEnricher(Enricher):
|
||||
if not all([address, city, zip_code, country]):
|
||||
return None
|
||||
|
||||
return Location(
|
||||
address=address, city=city, zip=zip_code, country=country
|
||||
)
|
||||
return Location(address=address, city=city, zip=zip_code, country=country)
|
||||
|
||||
def postprocess(self, results: List[OutputType], original_input: List[InputType]) -> List[OutputType]:
|
||||
def postprocess(
|
||||
self, results: List[OutputType], original_input: List[InputType]
|
||||
) -> List[OutputType]:
|
||||
"""Create Neo4j nodes and relationships from extracted data."""
|
||||
if not self._graph_service:
|
||||
return results
|
||||
@@ -299,7 +301,9 @@ class EmailToDomainsEnricher(Enricher):
|
||||
|
||||
# Create relationship between individual and domain
|
||||
domain_obj_ind = Domain(domain=domain_name)
|
||||
self.create_relationship(individual, domain_obj_ind, f"IS_{contact_type}_CONTACT")
|
||||
self.create_relationship(
|
||||
individual, domain_obj_ind, f"IS_{contact_type}_CONTACT"
|
||||
)
|
||||
|
||||
# Create relationship between individual and email
|
||||
email_obj_ind = Email(email=email_address)
|
||||
|
||||
@@ -84,9 +84,21 @@ class WebsiteToCrawler(Enricher):
|
||||
{"message": f" Found on: {item.source_url}"},
|
||||
)
|
||||
if item.type == "email":
|
||||
website_result["emails"].append(Email(email=item.value))
|
||||
try:
|
||||
website_result["emails"].append(Email(email=item.value))
|
||||
except Exception as e:
|
||||
Logger.warn(
|
||||
self.sketch_id,
|
||||
{"message": f"Skipping invalid email '{item.value}': {e}"},
|
||||
)
|
||||
if item.type == "phone":
|
||||
website_result["phones"].append(Phone(number=item.value))
|
||||
try:
|
||||
website_result["phones"].append(Phone(number=item.value))
|
||||
except Exception as e:
|
||||
Logger.warn(
|
||||
self.sketch_id,
|
||||
{"message": f"Skipping invalid phone '{item.value}': {e}"},
|
||||
)
|
||||
|
||||
# Log results
|
||||
Logger.info(
|
||||
|
||||
Reference in New Issue
Block a user