diff --git a/tinytorch/generation/kv_cache.py b/tinytorch/generation/kv_cache.py index 1cbc93cf..55d8504b 100644 --- a/tinytorch/generation/kv_cache.py +++ b/tinytorch/generation/kv_cache.py @@ -17,7 +17,7 @@ # %% auto 0 __all__ = ['KVCache', 'enable_kv_cache', 'disable_kv_cache'] -# %% ../../modules/source/14_kvcaching/kvcaching_dev.ipynb 1 +# %% ../../modules/source/15_memoization/memoization_dev.ipynb 1 import numpy as np import time from typing import Tuple, Optional, Dict, List @@ -25,7 +25,7 @@ from typing import Tuple, Optional, Dict, List # Import TinyTorch components from previous modules from ..core.tensor import Tensor -# %% ../../modules/source/14_kvcaching/kvcaching_dev.ipynb 5 +# %% ../../modules/source/15_memoization/memoization_dev.ipynb 7 class KVCache: """ Efficient key-value cache for autoregressive generation. @@ -298,7 +298,7 @@ class KVCache: 'total_elements': total_elements } -# %% ../../modules/source/14_kvcaching/kvcaching_dev.ipynb 9 +# %% ../../modules/source/15_memoization/memoization_dev.ipynb 11 def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int, num_heads: int, head_dim: int) -> KVCache: """ @@ -351,7 +351,7 @@ def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int, return cache -# %% ../../modules/source/14_kvcaching/kvcaching_dev.ipynb 14 +# %% ../../modules/source/15_memoization/memoization_dev.ipynb 16 def enable_kv_cache(model): """ Enable KV caching for a transformer model WITHOUT modifying Module 12/13 code.