feat(api): enrichers from templates

This commit is contained in:
dextmorgn
2026-02-02 11:41:44 +01:00
parent 5ff6ba2b1c
commit 53a03575cd
8 changed files with 622 additions and 8 deletions

View File

@@ -0,0 +1,48 @@
"""add enricher_templates table
Revision ID: a1b2c3d4e5f6
Revises: 8173aba964e7
Create Date: 2025-01-31
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'a1b2c3d4e5f6'
down_revision: Union[str, None] = '8173aba964e7'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.create_table('enricher_templates',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('category', sa.Text(), nullable=False),
sa.Column('version', sa.Float(), nullable=False, server_default='1.0'),
sa.Column('content', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column('is_public', sa.Boolean(), nullable=False, server_default='false'),
sa.Column('owner_id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.ForeignKeyConstraint(['owner_id'], ['profiles.id'], onupdate='CASCADE', ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index('idx_enricher_templates_owner_id', 'enricher_templates', ['owner_id'], unique=False)
op.create_index('idx_enricher_templates_name', 'enricher_templates', ['name'], unique=False)
op.create_index('idx_enricher_templates_category', 'enricher_templates', ['category'], unique=False)
op.create_index('idx_enricher_templates_is_public', 'enricher_templates', ['is_public'], unique=False)
def downgrade() -> None:
"""Downgrade schema."""
op.drop_index('idx_enricher_templates_is_public', table_name='enricher_templates')
op.drop_index('idx_enricher_templates_category', table_name='enricher_templates')
op.drop_index('idx_enricher_templates_name', table_name='enricher_templates')
op.drop_index('idx_enricher_templates_owner_id', table_name='enricher_templates')
op.drop_table('enricher_templates')

View File

@@ -0,0 +1,28 @@
"""add description to enricher_templates
Revision ID: b2c3d4e5f6a7
Revises: a1b2c3d4e5f6
Create Date: 2025-01-31
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b2c3d4e5f6a7'
down_revision: Union[str, None] = 'a1b2c3d4e5f6'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Add description column to enricher_templates table."""
op.add_column('enricher_templates', sa.Column('description', sa.Text(), nullable=True))
def downgrade() -> None:
"""Remove description column from enricher_templates table."""
op.drop_column('enricher_templates', 'description')

View File

@@ -0,0 +1,150 @@
"""API routes for enricher templates management."""
from typing import List
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from flowsint_core.core.models import Profile
from flowsint_core.core.postgre_db import get_db
from flowsint_core.core.services import (
ConflictError,
NotFoundError,
create_enricher_template_service,
)
from flowsint_core.core.template_enricher import TemplateEnricher
from flowsint_core.templates.types import Template
from sqlalchemy.orm import Session
from app.api.deps import get_current_user
from app.api.schemas.enricher_template import (
EnricherTemplateCreate,
EnricherTemplateList,
EnricherTemplateRead,
EnricherTemplateTestRequest,
EnricherTemplateTestResponse,
EnricherTemplateUpdate,
)
router = APIRouter()
@router.post(
"", response_model=EnricherTemplateRead, status_code=status.HTTP_201_CREATED
)
def create_template(
template: EnricherTemplateCreate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
"""Create a new enricher template."""
content = template.content
name = content.get("name", template.name)
description = content.get("description", template.description)
category = content.get("category", template.category)
version = float(content.get("version", template.version))
service = create_enricher_template_service(db)
try:
return service.create_template(
name=name,
description=description,
category=category,
version=version,
content=content,
is_public=template.is_public,
owner_id=current_user.id,
)
except ConflictError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("", response_model=List[EnricherTemplateList])
def list_templates(
category: str = Query(None, description="Filter by category"),
include_public: bool = Query(
True, description="Include public templates from other users"
),
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
"""List enricher templates."""
service = create_enricher_template_service(db)
return service.list_templates(current_user.id, category, include_public)
@router.post("/{template_id}/test", response_model=EnricherTemplateTestResponse)
async def test_template(
template_id: UUID,
test_request: EnricherTemplateTestRequest,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
"""Test an enricher template with a sample input value."""
service = create_enricher_template_service(db)
try:
db_template = service.get_template(template_id, current_user.id)
except NotFoundError:
raise HTTPException(status_code=404, detail="Template not found")
try:
content = db_template.content
template = Template(**content)
enricher = TemplateEnricher(sketch_id="123", scan_id="123", template=template)
pre = enricher.preprocess([test_request.input_value])
results = await enricher.scan(pre)
data = {"results": results}
return EnricherTemplateTestResponse(
success=True, data=data, status_code=200, url=template.request.url
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occured : {e}")
@router.get("/{template_id}", response_model=EnricherTemplateRead)
def get_template(
template_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
"""Get a specific enricher template by ID."""
service = create_enricher_template_service(db)
try:
return service.get_template(template_id, current_user.id)
except NotFoundError:
raise HTTPException(status_code=404, detail="Template not found")
@router.put("/{template_id}", response_model=EnricherTemplateRead)
def update_template(
template_id: UUID,
update_data: EnricherTemplateUpdate,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
"""Update an enricher template. Only the owner can update."""
service = create_enricher_template_service(db)
try:
return service.update_template(
template_id=template_id,
owner_id=current_user.id,
update_data=update_data.model_dump(exclude_unset=True),
)
except NotFoundError:
raise HTTPException(status_code=404, detail="Template not found")
except ConflictError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/{template_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_template(
template_id: UUID,
db: Session = Depends(get_db),
current_user: Profile = Depends(get_current_user),
):
"""Delete an enricher template. Only the owner can delete."""
service = create_enricher_template_service(db)
try:
service.delete_template(template_id, current_user.id)
except NotFoundError:
raise HTTPException(status_code=404, detail="Template not found")
return None

View File

@@ -8,7 +8,10 @@ from flowsint_core.core.celery import celery
from flowsint_core.core.graph import create_graph_service
from flowsint_core.core.models import Profile
from flowsint_core.core.postgre_db import get_db
from flowsint_core.core.services import create_enricher_service
from flowsint_core.core.services import (
create_enricher_service,
create_enricher_template_service,
)
from flowsint_enrichers import ENRICHER_REGISTRY, load_all_enrichers
from app.api.deps import get_current_user
@@ -30,8 +33,17 @@ def get_enrichers(
current_user: Profile = Depends(get_current_user),
):
"""Get all enrichers, optionally filtered by category."""
service = create_enricher_service(db)
return service.get_enrichers(category, current_user.id, ENRICHER_REGISTRY)
enricher_service = create_enricher_service(db)
base_enrichers = enricher_service.get_enrichers(
category, current_user.id, ENRICHER_REGISTRY
)
template_service = create_enricher_template_service(db)
template_enrichers = template_service.list_by_category_for_user(
current_user.id, category
)
return [*base_enrichers, *template_enrichers]
@router.post("/{enricher_name}/launch")
@@ -39,6 +51,7 @@ async def launch_enricher(
enricher_name: str,
payload: launchEnricherPayload,
current_user: Profile = Depends(get_current_user),
db: Session = Depends(get_db),
):
try:
# Retrieve nodes from Neo4J by their element IDs
@@ -52,8 +65,21 @@ async def launch_enricher(
status_code=404, detail="No entities found with provided IDs"
)
enricher_in_registry = ENRICHER_REGISTRY.enricher_exists(enricher_name)
if not enricher_in_registry:
template_service = create_enricher_template_service(db)
template = template_service.find_by_name(enricher_name, current_user.id)
if not template:
raise HTTPException(
status_code=404,
detail=f"Enricher '{enricher_name}' not found in registry",
)
task_name = (
"run_enricher" if enricher_in_registry else "run_enricher_from_template"
)
task = celery.send_task(
"run_enricher",
task_name,
args=[
enricher_name,
entities,

View File

@@ -1,7 +1,8 @@
from .base import ORMBase
from typing import Any, Dict, List, Optional
from pydantic import UUID4, BaseModel
from typing import Optional
from typing import List, Optional, Dict, Any
from .base import ORMBase
class EnricherCreate(BaseModel):
@@ -21,6 +22,6 @@ class EnricherRead(ORMBase):
class EnricherUpdate(BaseModel):
name: Optional[str] = None
class_name: str = None
class_name: Optional[str] = None
description: Optional[str] = None
category: Optional[List[str]] = None

View File

@@ -0,0 +1,165 @@
"""Pydantic schemas for enricher templates."""
from datetime import datetime
from typing import Any, Dict, Optional
from pydantic import UUID4, BaseModel, Field, field_validator
from .base import ORMBase
class EnricherTemplateCreate(BaseModel):
"""Schema for creating a new enricher template."""
name: str = Field(
..., min_length=1, max_length=255, description="Name of the template"
)
description: Optional[str] = Field(
None, max_length=1000, description="Description of the template"
)
category: str = Field(
..., min_length=1, max_length=100, description="Category (e.g., Ip, Domain)"
)
version: float = Field(default=1.0, ge=0, description="Template version")
content: Dict[str, Any] = Field(
..., description="Template content as parsed YAML/JSON"
)
is_public: bool = Field(
default=False, description="Whether the template is publicly visible"
)
@field_validator("content")
@classmethod
def validate_content(cls, v: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that content has required template fields."""
required_fields = [
"name",
"category",
"version",
"input",
"request",
"output",
"response",
]
missing = [f for f in required_fields if f not in v]
if missing:
raise ValueError(
f"Missing required fields in content: {', '.join(missing)}"
)
# Validate input
if "input" in v and "type" not in v.get("input", {}):
raise ValueError("input.type is required")
# Validate request
request = v.get("request", {})
if "method" not in request:
raise ValueError("request.method is required")
if request.get("method") not in ["GET", "POST"]:
raise ValueError("request.method must be GET or POST")
if "url" not in request:
raise ValueError("request.url is required")
# Validate output
if "output" in v and "type" not in v.get("output", {}):
raise ValueError("output.type is required")
# Validate response
response = v.get("response", {})
if "expect" not in response:
raise ValueError("response.expect is required")
if response.get("expect") not in ["json", "xml", "text"]:
raise ValueError("response.expect must be json, xml, or text")
return v
class EnricherTemplateUpdate(BaseModel):
"""Schema for updating an existing enricher template."""
name: Optional[str] = Field(None, min_length=1, max_length=255)
description: Optional[str] = Field(None, max_length=1000)
category: Optional[str] = Field(None, min_length=1, max_length=100)
version: Optional[float] = Field(None, ge=0)
content: Optional[Dict[str, Any]] = None
is_public: Optional[bool] = None
@field_validator("content")
@classmethod
def validate_content(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Validate content if provided."""
if v is None:
return v
required_fields = [
"name",
"category",
"version",
"input",
"request",
"output",
"response",
]
missing = [f for f in required_fields if f not in v]
if missing:
raise ValueError(
f"Missing required fields in content: {', '.join(missing)}"
)
return v
class EnricherTemplateRead(ORMBase):
"""Schema for reading an enricher template."""
id: UUID4
name: str
description: Optional[str]
category: str
version: float
content: Dict[str, Any]
is_public: bool
owner_id: UUID4
created_at: datetime
updated_at: datetime
class EnricherTemplateList(ORMBase):
"""Schema for listing enricher templates (minimal fields)."""
id: UUID4
name: str
description: Optional[str]
category: str
version: float
is_public: bool
owner_id: UUID4
created_at: datetime
updated_at: datetime
class EnricherTemplateTestRequest(BaseModel):
"""Schema for testing an enricher template by ID."""
input_value: str = Field(
..., min_length=1, description="The value to test the template with"
)
class EnricherTemplateTestContentRequest(BaseModel):
"""Schema for testing template content directly (without saving)."""
input_value: str = Field(
..., min_length=1, description="The value to test the template with"
)
content: Dict[str, Any] = Field(..., description="Template content to test")
class EnricherTemplateTestResponse(BaseModel):
"""Schema for test response."""
success: bool
data: Optional[Dict[str, Any]] = None
error: Optional[str] = None
status_code: Optional[int] = None
url: str

View File

@@ -14,6 +14,7 @@ from app.api.routes import scan
from app.api.routes import keys
from app.api.routes import types
from app.api.routes import custom_types
from app.api.routes import enricher_templates
origins = [
"*",
@@ -51,3 +52,4 @@ app.include_router(scan.router, prefix="/api/scans", tags=["scans"])
app.include_router(keys.router, prefix="/api/keys", tags=["keys"])
app.include_router(types.router, prefix="/api/types", tags=["types"])
app.include_router(custom_types.router, prefix="/api/custom-types", tags=["custom-types"])
app.include_router(enricher_templates.router, prefix="/api/enrichers/templates", tags=["enricher-templates"])

View File

@@ -0,0 +1,194 @@
"""
Enricher template service for managing enricher template operations.
"""
from typing import Dict, List, Optional
from uuid import UUID
from sqlalchemy import or_
from sqlalchemy.orm import Session
from ..models import EnricherTemplate
from .base import BaseService
from .exceptions import ConflictError, NotFoundError, PermissionDeniedError
class EnricherTemplateService(BaseService):
"""Service for enricher template CRUD and lookup operations."""
def create_template(
self,
name: str,
description: Optional[str],
category: str,
version: float,
content: dict,
is_public: bool,
owner_id: UUID,
) -> EnricherTemplate:
self._check_duplicate_name(name, owner_id)
template = EnricherTemplate(
name=name,
description=description,
category=category,
version=version,
content=content,
is_public=is_public,
owner_id=owner_id,
)
self._add(template)
self._commit()
self._refresh(template)
return template
def list_templates(
self,
owner_id: UUID,
category: Optional[str] = None,
include_public: bool = True,
) -> List[EnricherTemplate]:
if include_public:
query = self._db.query(EnricherTemplate).filter(
or_(
EnricherTemplate.owner_id == owner_id,
EnricherTemplate.is_public,
)
)
else:
query = self._db.query(EnricherTemplate).filter(
EnricherTemplate.owner_id == owner_id
)
if category:
query = query.filter(EnricherTemplate.category == category)
return query.order_by(EnricherTemplate.created_at.desc()).all()
def get_template(self, template_id: UUID, user_id: UUID) -> EnricherTemplate:
template = (
self._db.query(EnricherTemplate)
.filter(
EnricherTemplate.id == template_id,
or_(
EnricherTemplate.owner_id == user_id,
EnricherTemplate.is_public,
),
)
.first()
)
if not template:
raise NotFoundError("Template not found")
return template
def get_owned_template(
self, template_id: UUID, owner_id: UUID
) -> EnricherTemplate:
template = (
self._db.query(EnricherTemplate)
.filter(
EnricherTemplate.id == template_id,
EnricherTemplate.owner_id == owner_id,
)
.first()
)
if not template:
raise NotFoundError("Template not found")
return template
def update_template(
self,
template_id: UUID,
owner_id: UUID,
update_data: dict,
) -> EnricherTemplate:
template = self.get_owned_template(template_id, owner_id)
content = update_data.get("content")
if content is not None:
new_name = content.get("name")
if new_name and new_name != template.name:
self._check_duplicate_name(new_name, owner_id, exclude_id=template_id)
template.name = new_name
new_category = content.get("category")
if new_category:
template.category = new_category
new_version = content.get("version")
if new_version is not None:
template.version = float(new_version)
template.description = content.get("description")
template.content = content
# Explicit field updates override content values
if update_data.get("name") is not None:
self._check_duplicate_name(
update_data["name"], owner_id, exclude_id=template_id
)
template.name = update_data["name"]
if update_data.get("category") is not None:
template.category = update_data["category"]
if update_data.get("description") is not None:
template.description = update_data["description"]
if update_data.get("version") is not None:
template.version = update_data["version"]
if update_data.get("is_public") is not None:
template.is_public = update_data["is_public"]
self._commit()
self._refresh(template)
return template
def delete_template(self, template_id: UUID, owner_id: UUID) -> None:
template = self.get_owned_template(template_id, owner_id)
self._delete(template)
self._commit()
def find_by_name(self, name: str, user_id: UUID) -> Optional[EnricherTemplate]:
return (
self._db.query(EnricherTemplate)
.filter(
EnricherTemplate.name == name,
or_(
EnricherTemplate.owner_id == user_id,
EnricherTemplate.is_public,
),
)
.first()
)
def list_by_category_for_user(
self, owner_id: UUID, category: Optional[str] = None
) -> List[EnricherTemplate]:
query = self._db.query(EnricherTemplate).filter(
EnricherTemplate.owner_id == owner_id
)
if category:
query = query.filter(EnricherTemplate.category == category)
return query.order_by(EnricherTemplate.created_at.desc()).all()
def _check_duplicate_name(
self,
name: str,
owner_id: UUID,
exclude_id: Optional[UUID] = None,
) -> None:
query = self._db.query(EnricherTemplate).filter(
EnricherTemplate.owner_id == owner_id,
EnricherTemplate.name == name,
)
if exclude_id:
query = query.filter(EnricherTemplate.id != exclude_id)
if query.first():
raise ConflictError(f"Template with name '{name}' already exists")
def create_enricher_template_service(db: Session) -> EnricherTemplateService:
"""Factory function to create an EnricherTemplateService instance."""
return EnricherTemplateService(db=db)