diff --git a/tinytorch/milestones/05_2017_transformer/01_vaswani_attention.py b/tinytorch/milestones/05_2017_transformer/01_vaswani_attention.py index cbe58637d..1041b69fc 100755 --- a/tinytorch/milestones/05_2017_transformer/01_vaswani_attention.py +++ b/tinytorch/milestones/05_2017_transformer/01_vaswani_attention.py @@ -544,6 +544,20 @@ def main(): results['copying'] = (passed2, acc2) # Challenge 3: Mixed Tasks (the real test) + # Reset model and optimizer so it learns both tasks from scratch + # with prefix conditioning (otherwise it's stuck in "copy mode") + console.print("\n[dim]Reinitializing model for mixed task learning...[/dim]") + model = AttentionTransformer( + vocab_size=vocab_size, + embed_dim=embed_dim, + num_heads=num_heads, + seq_len=seq_len + 1, + num_layers=num_layers + ) + for param in model.parameters(): + param.requires_grad = True + optimizer = Adam(model.parameters(), lr=lr) + train_mixed = generate_mixed_data(800, seq_len) test_mixed = generate_mixed_data(300, seq_len) passed3, acc3 = run_challenge(