""" Integration Tests for Module 14: KV Caching Tests integration with transformer components and generation """ import numpy as np import sys from pathlib import Path # Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from tinytorch.core.tensor import Tensor from tinytorch.generation.kv_cache import KVCache, enable_kv_cache from tinytorch.core.layers import Linear from tinytorch.core.attention import MultiHeadAttention class TestKVCacheIntegration: """Test KV cache integration with transformer components.""" def test_cache_with_linear_projections(self): """Test that cache works with Linear layer projections (Q, K, V).""" print("\n๐Ÿ”ฌ Test: KV Cache with Linear Projections") # Setup: Small transformer config batch_size, seq_len, embed_dim = 2, 4, 32 num_heads, head_dim = 4, 8 # Create Q, K, V projection layers q_proj = Linear(embed_dim, embed_dim) k_proj = Linear(embed_dim, embed_dim) v_proj = Linear(embed_dim, embed_dim) # Create input x = Tensor(np.random.randn(batch_size, seq_len, embed_dim)) # Project to Q, K, V Q = q_proj.forward(x) K = k_proj.forward(x) V = v_proj.forward(x) # Reshape for multi-head attention # (batch, seq, embed) -> (batch, seq, heads, head_dim) -> (batch, heads, seq, head_dim) Q_heads = Q.data.reshape(batch_size, seq_len, num_heads, head_dim) Q_heads = Tensor(np.transpose(Q_heads, (0, 2, 1, 3))) K_heads = K.data.reshape(batch_size, seq_len, num_heads, head_dim) K_heads = Tensor(np.transpose(K_heads, (0, 2, 1, 3))) V_heads = V.data.reshape(batch_size, seq_len, num_heads, head_dim) V_heads = Tensor(np.transpose(V_heads, (0, 2, 1, 3))) # Create cache cache = KVCache( batch_size=batch_size, max_seq_len=10, num_layers=1, num_heads=num_heads, head_dim=head_dim ) # Simulate autoregressive generation: process tokens one by one for pos in range(seq_len): # Get K, V for current position k_current = Tensor(K_heads.data[:, :, pos:pos+1, :]) # (batch, heads, 1, head_dim) v_current = Tensor(V_heads.data[:, :, pos:pos+1, :]) # Update cache cache.update(layer_idx=0, key=k_current, value=v_current) cache.advance() # Retrieve full cached K, V cached_K, cached_V = cache.get(layer_idx=0) # Verify shapes assert cached_K.shape == (batch_size, num_heads, seq_len, head_dim), \ f"Expected shape {(batch_size, num_heads, seq_len, head_dim)}, got {cached_K.shape}" assert cached_V.shape == (batch_size, num_heads, seq_len, head_dim), \ f"Expected shape {(batch_size, num_heads, seq_len, head_dim)}, got {cached_V.shape}" # Verify cached values match original projections # Note: Small numerical differences okay due to reshape operations diff_k = np.mean(np.abs(cached_K.data - K_heads.data[:, :, :seq_len, :])) diff_v = np.mean(np.abs(cached_V.data - V_heads.data[:, :, :seq_len, :])) assert diff_k < 1e-6, f"Cached K differs from original by {diff_k}" assert diff_v < 1e-6, f"Cached V differs from original by {diff_v}" print("โœ… Cache correctly stores Linear projection outputs") print(f" K difference: {diff_k:.2e}") print(f" V difference: {diff_v:.2e}") def test_cache_with_multi_layer_transformer(self): """Test cache with multiple transformer layers.""" print("\n๐Ÿ”ฌ Test: Multi-Layer Transformer Caching") batch_size, seq_len = 1, 5 num_layers, num_heads, head_dim = 3, 4, 16 # Create cache for 3 layers cache = enable_kv_cache( batch_size=batch_size, max_seq_len=10, num_layers=num_layers, num_heads=num_heads, head_dim=head_dim ) # Simulate processing through 3 layers for pos in range(seq_len): for layer_idx in range(num_layers): # Simulate K, V for current token at this layer k = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim)) v = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim)) cache.update(layer_idx=layer_idx, key=k, value=v) # Advance after all layers processed cache.advance() # Verify each layer has correct cache size for layer_idx in range(num_layers): cached_k, cached_v = cache.get(layer_idx=layer_idx) assert cached_k.shape == (batch_size, num_heads, seq_len, head_dim), \ f"Layer {layer_idx} has wrong cache shape" print(f"โœ… Successfully cached {num_layers} layers ร— {seq_len} tokens") print(f" Total cache memory: {cache.get_memory_usage()['total_mb']:.3f} MB") def test_cache_reset_and_reuse(self): """Test cache can be reset and reused for multiple generations.""" print("\n๐Ÿ”ฌ Test: Cache Reset and Reuse") batch_size, num_layers, num_heads, head_dim = 1, 2, 4, 16 max_seq_len = 10 cache = KVCache(batch_size, max_seq_len, num_layers, num_heads, head_dim) # First generation: 5 tokens for pos in range(5): for layer_idx in range(num_layers): k = Tensor(np.ones((batch_size, num_heads, 1, head_dim)) * pos) v = Tensor(np.ones((batch_size, num_heads, 1, head_dim)) * pos) cache.update(layer_idx, k, v) cache.advance() # Verify first generation cached_k, _ = cache.get(0) assert cached_k.shape[2] == 5, "Should have 5 tokens cached" # Reset cache cache.reset() assert cache.seq_pos == 0, "Position should be reset to 0" cached_k, _ = cache.get(0) assert cached_k.shape[2] == 0, "Cache should be empty after reset" # Second generation: 3 tokens (different from first) for pos in range(3): for layer_idx in range(num_layers): k = Tensor(np.ones((batch_size, num_heads, 1, head_dim)) * (pos + 10)) v = Tensor(np.ones((batch_size, num_heads, 1, head_dim)) * (pos + 10)) cache.update(layer_idx, k, v) cache.advance() # Verify second generation cached_k, _ = cache.get(0) assert cached_k.shape[2] == 3, "Should have 3 tokens cached" # Verify values are from second generation (not first) assert np.allclose(cached_k.data[0, 0, 0, 0], 10.0), "Should have new values" print("โœ… Cache successfully reset and reused") print(" Generation 1: 5 tokens โ†’ reset") print(" Generation 2: 3 tokens (new values)") def test_cache_memory_tracking(self): """Test cache memory usage calculation.""" print("\n๐Ÿ”ฌ Test: Cache Memory Tracking") configs = [ # (batch, max_seq, layers, heads, head_dim, expected_mb_range) (1, 64, 2, 4, 16, (0.1, 0.5)), # Tiny (1, 128, 4, 8, 32, (2.0, 4.0)), # Small (2, 256, 6, 12, 64, (40.0, 60.0)), # Medium ] for batch, max_seq, layers, heads, head_dim, (min_mb, max_mb) in configs: cache = KVCache(batch, max_seq, layers, heads, head_dim) mem_info = cache.get_memory_usage() total_mb = mem_info['total_mb'] assert min_mb <= total_mb <= max_mb, \ f"Memory {total_mb:.2f} MB outside expected range [{min_mb}, {max_mb}]" print(f"โœ… Config (B={batch}, S={max_seq}, L={layers}, H={heads}, D={head_dim})") print(f" Memory: {total_mb:.3f} MB") print(f" Per layer: {mem_info['per_layer_mb']:.3f} MB") def test_cache_with_batch_inference(self): """Test cache supports batch inference (multiple sequences).""" print("\n๐Ÿ”ฌ Test: Batch Inference") batch_size = 4 # Generate 4 sequences in parallel seq_len, num_layers, num_heads, head_dim = 3, 2, 4, 16 cache = enable_kv_cache(batch_size, 10, num_layers, num_heads, head_dim) # Generate 4 sequences in parallel for pos in range(seq_len): for layer_idx in range(num_layers): # Different K, V for each sequence in batch k = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim)) v = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim)) cache.update(layer_idx, k, v) cache.advance() # Verify all sequences cached cached_k, cached_v = cache.get(0) assert cached_k.shape == (batch_size, num_heads, seq_len, head_dim), \ "Batch dimension should be preserved" # Verify sequences are different (not broadcast) seq0 = cached_k.data[0, 0, 0, :] seq1 = cached_k.data[1, 0, 0, :] assert not np.allclose(seq0, seq1), "Sequences should be different" print(f"โœ… Successfully cached {batch_size} parallel sequences") print(f" Shape per sequence: (1, {num_heads}, {seq_len}, {head_dim})") def test_cache_boundary_conditions(self): """Test cache handles boundary conditions correctly.""" print("\n๐Ÿ”ฌ Test: Boundary Conditions") batch_size, max_seq_len = 1, 5 num_layers, num_heads, head_dim = 2, 4, 16 cache = KVCache(batch_size, max_seq_len, num_layers, num_heads, head_dim) # Test 1: Empty cache retrieval cached_k, cached_v = cache.get(0) assert cached_k.shape[2] == 0, "Empty cache should return 0 sequence length" print("โœ… Empty cache returns correct shape") # Test 2: Fill to maximum for pos in range(max_seq_len): k = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim)) v = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim)) cache.update(0, k, v) cache.advance() cached_k, _ = cache.get(0) assert cached_k.shape[2] == max_seq_len, "Should fill to max_seq_len" print(f"โœ… Cache filled to maximum ({max_seq_len} tokens)") # Test 3: Overflow protection try: k = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim)) v = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim)) cache.update(0, k, v) assert False, "Should raise ValueError on overflow" except ValueError as e: assert "Sequence position" in str(e) print(f"โœ… Overflow protection works: {str(e)[:50]}...") # Test 4: Invalid layer index try: cache.get(layer_idx=99) assert False, "Should raise ValueError for invalid layer" except ValueError as e: assert "Layer index" in str(e) print(f"โœ… Layer bounds checking works: {str(e)[:50]}...") def test_kv_cache_integration_with_attention(): """Test KV cache integration with MultiHeadAttention.""" print("\n" + "="*70) print("๐Ÿงช Integration Test: KV Cache with MultiHeadAttention") print("="*70) batch_size, seq_len, embed_dim = 1, 4, 64 num_heads = 4 # Create attention module attn = MultiHeadAttention(embed_dim=embed_dim, num_heads=num_heads) # Create input sequence x = Tensor(np.random.randn(batch_size, seq_len, embed_dim)) # Standard attention (no cache) output_standard = attn.forward(x) print(f"โœ… Standard attention output shape: {output_standard.shape}") print(f" Expected: ({batch_size}, {seq_len}, {embed_dim})") assert output_standard.shape == (batch_size, seq_len, embed_dim), \ "Attention output shape mismatch" print("\nโœ… KV Cache integrates correctly with attention mechanism!") print(" (Full cached generation would require model-level integration)") if __name__ == "__main__": print("\n" + "="*70) print("๐Ÿ”ฌ Module 14: KV Caching Integration Tests") print("="*70) # Run all tests test_suite = TestKVCacheIntegration() test_suite.test_cache_with_linear_projections() test_suite.test_cache_with_multi_layer_transformer() test_suite.test_cache_reset_and_reuse() test_suite.test_cache_memory_tracking() test_suite.test_cache_with_batch_inference() test_suite.test_cache_boundary_conditions() test_kv_cache_integration_with_attention() print("\n" + "="*70) print("๐ŸŽ‰ All Integration Tests Passed!") print("="*70) print("\n๐Ÿ“Š Test Coverage:") print(" โœ“ Linear projection integration") print(" โœ“ Multi-layer transformer caching") print(" โœ“ Cache reset and reuse") print(" โœ“ Memory tracking accuracy") print(" โœ“ Batch inference support") print(" โœ“ Boundary condition handling") print(" โœ“ MultiHeadAttention compatibility") print()