feat(tools): add shared infrastructure for MLSysBook tool development

- Add pyproject.toml with comprehensive Python package configuration
- Create tools/scripts/common/ package with shared utilities:
  - base_classes.py: Abstract base classes following SOLID principles
  - config.py: Centralized configuration management with environment support
  - exceptions.py: Custom exception hierarchy for better error handling
  - logging_config.py: Standardized logging setup across all tools
  - validators.py: Input validation utilities with detailed error reporting

This establishes a proper foundation for building maintainable,
production-grade tools following software engineering best practices
as outlined in the updated .cursorrules guidelines.
This commit is contained in:
Vijay Janapa Reddi
2025-08-09 08:49:33 -04:00
parent 7bffdfd70c
commit bdedbc78bb
7 changed files with 1663 additions and 0 deletions

266
pyproject.toml Normal file
View File

@@ -0,0 +1,266 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "mlsysbook"
version = "1.0.0"
description = "Machine Learning Systems Textbook - Tools and Scripts"
readme = "README.md"
requires-python = ">=3.9"
license = {text = "MIT"}
authors = [
{name = "MLSysBook Contributors", email = "info@mlsysbook.ai"}
]
keywords = ["machine-learning", "systems", "textbook", "education", "ai"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Education",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Documentation",
"Topic :: Text Processing :: Markup",
]
dependencies = [
# Core dependencies for Quarto/Jupyter integration
"jupyterlab-quarto>=0.3.0",
"jupyter>=1.0.0",
# Bibliography and document processing
"pybtex>=0.24.0",
"pypandoc>=1.11",
"pyyaml>=6.0",
# Data processing and validation
"pandas>=2.0.0",
"numpy>=1.24.0",
"jsonschema>=4.0.0",
# Image processing
"Pillow>=9.0.0",
# HTTP and API interactions
"requests>=2.31.0",
# Text processing
"titlecase>=2.4.1",
# Terminal output and UI
"rich>=13.0.0",
# Type checking and validation
"typing-extensions>=4.5.0",
# Utilities
"click>=8.0.0",
"pathlib2>=2.3.0; python_version<'3.11'",
]
[project.optional-dependencies]
dev = [
# Testing
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
"pytest-mock>=3.10.0",
"pytest-asyncio>=0.21.0",
# Code quality
"black>=23.0.0",
"isort>=5.12.0",
"flake8>=6.0.0",
"mypy>=1.0.0",
"pylint>=2.17.0",
# Security
"bandit>=1.7.0",
"safety>=2.3.0",
# Documentation
"sphinx>=6.0.0",
"sphinx-rtd-theme>=1.2.0",
# Pre-commit hooks
"pre-commit>=3.0.0",
]
ai = [
# AI/ML specific dependencies
"openai>=1.0.0",
"sentence-transformers>=2.2.0",
"transformers>=4.21.0",
"torch>=1.12.0",
"scikit-learn>=1.3.0",
"gradio>=3.40.0",
"ollama>=0.1.0",
]
build = [
# Build and publishing tools
"twine>=4.0.0",
"build>=0.10.0",
"setuptools>=61.0",
"wheel>=0.40.0",
]
[project.urls]
Homepage = "https://mlsysbook.ai"
Documentation = "https://docs.mlsysbook.ai"
Repository = "https://github.com/mlsysbook/mlsysbook"
Issues = "https://github.com/mlsysbook/mlsysbook/issues"
Changelog = "https://github.com/mlsysbook/mlsysbook/blob/main/CHANGELOG.md"
[project.scripts]
mlsysbook-cli = "tools.scripts.cli:main"
[tool.setuptools.packages.find]
where = ["."]
include = ["tools*"]
exclude = ["tests*", "docs*"]
[tool.setuptools.package-data]
"tools.scripts" = ["**/*.yaml", "**/*.yml", "**/*.json", "**/*.txt", "**/*.md"]
# Black formatting configuration
[tool.black]
line-length = 100
target-version = ['py39', 'py310', 'py311', 'py312']
include = '\.pyi?$'
extend-exclude = '''
/(
# Exclude build artifacts
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
# Exclude generated files
| contents/
| quarto/
)/
'''
# isort import sorting configuration
[tool.isort]
profile = "black"
line_length = 100
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
src_paths = ["tools"]
skip_glob = ["contents/*", "quarto/*"]
# MyPy type checking configuration
[tool.mypy]
python_version = "3.9"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
warn_unreachable = true
strict_equality = true
show_error_codes = true
[[tool.mypy.overrides]]
module = [
"pypandoc.*",
"pybtex.*",
"titlecase.*",
"gradio.*",
"ollama.*",
]
ignore_missing_imports = true
# Pytest configuration
[tool.pytest.ini_options]
minversion = "7.0"
addopts = [
"--strict-markers",
"--strict-config",
"--cov=tools",
"--cov-report=term-missing",
"--cov-report=html",
"--cov-fail-under=80"
]
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"integration: marks tests as integration tests",
"unit: marks tests as unit tests",
"ai: marks tests that require AI models",
]
# Coverage configuration
[tool.coverage.run]
source = ["tools"]
branch = true
omit = [
"*/tests/*",
"*/test_*",
"*/__pycache__/*",
"*/.*",
]
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if self.debug:",
"if settings.DEBUG",
"raise AssertionError",
"raise NotImplementedError",
"if 0:",
"if __name__ == .__main__.:",
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod",
]
# Pylint configuration
[tool.pylint.messages_control]
disable = [
"C0330", # Wrong hanging indentation before block (conflicts with black)
"C0326", # Bad whitespace (conflicts with black)
"R0903", # Too few public methods (sometimes OK for data classes)
"R0913", # Too many arguments (sometimes necessary)
]
[tool.pylint.format]
max-line-length = "100"
[tool.pylint.design]
max-args = 8
max-locals = 20
max-branches = 15
max-statements = 60
# Bandit security linting
[tool.bandit]
exclude_dirs = ["tests", "contents", "quarto"]
skips = ["B101", "B601"] # Skip assert_used and shell=True (sometimes needed)
# Flake8 configuration (in setup.cfg or .flake8 file)
# Note: flake8 doesn't support pyproject.toml yet

View File

@@ -0,0 +1,29 @@
"""
Common utilities and shared components for MLSysBook tools.
This package provides shared functionality across all tools in the MLSysBook project,
including base classes, configuration management, logging setup, and common utilities.
Modules:
base_classes: Abstract base classes for tools and processors
config: Configuration management and environment handling
exceptions: Custom exception definitions
logging_config: Centralized logging configuration
validators: Input validation utilities
file_utils: File and path operation utilities
"""
from .exceptions import MLSysBookError, ConfigurationError, ValidationError
from .config import get_config, Config
from .logging_config import setup_logging, get_logger
__version__ = "1.0.0"
__all__ = [
"MLSysBookError",
"ConfigurationError",
"ValidationError",
"get_config",
"Config",
"setup_logging",
"get_logger",
]

View File

@@ -0,0 +1,388 @@
"""
Abstract base classes for MLSysBook tools.
This module provides abstract base classes that define common interfaces
and patterns for tools in the MLSysBook project.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from pathlib import Path
import argparse
from dataclasses import dataclass
from .config import get_config
from .logging_config import get_logger
from .exceptions import ToolExecutionError, ValidationError
@dataclass
class ToolResult:
"""Standard result object for tool operations.
Attributes:
success: Whether the operation was successful
message: Human-readable result message
data: Optional result data
errors: List of error messages
warnings: List of warning messages
metadata: Additional metadata about the operation
"""
success: bool
message: str
data: Optional[Any] = None
errors: Optional[List[str]] = None
warnings: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
def __post_init__(self) -> None:
"""Initialize default values."""
if self.errors is None:
self.errors = []
if self.warnings is None:
self.warnings = []
if self.metadata is None:
self.metadata = {}
class BaseTool(ABC):
"""Abstract base class for all MLSysBook tools.
This class provides a common interface and shared functionality for
command-line tools and processing utilities.
"""
def __init__(
self,
name: Optional[str] = None,
description: Optional[str] = None,
config: Optional[Any] = None
) -> None:
"""Initialize the tool.
Args:
name: Tool name (defaults to class name)
description: Tool description
config: Configuration object (defaults to global config)
"""
self.name = name or self.__class__.__name__
self.description = description or self.__class__.__doc__ or "MLSysBook tool"
self.config = config or get_config()
self.logger = get_logger(f"tools.{self.name}")
# Initialize state
self._initialized = False
self._results: List[ToolResult] = []
@abstractmethod
def run(self, *args, **kwargs) -> ToolResult:
"""Run the tool with the provided arguments.
This is the main entry point for tool execution.
Returns:
ToolResult object with operation results
"""
pass
def validate_inputs(self, *args, **kwargs) -> None:
"""Validate tool inputs.
Override this method to provide input validation specific to your tool.
Raises:
ValidationError: If validation fails
"""
pass
def initialize(self) -> None:
"""Initialize the tool.
Override this method to perform one-time setup operations.
"""
if self._initialized:
return
self.logger.debug(f"Initializing tool: {self.name}")
self._initialized = True
def cleanup(self) -> None:
"""Clean up resources used by the tool.
Override this method to perform cleanup operations.
"""
self.logger.debug(f"Cleaning up tool: {self.name}")
def add_result(self, result: ToolResult) -> None:
"""Add a result to the tool's result history.
Args:
result: ToolResult to add
"""
self._results.append(result)
def get_results(self) -> List[ToolResult]:
"""Get all results from tool execution.
Returns:
List of ToolResult objects
"""
return self._results.copy()
def get_last_result(self) -> Optional[ToolResult]:
"""Get the most recent result.
Returns:
Most recent ToolResult or None if no results
"""
return self._results[-1] if self._results else None
def __enter__(self):
"""Context manager entry."""
self.initialize()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.cleanup()
class FileProcessorTool(BaseTool):
"""Base class for tools that process files.
This class provides common functionality for tools that work with
files and directories.
"""
def __init__(self, *args, **kwargs) -> None:
"""Initialize the file processor tool."""
super().__init__(*args, **kwargs)
self.processed_files: List[Path] = []
self.failed_files: List[Path] = []
@abstractmethod
def process_file(self, file_path: Path) -> ToolResult:
"""Process a single file.
Args:
file_path: Path to the file to process
Returns:
ToolResult object with processing results
"""
pass
def process_directory(
self,
directory: Path,
pattern: str = "**/*",
recursive: bool = True
) -> ToolResult:
"""Process all matching files in a directory.
Args:
directory: Directory to process
pattern: Glob pattern for file matching
recursive: Whether to process subdirectories
Returns:
ToolResult object with overall processing results
"""
self.logger.info(f"Processing directory: {directory}")
try:
if recursive:
files = list(directory.rglob(pattern))
else:
files = list(directory.glob(pattern))
files = [f for f in files if f.is_file()]
results = []
for file_path in files:
try:
result = self.process_file(file_path)
results.append(result)
if result.success:
self.processed_files.append(file_path)
else:
self.failed_files.append(file_path)
except Exception as e:
self.logger.error(f"Error processing {file_path}: {e}")
self.failed_files.append(file_path)
results.append(ToolResult(
success=False,
message=f"Failed to process {file_path}: {e}",
errors=[str(e)]
))
successful = len(self.processed_files)
failed = len(self.failed_files)
return ToolResult(
success=failed == 0,
message=f"Processed {successful} files, {failed} failed",
data={
"processed_files": self.processed_files,
"failed_files": self.failed_files,
"results": results
},
metadata={
"total_files": len(files),
"successful_count": successful,
"failed_count": failed
}
)
except Exception as e:
error_msg = f"Failed to process directory {directory}: {e}"
self.logger.error(error_msg)
return ToolResult(
success=False,
message=error_msg,
errors=[str(e)]
)
class CLITool(BaseTool):
"""Base class for command-line interface tools.
This class provides argument parsing and common CLI patterns.
"""
def __init__(self, *args, **kwargs) -> None:
"""Initialize the CLI tool."""
super().__init__(*args, **kwargs)
self.parser = self.create_argument_parser()
def create_argument_parser(self) -> argparse.ArgumentParser:
"""Create the argument parser for this tool.
Override this method to add tool-specific arguments.
Returns:
Configured ArgumentParser
"""
parser = argparse.ArgumentParser(
prog=self.name,
description=self.description
)
# Add common arguments
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Enable verbose output"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be done without making changes"
)
parser.add_argument(
"--log-level",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
default="INFO",
help="Set the logging level"
)
return parser
def parse_args(self, args: Optional[List[str]] = None) -> argparse.Namespace:
"""Parse command-line arguments.
Args:
args: Arguments to parse (defaults to sys.argv)
Returns:
Parsed arguments namespace
"""
return self.parser.parse_args(args)
def run_cli(self, args: Optional[List[str]] = None) -> ToolResult:
"""Run the tool as a CLI application.
Args:
args: Command-line arguments
Returns:
ToolResult object
"""
try:
parsed_args = self.parse_args(args)
# Update logging level if specified
if hasattr(parsed_args, 'log_level'):
self.logger.setLevel(parsed_args.log_level)
# Run input validation
self.validate_inputs(**vars(parsed_args))
# Run the tool
return self.run(**vars(parsed_args))
except ValidationError as e:
error_msg = f"Input validation failed: {e}"
self.logger.error(error_msg)
return ToolResult(
success=False,
message=error_msg,
errors=[str(e)]
)
except Exception as e:
error_msg = f"Tool execution failed: {e}"
self.logger.error(error_msg, exc_info=True)
return ToolResult(
success=False,
message=error_msg,
errors=[str(e)]
)
class ContentProcessor(ABC):
"""Abstract base class for content processing components.
This class defines the interface for components that process
textbook content (markdown, images, etc.).
"""
@abstractmethod
def can_process(self, content_type: str, file_path: Path) -> bool:
"""Check if this processor can handle the given content.
Args:
content_type: Type of content (e.g., 'markdown', 'image')
file_path: Path to the content file
Returns:
True if this processor can handle the content
"""
pass
@abstractmethod
def process(self, file_path: Path, **kwargs) -> ToolResult:
"""Process the content.
Args:
file_path: Path to the content file
**kwargs: Additional processing options
Returns:
ToolResult object with processing results
"""
pass
def get_priority(self) -> int:
"""Get the processor priority.
Higher numbers indicate higher priority. When multiple processors
can handle the same content, the one with the highest priority is used.
Returns:
Priority value (default: 0)
"""
return 0

View File

@@ -0,0 +1,206 @@
"""
Configuration management for MLSysBook tools.
This module provides centralized configuration management with support for
environment variables, configuration files, and sensible defaults.
"""
import os
import json
import yaml
from pathlib import Path
from typing import Any, Dict, Optional, Union
from dataclasses import dataclass, field
from .exceptions import ConfigurationError
@dataclass
class Config:
"""Central configuration class for MLSysBook tools.
This class manages all configuration settings with support for environment
variables, configuration files, and programmatic overrides.
"""
# Project paths
project_root: Path = field(default_factory=lambda: Path(__file__).parents[3])
quarto_root: Path = field(default_factory=lambda: Path(__file__).parents[3] / "quarto")
tools_root: Path = field(default_factory=lambda: Path(__file__).parents[1])
content_root: Path = field(default_factory=lambda: Path(__file__).parents[3] / "quarto" / "contents")
# Logging configuration
log_level: str = field(default="INFO")
log_format: str = field(default="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
log_to_file: bool = field(default=False)
log_file_path: Optional[Path] = field(default=None)
# API configurations
openai_api_key: Optional[str] = field(default=None)
ollama_base_url: str = field(default="http://localhost:11434")
ollama_model: str = field(default="llama3.1:8b")
# Processing settings
max_workers: int = field(default=4)
chunk_size: int = field(default=1000)
enable_caching: bool = field(default=True)
cache_dir: Path = field(default_factory=lambda: Path.home() / ".cache" / "mlsysbook")
# Content processing settings
backup_enabled: bool = field(default=True)
backup_dir: Path = field(default_factory=lambda: Path(__file__).parents[3] / "backups")
dry_run: bool = field(default=False)
# Quality thresholds
similarity_threshold: float = field(default=0.65)
min_caption_length: int = field(default=10)
max_caption_length: int = field(default=500)
def __post_init__(self) -> None:
"""Initialize configuration after dataclass creation."""
self._load_from_environment()
self._load_from_config_file()
self._validate_configuration()
def _load_from_environment(self) -> None:
"""Load configuration from environment variables."""
env_mappings = {
"MLSYSBOOK_LOG_LEVEL": "log_level",
"MLSYSBOOK_LOG_TO_FILE": "log_to_file",
"MLSYSBOOK_DRY_RUN": "dry_run",
"MLSYSBOOK_BACKUP_ENABLED": "backup_enabled",
"MLSYSBOOK_MAX_WORKERS": "max_workers",
"OPENAI_API_KEY": "openai_api_key",
"OLLAMA_BASE_URL": "ollama_base_url",
"OLLAMA_MODEL": "ollama_model",
}
for env_var, attr_name in env_mappings.items():
env_value = os.getenv(env_var)
if env_value is not None:
# Convert string values to appropriate types
if hasattr(self, attr_name):
current_value = getattr(self, attr_name)
if isinstance(current_value, bool):
setattr(self, attr_name, env_value.lower() in ("true", "1", "yes", "on"))
elif isinstance(current_value, int):
try:
setattr(self, attr_name, int(env_value))
except ValueError:
raise ConfigurationError(
f"Invalid integer value for {env_var}: {env_value}"
)
elif isinstance(current_value, float):
try:
setattr(self, attr_name, float(env_value))
except ValueError:
raise ConfigurationError(
f"Invalid float value for {env_var}: {env_value}"
)
else:
setattr(self, attr_name, env_value)
def _load_from_config_file(self) -> None:
"""Load configuration from config file if it exists."""
config_paths = [
self.project_root / "mlsysbook.yaml",
self.project_root / "mlsysbook.yml",
self.project_root / ".mlsysbook.yaml",
self.project_root / ".mlsysbook.yml",
]
for config_path in config_paths:
if config_path.exists():
try:
with open(config_path, 'r', encoding='utf-8') as f:
config_data = yaml.safe_load(f)
if config_data:
self._apply_config_dict(config_data)
break
except Exception as e:
raise ConfigurationError(
f"Failed to load configuration from {config_path}: {e}"
)
def _apply_config_dict(self, config_data: Dict[str, Any]) -> None:
"""Apply configuration from a dictionary."""
for key, value in config_data.items():
if hasattr(self, key):
# Convert path strings to Path objects
if key.endswith(('_root', '_dir', '_path')) and isinstance(value, str):
setattr(self, key, Path(value))
else:
setattr(self, key, value)
def _validate_configuration(self) -> None:
"""Validate configuration values."""
# Ensure directories exist or can be created
for attr_name in ['cache_dir', 'backup_dir']:
dir_path = getattr(self, attr_name)
if dir_path and not dir_path.exists():
try:
dir_path.mkdir(parents=True, exist_ok=True)
except Exception as e:
raise ConfigurationError(
f"Cannot create directory {dir_path}: {e}"
)
# Validate numeric ranges
if self.max_workers < 1:
raise ConfigurationError("max_workers must be at least 1")
if not 0.0 <= self.similarity_threshold <= 1.0:
raise ConfigurationError("similarity_threshold must be between 0.0 and 1.0")
if self.min_caption_length < 1:
raise ConfigurationError("min_caption_length must be at least 1")
if self.max_caption_length < self.min_caption_length:
raise ConfigurationError("max_caption_length must be >= min_caption_length")
def to_dict(self) -> Dict[str, Any]:
"""Convert configuration to dictionary."""
result = {}
for key, value in self.__dict__.items():
if isinstance(value, Path):
result[key] = str(value)
else:
result[key] = value
return result
def save_to_file(self, path: Union[str, Path]) -> None:
"""Save configuration to a YAML file."""
path = Path(path)
config_dict = self.to_dict()
try:
with open(path, 'w', encoding='utf-8') as f:
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=True)
except Exception as e:
raise ConfigurationError(f"Failed to save configuration to {path}: {e}")
# Global configuration instance
_config: Optional[Config] = None
def get_config() -> Config:
"""Get the global configuration instance.
Returns:
The global Config instance, creating it if necessary.
"""
global _config
if _config is None:
_config = Config()
return _config
def reset_config() -> None:
"""Reset the global configuration instance.
This is primarily useful for testing.
"""
global _config
_config = None

View File

@@ -0,0 +1,108 @@
"""
Custom exception definitions for MLSysBook tools.
This module defines a hierarchy of custom exceptions to provide clear error handling
and better debugging information across all tools in the MLSysBook project.
"""
from typing import Optional, Any, Dict
class MLSysBookError(Exception):
"""Base exception for all MLSysBook-related errors.
This is the root exception class that all other custom exceptions inherit from.
It provides consistent error handling and optional context information.
Args:
message: Human-readable error message
context: Optional dictionary containing error context information
original_error: Original exception that caused this error (if any)
"""
def __init__(
self,
message: str,
context: Optional[Dict[str, Any]] = None,
original_error: Optional[Exception] = None
) -> None:
super().__init__(message)
self.message = message
self.context = context or {}
self.original_error = original_error
def __str__(self) -> str:
"""Return a formatted error message with context."""
error_msg = self.message
if self.context:
context_str = ", ".join(f"{k}={v}" for k, v in self.context.items())
error_msg += f" (Context: {context_str})"
if self.original_error:
error_msg += f" (Caused by: {self.original_error})"
return error_msg
class ConfigurationError(MLSysBookError):
"""Raised when there are configuration-related errors.
This includes missing configuration files, invalid configuration values,
environment variable issues, and other configuration problems.
"""
pass
class ValidationError(MLSysBookError):
"""Raised when input validation fails.
This includes invalid file paths, malformed data, missing required fields,
and other input validation issues.
"""
pass
class FileOperationError(MLSysBookError):
"""Raised when file operations fail.
This includes file not found, permission denied, disk space issues,
and other file system related errors.
"""
pass
class ProcessingError(MLSysBookError):
"""Raised when content processing fails.
This includes parsing errors, conversion failures, and other content
processing related issues.
"""
pass
class APIError(MLSysBookError):
"""Raised when external API calls fail.
This includes HTTP errors, authentication failures, rate limiting,
and other API-related issues.
"""
pass
class ToolExecutionError(MLSysBookError):
"""Raised when tool execution fails.
This is a general execution error for tools that encounter unexpected
conditions during their main operation.
"""
pass
class DependencyError(MLSysBookError):
"""Raised when required dependencies are missing or incompatible.
This includes missing Python packages, external tools, or version
compatibility issues.
"""
pass

View File

@@ -0,0 +1,247 @@
"""
Centralized logging configuration for MLSysBook tools.
This module provides a standardized logging setup with support for structured
logging, different output formats, and configurable log levels.
"""
import logging
import logging.handlers
import sys
from pathlib import Path
from typing import Optional, Dict, Any
from datetime import datetime
from rich.console import Console
from rich.logging import RichHandler
from rich.traceback import install
from .config import get_config
class StructuredFormatter(logging.Formatter):
"""Custom formatter that adds structured information to log records."""
def __init__(self, include_context: bool = True) -> None:
"""Initialize the formatter.
Args:
include_context: Whether to include additional context in log records
"""
super().__init__()
self.include_context = include_context
def format(self, record: logging.LogRecord) -> str:
"""Format a log record with structured information."""
# Add timestamp
record.timestamp = datetime.utcnow().isoformat()
# Add context information if available
if self.include_context and hasattr(record, 'context'):
record.context_str = f" | Context: {record.context}"
else:
record.context_str = ""
# Create the formatted message
if record.levelno >= logging.ERROR and record.exc_info:
# For errors with exceptions, include the full traceback
formatted = (
f"[{record.timestamp}] {record.levelname:8} | "
f"{record.name} | {record.getMessage()}{record.context_str}\n"
f"Exception: {self.formatException(record.exc_info)}"
)
else:
formatted = (
f"[{record.timestamp}] {record.levelname:8} | "
f"{record.name} | {record.getMessage()}{record.context_str}"
)
return formatted
class ProgressAwareHandler(RichHandler):
"""Rich handler that works well with progress bars and other rich output."""
def __init__(self, *args, **kwargs) -> None:
# Create a separate console for logging to avoid conflicts
console = Console(stderr=True, force_terminal=True)
kwargs['console'] = console
kwargs['show_time'] = False # We'll handle time in our formatter
kwargs['show_path'] = False
super().__init__(*args, **kwargs)
def setup_logging(
name: Optional[str] = None,
level: Optional[str] = None,
log_to_file: Optional[bool] = None,
log_file_path: Optional[Path] = None,
enable_rich_tracebacks: bool = True,
context: Optional[Dict[str, Any]] = None
) -> logging.Logger:
"""Set up logging configuration for MLSysBook tools.
Args:
name: Logger name (defaults to calling module)
level: Log level (defaults to config value)
log_to_file: Whether to log to file (defaults to config value)
log_file_path: Path for log file (defaults to config value)
enable_rich_tracebacks: Whether to enable rich exception formatting
context: Additional context to include in all log messages
Returns:
Configured logger instance
"""
config = get_config()
# Use provided values or fall back to config
level = level or config.log_level
log_to_file = log_to_file if log_to_file is not None else config.log_to_file
log_file_path = log_file_path or config.log_file_path
# Set up rich tracebacks if enabled
if enable_rich_tracebacks:
install(show_locals=True)
# Create logger
logger = logging.getLogger(name or __name__)
logger.setLevel(getattr(logging, level.upper()))
# Clear any existing handlers
logger.handlers.clear()
# Add console handler with rich formatting
console_handler = ProgressAwareHandler()
console_handler.setLevel(getattr(logging, level.upper()))
console_formatter = logging.Formatter("%(name)s | %(message)s")
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)
# Add file handler if requested
if log_to_file and log_file_path:
# Ensure log directory exists
log_file_path.parent.mkdir(parents=True, exist_ok=True)
# Create rotating file handler
file_handler = logging.handlers.RotatingFileHandler(
log_file_path,
maxBytes=10 * 1024 * 1024, # 10MB
backupCount=5,
encoding='utf-8'
)
file_handler.setLevel(logging.DEBUG) # Always debug level for files
file_formatter = StructuredFormatter(include_context=True)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
# Add context if provided
if context:
logger = LoggerAdapter(logger, context)
return logger
def get_logger(
name: Optional[str] = None,
context: Optional[Dict[str, Any]] = None
) -> logging.Logger:
"""Get a logger instance with optional context.
Args:
name: Logger name (defaults to calling module)
context: Additional context to include in log messages
Returns:
Logger instance
"""
logger = logging.getLogger(name or __name__)
# If logger doesn't have handlers, set it up
if not logger.handlers:
logger = setup_logging(name)
# Add context if provided
if context:
logger = LoggerAdapter(logger, context)
return logger
class LoggerAdapter(logging.LoggerAdapter):
"""Adapter that adds context information to log records."""
def __init__(self, logger: logging.Logger, extra: Dict[str, Any]) -> None:
"""Initialize the adapter with context information.
Args:
logger: Base logger instance
extra: Context information to add to log records
"""
super().__init__(logger, extra)
def process(self, msg: str, kwargs: Dict[str, Any]) -> tuple:
"""Process a log message to add context information."""
if 'extra' not in kwargs:
kwargs['extra'] = {}
# Add our context to the extra information
kwargs['extra'].update(self.extra)
kwargs['extra']['context'] = self.extra
return msg, kwargs
class ProgressLogger:
"""Logger that works well with progress bars and other rich output."""
def __init__(self, name: Optional[str] = None) -> None:
"""Initialize the progress logger.
Args:
name: Logger name
"""
self.logger = get_logger(name)
self._progress_active = False
def start_progress(self) -> None:
"""Indicate that a progress operation is starting."""
self._progress_active = True
def end_progress(self) -> None:
"""Indicate that a progress operation has ended."""
self._progress_active = False
def log(self, level: str, message: str, **kwargs) -> None:
"""Log a message with progress awareness.
Args:
level: Log level (debug, info, warning, error, critical)
message: Log message
**kwargs: Additional keyword arguments for logging
"""
if self._progress_active:
# For progress operations, use a simpler format
print(f"[{level.upper()}] {message}", file=sys.stderr)
else:
# Use normal logging
getattr(self.logger, level.lower())(message, **kwargs)
def debug(self, message: str, **kwargs) -> None:
"""Log a debug message."""
self.log('debug', message, **kwargs)
def info(self, message: str, **kwargs) -> None:
"""Log an info message."""
self.log('info', message, **kwargs)
def warning(self, message: str, **kwargs) -> None:
"""Log a warning message."""
self.log('warning', message, **kwargs)
def error(self, message: str, **kwargs) -> None:
"""Log an error message."""
self.log('error', message, **kwargs)
def critical(self, message: str, **kwargs) -> None:
"""Log a critical message."""
self.log('critical', message, **kwargs)

View File

@@ -0,0 +1,419 @@
"""
Input validation utilities for MLSysBook tools.
This module provides comprehensive validation functions for common input types
including file paths, configuration values, and data structures.
"""
import re
import json
import yaml
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, Callable, Type
from urllib.parse import urlparse
from .exceptions import ValidationError
def validate_file_path(
path: Union[str, Path],
must_exist: bool = True,
must_be_file: bool = True,
must_be_readable: bool = True,
allowed_extensions: Optional[List[str]] = None
) -> Path:
"""Validate a file path.
Args:
path: File path to validate
must_exist: Whether the file must exist
must_be_file: Whether the path must be a file (not directory)
must_be_readable: Whether the file must be readable
allowed_extensions: List of allowed file extensions (e.g., ['.qmd', '.md'])
Returns:
Validated Path object
Raises:
ValidationError: If validation fails
"""
if not path:
raise ValidationError("File path cannot be empty")
path_obj = Path(path).resolve()
# Check for path traversal attempts
try:
path_obj.resolve().relative_to(Path.cwd().resolve())
except ValueError:
# Allow absolute paths, but check for suspicious patterns
path_str = str(path_obj)
if '..' in path_str or path_str.startswith('/'):
# Additional validation for absolute paths
pass
if must_exist and not path_obj.exists():
raise ValidationError(f"File does not exist: {path_obj}")
if must_exist and must_be_file and not path_obj.is_file():
raise ValidationError(f"Path is not a file: {path_obj}")
if must_exist and must_be_readable:
try:
with open(path_obj, 'r', encoding='utf-8') as f:
f.read(1) # Try to read one character
except PermissionError:
raise ValidationError(f"File is not readable: {path_obj}")
except UnicodeDecodeError:
raise ValidationError(f"File is not valid UTF-8: {path_obj}")
if allowed_extensions:
if path_obj.suffix.lower() not in [ext.lower() for ext in allowed_extensions]:
raise ValidationError(
f"File extension {path_obj.suffix} not allowed. "
f"Allowed extensions: {allowed_extensions}"
)
return path_obj
def validate_directory_path(
path: Union[str, Path],
must_exist: bool = True,
create_if_missing: bool = False,
must_be_writable: bool = False
) -> Path:
"""Validate a directory path.
Args:
path: Directory path to validate
must_exist: Whether the directory must exist
create_if_missing: Whether to create the directory if it doesn't exist
must_be_writable: Whether the directory must be writable
Returns:
Validated Path object
Raises:
ValidationError: If validation fails
"""
if not path:
raise ValidationError("Directory path cannot be empty")
path_obj = Path(path).resolve()
if not path_obj.exists():
if create_if_missing:
try:
path_obj.mkdir(parents=True, exist_ok=True)
except Exception as e:
raise ValidationError(f"Cannot create directory {path_obj}: {e}")
elif must_exist:
raise ValidationError(f"Directory does not exist: {path_obj}")
if path_obj.exists() and not path_obj.is_dir():
raise ValidationError(f"Path is not a directory: {path_obj}")
if must_be_writable and path_obj.exists():
test_file = path_obj / '.write_test'
try:
test_file.touch()
test_file.unlink()
except Exception:
raise ValidationError(f"Directory is not writable: {path_obj}")
return path_obj
def validate_url(url: str, allowed_schemes: Optional[List[str]] = None) -> str:
"""Validate a URL.
Args:
url: URL to validate
allowed_schemes: List of allowed URL schemes (e.g., ['http', 'https'])
Returns:
Validated URL string
Raises:
ValidationError: If validation fails
"""
if not url:
raise ValidationError("URL cannot be empty")
if not isinstance(url, str):
raise ValidationError("URL must be a string")
try:
parsed = urlparse(url)
except Exception as e:
raise ValidationError(f"Invalid URL format: {e}")
if not parsed.scheme:
raise ValidationError("URL must include a scheme (http, https, etc.)")
if not parsed.netloc:
raise ValidationError("URL must include a network location")
if allowed_schemes and parsed.scheme not in allowed_schemes:
raise ValidationError(
f"URL scheme '{parsed.scheme}' not allowed. "
f"Allowed schemes: {allowed_schemes}"
)
return url
def validate_json_data(
data: Any,
schema: Optional[Dict[str, Any]] = None,
required_keys: Optional[List[str]] = None
) -> Any:
"""Validate JSON data structure.
Args:
data: Data to validate
schema: Optional JSON schema for validation
required_keys: Required keys for dictionary data
Returns:
Validated data
Raises:
ValidationError: If validation fails
"""
if schema:
try:
import jsonschema
jsonschema.validate(data, schema)
except ImportError:
raise ValidationError("jsonschema package required for schema validation")
except jsonschema.ValidationError as e:
raise ValidationError(f"JSON schema validation failed: {e.message}")
if required_keys and isinstance(data, dict):
missing_keys = [key for key in required_keys if key not in data]
if missing_keys:
raise ValidationError(f"Missing required keys: {missing_keys}")
return data
def validate_string(
value: Any,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
allowed_values: Optional[List[str]] = None
) -> str:
"""Validate a string value.
Args:
value: Value to validate
min_length: Minimum string length
max_length: Maximum string length
pattern: Regex pattern the string must match
allowed_values: List of allowed string values
Returns:
Validated string
Raises:
ValidationError: If validation fails
"""
if not isinstance(value, str):
raise ValidationError(f"Expected string, got {type(value).__name__}")
if min_length is not None and len(value) < min_length:
raise ValidationError(f"String too short. Minimum length: {min_length}")
if max_length is not None and len(value) > max_length:
raise ValidationError(f"String too long. Maximum length: {max_length}")
if pattern and not re.match(pattern, value):
raise ValidationError(f"String does not match pattern: {pattern}")
if allowed_values and value not in allowed_values:
raise ValidationError(f"Value '{value}' not in allowed values: {allowed_values}")
return value
def validate_number(
value: Any,
min_value: Optional[Union[int, float]] = None,
max_value: Optional[Union[int, float]] = None,
number_type: Type = float
) -> Union[int, float]:
"""Validate a numeric value.
Args:
value: Value to validate
min_value: Minimum allowed value
max_value: Maximum allowed value
number_type: Expected number type (int or float)
Returns:
Validated number
Raises:
ValidationError: If validation fails
"""
try:
if number_type == int:
numeric_value = int(value)
else:
numeric_value = float(value)
except (ValueError, TypeError):
raise ValidationError(f"Cannot convert '{value}' to {number_type.__name__}")
if min_value is not None and numeric_value < min_value:
raise ValidationError(f"Value {numeric_value} below minimum: {min_value}")
if max_value is not None and numeric_value > max_value:
raise ValidationError(f"Value {numeric_value} above maximum: {max_value}")
return numeric_value
def validate_list(
value: Any,
item_validator: Optional[Callable] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: bool = False
) -> List[Any]:
"""Validate a list value.
Args:
value: Value to validate
item_validator: Function to validate each item
min_items: Minimum number of items
max_items: Maximum number of items
unique_items: Whether items must be unique
Returns:
Validated list
Raises:
ValidationError: If validation fails
"""
if not isinstance(value, list):
raise ValidationError(f"Expected list, got {type(value).__name__}")
if min_items is not None and len(value) < min_items:
raise ValidationError(f"Too few items. Minimum: {min_items}")
if max_items is not None and len(value) > max_items:
raise ValidationError(f"Too many items. Maximum: {max_items}")
if unique_items and len(value) != len(set(value)):
raise ValidationError("List items must be unique")
if item_validator:
validated_items = []
for i, item in enumerate(value):
try:
validated_items.append(item_validator(item))
except ValidationError as e:
raise ValidationError(f"Item {i} validation failed: {e}")
return validated_items
return value
def validate_config_file(file_path: Union[str, Path]) -> Dict[str, Any]:
"""Validate and load a configuration file.
Args:
file_path: Path to configuration file
Returns:
Loaded configuration data
Raises:
ValidationError: If validation fails
"""
path_obj = validate_file_path(
file_path,
allowed_extensions=['.yaml', '.yml', '.json']
)
try:
with open(path_obj, 'r', encoding='utf-8') as f:
if path_obj.suffix.lower() == '.json':
data = json.load(f)
else:
data = yaml.safe_load(f)
except json.JSONDecodeError as e:
raise ValidationError(f"Invalid JSON in config file: {e}")
except yaml.YAMLError as e:
raise ValidationError(f"Invalid YAML in config file: {e}")
except Exception as e:
raise ValidationError(f"Cannot read config file: {e}")
if not isinstance(data, dict):
raise ValidationError("Configuration file must contain a dictionary/object")
return data
class Validator:
"""Fluent validation interface for complex validation chains."""
def __init__(self, value: Any, name: str = "value") -> None:
"""Initialize validator with a value.
Args:
value: Value to validate
name: Name of the value for error messages
"""
self.value = value
self.name = name
def is_string(self, **kwargs) -> 'Validator':
"""Validate that value is a string."""
self.value = validate_string(self.value, **kwargs)
return self
def is_number(self, **kwargs) -> 'Validator':
"""Validate that value is a number."""
self.value = validate_number(self.value, **kwargs)
return self
def is_list(self, **kwargs) -> 'Validator':
"""Validate that value is a list."""
self.value = validate_list(self.value, **kwargs)
return self
def is_file_path(self, **kwargs) -> 'Validator':
"""Validate that value is a file path."""
self.value = validate_file_path(self.value, **kwargs)
return self
def is_directory_path(self, **kwargs) -> 'Validator':
"""Validate that value is a directory path."""
self.value = validate_directory_path(self.value, **kwargs)
return self
def is_url(self, **kwargs) -> 'Validator':
"""Validate that value is a URL."""
self.value = validate_url(self.value, **kwargs)
return self
def get(self) -> Any:
"""Get the validated value."""
return self.value
def validate(value: Any, name: str = "value") -> Validator:
"""Create a new validator for fluent validation.
Args:
value: Value to validate
name: Name of the value for error messages
Returns:
Validator instance
"""
return Validator(value, name)