mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-29 23:40:02 -05:00
feat: Complete transformer integration with milestones
- Add tokenization module (tinytorch/text/tokenization.py) - Update Milestone 05 transformer demos (validation, TinyCoder, Shakespeare) - Update book chapters with milestones overview - Update README and integration plan - Sync module notebooks and metadata
This commit is contained in:
465
tinytorch/text/tokenization.py
generated
Normal file
465
tinytorch/text/tokenization.py
generated
Normal file
@@ -0,0 +1,465 @@
|
||||
# ╔═══════════════════════════════════════════════════════════════════════════════╗
|
||||
# ║ 🚨 CRITICAL WARNING 🚨 ║
|
||||
# ║ AUTOGENERATED! DO NOT EDIT! ║
|
||||
# ║ ║
|
||||
# ║ This file is AUTOMATICALLY GENERATED from source modules. ║
|
||||
# ║ ANY CHANGES MADE HERE WILL BE LOST when modules are re-exported! ║
|
||||
# ║ ║
|
||||
# ║ ✅ TO EDIT: modules/source/XX_tokenization/tokenization_dev.py ║
|
||||
# ║ ✅ TO EXPORT: Run 'tito module complete <module_name>' ║
|
||||
# ║ ║
|
||||
# ║ 🛡️ STUDENT PROTECTION: This file contains optimized implementations. ║
|
||||
# ║ Editing it directly may break module functionality and training. ║
|
||||
# ║ ║
|
||||
# ║ 🎓 LEARNING TIP: Work in modules/source/ - that's where real development ║
|
||||
# ║ happens! The tinytorch/ directory is just the compiled output. ║
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = ['Tokenizer', 'CharTokenizer', 'BPETokenizer']
|
||||
|
||||
# %% ../../modules/source/10_tokenization/tokenization_dev.ipynb 0
|
||||
#| default_exp text.tokenization
|
||||
#| export
|
||||
|
||||
# %% ../../modules/source/10_tokenization/tokenization_dev.ipynb 8
|
||||
class Tokenizer:
|
||||
"""
|
||||
Base tokenizer class providing the interface for all tokenizers.
|
||||
|
||||
This defines the contract that all tokenizers must follow:
|
||||
- encode(): text → list of token IDs
|
||||
- decode(): list of token IDs → text
|
||||
"""
|
||||
|
||||
def encode(self, text: str) -> List[int]:
|
||||
"""
|
||||
Convert text to a list of token IDs.
|
||||
|
||||
TODO: Implement encoding logic in subclasses
|
||||
|
||||
APPROACH:
|
||||
1. Subclasses will override this method
|
||||
2. Return list of integer token IDs
|
||||
|
||||
EXAMPLE:
|
||||
>>> tokenizer = CharTokenizer(['a', 'b', 'c'])
|
||||
>>> tokenizer.encode("abc")
|
||||
[0, 1, 2]
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
raise NotImplementedError("Subclasses must implement encode()")
|
||||
### END SOLUTION
|
||||
|
||||
def decode(self, tokens: List[int]) -> str:
|
||||
"""
|
||||
Convert list of token IDs back to text.
|
||||
|
||||
TODO: Implement decoding logic in subclasses
|
||||
|
||||
APPROACH:
|
||||
1. Subclasses will override this method
|
||||
2. Return reconstructed text string
|
||||
|
||||
EXAMPLE:
|
||||
>>> tokenizer = CharTokenizer(['a', 'b', 'c'])
|
||||
>>> tokenizer.decode([0, 1, 2])
|
||||
"abc"
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
raise NotImplementedError("Subclasses must implement decode()")
|
||||
### END SOLUTION
|
||||
|
||||
# %% ../../modules/source/10_tokenization/tokenization_dev.ipynb 11
|
||||
class CharTokenizer(Tokenizer):
|
||||
"""
|
||||
Character-level tokenizer that treats each character as a separate token.
|
||||
|
||||
This is the simplest tokenization approach - every character in the
|
||||
vocabulary gets its own unique ID.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab: Optional[List[str]] = None):
|
||||
"""
|
||||
Initialize character tokenizer.
|
||||
|
||||
TODO: Set up vocabulary mappings
|
||||
|
||||
APPROACH:
|
||||
1. Store vocabulary list
|
||||
2. Create char→id and id→char mappings
|
||||
3. Handle special tokens (unknown character)
|
||||
|
||||
EXAMPLE:
|
||||
>>> tokenizer = CharTokenizer(['a', 'b', 'c'])
|
||||
>>> tokenizer.vocab_size
|
||||
4 # 3 chars + 1 unknown token
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
if vocab is None:
|
||||
vocab = []
|
||||
|
||||
# Add special unknown token
|
||||
self.vocab = ['<UNK>'] + vocab
|
||||
self.vocab_size = len(self.vocab)
|
||||
|
||||
# Create bidirectional mappings
|
||||
self.char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||
self.id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
|
||||
|
||||
# Store unknown token ID
|
||||
self.unk_id = 0
|
||||
### END SOLUTION
|
||||
|
||||
def build_vocab(self, corpus: List[str]) -> None:
|
||||
"""
|
||||
Build vocabulary from a corpus of text.
|
||||
|
||||
TODO: Extract unique characters and build vocabulary
|
||||
|
||||
APPROACH:
|
||||
1. Collect all unique characters from corpus
|
||||
2. Sort for consistent ordering
|
||||
3. Rebuild mappings with new vocabulary
|
||||
|
||||
HINTS:
|
||||
- Use set() to find unique characters
|
||||
- Join all texts then convert to set
|
||||
- Don't forget the <UNK> token
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
# Collect all unique characters
|
||||
all_chars = set()
|
||||
for text in corpus:
|
||||
all_chars.update(text)
|
||||
|
||||
# Sort for consistent ordering
|
||||
unique_chars = sorted(list(all_chars))
|
||||
|
||||
# Rebuild vocabulary with <UNK> token first
|
||||
self.vocab = ['<UNK>'] + unique_chars
|
||||
self.vocab_size = len(self.vocab)
|
||||
|
||||
# Rebuild mappings
|
||||
self.char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||
self.id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
|
||||
### END SOLUTION
|
||||
|
||||
def encode(self, text: str) -> List[int]:
|
||||
"""
|
||||
Encode text to list of character IDs.
|
||||
|
||||
TODO: Convert each character to its vocabulary ID
|
||||
|
||||
APPROACH:
|
||||
1. Iterate through each character in text
|
||||
2. Look up character ID in vocabulary
|
||||
3. Use unknown token ID for unseen characters
|
||||
|
||||
EXAMPLE:
|
||||
>>> tokenizer = CharTokenizer(['h', 'e', 'l', 'o'])
|
||||
>>> tokenizer.encode("hello")
|
||||
[1, 2, 3, 3, 4] # maps to h,e,l,l,o
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
tokens = []
|
||||
for char in text:
|
||||
tokens.append(self.char_to_id.get(char, self.unk_id))
|
||||
return tokens
|
||||
### END SOLUTION
|
||||
|
||||
def decode(self, tokens: List[int]) -> str:
|
||||
"""
|
||||
Decode list of token IDs back to text.
|
||||
|
||||
TODO: Convert each token ID back to its character
|
||||
|
||||
APPROACH:
|
||||
1. Look up each token ID in vocabulary
|
||||
2. Join characters into string
|
||||
3. Handle invalid token IDs gracefully
|
||||
|
||||
EXAMPLE:
|
||||
>>> tokenizer = CharTokenizer(['h', 'e', 'l', 'o'])
|
||||
>>> tokenizer.decode([1, 2, 3, 3, 4])
|
||||
"hello"
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
chars = []
|
||||
for token_id in tokens:
|
||||
# Use unknown token for invalid IDs
|
||||
char = self.id_to_char.get(token_id, '<UNK>')
|
||||
chars.append(char)
|
||||
return ''.join(chars)
|
||||
### END SOLUTION
|
||||
|
||||
# %% ../../modules/source/10_tokenization/tokenization_dev.ipynb 15
|
||||
class BPETokenizer(Tokenizer):
|
||||
"""
|
||||
Byte Pair Encoding (BPE) tokenizer that learns subword units.
|
||||
|
||||
BPE works by:
|
||||
1. Starting with character-level vocabulary
|
||||
2. Finding most frequent character pairs
|
||||
3. Merging frequent pairs into single tokens
|
||||
4. Repeating until desired vocabulary size
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int = 1000):
|
||||
"""
|
||||
Initialize BPE tokenizer.
|
||||
|
||||
TODO: Set up basic tokenizer state
|
||||
|
||||
APPROACH:
|
||||
1. Store target vocabulary size
|
||||
2. Initialize empty vocabulary and merge rules
|
||||
3. Set up mappings for encoding/decoding
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab = []
|
||||
self.merges = [] # List of (pair, new_token) merges
|
||||
self.token_to_id = {}
|
||||
self.id_to_token = {}
|
||||
### END SOLUTION
|
||||
|
||||
def _get_word_tokens(self, word: str) -> List[str]:
|
||||
"""
|
||||
Convert word to list of characters with end-of-word marker.
|
||||
|
||||
TODO: Tokenize word into character sequence
|
||||
|
||||
APPROACH:
|
||||
1. Split word into characters
|
||||
2. Add </w> marker to last character
|
||||
3. Return list of tokens
|
||||
|
||||
EXAMPLE:
|
||||
>>> tokenizer._get_word_tokens("hello")
|
||||
['h', 'e', 'l', 'l', 'o</w>']
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
if not word:
|
||||
return []
|
||||
|
||||
tokens = list(word)
|
||||
tokens[-1] += '</w>' # Mark end of word
|
||||
return tokens
|
||||
### END SOLUTION
|
||||
|
||||
def _get_pairs(self, word_tokens: List[str]) -> Set[Tuple[str, str]]:
|
||||
"""
|
||||
Get all adjacent pairs from word tokens.
|
||||
|
||||
TODO: Extract all consecutive character pairs
|
||||
|
||||
APPROACH:
|
||||
1. Iterate through adjacent tokens
|
||||
2. Create pairs of consecutive tokens
|
||||
3. Return set of unique pairs
|
||||
|
||||
EXAMPLE:
|
||||
>>> tokenizer._get_pairs(['h', 'e', 'l', 'l', 'o</w>'])
|
||||
{('h', 'e'), ('e', 'l'), ('l', 'l'), ('l', 'o</w>')}
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
pairs = set()
|
||||
for i in range(len(word_tokens) - 1):
|
||||
pairs.add((word_tokens[i], word_tokens[i + 1]))
|
||||
return pairs
|
||||
### END SOLUTION
|
||||
|
||||
def train(self, corpus: List[str], vocab_size: int = None) -> None:
|
||||
"""
|
||||
Train BPE on corpus to learn merge rules.
|
||||
|
||||
TODO: Implement BPE training algorithm
|
||||
|
||||
APPROACH:
|
||||
1. Build initial character vocabulary
|
||||
2. Count word frequencies in corpus
|
||||
3. Iteratively merge most frequent pairs
|
||||
4. Build final vocabulary and mappings
|
||||
|
||||
HINTS:
|
||||
- Start with character-level tokens
|
||||
- Use frequency counts to guide merging
|
||||
- Stop when vocabulary reaches target size
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
if vocab_size:
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# Count word frequencies
|
||||
word_freq = Counter(corpus)
|
||||
|
||||
# Initialize vocabulary with characters
|
||||
vocab = set()
|
||||
word_tokens = {}
|
||||
|
||||
for word in word_freq:
|
||||
tokens = self._get_word_tokens(word)
|
||||
word_tokens[word] = tokens
|
||||
vocab.update(tokens)
|
||||
|
||||
# Convert to sorted list for consistency
|
||||
self.vocab = sorted(list(vocab))
|
||||
|
||||
# Add special tokens
|
||||
if '<UNK>' not in self.vocab:
|
||||
self.vocab = ['<UNK>'] + self.vocab
|
||||
|
||||
# Learn merges
|
||||
self.merges = []
|
||||
|
||||
while len(self.vocab) < self.vocab_size:
|
||||
# Count all pairs across all words
|
||||
pair_counts = Counter()
|
||||
|
||||
for word, freq in word_freq.items():
|
||||
tokens = word_tokens[word]
|
||||
pairs = self._get_pairs(tokens)
|
||||
for pair in pairs:
|
||||
pair_counts[pair] += freq
|
||||
|
||||
if not pair_counts:
|
||||
break
|
||||
|
||||
# Get most frequent pair
|
||||
best_pair = pair_counts.most_common(1)[0][0]
|
||||
|
||||
# Merge this pair in all words
|
||||
for word in word_tokens:
|
||||
tokens = word_tokens[word]
|
||||
new_tokens = []
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
if (i < len(tokens) - 1 and
|
||||
tokens[i] == best_pair[0] and
|
||||
tokens[i + 1] == best_pair[1]):
|
||||
# Merge pair
|
||||
new_tokens.append(best_pair[0] + best_pair[1])
|
||||
i += 2
|
||||
else:
|
||||
new_tokens.append(tokens[i])
|
||||
i += 1
|
||||
word_tokens[word] = new_tokens
|
||||
|
||||
# Add merged token to vocabulary
|
||||
merged_token = best_pair[0] + best_pair[1]
|
||||
self.vocab.append(merged_token)
|
||||
self.merges.append(best_pair)
|
||||
|
||||
# Build final mappings
|
||||
self._build_mappings()
|
||||
### END SOLUTION
|
||||
|
||||
def _build_mappings(self):
|
||||
"""Build token-to-ID and ID-to-token mappings."""
|
||||
### BEGIN SOLUTION
|
||||
self.token_to_id = {token: idx for idx, token in enumerate(self.vocab)}
|
||||
self.id_to_token = {idx: token for idx, token in enumerate(self.vocab)}
|
||||
### END SOLUTION
|
||||
|
||||
def _apply_merges(self, tokens: List[str]) -> List[str]:
|
||||
"""
|
||||
Apply learned merge rules to token sequence.
|
||||
|
||||
TODO: Apply BPE merges to token list
|
||||
|
||||
APPROACH:
|
||||
1. Start with character-level tokens
|
||||
2. Apply each merge rule in order
|
||||
3. Continue until no more merges possible
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
if not self.merges:
|
||||
return tokens
|
||||
|
||||
for merge_pair in self.merges:
|
||||
new_tokens = []
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
if (i < len(tokens) - 1 and
|
||||
tokens[i] == merge_pair[0] and
|
||||
tokens[i + 1] == merge_pair[1]):
|
||||
# Apply merge
|
||||
new_tokens.append(merge_pair[0] + merge_pair[1])
|
||||
i += 2
|
||||
else:
|
||||
new_tokens.append(tokens[i])
|
||||
i += 1
|
||||
tokens = new_tokens
|
||||
|
||||
return tokens
|
||||
### END SOLUTION
|
||||
|
||||
def encode(self, text: str) -> List[int]:
|
||||
"""
|
||||
Encode text using BPE.
|
||||
|
||||
TODO: Apply BPE encoding to text
|
||||
|
||||
APPROACH:
|
||||
1. Split text into words
|
||||
2. Convert each word to character tokens
|
||||
3. Apply BPE merges
|
||||
4. Convert to token IDs
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
if not self.vocab:
|
||||
return []
|
||||
|
||||
# Simple word splitting (could be more sophisticated)
|
||||
words = text.split()
|
||||
all_tokens = []
|
||||
|
||||
for word in words:
|
||||
# Get character-level tokens
|
||||
word_tokens = self._get_word_tokens(word)
|
||||
|
||||
# Apply BPE merges
|
||||
merged_tokens = self._apply_merges(word_tokens)
|
||||
|
||||
all_tokens.extend(merged_tokens)
|
||||
|
||||
# Convert to IDs
|
||||
token_ids = []
|
||||
for token in all_tokens:
|
||||
token_ids.append(self.token_to_id.get(token, 0)) # 0 = <UNK>
|
||||
|
||||
return token_ids
|
||||
### END SOLUTION
|
||||
|
||||
def decode(self, tokens: List[int]) -> str:
|
||||
"""
|
||||
Decode token IDs back to text.
|
||||
|
||||
TODO: Convert token IDs back to readable text
|
||||
|
||||
APPROACH:
|
||||
1. Convert IDs to tokens
|
||||
2. Join tokens together
|
||||
3. Clean up word boundaries and markers
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
if not self.id_to_token:
|
||||
return ""
|
||||
|
||||
# Convert IDs to tokens
|
||||
token_strings = []
|
||||
for token_id in tokens:
|
||||
token = self.id_to_token.get(token_id, '<UNK>')
|
||||
token_strings.append(token)
|
||||
|
||||
# Join and clean up
|
||||
text = ''.join(token_strings)
|
||||
|
||||
# Replace end-of-word markers with spaces
|
||||
text = text.replace('</w>', ' ')
|
||||
|
||||
# Clean up extra spaces
|
||||
text = ' '.join(text.split())
|
||||
|
||||
return text
|
||||
### END SOLUTION
|
||||
Reference in New Issue
Block a user