mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-30 09:38:39 -05:00
Fix regression tests for current API
- Update TransformerBlock to use mlp_ratio instead of hidden_dim - Update PositionalEncoding argument order - Fix MultiHeadAttention to use self-attention API - Add missing MultiHeadAttention import
This commit is contained in:
@@ -42,7 +42,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from tinytorch.core.layers import Linear
|
||||
from tinytorch.nn import TransformerBlock, Embedding, PositionalEncoding
|
||||
from tinytorch.nn import TransformerBlock, Embedding, PositionalEncoding, MultiHeadAttention
|
||||
|
||||
|
||||
def test_transformer_to_linear_3d_to_2d():
|
||||
@@ -63,8 +63,8 @@ def test_transformer_to_linear_3d_to_2d():
|
||||
transformer = TransformerBlock(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
hidden_dim=embed_dim * 4,
|
||||
dropout=0.1
|
||||
mlp_ratio=4,
|
||||
dropout_prob=0.1
|
||||
)
|
||||
output_proj = Linear(embed_dim, vocab_size)
|
||||
|
||||
@@ -140,8 +140,8 @@ def test_full_gpt_architecture_shapes():
|
||||
assert x.shape == (batch_size, seq_length, embed_dim)
|
||||
print(f"After embedding: {x.shape}")
|
||||
|
||||
# Positional encoding
|
||||
pos_enc = PositionalEncoding(embed_dim, max_seq_length=seq_length)
|
||||
# Positional encoding (max_seq_len, embed_dim)
|
||||
pos_enc = PositionalEncoding(seq_length, embed_dim)
|
||||
x = pos_enc(x)
|
||||
assert x.shape == (batch_size, seq_length, embed_dim)
|
||||
print(f"After positional encoding: {x.shape}")
|
||||
@@ -151,7 +151,7 @@ def test_full_gpt_architecture_shapes():
|
||||
transformer = TransformerBlock(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
hidden_dim=embed_dim * 4
|
||||
mlp_ratio=4
|
||||
)
|
||||
x = transformer(x)
|
||||
assert x.shape == (batch_size, seq_length, embed_dim)
|
||||
@@ -187,27 +187,25 @@ def test_attention_kv_cache_shapes():
|
||||
embed_dim = 128
|
||||
num_heads = 4
|
||||
|
||||
# Multi-head attention with KV cache
|
||||
# Multi-head attention
|
||||
mha = MultiHeadAttention(embed_dim, num_heads)
|
||||
|
||||
# Initial forward pass
|
||||
x = Tensor(np.random.randn(batch_size, seq_length, embed_dim))
|
||||
|
||||
# Without cache
|
||||
output = mha(x, x, x)
|
||||
# Self-attention (Q, K, V all derived from x)
|
||||
output = mha(x)
|
||||
assert output.shape == (batch_size, seq_length, embed_dim)
|
||||
print(f"MHA output (no cache): {output.shape}")
|
||||
print(f"MHA output: {output.shape}")
|
||||
|
||||
# With cache (for autoregressive generation)
|
||||
# Process one token at a time
|
||||
# Process one token at a time (for autoregressive generation)
|
||||
for t in range(seq_length):
|
||||
x_t = x[:, t:t+1, :] # Single token
|
||||
output_t = mha(x_t, x_t, x_t)
|
||||
output_t = mha(x_t)
|
||||
assert output_t.shape == (batch_size, 1, embed_dim)
|
||||
print(f" Token {t} output: {output_t.shape}")
|
||||
|
||||
print("✅ KV cache shape handling works correctly!")
|
||||
return True
|
||||
print("✅ Attention shape handling works correctly!")
|
||||
|
||||
|
||||
def test_embedding_dimension_compatibility():
|
||||
|
||||
Reference in New Issue
Block a user