mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-05-05 21:57:32 -05:00
- Rename tests/14_kvcaching to tests/14_profiling - Rename tests/15_profiling to tests/15_memoization - Aligns test structure with optimization tier reorganization
334 lines
13 KiB
Python
334 lines
13 KiB
Python
"""
|
||
Integration Tests for Module 14: KV Caching
|
||
Tests integration with transformer components and generation
|
||
"""
|
||
|
||
import numpy as np
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# Add project root to path
|
||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||
|
||
from tinytorch.core.tensor import Tensor
|
||
from tinytorch.generation.kv_cache import KVCache, enable_kv_cache
|
||
from tinytorch.core.layers import Linear
|
||
from tinytorch.core.attention import MultiHeadAttention
|
||
|
||
|
||
class TestKVCacheIntegration:
|
||
"""Test KV cache integration with transformer components."""
|
||
|
||
def test_cache_with_linear_projections(self):
|
||
"""Test that cache works with Linear layer projections (Q, K, V)."""
|
||
print("\n🔬 Test: KV Cache with Linear Projections")
|
||
|
||
# Setup: Small transformer config
|
||
batch_size, seq_len, embed_dim = 2, 4, 32
|
||
num_heads, head_dim = 4, 8
|
||
|
||
# Create Q, K, V projection layers
|
||
q_proj = Linear(embed_dim, embed_dim)
|
||
k_proj = Linear(embed_dim, embed_dim)
|
||
v_proj = Linear(embed_dim, embed_dim)
|
||
|
||
# Create input
|
||
x = Tensor(np.random.randn(batch_size, seq_len, embed_dim))
|
||
|
||
# Project to Q, K, V
|
||
Q = q_proj.forward(x)
|
||
K = k_proj.forward(x)
|
||
V = v_proj.forward(x)
|
||
|
||
# Reshape for multi-head attention
|
||
# (batch, seq, embed) -> (batch, seq, heads, head_dim) -> (batch, heads, seq, head_dim)
|
||
Q_heads = Q.data.reshape(batch_size, seq_len, num_heads, head_dim)
|
||
Q_heads = Tensor(np.transpose(Q_heads, (0, 2, 1, 3)))
|
||
|
||
K_heads = K.data.reshape(batch_size, seq_len, num_heads, head_dim)
|
||
K_heads = Tensor(np.transpose(K_heads, (0, 2, 1, 3)))
|
||
|
||
V_heads = V.data.reshape(batch_size, seq_len, num_heads, head_dim)
|
||
V_heads = Tensor(np.transpose(V_heads, (0, 2, 1, 3)))
|
||
|
||
# Create cache
|
||
cache = KVCache(
|
||
batch_size=batch_size,
|
||
max_seq_len=10,
|
||
num_layers=1,
|
||
num_heads=num_heads,
|
||
head_dim=head_dim
|
||
)
|
||
|
||
# Simulate autoregressive generation: process tokens one by one
|
||
for pos in range(seq_len):
|
||
# Get K, V for current position
|
||
k_current = Tensor(K_heads.data[:, :, pos:pos+1, :]) # (batch, heads, 1, head_dim)
|
||
v_current = Tensor(V_heads.data[:, :, pos:pos+1, :])
|
||
|
||
# Update cache
|
||
cache.update(layer_idx=0, key=k_current, value=v_current)
|
||
cache.advance()
|
||
|
||
# Retrieve full cached K, V
|
||
cached_K, cached_V = cache.get(layer_idx=0)
|
||
|
||
# Verify shapes
|
||
assert cached_K.shape == (batch_size, num_heads, seq_len, head_dim), \
|
||
f"Expected shape {(batch_size, num_heads, seq_len, head_dim)}, got {cached_K.shape}"
|
||
assert cached_V.shape == (batch_size, num_heads, seq_len, head_dim), \
|
||
f"Expected shape {(batch_size, num_heads, seq_len, head_dim)}, got {cached_V.shape}"
|
||
|
||
# Verify cached values match original projections
|
||
# Note: Small numerical differences okay due to reshape operations
|
||
diff_k = np.mean(np.abs(cached_K.data - K_heads.data[:, :, :seq_len, :]))
|
||
diff_v = np.mean(np.abs(cached_V.data - V_heads.data[:, :, :seq_len, :]))
|
||
|
||
assert diff_k < 1e-6, f"Cached K differs from original by {diff_k}"
|
||
assert diff_v < 1e-6, f"Cached V differs from original by {diff_v}"
|
||
|
||
print("✅ Cache correctly stores Linear projection outputs")
|
||
print(f" K difference: {diff_k:.2e}")
|
||
print(f" V difference: {diff_v:.2e}")
|
||
|
||
def test_cache_with_multi_layer_transformer(self):
|
||
"""Test cache with multiple transformer layers."""
|
||
print("\n🔬 Test: Multi-Layer Transformer Caching")
|
||
|
||
batch_size, seq_len = 1, 5
|
||
num_layers, num_heads, head_dim = 3, 4, 16
|
||
|
||
# Create cache for 3 layers
|
||
cache = enable_kv_cache(
|
||
batch_size=batch_size,
|
||
max_seq_len=10,
|
||
num_layers=num_layers,
|
||
num_heads=num_heads,
|
||
head_dim=head_dim
|
||
)
|
||
|
||
# Simulate processing through 3 layers
|
||
for pos in range(seq_len):
|
||
for layer_idx in range(num_layers):
|
||
# Simulate K, V for current token at this layer
|
||
k = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
v = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
|
||
cache.update(layer_idx=layer_idx, key=k, value=v)
|
||
|
||
# Advance after all layers processed
|
||
cache.advance()
|
||
|
||
# Verify each layer has correct cache size
|
||
for layer_idx in range(num_layers):
|
||
cached_k, cached_v = cache.get(layer_idx=layer_idx)
|
||
assert cached_k.shape == (batch_size, num_heads, seq_len, head_dim), \
|
||
f"Layer {layer_idx} has wrong cache shape"
|
||
|
||
print(f"✅ Successfully cached {num_layers} layers × {seq_len} tokens")
|
||
print(f" Total cache memory: {cache.get_memory_usage()['total_mb']:.3f} MB")
|
||
|
||
def test_cache_reset_and_reuse(self):
|
||
"""Test cache can be reset and reused for multiple generations."""
|
||
print("\n🔬 Test: Cache Reset and Reuse")
|
||
|
||
batch_size, num_layers, num_heads, head_dim = 1, 2, 4, 16
|
||
max_seq_len = 10
|
||
|
||
cache = KVCache(batch_size, max_seq_len, num_layers, num_heads, head_dim)
|
||
|
||
# First generation: 5 tokens
|
||
for pos in range(5):
|
||
for layer_idx in range(num_layers):
|
||
k = Tensor(np.ones((batch_size, num_heads, 1, head_dim)) * pos)
|
||
v = Tensor(np.ones((batch_size, num_heads, 1, head_dim)) * pos)
|
||
cache.update(layer_idx, k, v)
|
||
cache.advance()
|
||
|
||
# Verify first generation
|
||
cached_k, _ = cache.get(0)
|
||
assert cached_k.shape[2] == 5, "Should have 5 tokens cached"
|
||
|
||
# Reset cache
|
||
cache.reset()
|
||
assert cache.seq_pos == 0, "Position should be reset to 0"
|
||
|
||
cached_k, _ = cache.get(0)
|
||
assert cached_k.shape[2] == 0, "Cache should be empty after reset"
|
||
|
||
# Second generation: 3 tokens (different from first)
|
||
for pos in range(3):
|
||
for layer_idx in range(num_layers):
|
||
k = Tensor(np.ones((batch_size, num_heads, 1, head_dim)) * (pos + 10))
|
||
v = Tensor(np.ones((batch_size, num_heads, 1, head_dim)) * (pos + 10))
|
||
cache.update(layer_idx, k, v)
|
||
cache.advance()
|
||
|
||
# Verify second generation
|
||
cached_k, _ = cache.get(0)
|
||
assert cached_k.shape[2] == 3, "Should have 3 tokens cached"
|
||
|
||
# Verify values are from second generation (not first)
|
||
assert np.allclose(cached_k.data[0, 0, 0, 0], 10.0), "Should have new values"
|
||
|
||
print("✅ Cache successfully reset and reused")
|
||
print(" Generation 1: 5 tokens → reset")
|
||
print(" Generation 2: 3 tokens (new values)")
|
||
|
||
def test_cache_memory_tracking(self):
|
||
"""Test cache memory usage calculation."""
|
||
print("\n🔬 Test: Cache Memory Tracking")
|
||
|
||
configs = [
|
||
# (batch, max_seq, layers, heads, head_dim, expected_mb_range)
|
||
(1, 64, 2, 4, 16, (0.1, 0.5)), # Tiny
|
||
(1, 128, 4, 8, 32, (2.0, 4.0)), # Small
|
||
(2, 256, 6, 12, 64, (40.0, 60.0)), # Medium
|
||
]
|
||
|
||
for batch, max_seq, layers, heads, head_dim, (min_mb, max_mb) in configs:
|
||
cache = KVCache(batch, max_seq, layers, heads, head_dim)
|
||
mem_info = cache.get_memory_usage()
|
||
|
||
total_mb = mem_info['total_mb']
|
||
assert min_mb <= total_mb <= max_mb, \
|
||
f"Memory {total_mb:.2f} MB outside expected range [{min_mb}, {max_mb}]"
|
||
|
||
print(f"✅ Config (B={batch}, S={max_seq}, L={layers}, H={heads}, D={head_dim})")
|
||
print(f" Memory: {total_mb:.3f} MB")
|
||
print(f" Per layer: {mem_info['per_layer_mb']:.3f} MB")
|
||
|
||
def test_cache_with_batch_inference(self):
|
||
"""Test cache supports batch inference (multiple sequences)."""
|
||
print("\n🔬 Test: Batch Inference")
|
||
|
||
batch_size = 4 # Generate 4 sequences in parallel
|
||
seq_len, num_layers, num_heads, head_dim = 3, 2, 4, 16
|
||
|
||
cache = enable_kv_cache(batch_size, 10, num_layers, num_heads, head_dim)
|
||
|
||
# Generate 4 sequences in parallel
|
||
for pos in range(seq_len):
|
||
for layer_idx in range(num_layers):
|
||
# Different K, V for each sequence in batch
|
||
k = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
v = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
cache.update(layer_idx, k, v)
|
||
cache.advance()
|
||
|
||
# Verify all sequences cached
|
||
cached_k, cached_v = cache.get(0)
|
||
assert cached_k.shape == (batch_size, num_heads, seq_len, head_dim), \
|
||
"Batch dimension should be preserved"
|
||
|
||
# Verify sequences are different (not broadcast)
|
||
seq0 = cached_k.data[0, 0, 0, :]
|
||
seq1 = cached_k.data[1, 0, 0, :]
|
||
assert not np.allclose(seq0, seq1), "Sequences should be different"
|
||
|
||
print(f"✅ Successfully cached {batch_size} parallel sequences")
|
||
print(f" Shape per sequence: (1, {num_heads}, {seq_len}, {head_dim})")
|
||
|
||
def test_cache_boundary_conditions(self):
|
||
"""Test cache handles boundary conditions correctly."""
|
||
print("\n🔬 Test: Boundary Conditions")
|
||
|
||
batch_size, max_seq_len = 1, 5
|
||
num_layers, num_heads, head_dim = 2, 4, 16
|
||
|
||
cache = KVCache(batch_size, max_seq_len, num_layers, num_heads, head_dim)
|
||
|
||
# Test 1: Empty cache retrieval
|
||
cached_k, cached_v = cache.get(0)
|
||
assert cached_k.shape[2] == 0, "Empty cache should return 0 sequence length"
|
||
print("✅ Empty cache returns correct shape")
|
||
|
||
# Test 2: Fill to maximum
|
||
for pos in range(max_seq_len):
|
||
k = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
v = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
cache.update(0, k, v)
|
||
cache.advance()
|
||
|
||
cached_k, _ = cache.get(0)
|
||
assert cached_k.shape[2] == max_seq_len, "Should fill to max_seq_len"
|
||
print(f"✅ Cache filled to maximum ({max_seq_len} tokens)")
|
||
|
||
# Test 3: Overflow protection
|
||
try:
|
||
k = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
v = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
cache.update(0, k, v)
|
||
assert False, "Should raise ValueError on overflow"
|
||
except ValueError as e:
|
||
assert "Sequence position" in str(e)
|
||
print(f"✅ Overflow protection works: {str(e)[:50]}...")
|
||
|
||
# Test 4: Invalid layer index
|
||
try:
|
||
cache.get(layer_idx=99)
|
||
assert False, "Should raise ValueError for invalid layer"
|
||
except ValueError as e:
|
||
assert "Layer index" in str(e)
|
||
print(f"✅ Layer bounds checking works: {str(e)[:50]}...")
|
||
|
||
|
||
def test_kv_cache_integration_with_attention():
|
||
"""Test KV cache integration with MultiHeadAttention."""
|
||
print("\n" + "="*70)
|
||
print("🧪 Integration Test: KV Cache with MultiHeadAttention")
|
||
print("="*70)
|
||
|
||
batch_size, seq_len, embed_dim = 1, 4, 64
|
||
num_heads = 4
|
||
|
||
# Create attention module
|
||
attn = MultiHeadAttention(embed_dim=embed_dim, num_heads=num_heads)
|
||
|
||
# Create input sequence
|
||
x = Tensor(np.random.randn(batch_size, seq_len, embed_dim))
|
||
|
||
# Standard attention (no cache)
|
||
output_standard = attn.forward(x)
|
||
|
||
print(f"✅ Standard attention output shape: {output_standard.shape}")
|
||
print(f" Expected: ({batch_size}, {seq_len}, {embed_dim})")
|
||
|
||
assert output_standard.shape == (batch_size, seq_len, embed_dim), \
|
||
"Attention output shape mismatch"
|
||
|
||
print("\n✅ KV Cache integrates correctly with attention mechanism!")
|
||
print(" (Full cached generation would require model-level integration)")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("\n" + "="*70)
|
||
print("🔬 Module 14: KV Caching Integration Tests")
|
||
print("="*70)
|
||
|
||
# Run all tests
|
||
test_suite = TestKVCacheIntegration()
|
||
|
||
test_suite.test_cache_with_linear_projections()
|
||
test_suite.test_cache_with_multi_layer_transformer()
|
||
test_suite.test_cache_reset_and_reuse()
|
||
test_suite.test_cache_memory_tracking()
|
||
test_suite.test_cache_with_batch_inference()
|
||
test_suite.test_cache_boundary_conditions()
|
||
|
||
test_kv_cache_integration_with_attention()
|
||
|
||
print("\n" + "="*70)
|
||
print("🎉 All Integration Tests Passed!")
|
||
print("="*70)
|
||
print("\n📊 Test Coverage:")
|
||
print(" ✓ Linear projection integration")
|
||
print(" ✓ Multi-layer transformer caching")
|
||
print(" ✓ Cache reset and reuse")
|
||
print(" ✓ Memory tracking accuracy")
|
||
print(" ✓ Batch inference support")
|
||
print(" ✓ Boundary condition handling")
|
||
print(" ✓ MultiHeadAttention compatibility")
|
||
print()
|
||
|