mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-27 18:57:36 -05:00
Remove obsolete KV cache test file
- Delete test_kv_cache_milestone.py - Standalone test file no longer needed after module integration
This commit is contained in:
@@ -1,143 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick test to demonstrate KV cache integration with chatbot milestone.
|
||||
|
||||
Tests:
|
||||
1. Generation WITHOUT cache (baseline)
|
||||
2. Generation WITH cache enabled (Module 14)
|
||||
3. Verify cache infrastructure works without breaking model
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
# Add paths
|
||||
sys.path.insert(0, 'milestones/05_2017_transformer')
|
||||
|
||||
print("=" * 70)
|
||||
print("🧪 Testing KV Cache Integration with TinyTalks ChatBot")
|
||||
print("=" * 70)
|
||||
print()
|
||||
|
||||
# Import components
|
||||
print("📦 Importing TinyTorch components...")
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from tinytorch.text.tokenization import CharTokenizer
|
||||
from tinytorch.generation.kv_cache import enable_kv_cache, disable_kv_cache
|
||||
|
||||
# Import the TinyGPT model from milestone
|
||||
from vaswani_chatgpt import TinyGPT
|
||||
|
||||
print("✅ All imports successful")
|
||||
print()
|
||||
|
||||
# Create a tiny model for testing
|
||||
print("🏗️ Building tiny test model...")
|
||||
model = TinyGPT(
|
||||
vocab_size=50, # Small vocab for testing
|
||||
embed_dim=32, # Tiny model
|
||||
num_layers=2,
|
||||
num_heads=2,
|
||||
max_seq_len=16
|
||||
)
|
||||
print(f"✅ Model created: {model.count_parameters():,} parameters")
|
||||
print()
|
||||
|
||||
# Create tokenizer
|
||||
tokenizer = CharTokenizer(list("abcdefghijklmnopqrstuvwxyz .,!?"))
|
||||
print(f"✅ Tokenizer created: {tokenizer.vocab_size} tokens")
|
||||
print()
|
||||
|
||||
# Test 1: Generation WITHOUT cache
|
||||
print("=" * 70)
|
||||
print("🔬 Test 1: Generation WITHOUT Cache (Baseline)")
|
||||
print("=" * 70)
|
||||
prompt = "hello"
|
||||
print(f"Prompt: '{prompt}'")
|
||||
print()
|
||||
|
||||
start = time.time()
|
||||
response1, stats1 = model.generate(
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
max_new_tokens=10,
|
||||
return_stats=True,
|
||||
use_cache=False
|
||||
)
|
||||
elapsed1 = time.time() - start
|
||||
|
||||
print(f"Generated: '{response1[:50]}...'")
|
||||
print(f"Time: {elapsed1:.3f}s")
|
||||
print(f"Speed: {stats1['tokens_per_sec']:.1f} tokens/sec")
|
||||
print(f"Tokens: {stats1['tokens_generated']}")
|
||||
print()
|
||||
|
||||
# Test 2: Generation WITH cache
|
||||
print("=" * 70)
|
||||
print("🔬 Test 2: Generation WITH Cache (Module 14)")
|
||||
print("=" * 70)
|
||||
print(f"Prompt: '{prompt}'")
|
||||
print()
|
||||
|
||||
start = time.time()
|
||||
response2, stats2 = model.generate(
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
max_new_tokens=10,
|
||||
return_stats=True,
|
||||
use_cache=True
|
||||
)
|
||||
elapsed2 = time.time() - start
|
||||
|
||||
print(f"Generated: '{response2[:50]}...'")
|
||||
print(f"Time: {elapsed2:.3f}s")
|
||||
print(f"Speed: {stats2['tokens_per_sec']:.1f} tokens/sec")
|
||||
print(f"Tokens: {stats2['tokens_generated']}")
|
||||
print()
|
||||
|
||||
# Summary
|
||||
print("=" * 70)
|
||||
print("📊 Summary")
|
||||
print("=" * 70)
|
||||
print(f"Without cache: {stats1['tokens_per_sec']:.1f} tok/s")
|
||||
print(f"With cache: {stats2['tokens_per_sec']:.1f} tok/s")
|
||||
print()
|
||||
|
||||
# Check if cache infrastructure was activated
|
||||
if hasattr(model, '_cache_enabled'):
|
||||
print("✅ Cache infrastructure successfully integrated!")
|
||||
print(f" Cache enabled: {model._cache_enabled}")
|
||||
if hasattr(model, '_kv_cache'):
|
||||
mem = model._kv_cache.get_memory_usage()
|
||||
print(f" Cache memory: {mem['total_mb']:.2f} MB")
|
||||
else:
|
||||
print("⚠️ Cache infrastructure not found on model")
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("📝 Note: Current Implementation")
|
||||
print("=" * 70)
|
||||
print("""
|
||||
This is a REAL implementation of KV caching with actual speedup:
|
||||
✅ enable_kv_cache() patches the model non-invasively
|
||||
✅ Cache stores K,V for all previous tokens
|
||||
✅ Only computes K,V for NEW token during generation
|
||||
✅ Uses cached K,V history for attention computation
|
||||
✅ Achieves 5-7x speedup on this tiny model
|
||||
|
||||
The speedup comes from transforming O(n²) to O(n):
|
||||
- WITHOUT cache: Recomputes attention for ALL tokens at each step
|
||||
- WITH cache: Only computes attention for NEW token, retrieves history
|
||||
|
||||
For longer sequences, the speedup will be even higher (10-15x+)!
|
||||
|
||||
Students learn:
|
||||
1. Non-invasive optimization patterns (Module 14 enhances Module 12)
|
||||
2. Inference vs training optimizations (cache only during generation)
|
||||
3. Memory-compute trade-offs (small cache = big speedup)
|
||||
4. Real ML systems engineering (this is how ChatGPT works!)
|
||||
""")
|
||||
|
||||
print("✅ Test complete!")
|
||||
|
||||
Reference in New Issue
Block a user