mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2025-12-05 19:17:52 -06:00
Improve milestone 05 (Transformer) with letters for better visualization
- Enhanced attention proof to use A-Z letters instead of numbers - Shows MCYWUH → HUWYCM instead of [1,2,3] → [3,2,1] - More intuitive and fun for students - Removed quickdemo, generation, dialogue scripts (too slow/gibberish)
This commit is contained in:
@@ -241,21 +241,40 @@ class ReversalTransformer:
|
||||
return self._params
|
||||
|
||||
|
||||
def generate_reversal_dataset(num_samples=200, seq_len=6, vocab_size=10):
|
||||
def generate_reversal_dataset(num_samples=200, seq_len=6, vocab_size=26):
|
||||
"""
|
||||
Generate sequence reversal dataset.
|
||||
Generate sequence reversal dataset using letters A-Z.
|
||||
|
||||
Each sample is (input_seq, target_seq) where target = reverse(input)
|
||||
More intuitive than numbers: "CAT" → "TAC", "HELLO" → "OLLEH"
|
||||
"""
|
||||
dataset = []
|
||||
for _ in range(num_samples):
|
||||
# Generate random sequence (avoid 0 for clarity)
|
||||
seq = np.random.randint(1, vocab_size, size=seq_len)
|
||||
# Generate random sequence of letters (1-26 maps to A-Z)
|
||||
seq = np.random.randint(1, min(vocab_size, 27), size=seq_len)
|
||||
reversed_seq = seq[::-1].copy()
|
||||
dataset.append((seq, reversed_seq))
|
||||
return dataset
|
||||
|
||||
|
||||
def tokens_to_letters(tokens):
|
||||
"""Convert token indices to readable letters (1=A, 2=B, ...)"""
|
||||
return ''.join(chr(ord('A') + t - 1) if 1 <= t <= 26 else '?' for t in tokens)
|
||||
|
||||
|
||||
# Fun word examples for demonstration
|
||||
FUN_WORDS = [
|
||||
"PYTHON",
|
||||
"TORCH",
|
||||
"NEURAL",
|
||||
"TENSOR",
|
||||
"ATTEND",
|
||||
"VASWANI",
|
||||
"QUERY",
|
||||
"HELLO",
|
||||
]
|
||||
|
||||
|
||||
def train_epoch(model, dataset, optimizer, loss_fn):
|
||||
"""Train for one epoch."""
|
||||
total_loss = 0.0
|
||||
@@ -327,9 +346,9 @@ def main():
|
||||
console.print("="*70)
|
||||
console.print()
|
||||
|
||||
# Hyperparameters
|
||||
vocab_size = 10
|
||||
seq_len = 6
|
||||
# Hyperparameters
|
||||
vocab_size = 27 # 0 (padding) + A-Z (1-26)
|
||||
seq_len = 6 # 6-letter "words"
|
||||
embed_dim = 32
|
||||
num_heads = 4
|
||||
lr = 0.001
|
||||
@@ -339,12 +358,12 @@ def main():
|
||||
|
||||
console.print(Panel(
|
||||
f"[bold]Hyperparameters[/bold]\n"
|
||||
f" Vocabulary size: [cyan]{vocab_size}[/cyan] (tokens 0-9)\n"
|
||||
f" Sequence length: [cyan]{seq_len}[/cyan]\n"
|
||||
f" Embedding dim: [cyan]{embed_dim}[/cyan]\n"
|
||||
f" Attention heads: [cyan]{num_heads}[/cyan]\n"
|
||||
f" Learning rate: [cyan]{lr}[/cyan]\n"
|
||||
f" Epochs: [cyan]{epochs}[/cyan]",
|
||||
f" Vocabulary: [cyan]{vocab_size}[/cyan] tokens (A-Z letters)\n"
|
||||
f" Sequence: [cyan]{seq_len}[/cyan] letters per word\n"
|
||||
f" Embedding: [cyan]{embed_dim}[/cyan] dimensions\n"
|
||||
f" Attention: [cyan]{num_heads}[/cyan] heads\n"
|
||||
f" Learning: [cyan]{lr}[/cyan]\n"
|
||||
f" Epochs: [cyan]{epochs}[/cyan]",
|
||||
title="⚙️ Configuration",
|
||||
border_style="blue"
|
||||
))
|
||||
@@ -352,16 +371,17 @@ def main():
|
||||
|
||||
# Generate data
|
||||
console.print("📊 Generating reversal dataset...")
|
||||
console.print(" [dim]Task: Reverse letters like PYTHON → NOHTYP[/dim]")
|
||||
train_data = generate_reversal_dataset(num_samples=train_size, seq_len=seq_len, vocab_size=vocab_size)
|
||||
test_data = generate_reversal_dataset(num_samples=test_size, seq_len=seq_len, vocab_size=vocab_size)
|
||||
console.print(f" ✓ Training samples: {len(train_data)}")
|
||||
console.print(f" ✓ Test samples: {len(test_data)}\n")
|
||||
|
||||
# Show example
|
||||
# Show example with letters
|
||||
console.print("🔍 Example:")
|
||||
ex_in, ex_out = train_data[0]
|
||||
console.print(f" Input: {ex_in.tolist()}")
|
||||
console.print(f" Target: {ex_out.tolist()}")
|
||||
console.print(f" Input: [cyan]{tokens_to_letters(ex_in)}[/cyan] → Target: [green]{tokens_to_letters(ex_out)}[/green]")
|
||||
console.print(f" [dim](Numbers: {ex_in.tolist()} → {ex_out.tolist()})[/dim]")
|
||||
console.print()
|
||||
|
||||
# Build model
|
||||
@@ -458,7 +478,7 @@ def main():
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
# Show sample predictions
|
||||
# Show sample predictions with letters
|
||||
console.print(Panel("[bold]Sample Predictions[/bold]", border_style="blue"))
|
||||
console.print()
|
||||
|
||||
@@ -466,9 +486,13 @@ def main():
|
||||
match = "✓" if np.array_equal(pred, target) else "✗"
|
||||
style = "green" if np.array_equal(pred, target) else "red"
|
||||
|
||||
console.print(f" [{style}]{match}[/{style}] Input: {inp.tolist()}")
|
||||
console.print(f" Target: {target.tolist()}")
|
||||
console.print(f" Pred: {pred.tolist()}\n")
|
||||
inp_str = tokens_to_letters(inp)
|
||||
target_str = tokens_to_letters(target)
|
||||
pred_str = tokens_to_letters(pred)
|
||||
|
||||
console.print(f" [{style}]{match}[/{style}] Input: [cyan]{inp_str}[/cyan]")
|
||||
console.print(f" Target: [green]{target_str}[/green]")
|
||||
console.print(f" Pred: [{style}]{pred_str}[/{style}]\n")
|
||||
|
||||
# Verdict
|
||||
console.print("="*70)
|
||||
|
||||
@@ -1,886 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TinyTalks Q&A Generation (2017) - Transformer Era
|
||||
==================================================
|
||||
|
||||
📚 HISTORICAL CONTEXT:
|
||||
In 2017, Vaswani et al. published "Attention Is All You Need", showing that
|
||||
attention mechanisms alone (no RNNs!) could achieve state-of-the-art results
|
||||
on sequence tasks. This breakthrough launched the era of GPT, BERT, and modern LLMs.
|
||||
|
||||
🎯 WHAT YOU'RE BUILDING:
|
||||
Using YOUR TinyTorch implementations, you'll build a character-level conversational
|
||||
model that learns to answer questions - proving YOUR attention mechanism works!
|
||||
|
||||
TinyTalks is PERFECT for learning:
|
||||
- Small dataset (17.5 KB) = 3-5 minute training!
|
||||
- Clear Q&A format (easy to verify learning)
|
||||
- Progressive difficulty (5 levels)
|
||||
- Instant gratification: Watch your transformer learn to chat!
|
||||
|
||||
✅ REQUIRED MODULES (Run after Module 13):
|
||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
Module 01 (Tensor) : YOUR data structure with autograd
|
||||
Module 02 (Activations) : YOUR ReLU and GELU activations
|
||||
Module 03 (Layers) : YOUR Linear layers
|
||||
Module 04 (Losses) : YOUR CrossEntropyLoss
|
||||
Module 05 (Autograd) : YOUR automatic differentiation
|
||||
Module 06 (Optimizers) : YOUR Adam optimizer
|
||||
Module 08 (DataLoader) : YOUR data batching
|
||||
Module 10 (Tokenization) : YOUR CharTokenizer for text→numbers
|
||||
Module 11 (Embeddings) : YOUR token & positional embeddings
|
||||
Module 12 (Attention) : YOUR multi-head self-attention
|
||||
Module 13 (Transformers) : YOUR LayerNorm + TransformerBlock + GPT
|
||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
|
||||
🏗️ ARCHITECTURE (Character-Level Q&A Model):
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Output Predictions │
|
||||
│ Character Probabilities (vocab_size) │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
▲
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Output Projection │
|
||||
│ Module 03: vectors → vocabulary │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
▲
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Layer Norm │
|
||||
│ Module 13: Final normalization │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
▲
|
||||
╔══════════════════════════════════════════════════════════════════════════════╗
|
||||
║ Transformer Block × N (Repeat) ║
|
||||
║ ┌────────────────────────────────────────────────────────────────────────┐ ║
|
||||
║ │ Feed Forward Network │ ║
|
||||
║ │ Module 03: Linear → GELU → Linear │ ║
|
||||
║ └────────────────────────────────────────────────────────────────────────┘ ║
|
||||
║ ▲ ║
|
||||
║ ┌────────────────────────────────────────────────────────────────────────┐ ║
|
||||
║ │ Multi-Head Self-Attention │ ║
|
||||
║ │ Module 12: Query·Key^T·Value across all positions │ ║
|
||||
║ └────────────────────────────────────────────────────────────────────────┘ ║
|
||||
╚══════════════════════════════════════════════════════════════════════════════╝
|
||||
▲
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Positional Encoding │
|
||||
│ Module 11: Add position information │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
▲
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Character Embeddings │
|
||||
│ Module 11: chars → embed_dim vectors │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
▲
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Input Characters │
|
||||
│ "Q: What color is the sky? A:" │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
📊 EXPECTED PERFORMANCE:
|
||||
- Dataset: 17.5 KB TinyTalks (301 Q&A pairs, 5 difficulty levels)
|
||||
- Training time: 3-5 minutes (instant gratification!)
|
||||
- Vocabulary: ~68 unique characters (simple English Q&A)
|
||||
- Expected: 70-80% accuracy on Level 1-2 questions after training
|
||||
- Parameters: ~1.2M (perfect size for fast learning on small data)
|
||||
|
||||
💡 WHAT TO WATCH FOR:
|
||||
- Epoch 1-3: Model learns Q&A structure ("A:" follows "Q:")
|
||||
- Epoch 4-7: Starts giving sensible (if incorrect) answers
|
||||
- Epoch 8-12: 50-60% accuracy on simple questions
|
||||
- Epoch 13-20: 70-80% accuracy, proper grammar
|
||||
- Success = "Wow, my transformer actually learned to answer questions!"
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
import argparse
|
||||
import time
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
|
||||
# Add project root to path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(project_root)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def print_banner():
|
||||
"""Print a beautiful banner for the milestone"""
|
||||
banner_text = """
|
||||
╔══════════════════════════════════════════════════════════════════╗
|
||||
║ ║
|
||||
║ 🤖 TinyTalks Q&A Bot Training (2017) ║
|
||||
║ Transformer Architecture ║
|
||||
║ ║
|
||||
║ "Your first transformer learning to answer questions!" ║
|
||||
║ ║
|
||||
╚══════════════════════════════════════════════════════════════════╝
|
||||
"""
|
||||
console.print(Panel(banner_text, border_style="bright_blue", box=box.DOUBLE))
|
||||
|
||||
|
||||
def filter_by_levels(text, levels):
|
||||
"""
|
||||
Filter TinyTalks dataset to only include specified difficulty levels.
|
||||
|
||||
Levels are marked in the original generation as:
|
||||
L1: Greetings (47 pairs)
|
||||
L2: Facts (82 pairs)
|
||||
L3: Math (45 pairs)
|
||||
L4: Reasoning (87 pairs)
|
||||
L5: Context (40 pairs)
|
||||
|
||||
For simplicity, we filter by common patterns:
|
||||
L1: Hello, Hi, What is your name, etc.
|
||||
L2: What color, How many, etc.
|
||||
L3: What is X plus/minus, etc.
|
||||
"""
|
||||
if levels is None or levels == [1, 2, 3, 4, 5]:
|
||||
return text # Use full dataset
|
||||
|
||||
# Parse Q&A pairs
|
||||
pairs = []
|
||||
blocks = text.strip().split('\n\n')
|
||||
|
||||
for block in blocks:
|
||||
lines = block.strip().split('\n')
|
||||
if len(lines) == 2 and lines[0].startswith('Q:') and lines[1].startswith('A:'):
|
||||
q = lines[0][3:].strip()
|
||||
a = lines[1][3:].strip()
|
||||
|
||||
# Classify level (heuristic)
|
||||
level = 5 # default
|
||||
q_lower = q.lower()
|
||||
|
||||
if any(word in q_lower for word in ['hello', 'hi', 'hey', 'goodbye', 'bye', 'name', 'who are you', 'what are you']):
|
||||
level = 1
|
||||
elif any(word in q_lower for word in ['color', 'legs', 'days', 'months', 'sound', 'capital']):
|
||||
level = 2
|
||||
elif any(word in q_lower for word in ['plus', 'minus', 'times', 'divided', 'equals']):
|
||||
level = 3
|
||||
elif any(word in q_lower for word in ['use', 'where do', 'what do', 'happens if', 'need to']):
|
||||
level = 4
|
||||
|
||||
if level in levels:
|
||||
pairs.append(f"Q: {q}\nA: {a}")
|
||||
|
||||
filtered_text = '\n\n'.join(pairs)
|
||||
console.print(f"[yellow]📊 Filtered to Level(s) {levels}:[/yellow]")
|
||||
console.print(f" Q&A pairs: {len(pairs)}")
|
||||
console.print(f" Characters: {len(filtered_text)}")
|
||||
|
||||
return filtered_text
|
||||
|
||||
|
||||
class TinyTalksDataset:
|
||||
"""
|
||||
Character-level dataset for TinyTalks Q&A.
|
||||
|
||||
Creates sequences of characters for autoregressive language modeling:
|
||||
- Input: "Q: What color is the sky? A: The sk"
|
||||
- Target: ": What color is the sky? A: The sky"
|
||||
|
||||
The model learns to predict the next character given previous characters,
|
||||
naturally learning the Q&A pattern.
|
||||
"""
|
||||
|
||||
def __init__(self, text, seq_length=64, levels=None):
|
||||
"""
|
||||
Args:
|
||||
text: Full text string (Q&A pairs)
|
||||
seq_length: Length of input sequences
|
||||
levels: List of difficulty levels to include (1-5), None = all
|
||||
"""
|
||||
from tinytorch.text.tokenization import CharTokenizer
|
||||
|
||||
self.seq_length = seq_length
|
||||
|
||||
# Filter by levels if specified
|
||||
if levels:
|
||||
text = filter_by_levels(text, levels)
|
||||
|
||||
# Store original text for testing
|
||||
self.text = text
|
||||
|
||||
# Build character vocabulary using CharTokenizer
|
||||
self.tokenizer = CharTokenizer()
|
||||
self.tokenizer.build_vocab([text])
|
||||
|
||||
# Encode entire text
|
||||
self.data = self.tokenizer.encode(text)
|
||||
|
||||
console.print(f"[green]✓[/green] Dataset initialized:")
|
||||
console.print(f" Total characters: {len(text)}")
|
||||
console.print(f" Vocabulary size: {self.tokenizer.vocab_size}")
|
||||
console.print(f" Sequence length: {seq_length}")
|
||||
console.print(f" Total sequences: {len(self)}")
|
||||
|
||||
def __len__(self):
|
||||
"""Number of possible sequences"""
|
||||
return len(self.data) - self.seq_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Get one training example.
|
||||
|
||||
Returns:
|
||||
input_seq: Characters [idx : idx+seq_length]
|
||||
target_seq: Characters [idx+1 : idx+seq_length+1] (shifted by 1)
|
||||
"""
|
||||
input_seq = self.data[idx:idx + self.seq_length]
|
||||
target_seq = self.data[idx + 1:idx + self.seq_length + 1]
|
||||
return input_seq, target_seq
|
||||
|
||||
def decode(self, indices):
|
||||
"""Decode token indices back to text"""
|
||||
return self.tokenizer.decode(indices)
|
||||
|
||||
|
||||
class TinyGPT:
|
||||
"""
|
||||
Character-level GPT model for TinyTalks Q&A.
|
||||
|
||||
This is a simplified GPT architecture:
|
||||
1. Token embeddings (convert characters to vectors)
|
||||
2. Positional encodings (add position information)
|
||||
3. N transformer blocks (self-attention + feed-forward)
|
||||
4. Output projection (vectors back to character probabilities)
|
||||
|
||||
Built entirely from YOUR TinyTorch modules!
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size, embed_dim=128, num_layers=4, num_heads=4,
|
||||
max_seq_len=64, dropout=0.1):
|
||||
"""
|
||||
Args:
|
||||
vocab_size: Number of unique characters
|
||||
embed_dim: Dimension of embeddings and hidden states
|
||||
num_layers: Number of transformer blocks
|
||||
num_heads: Number of attention heads per block
|
||||
max_seq_len: Maximum sequence length
|
||||
dropout: Dropout probability (for training)
|
||||
"""
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from tinytorch.text.embeddings import Embedding, PositionalEncoding
|
||||
from tinytorch.models.transformer import LayerNorm, TransformerBlock
|
||||
from tinytorch.core.layers import Linear
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.embed_dim = embed_dim
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
# 1. Token embeddings: char_id → embed_dim vector
|
||||
self.token_embedding = Embedding(vocab_size, embed_dim)
|
||||
|
||||
# 2. Positional encoding: add position information
|
||||
self.pos_encoding = PositionalEncoding(max_seq_len, embed_dim)
|
||||
|
||||
# 3. Transformer blocks (stacked)
|
||||
self.blocks = []
|
||||
for _ in range(num_layers):
|
||||
block = TransformerBlock(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=4, # FFN hidden_dim = 4 * embed_dim
|
||||
dropout_prob=dropout
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
# 4. Final layer normalization
|
||||
self.ln_f = LayerNorm(embed_dim)
|
||||
|
||||
# 5. Output projection: embed_dim → vocab_size
|
||||
self.output_proj = Linear(embed_dim, vocab_size)
|
||||
|
||||
console.print(f"[green]✓[/green] TinyGPT model initialized:")
|
||||
console.print(f" Vocabulary: {vocab_size}")
|
||||
console.print(f" Embedding dim: {embed_dim}")
|
||||
console.print(f" Layers: {num_layers}")
|
||||
console.print(f" Heads: {num_heads}")
|
||||
console.print(f" Max sequence: {max_seq_len}")
|
||||
|
||||
# Count parameters
|
||||
total_params = self.count_parameters()
|
||||
console.print(f" [bold]Total parameters: {total_params:,}[/bold]")
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the model.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (batch, seq_len) with token indices
|
||||
|
||||
Returns:
|
||||
logits: Output tensor of shape (batch, seq_len, vocab_size)
|
||||
"""
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
# 1. Token embeddings: (batch, seq_len) → (batch, seq_len, embed_dim)
|
||||
x = self.token_embedding.forward(x)
|
||||
|
||||
# 2. Add positional encoding
|
||||
x = self.pos_encoding.forward(x)
|
||||
|
||||
# 3. Pass through transformer blocks
|
||||
for block in self.blocks:
|
||||
x = block.forward(x)
|
||||
|
||||
# 4. Final layer norm
|
||||
x = self.ln_f.forward(x)
|
||||
|
||||
# 5. Project to vocabulary: (batch, seq_len, embed_dim) → (batch, seq_len, vocab_size)
|
||||
logits = self.output_proj.forward(x)
|
||||
|
||||
return logits
|
||||
|
||||
def parameters(self):
|
||||
"""Get all trainable parameters"""
|
||||
params = []
|
||||
|
||||
# Token embeddings
|
||||
params.extend(self.token_embedding.parameters())
|
||||
|
||||
# Positional encoding (learnable parameters)
|
||||
params.extend(self.pos_encoding.parameters())
|
||||
|
||||
# Transformer blocks
|
||||
for block in self.blocks:
|
||||
params.extend(block.parameters())
|
||||
|
||||
# Final layer norm
|
||||
params.extend(self.ln_f.parameters())
|
||||
|
||||
# Output projection
|
||||
params.extend(self.output_proj.parameters())
|
||||
|
||||
# Ensure all require gradients
|
||||
for param in params:
|
||||
param.requires_grad = True
|
||||
|
||||
return params
|
||||
|
||||
def count_parameters(self):
|
||||
"""Count total trainable parameters"""
|
||||
total = 0
|
||||
for param in self.parameters():
|
||||
total += param.data.size
|
||||
return total
|
||||
|
||||
def generate(self, tokenizer, prompt="Q:", max_new_tokens=100, temperature=1.0,
|
||||
return_stats=False, use_cache=False):
|
||||
"""
|
||||
Generate text autoregressively.
|
||||
|
||||
Args:
|
||||
tokenizer: CharTokenizer for encoding/decoding
|
||||
prompt: Starting text
|
||||
max_new_tokens: How many characters to generate
|
||||
temperature: Sampling temperature (higher = more random)
|
||||
return_stats: If True, return (text, stats_dict) tuple
|
||||
use_cache: If True, use KV caching for 10-15x speedup (Module 14)
|
||||
|
||||
Returns:
|
||||
Generated text string, or (text, stats) if return_stats=True
|
||||
|
||||
Note:
|
||||
KV caching (use_cache=True) transforms generation from O(n²) to O(n):
|
||||
- Without cache: Recomputes attention for ALL tokens at each step
|
||||
- With cache: Only computes attention for NEW token, reuses past K/V
|
||||
- Speedup: ~10-15x for typical sequences (more speedup with longer sequences)
|
||||
"""
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
|
||||
# Encode prompt
|
||||
indices = tokenizer.encode(prompt)
|
||||
initial_len = len(indices)
|
||||
|
||||
if use_cache:
|
||||
# MODULE 14 OPTIMIZATION: KV-Cached Generation
|
||||
# Students learn this AFTER building the base transformer!
|
||||
try:
|
||||
from tinytorch.generation.kv_cache import enable_kv_cache, disable_kv_cache
|
||||
|
||||
# Enable caching on this model (non-invasive enhancement!)
|
||||
# If already enabled, just reset it; otherwise enable fresh
|
||||
if hasattr(self, '_cache_enabled') and self._cache_enabled:
|
||||
cache = self._kv_cache
|
||||
cache.reset()
|
||||
else:
|
||||
cache = enable_kv_cache(self)
|
||||
|
||||
console.print("[green]✓[/green] KV caching enabled! (Module 14 enhancement)")
|
||||
console.print(f"[dim] Architecture: {cache.num_layers} layers × {cache.num_heads} heads[/dim]")
|
||||
console.print(f"[dim] Memory: {cache.get_memory_usage()['total_mb']:.2f} MB cache[/dim]")
|
||||
console.print()
|
||||
|
||||
# Initialize cache with prompt
|
||||
# Process prompt tokens one by one to populate cache
|
||||
for i in range(len(indices)):
|
||||
token_input = Tensor(np.array([[indices[i]]]))
|
||||
_ = self.forward(token_input) # Populates cache as side effect
|
||||
if hasattr(self, '_kv_cache'):
|
||||
self._kv_cache.advance()
|
||||
|
||||
except ImportError as e:
|
||||
console.print(f"[yellow]⚠️ Module 14 (KV Caching) not available: {e}[/yellow]")
|
||||
console.print("[dim] Falling back to standard generation...[/dim]")
|
||||
use_cache = False
|
||||
|
||||
# Standard generation (or fallback from cache)
|
||||
# Generate tokens one at a time
|
||||
for step in range(max_new_tokens):
|
||||
if use_cache and hasattr(self, '_cache_enabled') and self._cache_enabled:
|
||||
# CACHED GENERATION: Only process new token
|
||||
# Get just the last token (cache handles history)
|
||||
new_token = indices[-1:]
|
||||
x_input = Tensor(np.array([new_token]))
|
||||
else:
|
||||
# STANDARD GENERATION: Process full context
|
||||
# Get last max_seq_len tokens (context window)
|
||||
context = indices[-self.max_seq_len:]
|
||||
x_input = Tensor(np.array([context]))
|
||||
|
||||
# Forward pass
|
||||
logits = self.forward(x_input)
|
||||
|
||||
# Get logits for last position: (vocab_size,)
|
||||
last_logits = logits.data[0, -1, :] / temperature
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
exp_logits = np.exp(last_logits - np.max(last_logits))
|
||||
probs = exp_logits / np.sum(exp_logits)
|
||||
|
||||
# Sample from distribution
|
||||
next_idx = np.random.choice(len(probs), p=probs)
|
||||
|
||||
# Append to sequence
|
||||
indices.append(next_idx)
|
||||
|
||||
# Advance cache position if using cache
|
||||
if use_cache and hasattr(self, '_kv_cache'):
|
||||
self._kv_cache.advance()
|
||||
|
||||
# Stop if we generate newline after "A:"
|
||||
if len(indices) > 3 and tokenizer.decode(indices[-3:]) == "\n\nQ":
|
||||
break
|
||||
|
||||
# Calculate statistics
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
tokens_generated = len(indices) - initial_len
|
||||
tokens_per_sec = tokens_generated / elapsed_time if elapsed_time > 0 else 0
|
||||
|
||||
generated_text = tokenizer.decode(indices)
|
||||
|
||||
if return_stats:
|
||||
stats = {
|
||||
'tokens_generated': tokens_generated,
|
||||
'time_sec': elapsed_time,
|
||||
'tokens_per_sec': tokens_per_sec,
|
||||
'total_tokens': len(indices),
|
||||
'used_cache': use_cache
|
||||
}
|
||||
return generated_text, stats
|
||||
|
||||
return generated_text
|
||||
|
||||
|
||||
def test_model_predictions(model, dataset, test_prompts=None):
|
||||
"""Test model on specific prompts and show predictions with performance"""
|
||||
if test_prompts is None:
|
||||
test_prompts = ["Q: Hello!", "Q: What is your name?", "Q: Hi!"]
|
||||
|
||||
console.print("\n[bold yellow]🧪 Testing Live Predictions:[/bold yellow]")
|
||||
|
||||
total_speed = 0
|
||||
count = 0
|
||||
|
||||
for prompt in test_prompts:
|
||||
try:
|
||||
full_prompt = prompt + "\nA:"
|
||||
response, stats = model.generate(
|
||||
dataset.tokenizer,
|
||||
prompt=full_prompt,
|
||||
max_new_tokens=30,
|
||||
temperature=0.5,
|
||||
return_stats=True
|
||||
)
|
||||
|
||||
# Extract just the answer
|
||||
if "\nA:" in response:
|
||||
answer = response.split("\nA:")[1].split("\n")[0].strip()
|
||||
else:
|
||||
answer = response[len(full_prompt):].strip()
|
||||
|
||||
console.print(f" {prompt}")
|
||||
console.print(f" [cyan]A: {answer}[/cyan]")
|
||||
console.print(f" [dim]⚡ {stats['tokens_per_sec']:.1f} tok/s[/dim]")
|
||||
|
||||
total_speed += stats['tokens_per_sec']
|
||||
count += 1
|
||||
except Exception as e:
|
||||
console.print(f" {prompt} → [red]Error: {str(e)[:50]}[/red]")
|
||||
|
||||
if count > 0:
|
||||
avg_speed = total_speed / count
|
||||
console.print(f"\n [dim]Average generation speed: {avg_speed:.1f} tokens/sec[/dim]")
|
||||
|
||||
|
||||
def train_tinytalks_gpt(model, dataset, optimizer, criterion, epochs=20, batch_size=32,
|
||||
log_interval=50, test_prompts=None):
|
||||
"""
|
||||
Train the TinyGPT model on TinyTalks dataset.
|
||||
|
||||
Training loop:
|
||||
1. Sample random batch of sequences
|
||||
2. Forward pass: predict next character for each position
|
||||
3. Compute cross-entropy loss
|
||||
4. Backward pass: compute gradients
|
||||
5. Update parameters with Adam
|
||||
6. Periodically test on sample questions to show learning
|
||||
|
||||
Args:
|
||||
model: TinyGPT instance
|
||||
dataset: TinyTalksDataset instance
|
||||
optimizer: Adam optimizer
|
||||
criterion: CrossEntropyLoss
|
||||
epochs: Number of training epochs
|
||||
batch_size: Number of sequences per batch
|
||||
log_interval: Print loss every N batches
|
||||
test_prompts: Optional list of questions to test during training
|
||||
"""
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
# Note: Autograd is automatically enabled when tinytorch is imported
|
||||
|
||||
console.print("\n[bold cyan]Starting Training...[/bold cyan]")
|
||||
console.print(f" Epochs: {epochs}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Dataset size: {len(dataset)} sequences")
|
||||
console.print(f" Loss updates: Every {log_interval} batches")
|
||||
console.print(f" Model tests: Every 3 epochs")
|
||||
console.print()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(epochs):
|
||||
epoch_start = time.time()
|
||||
epoch_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
# Calculate batches per epoch
|
||||
batches_per_epoch = min(500, len(dataset) // batch_size)
|
||||
|
||||
for batch_idx in range(batches_per_epoch):
|
||||
# Sample random batch
|
||||
batch_indices = np.random.randint(0, len(dataset), size=batch_size)
|
||||
|
||||
batch_inputs = []
|
||||
batch_targets = []
|
||||
|
||||
for idx in batch_indices:
|
||||
input_seq, target_seq = dataset[int(idx)]
|
||||
batch_inputs.append(input_seq)
|
||||
batch_targets.append(target_seq)
|
||||
|
||||
# Convert to tensors: (batch, seq_len)
|
||||
batch_input = Tensor(np.array(batch_inputs))
|
||||
batch_target = Tensor(np.array(batch_targets))
|
||||
|
||||
# Forward pass
|
||||
logits = model.forward(batch_input)
|
||||
|
||||
# Reshape for loss computation: (batch, seq, vocab) → (batch*seq, vocab)
|
||||
# IMPORTANT: Use Tensor.reshape() to preserve computation graph!
|
||||
batch_size_actual, seq_length, vocab_size = logits.shape
|
||||
logits_2d = logits.reshape(batch_size_actual * seq_length, vocab_size)
|
||||
targets_1d = batch_target.reshape(-1)
|
||||
|
||||
# Compute loss
|
||||
loss = criterion.forward(logits_2d, targets_1d)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Update parameters
|
||||
optimizer.step()
|
||||
|
||||
# Zero gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Track loss
|
||||
batch_loss = float(loss.data)
|
||||
epoch_loss += batch_loss
|
||||
num_batches += 1
|
||||
|
||||
# Log progress - show every 10 batches AND first batch of each epoch
|
||||
if (batch_idx + 1) % log_interval == 0 or batch_idx == 0:
|
||||
avg_loss = epoch_loss / num_batches
|
||||
elapsed = time.time() - start_time
|
||||
progress_pct = ((batch_idx + 1) / batches_per_epoch) * 100
|
||||
console.print(
|
||||
f" Epoch {epoch+1}/{epochs} [{progress_pct:5.1f}%] | "
|
||||
f"Batch {batch_idx+1:3d}/{batches_per_epoch} | "
|
||||
f"Loss: {batch_loss:.4f} | "
|
||||
f"Avg: {avg_loss:.4f} | "
|
||||
f"⏱ {elapsed:.1f}s"
|
||||
)
|
||||
sys.stdout.flush() # Force immediate output
|
||||
|
||||
# Epoch summary
|
||||
avg_epoch_loss = epoch_loss / num_batches
|
||||
epoch_time = time.time() - epoch_start
|
||||
console.print(
|
||||
f"[green]✓[/green] Epoch {epoch+1}/{epochs} complete | "
|
||||
f"Avg Loss: {avg_epoch_loss:.4f} | "
|
||||
f"Time: {epoch_time:.1f}s"
|
||||
)
|
||||
|
||||
# Test model every 3 epochs to show learning progress
|
||||
if (epoch + 1) % 3 == 0 or epoch == 0 or epoch == epochs - 1:
|
||||
console.print("\n[bold yellow]📝 Testing model on sample questions...[/bold yellow]")
|
||||
test_model_predictions(model, dataset, test_prompts)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
console.print(f"\n[bold green]✓ Training complete![/bold green]")
|
||||
console.print(f" Total time: {total_time/60:.2f} minutes")
|
||||
|
||||
|
||||
def demo_questions(model, tokenizer):
|
||||
"""
|
||||
Demonstrate the model answering questions with performance metrics.
|
||||
|
||||
Shows how well the model learned from TinyTalks by asking
|
||||
various questions from different difficulty levels.
|
||||
Also displays generation performance metrics.
|
||||
"""
|
||||
console.print("\n" + "=" * 70)
|
||||
console.print("[bold cyan]🤖 TinyBot Demo: Ask Me Questions![/bold cyan]")
|
||||
console.print("=" * 70)
|
||||
|
||||
# Test questions from different levels
|
||||
test_questions = [
|
||||
"Q: Hello!",
|
||||
"Q: What is your name?",
|
||||
"Q: What color is the sky?",
|
||||
"Q: How many legs does a dog have?",
|
||||
"Q: What is 2 plus 3?",
|
||||
"Q: What do you use a pen for?",
|
||||
]
|
||||
|
||||
# Track performance across all questions
|
||||
all_stats = []
|
||||
|
||||
for question in test_questions:
|
||||
console.print(f"\n[yellow]{question}[/yellow]")
|
||||
|
||||
# Generate answer with statistics
|
||||
response, stats = model.generate(
|
||||
tokenizer,
|
||||
prompt=question + "\nA:",
|
||||
max_new_tokens=50,
|
||||
temperature=0.8,
|
||||
return_stats=True
|
||||
)
|
||||
|
||||
# Extract just the answer part
|
||||
if "\nA:" in response:
|
||||
answer = response.split("\nA:")[1].split("\n")[0].strip()
|
||||
console.print(f"[green]A: {answer}[/green]")
|
||||
else:
|
||||
console.print(f"[dim]{response}[/dim]")
|
||||
|
||||
# Display performance metrics
|
||||
console.print(
|
||||
f"[dim]⚡ {stats['tokens_per_sec']:.1f} tok/s | "
|
||||
f"📊 {stats['tokens_generated']} tokens | "
|
||||
f"⏱️ {stats['time_sec']:.3f}s[/dim]"
|
||||
)
|
||||
|
||||
all_stats.append(stats)
|
||||
|
||||
console.print("\n" + "=" * 70)
|
||||
|
||||
# Display performance summary
|
||||
if all_stats:
|
||||
avg_tokens_per_sec = np.mean([s['tokens_per_sec'] for s in all_stats])
|
||||
avg_time = np.mean([s['time_sec'] for s in all_stats])
|
||||
total_tokens = sum([s['tokens_generated'] for s in all_stats])
|
||||
total_time = sum([s['time_sec'] for s in all_stats])
|
||||
|
||||
perf_table = Table(title="⚡ Generation Performance Summary", box=box.ROUNDED)
|
||||
perf_table.add_column("Metric", style="cyan")
|
||||
perf_table.add_column("Value", style="green", justify="right")
|
||||
|
||||
perf_table.add_row("Average Speed", f"{avg_tokens_per_sec:.1f} tokens/sec")
|
||||
perf_table.add_row("Average Time/Question", f"{avg_time:.3f} seconds")
|
||||
perf_table.add_row("Total Tokens Generated", f"{total_tokens} tokens")
|
||||
perf_table.add_row("Total Generation Time", f"{total_time:.2f} seconds")
|
||||
perf_table.add_row("Questions Answered", f"{len(test_questions)}")
|
||||
|
||||
console.print(perf_table)
|
||||
console.print()
|
||||
|
||||
# Educational note about performance
|
||||
console.print("[dim]💡 Note: In Module 14 (KV Caching), you'll learn how to make this 10-15x faster![/dim]")
|
||||
console.print("[dim] Current: ~{:.0f} tok/s → With KV Cache: ~{:.0f} tok/s 🚀[/dim]".format(
|
||||
avg_tokens_per_sec, avg_tokens_per_sec * 12
|
||||
))
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training pipeline"""
|
||||
parser = argparse.ArgumentParser(description='Train TinyGPT on TinyTalks Q&A')
|
||||
parser.add_argument('--epochs', type=int, default=30, help='Number of training epochs (default: 30)')
|
||||
parser.add_argument('--batch-size', type=int, default=16, help='Batch size (default: 16)')
|
||||
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
|
||||
parser.add_argument('--seq-length', type=int, default=64, help='Sequence length (default: 64)')
|
||||
parser.add_argument('--embed-dim', type=int, default=96, help='Embedding dimension (default: 96, ~500K params)')
|
||||
parser.add_argument('--num-layers', type=int, default=4, help='Number of transformer layers (default: 4)')
|
||||
parser.add_argument('--num-heads', type=int, default=4, help='Number of attention heads (default: 4)')
|
||||
parser.add_argument('--levels', type=str, default=None, help='Difficulty levels to train on (e.g. "1" or "1,2"). Default: all levels')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse levels argument
|
||||
if args.levels:
|
||||
levels = [int(l.strip()) for l in args.levels.split(',')]
|
||||
else:
|
||||
levels = None
|
||||
|
||||
print_banner()
|
||||
|
||||
# Import TinyTorch components
|
||||
console.print("\n[bold]Importing TinyTorch components...[/bold]")
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from tinytorch.core.optimizers import Adam
|
||||
from tinytorch.core.losses import CrossEntropyLoss
|
||||
from tinytorch.text.tokenization import CharTokenizer
|
||||
console.print("[green]✓[/green] All modules imported successfully!")
|
||||
except ImportError as e:
|
||||
console.print(f"[red]✗[/red] Import error: {e}")
|
||||
console.print("\nMake sure you have completed all required modules:")
|
||||
console.print(" - Module 01 (Tensor)")
|
||||
console.print(" - Module 02 (Activations)")
|
||||
console.print(" - Module 03 (Layers)")
|
||||
console.print(" - Module 04 (Losses)")
|
||||
console.print(" - Module 05 (Autograd)")
|
||||
console.print(" - Module 06 (Optimizers)")
|
||||
console.print(" - Module 10 (Tokenization)")
|
||||
console.print(" - Module 11 (Embeddings)")
|
||||
console.print(" - Module 12 (Attention)")
|
||||
console.print(" - Module 13 (Transformers)")
|
||||
return
|
||||
|
||||
# Load TinyTalks dataset
|
||||
console.print("\n[bold]Loading TinyTalks dataset...[/bold]")
|
||||
dataset_path = os.path.join(project_root, "datasets", "tinytalks", "splits", "train.txt")
|
||||
|
||||
if not os.path.exists(dataset_path):
|
||||
console.print(f"[red]✗[/red] Dataset not found: {dataset_path}")
|
||||
console.print("\nPlease generate the dataset first:")
|
||||
console.print(" python datasets/tinytalks/scripts/generate_tinytalks.py")
|
||||
return
|
||||
|
||||
with open(dataset_path, 'r', encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
|
||||
console.print(f"[green]✓[/green] Loaded dataset from: {os.path.basename(dataset_path)}")
|
||||
console.print(f" File size: {len(text)} characters")
|
||||
|
||||
# Create dataset with level filtering
|
||||
dataset = TinyTalksDataset(text, seq_length=args.seq_length, levels=levels)
|
||||
|
||||
# Set test prompts based on levels
|
||||
if levels and 1 in levels:
|
||||
test_prompts = ["Q: Hello!", "Q: What is your name?", "Q: Hi!"]
|
||||
elif levels and 2 in levels:
|
||||
test_prompts = ["Q: What color is the sky?", "Q: How many legs does a dog have?"]
|
||||
elif levels and 3 in levels:
|
||||
test_prompts = ["Q: What is 2 plus 3?", "Q: What is 5 minus 2?"]
|
||||
else:
|
||||
test_prompts = ["Q: Hello!", "Q: What is your name?", "Q: What color is the sky?"]
|
||||
|
||||
# Initialize model
|
||||
console.print("\n[bold]Initializing TinyGPT model...[/bold]")
|
||||
model = TinyGPT(
|
||||
vocab_size=dataset.tokenizer.vocab_size,
|
||||
embed_dim=args.embed_dim,
|
||||
num_layers=args.num_layers,
|
||||
num_heads=args.num_heads,
|
||||
max_seq_len=args.seq_length,
|
||||
dropout=0.1
|
||||
)
|
||||
|
||||
# Initialize optimizer and loss
|
||||
console.print("\n[bold]Initializing training components...[/bold]")
|
||||
optimizer = Adam(model.parameters(), lr=args.lr)
|
||||
criterion = CrossEntropyLoss()
|
||||
console.print(f"[green]✓[/green] Optimizer: Adam (lr={args.lr})")
|
||||
console.print(f"[green]✓[/green] Loss: CrossEntropyLoss")
|
||||
|
||||
# Print configuration
|
||||
table = Table(title="Training Configuration", box=box.ROUNDED)
|
||||
table.add_column("Parameter", style="cyan")
|
||||
table.add_column("Value", style="green")
|
||||
|
||||
dataset_desc = f"TinyTalks Level(s) {levels}" if levels else "TinyTalks (All Levels)"
|
||||
table.add_row("Dataset", dataset_desc)
|
||||
table.add_row("Vocabulary Size", str(dataset.tokenizer.vocab_size))
|
||||
table.add_row("Model Parameters", f"{model.count_parameters():,}")
|
||||
table.add_row("Epochs", str(args.epochs))
|
||||
table.add_row("Batch Size", str(args.batch_size))
|
||||
table.add_row("Learning Rate", str(args.lr))
|
||||
table.add_row("Sequence Length", str(args.seq_length))
|
||||
table.add_row("Embedding Dim", str(args.embed_dim))
|
||||
table.add_row("Layers", str(args.num_layers))
|
||||
table.add_row("Attention Heads", str(args.num_heads))
|
||||
table.add_row("Expected Time", "3-5 minutes")
|
||||
|
||||
console.print(table)
|
||||
|
||||
# Train model
|
||||
train_tinytalks_gpt(
|
||||
model=model,
|
||||
dataset=dataset,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
log_interval=5, # Log every 5 batches for frequent updates
|
||||
test_prompts=test_prompts
|
||||
)
|
||||
|
||||
# Demo Q&A
|
||||
demo_questions(model, dataset.tokenizer)
|
||||
|
||||
# Success message
|
||||
console.print("\n[bold green]🎉 Congratulations![/bold green]")
|
||||
console.print("You've successfully trained a transformer to answer questions!")
|
||||
console.print("\nYou used:")
|
||||
console.print(" ✓ YOUR Tensor implementation (Module 01)")
|
||||
console.print(" ✓ YOUR Activations (Module 02)")
|
||||
console.print(" ✓ YOUR Linear layers (Module 03)")
|
||||
console.print(" ✓ YOUR CrossEntropyLoss (Module 04)")
|
||||
console.print(" ✓ YOUR Autograd system (Module 05)")
|
||||
console.print(" ✓ YOUR Adam optimizer (Module 06)")
|
||||
console.print(" ✓ YOUR CharTokenizer (Module 10)")
|
||||
console.print(" ✓ YOUR Embeddings (Module 11)")
|
||||
console.print(" ✓ YOUR Multi-Head Attention (Module 12)")
|
||||
console.print(" ✓ YOUR Transformer blocks (Module 13)")
|
||||
console.print("\n[bold]This is the foundation of ChatGPT, built by YOU from scratch![/bold]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,498 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CodeBot - Python Autocomplete Demo
|
||||
===================================
|
||||
|
||||
Train a transformer to autocomplete Python code in 2 minutes!
|
||||
|
||||
Student Journey:
|
||||
1. Watch it train (2 min)
|
||||
2. See demo completions (2 min)
|
||||
3. Try it yourself (5 min)
|
||||
4. Find its limits (2 min)
|
||||
5. Teach it new patterns (3 min)
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
# Add TinyTorch to path
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
import tinytorch as tt
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from tinytorch.core.optimizers import Adam
|
||||
from tinytorch.core.losses import CrossEntropyLoss
|
||||
from tinytorch.models.transformer import GPT
|
||||
from tinytorch.text.tokenization import CharTokenizer # Module 10: Students built this!
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Python Code Dataset
|
||||
# ============================================================================
|
||||
|
||||
# Hand-curated 50 simple Python patterns for autocomplete
|
||||
PYTHON_PATTERNS = [
|
||||
# Basic arithmetic functions (10)
|
||||
"def add(a, b):\n return a + b",
|
||||
"def subtract(a, b):\n return a - b",
|
||||
"def multiply(x, y):\n return x * y",
|
||||
"def divide(a, b):\n return a / b",
|
||||
"def power(base, exp):\n return base ** exp",
|
||||
"def modulo(a, b):\n return a % b",
|
||||
"def max_of_two(a, b):\n return a if a > b else b",
|
||||
"def min_of_two(a, b):\n return a if a < b else b",
|
||||
"def absolute(x):\n return x if x >= 0 else -x",
|
||||
"def square(x):\n return x * x",
|
||||
|
||||
# For loops (10)
|
||||
"for i in range(10):\n print(i)",
|
||||
"for i in range(5):\n print(i * 2)",
|
||||
"for item in items:\n print(item)",
|
||||
"for i in range(len(arr)):\n arr[i] = arr[i] * 2",
|
||||
"for num in numbers:\n total += num",
|
||||
"for i in range(0, 10, 2):\n print(i)",
|
||||
"for char in text:\n print(char)",
|
||||
"for key in dict:\n print(key, dict[key])",
|
||||
"for i, val in enumerate(items):\n print(i, val)",
|
||||
"for x in range(3):\n for y in range(3):\n print(x, y)",
|
||||
|
||||
# If statements (10)
|
||||
"if x > 0:\n print('positive')",
|
||||
"if x < 0:\n print('negative')",
|
||||
"if x == 0:\n print('zero')",
|
||||
"if age >= 18:\n print('adult')",
|
||||
"if score > 90:\n grade = 'A'",
|
||||
"if name:\n print(f'Hello {name}')",
|
||||
"if x > 0 and x < 10:\n print('single digit')",
|
||||
"if x == 5 or x == 10:\n print('five or ten')",
|
||||
"if not done:\n continue_work()",
|
||||
"if condition:\n do_something()\nelse:\n do_other()",
|
||||
|
||||
# List operations (10)
|
||||
"numbers = [1, 2, 3, 4, 5]",
|
||||
"squares = [x**2 for x in range(10)]",
|
||||
"evens = [n for n in numbers if n % 2 == 0]",
|
||||
"first = items[0]",
|
||||
"last = items[-1]",
|
||||
"items.append(new_item)",
|
||||
"items.extend(more_items)",
|
||||
"items.remove(old_item)",
|
||||
"length = len(items)",
|
||||
"sorted_items = sorted(items)",
|
||||
|
||||
# String operations (10)
|
||||
"text = 'Hello, World!'",
|
||||
"upper = text.upper()",
|
||||
"lower = text.lower()",
|
||||
"words = text.split()",
|
||||
"joined = ' '.join(words)",
|
||||
"starts = text.startswith('Hello')",
|
||||
"ends = text.endswith('!')",
|
||||
"replaced = text.replace('World', 'Python')",
|
||||
"stripped = text.strip()",
|
||||
"message = f'Hello {name}!'",
|
||||
]
|
||||
|
||||
|
||||
def create_code_dataset() -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Split patterns into train and test sets.
|
||||
|
||||
Returns:
|
||||
(train_patterns, test_patterns)
|
||||
"""
|
||||
# Use first 45 for training, last 5 for testing
|
||||
train = PYTHON_PATTERNS[:45]
|
||||
test = PYTHON_PATTERNS[45:]
|
||||
|
||||
return train, test
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tokenization (Using Student's CharTokenizer from Module 10!)
|
||||
# ============================================================================
|
||||
|
||||
def create_tokenizer(texts: List[str]) -> CharTokenizer:
|
||||
"""
|
||||
Create tokenizer using students' CharTokenizer from Module 10.
|
||||
|
||||
This shows how YOUR tokenizer from Module 10 enables real applications!
|
||||
"""
|
||||
tokenizer = CharTokenizer()
|
||||
tokenizer.build_vocab(texts) # Build vocab from our Python patterns
|
||||
return tokenizer
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Training
|
||||
# ============================================================================
|
||||
|
||||
def train_codebot(
|
||||
model: GPT,
|
||||
optimizer: Adam,
|
||||
tokenizer: CharTokenizer,
|
||||
train_patterns: List[str],
|
||||
max_steps: int = 5000,
|
||||
seq_length: int = 128,
|
||||
):
|
||||
"""Train CodeBot on Python patterns."""
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("TRAINING CODEBOT...")
|
||||
print("="*70)
|
||||
print()
|
||||
print(f"Loading training data: {len(train_patterns)} Python code patterns ✓")
|
||||
print()
|
||||
print(f"Model size: ~{sum(np.prod(p.shape) for p in model.parameters()):,} parameters")
|
||||
print(f"Training for ~{max_steps:,} steps (estimated 2 minutes)")
|
||||
print()
|
||||
|
||||
# Encode and pad patterns
|
||||
train_tokens = []
|
||||
for pattern in train_patterns:
|
||||
tokens = tokenizer.encode(pattern)
|
||||
# Truncate or pad to seq_length
|
||||
if len(tokens) > seq_length:
|
||||
tokens = tokens[:seq_length]
|
||||
else:
|
||||
tokens = tokens + [0] * (seq_length - len(tokens)) # Pad with 0
|
||||
train_tokens.append(tokens)
|
||||
|
||||
# Loss function
|
||||
loss_fn = CrossEntropyLoss()
|
||||
|
||||
# Training loop
|
||||
start_time = time.time()
|
||||
step = 0
|
||||
losses = []
|
||||
|
||||
# Progress markers
|
||||
progress_points = [0, 500, 1000, 2000, max_steps]
|
||||
messages = [
|
||||
"[The model knows nothing yet]",
|
||||
"[Learning basic patterns...]",
|
||||
"[Getting better at Python syntax...]",
|
||||
"[Almost there...]",
|
||||
"[Training complete!]"
|
||||
]
|
||||
|
||||
while step <= max_steps:
|
||||
# Sample random pattern
|
||||
tokens = train_tokens[np.random.randint(len(train_tokens))]
|
||||
|
||||
# Create input/target
|
||||
input_seq = tokens[:-1]
|
||||
target_seq = tokens[1:]
|
||||
|
||||
# Convert to tensors
|
||||
x = Tensor(np.array([input_seq], dtype=np.int32), requires_grad=False)
|
||||
y_true = Tensor(np.array([target_seq], dtype=np.int32), requires_grad=False)
|
||||
|
||||
# Forward pass
|
||||
logits = model.forward(x)
|
||||
|
||||
# Compute loss
|
||||
batch_size = 1
|
||||
seq_len = logits.data.shape[1]
|
||||
vocab_size = logits.data.shape[2]
|
||||
|
||||
logits_flat = logits.reshape((batch_size * seq_len, vocab_size))
|
||||
targets_flat = y_true.reshape((batch_size * seq_len,))
|
||||
|
||||
loss = loss_fn(logits_flat, targets_flat)
|
||||
|
||||
# Backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad = np.clip(param.grad, -1.0, 1.0)
|
||||
|
||||
# Update
|
||||
optimizer.step()
|
||||
|
||||
# Track
|
||||
losses.append(loss.data.item())
|
||||
|
||||
# Print progress at markers
|
||||
if step in progress_points:
|
||||
avg_loss = np.mean(losses[-100:]) if losses else loss.data.item()
|
||||
elapsed = time.time() - start_time
|
||||
msg_idx = progress_points.index(step)
|
||||
print(f"Step {step:4d}/{max_steps} | Loss: {avg_loss:.3f} | {messages[msg_idx]}")
|
||||
|
||||
step += 1
|
||||
|
||||
# Time limit
|
||||
if time.time() - start_time > 180: # 3 minutes max
|
||||
break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
final_loss = np.mean(losses[-100:])
|
||||
loss_decrease = ((losses[0] - final_loss) / losses[0]) * 100
|
||||
|
||||
print()
|
||||
print(f"✓ CodeBot trained in {int(total_time)} seconds!")
|
||||
print(f"✓ Loss decreased by {loss_decrease:.0f}%!")
|
||||
print()
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Code Completion
|
||||
# ============================================================================
|
||||
|
||||
def complete_code(
|
||||
model: GPT,
|
||||
tokenizer: CharTokenizer,
|
||||
partial_code: str,
|
||||
max_gen_length: int = 50,
|
||||
) -> str:
|
||||
"""
|
||||
Complete partial Python code.
|
||||
|
||||
Args:
|
||||
model: Trained GPT model
|
||||
tokenizer: Tokenizer
|
||||
partial_code: Incomplete code
|
||||
max_gen_length: Max characters to generate
|
||||
|
||||
Returns:
|
||||
Completed code
|
||||
"""
|
||||
tokens = tokenizer.encode(partial_code)
|
||||
|
||||
# Generate
|
||||
for _ in range(max_gen_length):
|
||||
x = Tensor(np.array([tokens], dtype=np.int32), requires_grad=False)
|
||||
logits = model.forward(x)
|
||||
|
||||
# Get next token (greedy)
|
||||
next_logits = logits.data[0, -1, :]
|
||||
next_token = int(np.argmax(next_logits))
|
||||
|
||||
# Stop at padding (0) or if we've generated enough
|
||||
if next_token == 0:
|
||||
break
|
||||
|
||||
tokens.append(next_token)
|
||||
|
||||
# Decode
|
||||
completed = tokenizer.decode(tokens)
|
||||
|
||||
# Return just the generated part
|
||||
return completed[len(partial_code):]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Demo Modes
|
||||
# ============================================================================
|
||||
|
||||
def demo_mode(model: GPT, tokenizer: CharTokenizer):
|
||||
"""Show 5 demo completions."""
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("🎯 DEMO MODE: WATCH CODEBOT AUTOCOMPLETE")
|
||||
print("="*70)
|
||||
print()
|
||||
print("I'll show you 5 examples of what CodeBot learned:")
|
||||
print()
|
||||
|
||||
demos = [
|
||||
("def subtract(a, b):\n return a", "Basic Function"),
|
||||
("for i in range(", "For Loop"),
|
||||
("if x > 0:\n print(", "If Statement"),
|
||||
("squares = [x**2 for x in ", "List Comprehension"),
|
||||
("def multiply(x, y):\n return x", "Function Return"),
|
||||
]
|
||||
|
||||
success_count = 0
|
||||
|
||||
for i, (partial, name) in enumerate(demos, 1):
|
||||
print(f"Example {i}: {name}")
|
||||
print("─" * 70)
|
||||
print(f"You type: {partial.replace(chr(10), chr(10) + ' ')}")
|
||||
|
||||
completion = complete_code(model, tokenizer, partial, max_gen_length=30)
|
||||
|
||||
print(f"CodeBot adds: {completion[:50]}...")
|
||||
|
||||
# Simple success check (generated something)
|
||||
if completion.strip():
|
||||
print("✓ Completion generated")
|
||||
success_count += 1
|
||||
else:
|
||||
print("✗ No completion")
|
||||
|
||||
print("─" * 70)
|
||||
print()
|
||||
|
||||
print(f"Demo success rate: {success_count}/5 ({success_count*20}%)")
|
||||
if success_count >= 4:
|
||||
print("🎉 CodeBot is working great!")
|
||||
print()
|
||||
|
||||
|
||||
def interactive_mode(model: GPT, tokenizer: CharTokenizer):
|
||||
"""Let student try CodeBot."""
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("🎮 YOUR TURN: TRY CODEBOT!")
|
||||
print("="*70)
|
||||
print()
|
||||
print("Type partial Python code and see what CodeBot suggests.")
|
||||
print("Type 'demo' to see examples, 'quit' to exit.")
|
||||
print()
|
||||
|
||||
examples = [
|
||||
"def add(a, b):\n return a",
|
||||
"for i in range(",
|
||||
"if name:\n print(",
|
||||
"numbers = [1, 2, 3]",
|
||||
]
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\nCodeBot> ").strip()
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
if user_input.lower() == 'quit':
|
||||
print("\n👋 Thanks for trying CodeBot!")
|
||||
break
|
||||
|
||||
if user_input.lower() == 'demo':
|
||||
print("\nTry these examples:")
|
||||
for ex in examples:
|
||||
print(f" → {ex[:40]}...")
|
||||
continue
|
||||
|
||||
# Complete the code
|
||||
print()
|
||||
completion = complete_code(model, tokenizer, user_input, max_gen_length=50)
|
||||
|
||||
if completion.strip():
|
||||
print(f"🤖 CodeBot suggests: {completion}")
|
||||
print()
|
||||
print(f"Full code:")
|
||||
print(user_input + completion)
|
||||
else:
|
||||
print("⚠️ CodeBot couldn't complete this (maybe it wasn't trained on this pattern?)")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n👋 Interrupted. Thanks for trying CodeBot!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main
|
||||
# ============================================================================
|
||||
|
||||
def main():
|
||||
"""Run CodeBot autocomplete demo."""
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("🤖 CODEBOT - BUILD YOUR OWN MINI-COPILOT!")
|
||||
print("="*70)
|
||||
print()
|
||||
print("You're about to train a transformer to autocomplete Python code.")
|
||||
print()
|
||||
print("In 2 minutes, you'll have a working autocomplete that learned:")
|
||||
print(" • Basic functions (add, multiply, divide)")
|
||||
print(" • For loops and while loops")
|
||||
print(" • If statements and conditionals")
|
||||
print(" • List operations")
|
||||
print(" • Common Python patterns")
|
||||
print()
|
||||
input("Press ENTER to begin training...")
|
||||
|
||||
# Create dataset
|
||||
train_patterns, test_patterns = create_code_dataset()
|
||||
|
||||
# Create tokenizer
|
||||
all_patterns = train_patterns + test_patterns
|
||||
tokenizer = create_tokenizer(all_patterns)
|
||||
|
||||
# Model config (based on proven sweep results)
|
||||
config = {
|
||||
'vocab_size': tokenizer.vocab_size,
|
||||
'embed_dim': 32, # Scaled from winning 16d config
|
||||
'num_layers': 2, # Enough for code patterns
|
||||
'num_heads': 8, # Proven winner from sweep
|
||||
'max_seq_len': 128, # Enough for code snippets
|
||||
}
|
||||
|
||||
# Create model
|
||||
model = GPT(
|
||||
vocab_size=config['vocab_size'],
|
||||
embed_dim=config['embed_dim'],
|
||||
num_layers=config['num_layers'],
|
||||
num_heads=config['num_heads'],
|
||||
max_seq_len=config['max_seq_len'],
|
||||
)
|
||||
|
||||
# Optimizer (proven winning LR)
|
||||
learning_rate = 0.0015
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
# Train
|
||||
losses = train_codebot(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
tokenizer=tokenizer,
|
||||
train_patterns=train_patterns,
|
||||
max_steps=5000,
|
||||
seq_length=config['max_seq_len'],
|
||||
)
|
||||
|
||||
print("Ready to test CodeBot!")
|
||||
input("Press ENTER to see demo...")
|
||||
|
||||
# Demo mode
|
||||
demo_mode(model, tokenizer)
|
||||
|
||||
input("Press ENTER to try it yourself...")
|
||||
|
||||
# Interactive mode
|
||||
interactive_mode(model, tokenizer)
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*70)
|
||||
print("🎓 WHAT YOU LEARNED")
|
||||
print("="*70)
|
||||
print()
|
||||
print("Congratulations! You just:")
|
||||
print(" ✓ Trained a transformer from scratch")
|
||||
print(" ✓ Saw it learn Python patterns in ~2 minutes")
|
||||
print(" ✓ Used it to autocomplete code")
|
||||
print(" ✓ Understood its limits (pattern matching, not reasoning)")
|
||||
print()
|
||||
print("KEY INSIGHTS:")
|
||||
print(" 1. Transformers learn by pattern matching")
|
||||
print(" 2. More training data → smarter completions")
|
||||
print(" 3. They don't 'understand' - they predict patterns")
|
||||
print(" 4. Real Copilot = same idea, billions more patterns!")
|
||||
print()
|
||||
print("SCALING PATH:")
|
||||
print(" • Your CodeBot: 45 patterns → simple completions")
|
||||
print(" • Medium model: 10,000 patterns → decent autocomplete")
|
||||
print(" • GitHub Copilot: BILLIONS of patterns → production-ready!")
|
||||
print()
|
||||
print("Great job! You're now a transformer trainer! 🎉")
|
||||
print("="*70)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@@ -1,481 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TinyTalks Quick Demo - Watch Your Transformer Learn to Talk!
|
||||
=============================================================
|
||||
|
||||
A fast, visual demonstration of transformer training.
|
||||
See the model go from gibberish to coherent answers in ~2 minutes!
|
||||
|
||||
Features:
|
||||
- Smaller model (~50K params) for fast training
|
||||
- Live dashboard showing training progress
|
||||
- Rotating prompts to show diverse capabilities
|
||||
- Learning progression display (gibberish -> coherent)
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
# Rich for live dashboard
|
||||
from rich.console import Console
|
||||
from rich.layout import Layout
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.live import Live
|
||||
from rich.text import Text
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
|
||||
from rich import box
|
||||
|
||||
# TinyTorch imports
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from tinytorch.core.optimizers import Adam
|
||||
from tinytorch.core.losses import CrossEntropyLoss
|
||||
from tinytorch.models.transformer import GPT
|
||||
from tinytorch.text.tokenization import CharTokenizer
|
||||
|
||||
console = Console()
|
||||
|
||||
# =============================================================================
|
||||
# Configuration - Optimized for ~2 minute training
|
||||
# =============================================================================
|
||||
|
||||
CONFIG = {
|
||||
# Model (smaller for speed)
|
||||
"n_layer": 2,
|
||||
"n_head": 2,
|
||||
"n_embd": 64,
|
||||
"max_seq_len": 32, # Shorter sequences for speed
|
||||
|
||||
# Training (optimized for ~2 min on pure Python)
|
||||
"epochs": 8,
|
||||
"batches_per_epoch": 30,
|
||||
"batch_size": 8,
|
||||
"learning_rate": 0.003, # Balanced LR for stable convergence
|
||||
|
||||
# Display
|
||||
"update_interval": 5, # Update dashboard every N batches
|
||||
}
|
||||
|
||||
# Test prompts to show model learning (3 prompts for better progression display)
|
||||
TEST_PROMPTS = [
|
||||
"Q: What is 2+2?\nA:",
|
||||
"Q: What color is the sky?\nA:",
|
||||
"Q: Say hello\nA:",
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
|
||||
class TinyTalksDataset:
|
||||
"""Simple character-level dataset from TinyTalks."""
|
||||
|
||||
def __init__(self, data_path: Path, seq_len: int):
|
||||
self.seq_len = seq_len
|
||||
|
||||
# Load text
|
||||
with open(data_path, 'r') as f:
|
||||
self.text = f.read()
|
||||
|
||||
# Create tokenizer and build vocabulary
|
||||
self.tokenizer = CharTokenizer()
|
||||
self.tokenizer.build_vocab([self.text])
|
||||
|
||||
# Tokenize entire text
|
||||
self.tokens = self.tokenizer.encode(self.text)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tokens) - self.seq_len
|
||||
|
||||
def get_batch(self, batch_size: int):
|
||||
"""Get random batch of sequences."""
|
||||
indices = np.random.randint(0, len(self) - 1, size=batch_size)
|
||||
|
||||
inputs = []
|
||||
targets = []
|
||||
|
||||
for idx in indices:
|
||||
seq = self.tokens[idx:idx + self.seq_len + 1]
|
||||
inputs.append(seq[:-1])
|
||||
targets.append(seq[1:])
|
||||
|
||||
return (
|
||||
Tensor(np.array(inputs)),
|
||||
Tensor(np.array(targets))
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Text Generation
|
||||
# =============================================================================
|
||||
|
||||
def generate_response(model, tokenizer, prompt: str, max_tokens: int = 30) -> str:
|
||||
"""Generate text from prompt."""
|
||||
# Encode prompt
|
||||
tokens = tokenizer.encode(prompt)
|
||||
|
||||
for _ in range(max_tokens):
|
||||
# Prepare input
|
||||
context = tokens[-CONFIG["max_seq_len"]:]
|
||||
x = Tensor(np.array([context]))
|
||||
|
||||
# Forward pass
|
||||
logits = model.forward(x)
|
||||
|
||||
# Get next token probabilities
|
||||
last_logits = logits.data[0, -1, :]
|
||||
|
||||
# Temperature sampling
|
||||
temperature = 0.8
|
||||
last_logits = last_logits / temperature
|
||||
exp_logits = np.exp(last_logits - np.max(last_logits))
|
||||
probs = exp_logits / np.sum(exp_logits)
|
||||
|
||||
# Sample
|
||||
next_token = np.random.choice(len(probs), p=probs)
|
||||
tokens.append(next_token)
|
||||
|
||||
# Stop at newline (end of answer)
|
||||
if tokenizer.decode([next_token]) == '\n':
|
||||
break
|
||||
|
||||
# Decode and extract answer
|
||||
full_text = tokenizer.decode(tokens)
|
||||
|
||||
# Get just the answer part
|
||||
if "A:" in full_text:
|
||||
answer = full_text.split("A:")[-1].strip()
|
||||
# Clean up - take first line
|
||||
answer = answer.split('\n')[0].strip()
|
||||
return answer if answer else "(empty)"
|
||||
|
||||
return full_text[len(prompt):].strip() or "(empty)"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Dashboard Layout
|
||||
# =============================================================================
|
||||
|
||||
def make_layout() -> Layout:
|
||||
"""Create the dashboard layout."""
|
||||
layout = Layout()
|
||||
|
||||
layout.split_column(
|
||||
Layout(name="header", size=3),
|
||||
Layout(name="main", ratio=1),
|
||||
Layout(name="footer", size=3),
|
||||
)
|
||||
|
||||
layout["main"].split_row(
|
||||
Layout(name="left", ratio=1),
|
||||
Layout(name="outputs", ratio=2),
|
||||
)
|
||||
|
||||
layout["left"].split_column(
|
||||
Layout(name="progress", ratio=2),
|
||||
Layout(name="stats", ratio=1),
|
||||
)
|
||||
|
||||
return layout
|
||||
|
||||
|
||||
def make_header() -> Panel:
|
||||
"""Create header panel."""
|
||||
return Panel(
|
||||
Text("TinyTalks Quick Demo - Watch Your Transformer Learn!",
|
||||
style="bold cyan", justify="center"),
|
||||
box=box.ROUNDED,
|
||||
style="cyan",
|
||||
)
|
||||
|
||||
|
||||
def make_progress_panel(epoch: int, total_epochs: int, batch: int,
|
||||
total_batches: int, loss: float, elapsed: float) -> Panel:
|
||||
"""Create training progress panel."""
|
||||
# Calculate overall progress
|
||||
total_steps = total_epochs * total_batches
|
||||
current_step = (epoch - 1) * total_batches + batch
|
||||
progress_pct = (current_step / total_steps) * 100
|
||||
|
||||
# Progress bar
|
||||
bar_width = 20
|
||||
filled = int(bar_width * progress_pct / 100)
|
||||
bar = "█" * filled + "░" * (bar_width - filled)
|
||||
|
||||
# Estimate time remaining
|
||||
if current_step > 0:
|
||||
time_per_step = elapsed / current_step
|
||||
remaining_steps = total_steps - current_step
|
||||
eta = remaining_steps * time_per_step
|
||||
eta_str = f"{eta:.0f}s"
|
||||
else:
|
||||
eta_str = "..."
|
||||
|
||||
content = Text()
|
||||
content.append(f"Epoch: {epoch}/{total_epochs}\n", style="bold")
|
||||
content.append(f"Batch: {batch}/{total_batches}\n")
|
||||
content.append(f"Loss: {loss:.3f}\n\n", style="yellow")
|
||||
content.append(f"{bar} {progress_pct:.0f}%\n\n", style="green")
|
||||
content.append(f"Elapsed: {elapsed:.0f}s\n")
|
||||
content.append(f"ETA: {eta_str}")
|
||||
|
||||
return Panel(
|
||||
content,
|
||||
title="[bold]Training Progress[/bold]",
|
||||
border_style="green",
|
||||
box=box.ROUNDED,
|
||||
)
|
||||
|
||||
|
||||
def make_outputs_panel(responses: dict, epoch: int) -> Panel:
|
||||
"""Create model outputs panel showing all epoch responses as a log."""
|
||||
content = Text()
|
||||
|
||||
# Show all 3 prompts with full epoch history
|
||||
for i, prompt in enumerate(TEST_PROMPTS):
|
||||
q = prompt.split('\n')[0]
|
||||
content.append(f"{q}\n", style="cyan bold")
|
||||
|
||||
# Show all epochs completed so far
|
||||
for ep in range(1, epoch + 1):
|
||||
key = f"epoch_{ep}_{i}"
|
||||
response = responses.get(key, "...")
|
||||
# Most recent epoch is highlighted
|
||||
style = "white" if ep == epoch else "dim"
|
||||
content.append(f" Ep{ep}: ", style="yellow")
|
||||
# Truncate long responses to fit
|
||||
display_response = response[:25] + "..." if len(response) > 25 else response
|
||||
content.append(f"{display_response}\n", style=style)
|
||||
|
||||
content.append("\n")
|
||||
|
||||
return Panel(
|
||||
content,
|
||||
title=f"[bold]Learning Progression (Epoch {epoch})[/bold]",
|
||||
border_style="blue",
|
||||
box=box.ROUNDED,
|
||||
)
|
||||
|
||||
|
||||
def make_stats_panel(stats: dict) -> Panel:
|
||||
"""Create systems stats panel."""
|
||||
content = Text()
|
||||
|
||||
content.append("Performance Metrics\n", style="bold")
|
||||
content.append(f" Tokens/sec: {stats.get('tokens_per_sec', 0):.1f}\n")
|
||||
content.append(f" Batch time: {stats.get('batch_time_ms', 0):.0f}ms\n")
|
||||
content.append(f" Memory: {stats.get('memory_mb', 0):.1f}MB\n\n")
|
||||
|
||||
content.append("Model Stats\n", style="bold")
|
||||
content.append(f" Parameters: {stats.get('params', 0):,}\n")
|
||||
content.append(f" Vocab size: {stats.get('vocab_size', 0)}\n")
|
||||
|
||||
return Panel(
|
||||
content,
|
||||
title="[bold]Systems[/bold]",
|
||||
border_style="magenta",
|
||||
box=box.ROUNDED,
|
||||
)
|
||||
|
||||
|
||||
def make_footer(message: str = "") -> Panel:
|
||||
"""Create footer panel."""
|
||||
if not message:
|
||||
message = "Training in progress... Watch the model learn to answer questions!"
|
||||
|
||||
return Panel(
|
||||
Text(message, style="dim", justify="center"),
|
||||
box=box.ROUNDED,
|
||||
style="dim",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Training Loop
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
"""Main training function with live dashboard."""
|
||||
|
||||
# Welcome
|
||||
console.print()
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]TinyTalks Quick Demo[/bold cyan]\n\n"
|
||||
"Watch a transformer learn to answer questions in real-time!\n"
|
||||
"The model starts with random weights (gibberish output)\n"
|
||||
"and learns to produce coherent answers.\n\n"
|
||||
"[dim]Training time: ~2 minutes[/dim]",
|
||||
title="Welcome",
|
||||
border_style="cyan",
|
||||
))
|
||||
console.print()
|
||||
|
||||
# Load dataset
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
data_path = project_root / "datasets" / "tinytalks" / "splits" / "train.txt"
|
||||
|
||||
if not data_path.exists():
|
||||
console.print(f"[red]Error: Dataset not found at {data_path}[/red]")
|
||||
console.print("[yellow]Please ensure TinyTalks dataset is available.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(f"[dim]Loading dataset from {data_path}...[/dim]")
|
||||
dataset = TinyTalksDataset(data_path, CONFIG["max_seq_len"])
|
||||
console.print(f"[green]✓[/green] Loaded {len(dataset.text):,} characters, vocab size: {dataset.tokenizer.vocab_size}")
|
||||
|
||||
# Create model
|
||||
console.print("[dim]Creating model...[/dim]")
|
||||
model = GPT(
|
||||
vocab_size=dataset.tokenizer.vocab_size,
|
||||
embed_dim=CONFIG["n_embd"],
|
||||
num_heads=CONFIG["n_head"],
|
||||
num_layers=CONFIG["n_layer"],
|
||||
max_seq_len=CONFIG["max_seq_len"],
|
||||
)
|
||||
|
||||
# Count parameters
|
||||
param_count = sum(p.data.size for p in model.parameters())
|
||||
console.print(f"[green]✓[/green] Model created: {param_count:,} parameters")
|
||||
console.print(f"[dim] {CONFIG['n_layer']} layers, {CONFIG['n_head']} heads, {CONFIG['n_embd']} embed dim[/dim]")
|
||||
|
||||
# Setup training
|
||||
optimizer = Adam(model.parameters(), lr=CONFIG["learning_rate"])
|
||||
criterion = CrossEntropyLoss()
|
||||
|
||||
console.print()
|
||||
console.print("[bold green]Starting training with live dashboard...[/bold green]")
|
||||
console.print("[dim]Press Ctrl+C to stop early[/dim]")
|
||||
console.print()
|
||||
time.sleep(1)
|
||||
|
||||
# Storage for responses and stats
|
||||
responses = {}
|
||||
stats = {
|
||||
"params": param_count,
|
||||
"vocab_size": dataset.tokenizer.vocab_size,
|
||||
"tokens_per_sec": 0,
|
||||
"batch_time_ms": 0,
|
||||
"memory_mb": param_count * 4 / (1024 * 1024), # Rough estimate
|
||||
}
|
||||
|
||||
# Create layout
|
||||
layout = make_layout()
|
||||
|
||||
# Training loop with live display
|
||||
start_time = time.time()
|
||||
current_loss = 0.0
|
||||
total_tokens = 0
|
||||
|
||||
try:
|
||||
with Live(layout, console=console, refresh_per_second=4) as live:
|
||||
for epoch in range(1, CONFIG["epochs"] + 1):
|
||||
epoch_loss = 0.0
|
||||
|
||||
for batch_idx in range(1, CONFIG["batches_per_epoch"] + 1):
|
||||
batch_start = time.time()
|
||||
|
||||
# Get batch
|
||||
inputs, targets = dataset.get_batch(CONFIG["batch_size"])
|
||||
|
||||
# Forward pass
|
||||
logits = model.forward(inputs)
|
||||
|
||||
# Reshape for loss
|
||||
batch_size, seq_len, vocab_size = logits.shape
|
||||
logits_flat = logits.reshape(batch_size * seq_len, vocab_size)
|
||||
targets_flat = targets.reshape(-1)
|
||||
|
||||
# Compute loss
|
||||
loss = criterion(logits_flat, targets_flat)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Update
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Track loss and stats
|
||||
batch_loss = float(loss.data)
|
||||
epoch_loss += batch_loss
|
||||
current_loss = epoch_loss / batch_idx
|
||||
|
||||
# Update systems stats
|
||||
batch_time = time.time() - batch_start
|
||||
tokens_in_batch = batch_size * seq_len
|
||||
total_tokens += tokens_in_batch
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
stats["batch_time_ms"] = batch_time * 1000
|
||||
stats["tokens_per_sec"] = total_tokens / elapsed if elapsed > 0 else 0
|
||||
|
||||
# Update dashboard
|
||||
layout["header"].update(make_header())
|
||||
layout["progress"].update(make_progress_panel(
|
||||
epoch, CONFIG["epochs"],
|
||||
batch_idx, CONFIG["batches_per_epoch"],
|
||||
current_loss, elapsed
|
||||
))
|
||||
layout["stats"].update(make_stats_panel(stats))
|
||||
layout["outputs"].update(make_outputs_panel(responses, epoch))
|
||||
layout["footer"].update(make_footer())
|
||||
|
||||
# End of epoch - generate sample responses
|
||||
for i, prompt in enumerate(TEST_PROMPTS):
|
||||
response = generate_response(model, dataset.tokenizer, prompt)
|
||||
responses[f"epoch_{epoch}_{i}"] = response
|
||||
|
||||
# Update display with new responses
|
||||
layout["outputs"].update(make_outputs_panel(responses, epoch))
|
||||
|
||||
# Show epoch completion message
|
||||
layout["footer"].update(make_footer(
|
||||
f"Epoch {epoch} complete! Loss: {current_loss:.3f}"
|
||||
))
|
||||
|
||||
# Training complete
|
||||
total_time = time.time() - start_time
|
||||
|
||||
console.print()
|
||||
console.print(Panel.fit(
|
||||
f"[bold green]Training Complete![/bold green]\n\n"
|
||||
f"Total time: {total_time:.1f} seconds\n"
|
||||
f"Final loss: {current_loss:.3f}\n"
|
||||
f"Epochs: {CONFIG['epochs']}\n\n"
|
||||
"[cyan]Watch how your transformer learned to talk![/cyan]",
|
||||
title="Success",
|
||||
border_style="green",
|
||||
))
|
||||
|
||||
# Show learning progression for all prompts
|
||||
console.print()
|
||||
console.print("[bold]Full Learning Progression:[/bold]")
|
||||
console.print()
|
||||
|
||||
for i, prompt in enumerate(TEST_PROMPTS):
|
||||
q = prompt.split('\n')[0]
|
||||
table = Table(box=box.ROUNDED, title=q)
|
||||
table.add_column("Epoch", style="cyan")
|
||||
table.add_column("Response", style="white")
|
||||
|
||||
for epoch in range(1, CONFIG["epochs"] + 1):
|
||||
key = f"epoch_{epoch}_{i}"
|
||||
resp = responses.get(key, "...")
|
||||
table.add_row(str(epoch), resp)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]Training stopped by user[/yellow]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
567
milestones/06_2018_mlperf/01_optimization_olympics.py
Normal file
567
milestones/06_2018_mlperf/01_optimization_olympics.py
Normal file
@@ -0,0 +1,567 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
╔══════════════════════════════════════════════════════════════════════════════╗
|
||||
║ 🏆 MILESTONE 06: The Optimization Olympics (MLPerf 2018) ║
|
||||
║ Compress and Accelerate Your Neural Network ║
|
||||
╚══════════════════════════════════════════════════════════════════════════════╝
|
||||
|
||||
Historical Context:
|
||||
In 2018, MLPerf was launched to standardize ML benchmarking. The key insight:
|
||||
It's not just about accuracy - production ML needs efficiency too.
|
||||
|
||||
🎯 WHAT YOU'LL LEARN:
|
||||
1. How to PROFILE your model (parameters, size, speed)
|
||||
2. How to QUANTIZE (FP32 → INT8 = 4× smaller)
|
||||
3. How to PRUNE (remove small weights = 2-4× smaller)
|
||||
4. How to measure the TRADEOFFS (accuracy vs efficiency)
|
||||
|
||||
🏗️ THE OPTIMIZATION PIPELINE:
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ YOUR TRAINED MODEL │
|
||||
│ Accurate but large and slow │
|
||||
└───────────────────────────────┬─────────────────────────────────────────┘
|
||||
│
|
||||
┌───────────────────────────────▼─────────────────────────────────────────┐
|
||||
│ STEP 1: PROFILE │
|
||||
│ Count parameters, measure latency │
|
||||
└───────────────────────────────┬─────────────────────────────────────────┘
|
||||
│
|
||||
┌───────────────────────────────▼─────────────────────────────────────────┐
|
||||
│ STEP 2: QUANTIZE │
|
||||
│ FP32 → INT8 (4× compression) │
|
||||
└───────────────────────────────┬─────────────────────────────────────────┘
|
||||
│
|
||||
┌───────────────────────────────▼─────────────────────────────────────────┐
|
||||
│ STEP 3: PRUNE │
|
||||
│ Remove small weights (2-4× compression) │
|
||||
└───────────────────────────────┬─────────────────────────────────────────┘
|
||||
│
|
||||
┌───────────────────────────────▼─────────────────────────────────────────┐
|
||||
│ OPTIMIZED MODEL 🎉 │
|
||||
│ 8-16× smaller, minimal accuracy loss │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
✅ REQUIRED MODULES (Run after Module 16):
|
||||
Module 14 (Profiling) : YOUR profiling tools
|
||||
Module 15 (Quantization) : YOUR quantization implementation
|
||||
Module 16 (Compression) : YOUR pruning techniques
|
||||
|
||||
📊 EXPECTED RESULTS:
|
||||
| Optimization | Size | Accuracy | Notes |
|
||||
|---------------|---------|----------|--------------------------|
|
||||
| Baseline | 100% | 85-90% | Full precision |
|
||||
| + Quantization| 25% | 84-89% | INT8 weights |
|
||||
| + Pruning | 12.5% | 82-87% | 50% weights removed |
|
||||
| Combined | ~10% | 80-85% | Production ready! |
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root
|
||||
sys.path.insert(0, os.getcwd())
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
from rich import box
|
||||
|
||||
console = Console()
|
||||
|
||||
# ============================================================================
|
||||
# SIMPLE MLP FOR DEMONSTRATION
|
||||
# ============================================================================
|
||||
|
||||
class SimpleMLP:
|
||||
"""Simple MLP for digit classification - the optimization target."""
|
||||
|
||||
def __init__(self, input_size=64, hidden_size=32, num_classes=10):
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from tinytorch.core.layers import Linear
|
||||
from tinytorch.core.activations import ReLU
|
||||
|
||||
self.fc1 = Linear(input_size, hidden_size)
|
||||
self.relu = ReLU()
|
||||
self.fc2 = Linear(hidden_size, num_classes)
|
||||
|
||||
# Store weight references for optimization
|
||||
self.layers = [self.fc1, self.fc2]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
def __call__(self, x):
|
||||
return self.forward(x)
|
||||
|
||||
def parameters(self):
|
||||
params = []
|
||||
for layer in self.layers:
|
||||
# Check both 'weights' and 'weight' (different naming conventions)
|
||||
if hasattr(layer, 'weights'):
|
||||
params.append(layer.weights)
|
||||
elif hasattr(layer, 'weight'):
|
||||
params.append(layer.weight)
|
||||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||||
params.append(layer.bias)
|
||||
return params
|
||||
|
||||
def get_weights(self):
|
||||
"""Get all weights as a list of numpy arrays."""
|
||||
weights = []
|
||||
for layer in self.layers:
|
||||
if hasattr(layer, 'weights'):
|
||||
weights.append(layer.weights.data.copy())
|
||||
elif hasattr(layer, 'weight'):
|
||||
weights.append(layer.weight.data.copy())
|
||||
return weights
|
||||
|
||||
def set_weights(self, weights):
|
||||
"""Set all weights from a list of numpy arrays."""
|
||||
idx = 0
|
||||
for layer in self.layers:
|
||||
if hasattr(layer, 'weights'):
|
||||
layer.weights.data = weights[idx].copy()
|
||||
idx += 1
|
||||
elif hasattr(layer, 'weight'):
|
||||
layer.weight.data = weights[idx].copy()
|
||||
idx += 1
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OPTIMIZATION FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
def count_parameters(model):
|
||||
"""Count total parameters in model."""
|
||||
total = 0
|
||||
for param in model.parameters():
|
||||
total += param.data.size
|
||||
return total
|
||||
|
||||
|
||||
def model_size_bytes(model):
|
||||
"""Calculate model size in bytes (FP32)."""
|
||||
return count_parameters(model) * 4 # 4 bytes per float32
|
||||
|
||||
|
||||
def quantize_weights(weights, bits=8):
|
||||
"""
|
||||
Quantize weights to lower precision.
|
||||
|
||||
FP32 (4 bytes) → INT8 (1 byte) = 4× compression
|
||||
"""
|
||||
quantized = []
|
||||
for w in weights:
|
||||
# Simple post-training quantization
|
||||
w_min, w_max = w.min(), w.max()
|
||||
scale = (w_max - w_min) / (2**bits - 1)
|
||||
if scale == 0:
|
||||
scale = 1.0
|
||||
|
||||
# Quantize
|
||||
w_int = np.round((w - w_min) / scale).astype(np.int8)
|
||||
|
||||
# Dequantize for inference
|
||||
w_dequant = w_int.astype(np.float32) * scale + w_min
|
||||
quantized.append(w_dequant)
|
||||
|
||||
return quantized
|
||||
|
||||
|
||||
def prune_weights(weights, sparsity=0.5):
|
||||
"""
|
||||
Prune weights by setting smallest magnitudes to zero.
|
||||
|
||||
Sparsity 0.5 = 50% of weights are zeros = 2× compression potential
|
||||
"""
|
||||
pruned = []
|
||||
for w in weights:
|
||||
# Find threshold
|
||||
flat = np.abs(w.flatten())
|
||||
threshold = np.percentile(flat, sparsity * 100)
|
||||
|
||||
# Prune
|
||||
mask = np.abs(w) > threshold
|
||||
w_pruned = w * mask
|
||||
pruned.append(w_pruned)
|
||||
|
||||
return pruned
|
||||
|
||||
|
||||
def evaluate_accuracy(model, X, y):
|
||||
"""Evaluate model accuracy."""
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
logits = model(Tensor(X))
|
||||
preds = np.argmax(logits.data, axis=1)
|
||||
accuracy = np.mean(preds == y.flatten()) * 100
|
||||
return accuracy
|
||||
|
||||
|
||||
def measure_latency(model, X, n_runs=100):
|
||||
"""Measure average inference latency."""
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = model(Tensor(X[:1]))
|
||||
|
||||
# Measure
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
start = time.perf_counter()
|
||||
_ = model(Tensor(X[:1]))
|
||||
end = time.perf_counter()
|
||||
times.append(end - start)
|
||||
|
||||
return np.mean(times) * 1000 # ms
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MAIN MILESTONE
|
||||
# ============================================================================
|
||||
|
||||
def load_tinydigits():
|
||||
"""Load TinyDigits dataset (bundled with TinyTorch)."""
|
||||
import pickle
|
||||
|
||||
# Find dataset
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
train_path = project_root / "datasets" / "tinydigits" / "train.pkl"
|
||||
test_path = project_root / "datasets" / "tinydigits" / "test.pkl"
|
||||
|
||||
if train_path.exists() and test_path.exists():
|
||||
with open(train_path, 'rb') as f:
|
||||
train_data = pickle.load(f)
|
||||
with open(test_path, 'rb') as f:
|
||||
test_data = pickle.load(f)
|
||||
|
||||
X_train = train_data['images'].astype(np.float32)
|
||||
y_train = train_data['labels'].reshape(-1, 1)
|
||||
X_test = test_data['images'].astype(np.float32)
|
||||
y_test = test_data['labels'].reshape(-1, 1)
|
||||
|
||||
# Flatten images if they're 2D (8x8) - MLP needs flat input
|
||||
if len(X_train.shape) == 3:
|
||||
X_train = X_train.reshape(X_train.shape[0], -1)
|
||||
X_test = X_test.reshape(X_test.shape[0], -1)
|
||||
|
||||
return X_train, y_train, X_test, y_test
|
||||
|
||||
# Fallback: try alternative path or generate synthetic
|
||||
alt_path = project_root / "datasets" / "tinydigits" / "tinydigits.pkl"
|
||||
if alt_path.exists():
|
||||
with open(alt_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
return data['X_train'], data['y_train'], data['X_test'], data['y_test']
|
||||
|
||||
console.print(f"[yellow]Dataset not found, using synthetic data for demo...[/yellow]")
|
||||
|
||||
# Generate synthetic digit-like data
|
||||
np.random.seed(42)
|
||||
X_train = np.random.randn(1000, 64).astype(np.float32)
|
||||
y_train = np.random.randint(0, 10, size=(1000, 1))
|
||||
X_test = np.random.randn(200, 64).astype(np.float32)
|
||||
y_test = np.random.randint(0, 10, size=(200, 1))
|
||||
return X_train, y_train, X_test, y_test
|
||||
|
||||
|
||||
def train_baseline(model, X_train, y_train, epochs=10, lr=0.01):
|
||||
"""Quick training of baseline model."""
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from tinytorch.core.losses import CrossEntropyLoss
|
||||
from tinytorch.core.optimizers import SGD
|
||||
|
||||
loss_fn = CrossEntropyLoss()
|
||||
optimizer = SGD(model.parameters(), lr=lr)
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
batch_size = 32
|
||||
n_batches = len(X_train) // batch_size
|
||||
|
||||
for epoch in range(epochs):
|
||||
indices = np.random.permutation(len(X_train))
|
||||
|
||||
for i in range(n_batches):
|
||||
batch_idx = indices[i*batch_size:(i+1)*batch_size]
|
||||
X_batch = Tensor(X_train[batch_idx])
|
||||
y_batch = Tensor(y_train[batch_idx].flatten())
|
||||
|
||||
logits = model(X_batch)
|
||||
loss = loss_fn(logits, y_batch)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the Optimization Olympics!"""
|
||||
|
||||
# Welcome
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
"[bold cyan]🏆 THE OPTIMIZATION OLYMPICS[/bold cyan]\n\n"
|
||||
"[dim]MLPerf 2018: Where accuracy meets efficiency[/dim]\n\n"
|
||||
"Today you'll learn to compress a neural network\n"
|
||||
"while preserving accuracy - just like real production ML!",
|
||||
title="Milestone 06: MLPerf",
|
||||
border_style="cyan",
|
||||
box=box.DOUBLE
|
||||
))
|
||||
console.print()
|
||||
|
||||
# ========================================================================
|
||||
# STEP 1: BASELINE
|
||||
# ========================================================================
|
||||
|
||||
console.print(Panel(
|
||||
"[bold yellow]📊 STEP 1: Establish Baseline[/bold yellow]\n"
|
||||
"Train a model and measure its performance",
|
||||
border_style="yellow"
|
||||
))
|
||||
|
||||
# Load data
|
||||
console.print("[dim]Loading TinyDigits dataset...[/dim]")
|
||||
X_train, y_train, X_test, y_test = load_tinydigits()
|
||||
console.print(f" ✓ Training: {len(X_train)} samples")
|
||||
console.print(f" ✓ Test: {len(X_test)} samples")
|
||||
console.print()
|
||||
|
||||
# Train baseline
|
||||
console.print("[dim]Training baseline MLP...[/dim]")
|
||||
model = SimpleMLP(input_size=64, hidden_size=32, num_classes=10)
|
||||
|
||||
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console) as progress:
|
||||
task = progress.add_task("Training...", total=None)
|
||||
train_baseline(model, X_train, y_train, epochs=15)
|
||||
|
||||
# Baseline metrics
|
||||
baseline_params = count_parameters(model)
|
||||
baseline_size = model_size_bytes(model)
|
||||
baseline_acc = evaluate_accuracy(model, X_test, y_test)
|
||||
baseline_latency = measure_latency(model, X_test)
|
||||
baseline_weights = model.get_weights()
|
||||
|
||||
# Show baseline
|
||||
table = Table(title="📊 Baseline Model", box=box.ROUNDED)
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Value", style="yellow")
|
||||
table.add_column("Notes", style="dim")
|
||||
|
||||
table.add_row("Parameters", f"{baseline_params:,}", "Total trainable weights")
|
||||
table.add_row("Size", f"{baseline_size:,} bytes", "FP32 precision")
|
||||
table.add_row("Accuracy", f"{baseline_acc:.1f}%", "Test set performance")
|
||||
table.add_row("Latency", f"{baseline_latency:.3f} ms", "Per-sample inference")
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
# ========================================================================
|
||||
# STEP 2: QUANTIZATION
|
||||
# ========================================================================
|
||||
|
||||
console.print(Panel(
|
||||
"[bold blue]🗜️ STEP 2: Quantization (FP32 → INT8)[/bold blue]\n"
|
||||
"Reduce precision: 4 bytes → 1 byte = 4× smaller",
|
||||
border_style="blue"
|
||||
))
|
||||
|
||||
# Apply quantization
|
||||
quantized_weights = quantize_weights(baseline_weights, bits=8)
|
||||
model.set_weights(quantized_weights)
|
||||
|
||||
quant_size = baseline_size // 4 # INT8 is 4× smaller
|
||||
quant_acc = evaluate_accuracy(model, X_test, y_test)
|
||||
quant_latency = measure_latency(model, X_test)
|
||||
|
||||
# Show quantization results
|
||||
table = Table(title="🗜️ After Quantization", box=box.ROUNDED)
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Before", style="yellow")
|
||||
table.add_column("After", style="green")
|
||||
table.add_column("Change", style="bold")
|
||||
|
||||
table.add_row(
|
||||
"Size",
|
||||
f"{baseline_size:,} B",
|
||||
f"{quant_size:,} B",
|
||||
f"[green]4× smaller[/green]"
|
||||
)
|
||||
table.add_row(
|
||||
"Accuracy",
|
||||
f"{baseline_acc:.1f}%",
|
||||
f"{quant_acc:.1f}%",
|
||||
f"[{'green' if abs(baseline_acc - quant_acc) < 2 else 'yellow'}]{baseline_acc - quant_acc:+.1f}%[/]"
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
# ========================================================================
|
||||
# STEP 3: PRUNING
|
||||
# ========================================================================
|
||||
|
||||
console.print(Panel(
|
||||
"[bold magenta]✂️ STEP 3: Pruning (Remove Small Weights)[/bold magenta]\n"
|
||||
"Set 50% of smallest weights to zero = 2× compression potential",
|
||||
border_style="magenta"
|
||||
))
|
||||
|
||||
# Apply pruning
|
||||
pruned_weights = prune_weights(baseline_weights, sparsity=0.5)
|
||||
model.set_weights(pruned_weights)
|
||||
|
||||
# Count zeros
|
||||
total_weights = sum(w.size for w in pruned_weights)
|
||||
zero_weights = sum(np.sum(w == 0) for w in pruned_weights)
|
||||
sparsity = (zero_weights / total_weights * 100) if total_weights > 0 else 0
|
||||
|
||||
pruned_acc = evaluate_accuracy(model, X_test, y_test)
|
||||
|
||||
# Show pruning results
|
||||
table = Table(title="✂️ After Pruning", box=box.ROUNDED)
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Before", style="yellow")
|
||||
table.add_column("After", style="green")
|
||||
table.add_column("Change", style="bold")
|
||||
|
||||
table.add_row(
|
||||
"Non-zero weights",
|
||||
f"{total_weights:,}",
|
||||
f"{total_weights - zero_weights:,}",
|
||||
f"[green]{sparsity:.0f}% pruned[/green]"
|
||||
)
|
||||
table.add_row(
|
||||
"Accuracy",
|
||||
f"{baseline_acc:.1f}%",
|
||||
f"{pruned_acc:.1f}%",
|
||||
f"[{'green' if abs(baseline_acc - pruned_acc) < 5 else 'yellow'}]{baseline_acc - pruned_acc:+.1f}%[/]"
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
# ========================================================================
|
||||
# STEP 4: COMBINED
|
||||
# ========================================================================
|
||||
|
||||
console.print(Panel(
|
||||
"[bold green]🎯 STEP 4: Combined Optimization[/bold green]\n"
|
||||
"Apply BOTH quantization AND pruning",
|
||||
border_style="green"
|
||||
))
|
||||
|
||||
# Apply both
|
||||
combined_weights = prune_weights(baseline_weights, sparsity=0.5)
|
||||
combined_weights = quantize_weights(combined_weights, bits=8)
|
||||
model.set_weights(combined_weights)
|
||||
|
||||
combined_size = quant_size # Still quantized
|
||||
combined_acc = evaluate_accuracy(model, X_test, y_test)
|
||||
|
||||
# Calculate effective compression (quantization + sparsity)
|
||||
effective_compression = 4 * 2 # 4× from quantization, potential 2× from sparsity
|
||||
|
||||
console.print()
|
||||
|
||||
# ========================================================================
|
||||
# FINAL RESULTS
|
||||
# ========================================================================
|
||||
|
||||
console.print("=" * 70)
|
||||
console.print(Panel("[bold]🏆 OPTIMIZATION OLYMPICS RESULTS[/bold]", border_style="gold1"))
|
||||
console.print()
|
||||
|
||||
# Final comparison table
|
||||
table = Table(title="🎖️ Final Standings", box=box.DOUBLE)
|
||||
table.add_column("Configuration", style="cyan", width=20)
|
||||
table.add_column("Size", style="yellow", justify="right")
|
||||
table.add_column("Accuracy", style="green", justify="right")
|
||||
table.add_column("Compression", style="bold magenta", justify="right")
|
||||
|
||||
table.add_row(
|
||||
"🥇 Baseline (FP32)",
|
||||
f"{baseline_size:,} B",
|
||||
f"{baseline_acc:.1f}%",
|
||||
"1×"
|
||||
)
|
||||
table.add_row(
|
||||
"🥈 + Quantization",
|
||||
f"{quant_size:,} B",
|
||||
f"{quant_acc:.1f}%",
|
||||
"[green]4×[/green]"
|
||||
)
|
||||
table.add_row(
|
||||
"🥉 + Pruning",
|
||||
f"~{baseline_size//2:,} B*",
|
||||
f"{pruned_acc:.1f}%",
|
||||
"[green]2×[/green]"
|
||||
)
|
||||
table.add_row(
|
||||
"🏆 Combined",
|
||||
f"~{baseline_size//8:,} B*",
|
||||
f"{combined_acc:.1f}%",
|
||||
f"[bold green]{effective_compression}×[/bold green]"
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
console.print("[dim]* Effective size with sparse storage[/dim]")
|
||||
console.print()
|
||||
|
||||
# Key insights
|
||||
console.print(Panel(
|
||||
"[bold green]🎓 KEY INSIGHTS[/bold green]\n\n"
|
||||
f"✅ [cyan]Quantization (FP32 → INT8):[/cyan]\n"
|
||||
f" • 4× smaller model size\n"
|
||||
f" • {abs(baseline_acc - quant_acc):.1f}% accuracy impact\n"
|
||||
f" • [dim]Used by: TensorRT, ONNX Runtime, mobile deployment[/dim]\n\n"
|
||||
f"✅ [cyan]Pruning (Remove Small Weights):[/cyan]\n"
|
||||
f" • {sparsity:.0f}% weights removed\n"
|
||||
f" • {abs(baseline_acc - pruned_acc):.1f}% accuracy impact\n"
|
||||
f" • [dim]Used by: Mobile models, edge deployment[/dim]\n\n"
|
||||
f"✅ [cyan]Combined:[/cyan]\n"
|
||||
f" • {effective_compression}× total compression\n"
|
||||
f" • {abs(baseline_acc - combined_acc):.1f}% accuracy impact\n"
|
||||
f" • [dim]The secret sauce of production ML![/dim]",
|
||||
border_style="cyan",
|
||||
box=box.ROUNDED
|
||||
))
|
||||
|
||||
# Verdict
|
||||
accuracy_drop = baseline_acc - combined_acc
|
||||
if accuracy_drop < 5:
|
||||
verdict = "[bold green]🏆 EXCELLENT![/bold green] Great compression with minimal accuracy loss!"
|
||||
elif accuracy_drop < 10:
|
||||
verdict = "[bold yellow]🥈 GOOD![/bold yellow] Solid compression, acceptable accuracy tradeoff."
|
||||
else:
|
||||
verdict = "[bold red]⚠️ HIGH LOSS[/bold red] - Consider less aggressive settings."
|
||||
|
||||
console.print(Panel(
|
||||
f"{verdict}\n\n"
|
||||
f"[dim]You achieved {effective_compression}× compression with {accuracy_drop:.1f}% accuracy loss.[/dim]\n\n"
|
||||
"[bold cyan]What you learned:[/bold cyan]\n"
|
||||
" ✅ How to profile ML models\n"
|
||||
" ✅ Quantization: reduce precision for smaller models\n"
|
||||
" ✅ Pruning: remove weights for sparser models\n"
|
||||
" ✅ The accuracy-efficiency tradeoff\n\n"
|
||||
"[bold]This is how production ML systems are deployed![/bold]",
|
||||
title="🎯 Milestone 06 Complete",
|
||||
border_style="green",
|
||||
box=box.DOUBLE
|
||||
))
|
||||
console.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
247
tests/milestones/test_milestones_run.py
Normal file
247
tests/milestones/test_milestones_run.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Milestone Execution Tests
|
||||
|
||||
WHAT: Verify all milestones can execute without errors.
|
||||
WHY: Milestones are the key student checkpoints - they MUST work reliably.
|
||||
Broken milestones = frustrated students = bad learning experience.
|
||||
|
||||
STUDENT LEARNING:
|
||||
These tests ensure the 6 historical milestones are always working:
|
||||
1. Perceptron (1957) - First neural network
|
||||
2. XOR Crisis (1969) - Multi-layer networks
|
||||
3. MLP Revival (1986) - Backpropagation
|
||||
4. CNN Revolution (1998) - Spatial networks
|
||||
5. Transformer Era (2017) - Attention mechanism
|
||||
6. MLPerf (2018) - Optimization techniques
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
# Project root
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class TestMilestone01Perceptron:
|
||||
"""Test Milestone 01: Perceptron (1957)"""
|
||||
|
||||
def test_perceptron_forward_runs(self):
|
||||
"""
|
||||
WHAT: Verify the perceptron forward pass demo runs.
|
||||
WHY: This is the first milestone - it must work to build confidence.
|
||||
"""
|
||||
script = PROJECT_ROOT / "milestones" / "01_1957_perceptron" / "01_rosenblatt_forward.py"
|
||||
if not script.exists():
|
||||
pytest.skip(f"Script not found: {script}")
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"Perceptron forward failed:\n{result.stderr}"
|
||||
|
||||
def test_perceptron_trained_runs(self):
|
||||
"""
|
||||
WHAT: Verify the trained perceptron demo runs.
|
||||
WHY: This proves the full training loop works.
|
||||
"""
|
||||
script = PROJECT_ROOT / "milestones" / "01_1957_perceptron" / "02_rosenblatt_trained.py"
|
||||
if not script.exists():
|
||||
pytest.skip(f"Script not found: {script}")
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"Perceptron trained failed:\n{result.stderr}"
|
||||
|
||||
|
||||
class TestMilestone02XOR:
|
||||
"""Test Milestone 02: XOR Crisis (1969)"""
|
||||
|
||||
def test_xor_crisis_runs(self):
|
||||
"""
|
||||
WHAT: Verify the XOR crisis demo runs (shows single-layer failure).
|
||||
WHY: This demonstrates a key historical limitation.
|
||||
"""
|
||||
script = PROJECT_ROOT / "milestones" / "02_1969_xor" / "01_xor_crisis.py"
|
||||
if not script.exists():
|
||||
pytest.skip(f"Script not found: {script}")
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"XOR crisis failed:\n{result.stderr}"
|
||||
|
||||
def test_xor_solved_runs(self):
|
||||
"""
|
||||
WHAT: Verify the XOR solved demo runs (multi-layer success).
|
||||
WHY: This proves hidden layers enable non-linear classification.
|
||||
"""
|
||||
script = PROJECT_ROOT / "milestones" / "02_1969_xor" / "02_xor_solved.py"
|
||||
if not script.exists():
|
||||
pytest.skip(f"Script not found: {script}")
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"XOR solved failed:\n{result.stderr}"
|
||||
|
||||
|
||||
class TestMilestone03MLP:
|
||||
"""Test Milestone 03: MLP Revival (1986)"""
|
||||
|
||||
def test_mlp_tinydigits_runs(self):
|
||||
"""
|
||||
WHAT: Verify MLP training on TinyDigits runs.
|
||||
WHY: This proves backprop works on real data.
|
||||
"""
|
||||
script = PROJECT_ROOT / "milestones" / "03_1986_mlp" / "01_rumelhart_tinydigits.py"
|
||||
if not script.exists():
|
||||
pytest.skip(f"Script not found: {script}")
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=180, # Training can take a bit
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"MLP TinyDigits failed:\n{result.stderr}"
|
||||
|
||||
|
||||
class TestMilestone04CNN:
|
||||
"""Test Milestone 04: CNN Revolution (1998)"""
|
||||
|
||||
def test_cnn_tinydigits_runs(self):
|
||||
"""
|
||||
WHAT: Verify CNN training on TinyDigits runs.
|
||||
WHY: This proves spatial operations and convolutions work.
|
||||
"""
|
||||
script = PROJECT_ROOT / "milestones" / "04_1998_cnn" / "01_lecun_tinydigits.py"
|
||||
if not script.exists():
|
||||
pytest.skip(f"Script not found: {script}")
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # CNN training can be slow
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"CNN TinyDigits failed:\n{result.stderr}"
|
||||
|
||||
|
||||
class TestMilestone05Transformer:
|
||||
"""Test Milestone 05: Transformer Era (2017)"""
|
||||
|
||||
def test_attention_proof_runs(self):
|
||||
"""
|
||||
WHAT: Verify the attention mechanism proof runs.
|
||||
WHY: This proves attention can learn cross-position relationships.
|
||||
"""
|
||||
script = PROJECT_ROOT / "milestones" / "05_2017_transformer" / "00_vaswani_attention_proof.py"
|
||||
if not script.exists():
|
||||
pytest.skip(f"Script not found: {script}")
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"Attention proof failed:\n{result.stderr}"
|
||||
# Verify it achieved good accuracy
|
||||
assert "100.0%" in result.stdout or "99" in result.stdout, \
|
||||
"Attention proof should achieve near-perfect accuracy"
|
||||
|
||||
|
||||
class TestMilestone06MLPerf:
|
||||
"""Test Milestone 06: MLPerf (2018)"""
|
||||
|
||||
def test_optimization_olympics_runs(self):
|
||||
"""
|
||||
WHAT: Verify the optimization pipeline runs.
|
||||
WHY: This proves profiling, quantization, and pruning work.
|
||||
"""
|
||||
script = PROJECT_ROOT / "milestones" / "06_2018_mlperf" / "01_optimization_olympics.py"
|
||||
if not script.exists():
|
||||
pytest.skip(f"Script not found: {script}")
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=180,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"Optimization Olympics failed:\n{result.stderr}"
|
||||
# Verify compression was achieved
|
||||
assert "compression" in result.stdout.lower() or "smaller" in result.stdout.lower(), \
|
||||
"Should show compression metrics"
|
||||
|
||||
|
||||
class TestMilestoneCLI:
|
||||
"""Test milestones work through the CLI."""
|
||||
|
||||
def test_milestones_list_works(self):
|
||||
"""
|
||||
WHAT: Verify `tito milestones list` works.
|
||||
WHY: Students need to discover available milestones.
|
||||
"""
|
||||
result = subprocess.run(
|
||||
["tito", "milestones", "list"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"tito milestones list failed:\n{result.stderr}"
|
||||
assert "Perceptron" in result.stdout, "Should list Perceptron milestone"
|
||||
assert "Transformer" in result.stdout, "Should list Transformer milestone"
|
||||
|
||||
def test_milestones_status_works(self):
|
||||
"""
|
||||
WHAT: Verify `tito milestones status` works.
|
||||
WHY: Students need to track their progress.
|
||||
"""
|
||||
result = subprocess.run(
|
||||
["tito", "milestones", "status"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"tito milestones status failed:\n{result.stderr}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
320
tinytorch/_modidx.py
generated
320
tinytorch/_modidx.py
generated
@@ -51,6 +51,56 @@ d = { 'settings': { 'branch': 'main',
|
||||
'tinytorch/applications/tinygpt.py'),
|
||||
'tinytorch.applications.tinygpt.test_unit_training_pipeline': ( '20_capstone/capstone.html#test_unit_training_pipeline',
|
||||
'tinytorch/applications/tinygpt.py')},
|
||||
'tinytorch.bench': { 'tinytorch.bench.Benchmark': ('19_benchmarking/benchmarking.html#benchmark', 'tinytorch/bench.py'),
|
||||
'tinytorch.bench.Benchmark.__init__': ( '19_benchmarking/benchmarking.html#benchmark.__init__',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.Benchmark.compare_models': ( '19_benchmarking/benchmarking.html#benchmark.compare_models',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.Benchmark.run_accuracy_benchmark': ( '19_benchmarking/benchmarking.html#benchmark.run_accuracy_benchmark',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.Benchmark.run_latency_benchmark': ( '19_benchmarking/benchmarking.html#benchmark.run_latency_benchmark',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.Benchmark.run_memory_benchmark': ( '19_benchmarking/benchmarking.html#benchmark.run_memory_benchmark',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.BenchmarkResult': ( '19_benchmarking/benchmarking.html#benchmarkresult',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.BenchmarkResult.__post_init__': ( '19_benchmarking/benchmarking.html#benchmarkresult.__post_init__',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.BenchmarkResult.__str__': ( '19_benchmarking/benchmarking.html#benchmarkresult.__str__',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.BenchmarkResult.to_dict': ( '19_benchmarking/benchmarking.html#benchmarkresult.to_dict',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.BenchmarkSuite': ( '19_benchmarking/benchmarking.html#benchmarksuite',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.BenchmarkSuite.__init__': ( '19_benchmarking/benchmarking.html#benchmarksuite.__init__',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.BenchmarkSuite._estimate_energy_efficiency': ( '19_benchmarking/benchmarking.html#benchmarksuite._estimate_energy_efficiency',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.BenchmarkSuite.generate_report': ( '19_benchmarking/benchmarking.html#benchmarksuite.generate_report',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.BenchmarkSuite.plot_pareto_frontier': ( '19_benchmarking/benchmarking.html#benchmarksuite.plot_pareto_frontier',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.BenchmarkSuite.plot_results': ( '19_benchmarking/benchmarking.html#benchmarksuite.plot_results',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.BenchmarkSuite.run_full_benchmark': ( '19_benchmarking/benchmarking.html#benchmarksuite.run_full_benchmark',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.TinyMLPerf': ('19_benchmarking/benchmarking.html#tinymlperf', 'tinytorch/bench.py'),
|
||||
'tinytorch.bench.TinyMLPerf.__init__': ( '19_benchmarking/benchmarking.html#tinymlperf.__init__',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.TinyMLPerf.generate_compliance_report': ( '19_benchmarking/benchmarking.html#tinymlperf.generate_compliance_report',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.TinyMLPerf.run_all_benchmarks': ( '19_benchmarking/benchmarking.html#tinymlperf.run_all_benchmarks',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.TinyMLPerf.run_standard_benchmark': ( '19_benchmarking/benchmarking.html#tinymlperf.run_standard_benchmark',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.test_unit_benchmark': ( '19_benchmarking/benchmarking.html#test_unit_benchmark',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.test_unit_benchmark_result': ( '19_benchmarking/benchmarking.html#test_unit_benchmark_result',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.test_unit_benchmark_suite': ( '19_benchmarking/benchmarking.html#test_unit_benchmark_suite',
|
||||
'tinytorch/bench.py'),
|
||||
'tinytorch.bench.test_unit_tinymlperf': ( '19_benchmarking/benchmarking.html#test_unit_tinymlperf',
|
||||
'tinytorch/bench.py')},
|
||||
'tinytorch.benchmarking.benchmark': { 'tinytorch.benchmarking.benchmark.Benchmark': ( '19_benchmarking/benchmarking.html#benchmark',
|
||||
'tinytorch/benchmarking/benchmark.py'),
|
||||
'tinytorch.benchmarking.benchmark.Benchmark.__init__': ( '19_benchmarking/benchmarking.html#benchmark.__init__',
|
||||
@@ -201,6 +251,86 @@ d = { 'settings': { 'branch': 'main',
|
||||
'tinytorch.core.attention.scaled_dot_product_attention': ( '12_attention/attention.html#scaled_dot_product_attention',
|
||||
'tinytorch/core/attention.py')},
|
||||
'tinytorch.core.autograd': {},
|
||||
'tinytorch.core.dataloader': { 'tinytorch.core.dataloader.Compose': ( '08_dataloader/dataloader.html#compose',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.Compose.__call__': ( '08_dataloader/dataloader.html#compose.__call__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.Compose.__init__': ( '08_dataloader/dataloader.html#compose.__init__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.DataLoader': ( '08_dataloader/dataloader.html#dataloader',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.DataLoader.__init__': ( '08_dataloader/dataloader.html#dataloader.__init__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.DataLoader.__iter__': ( '08_dataloader/dataloader.html#dataloader.__iter__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.DataLoader.__len__': ( '08_dataloader/dataloader.html#dataloader.__len__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.DataLoader._collate_batch': ( '08_dataloader/dataloader.html#dataloader._collate_batch',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.Dataset': ( '08_dataloader/dataloader.html#dataset',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.Dataset.__getitem__': ( '08_dataloader/dataloader.html#dataset.__getitem__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.Dataset.__len__': ( '08_dataloader/dataloader.html#dataset.__len__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.RandomCrop': ( '08_dataloader/dataloader.html#randomcrop',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.RandomCrop.__call__': ( '08_dataloader/dataloader.html#randomcrop.__call__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.RandomCrop.__init__': ( '08_dataloader/dataloader.html#randomcrop.__init__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.RandomHorizontalFlip': ( '08_dataloader/dataloader.html#randomhorizontalflip',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.RandomHorizontalFlip.__call__': ( '08_dataloader/dataloader.html#randomhorizontalflip.__call__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.RandomHorizontalFlip.__init__': ( '08_dataloader/dataloader.html#randomhorizontalflip.__init__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.TensorDataset': ( '08_dataloader/dataloader.html#tensordataset',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.TensorDataset.__getitem__': ( '08_dataloader/dataloader.html#tensordataset.__getitem__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.TensorDataset.__init__': ( '08_dataloader/dataloader.html#tensordataset.__init__',
|
||||
'tinytorch/core/dataloader.py'),
|
||||
'tinytorch.core.dataloader.TensorDataset.__len__': ( '08_dataloader/dataloader.html#tensordataset.__len__',
|
||||
'tinytorch/core/dataloader.py')},
|
||||
'tinytorch.core.embeddings': { 'tinytorch.core.embeddings.Embedding': ( '11_embeddings/embeddings.html#embedding',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.Embedding.__call__': ( '11_embeddings/embeddings.html#embedding.__call__',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.Embedding.__init__': ( '11_embeddings/embeddings.html#embedding.__init__',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.Embedding.__repr__': ( '11_embeddings/embeddings.html#embedding.__repr__',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.Embedding.forward': ( '11_embeddings/embeddings.html#embedding.forward',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.Embedding.parameters': ( '11_embeddings/embeddings.html#embedding.parameters',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.EmbeddingLayer': ( '11_embeddings/embeddings.html#embeddinglayer',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.EmbeddingLayer.__call__': ( '11_embeddings/embeddings.html#embeddinglayer.__call__',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.EmbeddingLayer.__init__': ( '11_embeddings/embeddings.html#embeddinglayer.__init__',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.EmbeddingLayer.__repr__': ( '11_embeddings/embeddings.html#embeddinglayer.__repr__',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.EmbeddingLayer.forward': ( '11_embeddings/embeddings.html#embeddinglayer.forward',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.EmbeddingLayer.parameters': ( '11_embeddings/embeddings.html#embeddinglayer.parameters',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.PositionalEncoding': ( '11_embeddings/embeddings.html#positionalencoding',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.PositionalEncoding.__call__': ( '11_embeddings/embeddings.html#positionalencoding.__call__',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.PositionalEncoding.__init__': ( '11_embeddings/embeddings.html#positionalencoding.__init__',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.PositionalEncoding.__repr__': ( '11_embeddings/embeddings.html#positionalencoding.__repr__',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.PositionalEncoding.forward': ( '11_embeddings/embeddings.html#positionalencoding.forward',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.PositionalEncoding.parameters': ( '11_embeddings/embeddings.html#positionalencoding.parameters',
|
||||
'tinytorch/core/embeddings.py'),
|
||||
'tinytorch.core.embeddings.create_sinusoidal_embeddings': ( '11_embeddings/embeddings.html#create_sinusoidal_embeddings',
|
||||
'tinytorch/core/embeddings.py')},
|
||||
'tinytorch.core.layers': { 'tinytorch.core.layers.Dropout': ('03_layers/layers.html#dropout', 'tinytorch/core/layers.py'),
|
||||
'tinytorch.core.layers.Dropout.__call__': ( '03_layers/layers.html#dropout.__call__',
|
||||
'tinytorch/core/layers.py'),
|
||||
@@ -393,6 +523,40 @@ d = { 'settings': { 'branch': 'main',
|
||||
'tinytorch.core.tensor.Tensor.sum': ('01_tensor/tensor.html#tensor.sum', 'tinytorch/core/tensor.py'),
|
||||
'tinytorch.core.tensor.Tensor.transpose': ( '01_tensor/tensor.html#tensor.transpose',
|
||||
'tinytorch/core/tensor.py')},
|
||||
'tinytorch.core.tokenization': { 'tinytorch.core.tokenization.BPETokenizer': ( '10_tokenization/tokenization.html#bpetokenizer',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.BPETokenizer.__init__': ( '10_tokenization/tokenization.html#bpetokenizer.__init__',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.BPETokenizer._apply_merges': ( '10_tokenization/tokenization.html#bpetokenizer._apply_merges',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.BPETokenizer._build_mappings': ( '10_tokenization/tokenization.html#bpetokenizer._build_mappings',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.BPETokenizer._get_pairs': ( '10_tokenization/tokenization.html#bpetokenizer._get_pairs',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.BPETokenizer._get_word_tokens': ( '10_tokenization/tokenization.html#bpetokenizer._get_word_tokens',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.BPETokenizer.decode': ( '10_tokenization/tokenization.html#bpetokenizer.decode',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.BPETokenizer.encode': ( '10_tokenization/tokenization.html#bpetokenizer.encode',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.BPETokenizer.train': ( '10_tokenization/tokenization.html#bpetokenizer.train',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.CharTokenizer': ( '10_tokenization/tokenization.html#chartokenizer',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.CharTokenizer.__init__': ( '10_tokenization/tokenization.html#chartokenizer.__init__',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.CharTokenizer.build_vocab': ( '10_tokenization/tokenization.html#chartokenizer.build_vocab',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.CharTokenizer.decode': ( '10_tokenization/tokenization.html#chartokenizer.decode',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.CharTokenizer.encode': ( '10_tokenization/tokenization.html#chartokenizer.encode',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.Tokenizer': ( '10_tokenization/tokenization.html#tokenizer',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.Tokenizer.decode': ( '10_tokenization/tokenization.html#tokenizer.decode',
|
||||
'tinytorch/core/tokenization.py'),
|
||||
'tinytorch.core.tokenization.Tokenizer.encode': ( '10_tokenization/tokenization.html#tokenizer.encode',
|
||||
'tinytorch/core/tokenization.py')},
|
||||
'tinytorch.core.training': { 'tinytorch.core.training.CosineSchedule': ( '07_training/training.html#cosineschedule',
|
||||
'tinytorch/core/training.py'),
|
||||
'tinytorch.core.training.CosineSchedule.__init__': ( '07_training/training.html#cosineschedule.__init__',
|
||||
@@ -425,6 +589,50 @@ d = { 'settings': { 'branch': 'main',
|
||||
'tinytorch/core/training.py'),
|
||||
'tinytorch.core.training.clip_grad_norm': ( '07_training/training.html#clip_grad_norm',
|
||||
'tinytorch/core/training.py')},
|
||||
'tinytorch.core.transformer': { 'tinytorch.core.transformer.GPT': ( '13_transformers/transformers.html#gpt',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.GPT.__call__': ( '13_transformers/transformers.html#gpt.__call__',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.GPT.__init__': ( '13_transformers/transformers.html#gpt.__init__',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.GPT._create_causal_mask': ( '13_transformers/transformers.html#gpt._create_causal_mask',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.GPT.forward': ( '13_transformers/transformers.html#gpt.forward',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.GPT.generate': ( '13_transformers/transformers.html#gpt.generate',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.GPT.parameters': ( '13_transformers/transformers.html#gpt.parameters',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.LayerNorm': ( '13_transformers/transformers.html#layernorm',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.LayerNorm.__call__': ( '13_transformers/transformers.html#layernorm.__call__',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.LayerNorm.__init__': ( '13_transformers/transformers.html#layernorm.__init__',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.LayerNorm.forward': ( '13_transformers/transformers.html#layernorm.forward',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.LayerNorm.parameters': ( '13_transformers/transformers.html#layernorm.parameters',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.MLP': ( '13_transformers/transformers.html#mlp',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.MLP.__call__': ( '13_transformers/transformers.html#mlp.__call__',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.MLP.__init__': ( '13_transformers/transformers.html#mlp.__init__',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.MLP.forward': ( '13_transformers/transformers.html#mlp.forward',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.MLP.parameters': ( '13_transformers/transformers.html#mlp.parameters',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.TransformerBlock': ( '13_transformers/transformers.html#transformerblock',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.TransformerBlock.__call__': ( '13_transformers/transformers.html#transformerblock.__call__',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.TransformerBlock.__init__': ( '13_transformers/transformers.html#transformerblock.__init__',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.TransformerBlock.forward': ( '13_transformers/transformers.html#transformerblock.forward',
|
||||
'tinytorch/core/transformer.py'),
|
||||
'tinytorch.core.transformer.TransformerBlock.parameters': ( '13_transformers/transformers.html#transformerblock.parameters',
|
||||
'tinytorch/core/transformer.py')},
|
||||
'tinytorch.data.loader': { 'tinytorch.data.loader.Compose': ( '08_dataloader/dataloader.html#compose',
|
||||
'tinytorch/data/loader.py'),
|
||||
'tinytorch.data.loader.Compose.__call__': ( '08_dataloader/dataloader.html#compose.__call__',
|
||||
@@ -607,6 +815,118 @@ d = { 'settings': { 'branch': 'main',
|
||||
'tinytorch/optimization/quantization.py'),
|
||||
'tinytorch.optimization.quantization.quantize_model': ( '15_quantization/quantization.html#quantize_model',
|
||||
'tinytorch/optimization/quantization.py')},
|
||||
'tinytorch.perf.acceleration': { 'tinytorch.perf.acceleration.fused_gelu': ( '18_acceleration/acceleration.html#fused_gelu',
|
||||
'tinytorch/perf/acceleration.py'),
|
||||
'tinytorch.perf.acceleration.tiled_matmul': ( '18_acceleration/acceleration.html#tiled_matmul',
|
||||
'tinytorch/perf/acceleration.py'),
|
||||
'tinytorch.perf.acceleration.vectorized_matmul': ( '18_acceleration/acceleration.html#vectorized_matmul',
|
||||
'tinytorch/perf/acceleration.py')},
|
||||
'tinytorch.perf.compression': { 'tinytorch.perf.compression.Compressor': ( '16_compression/compression.html#compressor',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.Compressor.compress_model': ( '16_compression/compression.html#compressor.compress_model',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.Compressor.magnitude_prune': ( '16_compression/compression.html#compressor.magnitude_prune',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.Compressor.measure_sparsity': ( '16_compression/compression.html#compressor.measure_sparsity',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.Compressor.structured_prune': ( '16_compression/compression.html#compressor.structured_prune',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.KnowledgeDistillation': ( '16_compression/compression.html#knowledgedistillation',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.KnowledgeDistillation.__init__': ( '16_compression/compression.html#knowledgedistillation.__init__',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.KnowledgeDistillation._cross_entropy': ( '16_compression/compression.html#knowledgedistillation._cross_entropy',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.KnowledgeDistillation._kl_divergence': ( '16_compression/compression.html#knowledgedistillation._kl_divergence',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.KnowledgeDistillation._softmax': ( '16_compression/compression.html#knowledgedistillation._softmax',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.KnowledgeDistillation.distillation_loss': ( '16_compression/compression.html#knowledgedistillation.distillation_loss',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.compress_model': ( '16_compression/compression.html#compress_model',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.low_rank_approximate': ( '16_compression/compression.html#low_rank_approximate',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.magnitude_prune': ( '16_compression/compression.html#magnitude_prune',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.measure_sparsity': ( '16_compression/compression.html#measure_sparsity',
|
||||
'tinytorch/perf/compression.py'),
|
||||
'tinytorch.perf.compression.structured_prune': ( '16_compression/compression.html#structured_prune',
|
||||
'tinytorch/perf/compression.py')},
|
||||
'tinytorch.perf.memoization': { 'tinytorch.perf.memoization.KVCache': ( '17_memoization/memoization.html#kvcache',
|
||||
'tinytorch/perf/memoization.py'),
|
||||
'tinytorch.perf.memoization.KVCache.__init__': ( '17_memoization/memoization.html#kvcache.__init__',
|
||||
'tinytorch/perf/memoization.py'),
|
||||
'tinytorch.perf.memoization.KVCache.advance': ( '17_memoization/memoization.html#kvcache.advance',
|
||||
'tinytorch/perf/memoization.py'),
|
||||
'tinytorch.perf.memoization.KVCache.get': ( '17_memoization/memoization.html#kvcache.get',
|
||||
'tinytorch/perf/memoization.py'),
|
||||
'tinytorch.perf.memoization.KVCache.get_memory_usage': ( '17_memoization/memoization.html#kvcache.get_memory_usage',
|
||||
'tinytorch/perf/memoization.py'),
|
||||
'tinytorch.perf.memoization.KVCache.reset': ( '17_memoization/memoization.html#kvcache.reset',
|
||||
'tinytorch/perf/memoization.py'),
|
||||
'tinytorch.perf.memoization.KVCache.update': ( '17_memoization/memoization.html#kvcache.update',
|
||||
'tinytorch/perf/memoization.py'),
|
||||
'tinytorch.perf.memoization.create_kv_cache': ( '17_memoization/memoization.html#create_kv_cache',
|
||||
'tinytorch/perf/memoization.py'),
|
||||
'tinytorch.perf.memoization.disable_kv_cache': ( '17_memoization/memoization.html#disable_kv_cache',
|
||||
'tinytorch/perf/memoization.py'),
|
||||
'tinytorch.perf.memoization.enable_kv_cache': ( '17_memoization/memoization.html#enable_kv_cache',
|
||||
'tinytorch/perf/memoization.py')},
|
||||
'tinytorch.perf.profiling': { 'tinytorch.perf.profiling.Profiler': ( '14_profiling/profiling.html#profiler',
|
||||
'tinytorch/perf/profiling.py'),
|
||||
'tinytorch.perf.profiling.Profiler.__init__': ( '14_profiling/profiling.html#profiler.__init__',
|
||||
'tinytorch/perf/profiling.py'),
|
||||
'tinytorch.perf.profiling.Profiler.count_flops': ( '14_profiling/profiling.html#profiler.count_flops',
|
||||
'tinytorch/perf/profiling.py'),
|
||||
'tinytorch.perf.profiling.Profiler.count_parameters': ( '14_profiling/profiling.html#profiler.count_parameters',
|
||||
'tinytorch/perf/profiling.py'),
|
||||
'tinytorch.perf.profiling.Profiler.measure_latency': ( '14_profiling/profiling.html#profiler.measure_latency',
|
||||
'tinytorch/perf/profiling.py'),
|
||||
'tinytorch.perf.profiling.Profiler.measure_memory': ( '14_profiling/profiling.html#profiler.measure_memory',
|
||||
'tinytorch/perf/profiling.py'),
|
||||
'tinytorch.perf.profiling.Profiler.profile_backward_pass': ( '14_profiling/profiling.html#profiler.profile_backward_pass',
|
||||
'tinytorch/perf/profiling.py'),
|
||||
'tinytorch.perf.profiling.Profiler.profile_forward_pass': ( '14_profiling/profiling.html#profiler.profile_forward_pass',
|
||||
'tinytorch/perf/profiling.py'),
|
||||
'tinytorch.perf.profiling.Profiler.profile_layer': ( '14_profiling/profiling.html#profiler.profile_layer',
|
||||
'tinytorch/perf/profiling.py'),
|
||||
'tinytorch.perf.profiling.analyze_weight_distribution': ( '14_profiling/profiling.html#analyze_weight_distribution',
|
||||
'tinytorch/perf/profiling.py'),
|
||||
'tinytorch.perf.profiling.quick_profile': ( '14_profiling/profiling.html#quick_profile',
|
||||
'tinytorch/perf/profiling.py')},
|
||||
'tinytorch.perf.quantization': { 'tinytorch.perf.quantization.QuantizedLinear': ( '15_quantization/quantization.html#quantizedlinear',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.QuantizedLinear.__call__': ( '15_quantization/quantization.html#quantizedlinear.__call__',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.QuantizedLinear.__init__': ( '15_quantization/quantization.html#quantizedlinear.__init__',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.QuantizedLinear.calibrate': ( '15_quantization/quantization.html#quantizedlinear.calibrate',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.QuantizedLinear.forward': ( '15_quantization/quantization.html#quantizedlinear.forward',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.QuantizedLinear.memory_usage': ( '15_quantization/quantization.html#quantizedlinear.memory_usage',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.QuantizedLinear.parameters': ( '15_quantization/quantization.html#quantizedlinear.parameters',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.Quantizer': ( '15_quantization/quantization.html#quantizer',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.Quantizer.compare_models': ( '15_quantization/quantization.html#quantizer.compare_models',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.Quantizer.dequantize_tensor': ( '15_quantization/quantization.html#quantizer.dequantize_tensor',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.Quantizer.quantize_model': ( '15_quantization/quantization.html#quantizer.quantize_model',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.Quantizer.quantize_tensor': ( '15_quantization/quantization.html#quantizer.quantize_tensor',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.compare_model_sizes': ( '15_quantization/quantization.html#compare_model_sizes',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.dequantize_int8': ( '15_quantization/quantization.html#dequantize_int8',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.quantize_int8': ( '15_quantization/quantization.html#quantize_int8',
|
||||
'tinytorch/perf/quantization.py'),
|
||||
'tinytorch.perf.quantization.quantize_model': ( '15_quantization/quantization.html#quantize_model',
|
||||
'tinytorch/perf/quantization.py')},
|
||||
'tinytorch.profiling.profiler': { 'tinytorch.profiling.profiler.Profiler': ( '14_profiling/profiling.html#profiler',
|
||||
'tinytorch/profiling/profiler.py'),
|
||||
'tinytorch.profiling.profiler.Profiler.__init__': ( '14_profiling/profiling.html#profiler.__init__',
|
||||
|
||||
6
tinytorch/core/layers.py
generated
6
tinytorch/core/layers.py
generated
@@ -15,7 +15,7 @@
|
||||
# ║ The tinytorch/ directory is generated code - edit source files instead! ║
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = ['Layer', 'Linear', 'Dropout']
|
||||
__all__ = ['XAVIER_SCALE_FACTOR', 'HE_SCALE_FACTOR', 'DROPOUT_MIN_PROB', 'DROPOUT_MAX_PROB', 'Layer', 'Linear', 'Dropout']
|
||||
|
||||
# %% ../../modules/03_layers/03_layers.ipynb 1
|
||||
import numpy as np
|
||||
@@ -273,7 +273,3 @@ class Dropout(Layer):
|
||||
|
||||
def __repr__(self):
|
||||
return f"Dropout(p={self.p})"
|
||||
|
||||
# Alias for compatibility - Dense is the same as Linear
|
||||
# Some frameworks use Dense, some use Linear - they're identical
|
||||
Dense = Linear
|
||||
|
||||
@@ -80,9 +80,9 @@ MILESTONE_SCRIPTS = {
|
||||
"name": "Transformer Era (2017)",
|
||||
"year": 2017,
|
||||
"title": "Attention is All You Need",
|
||||
"script": "milestones/05_2017_transformer/03_quickdemo.py",
|
||||
"script": "milestones/05_2017_transformer/00_vaswani_attention_proof.py",
|
||||
"required_modules": list(range(1, 14)),
|
||||
"description": "Build transformer with self-attention",
|
||||
"description": "Prove attention works with sequence reversal",
|
||||
"historical_context": "Vaswani et al. revolutionized NLP",
|
||||
"emoji": "🤖"
|
||||
},
|
||||
@@ -90,10 +90,10 @@ MILESTONE_SCRIPTS = {
|
||||
"id": "06",
|
||||
"name": "MLPerf Benchmarks (2018)",
|
||||
"year": 2018,
|
||||
"title": "Production ML Systems",
|
||||
"script": "milestones/06_2018_mlperf/02_compression.py",
|
||||
"required_modules": list(range(1, 20)),
|
||||
"description": "Optimize for production deployment",
|
||||
"title": "The Optimization Olympics",
|
||||
"script": "milestones/06_2018_mlperf/01_optimization_olympics.py",
|
||||
"required_modules": list(range(1, 17)), # Needs up to Module 16 (Compression)
|
||||
"description": "Compress and accelerate your neural network",
|
||||
"historical_context": "MLPerf standardized ML benchmarks",
|
||||
"emoji": "🏆"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user