mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-05-30 09:21:44 -05:00
refactor: update KV cache module path to 15_memoization
Module path updated from 14_kvcaching to 15_memoization to reflect optimization tier restructuring
This commit is contained in:
8
tinytorch/generation/kv_cache.py
generated
8
tinytorch/generation/kv_cache.py
generated
@@ -17,7 +17,7 @@
|
||||
# %% auto 0
|
||||
__all__ = ['KVCache', 'enable_kv_cache', 'disable_kv_cache']
|
||||
|
||||
# %% ../../modules/source/14_kvcaching/kvcaching_dev.ipynb 1
|
||||
# %% ../../modules/source/15_memoization/memoization_dev.ipynb 1
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import Tuple, Optional, Dict, List
|
||||
@@ -25,7 +25,7 @@ from typing import Tuple, Optional, Dict, List
|
||||
# Import TinyTorch components from previous modules
|
||||
from ..core.tensor import Tensor
|
||||
|
||||
# %% ../../modules/source/14_kvcaching/kvcaching_dev.ipynb 5
|
||||
# %% ../../modules/source/15_memoization/memoization_dev.ipynb 7
|
||||
class KVCache:
|
||||
"""
|
||||
Efficient key-value cache for autoregressive generation.
|
||||
@@ -298,7 +298,7 @@ class KVCache:
|
||||
'total_elements': total_elements
|
||||
}
|
||||
|
||||
# %% ../../modules/source/14_kvcaching/kvcaching_dev.ipynb 9
|
||||
# %% ../../modules/source/15_memoization/memoization_dev.ipynb 11
|
||||
def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int,
|
||||
num_heads: int, head_dim: int) -> KVCache:
|
||||
"""
|
||||
@@ -351,7 +351,7 @@ def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int,
|
||||
|
||||
return cache
|
||||
|
||||
# %% ../../modules/source/14_kvcaching/kvcaching_dev.ipynb 14
|
||||
# %% ../../modules/source/15_memoization/memoization_dev.ipynb 16
|
||||
def enable_kv_cache(model):
|
||||
"""
|
||||
Enable KV caching for a transformer model WITHOUT modifying Module 12/13 code.
|
||||
|
||||
Reference in New Issue
Block a user