mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-04-29 00:59:07 -05:00
feat(tools): add shared infrastructure for MLSysBook tool development
- Add pyproject.toml with comprehensive Python package configuration - Create tools/scripts/common/ package with shared utilities: - base_classes.py: Abstract base classes following SOLID principles - config.py: Centralized configuration management with environment support - exceptions.py: Custom exception hierarchy for better error handling - logging_config.py: Standardized logging setup across all tools - validators.py: Input validation utilities with detailed error reporting This establishes a proper foundation for building maintainable, production-grade tools following software engineering best practices as outlined in the updated .cursorrules guidelines.
This commit is contained in:
266
pyproject.toml
Normal file
266
pyproject.toml
Normal file
@@ -0,0 +1,266 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "mlsysbook"
|
||||
version = "1.0.0"
|
||||
description = "Machine Learning Systems Textbook - Tools and Scripts"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = {text = "MIT"}
|
||||
authors = [
|
||||
{name = "MLSysBook Contributors", email = "info@mlsysbook.ai"}
|
||||
]
|
||||
keywords = ["machine-learning", "systems", "textbook", "education", "ai"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Education",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Documentation",
|
||||
"Topic :: Text Processing :: Markup",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
# Core dependencies for Quarto/Jupyter integration
|
||||
"jupyterlab-quarto>=0.3.0",
|
||||
"jupyter>=1.0.0",
|
||||
|
||||
# Bibliography and document processing
|
||||
"pybtex>=0.24.0",
|
||||
"pypandoc>=1.11",
|
||||
"pyyaml>=6.0",
|
||||
|
||||
# Data processing and validation
|
||||
"pandas>=2.0.0",
|
||||
"numpy>=1.24.0",
|
||||
"jsonschema>=4.0.0",
|
||||
|
||||
# Image processing
|
||||
"Pillow>=9.0.0",
|
||||
|
||||
# HTTP and API interactions
|
||||
"requests>=2.31.0",
|
||||
|
||||
# Text processing
|
||||
"titlecase>=2.4.1",
|
||||
|
||||
# Terminal output and UI
|
||||
"rich>=13.0.0",
|
||||
|
||||
# Type checking and validation
|
||||
"typing-extensions>=4.5.0",
|
||||
|
||||
# Utilities
|
||||
"click>=8.0.0",
|
||||
"pathlib2>=2.3.0; python_version<'3.11'",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
# Testing
|
||||
"pytest>=7.0.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
"pytest-mock>=3.10.0",
|
||||
"pytest-asyncio>=0.21.0",
|
||||
|
||||
# Code quality
|
||||
"black>=23.0.0",
|
||||
"isort>=5.12.0",
|
||||
"flake8>=6.0.0",
|
||||
"mypy>=1.0.0",
|
||||
"pylint>=2.17.0",
|
||||
|
||||
# Security
|
||||
"bandit>=1.7.0",
|
||||
"safety>=2.3.0",
|
||||
|
||||
# Documentation
|
||||
"sphinx>=6.0.0",
|
||||
"sphinx-rtd-theme>=1.2.0",
|
||||
|
||||
# Pre-commit hooks
|
||||
"pre-commit>=3.0.0",
|
||||
]
|
||||
|
||||
ai = [
|
||||
# AI/ML specific dependencies
|
||||
"openai>=1.0.0",
|
||||
"sentence-transformers>=2.2.0",
|
||||
"transformers>=4.21.0",
|
||||
"torch>=1.12.0",
|
||||
"scikit-learn>=1.3.0",
|
||||
"gradio>=3.40.0",
|
||||
"ollama>=0.1.0",
|
||||
]
|
||||
|
||||
build = [
|
||||
# Build and publishing tools
|
||||
"twine>=4.0.0",
|
||||
"build>=0.10.0",
|
||||
"setuptools>=61.0",
|
||||
"wheel>=0.40.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://mlsysbook.ai"
|
||||
Documentation = "https://docs.mlsysbook.ai"
|
||||
Repository = "https://github.com/mlsysbook/mlsysbook"
|
||||
Issues = "https://github.com/mlsysbook/mlsysbook/issues"
|
||||
Changelog = "https://github.com/mlsysbook/mlsysbook/blob/main/CHANGELOG.md"
|
||||
|
||||
[project.scripts]
|
||||
mlsysbook-cli = "tools.scripts.cli:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["tools*"]
|
||||
exclude = ["tests*", "docs*"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"tools.scripts" = ["**/*.yaml", "**/*.yml", "**/*.json", "**/*.txt", "**/*.md"]
|
||||
|
||||
# Black formatting configuration
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
target-version = ['py39', 'py310', 'py311', 'py312']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
/(
|
||||
# Exclude build artifacts
|
||||
\.eggs
|
||||
| \.git
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| _build
|
||||
| buck-out
|
||||
| build
|
||||
| dist
|
||||
# Exclude generated files
|
||||
| contents/
|
||||
| quarto/
|
||||
)/
|
||||
'''
|
||||
|
||||
# isort import sorting configuration
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 100
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
src_paths = ["tools"]
|
||||
skip_glob = ["contents/*", "quarto/*"]
|
||||
|
||||
# MyPy type checking configuration
|
||||
[tool.mypy]
|
||||
python_version = "3.9"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
no_implicit_optional = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
warn_no_return = true
|
||||
warn_unreachable = true
|
||||
strict_equality = true
|
||||
show_error_codes = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"pypandoc.*",
|
||||
"pybtex.*",
|
||||
"titlecase.*",
|
||||
"gradio.*",
|
||||
"ollama.*",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
# Pytest configuration
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "7.0"
|
||||
addopts = [
|
||||
"--strict-markers",
|
||||
"--strict-config",
|
||||
"--cov=tools",
|
||||
"--cov-report=term-missing",
|
||||
"--cov-report=html",
|
||||
"--cov-fail-under=80"
|
||||
]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"integration: marks tests as integration tests",
|
||||
"unit: marks tests as unit tests",
|
||||
"ai: marks tests that require AI models",
|
||||
]
|
||||
|
||||
# Coverage configuration
|
||||
[tool.coverage.run]
|
||||
source = ["tools"]
|
||||
branch = true
|
||||
omit = [
|
||||
"*/tests/*",
|
||||
"*/test_*",
|
||||
"*/__pycache__/*",
|
||||
"*/.*",
|
||||
]
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"def __repr__",
|
||||
"if self.debug:",
|
||||
"if settings.DEBUG",
|
||||
"raise AssertionError",
|
||||
"raise NotImplementedError",
|
||||
"if 0:",
|
||||
"if __name__ == .__main__.:",
|
||||
"class .*\\bProtocol\\):",
|
||||
"@(abc\\.)?abstractmethod",
|
||||
]
|
||||
|
||||
# Pylint configuration
|
||||
[tool.pylint.messages_control]
|
||||
disable = [
|
||||
"C0330", # Wrong hanging indentation before block (conflicts with black)
|
||||
"C0326", # Bad whitespace (conflicts with black)
|
||||
"R0903", # Too few public methods (sometimes OK for data classes)
|
||||
"R0913", # Too many arguments (sometimes necessary)
|
||||
]
|
||||
|
||||
[tool.pylint.format]
|
||||
max-line-length = "100"
|
||||
|
||||
[tool.pylint.design]
|
||||
max-args = 8
|
||||
max-locals = 20
|
||||
max-branches = 15
|
||||
max-statements = 60
|
||||
|
||||
# Bandit security linting
|
||||
[tool.bandit]
|
||||
exclude_dirs = ["tests", "contents", "quarto"]
|
||||
skips = ["B101", "B601"] # Skip assert_used and shell=True (sometimes needed)
|
||||
|
||||
# Flake8 configuration (in setup.cfg or .flake8 file)
|
||||
# Note: flake8 doesn't support pyproject.toml yet
|
||||
29
tools/scripts/common/__init__.py
Normal file
29
tools/scripts/common/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Common utilities and shared components for MLSysBook tools.
|
||||
|
||||
This package provides shared functionality across all tools in the MLSysBook project,
|
||||
including base classes, configuration management, logging setup, and common utilities.
|
||||
|
||||
Modules:
|
||||
base_classes: Abstract base classes for tools and processors
|
||||
config: Configuration management and environment handling
|
||||
exceptions: Custom exception definitions
|
||||
logging_config: Centralized logging configuration
|
||||
validators: Input validation utilities
|
||||
file_utils: File and path operation utilities
|
||||
"""
|
||||
|
||||
from .exceptions import MLSysBookError, ConfigurationError, ValidationError
|
||||
from .config import get_config, Config
|
||||
from .logging_config import setup_logging, get_logger
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__all__ = [
|
||||
"MLSysBookError",
|
||||
"ConfigurationError",
|
||||
"ValidationError",
|
||||
"get_config",
|
||||
"Config",
|
||||
"setup_logging",
|
||||
"get_logger",
|
||||
]
|
||||
388
tools/scripts/common/base_classes.py
Normal file
388
tools/scripts/common/base_classes.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
Abstract base classes for MLSysBook tools.
|
||||
|
||||
This module provides abstract base classes that define common interfaces
|
||||
and patterns for tools in the MLSysBook project.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .config import get_config
|
||||
from .logging_config import get_logger
|
||||
from .exceptions import ToolExecutionError, ValidationError
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Standard result object for tool operations.
|
||||
|
||||
Attributes:
|
||||
success: Whether the operation was successful
|
||||
message: Human-readable result message
|
||||
data: Optional result data
|
||||
errors: List of error messages
|
||||
warnings: List of warning messages
|
||||
metadata: Additional metadata about the operation
|
||||
"""
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[Any] = None
|
||||
errors: Optional[List[str]] = None
|
||||
warnings: Optional[List[str]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize default values."""
|
||||
if self.errors is None:
|
||||
self.errors = []
|
||||
if self.warnings is None:
|
||||
self.warnings = []
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""Abstract base class for all MLSysBook tools.
|
||||
|
||||
This class provides a common interface and shared functionality for
|
||||
command-line tools and processing utilities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
config: Optional[Any] = None
|
||||
) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
name: Tool name (defaults to class name)
|
||||
description: Tool description
|
||||
config: Configuration object (defaults to global config)
|
||||
"""
|
||||
self.name = name or self.__class__.__name__
|
||||
self.description = description or self.__class__.__doc__ or "MLSysBook tool"
|
||||
self.config = config or get_config()
|
||||
self.logger = get_logger(f"tools.{self.name}")
|
||||
|
||||
# Initialize state
|
||||
self._initialized = False
|
||||
self._results: List[ToolResult] = []
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs) -> ToolResult:
|
||||
"""Run the tool with the provided arguments.
|
||||
|
||||
This is the main entry point for tool execution.
|
||||
|
||||
Returns:
|
||||
ToolResult object with operation results
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_inputs(self, *args, **kwargs) -> None:
|
||||
"""Validate tool inputs.
|
||||
|
||||
Override this method to provide input validation specific to your tool.
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Override this method to perform one-time setup operations.
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.logger.debug(f"Initializing tool: {self.name}")
|
||||
self._initialized = True
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up resources used by the tool.
|
||||
|
||||
Override this method to perform cleanup operations.
|
||||
"""
|
||||
self.logger.debug(f"Cleaning up tool: {self.name}")
|
||||
|
||||
def add_result(self, result: ToolResult) -> None:
|
||||
"""Add a result to the tool's result history.
|
||||
|
||||
Args:
|
||||
result: ToolResult to add
|
||||
"""
|
||||
self._results.append(result)
|
||||
|
||||
def get_results(self) -> List[ToolResult]:
|
||||
"""Get all results from tool execution.
|
||||
|
||||
Returns:
|
||||
List of ToolResult objects
|
||||
"""
|
||||
return self._results.copy()
|
||||
|
||||
def get_last_result(self) -> Optional[ToolResult]:
|
||||
"""Get the most recent result.
|
||||
|
||||
Returns:
|
||||
Most recent ToolResult or None if no results
|
||||
"""
|
||||
return self._results[-1] if self._results else None
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
self.initialize()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.cleanup()
|
||||
|
||||
|
||||
class FileProcessorTool(BaseTool):
|
||||
"""Base class for tools that process files.
|
||||
|
||||
This class provides common functionality for tools that work with
|
||||
files and directories.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
"""Initialize the file processor tool."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.processed_files: List[Path] = []
|
||||
self.failed_files: List[Path] = []
|
||||
|
||||
@abstractmethod
|
||||
def process_file(self, file_path: Path) -> ToolResult:
|
||||
"""Process a single file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to process
|
||||
|
||||
Returns:
|
||||
ToolResult object with processing results
|
||||
"""
|
||||
pass
|
||||
|
||||
def process_directory(
|
||||
self,
|
||||
directory: Path,
|
||||
pattern: str = "**/*",
|
||||
recursive: bool = True
|
||||
) -> ToolResult:
|
||||
"""Process all matching files in a directory.
|
||||
|
||||
Args:
|
||||
directory: Directory to process
|
||||
pattern: Glob pattern for file matching
|
||||
recursive: Whether to process subdirectories
|
||||
|
||||
Returns:
|
||||
ToolResult object with overall processing results
|
||||
"""
|
||||
self.logger.info(f"Processing directory: {directory}")
|
||||
|
||||
try:
|
||||
if recursive:
|
||||
files = list(directory.rglob(pattern))
|
||||
else:
|
||||
files = list(directory.glob(pattern))
|
||||
|
||||
files = [f for f in files if f.is_file()]
|
||||
|
||||
results = []
|
||||
for file_path in files:
|
||||
try:
|
||||
result = self.process_file(file_path)
|
||||
results.append(result)
|
||||
|
||||
if result.success:
|
||||
self.processed_files.append(file_path)
|
||||
else:
|
||||
self.failed_files.append(file_path)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing {file_path}: {e}")
|
||||
self.failed_files.append(file_path)
|
||||
results.append(ToolResult(
|
||||
success=False,
|
||||
message=f"Failed to process {file_path}: {e}",
|
||||
errors=[str(e)]
|
||||
))
|
||||
|
||||
successful = len(self.processed_files)
|
||||
failed = len(self.failed_files)
|
||||
|
||||
return ToolResult(
|
||||
success=failed == 0,
|
||||
message=f"Processed {successful} files, {failed} failed",
|
||||
data={
|
||||
"processed_files": self.processed_files,
|
||||
"failed_files": self.failed_files,
|
||||
"results": results
|
||||
},
|
||||
metadata={
|
||||
"total_files": len(files),
|
||||
"successful_count": successful,
|
||||
"failed_count": failed
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to process directory {directory}: {e}"
|
||||
self.logger.error(error_msg)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
message=error_msg,
|
||||
errors=[str(e)]
|
||||
)
|
||||
|
||||
|
||||
class CLITool(BaseTool):
|
||||
"""Base class for command-line interface tools.
|
||||
|
||||
This class provides argument parsing and common CLI patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
"""Initialize the CLI tool."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.parser = self.create_argument_parser()
|
||||
|
||||
def create_argument_parser(self) -> argparse.ArgumentParser:
|
||||
"""Create the argument parser for this tool.
|
||||
|
||||
Override this method to add tool-specific arguments.
|
||||
|
||||
Returns:
|
||||
Configured ArgumentParser
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
prog=self.name,
|
||||
description=self.description
|
||||
)
|
||||
|
||||
# Add common arguments
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Enable verbose output"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Show what would be done without making changes"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
default="INFO",
|
||||
help="Set the logging level"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def parse_args(self, args: Optional[List[str]] = None) -> argparse.Namespace:
|
||||
"""Parse command-line arguments.
|
||||
|
||||
Args:
|
||||
args: Arguments to parse (defaults to sys.argv)
|
||||
|
||||
Returns:
|
||||
Parsed arguments namespace
|
||||
"""
|
||||
return self.parser.parse_args(args)
|
||||
|
||||
def run_cli(self, args: Optional[List[str]] = None) -> ToolResult:
|
||||
"""Run the tool as a CLI application.
|
||||
|
||||
Args:
|
||||
args: Command-line arguments
|
||||
|
||||
Returns:
|
||||
ToolResult object
|
||||
"""
|
||||
try:
|
||||
parsed_args = self.parse_args(args)
|
||||
|
||||
# Update logging level if specified
|
||||
if hasattr(parsed_args, 'log_level'):
|
||||
self.logger.setLevel(parsed_args.log_level)
|
||||
|
||||
# Run input validation
|
||||
self.validate_inputs(**vars(parsed_args))
|
||||
|
||||
# Run the tool
|
||||
return self.run(**vars(parsed_args))
|
||||
|
||||
except ValidationError as e:
|
||||
error_msg = f"Input validation failed: {e}"
|
||||
self.logger.error(error_msg)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
message=error_msg,
|
||||
errors=[str(e)]
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Tool execution failed: {e}"
|
||||
self.logger.error(error_msg, exc_info=True)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
message=error_msg,
|
||||
errors=[str(e)]
|
||||
)
|
||||
|
||||
|
||||
class ContentProcessor(ABC):
|
||||
"""Abstract base class for content processing components.
|
||||
|
||||
This class defines the interface for components that process
|
||||
textbook content (markdown, images, etc.).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def can_process(self, content_type: str, file_path: Path) -> bool:
|
||||
"""Check if this processor can handle the given content.
|
||||
|
||||
Args:
|
||||
content_type: Type of content (e.g., 'markdown', 'image')
|
||||
file_path: Path to the content file
|
||||
|
||||
Returns:
|
||||
True if this processor can handle the content
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process(self, file_path: Path, **kwargs) -> ToolResult:
|
||||
"""Process the content.
|
||||
|
||||
Args:
|
||||
file_path: Path to the content file
|
||||
**kwargs: Additional processing options
|
||||
|
||||
Returns:
|
||||
ToolResult object with processing results
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_priority(self) -> int:
|
||||
"""Get the processor priority.
|
||||
|
||||
Higher numbers indicate higher priority. When multiple processors
|
||||
can handle the same content, the one with the highest priority is used.
|
||||
|
||||
Returns:
|
||||
Priority value (default: 0)
|
||||
"""
|
||||
return 0
|
||||
206
tools/scripts/common/config.py
Normal file
206
tools/scripts/common/config.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Configuration management for MLSysBook tools.
|
||||
|
||||
This module provides centralized configuration management with support for
|
||||
environment variables, configuration files, and sensible defaults.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .exceptions import ConfigurationError
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Central configuration class for MLSysBook tools.
|
||||
|
||||
This class manages all configuration settings with support for environment
|
||||
variables, configuration files, and programmatic overrides.
|
||||
"""
|
||||
|
||||
# Project paths
|
||||
project_root: Path = field(default_factory=lambda: Path(__file__).parents[3])
|
||||
quarto_root: Path = field(default_factory=lambda: Path(__file__).parents[3] / "quarto")
|
||||
tools_root: Path = field(default_factory=lambda: Path(__file__).parents[1])
|
||||
content_root: Path = field(default_factory=lambda: Path(__file__).parents[3] / "quarto" / "contents")
|
||||
|
||||
# Logging configuration
|
||||
log_level: str = field(default="INFO")
|
||||
log_format: str = field(default="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
log_to_file: bool = field(default=False)
|
||||
log_file_path: Optional[Path] = field(default=None)
|
||||
|
||||
# API configurations
|
||||
openai_api_key: Optional[str] = field(default=None)
|
||||
ollama_base_url: str = field(default="http://localhost:11434")
|
||||
ollama_model: str = field(default="llama3.1:8b")
|
||||
|
||||
# Processing settings
|
||||
max_workers: int = field(default=4)
|
||||
chunk_size: int = field(default=1000)
|
||||
enable_caching: bool = field(default=True)
|
||||
cache_dir: Path = field(default_factory=lambda: Path.home() / ".cache" / "mlsysbook")
|
||||
|
||||
# Content processing settings
|
||||
backup_enabled: bool = field(default=True)
|
||||
backup_dir: Path = field(default_factory=lambda: Path(__file__).parents[3] / "backups")
|
||||
dry_run: bool = field(default=False)
|
||||
|
||||
# Quality thresholds
|
||||
similarity_threshold: float = field(default=0.65)
|
||||
min_caption_length: int = field(default=10)
|
||||
max_caption_length: int = field(default=500)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize configuration after dataclass creation."""
|
||||
self._load_from_environment()
|
||||
self._load_from_config_file()
|
||||
self._validate_configuration()
|
||||
|
||||
def _load_from_environment(self) -> None:
|
||||
"""Load configuration from environment variables."""
|
||||
env_mappings = {
|
||||
"MLSYSBOOK_LOG_LEVEL": "log_level",
|
||||
"MLSYSBOOK_LOG_TO_FILE": "log_to_file",
|
||||
"MLSYSBOOK_DRY_RUN": "dry_run",
|
||||
"MLSYSBOOK_BACKUP_ENABLED": "backup_enabled",
|
||||
"MLSYSBOOK_MAX_WORKERS": "max_workers",
|
||||
"OPENAI_API_KEY": "openai_api_key",
|
||||
"OLLAMA_BASE_URL": "ollama_base_url",
|
||||
"OLLAMA_MODEL": "ollama_model",
|
||||
}
|
||||
|
||||
for env_var, attr_name in env_mappings.items():
|
||||
env_value = os.getenv(env_var)
|
||||
if env_value is not None:
|
||||
# Convert string values to appropriate types
|
||||
if hasattr(self, attr_name):
|
||||
current_value = getattr(self, attr_name)
|
||||
if isinstance(current_value, bool):
|
||||
setattr(self, attr_name, env_value.lower() in ("true", "1", "yes", "on"))
|
||||
elif isinstance(current_value, int):
|
||||
try:
|
||||
setattr(self, attr_name, int(env_value))
|
||||
except ValueError:
|
||||
raise ConfigurationError(
|
||||
f"Invalid integer value for {env_var}: {env_value}"
|
||||
)
|
||||
elif isinstance(current_value, float):
|
||||
try:
|
||||
setattr(self, attr_name, float(env_value))
|
||||
except ValueError:
|
||||
raise ConfigurationError(
|
||||
f"Invalid float value for {env_var}: {env_value}"
|
||||
)
|
||||
else:
|
||||
setattr(self, attr_name, env_value)
|
||||
|
||||
def _load_from_config_file(self) -> None:
|
||||
"""Load configuration from config file if it exists."""
|
||||
config_paths = [
|
||||
self.project_root / "mlsysbook.yaml",
|
||||
self.project_root / "mlsysbook.yml",
|
||||
self.project_root / ".mlsysbook.yaml",
|
||||
self.project_root / ".mlsysbook.yml",
|
||||
]
|
||||
|
||||
for config_path in config_paths:
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
|
||||
if config_data:
|
||||
self._apply_config_dict(config_data)
|
||||
break
|
||||
except Exception as e:
|
||||
raise ConfigurationError(
|
||||
f"Failed to load configuration from {config_path}: {e}"
|
||||
)
|
||||
|
||||
def _apply_config_dict(self, config_data: Dict[str, Any]) -> None:
|
||||
"""Apply configuration from a dictionary."""
|
||||
for key, value in config_data.items():
|
||||
if hasattr(self, key):
|
||||
# Convert path strings to Path objects
|
||||
if key.endswith(('_root', '_dir', '_path')) and isinstance(value, str):
|
||||
setattr(self, key, Path(value))
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
def _validate_configuration(self) -> None:
|
||||
"""Validate configuration values."""
|
||||
# Ensure directories exist or can be created
|
||||
for attr_name in ['cache_dir', 'backup_dir']:
|
||||
dir_path = getattr(self, attr_name)
|
||||
if dir_path and not dir_path.exists():
|
||||
try:
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
except Exception as e:
|
||||
raise ConfigurationError(
|
||||
f"Cannot create directory {dir_path}: {e}"
|
||||
)
|
||||
|
||||
# Validate numeric ranges
|
||||
if self.max_workers < 1:
|
||||
raise ConfigurationError("max_workers must be at least 1")
|
||||
|
||||
if not 0.0 <= self.similarity_threshold <= 1.0:
|
||||
raise ConfigurationError("similarity_threshold must be between 0.0 and 1.0")
|
||||
|
||||
if self.min_caption_length < 1:
|
||||
raise ConfigurationError("min_caption_length must be at least 1")
|
||||
|
||||
if self.max_caption_length < self.min_caption_length:
|
||||
raise ConfigurationError("max_caption_length must be >= min_caption_length")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert configuration to dictionary."""
|
||||
result = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if isinstance(value, Path):
|
||||
result[key] = str(value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
def save_to_file(self, path: Union[str, Path]) -> None:
|
||||
"""Save configuration to a YAML file."""
|
||||
path = Path(path)
|
||||
config_dict = self.to_dict()
|
||||
|
||||
try:
|
||||
with open(path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=True)
|
||||
except Exception as e:
|
||||
raise ConfigurationError(f"Failed to save configuration to {path}: {e}")
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_config: Optional[Config] = None
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""Get the global configuration instance.
|
||||
|
||||
Returns:
|
||||
The global Config instance, creating it if necessary.
|
||||
"""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = Config()
|
||||
return _config
|
||||
|
||||
|
||||
def reset_config() -> None:
|
||||
"""Reset the global configuration instance.
|
||||
|
||||
This is primarily useful for testing.
|
||||
"""
|
||||
global _config
|
||||
_config = None
|
||||
108
tools/scripts/common/exceptions.py
Normal file
108
tools/scripts/common/exceptions.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Custom exception definitions for MLSysBook tools.
|
||||
|
||||
This module defines a hierarchy of custom exceptions to provide clear error handling
|
||||
and better debugging information across all tools in the MLSysBook project.
|
||||
"""
|
||||
|
||||
from typing import Optional, Any, Dict
|
||||
|
||||
|
||||
class MLSysBookError(Exception):
|
||||
"""Base exception for all MLSysBook-related errors.
|
||||
|
||||
This is the root exception class that all other custom exceptions inherit from.
|
||||
It provides consistent error handling and optional context information.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message
|
||||
context: Optional dictionary containing error context information
|
||||
original_error: Original exception that caused this error (if any)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
original_error: Optional[Exception] = None
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.context = context or {}
|
||||
self.original_error = original_error
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a formatted error message with context."""
|
||||
error_msg = self.message
|
||||
|
||||
if self.context:
|
||||
context_str = ", ".join(f"{k}={v}" for k, v in self.context.items())
|
||||
error_msg += f" (Context: {context_str})"
|
||||
|
||||
if self.original_error:
|
||||
error_msg += f" (Caused by: {self.original_error})"
|
||||
|
||||
return error_msg
|
||||
|
||||
|
||||
class ConfigurationError(MLSysBookError):
|
||||
"""Raised when there are configuration-related errors.
|
||||
|
||||
This includes missing configuration files, invalid configuration values,
|
||||
environment variable issues, and other configuration problems.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(MLSysBookError):
|
||||
"""Raised when input validation fails.
|
||||
|
||||
This includes invalid file paths, malformed data, missing required fields,
|
||||
and other input validation issues.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class FileOperationError(MLSysBookError):
|
||||
"""Raised when file operations fail.
|
||||
|
||||
This includes file not found, permission denied, disk space issues,
|
||||
and other file system related errors.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ProcessingError(MLSysBookError):
|
||||
"""Raised when content processing fails.
|
||||
|
||||
This includes parsing errors, conversion failures, and other content
|
||||
processing related issues.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class APIError(MLSysBookError):
|
||||
"""Raised when external API calls fail.
|
||||
|
||||
This includes HTTP errors, authentication failures, rate limiting,
|
||||
and other API-related issues.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ToolExecutionError(MLSysBookError):
|
||||
"""Raised when tool execution fails.
|
||||
|
||||
This is a general execution error for tools that encounter unexpected
|
||||
conditions during their main operation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DependencyError(MLSysBookError):
|
||||
"""Raised when required dependencies are missing or incompatible.
|
||||
|
||||
This includes missing Python packages, external tools, or version
|
||||
compatibility issues.
|
||||
"""
|
||||
pass
|
||||
247
tools/scripts/common/logging_config.py
Normal file
247
tools/scripts/common/logging_config.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Centralized logging configuration for MLSysBook tools.
|
||||
|
||||
This module provides a standardized logging setup with support for structured
|
||||
logging, different output formats, and configurable log levels.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from rich.traceback import install
|
||||
|
||||
from .config import get_config
|
||||
|
||||
|
||||
class StructuredFormatter(logging.Formatter):
|
||||
"""Custom formatter that adds structured information to log records."""
|
||||
|
||||
def __init__(self, include_context: bool = True) -> None:
|
||||
"""Initialize the formatter.
|
||||
|
||||
Args:
|
||||
include_context: Whether to include additional context in log records
|
||||
"""
|
||||
super().__init__()
|
||||
self.include_context = include_context
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Format a log record with structured information."""
|
||||
# Add timestamp
|
||||
record.timestamp = datetime.utcnow().isoformat()
|
||||
|
||||
# Add context information if available
|
||||
if self.include_context and hasattr(record, 'context'):
|
||||
record.context_str = f" | Context: {record.context}"
|
||||
else:
|
||||
record.context_str = ""
|
||||
|
||||
# Create the formatted message
|
||||
if record.levelno >= logging.ERROR and record.exc_info:
|
||||
# For errors with exceptions, include the full traceback
|
||||
formatted = (
|
||||
f"[{record.timestamp}] {record.levelname:8} | "
|
||||
f"{record.name} | {record.getMessage()}{record.context_str}\n"
|
||||
f"Exception: {self.formatException(record.exc_info)}"
|
||||
)
|
||||
else:
|
||||
formatted = (
|
||||
f"[{record.timestamp}] {record.levelname:8} | "
|
||||
f"{record.name} | {record.getMessage()}{record.context_str}"
|
||||
)
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
class ProgressAwareHandler(RichHandler):
|
||||
"""Rich handler that works well with progress bars and other rich output."""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
# Create a separate console for logging to avoid conflicts
|
||||
console = Console(stderr=True, force_terminal=True)
|
||||
kwargs['console'] = console
|
||||
kwargs['show_time'] = False # We'll handle time in our formatter
|
||||
kwargs['show_path'] = False
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def setup_logging(
|
||||
name: Optional[str] = None,
|
||||
level: Optional[str] = None,
|
||||
log_to_file: Optional[bool] = None,
|
||||
log_file_path: Optional[Path] = None,
|
||||
enable_rich_tracebacks: bool = True,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> logging.Logger:
|
||||
"""Set up logging configuration for MLSysBook tools.
|
||||
|
||||
Args:
|
||||
name: Logger name (defaults to calling module)
|
||||
level: Log level (defaults to config value)
|
||||
log_to_file: Whether to log to file (defaults to config value)
|
||||
log_file_path: Path for log file (defaults to config value)
|
||||
enable_rich_tracebacks: Whether to enable rich exception formatting
|
||||
context: Additional context to include in all log messages
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
# Use provided values or fall back to config
|
||||
level = level or config.log_level
|
||||
log_to_file = log_to_file if log_to_file is not None else config.log_to_file
|
||||
log_file_path = log_file_path or config.log_file_path
|
||||
|
||||
# Set up rich tracebacks if enabled
|
||||
if enable_rich_tracebacks:
|
||||
install(show_locals=True)
|
||||
|
||||
# Create logger
|
||||
logger = logging.getLogger(name or __name__)
|
||||
logger.setLevel(getattr(logging, level.upper()))
|
||||
|
||||
# Clear any existing handlers
|
||||
logger.handlers.clear()
|
||||
|
||||
# Add console handler with rich formatting
|
||||
console_handler = ProgressAwareHandler()
|
||||
console_handler.setLevel(getattr(logging, level.upper()))
|
||||
console_formatter = logging.Formatter("%(name)s | %(message)s")
|
||||
console_handler.setFormatter(console_formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# Add file handler if requested
|
||||
if log_to_file and log_file_path:
|
||||
# Ensure log directory exists
|
||||
log_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create rotating file handler
|
||||
file_handler = logging.handlers.RotatingFileHandler(
|
||||
log_file_path,
|
||||
maxBytes=10 * 1024 * 1024, # 10MB
|
||||
backupCount=5,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG) # Always debug level for files
|
||||
file_formatter = StructuredFormatter(include_context=True)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# Add context if provided
|
||||
if context:
|
||||
logger = LoggerAdapter(logger, context)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(
|
||||
name: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> logging.Logger:
|
||||
"""Get a logger instance with optional context.
|
||||
|
||||
Args:
|
||||
name: Logger name (defaults to calling module)
|
||||
context: Additional context to include in log messages
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
logger = logging.getLogger(name or __name__)
|
||||
|
||||
# If logger doesn't have handlers, set it up
|
||||
if not logger.handlers:
|
||||
logger = setup_logging(name)
|
||||
|
||||
# Add context if provided
|
||||
if context:
|
||||
logger = LoggerAdapter(logger, context)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
class LoggerAdapter(logging.LoggerAdapter):
|
||||
"""Adapter that adds context information to log records."""
|
||||
|
||||
def __init__(self, logger: logging.Logger, extra: Dict[str, Any]) -> None:
|
||||
"""Initialize the adapter with context information.
|
||||
|
||||
Args:
|
||||
logger: Base logger instance
|
||||
extra: Context information to add to log records
|
||||
"""
|
||||
super().__init__(logger, extra)
|
||||
|
||||
def process(self, msg: str, kwargs: Dict[str, Any]) -> tuple:
|
||||
"""Process a log message to add context information."""
|
||||
if 'extra' not in kwargs:
|
||||
kwargs['extra'] = {}
|
||||
|
||||
# Add our context to the extra information
|
||||
kwargs['extra'].update(self.extra)
|
||||
kwargs['extra']['context'] = self.extra
|
||||
|
||||
return msg, kwargs
|
||||
|
||||
|
||||
class ProgressLogger:
|
||||
"""Logger that works well with progress bars and other rich output."""
|
||||
|
||||
def __init__(self, name: Optional[str] = None) -> None:
|
||||
"""Initialize the progress logger.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
"""
|
||||
self.logger = get_logger(name)
|
||||
self._progress_active = False
|
||||
|
||||
def start_progress(self) -> None:
|
||||
"""Indicate that a progress operation is starting."""
|
||||
self._progress_active = True
|
||||
|
||||
def end_progress(self) -> None:
|
||||
"""Indicate that a progress operation has ended."""
|
||||
self._progress_active = False
|
||||
|
||||
def log(self, level: str, message: str, **kwargs) -> None:
|
||||
"""Log a message with progress awareness.
|
||||
|
||||
Args:
|
||||
level: Log level (debug, info, warning, error, critical)
|
||||
message: Log message
|
||||
**kwargs: Additional keyword arguments for logging
|
||||
"""
|
||||
if self._progress_active:
|
||||
# For progress operations, use a simpler format
|
||||
print(f"[{level.upper()}] {message}", file=sys.stderr)
|
||||
else:
|
||||
# Use normal logging
|
||||
getattr(self.logger, level.lower())(message, **kwargs)
|
||||
|
||||
def debug(self, message: str, **kwargs) -> None:
|
||||
"""Log a debug message."""
|
||||
self.log('debug', message, **kwargs)
|
||||
|
||||
def info(self, message: str, **kwargs) -> None:
|
||||
"""Log an info message."""
|
||||
self.log('info', message, **kwargs)
|
||||
|
||||
def warning(self, message: str, **kwargs) -> None:
|
||||
"""Log a warning message."""
|
||||
self.log('warning', message, **kwargs)
|
||||
|
||||
def error(self, message: str, **kwargs) -> None:
|
||||
"""Log an error message."""
|
||||
self.log('error', message, **kwargs)
|
||||
|
||||
def critical(self, message: str, **kwargs) -> None:
|
||||
"""Log a critical message."""
|
||||
self.log('critical', message, **kwargs)
|
||||
419
tools/scripts/common/validators.py
Normal file
419
tools/scripts/common/validators.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Input validation utilities for MLSysBook tools.
|
||||
|
||||
This module provides comprehensive validation functions for common input types
|
||||
including file paths, configuration values, and data structures.
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union, Callable, Type
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .exceptions import ValidationError
|
||||
|
||||
|
||||
def validate_file_path(
|
||||
path: Union[str, Path],
|
||||
must_exist: bool = True,
|
||||
must_be_file: bool = True,
|
||||
must_be_readable: bool = True,
|
||||
allowed_extensions: Optional[List[str]] = None
|
||||
) -> Path:
|
||||
"""Validate a file path.
|
||||
|
||||
Args:
|
||||
path: File path to validate
|
||||
must_exist: Whether the file must exist
|
||||
must_be_file: Whether the path must be a file (not directory)
|
||||
must_be_readable: Whether the file must be readable
|
||||
allowed_extensions: List of allowed file extensions (e.g., ['.qmd', '.md'])
|
||||
|
||||
Returns:
|
||||
Validated Path object
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
"""
|
||||
if not path:
|
||||
raise ValidationError("File path cannot be empty")
|
||||
|
||||
path_obj = Path(path).resolve()
|
||||
|
||||
# Check for path traversal attempts
|
||||
try:
|
||||
path_obj.resolve().relative_to(Path.cwd().resolve())
|
||||
except ValueError:
|
||||
# Allow absolute paths, but check for suspicious patterns
|
||||
path_str = str(path_obj)
|
||||
if '..' in path_str or path_str.startswith('/'):
|
||||
# Additional validation for absolute paths
|
||||
pass
|
||||
|
||||
if must_exist and not path_obj.exists():
|
||||
raise ValidationError(f"File does not exist: {path_obj}")
|
||||
|
||||
if must_exist and must_be_file and not path_obj.is_file():
|
||||
raise ValidationError(f"Path is not a file: {path_obj}")
|
||||
|
||||
if must_exist and must_be_readable:
|
||||
try:
|
||||
with open(path_obj, 'r', encoding='utf-8') as f:
|
||||
f.read(1) # Try to read one character
|
||||
except PermissionError:
|
||||
raise ValidationError(f"File is not readable: {path_obj}")
|
||||
except UnicodeDecodeError:
|
||||
raise ValidationError(f"File is not valid UTF-8: {path_obj}")
|
||||
|
||||
if allowed_extensions:
|
||||
if path_obj.suffix.lower() not in [ext.lower() for ext in allowed_extensions]:
|
||||
raise ValidationError(
|
||||
f"File extension {path_obj.suffix} not allowed. "
|
||||
f"Allowed extensions: {allowed_extensions}"
|
||||
)
|
||||
|
||||
return path_obj
|
||||
|
||||
|
||||
def validate_directory_path(
|
||||
path: Union[str, Path],
|
||||
must_exist: bool = True,
|
||||
create_if_missing: bool = False,
|
||||
must_be_writable: bool = False
|
||||
) -> Path:
|
||||
"""Validate a directory path.
|
||||
|
||||
Args:
|
||||
path: Directory path to validate
|
||||
must_exist: Whether the directory must exist
|
||||
create_if_missing: Whether to create the directory if it doesn't exist
|
||||
must_be_writable: Whether the directory must be writable
|
||||
|
||||
Returns:
|
||||
Validated Path object
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
"""
|
||||
if not path:
|
||||
raise ValidationError("Directory path cannot be empty")
|
||||
|
||||
path_obj = Path(path).resolve()
|
||||
|
||||
if not path_obj.exists():
|
||||
if create_if_missing:
|
||||
try:
|
||||
path_obj.mkdir(parents=True, exist_ok=True)
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Cannot create directory {path_obj}: {e}")
|
||||
elif must_exist:
|
||||
raise ValidationError(f"Directory does not exist: {path_obj}")
|
||||
|
||||
if path_obj.exists() and not path_obj.is_dir():
|
||||
raise ValidationError(f"Path is not a directory: {path_obj}")
|
||||
|
||||
if must_be_writable and path_obj.exists():
|
||||
test_file = path_obj / '.write_test'
|
||||
try:
|
||||
test_file.touch()
|
||||
test_file.unlink()
|
||||
except Exception:
|
||||
raise ValidationError(f"Directory is not writable: {path_obj}")
|
||||
|
||||
return path_obj
|
||||
|
||||
|
||||
def validate_url(url: str, allowed_schemes: Optional[List[str]] = None) -> str:
|
||||
"""Validate a URL.
|
||||
|
||||
Args:
|
||||
url: URL to validate
|
||||
allowed_schemes: List of allowed URL schemes (e.g., ['http', 'https'])
|
||||
|
||||
Returns:
|
||||
Validated URL string
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
"""
|
||||
if not url:
|
||||
raise ValidationError("URL cannot be empty")
|
||||
|
||||
if not isinstance(url, str):
|
||||
raise ValidationError("URL must be a string")
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Invalid URL format: {e}")
|
||||
|
||||
if not parsed.scheme:
|
||||
raise ValidationError("URL must include a scheme (http, https, etc.)")
|
||||
|
||||
if not parsed.netloc:
|
||||
raise ValidationError("URL must include a network location")
|
||||
|
||||
if allowed_schemes and parsed.scheme not in allowed_schemes:
|
||||
raise ValidationError(
|
||||
f"URL scheme '{parsed.scheme}' not allowed. "
|
||||
f"Allowed schemes: {allowed_schemes}"
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def validate_json_data(
|
||||
data: Any,
|
||||
schema: Optional[Dict[str, Any]] = None,
|
||||
required_keys: Optional[List[str]] = None
|
||||
) -> Any:
|
||||
"""Validate JSON data structure.
|
||||
|
||||
Args:
|
||||
data: Data to validate
|
||||
schema: Optional JSON schema for validation
|
||||
required_keys: Required keys for dictionary data
|
||||
|
||||
Returns:
|
||||
Validated data
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
"""
|
||||
if schema:
|
||||
try:
|
||||
import jsonschema
|
||||
jsonschema.validate(data, schema)
|
||||
except ImportError:
|
||||
raise ValidationError("jsonschema package required for schema validation")
|
||||
except jsonschema.ValidationError as e:
|
||||
raise ValidationError(f"JSON schema validation failed: {e.message}")
|
||||
|
||||
if required_keys and isinstance(data, dict):
|
||||
missing_keys = [key for key in required_keys if key not in data]
|
||||
if missing_keys:
|
||||
raise ValidationError(f"Missing required keys: {missing_keys}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def validate_string(
|
||||
value: Any,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pattern: Optional[str] = None,
|
||||
allowed_values: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""Validate a string value.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
min_length: Minimum string length
|
||||
max_length: Maximum string length
|
||||
pattern: Regex pattern the string must match
|
||||
allowed_values: List of allowed string values
|
||||
|
||||
Returns:
|
||||
Validated string
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise ValidationError(f"Expected string, got {type(value).__name__}")
|
||||
|
||||
if min_length is not None and len(value) < min_length:
|
||||
raise ValidationError(f"String too short. Minimum length: {min_length}")
|
||||
|
||||
if max_length is not None and len(value) > max_length:
|
||||
raise ValidationError(f"String too long. Maximum length: {max_length}")
|
||||
|
||||
if pattern and not re.match(pattern, value):
|
||||
raise ValidationError(f"String does not match pattern: {pattern}")
|
||||
|
||||
if allowed_values and value not in allowed_values:
|
||||
raise ValidationError(f"Value '{value}' not in allowed values: {allowed_values}")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_number(
|
||||
value: Any,
|
||||
min_value: Optional[Union[int, float]] = None,
|
||||
max_value: Optional[Union[int, float]] = None,
|
||||
number_type: Type = float
|
||||
) -> Union[int, float]:
|
||||
"""Validate a numeric value.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
min_value: Minimum allowed value
|
||||
max_value: Maximum allowed value
|
||||
number_type: Expected number type (int or float)
|
||||
|
||||
Returns:
|
||||
Validated number
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
"""
|
||||
try:
|
||||
if number_type == int:
|
||||
numeric_value = int(value)
|
||||
else:
|
||||
numeric_value = float(value)
|
||||
except (ValueError, TypeError):
|
||||
raise ValidationError(f"Cannot convert '{value}' to {number_type.__name__}")
|
||||
|
||||
if min_value is not None and numeric_value < min_value:
|
||||
raise ValidationError(f"Value {numeric_value} below minimum: {min_value}")
|
||||
|
||||
if max_value is not None and numeric_value > max_value:
|
||||
raise ValidationError(f"Value {numeric_value} above maximum: {max_value}")
|
||||
|
||||
return numeric_value
|
||||
|
||||
|
||||
def validate_list(
|
||||
value: Any,
|
||||
item_validator: Optional[Callable] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: bool = False
|
||||
) -> List[Any]:
|
||||
"""Validate a list value.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
item_validator: Function to validate each item
|
||||
min_items: Minimum number of items
|
||||
max_items: Maximum number of items
|
||||
unique_items: Whether items must be unique
|
||||
|
||||
Returns:
|
||||
Validated list
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
"""
|
||||
if not isinstance(value, list):
|
||||
raise ValidationError(f"Expected list, got {type(value).__name__}")
|
||||
|
||||
if min_items is not None and len(value) < min_items:
|
||||
raise ValidationError(f"Too few items. Minimum: {min_items}")
|
||||
|
||||
if max_items is not None and len(value) > max_items:
|
||||
raise ValidationError(f"Too many items. Maximum: {max_items}")
|
||||
|
||||
if unique_items and len(value) != len(set(value)):
|
||||
raise ValidationError("List items must be unique")
|
||||
|
||||
if item_validator:
|
||||
validated_items = []
|
||||
for i, item in enumerate(value):
|
||||
try:
|
||||
validated_items.append(item_validator(item))
|
||||
except ValidationError as e:
|
||||
raise ValidationError(f"Item {i} validation failed: {e}")
|
||||
return validated_items
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_config_file(file_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""Validate and load a configuration file.
|
||||
|
||||
Args:
|
||||
file_path: Path to configuration file
|
||||
|
||||
Returns:
|
||||
Loaded configuration data
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
"""
|
||||
path_obj = validate_file_path(
|
||||
file_path,
|
||||
allowed_extensions=['.yaml', '.yml', '.json']
|
||||
)
|
||||
|
||||
try:
|
||||
with open(path_obj, 'r', encoding='utf-8') as f:
|
||||
if path_obj.suffix.lower() == '.json':
|
||||
data = json.load(f)
|
||||
else:
|
||||
data = yaml.safe_load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValidationError(f"Invalid JSON in config file: {e}")
|
||||
except yaml.YAMLError as e:
|
||||
raise ValidationError(f"Invalid YAML in config file: {e}")
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Cannot read config file: {e}")
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValidationError("Configuration file must contain a dictionary/object")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class Validator:
|
||||
"""Fluent validation interface for complex validation chains."""
|
||||
|
||||
def __init__(self, value: Any, name: str = "value") -> None:
|
||||
"""Initialize validator with a value.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
name: Name of the value for error messages
|
||||
"""
|
||||
self.value = value
|
||||
self.name = name
|
||||
|
||||
def is_string(self, **kwargs) -> 'Validator':
|
||||
"""Validate that value is a string."""
|
||||
self.value = validate_string(self.value, **kwargs)
|
||||
return self
|
||||
|
||||
def is_number(self, **kwargs) -> 'Validator':
|
||||
"""Validate that value is a number."""
|
||||
self.value = validate_number(self.value, **kwargs)
|
||||
return self
|
||||
|
||||
def is_list(self, **kwargs) -> 'Validator':
|
||||
"""Validate that value is a list."""
|
||||
self.value = validate_list(self.value, **kwargs)
|
||||
return self
|
||||
|
||||
def is_file_path(self, **kwargs) -> 'Validator':
|
||||
"""Validate that value is a file path."""
|
||||
self.value = validate_file_path(self.value, **kwargs)
|
||||
return self
|
||||
|
||||
def is_directory_path(self, **kwargs) -> 'Validator':
|
||||
"""Validate that value is a directory path."""
|
||||
self.value = validate_directory_path(self.value, **kwargs)
|
||||
return self
|
||||
|
||||
def is_url(self, **kwargs) -> 'Validator':
|
||||
"""Validate that value is a URL."""
|
||||
self.value = validate_url(self.value, **kwargs)
|
||||
return self
|
||||
|
||||
def get(self) -> Any:
|
||||
"""Get the validated value."""
|
||||
return self.value
|
||||
|
||||
|
||||
def validate(value: Any, name: str = "value") -> Validator:
|
||||
"""Create a new validator for fluent validation.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
name: Name of the value for error messages
|
||||
|
||||
Returns:
|
||||
Validator instance
|
||||
"""
|
||||
return Validator(value, name)
|
||||
Reference in New Issue
Block a user