mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-29 20:38:58 -05:00
fix(package): Add PyTorch-style __call__ methods to exported modules
Resolved transformer training issues by adding __call__ methods to: - Embedding, PositionalEncoding, EmbeddingLayer (text.embeddings) - LayerNorm, MLP, TransformerBlock, GPT (models.transformer) - MultiHeadAttention (core.attention) This enables PyTorch-style syntax: model(x) instead of model.forward(x) All transformer diagnostic tests now pass (5/5 ✓)(https://claude.ai/code)
This commit is contained in:
4
tinytorch/core/attention.py
generated
4
tinytorch/core/attention.py
generated
@@ -280,6 +280,10 @@ class MultiHeadAttention:
|
||||
return output
|
||||
### END SOLUTION
|
||||
|
||||
def __call__(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
||||
"""Allows the attention layer to be called like a function."""
|
||||
return self.forward(x, mask)
|
||||
|
||||
def parameters(self) -> List[Tensor]:
|
||||
"""
|
||||
Return all trainable parameters.
|
||||
|
||||
Reference in New Issue
Block a user