mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-03-12 00:03:35 -05:00
✨ Add Shakespeare dataset to DatasetManager
- Add get_shakespeare() method to download tiny-shakespeare.txt - Downloads from Karpathy's char-rnn repository (1MB corpus) - Returns raw text for character-level language modeling - Follows same pattern as MNIST/CIFAR-10 downloads - Includes test in main() function
This commit is contained in:
@@ -133,6 +133,29 @@ class DatasetManager:
|
||||
print(f"📊 CIFAR-10 loaded: {len(train_data)} training, {len(test_data)} test images")
|
||||
return (train_data, train_labels), (test_data, test_labels)
|
||||
|
||||
def get_shakespeare(self):
|
||||
"""Download and prepare Shakespeare text dataset for transformer milestone."""
|
||||
shakespeare_dir = self.data_dir / "shakespeare"
|
||||
shakespeare_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Shakespeare text file
|
||||
text_file = shakespeare_dir / "tiny-shakespeare.txt"
|
||||
|
||||
if not text_file.exists():
|
||||
# Download from Karpathy's char-rnn repo
|
||||
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
|
||||
print("📥 Downloading tiny-shakespeare.txt...")
|
||||
self.download_with_progress(url, text_file)
|
||||
|
||||
# Load text
|
||||
with open(text_file, 'r', encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
|
||||
print(f"📊 Shakespeare loaded: {len(text):,} characters, {len(text.split()):,} words")
|
||||
print(f" First 100 chars: {text[:100]!r}...")
|
||||
|
||||
return text
|
||||
|
||||
def get_xor_data(self, num_samples=1000):
|
||||
"""Generate XOR problem data for non-linear milestone."""
|
||||
print("🧮 Generating XOR problem data...")
|
||||
@@ -217,6 +240,13 @@ def main():
|
||||
except Exception as e:
|
||||
print(f" CIFAR-10 download failed: {e}")
|
||||
|
||||
print("\n5. Testing Shakespeare Text:")
|
||||
try:
|
||||
text = manager.get_shakespeare()
|
||||
print(f" Length: {len(text):,} characters")
|
||||
except Exception as e:
|
||||
print(f" Shakespeare download failed: {e}")
|
||||
|
||||
print("\n✅ Dataset Manager test complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user