Files
TinyTorch/tests/14_profiling/test_kv_cache_integration.py
Vijay Janapa Reddi 3b1922c653 Rename test directories to match restructured modules
- Rename tests/14_kvcaching to tests/14_profiling
- Rename tests/15_profiling to tests/15_memoization
- Aligns test structure with optimization tier reorganization
2025-11-10 06:21:04 -05:00

334 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Integration Tests for Module 14: KV Caching
Tests integration with transformer components and generation
"""
import numpy as np
import sys
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from tinytorch.core.tensor import Tensor
from tinytorch.generation.kv_cache import KVCache, enable_kv_cache
from tinytorch.core.layers import Linear
from tinytorch.core.attention import MultiHeadAttention
class TestKVCacheIntegration:
"""Test KV cache integration with transformer components."""
def test_cache_with_linear_projections(self):
"""Test that cache works with Linear layer projections (Q, K, V)."""
print("\n🔬 Test: KV Cache with Linear Projections")
# Setup: Small transformer config
batch_size, seq_len, embed_dim = 2, 4, 32
num_heads, head_dim = 4, 8
# Create Q, K, V projection layers
q_proj = Linear(embed_dim, embed_dim)
k_proj = Linear(embed_dim, embed_dim)
v_proj = Linear(embed_dim, embed_dim)
# Create input
x = Tensor(np.random.randn(batch_size, seq_len, embed_dim))
# Project to Q, K, V
Q = q_proj.forward(x)
K = k_proj.forward(x)
V = v_proj.forward(x)
# Reshape for multi-head attention
# (batch, seq, embed) -> (batch, seq, heads, head_dim) -> (batch, heads, seq, head_dim)
Q_heads = Q.data.reshape(batch_size, seq_len, num_heads, head_dim)
Q_heads = Tensor(np.transpose(Q_heads, (0, 2, 1, 3)))
K_heads = K.data.reshape(batch_size, seq_len, num_heads, head_dim)
K_heads = Tensor(np.transpose(K_heads, (0, 2, 1, 3)))
V_heads = V.data.reshape(batch_size, seq_len, num_heads, head_dim)
V_heads = Tensor(np.transpose(V_heads, (0, 2, 1, 3)))
# Create cache
cache = KVCache(
batch_size=batch_size,
max_seq_len=10,
num_layers=1,
num_heads=num_heads,
head_dim=head_dim
)
# Simulate autoregressive generation: process tokens one by one
for pos in range(seq_len):
# Get K, V for current position
k_current = Tensor(K_heads.data[:, :, pos:pos+1, :]) # (batch, heads, 1, head_dim)
v_current = Tensor(V_heads.data[:, :, pos:pos+1, :])
# Update cache
cache.update(layer_idx=0, key=k_current, value=v_current)
cache.advance()
# Retrieve full cached K, V
cached_K, cached_V = cache.get(layer_idx=0)
# Verify shapes
assert cached_K.shape == (batch_size, num_heads, seq_len, head_dim), \
f"Expected shape {(batch_size, num_heads, seq_len, head_dim)}, got {cached_K.shape}"
assert cached_V.shape == (batch_size, num_heads, seq_len, head_dim), \
f"Expected shape {(batch_size, num_heads, seq_len, head_dim)}, got {cached_V.shape}"
# Verify cached values match original projections
# Note: Small numerical differences okay due to reshape operations
diff_k = np.mean(np.abs(cached_K.data - K_heads.data[:, :, :seq_len, :]))
diff_v = np.mean(np.abs(cached_V.data - V_heads.data[:, :, :seq_len, :]))
assert diff_k < 1e-6, f"Cached K differs from original by {diff_k}"
assert diff_v < 1e-6, f"Cached V differs from original by {diff_v}"
print("✅ Cache correctly stores Linear projection outputs")
print(f" K difference: {diff_k:.2e}")
print(f" V difference: {diff_v:.2e}")
def test_cache_with_multi_layer_transformer(self):
"""Test cache with multiple transformer layers."""
print("\n🔬 Test: Multi-Layer Transformer Caching")
batch_size, seq_len = 1, 5
num_layers, num_heads, head_dim = 3, 4, 16
# Create cache for 3 layers
cache = enable_kv_cache(
batch_size=batch_size,
max_seq_len=10,
num_layers=num_layers,
num_heads=num_heads,
head_dim=head_dim
)
# Simulate processing through 3 layers
for pos in range(seq_len):
for layer_idx in range(num_layers):
# Simulate K, V for current token at this layer
k = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
v = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
cache.update(layer_idx=layer_idx, key=k, value=v)
# Advance after all layers processed
cache.advance()
# Verify each layer has correct cache size
for layer_idx in range(num_layers):
cached_k, cached_v = cache.get(layer_idx=layer_idx)
assert cached_k.shape == (batch_size, num_heads, seq_len, head_dim), \
f"Layer {layer_idx} has wrong cache shape"
print(f"✅ Successfully cached {num_layers} layers × {seq_len} tokens")
print(f" Total cache memory: {cache.get_memory_usage()['total_mb']:.3f} MB")
def test_cache_reset_and_reuse(self):
"""Test cache can be reset and reused for multiple generations."""
print("\n🔬 Test: Cache Reset and Reuse")
batch_size, num_layers, num_heads, head_dim = 1, 2, 4, 16
max_seq_len = 10
cache = KVCache(batch_size, max_seq_len, num_layers, num_heads, head_dim)
# First generation: 5 tokens
for pos in range(5):
for layer_idx in range(num_layers):
k = Tensor(np.ones((batch_size, num_heads, 1, head_dim)) * pos)
v = Tensor(np.ones((batch_size, num_heads, 1, head_dim)) * pos)
cache.update(layer_idx, k, v)
cache.advance()
# Verify first generation
cached_k, _ = cache.get(0)
assert cached_k.shape[2] == 5, "Should have 5 tokens cached"
# Reset cache
cache.reset()
assert cache.seq_pos == 0, "Position should be reset to 0"
cached_k, _ = cache.get(0)
assert cached_k.shape[2] == 0, "Cache should be empty after reset"
# Second generation: 3 tokens (different from first)
for pos in range(3):
for layer_idx in range(num_layers):
k = Tensor(np.ones((batch_size, num_heads, 1, head_dim)) * (pos + 10))
v = Tensor(np.ones((batch_size, num_heads, 1, head_dim)) * (pos + 10))
cache.update(layer_idx, k, v)
cache.advance()
# Verify second generation
cached_k, _ = cache.get(0)
assert cached_k.shape[2] == 3, "Should have 3 tokens cached"
# Verify values are from second generation (not first)
assert np.allclose(cached_k.data[0, 0, 0, 0], 10.0), "Should have new values"
print("✅ Cache successfully reset and reused")
print(" Generation 1: 5 tokens → reset")
print(" Generation 2: 3 tokens (new values)")
def test_cache_memory_tracking(self):
"""Test cache memory usage calculation."""
print("\n🔬 Test: Cache Memory Tracking")
configs = [
# (batch, max_seq, layers, heads, head_dim, expected_mb_range)
(1, 64, 2, 4, 16, (0.1, 0.5)), # Tiny
(1, 128, 4, 8, 32, (2.0, 4.0)), # Small
(2, 256, 6, 12, 64, (40.0, 60.0)), # Medium
]
for batch, max_seq, layers, heads, head_dim, (min_mb, max_mb) in configs:
cache = KVCache(batch, max_seq, layers, heads, head_dim)
mem_info = cache.get_memory_usage()
total_mb = mem_info['total_mb']
assert min_mb <= total_mb <= max_mb, \
f"Memory {total_mb:.2f} MB outside expected range [{min_mb}, {max_mb}]"
print(f"✅ Config (B={batch}, S={max_seq}, L={layers}, H={heads}, D={head_dim})")
print(f" Memory: {total_mb:.3f} MB")
print(f" Per layer: {mem_info['per_layer_mb']:.3f} MB")
def test_cache_with_batch_inference(self):
"""Test cache supports batch inference (multiple sequences)."""
print("\n🔬 Test: Batch Inference")
batch_size = 4 # Generate 4 sequences in parallel
seq_len, num_layers, num_heads, head_dim = 3, 2, 4, 16
cache = enable_kv_cache(batch_size, 10, num_layers, num_heads, head_dim)
# Generate 4 sequences in parallel
for pos in range(seq_len):
for layer_idx in range(num_layers):
# Different K, V for each sequence in batch
k = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
v = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
cache.update(layer_idx, k, v)
cache.advance()
# Verify all sequences cached
cached_k, cached_v = cache.get(0)
assert cached_k.shape == (batch_size, num_heads, seq_len, head_dim), \
"Batch dimension should be preserved"
# Verify sequences are different (not broadcast)
seq0 = cached_k.data[0, 0, 0, :]
seq1 = cached_k.data[1, 0, 0, :]
assert not np.allclose(seq0, seq1), "Sequences should be different"
print(f"✅ Successfully cached {batch_size} parallel sequences")
print(f" Shape per sequence: (1, {num_heads}, {seq_len}, {head_dim})")
def test_cache_boundary_conditions(self):
"""Test cache handles boundary conditions correctly."""
print("\n🔬 Test: Boundary Conditions")
batch_size, max_seq_len = 1, 5
num_layers, num_heads, head_dim = 2, 4, 16
cache = KVCache(batch_size, max_seq_len, num_layers, num_heads, head_dim)
# Test 1: Empty cache retrieval
cached_k, cached_v = cache.get(0)
assert cached_k.shape[2] == 0, "Empty cache should return 0 sequence length"
print("✅ Empty cache returns correct shape")
# Test 2: Fill to maximum
for pos in range(max_seq_len):
k = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
v = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
cache.update(0, k, v)
cache.advance()
cached_k, _ = cache.get(0)
assert cached_k.shape[2] == max_seq_len, "Should fill to max_seq_len"
print(f"✅ Cache filled to maximum ({max_seq_len} tokens)")
# Test 3: Overflow protection
try:
k = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
v = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
cache.update(0, k, v)
assert False, "Should raise ValueError on overflow"
except ValueError as e:
assert "Sequence position" in str(e)
print(f"✅ Overflow protection works: {str(e)[:50]}...")
# Test 4: Invalid layer index
try:
cache.get(layer_idx=99)
assert False, "Should raise ValueError for invalid layer"
except ValueError as e:
assert "Layer index" in str(e)
print(f"✅ Layer bounds checking works: {str(e)[:50]}...")
def test_kv_cache_integration_with_attention():
"""Test KV cache integration with MultiHeadAttention."""
print("\n" + "="*70)
print("🧪 Integration Test: KV Cache with MultiHeadAttention")
print("="*70)
batch_size, seq_len, embed_dim = 1, 4, 64
num_heads = 4
# Create attention module
attn = MultiHeadAttention(embed_dim=embed_dim, num_heads=num_heads)
# Create input sequence
x = Tensor(np.random.randn(batch_size, seq_len, embed_dim))
# Standard attention (no cache)
output_standard = attn.forward(x)
print(f"✅ Standard attention output shape: {output_standard.shape}")
print(f" Expected: ({batch_size}, {seq_len}, {embed_dim})")
assert output_standard.shape == (batch_size, seq_len, embed_dim), \
"Attention output shape mismatch"
print("\n✅ KV Cache integrates correctly with attention mechanism!")
print(" (Full cached generation would require model-level integration)")
if __name__ == "__main__":
print("\n" + "="*70)
print("🔬 Module 14: KV Caching Integration Tests")
print("="*70)
# Run all tests
test_suite = TestKVCacheIntegration()
test_suite.test_cache_with_linear_projections()
test_suite.test_cache_with_multi_layer_transformer()
test_suite.test_cache_reset_and_reuse()
test_suite.test_cache_memory_tracking()
test_suite.test_cache_with_batch_inference()
test_suite.test_cache_boundary_conditions()
test_kv_cache_integration_with_attention()
print("\n" + "="*70)
print("🎉 All Integration Tests Passed!")
print("="*70)
print("\n📊 Test Coverage:")
print(" ✓ Linear projection integration")
print(" ✓ Multi-layer transformer caching")
print(" ✓ Cache reset and reuse")
print(" ✓ Memory tracking accuracy")
print(" ✓ Batch inference support")
print(" ✓ Boundary condition handling")
print(" ✓ MultiHeadAttention compatibility")
print()