mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2025-12-05 19:17:52 -06:00
Compare commits
15 Commits
277033d2f9
...
244ac44a3e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
244ac44a3e | ||
|
|
3a297fdf64 | ||
|
|
033fa148ff | ||
|
|
7d3e302672 | ||
|
|
c3e37567de | ||
|
|
a0468aad18 | ||
|
|
bd5622c129 | ||
|
|
ccf8a0e797 | ||
|
|
3f1021d448 | ||
|
|
05e29f36f7 | ||
|
|
c31d49d045 | ||
|
|
ee7276c97e | ||
|
|
c3dfa51fb4 | ||
|
|
e1fa4d7f73 | ||
|
|
b3e87f9cca |
24
README.md
24
README.md
@@ -518,15 +518,31 @@ We welcome contributions! See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
|
||||
|
||||
MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
## Related Projects
|
||||
## Related Projects & Name Disambiguation
|
||||
|
||||
We acknowledge several excellent educational ML framework projects with similar names:
|
||||
**Note**: "TinyTorch" is a popular name for educational ML frameworks. This project (MLSysBook/TinyTorch) is the Harvard University course focused on ML systems engineering, part of the [ML Systems Book](https://mlsysbook.ai/tinytorch) ecosystem. We acknowledge and respect other excellent projects with similar names:
|
||||
|
||||
### Educational ML Frameworks:
|
||||
- [tinygrad](https://github.com/tinygrad/tinygrad) - George Hotz's minimalist deep learning framework
|
||||
- [micrograd](https://github.com/karpathy/micrograd) - Andrej Karpathy's tiny autograd engine
|
||||
- [MiniTorch](https://minitorch.github.io/) - Cornell's educational framework
|
||||
- Other TinyTorch implementations - Various educational implementations on GitHub
|
||||
|
||||
**Our TinyTorch** focuses specifically on ML systems engineering with a complete curriculum, NBGrader integration, and production deployment—designed as a comprehensive university course rather than a standalone library.
|
||||
### Other TinyTorch Implementations:
|
||||
- [msarmi9/tinytorch](https://github.com/msarmi9/tinytorch) - Numpy-based deep learning library
|
||||
- [keith2018/TinyTorch](https://github.com/keith2018/TinyTorch) - C++ implementation following PyTorch API
|
||||
- [darglein/TinyTorch](https://github.com/darglein/TinyTorch) - Auto-diff optimization framework
|
||||
- [joey00072/Tinytorch](https://github.com/joey00072/Tinytorch) - Tiny autograd engine
|
||||
- [aspfohl/tinytorch](https://github.com/aspfohl/tinytorch) - Pure-python PyTorch implementation
|
||||
- Several other educational implementations on GitHub
|
||||
|
||||
**Our TinyTorch** distinguishes itself through:
|
||||
- Complete 20-module curriculum (Tensor → Transformers → Optimization → Capstone)
|
||||
- NBGrader integration for classroom deployment
|
||||
- ML systems engineering focus (memory, performance, production deployment)
|
||||
- Part of the [ML Systems Book](https://mlsysbook.ai/tinytorch) educational ecosystem
|
||||
- Designed as a comprehensive university course, not a standalone library
|
||||
|
||||
All these projects share the noble goal of making ML internals accessible through education. We're grateful to be part of this community.
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
|
||||
@@ -204,7 +204,7 @@
|
||||
Harvard University\\[1.5em]
|
||||
\fontsize{10}{12}\selectfont
|
||||
\fontsize{12}{12}\selectfont
|
||||
\textcolor{gray!70}{\href{https://www.tinytorch.ai}{tinytorch.ai}}
|
||||
\textcolor{gray!70}{\href{https://mlsysbook.ai/tinytorch}{mlsysbook.ai/tinytorch}}
|
||||
}
|
||||
|
||||
\date{}
|
||||
@@ -217,7 +217,7 @@
|
||||
|
||||
% Abstract - REVISED: Curriculum design focus
|
||||
\begin{abstract}
|
||||
Machine learning systems engineering requires understanding framework internals: why optimizers consume memory, when computational complexity becomes prohibitive, how to navigate accuracy-latency-memory tradeoffs. Yet current ML education separates algorithms from systems—students learn gradient descent without measuring memory, attention mechanisms without profiling costs, training without understanding optimizer overhead. This divide leaves graduates unable to debug production failures or make informed engineering decisions. We present TinyTorch, a build-from-scratch curriculum where students implement PyTorch's core components (tensors, autograd, optimizers, neural networks) to gain framework transparency. Three pedagogical patterns address the gap: \textbf{progressive disclosure} gradually reveals complexity (gradient features exist from Module 01, activate in Module 05); \textbf{systems-first curriculum} embeds memory profiling from the start; \textbf{historical milestone validation} recreates nearly 70 years of ML breakthroughs (1958--2025) using exclusively student-implemented code. These patterns are grounded in learning theory (situated cognition, cognitive load theory) but represent testable hypotheses requiring empirical validation. The 20-module curriculum (60--80 hours) provides complete open-source infrastructure at \texttt{tinytorch.ai}.
|
||||
Machine learning systems engineering requires understanding framework internals: why optimizers consume memory, when computational complexity becomes prohibitive, how to navigate accuracy-latency-memory tradeoffs. Yet current ML education separates algorithms from systems—students learn gradient descent without measuring memory, attention mechanisms without profiling costs, training without understanding optimizer overhead. This divide leaves graduates unable to debug production failures or make informed engineering decisions. We present TinyTorch, a build-from-scratch curriculum where students implement PyTorch's core components (tensors, autograd, optimizers, neural networks) to gain framework transparency. Three pedagogical patterns address the gap: \textbf{progressive disclosure} gradually reveals complexity (gradient features exist from Module 01, activate in Module 05); \textbf{systems-first curriculum} embeds memory profiling from the start; \textbf{historical milestone validation} recreates nearly 70 years of ML breakthroughs (1958--2025) using exclusively student-implemented code. These patterns are grounded in learning theory (situated cognition, cognitive load theory) but represent testable hypotheses requiring empirical validation. The 20-module curriculum (60--80 hours) provides complete open-source infrastructure at \texttt{mlsysbook.ai/tinytorch}.
|
||||
\end{abstract}
|
||||
|
||||
|
||||
@@ -505,6 +505,9 @@ Empirical validation of learning outcomes remains future work (\Cref{sec:discuss
|
||||
|
||||
This section presents the 20-module curriculum structure, organized into four tiers that progressively build a complete ML framework. Each module enforces prerequisite mastery, ensuring students build on solid foundations.
|
||||
|
||||
\noindent\textbf{Integration with ML Systems Curriculum.}
|
||||
TinyTorch serves as the hands-on implementation companion to the \emph{Machine Learning Systems} textbook~\citep{mlsysbook2025} (\texttt{mlsysbook.ai}), creating synergy between theoretical foundations and systems engineering practice. While the textbook covers the full ML lifecycle—data engineering, training architectures, deployment monitoring, robust operations, and sustainable AI—TinyTorch provides the complementary experience of building core infrastructure from first principles. This integration enables a complete educational pathway: students study production ML systems architecture in the textbook (Chapter 4: distributed training patterns, Chapter 7: quantization strategies), then implement those same abstractions in TinyTorch (Module 05: autograd for backpropagation, Module 15: INT8 quantization). The two resources address different aspects of the same educational gap: understanding both \emph{how production systems work} (textbook's systems architecture perspective) and \emph{how to build them yourself} (TinyTorch's implementation depth). This theory-practice coupling mirrors industry workflows where ML systems engineers both understand architectural patterns and implement infrastructure components.
|
||||
|
||||
\subsection{Prerequisites}
|
||||
|
||||
As established in \Cref{sec:intro}, TinyTorch targets students transitioning from framework users to framework engineers. The curriculum assumes intermediate Python proficiency (comfort with classes, functions, and NumPy array operations) alongside mathematical foundations in linear algebra (matrix multiplication, vectors) and basic calculus (derivatives, chain rule). Students should understand complexity analysis (Big-O notation) and basic algorithms. While prior ML coursework (traditional machine learning or deep learning courses) and data structures courses are helpful, they are not strictly required; motivated students can acquire these foundations concurrently.
|
||||
@@ -596,7 +599,7 @@ The tier then branches into two paths. \textbf{Vision} implements Conv2d with se
|
||||
Students transition from ``models that train'' to ``systems that deploy.'' Profiling (14) teaches measuring time, memory, and FLOPs (floating-point operations), introducing Amdahl's Law: optimizing 70\% of runtime by 2$\times$ yields only 1.53$\times$ overall speedup because the remaining 30\% becomes the new bottleneck. This teaches that optimization is iterative and measurement-driven. Quantization (15) achieves 4$\times$ compression (FP32$\rightarrow$INT8) with 1--2\% accuracy cost. Compression (16) applies pruning and distillation for 10$\times$ shrinkage. Memoization (17) implements KV caching (storing attention keys and values to avoid recomputation), a technique used in production LLM serving: students discover that naive autoregressive generation recomputes attention keys and values at every step, generating 100 tokens requires 5,050 redundant computations (1+2+...+100). By caching these values and reusing them, students transform $O(n^2)$ generation into $O(n)$, achieving 10--100$\times$ speedup and understanding why this optimization is essential in systems like ChatGPT and Claude for economically viable inference. Acceleration (18) vectorizes convolution for 10--100$\times$ gains. Benchmarking (19) teaches rigorous performance measurement.
|
||||
|
||||
\textbf{Capstone (Module 20).}
|
||||
The capstone integrates all 19 modules into production-optimized systems with professional submission infrastructure. Inspired by MLPerf~\citep{reddi2020mlperf}, students optimize prior milestones (CIFAR-10 CNN, transformer generation, or custom architecture) for 10$\times$ faster inference, 4$\times$ smaller size, and sub-100ms latency while maintaining accuracy. Module 20 provides complete benchmarking and submission infrastructure: \texttt{BenchmarkReport} class for collecting model metrics (parameter count, model size, accuracy, latency), \texttt{generate\_submission()} function producing standardized JSON with schema validation, and \texttt{tito community submit} CLI command for leaderboard submission. Students learn professional ML systems workflow: benchmark baseline model, apply optimizations from Modules 14--19 (quantization, pruning, acceleration), benchmark optimized version, generate submission with improvement metrics (speedup, compression ratio, accuracy delta), and validate against required schema. This teaches data-driven optimization mirroring real ML systems engineering while introducing reproducible benchmarking practices essential for production deployment.
|
||||
The capstone integrates all 19 modules into production-optimized systems. Inspired by MLPerf~\citep{reddi2020mlperf}, students optimize prior milestones (CIFAR-10 CNN, transformer generation, or custom architecture) for 10$\times$ faster inference, 4$\times$ smaller size, and sub-100ms latency while maintaining accuracy. Students learn professional ML systems workflow: benchmark baseline model, apply optimizations from Modules 14--19 (quantization, pruning, acceleration), benchmark optimized version, measure improvement metrics (speedup, compression ratio, accuracy delta), and analyze optimization tradeoffs. This teaches data-driven optimization mirroring real ML systems engineering: profiling to identify bottlenecks, principled tradeoffs between accuracy and efficiency, and reproducible benchmarking practices essential for production deployment.
|
||||
|
||||
\subsection{Module Structure}
|
||||
\label{subsec:module-pedagogy}
|
||||
@@ -634,7 +637,7 @@ Second, \textbf{implementation validation beyond unit tests}: Milestones address
|
||||
|
||||
\item \textbf{2017 Transformer Era} (after Module 13): Language generation with attention-based architecture. Validates that attention mechanisms, positional embeddings, and autoregressive sampling function correctly through coherent text generation.
|
||||
|
||||
\item \textbf{2018 MLPerf Benchmark Era} (after Module 20): Production-optimized system integrating all 20 modules, inspired by MLPerf~\citep{reddi2020mlperf}. Students import from every module: \texttt{from tinytorch.nn import Transformer; from tinytorch.optim import Adam; from tinytorch.profiling import profile\_memory; from tinytorch.capstone import BenchmarkReport, generate\_submission}, demonstrating quantization, compression, and acceleration for 10$\times$ faster inference and 4$\times$ smaller models. Students generate standardized benchmark submissions validated against JSON schema, preparing results for community leaderboard comparison.
|
||||
\item \textbf{2018 MLPerf Benchmark Era} (after Module 20): Production-optimized system integrating all 20 modules, inspired by MLPerf~\citep{reddi2020mlperf}. Students import from every module (\texttt{tinytorch.nn}, \texttt{tinytorch.optim}, \texttt{tinytorch.profiling}, \texttt{tinytorch.optimization}) demonstrating quantization, compression, and acceleration for 10$\times$ faster inference and 4$\times$ smaller models while maintaining accuracy. This validates that students can apply optimization techniques systematically and measure impact through reproducible benchmarking.
|
||||
\end{enumerate}
|
||||
|
||||
Each milestone: (1) recreates actual breakthroughs using exclusively student code, (2) uses \emph{only} TinyTorch implementations (no PyTorch/TensorFlow), (3) validates success through task-appropriate performance, and (4) demonstrates architectural comparisons showing why new approaches improved over predecessors.
|
||||
@@ -1086,15 +1089,16 @@ TinyTorch's CPU-only design prioritizes pedagogical transparency, but students b
|
||||
|
||||
\noindent\textbf{Energy and Power Profiling.} Edge deployment and sustainable ML~\citep{strubell2019energy,patterson2021carbon} require understanding energy consumption. Future extensions could integrate power profiling tools enabling students to measure energy costs (joules per inference, watt-hours per training epoch) alongside latency and memory. This connects existing optimization techniques (quantization, pruning) taught in Modules 15--18 to concrete sustainability metrics, particularly relevant for edge AI~\citep{banbury2021benchmarking} where battery life constrains deployment.
|
||||
|
||||
\noindent\textbf{Hardware Simulation Integration.}
|
||||
TinyTorch's current profiling infrastructure—memory tracking (tracemalloc), FLOP counting, and performance benchmarking—provides algorithmic-level performance analysis. A natural extension would integrate architecture simulators (e.g., scale-sim~\citep{samajdar2018scale}, timeloop~\citep{parashar2019timeloop}, astra-sim~\citep{kannan2022astrasim}) to connect high-level ML operations with cycle-accurate hardware models. This layered approach mirrors real ML systems engineering: students first understand algorithmic complexity and memory patterns in TinyTorch, then trace those operations down to microarchitectural performance in simulators. Such integration would complete the educational arc from algorithmic implementation $\rightarrow$ systems profiling $\rightarrow$ hardware realization, enabling first-principle analysis of how model architecture, system configuration (memory hierarchy, compute units, interconnects), and hardware substrate jointly determine production performance. Early discussions with students and collaborators suggest strong pedagogical value in this systems-to-hardware pipeline, maintaining TinyTorch's accessibility (no GPU hardware required) while preparing students for hardware-aware optimization through measurement-driven analysis rather than black-box experimentation.
|
||||
|
||||
\noindent\textbf{Architecture Extensions.} Potential additions (graph neural networks, diffusion models, reinforcement learning) must justify inclusion through systems pedagogy rather than completeness. The question is not ``Can TinyTorch implement this?'' but rather ``Does implementing this teach fundamental systems concepts unavailable through existing modules?'' Graph convolutions might teach sparse tensor operations; diffusion models might illuminate iterative refinement trade-offs. However, extensions succeed only when maintaining TinyTorch's principle: \textbf{every line of code teaches a systems concept}. Community forks demonstrate this philosophy: quantum ML variants replace tensors with quantum state vectors (teaching circuit depth versus memory); robotics forks emphasize RL simulation overhead and real-time constraints. The curriculum remains intentionally incomplete as a production framework: completeness lies in foundational systems thinking applicable across all ML architectures.
|
||||
|
||||
\subsection{Community Adoption and Impact}
|
||||
\subsection{Community and Sustainability}
|
||||
|
||||
TinyTorch serves as the hands-on companion to the Machine Learning Systems textbook, providing practical implementation experience alongside theoretical foundations. Adoption will be measured through multiple channels: (1) \textbf{Educational adoption}: tracking course integrations, student enrollment, and instructor feedback across institutions; (2) \textbf{Capstone community}: inspired by MLPerf benchmarking, the Capstone leaderboard creates competitive systems engineering challenges where students submit optimized implementations competing across accuracy, speed, compression, and efficiency tracks, building community engagement and peer learning; (3) \textbf{Open-source metrics}: GitHub stars, forks, contributions, and community discussions indicating active use beyond formal coursework.
|
||||
As part of the ML Systems Book ecosystem (\texttt{mlsysbook.ai}), TinyTorch benefits from and contributes to broader educational infrastructure. Integration with the textbook's theoretical foundations~\citep{mlsysbook2025} enables a complete pedagogical pathway: students study production ML systems architecture (distributed training patterns, quantization strategies, deployment considerations), then implement those abstractions in TinyTorch (autograd for backpropagation, INT8 quantization, profiling infrastructure). The open-source model (MIT license) and community-driven development enable collaborative refinement across institutions: instructor discussion forums for pedagogical exchange, shared teaching resources, and empirical validation of learning outcomes.
|
||||
|
||||
The submission infrastructure integrates directly into the CLI workflow. Students benchmark their optimized models using Module 20's \texttt{BenchmarkReport} class, generate standardized JSON submissions via \texttt{generate\_submission()}, and submit to the community leaderboard using \texttt{tito community submit submission.json}. The CLI validates submissions against required schema (checking metric ranges, field types, and completeness), displays improvement summary (speedup, compression ratio, accuracy delta), and prepares submissions for leaderboard integration. While leaderboard backend remains under development, the validation infrastructure is production-ready, teaching students professional benchmarking practices: reproducible metrics collection, standardized reporting formats, and schema-driven data validation. Students can also join the global community via \texttt{tito community join} (GitHub-authenticated profiles), view the leaderboard via \texttt{tito community leaderboard} (opens browser), and participate in optimization challenges via \texttt{tito community compete}.
|
||||
|
||||
This multi-faceted approach recognizes that educational impact extends beyond traditional classroom metrics to include community building, peer learning, and long-term skill development. The Capstone platform particularly enables students to see how their implementations compare globally, fostering systems thinking through competitive optimization while maintaining educational focus on understanding internals rather than achieving state-of-the-art performance.
|
||||
Module 20 (Capstone) culminates the curriculum with competitive systems engineering challenges. Inspired by MLPerf benchmarking~\citep{reddi2020mlperf}, students optimize their implementations across accuracy, speed, compression, and efficiency dimensions, comparing results globally through standardized benchmarking infrastructure. This competitive element reinforces systems thinking: optimization requires measurement-driven decisions (profiling bottlenecks), principled tradeoffs (accuracy versus compression), and reproducible methodology (standardized metrics collection). The focus remains pedagogical—understanding \emph{why} optimizations work—rather than achieving state-of-the-art performance, but the competitive framing increases engagement and mirrors real ML engineering workflows.
|
||||
|
||||
\section{Conclusion}
|
||||
\label{sec:conclusion}
|
||||
@@ -1109,7 +1113,7 @@ Three pedagogical contributions enable this transformation. \textbf{Progressive
|
||||
|
||||
\textbf{For educators and bootcamp instructors}: TinyTorch supports flexible integration: self-paced learning requiring zero infrastructure (students run locally on laptops), institutional courses with automated NBGrader assessment, or industry team onboarding for ML engineers transitioning from application development to systems work. The modular structure enables selective adoption: foundation tier only (Modules 01--07, teaching core concepts), architecture focus (adding CNNs and Transformers through Module 13), or complete systems coverage (all 20 modules including optimization and deployment). No GPU access required, no cloud credits needed, no infrastructure barriers.
|
||||
|
||||
The complete codebase, curriculum materials, and assessment infrastructure are openly available at \texttt{tinytorch.ai} under permissive open-source licensing. We invite the global ML education community to adopt TinyTorch in courses, contribute curriculum improvements, translate materials for international accessibility, fork for domain-specific variants (quantum ML, robotics, edge AI), and empirically evaluate whether implementation-based pedagogy achieves its promise. The difference between engineers who know \emph{what} ML systems do and engineers who understand \emph{why} they work begins with understanding what's inside \texttt{loss.backward()}, and TinyTorch makes that understanding accessible to everyone.
|
||||
The complete codebase, curriculum materials, and assessment infrastructure are openly available at \texttt{mlsysbook.ai/tinytorch} (or \texttt{tinytorch.ai}) under permissive open-source licensing. We invite the global ML education community to adopt TinyTorch in courses, contribute curriculum improvements, translate materials for international accessibility, fork for domain-specific variants (quantum ML, robotics, edge AI), and empirically evaluate whether implementation-based pedagogy achieves its promise. The difference between engineers who know \emph{what} ML systems do and engineers who understand \emph{why} they work begins with understanding what's inside \texttt{loss.backward()}, and TinyTorch makes that understanding accessible to everyone.
|
||||
|
||||
\section*{Acknowledgments}
|
||||
|
||||
|
||||
@@ -639,4 +639,36 @@
|
||||
school = {Aalto University},
|
||||
type = {Doctoral dissertation},
|
||||
address = {Espoo, Finland},
|
||||
}
|
||||
}
|
||||
|
||||
@book{mlsysbook2025,
|
||||
author = {Reddi, Vijay Janapa},
|
||||
title = {Machine Learning Systems: Design and Implementation},
|
||||
year = {2025},
|
||||
publisher = {MIT Press},
|
||||
note = {Forthcoming. Early access at \url{https://mlsysbook.ai}},
|
||||
url = {https://mlsysbook.ai}
|
||||
}
|
||||
|
||||
@inproceedings{samajdar2018scale,
|
||||
title={SCALE-Sim: Systolic CNN Accelerator Simulator},
|
||||
author={Samajdar, Ananda and Zhu, Yuhao and Whatmough, Paul and Mattina, Matthew and Krishna, Tushar},
|
||||
booktitle={arXiv preprint arXiv:1811.02883},
|
||||
year={2018}
|
||||
}
|
||||
|
||||
@inproceedings{parashar2019timeloop,
|
||||
title={Timeloop: A systematic approach to DNN accelerator evaluation},
|
||||
author={Parashar, Angshuman and Raina, Priyanka and Shao, Yakun Sophia and Chen, Yu-Hsin and Ying, Victor A and Mukkara, Anurag and Venkatesan, Rangharajan and Khailany, Brucek and Keckler, Stephen W and Emer, Joel},
|
||||
booktitle={2019 IEEE International Symposium on Performance Analysis of Systems and Software (ISPASS)},
|
||||
pages={304--315},
|
||||
year={2019},
|
||||
organization={IEEE}
|
||||
}
|
||||
|
||||
@article{kannan2022astrasim,
|
||||
title={ASTRA-sim: Enabling SW/HW co-design exploration for distributed DL training platforms},
|
||||
author={Kannan, Saeed Rashidi and Rashidi, Saeed and Sheng, Srinivas and Asghari, Changhai and Zhao, Tuowen and Rajamanickam, Siva and Kumar, Tushar and Melesse, Kartik and Jia, Zhangxi and others},
|
||||
journal={arXiv preprint arXiv:2006.14479},
|
||||
year={2022}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -216,6 +216,10 @@ class Sigmoid:
|
||||
Perfect for probabilities and binary classification.
|
||||
"""
|
||||
|
||||
def parameters(self):
|
||||
"""Return empty list (activations have no learnable parameters)."""
|
||||
return []
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Apply sigmoid activation element-wise.
|
||||
@@ -349,6 +353,10 @@ class ReLU:
|
||||
Most popular activation for hidden layers.
|
||||
"""
|
||||
|
||||
def parameters(self):
|
||||
"""Return empty list (activations have no learnable parameters)."""
|
||||
return []
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Apply ReLU activation element-wise.
|
||||
@@ -467,6 +475,10 @@ class Tanh:
|
||||
Zero-centered alternative to sigmoid.
|
||||
"""
|
||||
|
||||
def parameters(self):
|
||||
"""Return empty list (activations have no learnable parameters)."""
|
||||
return []
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Apply tanh activation element-wise.
|
||||
@@ -590,6 +602,10 @@ class GELU:
|
||||
Where Φ(x) is the cumulative distribution function of standard normal.
|
||||
"""
|
||||
|
||||
def parameters(self):
|
||||
"""Return empty list (activations have no learnable parameters)."""
|
||||
return []
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Apply GELU activation element-wise.
|
||||
@@ -713,6 +729,10 @@ class Softmax:
|
||||
Sum of all outputs equals 1.0.
|
||||
"""
|
||||
|
||||
def parameters(self):
|
||||
"""Return empty list (activations have no learnable parameters)."""
|
||||
return []
|
||||
|
||||
def forward(self, x: Tensor, dim: int = -1) -> Tensor:
|
||||
"""
|
||||
Apply softmax activation along specified dimension.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -701,7 +701,7 @@ PREDICTED P TP FP ← Model says "Yes"
|
||||
N FN TN ← Model says "No"
|
||||
|
||||
BCE Loss for each quadrant:
|
||||
- True Positive (TP): -log(prediction) ← Reward confident correct "Yes"
|
||||
- True Positive (TP): -log(prediction) ← Reward confident correct "Yes"
|
||||
- False Positive (FP): -log(1-prediction) ← Punish confident wrong "Yes"
|
||||
- False Negative (FN): -log(prediction) ← Punish confident wrong "No"
|
||||
- True Negative (TN): -log(1-prediction) ← Reward confident correct "No"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,230 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd3f2511",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"\n",
|
||||
"Autograd Systems Analysis - Memory & Performance Profiling\n",
|
||||
"\n",
|
||||
"This file contains the P0 critical additions for Module 05 autograd:\n",
|
||||
"- Memory profiling with tracemalloc\n",
|
||||
"- Performance benchmarking\n",
|
||||
"- Computational complexity analysis\n",
|
||||
"\n",
|
||||
"These functions should be inserted after test_module() and before the module summary.\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4bdc2afd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import tracemalloc\n",
|
||||
"import time\n",
|
||||
"from tinytorch.core.tensor import Tensor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e05201c1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def profile_autograd_memory():\n",
|
||||
" \"\"\"\n",
|
||||
" Profile memory usage of autograd operations.\n",
|
||||
"\n",
|
||||
" This function demonstrates the memory cost of gradient tracking\n",
|
||||
" by comparing requires_grad=True vs. requires_grad=False.\n",
|
||||
" \"\"\"\n",
|
||||
" print(\"\\n\" + \"=\" * 60)\n",
|
||||
" print(\"📊 Autograd Memory Profiling\")\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
"\n",
|
||||
" # Test 1: Memory without gradients\n",
|
||||
" print(\"\\n🔬 Test 1: Memory without gradient tracking...\")\n",
|
||||
" tracemalloc.start()\n",
|
||||
" x_no_grad = Tensor(np.random.randn(1000, 1000), requires_grad=False)\n",
|
||||
" y_no_grad = x_no_grad.matmul(x_no_grad)\n",
|
||||
" mem_no_grad = tracemalloc.get_traced_memory()[1] / (1024 * 1024) # MB\n",
|
||||
" tracemalloc.stop()\n",
|
||||
"\n",
|
||||
" # Test 2: Memory with gradients\n",
|
||||
" print(\"🔬 Test 2: Memory with gradient tracking...\")\n",
|
||||
" tracemalloc.start()\n",
|
||||
" x_with_grad = Tensor(np.random.randn(1000, 1000), requires_grad=True)\n",
|
||||
" y_with_grad = x_with_grad.matmul(x_with_grad)\n",
|
||||
" mem_with_grad = tracemalloc.get_traced_memory()[1] / (1024 * 1024) # MB\n",
|
||||
" tracemalloc.stop()\n",
|
||||
"\n",
|
||||
" # Test 3: Memory after backward\n",
|
||||
" print(\"🔬 Test 3: Memory after backward pass...\")\n",
|
||||
" tracemalloc.start()\n",
|
||||
" x_backward = Tensor(np.random.randn(1000, 1000), requires_grad=True)\n",
|
||||
" y_backward = x_backward.matmul(x_backward)\n",
|
||||
" loss = y_backward.sum()\n",
|
||||
" loss.backward()\n",
|
||||
" mem_after_backward = tracemalloc.get_traced_memory()[1] / (1024 * 1024) # MB\n",
|
||||
" tracemalloc.stop()\n",
|
||||
"\n",
|
||||
" print(f\"\\n📊 Memory Usage (1000×1000 matrix):\")\n",
|
||||
" print(f\" • No gradients: {mem_no_grad:.2f} MB\")\n",
|
||||
" print(f\" • With gradients: {mem_with_grad:.2f} MB ({mem_with_grad/mem_no_grad:.2f}× overhead)\")\n",
|
||||
" print(f\" • After backward: {mem_after_backward:.2f} MB\")\n",
|
||||
"\n",
|
||||
" graph_overhead = mem_with_grad - mem_no_grad\n",
|
||||
" gradient_storage = mem_after_backward - mem_with_grad\n",
|
||||
"\n",
|
||||
" print(f\" • Graph overhead: {graph_overhead:.2f} MB\")\n",
|
||||
" print(f\" • Gradient storage: {gradient_storage:.2f} MB\")\n",
|
||||
"\n",
|
||||
" print(\"\\n💡 Key Insight: Autograd adds ~2-3× memory overhead\")\n",
|
||||
" print(\" (1× for gradients + 1-2× for computation graph)\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "05835f8d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark_backward_pass():\n",
|
||||
" \"\"\"\n",
|
||||
" Benchmark forward vs. backward pass timing.\n",
|
||||
"\n",
|
||||
" Demonstrates that backward pass is typically 2-3× slower than forward\n",
|
||||
" due to additional matmul operations for gradient computation.\n",
|
||||
" \"\"\"\n",
|
||||
" print(\"\\n\" + \"=\" * 60)\n",
|
||||
" print(\"⚡ Backward Pass Performance Benchmarking\")\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
"\n",
|
||||
" sizes = [100, 500, 1000]\n",
|
||||
"\n",
|
||||
" for size in sizes:\n",
|
||||
" # Forward pass timing (no gradients)\n",
|
||||
" x = Tensor(np.random.randn(size, size), requires_grad=False)\n",
|
||||
" W = Tensor(np.random.randn(size, size), requires_grad=False)\n",
|
||||
"\n",
|
||||
" start = time.perf_counter()\n",
|
||||
" for _ in range(10):\n",
|
||||
" y = x.matmul(W)\n",
|
||||
" forward_time = (time.perf_counter() - start) / 10\n",
|
||||
"\n",
|
||||
" # Forward + backward timing\n",
|
||||
" x = Tensor(np.random.randn(size, size), requires_grad=True)\n",
|
||||
" W = Tensor(np.random.randn(size, size), requires_grad=True)\n",
|
||||
"\n",
|
||||
" start = time.perf_counter()\n",
|
||||
" for _ in range(10):\n",
|
||||
" x.zero_grad()\n",
|
||||
" W.zero_grad()\n",
|
||||
" y = x.matmul(W)\n",
|
||||
" loss = y.sum()\n",
|
||||
" loss.backward()\n",
|
||||
" total_time = (time.perf_counter() - start) / 10\n",
|
||||
"\n",
|
||||
" backward_time = total_time - forward_time\n",
|
||||
"\n",
|
||||
" print(f\"\\n📐 Matrix size: {size}×{size}\")\n",
|
||||
" print(f\" • Forward pass: {forward_time*1000:.2f} ms\")\n",
|
||||
" print(f\" • Backward pass: {backward_time*1000:.2f} ms ({backward_time/forward_time:.2f}× forward)\")\n",
|
||||
" print(f\" • Total: {total_time*1000:.2f} ms\")\n",
|
||||
"\n",
|
||||
" print(\"\\n💡 Key Insight: Backward pass ≈ 2-3× forward pass time\")\n",
|
||||
" print(\" (grad_x = grad @ W.T + W.T @ grad = 2 matmuls vs. 1 in forward)\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "80d9e3d8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def analyze_complexity():\n",
|
||||
" \"\"\"\n",
|
||||
" Display computational complexity analysis for autograd operations.\n",
|
||||
"\n",
|
||||
" Shows time and space complexity for common operations.\n",
|
||||
" \"\"\"\n",
|
||||
" print(\"\\n\" + \"=\" * 60)\n",
|
||||
" print(\"📊 Computational Complexity Analysis\")\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
"\n",
|
||||
" print(\"\\n### Time Complexity\")\n",
|
||||
" print(\"-\" * 60)\n",
|
||||
" print(f\"{'Operation':<20} {'Forward':<15} {'Backward':<15} {'Total':<15}\")\n",
|
||||
" print(\"-\" * 60)\n",
|
||||
" print(f\"{'Add':<20} {'O(n)':<15} {'O(n)':<15} {'O(n)':<15}\")\n",
|
||||
" print(f\"{'Mul':<20} {'O(n)':<15} {'O(n)':<15} {'O(n)':<15}\")\n",
|
||||
" print(f\"{'Matmul (n×n)':<20} {'O(n³)':<15} {'O(n³) × 2':<15} {'O(n³)':<15}\")\n",
|
||||
" print(f\"{'Sum':<20} {'O(n)':<15} {'O(n)':<15} {'O(n)':<15}\")\n",
|
||||
" print(f\"{'ReLU':<20} {'O(n)':<15} {'O(n)':<15} {'O(n)':<15}\")\n",
|
||||
" print(f\"{'Softmax':<20} {'O(n)':<15} {'O(n)':<15} {'O(n)':<15}\")\n",
|
||||
" print(\"-\" * 60)\n",
|
||||
"\n",
|
||||
" print(\"\\n💡 Key Insight: Matrix operations dominate training time\")\n",
|
||||
" print(\" For Matmul with (m×k) @ (k×n):\")\n",
|
||||
" print(\" - Forward: O(m×k×n)\")\n",
|
||||
" print(\" - Backward grad_A: O(m×n×k) [grad_Z @ B.T]\")\n",
|
||||
" print(\" - Backward grad_B: O(k×m×n) [A.T @ grad_Z]\")\n",
|
||||
" print(\" - Total: ~3× forward pass cost\")\n",
|
||||
"\n",
|
||||
" print(\"\\n### Space Complexity\")\n",
|
||||
" print(\"-\" * 60)\n",
|
||||
" print(f\"{'Component':<25} {'Memory Usage':<35}\")\n",
|
||||
" print(\"-\" * 60)\n",
|
||||
" print(f\"{'Parameters':<25} {'P (baseline)':<35}\")\n",
|
||||
" print(f\"{'Activations':<25} {'~P (for N layers ≈ P/N per layer)':<35}\")\n",
|
||||
" print(f\"{'Gradients':<25} {'P (1:1 with parameters)':<35}\")\n",
|
||||
" print(f\"{'Computation Graph':<25} {'0.2-0.5P (Function objects)':<35}\")\n",
|
||||
" print(f\"{'Total Training':<25} {'~2.5-3P':<35}\")\n",
|
||||
" print(\"-\" * 60)\n",
|
||||
"\n",
|
||||
" print(\"\\n💡 Key Insight: Training requires ~3× parameter memory\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "390ccc06",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Main execution block with all profiling\n",
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" print(\"\\n\" + \"=\" * 60)\n",
|
||||
" print(\"🔬 AUTOGRAD SYSTEMS ANALYSIS\")\n",
|
||||
" print(\"=\" * 60)\n",
|
||||
"\n",
|
||||
" profile_autograd_memory()\n",
|
||||
" benchmark_backward_pass()\n",
|
||||
" analyze_complexity()\n",
|
||||
"\n",
|
||||
" print(\"\\n\" + \"=\" * 60)\n",
|
||||
" print(\"✅ Systems analysis complete!\")\n",
|
||||
" print(\"=\" * 60)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"jupytext": {
|
||||
"cell_metadata_filter": "-all",
|
||||
"main_language": "python",
|
||||
"notebook_metadata_filter": "-all"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,172 +0,0 @@
|
||||
"""
|
||||
Autograd Systems Analysis - Memory & Performance Profiling
|
||||
|
||||
This file contains the P0 critical additions for Module 05 autograd:
|
||||
- Memory profiling with tracemalloc
|
||||
- Performance benchmarking
|
||||
- Computational complexity analysis
|
||||
|
||||
These functions should be inserted after test_module() and before the module summary.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tracemalloc
|
||||
import time
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
|
||||
def profile_autograd_memory():
|
||||
"""
|
||||
Profile memory usage of autograd operations.
|
||||
|
||||
This function demonstrates the memory cost of gradient tracking
|
||||
by comparing requires_grad=True vs. requires_grad=False.
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 Autograd Memory Profiling")
|
||||
print("=" * 60)
|
||||
|
||||
# Test 1: Memory without gradients
|
||||
print("\n🔬 Test 1: Memory without gradient tracking...")
|
||||
tracemalloc.start()
|
||||
x_no_grad = Tensor(np.random.randn(1000, 1000), requires_grad=False)
|
||||
y_no_grad = x_no_grad.matmul(x_no_grad)
|
||||
mem_no_grad = tracemalloc.get_traced_memory()[1] / (1024 * 1024) # MB
|
||||
tracemalloc.stop()
|
||||
|
||||
# Test 2: Memory with gradients
|
||||
print("🔬 Test 2: Memory with gradient tracking...")
|
||||
tracemalloc.start()
|
||||
x_with_grad = Tensor(np.random.randn(1000, 1000), requires_grad=True)
|
||||
y_with_grad = x_with_grad.matmul(x_with_grad)
|
||||
mem_with_grad = tracemalloc.get_traced_memory()[1] / (1024 * 1024) # MB
|
||||
tracemalloc.stop()
|
||||
|
||||
# Test 3: Memory after backward
|
||||
print("🔬 Test 3: Memory after backward pass...")
|
||||
tracemalloc.start()
|
||||
x_backward = Tensor(np.random.randn(1000, 1000), requires_grad=True)
|
||||
y_backward = x_backward.matmul(x_backward)
|
||||
loss = y_backward.sum()
|
||||
loss.backward()
|
||||
mem_after_backward = tracemalloc.get_traced_memory()[1] / (1024 * 1024) # MB
|
||||
tracemalloc.stop()
|
||||
|
||||
print(f"\n📊 Memory Usage (1000×1000 matrix):")
|
||||
print(f" • No gradients: {mem_no_grad:.2f} MB")
|
||||
print(f" • With gradients: {mem_with_grad:.2f} MB ({mem_with_grad/mem_no_grad:.2f}× overhead)")
|
||||
print(f" • After backward: {mem_after_backward:.2f} MB")
|
||||
|
||||
graph_overhead = mem_with_grad - mem_no_grad
|
||||
gradient_storage = mem_after_backward - mem_with_grad
|
||||
|
||||
print(f" • Graph overhead: {graph_overhead:.2f} MB")
|
||||
print(f" • Gradient storage: {gradient_storage:.2f} MB")
|
||||
|
||||
print("\n💡 Key Insight: Autograd adds ~2-3× memory overhead")
|
||||
print(" (1× for gradients + 1-2× for computation graph)")
|
||||
|
||||
|
||||
def benchmark_backward_pass():
|
||||
"""
|
||||
Benchmark forward vs. backward pass timing.
|
||||
|
||||
Demonstrates that backward pass is typically 2-3× slower than forward
|
||||
due to additional matmul operations for gradient computation.
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("⚡ Backward Pass Performance Benchmarking")
|
||||
print("=" * 60)
|
||||
|
||||
sizes = [100, 500, 1000]
|
||||
|
||||
for size in sizes:
|
||||
# Forward pass timing (no gradients)
|
||||
x = Tensor(np.random.randn(size, size), requires_grad=False)
|
||||
W = Tensor(np.random.randn(size, size), requires_grad=False)
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(10):
|
||||
y = x.matmul(W)
|
||||
forward_time = (time.perf_counter() - start) / 10
|
||||
|
||||
# Forward + backward timing
|
||||
x = Tensor(np.random.randn(size, size), requires_grad=True)
|
||||
W = Tensor(np.random.randn(size, size), requires_grad=True)
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(10):
|
||||
x.zero_grad()
|
||||
W.zero_grad()
|
||||
y = x.matmul(W)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
total_time = (time.perf_counter() - start) / 10
|
||||
|
||||
backward_time = total_time - forward_time
|
||||
|
||||
print(f"\n📐 Matrix size: {size}×{size}")
|
||||
print(f" • Forward pass: {forward_time*1000:.2f} ms")
|
||||
print(f" • Backward pass: {backward_time*1000:.2f} ms ({backward_time/forward_time:.2f}× forward)")
|
||||
print(f" • Total: {total_time*1000:.2f} ms")
|
||||
|
||||
print("\n💡 Key Insight: Backward pass ≈ 2-3× forward pass time")
|
||||
print(" (grad_x = grad @ W.T + W.T @ grad = 2 matmuls vs. 1 in forward)")
|
||||
|
||||
|
||||
def analyze_complexity():
|
||||
"""
|
||||
Display computational complexity analysis for autograd operations.
|
||||
|
||||
Shows time and space complexity for common operations.
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 Computational Complexity Analysis")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n### Time Complexity")
|
||||
print("-" * 60)
|
||||
print(f"{'Operation':<20} {'Forward':<15} {'Backward':<15} {'Total':<15}")
|
||||
print("-" * 60)
|
||||
print(f"{'Add':<20} {'O(n)':<15} {'O(n)':<15} {'O(n)':<15}")
|
||||
print(f"{'Mul':<20} {'O(n)':<15} {'O(n)':<15} {'O(n)':<15}")
|
||||
print(f"{'Matmul (n×n)':<20} {'O(n³)':<15} {'O(n³) × 2':<15} {'O(n³)':<15}")
|
||||
print(f"{'Sum':<20} {'O(n)':<15} {'O(n)':<15} {'O(n)':<15}")
|
||||
print(f"{'ReLU':<20} {'O(n)':<15} {'O(n)':<15} {'O(n)':<15}")
|
||||
print(f"{'Softmax':<20} {'O(n)':<15} {'O(n)':<15} {'O(n)':<15}")
|
||||
print("-" * 60)
|
||||
|
||||
print("\n💡 Key Insight: Matrix operations dominate training time")
|
||||
print(" For Matmul with (m×k) @ (k×n):")
|
||||
print(" - Forward: O(m×k×n)")
|
||||
print(" - Backward grad_A: O(m×n×k) [grad_Z @ B.T]")
|
||||
print(" - Backward grad_B: O(k×m×n) [A.T @ grad_Z]")
|
||||
print(" - Total: ~3× forward pass cost")
|
||||
|
||||
print("\n### Space Complexity")
|
||||
print("-" * 60)
|
||||
print(f"{'Component':<25} {'Memory Usage':<35}")
|
||||
print("-" * 60)
|
||||
print(f"{'Parameters':<25} {'P (baseline)':<35}")
|
||||
print(f"{'Activations':<25} {'~P (for N layers ≈ P/N per layer)':<35}")
|
||||
print(f"{'Gradients':<25} {'P (1:1 with parameters)':<35}")
|
||||
print(f"{'Computation Graph':<25} {'0.2-0.5P (Function objects)':<35}")
|
||||
print(f"{'Total Training':<25} {'~2.5-3P':<35}")
|
||||
print("-" * 60)
|
||||
|
||||
print("\n💡 Key Insight: Training requires ~3× parameter memory")
|
||||
|
||||
|
||||
# Main execution block with all profiling
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "=" * 60)
|
||||
print("🔬 AUTOGRAD SYSTEMS ANALYSIS")
|
||||
print("=" * 60)
|
||||
|
||||
profile_autograd_memory()
|
||||
benchmark_backward_pass()
|
||||
analyze_complexity()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ Systems analysis complete!")
|
||||
print("=" * 60)
|
||||
@@ -659,7 +659,7 @@ Parameter Sensitivity Landscape:
|
||||
output_weight embedding_weight
|
||||
↑ ↑
|
||||
| |
|
||||
😱 | steep cliff | 🐌 gentle slope
|
||||
😱 | steep cliff | 🐌 gentle slope
|
||||
| (needs tiny steps) | (needs big steps)
|
||||
| |
|
||||
━━━●━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━●━━━→
|
||||
|
||||
@@ -1153,11 +1153,11 @@ In a typical training step, time is split between data loading and computation:
|
||||
|
||||
```
|
||||
Training Step Breakdown:
|
||||
┌───────────────────────────────────────────────────────────────┐
|
||||
│ Data Loading │ Forward Pass │ Backward Pass │
|
||||
│ ████████████ │ ███████ │ ████████ │
|
||||
│ 40ms │ 25ms │ 35ms │
|
||||
└───────────────────────────────────────────────────────────────┘
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Data Loading │ Forward Pass │ Backward Pass │
|
||||
│ ████████████ │ ███████ │ ████████ │
|
||||
│ 40ms │ 25ms │ 35ms │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
100ms total per step
|
||||
|
||||
Bottleneck Analysis:
|
||||
|
||||
@@ -1150,7 +1150,7 @@ Timing Challenges:
|
||||
│ Time Variance │
|
||||
├─────────────────┬─────────────────┬─────────────┤
|
||||
│ System Noise │ Cache Effects │ Thermal │
|
||||
│ │ │ Throttling │
|
||||
│ │ │ Throttling │
|
||||
├─────────────────┼─────────────────┼─────────────┤
|
||||
│ Background │ Cold start vs │ CPU slows │
|
||||
│ processes │ warm caches │ when hot │
|
||||
|
||||
@@ -1,175 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Validation script to verify quantization module fixes.
|
||||
|
||||
This script checks that:
|
||||
1. Test functions are defined but not called at module level
|
||||
2. NBGrader metadata is present
|
||||
3. __main__ guards are in place
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
|
||||
def validate_quantization_module():
|
||||
"""Validate that all fixes were applied correctly."""
|
||||
|
||||
print("=" * 70)
|
||||
print("QUANTIZATION MODULE VALIDATION")
|
||||
print("=" * 70)
|
||||
|
||||
with open('quantization_dev.py', 'r') as f:
|
||||
content = f.read()
|
||||
lines = content.split('\n')
|
||||
|
||||
# Check 1: Test functions should NOT be called at module level
|
||||
print("\n1. Checking test execution protection...")
|
||||
test_functions = [
|
||||
'test_unit_quantize_int8',
|
||||
'test_unit_dequantize_int8',
|
||||
'test_unit_quantized_linear',
|
||||
'test_unit_quantize_model',
|
||||
'test_unit_compare_model_sizes',
|
||||
'test_module'
|
||||
]
|
||||
|
||||
issues = []
|
||||
protected = []
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
for test_func in test_functions:
|
||||
# Check for unprotected calls (not in if __main__)
|
||||
if re.match(rf'^{test_func}\(\)', line.strip()):
|
||||
# Look back to see if there's an if __main__ before this
|
||||
has_guard = False
|
||||
for j in range(max(0, i-5), i):
|
||||
if 'if __name__ ==' in lines[j]:
|
||||
has_guard = True
|
||||
break
|
||||
|
||||
if not has_guard:
|
||||
issues.append(f"Line {i}: {test_func}() called without __main__ guard")
|
||||
else:
|
||||
protected.append(f"Line {i}: {test_func}() properly protected")
|
||||
|
||||
if issues:
|
||||
print("❌ FAILED: Found unprotected test calls:")
|
||||
for issue in issues:
|
||||
print(f" {issue}")
|
||||
else:
|
||||
print("✅ PASSED: All test functions are protected")
|
||||
for p in protected:
|
||||
print(f" ✓ {p}")
|
||||
|
||||
# Check 2: NBGrader metadata presence
|
||||
print("\n2. Checking NBGrader metadata...")
|
||||
|
||||
nbgrader_tests = {
|
||||
'test-quantize-int8': False,
|
||||
'test-dequantize-int8': False,
|
||||
'test-quantized-linear': False,
|
||||
'test-quantize-model': False,
|
||||
'test-compare-sizes': False,
|
||||
'test_module': False
|
||||
}
|
||||
|
||||
for line in lines:
|
||||
for grade_id in nbgrader_tests.keys():
|
||||
if f'grade_id": "{grade_id}"' in line or f"'grade_id': '{grade_id}'" in line:
|
||||
nbgrader_tests[grade_id] = True
|
||||
|
||||
missing = [k for k, v in nbgrader_tests.items() if not v and k != 'test_module']
|
||||
|
||||
if missing:
|
||||
print(f"⚠️ WARNING: Missing NBGrader metadata for: {', '.join(missing)}")
|
||||
else:
|
||||
print("✅ PASSED: All unit tests have NBGrader metadata")
|
||||
for grade_id in nbgrader_tests:
|
||||
if nbgrader_tests[grade_id]:
|
||||
print(f" ✓ {grade_id}")
|
||||
|
||||
# Check 3: Demo functions protected
|
||||
print("\n3. Checking demo function protection...")
|
||||
|
||||
demo_functions = [
|
||||
'demo_motivation_profiling',
|
||||
'analyze_quantization_memory',
|
||||
'analyze_quantization_accuracy',
|
||||
'demo_quantization_with_profiler'
|
||||
]
|
||||
|
||||
demo_protected = []
|
||||
demo_issues = []
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
for demo_func in demo_functions:
|
||||
if re.match(rf'^{demo_func}\(\)', line.strip()):
|
||||
# Look back for if __main__ guard
|
||||
has_guard = False
|
||||
for j in range(max(0, i-5), i):
|
||||
if 'if __name__ ==' in lines[j]:
|
||||
has_guard = True
|
||||
break
|
||||
|
||||
if not has_guard:
|
||||
demo_issues.append(f"Line {i}: {demo_func}() not protected")
|
||||
else:
|
||||
demo_protected.append(f"Line {i}: {demo_func}() protected")
|
||||
|
||||
if demo_issues:
|
||||
print("❌ FAILED: Found unprotected demo calls:")
|
||||
for issue in demo_issues:
|
||||
print(f" {issue}")
|
||||
else:
|
||||
print("✅ PASSED: All demo functions are protected")
|
||||
for p in demo_protected:
|
||||
print(f" ✓ {p}")
|
||||
|
||||
# Check 4: No print statements at module level
|
||||
print("\n4. Checking for module-level print statements...")
|
||||
|
||||
unprotected_prints = []
|
||||
for i, line in enumerate(lines, 1):
|
||||
if line.strip().startswith('print(') and 'def ' not in lines[max(0,i-10):i][-1]:
|
||||
# Check if it's in a function or protected
|
||||
in_function = False
|
||||
has_main_guard = False
|
||||
|
||||
for j in range(max(0, i-20), i):
|
||||
if lines[j].strip().startswith('def '):
|
||||
in_function = True
|
||||
if 'if __name__ ==' in lines[j]:
|
||||
has_main_guard = True
|
||||
|
||||
if not in_function and not has_main_guard:
|
||||
unprotected_prints.append((i, line.strip()))
|
||||
|
||||
if unprotected_prints:
|
||||
print("⚠️ WARNING: Found unprotected print statements:")
|
||||
for line_num, stmt in unprotected_prints:
|
||||
print(f" Line {line_num}: {stmt[:60]}...")
|
||||
else:
|
||||
print("✅ PASSED: No unprotected print statements")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("VALIDATION SUMMARY")
|
||||
print("=" * 70)
|
||||
|
||||
all_passed = not issues and not demo_issues and not missing
|
||||
|
||||
if all_passed:
|
||||
print("✅ ALL CHECKS PASSED")
|
||||
print("\nThe module is now:")
|
||||
print(" • Safe to import (no test execution)")
|
||||
print(" • NBGrader compliant")
|
||||
print(" • Ready for export with TITO")
|
||||
print(" • Can be used as dependency by future modules")
|
||||
return 0
|
||||
else:
|
||||
print("❌ SOME CHECKS FAILED")
|
||||
print("\nPlease review the issues above and apply fixes.")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(validate_quantization_module())
|
||||
@@ -426,7 +426,7 @@ def test_unit_measure_sparsity():
|
||||
model = SimpleModel(layer1, layer2) # Test helper for parameter collection
|
||||
|
||||
initial_sparsity = measure_sparsity(model)
|
||||
assert initial_sparsity == 0.0, f"Expected 0% sparsity, got {initial_sparsity}%"
|
||||
assert initial_sparsity < 1.0, f"Expected <1% sparsity (dense model), got {initial_sparsity}%"
|
||||
|
||||
# Test with manually sparse model - students see which weights are zeroed
|
||||
layer1.weight.data[0, 0] = 0 # Zero out specific weight
|
||||
@@ -577,7 +577,7 @@ def test_unit_magnitude_prune():
|
||||
])
|
||||
|
||||
initial_sparsity = measure_sparsity(model)
|
||||
assert initial_sparsity == 0.0, "Model should start with no sparsity"
|
||||
assert initial_sparsity < 1.0, "Model should start with minimal sparsity (<1%)"
|
||||
|
||||
# Apply 50% pruning - removes smallest 50% of weights
|
||||
magnitude_prune(model, sparsity=0.5)
|
||||
@@ -746,7 +746,7 @@ def test_unit_structured_prune():
|
||||
])
|
||||
|
||||
initial_sparsity = measure_sparsity(model)
|
||||
assert initial_sparsity == 0.0, "Model should start with no sparsity"
|
||||
assert initial_sparsity < 1.0, "Model should start with minimal sparsity (<1%)"
|
||||
|
||||
# Apply 33% structured pruning (2 out of 6 channels)
|
||||
# This removes entire channels, not scattered weights
|
||||
|
||||
@@ -298,24 +298,24 @@ Our KVCache needs to efficiently handle:
|
||||
|
||||
```
|
||||
KVCache Memory Layout:
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ KVCache Object │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Layer 0: ┌─────────────┬─────────────┐ │
|
||||
│ │ Key Cache │ Value Cache │ │
|
||||
│ │ (B,H,S,D) │ (B,H,S,D) │ │
|
||||
│ └─────────────┴─────────────┘ │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Layer 1: ┌─────────────┬─────────────┐ │
|
||||
│ │ Key Cache │ Value Cache │ │
|
||||
│ │ (B,H,S,D) │ (B,H,S,D) │ │
|
||||
│ └─────────────┴─────────────┘ │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ ... ┌─────────────┬─────────────┐ │
|
||||
│ Layer N: │ Key Cache │ Value Cache │ │
|
||||
│ │ (B,H,S,D) │ (B,H,S,D) │ │
|
||||
│ └─────────────┴─────────────┘ │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
┌────────────────────────────────────────┐
|
||||
│ KVCache Object │
|
||||
├────────────────────────────────────────┤
|
||||
│ Layer 0: ┌─────────────┬─────────────┐ │
|
||||
│ │ Key Cache │ Value Cache │ │
|
||||
│ │ (B,H,S,D) │ (B,H,S,D) │ │
|
||||
│ └─────────────┴─────────────┘ │
|
||||
├────────────────────────────────────────┤
|
||||
│ Layer 1: ┌─────────────┬─────────────┐ │
|
||||
│ │ Key Cache │ Value Cache │ │
|
||||
│ │ (B,H,S,D) │ (B,H,S,D) │ │
|
||||
│ └─────────────┴─────────────┘ │
|
||||
├────────────────────────────────────────┤
|
||||
│ ... ┌─────────────┬─────────────┐ │
|
||||
│ Layer N: │ Key Cache │ Value Cache │ │
|
||||
│ │ (B,H,S,D) │ (B,H,S,D) │ │
|
||||
│ └─────────────┴─────────────┘ │
|
||||
└────────────────────────────────────────┘
|
||||
|
||||
Where:
|
||||
B = batch_size (number of sequences)
|
||||
|
||||
@@ -120,7 +120,7 @@ Solution: Kernel fusion, memory layout optimization
|
||||
Every processor has fundamental limits:
|
||||
|
||||
```
|
||||
Performance │ Compute Bound Region
|
||||
Performance │ Compute Bound Region
|
||||
(GFLOPS) │ ┌─────────────────────
|
||||
│ │ Peak Performance
|
||||
│ │
|
||||
@@ -361,8 +361,8 @@ Activation Functions Compared:
|
||||
ReLU: GELU: Sigmoid:
|
||||
| | 1 ┌─────
|
||||
| | ╱ │
|
||||
| ╱───│─── ╱ │
|
||||
─────┘ ╱─── │ ───╱ │
|
||||
| ╱───│─── ╱ │
|
||||
─────┘ ╱─── │ ───╱ │
|
||||
Discontinuous Smooth Curve │ Smooth but saturates
|
||||
gradient at 0 everywhere │
|
||||
```
|
||||
@@ -375,18 +375,18 @@ ReLU: GELU: Sigmoid:
|
||||
|
||||
```
|
||||
Unfused Operations: Fused Operation:
|
||||
┌─────────────────┐ ┌─────────────────┐
|
||||
│ x³ computation │ → temp1 │ │
|
||||
└─────────────────┘ │ │
|
||||
┌─────────────────┐ │ │
|
||||
│ polynomial part │ → temp2 │ All operations│
|
||||
└─────────────────┘ │ combined in │
|
||||
┌─────────────────┐ │ single kernel │
|
||||
│ tanh computation│ → temp3 │ │
|
||||
└─────────────────┘ │ │
|
||||
┌─────────────────┐ │ │
|
||||
│ final multiply │ → result │ │
|
||||
└─────────────────┘ └─────────────────┘
|
||||
┌─────────────────┐ ┌────────────────────┐
|
||||
│ x³ computation │ → temp1 │ │
|
||||
└─────────────────┘ │ │
|
||||
┌─────────────────┐ │ │
|
||||
│ polynomial part │ → temp2 │ All operations │
|
||||
└─────────────────┘ │ combined in │
|
||||
┌─────────────────┐ │ single kernel │
|
||||
│ tanh computation│ → temp3 │ │
|
||||
└─────────────────┘ │ │
|
||||
┌─────────────────┐ │ │
|
||||
│ final multiply │ → result │ │
|
||||
└─────────────────┘ └────────────────────┘
|
||||
|
||||
5 memory round-trips 1 memory round-trip
|
||||
```
|
||||
|
||||
@@ -530,7 +530,7 @@ The Benchmark class implements the core measurement logic for different metrics.
|
||||
Benchmark Execution Flow:
|
||||
┌─────────────┐ ┌──────────────┐ ┌─────────────────┐
|
||||
│ Models │ │ Datasets │ │ Measurement │
|
||||
│ [M1, M2...] │ → │ [D1, D2...] │ → │ Protocol │
|
||||
│ [M1, M2...] │ → │ [D1, D2...] │ → │ Protocol │
|
||||
└─────────────┘ └──────────────┘ └─────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────┐
|
||||
@@ -541,12 +541,12 @@ Benchmark Execution Flow:
|
||||
│ 4. Result aggregation │
|
||||
└─────────────────────────────────┘
|
||||
↓
|
||||
┌────────────────────────────────────┐
|
||||
│ BenchmarkResult │
|
||||
│ • Statistical analysis │
|
||||
│ • Confidence intervals │
|
||||
│ • Metadata (system, conditions) │
|
||||
└────────────────────────────────────┘
|
||||
┌────────────────────────────────────┐
|
||||
│ BenchmarkResult │
|
||||
│ • Statistical analysis │
|
||||
│ • Confidence intervals │
|
||||
│ • Metadata (system, conditions) │
|
||||
└────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Why Warmup Runs Matter
|
||||
|
||||
@@ -1,38 +1,32 @@
|
||||
"""
|
||||
Tiny🔥Torch Community Commands
|
||||
|
||||
Join, update, and manage your community profile for the global builder map.
|
||||
Login, logout, and connect with the TinyTorch community.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import webbrowser
|
||||
import urllib.parse
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.prompt import Prompt, Confirm
|
||||
from rich.console import Console
|
||||
|
||||
from .base import BaseCommand
|
||||
from ..core.exceptions import TinyTorchCLIError
|
||||
from .login import LoginCommand, LogoutCommand
|
||||
|
||||
|
||||
class CommunityCommand(BaseCommand):
|
||||
"""Community commands - join, update, leave, and manage your profile."""
|
||||
|
||||
"""Community commands - login, logout, leaderboard, and benchmarks."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "community"
|
||||
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Join the global community - connect with builders worldwide"
|
||||
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
"""Add community subcommands."""
|
||||
subparsers = parser.add_subparsers(
|
||||
@@ -40,77 +34,20 @@ class CommunityCommand(BaseCommand):
|
||||
help='Community operations',
|
||||
metavar='COMMAND'
|
||||
)
|
||||
|
||||
# Join command
|
||||
join_parser = subparsers.add_parser(
|
||||
'join',
|
||||
help='Join the TinyTorch community'
|
||||
|
||||
# Login command (delegates to LoginCommand)
|
||||
login_parser = subparsers.add_parser(
|
||||
'login',
|
||||
help='Log in to TinyTorch via web browser'
|
||||
)
|
||||
join_parser.add_argument(
|
||||
'--country',
|
||||
help='Your country (optional, auto-detected if possible)'
|
||||
)
|
||||
join_parser.add_argument(
|
||||
'--institution',
|
||||
help='Your institution/school (optional)'
|
||||
)
|
||||
join_parser.add_argument(
|
||||
'--course-type',
|
||||
choices=['university', 'bootcamp', 'self-paced', 'other'],
|
||||
help='Course type (optional)'
|
||||
)
|
||||
join_parser.add_argument(
|
||||
'--experience',
|
||||
choices=['beginner', 'intermediate', 'advanced', 'expert'],
|
||||
help='Experience level (optional)'
|
||||
)
|
||||
|
||||
# Update command
|
||||
update_parser = subparsers.add_parser(
|
||||
'update',
|
||||
help='Update your community profile'
|
||||
)
|
||||
update_parser.add_argument(
|
||||
'--country',
|
||||
help='Update country'
|
||||
)
|
||||
update_parser.add_argument(
|
||||
'--institution',
|
||||
help='Update institution'
|
||||
)
|
||||
update_parser.add_argument(
|
||||
'--course-type',
|
||||
choices=['university', 'bootcamp', 'self-paced', 'other'],
|
||||
help='Update course type'
|
||||
)
|
||||
update_parser.add_argument(
|
||||
'--experience',
|
||||
choices=['beginner', 'intermediate', 'advanced', 'expert'],
|
||||
help='Update experience level'
|
||||
)
|
||||
|
||||
# Leave command
|
||||
leave_parser = subparsers.add_parser(
|
||||
'leave',
|
||||
help='Leave the community (removes your profile)'
|
||||
)
|
||||
leave_parser.add_argument(
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='Skip confirmation'
|
||||
)
|
||||
|
||||
# Stats command
|
||||
stats_parser = subparsers.add_parser(
|
||||
'stats',
|
||||
help='View community statistics'
|
||||
)
|
||||
|
||||
# Profile command
|
||||
profile_parser = subparsers.add_parser(
|
||||
'profile',
|
||||
help='View your community profile'
|
||||
LoginCommand(self.config).add_arguments(login_parser)
|
||||
|
||||
# Logout command (delegates to LogoutCommand)
|
||||
logout_parser = subparsers.add_parser(
|
||||
'logout',
|
||||
help='Log out of TinyTorch'
|
||||
)
|
||||
LogoutCommand(self.config).add_arguments(logout_parser)
|
||||
|
||||
# Leaderboard command (opens browser)
|
||||
leaderboard_parser = subparsers.add_parser(
|
||||
@@ -133,23 +70,17 @@ class CommunityCommand(BaseCommand):
|
||||
'submission_file',
|
||||
help='Path to submission JSON file (e.g., submission.json)'
|
||||
)
|
||||
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
"""Execute community command."""
|
||||
if not args.community_command:
|
||||
self.console.print("[yellow]Please specify a community command: join, leaderboard, compete, profile[/yellow]")
|
||||
self.console.print("[yellow]Please specify a community command: login, logout, leaderboard, compete, submit[/yellow]")
|
||||
return 1
|
||||
|
||||
if args.community_command == 'join':
|
||||
return self._join_community(args)
|
||||
elif args.community_command == 'update':
|
||||
return self._update_profile(args)
|
||||
elif args.community_command == 'leave':
|
||||
return self._leave_community(args)
|
||||
elif args.community_command == 'stats':
|
||||
return self._show_stats(args)
|
||||
elif args.community_command == 'profile':
|
||||
return self._show_profile(args)
|
||||
if args.community_command == 'login':
|
||||
return LoginCommand(self.config).run(args)
|
||||
elif args.community_command == 'logout':
|
||||
return LogoutCommand(self.config).run(args)
|
||||
elif args.community_command == 'leaderboard':
|
||||
return self._open_leaderboard(args)
|
||||
elif args.community_command == 'compete':
|
||||
@@ -157,535 +88,8 @@ class CommunityCommand(BaseCommand):
|
||||
elif args.community_command == 'submit':
|
||||
return self._submit_benchmark(args)
|
||||
else:
|
||||
self.console.print(f"[red]Unknown community command: {args.community_command}[/red]")
|
||||
self.console.print(f"[red]❌ Unknown community command: {args.community_command}[/red]")
|
||||
return 1
|
||||
|
||||
def _join_community(self, args: Namespace) -> int:
|
||||
"""Join the TinyTorch community - GitHub-first flow."""
|
||||
console = self.console
|
||||
|
||||
# Check if already joined
|
||||
profile = self._get_profile()
|
||||
if profile:
|
||||
github_username = profile.get("github_username")
|
||||
profile_url = profile.get("profile_url", "https://tinytorch.ai/community")
|
||||
console.print(Panel(
|
||||
f"[yellow]⚠️ You're already in the community![/yellow]\n\n"
|
||||
f"GitHub: [cyan]@{github_username}[/cyan]\n"
|
||||
f"Profile: [cyan]{profile_url}[/cyan]\n\n"
|
||||
f"Update online: [cyan]{profile_url}[/cyan]\n"
|
||||
f"View profile: [cyan]tito community profile[/cyan]",
|
||||
title="Already Joined",
|
||||
border_style="yellow"
|
||||
))
|
||||
return 0
|
||||
|
||||
console.print(Panel(
|
||||
"[bold cyan]🌍 Join the TinyTorch Community[/bold cyan]\n\n"
|
||||
"Connect with ML systems builders worldwide!\n"
|
||||
"We'll ask 3 quick questions, then open your browser to complete your profile.",
|
||||
title="Welcome",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
console.print("\n[dim]Your data:[/dim]")
|
||||
console.print(" • Stored locally in [cyan].tinytorch/community.json[/cyan]")
|
||||
console.print(" • GitHub username for authentication")
|
||||
console.print(" • Basic info shared with community (country, institution)")
|
||||
console.print(" • Full profile completed on tinytorch.ai\n")
|
||||
|
||||
# Question 1: GitHub username (REQUIRED)
|
||||
console.print("[bold]Question 1/3[/bold]")
|
||||
github_username = Prompt.ask(
|
||||
"[cyan]GitHub username[/cyan] (required for authentication)",
|
||||
default=""
|
||||
).strip()
|
||||
|
||||
if not github_username:
|
||||
console.print("[red]❌ GitHub username is required to join the community[/red]")
|
||||
console.print("[dim]Your GitHub username is used to:\n"
|
||||
" • Authenticate your profile\n "
|
||||
" • Link to your projects\n"
|
||||
" • Connect with other builders[/dim]")
|
||||
return 1
|
||||
|
||||
# Question 2: Country (optional, auto-detect)
|
||||
console.print("\n[bold]Question 2/3[/bold]")
|
||||
country = self._detect_country()
|
||||
if country:
|
||||
console.print(f"[dim]Auto-detected: {country}[/dim]")
|
||||
country = Prompt.ask(
|
||||
"[cyan]Country[/cyan] (for community map, optional)",
|
||||
default=country or "",
|
||||
show_default=False
|
||||
).strip()
|
||||
|
||||
# Question 3: Institution (optional)
|
||||
console.print("\n[bold]Question 3/3[/bold]")
|
||||
institution = Prompt.ask(
|
||||
"[cyan]Institution/University[/cyan] (optional)",
|
||||
default="",
|
||||
show_default=False
|
||||
).strip()
|
||||
|
||||
# Create local profile
|
||||
profile = {
|
||||
"github_username": github_username,
|
||||
"joined_at": datetime.now().isoformat(),
|
||||
"country": country or None,
|
||||
"institution": institution or None,
|
||||
"profile_url": f"https://tinytorch.ai/community/{github_username}",
|
||||
"last_synced": None
|
||||
}
|
||||
|
||||
# Save profile locally
|
||||
self._save_profile(profile)
|
||||
|
||||
# Build URL with pre-filled params
|
||||
base_url = "https://tinytorch.ai/community/join"
|
||||
params = {
|
||||
"github": github_username,
|
||||
}
|
||||
if country:
|
||||
params["country"] = country
|
||||
if institution:
|
||||
params["institution"] = institution
|
||||
|
||||
signup_url = f"{base_url}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
# Show success and open browser
|
||||
console.print("\n")
|
||||
console.print(Panel(
|
||||
f"[bold green]✅ Local profile created![/bold green]\n\n"
|
||||
f"👤 GitHub: [cyan]@{github_username}[/cyan]\n"
|
||||
f"📍 Country: {country or '[dim]Not specified[/dim]'}\n"
|
||||
f"🏫 Institution: {institution or '[dim]Not specified[/dim]'}\n\n"
|
||||
f"[bold cyan]🌐 Opening browser to complete your profile...[/bold cyan]\n"
|
||||
f"[dim]URL: {signup_url}[/dim]\n\n"
|
||||
f"Complete your profile online to:\n"
|
||||
f" • Authenticate with GitHub OAuth\n"
|
||||
f" • Add bio, interests, and social links\n"
|
||||
f" • Join the global community map\n"
|
||||
f" • Connect with other builders",
|
||||
title="Almost There!",
|
||||
border_style="green"
|
||||
))
|
||||
|
||||
# Open browser
|
||||
try:
|
||||
webbrowser.open(signup_url)
|
||||
console.print("\n[green]✓[/green] Browser opened! Complete your profile there.")
|
||||
except Exception as e:
|
||||
console.print(f"\n[yellow]⚠️ Could not open browser automatically[/yellow]")
|
||||
console.print(f"[dim]Please visit: {signup_url}[/dim]")
|
||||
|
||||
console.print(f"\n[dim]💡 View profile later: [cyan]tito community profile[/cyan][/dim]")
|
||||
|
||||
return 0
|
||||
|
||||
def _update_profile(self, args: Namespace) -> int:
|
||||
"""Update community profile."""
|
||||
console = self.console
|
||||
|
||||
# Get existing profile
|
||||
profile = self._get_profile()
|
||||
if not profile:
|
||||
console.print(Panel(
|
||||
"[yellow]⚠️ You're not in the community yet.[/yellow]\n\n"
|
||||
"Join first: [cyan]tito community join[/cyan]",
|
||||
title="Not Joined",
|
||||
border_style="yellow"
|
||||
))
|
||||
return 1
|
||||
|
||||
console.print(Panel(
|
||||
"[bold cyan]📝 Update Your Community Profile[/bold cyan]",
|
||||
title="Update Profile",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
# Update fields
|
||||
updated = False
|
||||
|
||||
if args.country:
|
||||
profile["location"]["country"] = args.country
|
||||
updated = True
|
||||
console.print(f"[green]✅ Updated country: {args.country}[/green]")
|
||||
|
||||
if args.institution:
|
||||
profile["institution"]["name"] = args.institution
|
||||
updated = True
|
||||
console.print(f"[green]✅ Updated institution: {args.institution}[/green]")
|
||||
|
||||
if args.course_type:
|
||||
profile["context"]["course_type"] = args.course_type
|
||||
updated = True
|
||||
console.print(f"[green]✅ Updated course type: {args.course_type}[/green]")
|
||||
|
||||
if args.experience:
|
||||
profile["context"]["experience_level"] = args.experience
|
||||
updated = True
|
||||
console.print(f"[green]✅ Updated experience level: {args.experience}[/green]")
|
||||
|
||||
# If no args provided, do interactive update
|
||||
if not updated:
|
||||
console.print("\n[cyan]Interactive update (press Enter to keep current value):[/cyan]\n")
|
||||
|
||||
# Country
|
||||
current_country = profile["location"].get("country", "")
|
||||
new_country = Prompt.ask(
|
||||
f"[cyan]Country[/cyan]",
|
||||
default=current_country or "",
|
||||
show_default=bool(current_country)
|
||||
)
|
||||
if new_country != current_country:
|
||||
profile["location"]["country"] = new_country or None
|
||||
updated = True
|
||||
|
||||
# Institution
|
||||
current_institution = profile["institution"].get("name", "")
|
||||
new_institution = Prompt.ask(
|
||||
f"[cyan]Institution[/cyan]",
|
||||
default=current_institution or "",
|
||||
show_default=bool(current_institution)
|
||||
)
|
||||
if new_institution != current_institution:
|
||||
profile["institution"]["name"] = new_institution or None
|
||||
updated = True
|
||||
|
||||
# Update progress if available
|
||||
self._update_progress(profile)
|
||||
|
||||
# Save updated profile
|
||||
if updated:
|
||||
profile["updated_at"] = datetime.now().isoformat()
|
||||
self._save_profile(profile)
|
||||
console.print("\n[green]✅ Profile updated successfully![/green]")
|
||||
else:
|
||||
console.print("\n[yellow]No changes made.[/yellow]")
|
||||
|
||||
return 0
|
||||
|
||||
def _leave_community(self, args: Namespace) -> int:
|
||||
"""Leave the community."""
|
||||
console = self.console
|
||||
|
||||
# Get existing profile
|
||||
profile = self._get_profile()
|
||||
if not profile:
|
||||
console.print(Panel(
|
||||
"[yellow]⚠️ You're not in the community.[/yellow]",
|
||||
title="Not Joined",
|
||||
border_style="yellow"
|
||||
))
|
||||
return 0
|
||||
|
||||
# Confirm
|
||||
if not args.force:
|
||||
console.print(Panel(
|
||||
"[yellow]⚠️ Warning: This will remove your community profile[/yellow]\n\n"
|
||||
"This action cannot be undone.\n"
|
||||
"Your benchmark submissions will remain, but your profile will be removed.",
|
||||
title="Leave Community",
|
||||
border_style="yellow"
|
||||
))
|
||||
|
||||
confirm = Confirm.ask("\n[red]Are you sure you want to leave?[/red]", default=False)
|
||||
if not confirm:
|
||||
console.print("[cyan]Cancelled.[/cyan]")
|
||||
return 0
|
||||
|
||||
# Remove profile
|
||||
profile_file = self._get_profile_file()
|
||||
if profile_file.exists():
|
||||
profile_file.unlink()
|
||||
|
||||
# Stub: Notify website of leave
|
||||
self._notify_website_leave(profile.get("anonymous_id") if profile else None)
|
||||
|
||||
console.print(Panel(
|
||||
"[green]✅ You've left the community.[/green]\n\n"
|
||||
"You can rejoin anytime with: [cyan]tito community join[/cyan]",
|
||||
title="Left Community",
|
||||
border_style="green"
|
||||
))
|
||||
|
||||
return 0
|
||||
|
||||
def _show_stats(self, args: Namespace) -> int:
|
||||
"""Show community statistics."""
|
||||
console = self.console
|
||||
|
||||
# For now, show local stats
|
||||
# In production, this would fetch from a server
|
||||
profile = self._get_profile()
|
||||
|
||||
console.print(Panel(
|
||||
"[bold cyan]🌍 TinyTorch Community Stats[/bold cyan]\n\n"
|
||||
"[dim]Note: Full community stats require server connection.[/dim]\n"
|
||||
"This shows your local information.",
|
||||
title="Community Stats",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
if profile:
|
||||
console.print(f"\n[cyan]Your Profile:[/cyan]")
|
||||
console.print(f" • Country: {profile['location'].get('country', 'Not specified')}")
|
||||
console.print(f" • Institution: {profile['institution'].get('name', 'Not specified')}")
|
||||
console.print(f" • Course Type: {profile['context'].get('course_type', 'Not specified')}")
|
||||
console.print(f" • Experience: {profile['context'].get('experience_level', 'Not specified')}")
|
||||
console.print(f" • Cohort: {profile['context'].get('cohort', 'Not specified')}")
|
||||
else:
|
||||
console.print("\n[yellow]You're not in the community yet.[/yellow]")
|
||||
console.print("Join with: [cyan]tito community join[/cyan]")
|
||||
|
||||
return 0
|
||||
|
||||
def _show_profile(self, args: Namespace) -> int:
|
||||
"""Show user's community profile."""
|
||||
console = self.console
|
||||
|
||||
profile = self._get_profile()
|
||||
if not profile:
|
||||
console.print(Panel(
|
||||
"[yellow]⚠️ You're not in the community yet.[/yellow]\n\n"
|
||||
"Join with: [cyan]tito community join[/cyan]",
|
||||
title="Not Joined",
|
||||
border_style="yellow"
|
||||
))
|
||||
return 1
|
||||
|
||||
# Display profile
|
||||
profile_table = Table(title="Your Community Profile", show_header=False, box=None)
|
||||
profile_table.add_column("Field", style="cyan", width=20)
|
||||
profile_table.add_column("Value", style="green")
|
||||
|
||||
profile_table.add_row("Anonymous ID", profile.get("anonymous_id", "N/A"))
|
||||
profile_table.add_row("Joined", self._format_date(profile.get("joined_at")))
|
||||
profile_table.add_row("Country", profile["location"].get("country", "Not specified"))
|
||||
profile_table.add_row("Institution", profile["institution"].get("name", "Not specified"))
|
||||
profile_table.add_row("Course Type", profile["context"].get("course_type", "Not specified"))
|
||||
profile_table.add_row("Experience", profile["context"].get("experience_level", "Not specified"))
|
||||
profile_table.add_row("Cohort", profile["context"].get("cohort", "Not specified"))
|
||||
|
||||
progress = profile.get("progress", {})
|
||||
profile_table.add_row("", "")
|
||||
profile_table.add_row("[bold]Progress[/bold]", "")
|
||||
profile_table.add_row("Setup Verified", "✅" if progress.get("setup_verified") else "❌")
|
||||
profile_table.add_row("Milestones Passed", str(progress.get("milestones_passed", 0)))
|
||||
profile_table.add_row("Modules Completed", str(progress.get("modules_completed", 0)))
|
||||
capstone_score = progress.get("capstone_score")
|
||||
profile_table.add_row("Capstone Score", f"{capstone_score}/100" if capstone_score else "Not completed")
|
||||
|
||||
console.print("\n")
|
||||
console.print(profile_table)
|
||||
|
||||
return 0
|
||||
|
||||
def _get_profile(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get user's community profile."""
|
||||
profile_file = self._get_profile_file()
|
||||
if profile_file.exists():
|
||||
try:
|
||||
with open(profile_file, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _save_profile(self, profile: Dict[str, Any]) -> None:
|
||||
"""Save user's community profile."""
|
||||
profile_file = self._get_profile_file()
|
||||
profile_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(profile_file, 'w') as f:
|
||||
json.dump(profile, f, indent=2)
|
||||
|
||||
# Stub: Sync with website if configured
|
||||
self._sync_profile_to_website(profile)
|
||||
|
||||
def _get_profile_file(self) -> Path:
|
||||
"""Get path to profile file (project-local)."""
|
||||
return self.config.project_root / ".tinytorch" / "community" / "profile.json"
|
||||
|
||||
def _get_config(self) -> Dict[str, Any]:
|
||||
"""Get community configuration."""
|
||||
config_file = self.config.project_root / ".tinytorch" / "config.json"
|
||||
default_config = {
|
||||
"website": {
|
||||
"base_url": "https://tinytorch.ai",
|
||||
"community_map_url": "https://tinytorch.ai/community",
|
||||
"api_url": None, # Set when API is available
|
||||
"enabled": False # Set to True when website integration is ready
|
||||
},
|
||||
"local": {
|
||||
"enabled": True, # Always use local storage
|
||||
"auto_sync": False # Auto-sync to website when enabled
|
||||
}
|
||||
}
|
||||
|
||||
if config_file.exists():
|
||||
try:
|
||||
with open(config_file, 'r') as f:
|
||||
user_config = json.load(f)
|
||||
# Merge with defaults
|
||||
default_config.update(user_config)
|
||||
return default_config
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create default config if it doesn't exist
|
||||
config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(default_config, f, indent=2)
|
||||
|
||||
return default_config
|
||||
|
||||
def _sync_profile_to_website(self, profile: Dict[str, Any]) -> None:
|
||||
"""Stub: Sync profile to website (local for now, website integration later)."""
|
||||
config = self._get_config()
|
||||
|
||||
if not config.get("website", {}).get("enabled", False):
|
||||
# Website integration not enabled, just store locally
|
||||
return
|
||||
|
||||
# Stub for future website API integration
|
||||
api_url = config.get("website", {}).get("api_url")
|
||||
if api_url:
|
||||
# TODO: Implement API call when website is ready
|
||||
# Example:
|
||||
# import requests
|
||||
# response = requests.post(f"{api_url}/api/community/profile", json=profile)
|
||||
# response.raise_for_status()
|
||||
pass
|
||||
|
||||
def _detect_country(self) -> Optional[str]:
|
||||
"""Try to detect country from system."""
|
||||
# Try timezone first
|
||||
try:
|
||||
import time
|
||||
tz = time.tzname[0] if time.daylight == 0 else time.tzname[1]
|
||||
# This is a simple heuristic - could be improved
|
||||
return None # Don't auto-detect for privacy
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _determine_cohort(self) -> str:
|
||||
"""Determine cohort based on current date."""
|
||||
now = datetime.now()
|
||||
month = now.month
|
||||
|
||||
if month in [9, 10, 11, 12]:
|
||||
return f"Fall {now.year}"
|
||||
elif month in [1, 2, 3, 4, 5]:
|
||||
return f"Spring {now.year}"
|
||||
else:
|
||||
return f"Summer {now.year}"
|
||||
|
||||
def _update_progress(self, profile: Dict[str, Any]) -> None:
|
||||
"""Update progress information from local data."""
|
||||
# Check milestone progress
|
||||
milestone_file = Path(".tito") / "milestones.json"
|
||||
if milestone_file.exists():
|
||||
try:
|
||||
with open(milestone_file, 'r') as f:
|
||||
milestones_data = json.load(f)
|
||||
completed = milestones_data.get("completed_milestones", [])
|
||||
profile["progress"]["milestones_passed"] = len(completed)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check module progress
|
||||
progress_file = Path(".tito") / "progress.json"
|
||||
if progress_file.exists():
|
||||
try:
|
||||
with open(progress_file, 'r') as f:
|
||||
progress_data = json.load(f)
|
||||
completed = progress_data.get("completed_modules", [])
|
||||
profile["progress"]["modules_completed"] = len(completed)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check capstone score
|
||||
benchmark_dir = Path(".tito") / "benchmarks"
|
||||
if benchmark_dir.exists():
|
||||
# Find latest capstone benchmark
|
||||
capstone_files = sorted(benchmark_dir.glob("capstone_*.json"), reverse=True)
|
||||
if capstone_files:
|
||||
try:
|
||||
with open(capstone_files[0], 'r') as f:
|
||||
capstone_data = json.load(f)
|
||||
profile["progress"]["capstone_score"] = capstone_data.get("overall_score")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _format_date(self, date_str: Optional[str]) -> str:
|
||||
"""Format ISO date string."""
|
||||
if not date_str:
|
||||
return "N/A"
|
||||
try:
|
||||
dt = datetime.fromisoformat(date_str.replace('Z', '+00:00'))
|
||||
return dt.strftime("%Y-%m-%d %H:%M")
|
||||
except Exception:
|
||||
return date_str
|
||||
|
||||
def _notify_website_join(self, profile: Dict[str, Any]) -> None:
|
||||
"""Stub: Notify website when user joins (local for now, website integration later)."""
|
||||
config = self._get_config()
|
||||
|
||||
if not config.get("website", {}).get("enabled", False):
|
||||
# Website integration not enabled
|
||||
return
|
||||
|
||||
api_url = config.get("website", {}).get("api_url")
|
||||
if api_url:
|
||||
# TODO: Implement API call when website is ready
|
||||
# Example:
|
||||
# import requests
|
||||
# try:
|
||||
# response = requests.post(
|
||||
# f"{api_url}/api/community/join",
|
||||
# json=profile,
|
||||
# timeout=10, # 10 second timeout
|
||||
# headers={"Content-Type": "application/json"}
|
||||
# )
|
||||
# response.raise_for_status()
|
||||
# except requests.Timeout:
|
||||
# self.console.print("[dim]Note: Website sync timed out. Your data is saved locally.[/dim]")
|
||||
# except requests.RequestException as e:
|
||||
# # Log error but don't fail the command
|
||||
# self.console.print(f"[dim]Note: Could not sync with website: {e}[/dim]")
|
||||
# self.console.print("[dim]Your data is saved locally and can be synced later.[/dim]")
|
||||
pass
|
||||
|
||||
def _notify_website_leave(self, anonymous_id: Optional[str]) -> None:
|
||||
"""Stub: Notify website when user leaves (local for now, website integration later)."""
|
||||
config = self._get_config()
|
||||
|
||||
if not config.get("website", {}).get("enabled", False):
|
||||
# Website integration not enabled
|
||||
return
|
||||
|
||||
api_url = config.get("website", {}).get("api_url")
|
||||
if api_url and anonymous_id:
|
||||
# TODO: Implement API call when website is ready
|
||||
# Example:
|
||||
# import requests
|
||||
# try:
|
||||
# response = requests.post(
|
||||
# f"{api_url}/api/community/leave",
|
||||
# json={"anonymous_id": anonymous_id},
|
||||
# timeout=10, # 10 second timeout
|
||||
# headers={"Content-Type": "application/json"}
|
||||
# )
|
||||
# response.raise_for_status()
|
||||
# except requests.Timeout:
|
||||
# self.console.print("[dim]Note: Website sync timed out. Profile removed locally.[/dim]")
|
||||
# except requests.RequestException as e:
|
||||
# # Log error but don't fail the command
|
||||
# self.console.print(f"[dim]Note: Could not sync with website: {e}[/dim]")
|
||||
# self.console.print("[dim]Profile removed locally.[/dim]")
|
||||
pass
|
||||
|
||||
def _open_leaderboard(self, args: Namespace) -> int:
|
||||
"""Open community leaderboard in browser."""
|
||||
@@ -693,10 +97,10 @@ class CommunityCommand(BaseCommand):
|
||||
|
||||
leaderboard_url = "https://tinytorch.ai/community/leaderboard"
|
||||
|
||||
self.console.print(f"[blue]🏆 Opening leaderboard...[/blue]")
|
||||
self.console.print(f"[cyan]🏆 Opening leaderboard...[/cyan]")
|
||||
try:
|
||||
webbrowser.open(leaderboard_url)
|
||||
self.console.print(f"[green]✓[/green] Browser opened: [cyan]{leaderboard_url}[/cyan]")
|
||||
self.console.print(f"[green]✅[/green] Browser opened: [cyan]{leaderboard_url}[/cyan]")
|
||||
except Exception as e:
|
||||
self.console.print(f"[yellow]⚠️ Could not open browser automatically[/yellow]")
|
||||
self.console.print(f"[dim]Please visit: {leaderboard_url}[/dim]")
|
||||
@@ -709,10 +113,10 @@ class CommunityCommand(BaseCommand):
|
||||
|
||||
compete_url = "https://tinytorch.ai/community/compete"
|
||||
|
||||
self.console.print(f"[blue]🎯 Opening competitions...[/blue]")
|
||||
self.console.print(f"[cyan]🎯 Opening competitions...[/cyan]")
|
||||
try:
|
||||
webbrowser.open(compete_url)
|
||||
self.console.print(f"[green]✓[/green] Browser opened: [cyan]{compete_url}[/cyan]")
|
||||
self.console.print(f"[green]✅[/green] Browser opened: [cyan]{compete_url}[/cyan]")
|
||||
except Exception as e:
|
||||
self.console.print(f"[yellow]⚠️ Could not open browser automatically[/yellow]")
|
||||
self.console.print(f"[dim]Please visit: {compete_url}[/dim]")
|
||||
|
||||
@@ -182,7 +182,7 @@ class GradeCommand(BaseCommand):
|
||||
" 9. tito grade export # Export grades\n\n"
|
||||
"[dim]Note: NBGrader must be installed and configured[/dim]",
|
||||
title="Grade Help",
|
||||
border_style="bright_blue"
|
||||
border_style="bright_cyan"
|
||||
)
|
||||
self.console.print(help_panel)
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class LoginCommand(BaseCommand):
|
||||
try:
|
||||
port = receiver.start()
|
||||
target_url = f"{ENDPOINTS['cli_login']}?redirect_port={port}"
|
||||
self.console.print(f"Opening browser to: [blue]{target_url}[/blue]")
|
||||
self.console.print(f"Opening browser to: [cyan]{target_url}[/cyan]")
|
||||
self.console.print("Waiting for authentication...")
|
||||
webbrowser.open(target_url)
|
||||
tokens = receiver.wait_for_tokens()
|
||||
|
||||
@@ -89,7 +89,7 @@ class ProtectCommand(BaseCommand):
|
||||
|
||||
# Show header
|
||||
console.print(Panel.fit(
|
||||
"🛡️ [bold blue]TinyTorch Student Protection System[/bold blue]\n"
|
||||
"🛡️ [bold cyan]TinyTorch Student Protection System[/bold cyan]\n"
|
||||
"Prevents accidental edits to critical core functionality",
|
||||
border_style="blue"
|
||||
))
|
||||
@@ -112,7 +112,7 @@ class ProtectCommand(BaseCommand):
|
||||
|
||||
def _enable_protection(self, console: Console, args: Namespace) -> int:
|
||||
"""🔒 Enable comprehensive protection system."""
|
||||
console.print("[blue]🔒 Enabling TinyTorch Student Protection System...[/blue]")
|
||||
console.print("[cyan]🔒 Enabling TinyTorch Student Protection System...[/blue]")
|
||||
console.print()
|
||||
|
||||
protection_count = 0
|
||||
@@ -221,7 +221,7 @@ echo "✅ No auto-generated files being committed"
|
||||
"📝 GitHub will label files as 'Generated'\n"
|
||||
"🚫 Git prevents committing generated files\n"
|
||||
"⚙️ VSCode shows protection warnings\n\n"
|
||||
"[blue]Students are now protected from breaking CIFAR-10 training![/blue]",
|
||||
"[cyan]Students are now protected from breaking CIFAR-10 training![/blue]",
|
||||
border_style="green"
|
||||
))
|
||||
|
||||
@@ -262,10 +262,10 @@ echo "✅ No auto-generated files being committed"
|
||||
|
||||
def _show_protection_status(self, console: Console) -> int:
|
||||
"""🔍 Show current protection status."""
|
||||
console.print("[blue]🔍 TinyTorch Protection Status[/blue]")
|
||||
console.print("[cyan]🔍 TinyTorch Protection Status[/blue]")
|
||||
console.print()
|
||||
|
||||
table = Table(show_header=True, header_style="bold blue")
|
||||
table = Table(show_header=True, header_style="bold cyan")
|
||||
table.add_column("Protection Feature", style="cyan")
|
||||
table.add_column("Status", justify="center")
|
||||
table.add_column("Details", style="dim")
|
||||
@@ -335,7 +335,7 @@ echo "✅ No auto-generated files being committed"
|
||||
"""✅ Validate core functionality works correctly."""
|
||||
try:
|
||||
from tinytorch.core._validation import run_student_protection_checks
|
||||
console.print("[blue]🔍 Running comprehensive validation...[/blue]")
|
||||
console.print("[cyan]🔍 Running comprehensive validation...[/blue]")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
@@ -358,7 +358,7 @@ echo "✅ No auto-generated files being committed"
|
||||
|
||||
def _quick_health_check(self, console: Console) -> int:
|
||||
"""⚡ Quick health check of critical functionality."""
|
||||
console.print("[blue]⚡ Quick Health Check[/blue]")
|
||||
console.print("[cyan]⚡ Quick Health Check[/blue]")
|
||||
console.print()
|
||||
|
||||
checks = []
|
||||
|
||||
@@ -560,7 +560,7 @@ class TestCommand(BaseCommand):
|
||||
console = self.console
|
||||
|
||||
# Summary table
|
||||
table = Table(title="Test Summary Report", show_header=True, header_style="bold blue")
|
||||
table = Table(title="Test Summary Report", show_header=True, header_style="bold cyan")
|
||||
table.add_column("Module", style="bold cyan", width=15)
|
||||
table.add_column("Status", width=10, justify="center")
|
||||
table.add_column("Inline Tests", width=12, justify="center")
|
||||
|
||||
163
tito/main.py
163
tito/main.py
@@ -38,7 +38,6 @@ from .commands.milestone import MilestoneCommand
|
||||
from .commands.setup import SetupCommand
|
||||
from .commands.benchmark import BenchmarkCommand
|
||||
from .commands.community import CommunityCommand
|
||||
from .commands.login import LoginCommand, LogoutCommand
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -80,35 +79,98 @@ class TinyTorchCLI:
|
||||
'test': TestCommand,
|
||||
'grade': GradeCommand,
|
||||
'logo': LogoCommand,
|
||||
# Authentication commands
|
||||
'login': LoginCommand,
|
||||
'logout': LogoutCommand,
|
||||
}
|
||||
|
||||
# Command categorization for help display
|
||||
self.student_commands = ['module', 'milestones', 'community', 'benchmark']
|
||||
self.developer_commands = ['system', 'src', 'package', 'nbgrader']
|
||||
|
||||
# Welcome screen sections (used for both tito and tito --help)
|
||||
self.welcome_sections = {
|
||||
'quick_start': [
|
||||
('[green]tito setup[/green]', 'First-time setup'),
|
||||
('[green]tito module start 01[/green]', 'Start Module 01 (tensors)'),
|
||||
('[green]tito module complete 01[/green]', 'Test, export, and track progress'),
|
||||
],
|
||||
'track_progress': [
|
||||
('[yellow]tito module status[/yellow]', 'View module progress'),
|
||||
('[yellow]tito milestones status[/yellow]', 'View unlocked capabilities'),
|
||||
],
|
||||
'community': [
|
||||
('[cyan]tito community login[/cyan]', 'Log in to TinyTorch'),
|
||||
('[cyan]tito community leaderboard[/cyan]', 'View global leaderboard'),
|
||||
],
|
||||
'help_docs': [
|
||||
('[magenta]tito system doctor[/magenta]', 'Check environment health'),
|
||||
('[magenta]tito --help[/magenta]', 'See all commands'),
|
||||
]
|
||||
}
|
||||
|
||||
def _generate_welcome_text(self) -> str:
|
||||
"""Generate dynamic welcome text for interactive mode."""
|
||||
lines = []
|
||||
|
||||
# Quick Start
|
||||
lines.append("[bold cyan]Quick Start:[/bold cyan]")
|
||||
for cmd, desc in self.welcome_sections['quick_start']:
|
||||
lines.append(f" {cmd:<38} {desc}")
|
||||
|
||||
# Track Progress
|
||||
lines.append("\n[bold cyan]Track Progress:[/bold cyan]")
|
||||
for cmd, desc in self.welcome_sections['track_progress']:
|
||||
lines.append(f" {cmd:<38} {desc}")
|
||||
|
||||
# Community
|
||||
lines.append("\n[bold cyan]Community:[/bold cyan]")
|
||||
for cmd, desc in self.welcome_sections['community']:
|
||||
lines.append(f" {cmd:<38} {desc}")
|
||||
|
||||
# Help & Docs
|
||||
lines.append("\n[bold cyan]Help & Docs:[/bold cyan]")
|
||||
for cmd, desc in self.welcome_sections['help_docs']:
|
||||
lines.append(f" {cmd:<38} {desc}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_epilog(self) -> str:
|
||||
"""Generate dynamic epilog from registered commands."""
|
||||
lines = []
|
||||
|
||||
# Student Commands section
|
||||
lines.append("Student Commands:")
|
||||
for cmd_name in self.student_commands:
|
||||
if cmd_name in self.commands:
|
||||
cmd = self.commands[cmd_name](self.config)
|
||||
# Simplify description for epilog (first sentence or shorter version)
|
||||
desc = cmd.description.split('.')[0].split('-')[0].strip()
|
||||
lines.append(f" {cmd_name:<12} {desc}")
|
||||
lines.append("")
|
||||
|
||||
# Developer Commands section
|
||||
lines.append("Developer Commands:")
|
||||
for cmd_name in self.developer_commands:
|
||||
if cmd_name in self.commands:
|
||||
cmd = self.commands[cmd_name](self.config)
|
||||
desc = cmd.description.split('.')[0].split('-')[0].strip()
|
||||
lines.append(f" {cmd_name:<12} {desc}")
|
||||
lines.append("")
|
||||
|
||||
# Quick Start section (strip Rich formatting for plain text)
|
||||
lines.append("Quick Start:")
|
||||
for cmd, desc in self.welcome_sections['quick_start']:
|
||||
# Remove Rich color tags for plain epilog
|
||||
plain_cmd = cmd.replace('[green]', '').replace('[/green]', '')
|
||||
lines.append(f" {plain_cmd:<28} {desc}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def create_parser(self) -> argparse.ArgumentParser:
|
||||
"""Create the main argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="tito",
|
||||
description="Tiny🔥Torch CLI - Build ML systems from scratch",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Student Commands:
|
||||
module Module workflow - start, work, complete modules
|
||||
milestones Track progress - unlock capabilities as you build
|
||||
community Join global community - connect with builders
|
||||
|
||||
Developer Commands:
|
||||
system Environment and configuration
|
||||
src Export src/ to modules/ and tinytorch/
|
||||
package Package management (nbdev)
|
||||
nbgrader Auto-grading tools
|
||||
|
||||
Quick Start:
|
||||
tito setup First-time setup
|
||||
tito module start 01 Start Module 01 (tensors)
|
||||
tito module complete 01 Test, export, and track progress
|
||||
tito module status View your progress
|
||||
"""
|
||||
epilog=self._generate_epilog()
|
||||
)
|
||||
|
||||
# Global options
|
||||
@@ -163,9 +225,49 @@ Quick Start:
|
||||
|
||||
return True
|
||||
|
||||
def _show_help(self) -> int:
|
||||
"""Show custom Rich-formatted help."""
|
||||
from rich.table import Table
|
||||
|
||||
# Show ASCII logo
|
||||
print_ascii_logo()
|
||||
|
||||
# Create commands table
|
||||
table = Table(show_header=True, header_style="bold cyan", box=None, padding=(0, 2))
|
||||
table.add_column("Command", style="green", width=15)
|
||||
table.add_column("Description", style="dim")
|
||||
|
||||
# Add all commands dynamically
|
||||
for cmd_name, cmd_class in self.commands.items():
|
||||
cmd = cmd_class(self.config)
|
||||
table.add_row(cmd_name, cmd.description)
|
||||
|
||||
self.console.print()
|
||||
self.console.print("[bold cyan]Tiny🔥Torch CLI[/bold cyan] - Build ML systems from scratch")
|
||||
self.console.print()
|
||||
self.console.print("[bold]Usage:[/bold] [cyan]tito[/cyan] [yellow]COMMAND[/yellow] [dim][OPTIONS][/dim]")
|
||||
self.console.print()
|
||||
self.console.print("[bold cyan]Available Commands:[/bold cyan]")
|
||||
self.console.print(table)
|
||||
self.console.print()
|
||||
self.console.print(self._generate_welcome_text())
|
||||
self.console.print()
|
||||
self.console.print("[bold cyan]Global Options:[/bold cyan]")
|
||||
self.console.print(" [yellow]--help, -h[/yellow] Show this help message")
|
||||
self.console.print(" [yellow]--version[/yellow] Show version number")
|
||||
self.console.print(" [yellow]--verbose, -v[/yellow] Enable verbose output")
|
||||
self.console.print(" [yellow]--no-color[/yellow] Disable colored output")
|
||||
self.console.print()
|
||||
|
||||
return 0
|
||||
|
||||
def run(self, args: Optional[List[str]] = None) -> int:
|
||||
"""Run the CLI application."""
|
||||
try:
|
||||
# Check for help flag before argparse to use Rich formatting
|
||||
if args and ('-h' in args or '--help' in args) and len(args) == 1:
|
||||
return self._show_help()
|
||||
|
||||
parser = self.create_parser()
|
||||
parsed_args = parser.parse_args(args)
|
||||
|
||||
@@ -197,22 +299,9 @@ Quick Start:
|
||||
# Show ASCII logo first
|
||||
print_ascii_logo()
|
||||
|
||||
# Simple, focused welcome message
|
||||
help_text = "[bold cyan]Quick Start:[/bold cyan]\n"
|
||||
help_text += " [green]tito setup[/green] First-time setup\n"
|
||||
help_text += " [green]tito module start 01[/green] Start Module 01 (tensors)\n"
|
||||
help_text += " [green]tito module complete 01[/green] Test, export, and track progress\n"
|
||||
help_text += "\n[bold cyan]Track Progress:[/bold cyan]\n"
|
||||
help_text += " [yellow]tito module status[/yellow] View module progress\n"
|
||||
help_text += " [yellow]tito milestones status[/yellow] View unlocked capabilities\n"
|
||||
help_text += "\n[bold cyan]Community:[/bold cyan]\n"
|
||||
help_text += " [blue]tito community join[/blue] Connect with builders worldwide\n"
|
||||
help_text += "\n[bold cyan]Help & Docs:[/bold cyan]\n"
|
||||
help_text += " [magenta]tito system doctor[/magenta] Check environment health\n"
|
||||
help_text += " [magenta]tito --help[/magenta] See all commands"
|
||||
|
||||
# Generate dynamic welcome message
|
||||
self.console.print(Panel(
|
||||
help_text,
|
||||
self._generate_welcome_text(),
|
||||
title="Welcome to Tiny🔥Torch!",
|
||||
border_style="bright_green"
|
||||
))
|
||||
|
||||
Reference in New Issue
Block a user