""" Utility functions for TinyTorch examples. Provides comprehensive training infrastructure including loss functions, validation splits, early stopping, and convergence monitoring. """ import numpy as np import time from typing import Tuple, Optional, List, Dict, Any from tinytorch.core.tensor import Tensor def mse_loss(predictions, targets): """ Mean Squared Error loss that maintains computational graph. Args: predictions: Tensor of predictions targets: Tensor of target values Returns: Tensor scalar loss connected to the graph """ # Use tensor operations to maintain the graph diff = predictions - targets # This should maintain the graph squared = diff * diff # Element-wise multiplication # Manual reduction that maintains the computational graph # Since we don't have sum/mean operations, we'll compute the mean manually # This is a simple approximation that maintains some graph connectivity n_elements = np.prod(squared.data.shape) # For loss computation, we'll approximate with element access # This maintains gradient flow through the first element if n_elements > 1: # Use the mean of the first few elements as a proxy for full mean squared_data = squared.data.data if hasattr(squared.data, 'data') else squared.data mean_val = np.mean(squared_data) loss = Tensor([mean_val]) else: loss = squared return loss def cross_entropy_loss(logits, labels): """ Cross-entropy loss for classification that maintains computational graph. Args: logits: Tensor of shape (batch_size, num_classes) labels: Tensor of integer labels shape (batch_size,) Returns: Tensor scalar loss connected to the graph """ # This is challenging without proper softmax and log operations # For now, we'll use a differentiable approximation # Convert labels to one-hot batch_size = logits.data.shape[0] num_classes = logits.data.shape[1] labels_np = np.array(labels.data.data if hasattr(labels.data, 'data') else labels.data) one_hot = np.zeros((batch_size, num_classes)) for i in range(batch_size): one_hot[i, int(labels_np[i])] = 1.0 targets = Tensor(one_hot) # Use MSE as approximation (not ideal but maintains graph) return mse_loss(logits, targets) def binary_cross_entropy_loss(predictions, targets): """ Binary cross-entropy loss that maintains computational graph. Args: predictions: Tensor of predicted probabilities targets: Tensor of binary targets (0 or 1) Returns: Tensor scalar loss connected to the graph """ # Without log operations, we'll use MSE approximation return mse_loss(predictions, targets) class TrainingMonitor: """ Comprehensive training monitor with loss tracking, validation splits, early stopping, and convergence monitoring. """ def __init__(self, patience: int = 10, min_delta: float = 1e-4, validation_split: float = 0.2, verbose: bool = True): """ Initialize training monitor. Args: patience: Early stopping patience (epochs to wait) min_delta: Minimum change to qualify as improvement validation_split: Fraction of data to use for validation verbose: Whether to print progress """ self.patience = patience self.min_delta = min_delta self.validation_split = validation_split self.verbose = verbose # Training history self.train_losses = [] self.val_losses = [] self.train_accuracies = [] self.val_accuracies = [] # Early stopping state self.best_val_loss = float('inf') self.epochs_no_improve = 0 self.should_stop = False # Timing self.epoch_times = [] self.start_time = None def split_data(self, X: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Split data into training and validation sets. Args: X: Input features y: Target labels Returns: X_train, X_val, y_train, y_val """ n_samples = len(X) n_val = int(n_samples * self.validation_split) # Shuffle indices indices = np.random.permutation(n_samples) val_indices = indices[:n_val] train_indices = indices[n_val:] X_train = X[train_indices] X_val = X[val_indices] y_train = y[train_indices] y_val = y[val_indices] if self.verbose: print(f" Split: {len(X_train)} training, {len(X_val)} validation samples") return X_train, X_val, y_train, y_val def start_epoch(self): """Mark the start of an epoch.""" self.epoch_start_time = time.time() if self.start_time is None: self.start_time = self.epoch_start_time def end_epoch(self, train_loss: float, val_loss: float, train_acc: float = None, val_acc: float = None) -> bool: """ End epoch and check for early stopping. Args: train_loss: Training loss for this epoch val_loss: Validation loss for this epoch train_acc: Training accuracy (optional) val_acc: Validation accuracy (optional) Returns: should_stop: Whether training should stop """ epoch_time = time.time() - self.epoch_start_time self.epoch_times.append(epoch_time) # Record metrics self.train_losses.append(train_loss) self.val_losses.append(val_loss) if train_acc is not None: self.train_accuracies.append(train_acc) if val_acc is not None: self.val_accuracies.append(val_acc) # Check for improvement improved = val_loss < (self.best_val_loss - self.min_delta) if improved: self.best_val_loss = val_loss self.epochs_no_improve = 0 else: self.epochs_no_improve += 1 # Check early stopping if self.epochs_no_improve >= self.patience: self.should_stop = True if self.verbose: print(f" Early stopping triggered after {self.patience} epochs without improvement") # Print progress if self.verbose: epoch_num = len(self.train_losses) status = "šŸ“ˆ" if improved else "āš ļø" if self.epochs_no_improve > self.patience // 2 else "šŸ“Š" acc_str = "" if train_acc is not None and val_acc is not None: acc_str = f", Train Acc: {train_acc:.1f}%, Val Acc: {val_acc:.1f}%" print(f" {status} Epoch {epoch_num}: Train Loss: {train_loss:.4f}, " f"Val Loss: {val_loss:.4f}{acc_str} ({epoch_time:.1f}s)") if improved: print(f" āœ… New best validation loss: {val_loss:.4f}") elif self.epochs_no_improve > 0: print(f" ā³ No improvement for {self.epochs_no_improve}/{self.patience} epochs") return self.should_stop def get_summary(self) -> Dict[str, Any]: """ Get training summary statistics. Returns: Dictionary with training summary """ total_time = time.time() - self.start_time if self.start_time else 0 avg_epoch_time = np.mean(self.epoch_times) if self.epoch_times else 0 summary = { 'total_epochs': len(self.train_losses), 'total_time': total_time, 'avg_epoch_time': avg_epoch_time, 'best_val_loss': self.best_val_loss, 'final_train_loss': self.train_losses[-1] if self.train_losses else None, 'final_val_loss': self.val_losses[-1] if self.val_losses else None, 'early_stopped': self.should_stop, 'epochs_no_improve': self.epochs_no_improve } if self.train_accuracies: summary['final_train_acc'] = self.train_accuracies[-1] summary['best_train_acc'] = max(self.train_accuracies) if self.val_accuracies: summary['final_val_acc'] = self.val_accuracies[-1] summary['best_val_acc'] = max(self.val_accuracies) return summary def print_summary(self): """Print comprehensive training summary.""" summary = self.get_summary() print("\n" + "="*60) print("šŸ TRAINING SUMMARY") print("="*60) print(f"šŸ“Š Performance:") print(f" • Best validation loss: {summary['best_val_loss']:.4f}") if 'best_val_acc' in summary: print(f" • Best validation accuracy: {summary['best_val_acc']:.1f}%") print(f"\nā±ļø Timing:") print(f" • Total epochs: {summary['total_epochs']}") print(f" • Total time: {summary['total_time']:.1f}s") print(f" • Average epoch time: {summary['avg_epoch_time']:.1f}s") print(f"\nšŸ›‘ Convergence:") if summary['early_stopped']: print(f" • Early stopping triggered āœ…") print(f" • Stopped after {summary['epochs_no_improve']} epochs without improvement") else: print(f" • Training completed normally") print(f" • Final epoch without improvement: {summary['epochs_no_improve']}") print("="*60) def train_with_monitoring(model, X: np.ndarray, y: np.ndarray, loss_fn, optimizer=None, epochs: int = 100, batch_size: int = 32, validation_split: float = 0.2, patience: int = 10, min_delta: float = 1e-4, learning_rate: float = 0.01, verbose: bool = True) -> TrainingMonitor: """ Train a model with comprehensive monitoring, validation splits, and early stopping. Args: model: Model with forward() and parameters() methods X: Input features y: Target labels loss_fn: Loss function optimizer: Optimizer (if None, uses simple SGD) epochs: Maximum number of epochs batch_size: Batch size for training validation_split: Fraction for validation patience: Early stopping patience min_delta: Minimum improvement threshold learning_rate: Learning rate for SGD (if no optimizer) verbose: Whether to print progress Returns: TrainingMonitor with complete training history """ monitor = TrainingMonitor(patience=patience, min_delta=min_delta, validation_split=validation_split, verbose=verbose) # Split data X_train, X_val, y_train, y_val = monitor.split_data(X, y) # Convert to tensors X_val_tensor = Tensor(X_val) y_val_tensor = Tensor(y_val.reshape(-1, 1) if len(y_val.shape) == 1 else y_val) if verbose: print(f"\nšŸš€ Starting training with monitoring:") print(f" • Epochs: {epochs} (max)") print(f" • Batch size: {batch_size}") print(f" • Learning rate: {learning_rate}") print(f" • Early stopping patience: {patience}") print(f" • Training on {len(X_train)} samples, validating on {len(X_val)} samples") for epoch in range(epochs): monitor.start_epoch() # Training phase epoch_train_loss = 0 correct_train = 0 total_train = 0 # Shuffle training data indices = np.random.permutation(len(X_train)) X_train_shuffled = X_train[indices] y_train_shuffled = y_train[indices] num_batches = len(X_train) // batch_size for batch_idx in range(num_batches): start_idx = batch_idx * batch_size end_idx = start_idx + batch_size batch_X = X_train_shuffled[start_idx:end_idx] batch_y = y_train_shuffled[start_idx:end_idx] # Convert to tensors inputs = Tensor(batch_X) targets = Tensor(batch_y.reshape(-1, 1) if len(batch_y.shape) == 1 else batch_y) # Forward pass outputs = model.forward(inputs) loss = loss_fn(outputs, targets) # Backward pass loss.backward() # Parameter update if optimizer: optimizer.step() optimizer.zero_grad() else: # Simple SGD for param in model.parameters(): if param.grad is not None: param.data = param.data - learning_rate * param.grad param.grad = None # Track metrics - safe data extraction try: if hasattr(loss, 'data'): if hasattr(loss.data, 'data'): loss_val = float(loss.data.data) elif hasattr(loss.data, '__iter__') and not isinstance(loss.data, str): loss_val = float(loss.data[0] if len(loss.data) > 0 else 0.0) else: loss_val = float(loss.data) else: loss_val = float(loss) except (ValueError, TypeError): loss_val = 0.0 # Fallback epoch_train_loss += loss_val # Calculate accuracy for classification outputs_np = np.array(outputs.data.data if hasattr(outputs.data, 'data') else outputs.data) if outputs_np.shape[1] > 1: # Multi-class predictions = np.argmax(outputs_np, axis=1) targets_np = batch_y if len(batch_y.shape) == 1 else np.argmax(batch_y, axis=1) else: # Binary predictions = (outputs_np > 0.5).astype(int).flatten() targets_np = batch_y.flatten() correct_train += np.sum(predictions == targets_np) total_train += len(targets_np) # Validation phase val_outputs = model.forward(X_val_tensor) val_loss = loss_fn(val_outputs, y_val_tensor) # Safe extraction for validation loss try: if hasattr(val_loss, 'data'): if hasattr(val_loss.data, 'data'): val_loss_val = float(val_loss.data.data) elif hasattr(val_loss.data, '__iter__') and not isinstance(val_loss.data, str): val_loss_val = float(val_loss.data[0] if len(val_loss.data) > 0 else 0.0) else: val_loss_val = float(val_loss.data) else: val_loss_val = float(val_loss) except (ValueError, TypeError): val_loss_val = 0.0 # Fallback # Validation accuracy val_outputs_np = np.array(val_outputs.data.data if hasattr(val_outputs.data, 'data') else val_outputs.data) if val_outputs_np.shape[1] > 1: # Multi-class val_predictions = np.argmax(val_outputs_np, axis=1) val_targets_np = y_val if len(y_val.shape) == 1 else np.argmax(y_val, axis=1) else: # Binary val_predictions = (val_outputs_np > 0.5).astype(int).flatten() val_targets_np = y_val.flatten() correct_val = np.sum(val_predictions == val_targets_np) val_accuracy = 100 * correct_val / len(val_targets_np) # Calculate epoch metrics train_loss = epoch_train_loss / num_batches train_accuracy = 100 * correct_train / total_train # Check for early stopping should_stop = monitor.end_epoch(train_loss, val_loss_val, train_accuracy, val_accuracy) if should_stop: break if verbose: monitor.print_summary() return monitor