mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-05-24 16:35:52 -05:00
1140 lines
98 KiB
HTML
1140 lines
98 KiB
HTML
|
||
<!DOCTYPE html>
|
||
|
||
|
||
<html lang="en" data-content_root="../" >
|
||
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
|
||
<title>06. Optimizers — Tiny🔥Torch</title>
|
||
|
||
|
||
|
||
<script data-cfasync="false">
|
||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||
</script>
|
||
|
||
<!-- Loaded before other Sphinx assets -->
|
||
<link href="../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||
<link href="../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||
|
||
|
||
<link href="../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" />
|
||
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" />
|
||
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" />
|
||
|
||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=eba8b062" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/togglebutton.css?v=13237357" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css?v=76b2166b" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/mystnb.8ecb98da25f57f5357bf6f572d296f466b2cfe2517ffebfabe82451661e28f02.css" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/sphinx-thebe.css?v=4fa983c6" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/sphinx-design.min.css?v=95c83b7e" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/custom.css?v=009d37f4" />
|
||
|
||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" />
|
||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
|
||
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||
|
||
<script src="../_static/documentation_options.js?v=9eb32ce0"></script>
|
||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||
<script src="../_static/copybutton.js?v=f281be69"></script>
|
||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||
<script>let toggleHintShow = 'Click to show';</script>
|
||
<script>let toggleHintHide = 'Click to hide';</script>
|
||
<script>let toggleOpenOnPrint = 'true';</script>
|
||
<script src="../_static/togglebutton.js?v=4a39c7ea"></script>
|
||
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
|
||
<script src="../_static/design-tabs.js?v=f930bc37"></script>
|
||
<script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"; const thebe_selector = ".thebe,.cell"; const thebe_selector_input = "pre"; const thebe_selector_output = ".output, .cell_output"</script>
|
||
<script async="async" src="../_static/sphinx-thebe.js?v=c100c467"></script>
|
||
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
|
||
<script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"; const thebe_selector = ".thebe,.cell"; const thebe_selector_input = "pre"; const thebe_selector_output = ".output, .cell_output"</script>
|
||
<script>DOCUMENTATION_OPTIONS.pagename = 'modules/06_optimizers_ABOUT';</script>
|
||
<script src="../_static/ml-timeline.js?v=76e9b3e3"></script>
|
||
<script src="../_static/wip-banner.js?v=04a7e74d"></script>
|
||
<script src="../_static/marimo-badges.js?v=e6289128"></script>
|
||
<script src="../_static/sidebar-link.js?v=404b701b"></script>
|
||
<script src="../_static/hero-carousel.js?v=10341d2a"></script>
|
||
<script src="../_static/subscribe-modal.js?v=42919b64"></script>
|
||
<link rel="icon" href="../_static/favicon.svg"/>
|
||
<link rel="index" title="Index" href="../genindex.html" />
|
||
<link rel="search" title="Search" href="../search.html" />
|
||
<link rel="next" title="07. Training" href="07_training_ABOUT.html" />
|
||
<link rel="prev" title="05. Autograd" href="05_autograd_ABOUT.html" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||
<meta name="docsearch:language" content="en"/>
|
||
</head>
|
||
|
||
|
||
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
|
||
|
||
|
||
|
||
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
|
||
|
||
<div id="pst-scroll-pixel-helper"></div>
|
||
|
||
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
|
||
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
|
||
|
||
|
||
<input type="checkbox"
|
||
class="sidebar-toggle"
|
||
id="pst-primary-sidebar-checkbox"/>
|
||
<label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label>
|
||
|
||
<input type="checkbox"
|
||
class="sidebar-toggle"
|
||
id="pst-secondary-sidebar-checkbox"/>
|
||
<label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label>
|
||
|
||
<div class="search-button__wrapper">
|
||
<div class="search-button__overlay"></div>
|
||
<div class="search-button__search-container">
|
||
<form class="bd-search d-flex align-items-center"
|
||
action="../search.html"
|
||
method="get">
|
||
<i class="fa-solid fa-magnifying-glass"></i>
|
||
<input type="search"
|
||
class="form-control"
|
||
name="q"
|
||
id="search-input"
|
||
placeholder="Search..."
|
||
aria-label="Search..."
|
||
autocomplete="off"
|
||
autocorrect="off"
|
||
autocapitalize="off"
|
||
spellcheck="false"/>
|
||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
|
||
</form></div>
|
||
</div>
|
||
|
||
<div class="pst-async-banner-revealer d-none">
|
||
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
|
||
</div>
|
||
|
||
|
||
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
|
||
</header>
|
||
|
||
|
||
<div class="bd-container">
|
||
<div class="bd-container__inner bd-page-width">
|
||
|
||
|
||
|
||
<div class="bd-sidebar-primary bd-sidebar">
|
||
|
||
|
||
|
||
<div class="sidebar-header-items sidebar-primary__section">
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
<div class="sidebar-primary-items__start sidebar-primary__section">
|
||
<div class="sidebar-primary-item">
|
||
|
||
|
||
|
||
|
||
|
||
<a class="navbar-brand logo" href="../intro.html">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<img src="../_static/logo-tinytorch.png" class="logo__image only-light" alt="Tiny🔥Torch - Home"/>
|
||
<script>document.write(`<img src="../_static/logo-tinytorch.png" class="logo__image only-dark" alt="Tiny🔥Torch - Home"/>`);</script>
|
||
|
||
|
||
</a></div>
|
||
<div class="sidebar-primary-item">
|
||
|
||
<script>
|
||
document.write(`
|
||
<button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="fa-solid fa-magnifying-glass"></i>
|
||
<span class="search-button__default-text">Search</span>
|
||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
|
||
</button>
|
||
`);
|
||
</script></div>
|
||
<div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main">
|
||
<div class="bd-toc-item navbar-nav active">
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🚀 Getting Started</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../getting-started.html">Complete Guide</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🏗 Foundation Tier (01-07)</span></p>
|
||
<ul class="current nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../tiers/foundation.html">📖 Tier Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="01_tensor_ABOUT.html">01. Tensor</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="02_activations_ABOUT.html">02. Activations</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="03_layers_ABOUT.html">03. Layers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="04_losses_ABOUT.html">04. Losses</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="05_autograd_ABOUT.html">05. Autograd</a></li>
|
||
<li class="toctree-l1 current active"><a class="current reference internal" href="#">06. Optimizers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="07_training_ABOUT.html">07. Training</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🏛️ Architecture Tier (08-13)</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../tiers/architecture.html">📖 Tier Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="08_dataloader_ABOUT.html">08. DataLoader</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="09_spatial_ABOUT.html">09. Convolutions</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="10_tokenization_ABOUT.html">10. Tokenization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="11_embeddings_ABOUT.html">11. Embeddings</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="12_attention_ABOUT.html">12. Attention</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="13_transformers_ABOUT.html">13. Transformers</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">⏱️ Optimization Tier (14-19)</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../tiers/optimization.html">📖 Tier Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="14_profiling_ABOUT.html">14. Profiling</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="15_quantization_ABOUT.html">15. Quantization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="16_compression_ABOUT.html">16. Compression</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="17_memoization_ABOUT.html">17. Memoization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="18_acceleration_ABOUT.html">18. Acceleration</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="19_benchmarking_ABOUT.html">19. Benchmarking</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🏅 Capstone Competition</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../tiers/olympics.html">📖 Competition Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="20_capstone_ABOUT.html">20. Torch Olympics</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🧭 Course Orientation</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../chapters/00-introduction.html">Course Structure</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../prerequisites.html">Prerequisites & Resources</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../chapters/learning-journey.html">Learning Journey</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../chapters/milestones.html">Historical Milestones</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../faq.html">FAQ</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🛠️ TITO CLI Reference</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../tito/overview.html">Command Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../tito/modules.html">Module Workflow</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../tito/milestones.html">Milestone System</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../tito/data.html">Progress & Data</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../tito/troubleshooting.html">Troubleshooting</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../datasets.html">Datasets Guide</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🤝 Community</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../community.html">Ecosystem</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../resources.html">Learning Resources</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../credits.html">Credits & Acknowledgments</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</nav></div>
|
||
</div>
|
||
|
||
|
||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||
</div>
|
||
|
||
<div id="rtd-footer-container"></div>
|
||
|
||
|
||
</div>
|
||
|
||
<main id="main-content" class="bd-main" role="main">
|
||
|
||
|
||
|
||
<div class="sbt-scroll-pixel-helper"></div>
|
||
|
||
<div class="bd-content">
|
||
<div class="bd-article-container">
|
||
|
||
<div class="bd-header-article d-print-none">
|
||
<div class="header-article-items header-article__inner">
|
||
|
||
<div class="header-article-items__start">
|
||
|
||
<div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<span class="fa-solid fa-bars"></span>
|
||
</button></div>
|
||
|
||
</div>
|
||
|
||
|
||
<div class="header-article-items__end">
|
||
|
||
<div class="header-article-item">
|
||
|
||
<div class="article-header-buttons">
|
||
|
||
|
||
|
||
|
||
|
||
<div class="dropdown dropdown-download-buttons">
|
||
<button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page">
|
||
<i class="fas fa-download"></i>
|
||
</button>
|
||
<ul class="dropdown-menu">
|
||
|
||
|
||
|
||
<li><a href="../_sources/modules/06_optimizers_ABOUT.md" target="_blank"
|
||
class="btn btn-sm btn-download-source-button dropdown-item"
|
||
title="Download source file"
|
||
data-bs-placement="left" data-bs-toggle="tooltip"
|
||
>
|
||
|
||
|
||
<span class="btn__icon-container">
|
||
<i class="fas fa-file"></i>
|
||
</span>
|
||
<span class="btn__text-container">.md</span>
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
<li>
|
||
<button onclick="window.print()"
|
||
class="btn btn-sm btn-download-pdf-button dropdown-item"
|
||
title="Print to PDF"
|
||
data-bs-placement="left" data-bs-toggle="tooltip"
|
||
>
|
||
|
||
|
||
<span class="btn__icon-container">
|
||
<i class="fas fa-file-pdf"></i>
|
||
</span>
|
||
<span class="btn__text-container">.pdf</span>
|
||
</button>
|
||
</li>
|
||
|
||
</ul>
|
||
</div>
|
||
|
||
|
||
|
||
|
||
<button onclick="toggleFullScreen()"
|
||
class="btn btn-sm btn-fullscreen-button"
|
||
title="Fullscreen mode"
|
||
data-bs-placement="bottom" data-bs-toggle="tooltip"
|
||
>
|
||
|
||
|
||
<span class="btn__icon-container">
|
||
<i class="fas fa-expand"></i>
|
||
</span>
|
||
|
||
</button>
|
||
|
||
|
||
|
||
<script>
|
||
document.write(`
|
||
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i>
|
||
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i>
|
||
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i>
|
||
</button>
|
||
`);
|
||
</script>
|
||
|
||
|
||
<script>
|
||
document.write(`
|
||
<button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="fa-solid fa-magnifying-glass fa-lg"></i>
|
||
</button>
|
||
`);
|
||
</script>
|
||
<button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<span class="fa-solid fa-list"></span>
|
||
</button>
|
||
</div></div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</div>
|
||
|
||
|
||
|
||
<div id="jb-print-docs-body" class="onlyprint">
|
||
<h1>06. Optimizers</h1>
|
||
<!-- Table of contents -->
|
||
<div id="print-main-content">
|
||
<div id="jb-print-toc">
|
||
|
||
<div>
|
||
<h2> Contents </h2>
|
||
</div>
|
||
<nav aria-label="Page">
|
||
<ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#overview">Overview</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#learning-objectives">Learning Objectives</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#build-use-reflect">Build → Use → Reflect</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementation-guide">Implementation Guide</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#core-optimization-algorithms">Core Optimization Algorithms</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#sgd-with-momentum-implementation">SGD with Momentum Implementation</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#adam-optimizer-implementation">Adam Optimizer Implementation</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#adamw-implementation-decoupled-weight-decay">AdamW Implementation (Decoupled Weight Decay)</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#complete-training-integration">Complete Training Integration</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started">Getting Started</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#prerequisites">Prerequisites</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#development-workflow">Development Workflow</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#testing">Testing</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#comprehensive-test-suite">Comprehensive Test Suite</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#test-coverage-areas">Test Coverage Areas</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#inline-testing-convergence-analysis">Inline Testing & Convergence Analysis</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#manual-testing-examples">Manual Testing Examples</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#systems-thinking-questions">Systems Thinking Questions</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#real-world-applications">Real-World Applications</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#optimization-theory-foundations">Optimization Theory Foundations</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#performance-characteristics">Performance Characteristics</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#critical-thinking-memory-vs-convergence-trade-offs">Critical Thinking: Memory vs Convergence Trade-offs</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#ready-to-build">Ready to Build?</a></li>
|
||
</ul>
|
||
</nav>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
|
||
<div id="searchbox"></div>
|
||
<article class="bd-article">
|
||
|
||
<section id="optimizers">
|
||
<h1>06. Optimizers<a class="headerlink" href="#optimizers" title="Link to this heading">#</a></h1>
|
||
<p><strong>FOUNDATION TIER</strong> | Difficulty: ⭐⭐⭐⭐ (4/4) | Time: 6-8 hours</p>
|
||
<section id="overview">
|
||
<h2>Overview<a class="headerlink" href="#overview" title="Link to this heading">#</a></h2>
|
||
<p>Welcome to the Optimizers module! You’ll implement the learning algorithms that power every neural network—transforming gradients into intelligent parameter updates that enable models to learn from data. This module builds the optimization foundation used across all modern deep learning frameworks.</p>
|
||
</section>
|
||
<section id="learning-objectives">
|
||
<h2>Learning Objectives<a class="headerlink" href="#learning-objectives" title="Link to this heading">#</a></h2>
|
||
<p>By the end of this module, you will be able to:</p>
|
||
<ul class="simple">
|
||
<li><p><strong>Understand optimization dynamics</strong>: Master convergence behavior, learning rate sensitivity, and how gradients guide parameter updates in high-dimensional loss landscapes</p></li>
|
||
<li><p><strong>Implement core optimization algorithms</strong>: Build SGD, momentum, Adam, and AdamW optimizers from mathematical first principles</p></li>
|
||
<li><p><strong>Analyze memory-convergence trade-offs</strong>: Understand why Adam uses 3x memory but converges faster than SGD on many problems</p></li>
|
||
<li><p><strong>Master adaptive learning rates</strong>: See how Adam’s per-parameter learning rates handle different gradient scales automatically</p></li>
|
||
<li><p><strong>Connect to production frameworks</strong>: Understand how your implementations mirror PyTorch’s torch.optim.SGD and torch.optim.Adam design patterns</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="build-use-reflect">
|
||
<h2>Build → Use → Reflect<a class="headerlink" href="#build-use-reflect" title="Link to this heading">#</a></h2>
|
||
<p>This module follows TinyTorch’s <strong>Build → Use → Reflect</strong> framework:</p>
|
||
<ol class="arabic simple">
|
||
<li><p><strong>Build</strong>: Implement SGD with momentum, Adam optimizer with adaptive learning rates, and AdamW with decoupled weight decay from mathematical foundations</p></li>
|
||
<li><p><strong>Use</strong>: Apply optimization algorithms to train neural networks on real classification and regression tasks</p></li>
|
||
<li><p><strong>Reflect</strong>: Why does Adam converge faster initially but SGD often achieves better final test accuracy? What’s the memory cost of adaptive learning rates?</p></li>
|
||
</ol>
|
||
</section>
|
||
<section id="implementation-guide">
|
||
<h2>Implementation Guide<a class="headerlink" href="#implementation-guide" title="Link to this heading">#</a></h2>
|
||
<section id="core-optimization-algorithms">
|
||
<h3>Core Optimization Algorithms<a class="headerlink" href="#core-optimization-algorithms" title="Link to this heading">#</a></h3>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Base optimizer class with parameter management</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">Optimizer</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Base class defining optimizer interface."""</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">params</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">step_count</span> <span class="o">=</span> <span class="mi">0</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">zero_grad</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Clear gradients from all parameters."""</span>
|
||
<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">:</span>
|
||
<span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Update parameters - implemented by subclasses."""</span>
|
||
<span class="k">raise</span> <span class="ne">NotImplementedError</span>
|
||
|
||
<span class="c1"># SGD with momentum for accelerated convergence</span>
|
||
<span class="n">sgd</span> <span class="o">=</span> <span class="n">SGD</span><span class="p">(</span><span class="n">parameters</span><span class="o">=</span><span class="p">[</span><span class="n">w1</span><span class="p">,</span> <span class="n">w2</span><span class="p">,</span> <span class="n">bias</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">)</span>
|
||
<span class="n">sgd</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="c1"># Clear previous gradients</span>
|
||
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="c1"># Compute new gradients via autograd</span>
|
||
<span class="n">sgd</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="c1"># Update parameters with momentum</span>
|
||
|
||
<span class="c1"># Adam optimizer with adaptive learning rates</span>
|
||
<span class="n">adam</span> <span class="o">=</span> <span class="n">Adam</span><span class="p">(</span><span class="n">parameters</span><span class="o">=</span><span class="p">[</span><span class="n">w1</span><span class="p">,</span> <span class="n">w2</span><span class="p">,</span> <span class="n">bias</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">))</span>
|
||
<span class="n">adam</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
||
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
||
<span class="n">adam</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="c1"># Adaptive updates per parameter</span>
|
||
|
||
<span class="c1"># AdamW with decoupled weight decay</span>
|
||
<span class="n">adamw</span> <span class="o">=</span> <span class="n">AdamW</span><span class="p">(</span><span class="n">parameters</span><span class="o">=</span><span class="p">[</span><span class="n">w1</span><span class="p">,</span> <span class="n">w2</span><span class="p">,</span> <span class="n">bias</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>
|
||
<span class="n">adamw</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
||
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
||
<span class="n">adamw</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="c1"># Adam + proper regularization</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="sgd-with-momentum-implementation">
|
||
<h3>SGD with Momentum Implementation<a class="headerlink" href="#sgd-with-momentum-implementation" title="Link to this heading">#</a></h3>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">SGD</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Stochastic Gradient Descent with momentum.</span>
|
||
|
||
<span class="sd"> Momentum physics: velocity accumulates gradients over time,</span>
|
||
<span class="sd"> smoothing noisy updates and accelerating in consistent directions.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">,</span>
|
||
<span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">lr</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">=</span> <span class="n">weight_decay</span>
|
||
<span class="c1"># Initialize momentum buffers (created lazily)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">momentum_buffers</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">]</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Update parameters using momentum: v = βv + ∇L, θ = θ - αv"""</span>
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">continue</span>
|
||
|
||
<span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span>
|
||
|
||
<span class="c1"># Apply weight decay</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">grad</span> <span class="o">=</span> <span class="n">grad</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">*</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span>
|
||
|
||
<span class="c1"># Update momentum buffer</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">momentum_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Update velocity: v_t = β*v_{t-1} + grad</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">momentum_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="o">+</span> <span class="n">grad</span><span class="p">)</span>
|
||
<span class="n">grad</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">momentum_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
|
||
<span class="c1"># Update parameter: θ_t = θ_{t-1} - α*v_t</span>
|
||
<span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">*</span> <span class="n">grad</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">step_count</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="adam-optimizer-implementation">
|
||
<h3>Adam Optimizer Implementation<a class="headerlink" href="#adam-optimizer-implementation" title="Link to this heading">#</a></h3>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Adam</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Adam optimizer with adaptive learning rates.</span>
|
||
|
||
<span class="sd"> Combines momentum (first moment) with RMSprop-style adaptive rates</span>
|
||
<span class="sd"> (second moment) for robust optimization across different scales.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">betas</span><span class="p">:</span> <span class="nb">tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">),</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-8</span><span class="p">,</span>
|
||
<span class="n">weight_decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">lr</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beta1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta2</span> <span class="o">=</span> <span class="n">betas</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">=</span> <span class="n">weight_decay</span>
|
||
|
||
<span class="c1"># Initialize moment estimates (3x memory vs SGD)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">m_buffers</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">]</span> <span class="c1"># First moment</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">v_buffers</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">]</span> <span class="c1"># Second moment</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Update parameters with adaptive learning rates"""</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">step_count</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">continue</span>
|
||
|
||
<span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span>
|
||
|
||
<span class="c1"># Apply weight decay (Adam's approach - has issues)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">grad</span> <span class="o">=</span> <span class="n">grad</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">*</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span>
|
||
|
||
<span class="c1"># Initialize buffers if needed</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">m_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">m_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">v_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Update biased first moment: m_t = β1*m_{t-1} + (1-β1)*grad</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">m_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">beta1</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">m_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta1</span><span class="p">)</span> <span class="o">*</span> <span class="n">grad</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Update biased second moment: v_t = β2*v_{t-1} + (1-β2)*grad²</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">v_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">beta2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">v_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta2</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">grad</span> <span class="o">**</span> <span class="mi">2</span><span class="p">))</span>
|
||
|
||
<span class="c1"># Bias correction (critical for early training steps)</span>
|
||
<span class="n">bias_correction1</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta1</span> <span class="o">**</span> <span class="bp">self</span><span class="o">.</span><span class="n">step_count</span>
|
||
<span class="n">bias_correction2</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta2</span> <span class="o">**</span> <span class="bp">self</span><span class="o">.</span><span class="n">step_count</span>
|
||
|
||
<span class="n">m_hat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">m_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">/</span> <span class="n">bias_correction1</span>
|
||
<span class="n">v_hat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">v_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">/</span> <span class="n">bias_correction2</span>
|
||
|
||
<span class="c1"># Adaptive parameter update: θ = θ - α*m_hat/(√v_hat + ε)</span>
|
||
<span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">*</span> <span class="n">m_hat</span>
|
||
<span class="o">/</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">v_hat</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">))</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="adamw-implementation-decoupled-weight-decay">
|
||
<h3>AdamW Implementation (Decoupled Weight Decay)<a class="headerlink" href="#adamw-implementation-decoupled-weight-decay" title="Link to this heading">#</a></h3>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">AdamW</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""AdamW optimizer with decoupled weight decay.</span>
|
||
|
||
<span class="sd"> AdamW fixes Adam's weight decay bug by applying regularization</span>
|
||
<span class="sd"> directly to parameters, separate from gradient-based updates.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.001</span><span class="p">,</span>
|
||
<span class="n">betas</span><span class="p">:</span> <span class="nb">tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">),</span> <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-8</span><span class="p">,</span>
|
||
<span class="n">weight_decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">lr</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">beta1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta2</span> <span class="o">=</span> <span class="n">betas</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">=</span> <span class="n">weight_decay</span>
|
||
|
||
<span class="c1"># Initialize moment buffers (same as Adam)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">m_buffers</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">v_buffers</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">]</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Perform AdamW update with decoupled weight decay"""</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">step_count</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">params</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">continue</span>
|
||
|
||
<span class="c1"># Get gradient (NOT modified by weight decay - key difference!)</span>
|
||
<span class="n">grad</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span>
|
||
|
||
<span class="c1"># Initialize buffers if needed</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">m_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">m_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">v_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Update moments using pure gradients</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">m_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">beta1</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">m_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta1</span><span class="p">)</span> <span class="o">*</span> <span class="n">grad</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">v_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">beta2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">v_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta2</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">grad</span> <span class="o">**</span> <span class="mi">2</span><span class="p">))</span>
|
||
|
||
<span class="c1"># Compute bias correction</span>
|
||
<span class="n">bias_correction1</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta1</span> <span class="o">**</span> <span class="bp">self</span><span class="o">.</span><span class="n">step_count</span>
|
||
<span class="n">bias_correction2</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta2</span> <span class="o">**</span> <span class="bp">self</span><span class="o">.</span><span class="n">step_count</span>
|
||
|
||
<span class="n">m_hat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">m_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">/</span> <span class="n">bias_correction1</span>
|
||
<span class="n">v_hat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">v_buffers</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">/</span> <span class="n">bias_correction2</span>
|
||
|
||
<span class="c1"># Apply gradient-based update</span>
|
||
<span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">*</span> <span class="n">m_hat</span>
|
||
<span class="o">/</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">v_hat</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">))</span>
|
||
|
||
<span class="c1"># Apply decoupled weight decay (after gradient update!)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight_decay</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="complete-training-integration">
|
||
<h3>Complete Training Integration<a class="headerlink" href="#complete-training-integration" title="Link to this heading">#</a></h3>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Modern training workflow combining all components</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tinytorch.core.tensor</span><span class="w"> </span><span class="kn">import</span> <span class="n">Tensor</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tinytorch.core.optimizers</span><span class="w"> </span><span class="kn">import</span> <span class="n">SGD</span><span class="p">,</span> <span class="n">Adam</span><span class="p">,</span> <span class="n">AdamW</span>
|
||
|
||
<span class="c1"># Model setup (from previous modules)</span>
|
||
<span class="n">model</span> <span class="o">=</span> <span class="n">Sequential</span><span class="p">([</span>
|
||
<span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span> <span class="n">ReLU</span><span class="p">(),</span>
|
||
<span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">64</span><span class="p">),</span> <span class="n">ReLU</span><span class="p">(),</span>
|
||
<span class="n">Linear</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
|
||
<span class="p">])</span>
|
||
|
||
<span class="c1"># Optimization setup</span>
|
||
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">AdamW</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>
|
||
<span class="n">criterion</span> <span class="o">=</span> <span class="n">CrossEntropyLoss</span><span class="p">()</span>
|
||
|
||
<span class="c1"># Training loop</span>
|
||
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_epochs</span><span class="p">):</span>
|
||
<span class="n">epoch_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
|
||
|
||
<span class="k">for</span> <span class="n">batch_inputs</span><span class="p">,</span> <span class="n">batch_targets</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">:</span>
|
||
<span class="c1"># Forward pass</span>
|
||
<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">batch_inputs</span><span class="p">)</span>
|
||
<span class="n">loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">batch_targets</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Backward pass and optimization</span>
|
||
<span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="c1"># Clear old gradients</span>
|
||
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="c1"># Compute new gradients</span>
|
||
<span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="c1"># Update parameters</span>
|
||
|
||
<span class="n">epoch_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="o">.</span><span class="n">data</span>
|
||
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">: Loss = </span><span class="si">{</span><span class="n">epoch_loss</span><span class="si">:</span><span class="s2">.4f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
</section>
|
||
<section id="getting-started">
|
||
<h2>Getting Started<a class="headerlink" href="#getting-started" title="Link to this heading">#</a></h2>
|
||
<section id="prerequisites">
|
||
<h3>Prerequisites<a class="headerlink" href="#prerequisites" title="Link to this heading">#</a></h3>
|
||
<p>Ensure you understand the mathematical foundations:</p>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="c1"># Activate TinyTorch environment</span>
|
||
<span class="nb">source</span><span class="w"> </span>scripts/activate-tinytorch
|
||
|
||
<span class="c1"># Verify prerequisite modules</span>
|
||
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>tensor
|
||
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>autograd
|
||
</pre></div>
|
||
</div>
|
||
<p><strong>Required Background:</strong></p>
|
||
<ul class="simple">
|
||
<li><p><strong>Tensor Operations</strong>: Understanding parameter storage and update mechanics</p></li>
|
||
<li><p><strong>Automatic Differentiation</strong>: Gradients computed via backpropagation</p></li>
|
||
<li><p><strong>Calculus</strong>: Derivatives, gradient descent, chain rule</p></li>
|
||
<li><p><strong>Linear Algebra</strong>: Vector operations, element-wise operations</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="development-workflow">
|
||
<h3>Development Workflow<a class="headerlink" href="#development-workflow" title="Link to this heading">#</a></h3>
|
||
<ol class="arabic simple">
|
||
<li><p><strong>Open the development file</strong>: <code class="docutils literal notranslate"><span class="pre">modules/06_optimizers/optimizers_dev.ipynb</span></code></p></li>
|
||
<li><p><strong>Implement Optimizer base class</strong>: Start with parameter management and zero_grad interface</p></li>
|
||
<li><p><strong>Build SGD with momentum</strong>: Add velocity accumulation for smoother convergence</p></li>
|
||
<li><p><strong>Create Adam optimizer</strong>: Implement adaptive learning rates with moment estimation and bias correction</p></li>
|
||
<li><p><strong>Add AdamW optimizer</strong>: Build decoupled weight decay for proper regularization</p></li>
|
||
<li><p><strong>Export and verify</strong>: <code class="docutils literal notranslate"><span class="pre">tito</span> <span class="pre">module</span> <span class="pre">complete</span> <span class="pre">06</span> <span class="pre">&&</span> <span class="pre">tito</span> <span class="pre">test</span> <span class="pre">optimizers</span></code></p></li>
|
||
</ol>
|
||
<p><strong>Development Tips:</strong></p>
|
||
<ul class="simple">
|
||
<li><p>Test each optimizer on simple quadratic functions (f(x) = x²) where you can verify analytical convergence</p></li>
|
||
<li><p>Compare convergence speed between SGD and Adam on the same problem</p></li>
|
||
<li><p>Visualize loss curves to understand optimization dynamics</p></li>
|
||
<li><p>Check momentum/moment buffers are properly initialized and updated</p></li>
|
||
<li><p>Compare Adam vs AdamW to see the effect of decoupled weight decay</p></li>
|
||
</ul>
|
||
</section>
|
||
</section>
|
||
<section id="testing">
|
||
<h2>Testing<a class="headerlink" href="#testing" title="Link to this heading">#</a></h2>
|
||
<section id="comprehensive-test-suite">
|
||
<h3>Comprehensive Test Suite<a class="headerlink" href="#comprehensive-test-suite" title="Link to this heading">#</a></h3>
|
||
<p>Run the full test suite to verify optimization algorithm correctness:</p>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="c1"># TinyTorch CLI (recommended)</span>
|
||
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>optimizers
|
||
|
||
<span class="c1"># Direct pytest execution</span>
|
||
python<span class="w"> </span>-m<span class="w"> </span>pytest<span class="w"> </span>tests/<span class="w"> </span>-k<span class="w"> </span>optimizers<span class="w"> </span>-v
|
||
|
||
<span class="c1"># Test specific optimizer</span>
|
||
python<span class="w"> </span>-m<span class="w"> </span>pytest<span class="w"> </span>tests/test_optimizers.py::test_adam_convergence<span class="w"> </span>-v
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="test-coverage-areas">
|
||
<h3>Test Coverage Areas<a class="headerlink" href="#test-coverage-areas" title="Link to this heading">#</a></h3>
|
||
<ul class="simple">
|
||
<li><p><strong>Algorithm Implementation</strong>: Verify Optimizer base, SGD, Adam, and AdamW compute mathematically correct parameter updates</p></li>
|
||
<li><p><strong>Mathematical Correctness</strong>: Test against analytical solutions for convex optimization problems (quadratic functions)</p></li>
|
||
<li><p><strong>State Management</strong>: Ensure proper momentum and moment estimation tracking across training steps</p></li>
|
||
<li><p><strong>Memory Efficiency</strong>: Verify buffer initialization and memory usage patterns</p></li>
|
||
<li><p><strong>Training Integration</strong>: Test optimizers in complete neural network training workflows with real data</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="inline-testing-convergence-analysis">
|
||
<h3>Inline Testing & Convergence Analysis<a class="headerlink" href="#inline-testing-convergence-analysis" title="Link to this heading">#</a></h3>
|
||
<p>The module includes comprehensive mathematical validation and convergence visualization:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Example inline test output</span>
|
||
<span class="err">🔬</span> <span class="n">Unit</span> <span class="n">Test</span><span class="p">:</span> <span class="n">Base</span> <span class="n">Optimizer</span><span class="o">...</span>
|
||
<span class="err">✅</span> <span class="n">Parameter</span> <span class="n">validation</span> <span class="n">working</span> <span class="n">correctly</span>
|
||
<span class="err">✅</span> <span class="n">zero_grad</span> <span class="n">clears</span> <span class="nb">all</span> <span class="n">gradients</span> <span class="n">properly</span>
|
||
<span class="err">✅</span> <span class="n">Error</span> <span class="n">handling</span> <span class="k">for</span> <span class="n">non</span><span class="o">-</span><span class="n">gradient</span> <span class="n">parameters</span>
|
||
<span class="err">📈</span> <span class="n">Progress</span><span class="p">:</span> <span class="n">Base</span> <span class="n">Optimizer</span> <span class="err">✓</span>
|
||
|
||
<span class="c1"># SGD with momentum validation</span>
|
||
<span class="err">🔬</span> <span class="n">Unit</span> <span class="n">Test</span><span class="p">:</span> <span class="n">SGD</span> <span class="k">with</span> <span class="n">momentum</span><span class="o">...</span>
|
||
<span class="err">✅</span> <span class="n">Parameter</span> <span class="n">updates</span> <span class="n">follow</span> <span class="n">momentum</span> <span class="n">equation</span> <span class="n">v_t</span> <span class="o">=</span> <span class="n">βv_</span><span class="p">{</span><span class="n">t</span><span class="o">-</span><span class="mi">1</span><span class="p">}</span> <span class="o">+</span> <span class="err">∇</span><span class="n">L</span>
|
||
<span class="err">✅</span> <span class="n">Velocity</span> <span class="n">accumulation</span> <span class="n">working</span> <span class="n">correctly</span>
|
||
<span class="err">✅</span> <span class="n">Weight</span> <span class="n">decay</span> <span class="n">applied</span> <span class="n">properly</span>
|
||
<span class="err">✅</span> <span class="n">Momentum</span> <span class="n">accelerates</span> <span class="n">convergence</span> <span class="n">vs</span> <span class="n">vanilla</span> <span class="n">SGD</span>
|
||
<span class="err">📈</span> <span class="n">Progress</span><span class="p">:</span> <span class="n">SGD</span> <span class="k">with</span> <span class="n">Momentum</span> <span class="err">✓</span>
|
||
|
||
<span class="c1"># Adam optimizer validation</span>
|
||
<span class="err">🔬</span> <span class="n">Unit</span> <span class="n">Test</span><span class="p">:</span> <span class="n">Adam</span> <span class="n">optimizer</span><span class="o">...</span>
|
||
<span class="err">✅</span> <span class="n">First</span> <span class="n">moment</span> <span class="n">estimation</span> <span class="p">(</span><span class="n">m_t</span><span class="p">)</span> <span class="n">computed</span> <span class="n">correctly</span>
|
||
<span class="err">✅</span> <span class="n">Second</span> <span class="n">moment</span> <span class="n">estimation</span> <span class="p">(</span><span class="n">v_t</span><span class="p">)</span> <span class="n">computed</span> <span class="n">correctly</span>
|
||
<span class="err">✅</span> <span class="n">Bias</span> <span class="n">correction</span> <span class="n">applied</span> <span class="n">properly</span> <span class="p">(</span><span class="n">critical</span> <span class="k">for</span> <span class="n">early</span> <span class="n">steps</span><span class="p">)</span>
|
||
<span class="err">✅</span> <span class="n">Adaptive</span> <span class="n">learning</span> <span class="n">rates</span> <span class="n">working</span> <span class="n">per</span> <span class="n">parameter</span>
|
||
<span class="err">✅</span> <span class="n">Convergence</span> <span class="n">faster</span> <span class="n">than</span> <span class="n">SGD</span> <span class="n">on</span> <span class="n">ill</span><span class="o">-</span><span class="n">conditioned</span> <span class="n">problem</span>
|
||
<span class="err">📈</span> <span class="n">Progress</span><span class="p">:</span> <span class="n">Adam</span> <span class="n">Optimizer</span> <span class="err">✓</span>
|
||
|
||
<span class="c1"># AdamW decoupled weight decay validation</span>
|
||
<span class="err">🔬</span> <span class="n">Unit</span> <span class="n">Test</span><span class="p">:</span> <span class="n">AdamW</span> <span class="n">optimizer</span><span class="o">...</span>
|
||
<span class="err">✅</span> <span class="n">Weight</span> <span class="n">decay</span> <span class="n">decoupled</span> <span class="kn">from</span><span class="w"> </span><span class="nn">gradient</span> <span class="n">updates</span>
|
||
<span class="err">✅</span> <span class="n">Results</span> <span class="n">differ</span> <span class="kn">from</span><span class="w"> </span><span class="nn">Adam</span> <span class="p">(</span><span class="n">proving</span> <span class="n">proper</span> <span class="n">implementation</span><span class="p">)</span>
|
||
<span class="err">✅</span> <span class="n">Regularization</span> <span class="n">consistent</span> <span class="n">across</span> <span class="n">gradient</span> <span class="n">scales</span>
|
||
<span class="err">✅</span> <span class="n">With</span> <span class="n">zero</span> <span class="n">weight</span> <span class="n">decay</span><span class="p">,</span> <span class="n">matches</span> <span class="n">Adam</span> <span class="n">behavior</span>
|
||
<span class="err">📈</span> <span class="n">Progress</span><span class="p">:</span> <span class="n">AdamW</span> <span class="n">Optimizer</span> <span class="err">✓</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="manual-testing-examples">
|
||
<h3>Manual Testing Examples<a class="headerlink" href="#manual-testing-examples" title="Link to this heading">#</a></h3>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">tinytorch.core.optimizers</span><span class="w"> </span><span class="kn">import</span> <span class="n">SGD</span><span class="p">,</span> <span class="n">Adam</span><span class="p">,</span> <span class="n">AdamW</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tinytorch.core.tensor</span><span class="w"> </span><span class="kn">import</span> <span class="n">Tensor</span>
|
||
|
||
<span class="c1"># Test 1: SGD convergence on simple quadratic</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="s2">"Test 1: SGD on f(x) = x²"</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mf">10.0</span><span class="p">],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">sgd</span> <span class="o">=</span> <span class="n">SGD</span><span class="p">([</span><span class="n">x</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">)</span>
|
||
|
||
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
|
||
<span class="n">sgd</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
||
<span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="c1"># Minimize f(x) = x², minimum at x=0</span>
|
||
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
||
<span class="n">sgd</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">10</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Step </span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">: x = </span><span class="si">{</span><span class="n">x</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">, loss = </span><span class="si">{</span><span class="n">loss</span><span class="o">.</span><span class="n">data</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="c1"># Expected: x should converge to 0</span>
|
||
|
||
<span class="c1"># Test 2: Adam on multidimensional optimization</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Test 2: Adam on f(x,y) = x² + y²"</span><span class="p">)</span>
|
||
<span class="n">params</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mf">5.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">3.0</span><span class="p">],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">adam</span> <span class="o">=</span> <span class="n">Adam</span><span class="p">([</span><span class="n">params</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span>
|
||
|
||
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">50</span><span class="p">):</span>
|
||
<span class="n">adam</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
||
<span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">params</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="c1"># Minimize ||x||²</span>
|
||
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
||
<span class="n">adam</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">10</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Step </span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">: params = </span><span class="si">{</span><span class="n">params</span><span class="o">.</span><span class="n">data</span><span class="si">}</span><span class="s2">, loss = </span><span class="si">{</span><span class="n">loss</span><span class="o">.</span><span class="n">data</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="c1"># Expected: Both parameters converge to 0</span>
|
||
|
||
<span class="c1"># Test 3: Compare SGD vs Adam vs AdamW convergence</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Test 3: Optimizer comparison"</span><span class="p">)</span>
|
||
<span class="n">x_sgd</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mf">10.0</span><span class="p">],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">x_adam</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mf">10.0</span><span class="p">],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">x_adamw</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="mf">10.0</span><span class="p">],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
|
||
<span class="n">sgd</span> <span class="o">=</span> <span class="n">SGD</span><span class="p">([</span><span class="n">x_sgd</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">)</span>
|
||
<span class="n">adam</span> <span class="o">=</span> <span class="n">Adam</span><span class="p">([</span><span class="n">x_adam</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>
|
||
<span class="n">adamw</span> <span class="o">=</span> <span class="n">AdamW</span><span class="p">([</span><span class="n">x_adamw</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>
|
||
|
||
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">20</span><span class="p">):</span>
|
||
<span class="c1"># SGD update</span>
|
||
<span class="n">sgd</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
||
<span class="n">loss_sgd</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_sgd</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
|
||
<span class="n">loss_sgd</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
||
<span class="n">sgd</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
||
|
||
<span class="c1"># Adam update</span>
|
||
<span class="n">adam</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
||
<span class="n">loss_adam</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_adam</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
|
||
<span class="n">loss_adam</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
||
<span class="n">adam</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
||
|
||
<span class="c1"># AdamW update</span>
|
||
<span class="n">adamw</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
||
<span class="n">loss_adamw</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_adamw</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
|
||
<span class="n">loss_adamw</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
||
<span class="n">adamw</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">5</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Step </span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">: SGD=</span><span class="si">{</span><span class="n">x_sgd</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">, Adam=</span><span class="si">{</span><span class="n">x_adam</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">, AdamW=</span><span class="si">{</span><span class="n">x_adamw</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="c1"># Expected: Adam/AdamW converge faster initially</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
</section>
|
||
<section id="systems-thinking-questions">
|
||
<h2>Systems Thinking Questions<a class="headerlink" href="#systems-thinking-questions" title="Link to this heading">#</a></h2>
|
||
<section id="real-world-applications">
|
||
<h3>Real-World Applications<a class="headerlink" href="#real-world-applications" title="Link to this heading">#</a></h3>
|
||
<ul class="simple">
|
||
<li><p><strong>Large Language Models</strong>: GPT and BERT training relies on AdamW optimizer for stable convergence across billions of parameters with varying gradient scales and proper regularization</p></li>
|
||
<li><p><strong>Computer Vision</strong>: ResNet and Vision Transformer training typically uses SGD with momentum for best final test accuracy despite slower initial convergence</p></li>
|
||
<li><p><strong>Recommendation Systems</strong>: Online learning systems use adaptive optimizers like Adam for continuous model updates with non-stationary data distributions</p></li>
|
||
<li><p><strong>Reinforcement Learning</strong>: Policy gradient methods depend heavily on careful optimizer choice and learning rate tuning due to high variance gradients</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="optimization-theory-foundations">
|
||
<h3>Optimization Theory Foundations<a class="headerlink" href="#optimization-theory-foundations" title="Link to this heading">#</a></h3>
|
||
<ul class="simple">
|
||
<li><p><strong>Gradient Descent</strong>: Update rule θ_{t+1} = θ_t - α∇L(θ_t) where α is learning rate controlling step size in steepest descent direction</p></li>
|
||
<li><p><strong>Momentum</strong>: Velocity accumulation v_{t+1} = βv_t + ∇L(θ_t), then θ_{t+1} = θ_t - αv_{t+1} smooths noisy gradients and accelerates convergence</p></li>
|
||
<li><p><strong>Adam</strong>: Combines momentum (first moment m_t) with adaptive learning rates (second moment v_t), includes bias correction for early training steps</p></li>
|
||
<li><p><strong>AdamW</strong>: Decouples weight decay from gradient updates: applies gradient update first, then weight decay, fixing Adam’s regularization bug</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="performance-characteristics">
|
||
<h3>Performance Characteristics<a class="headerlink" href="#performance-characteristics" title="Link to this heading">#</a></h3>
|
||
<ul class="simple">
|
||
<li><p><strong>SGD Memory</strong>: O(2n) memory for n parameters (params + momentum buffers), most memory-efficient optimizer with momentum</p></li>
|
||
<li><p><strong>Adam Memory</strong>: O(3n) memory due to first and second moment buffers (params + m_buffers + v_buffers), 1.5x SGD cost</p></li>
|
||
<li><p><strong>Convergence Speed</strong>: Adam often converges faster initially due to adaptive rates, especially with sparse gradients or varying scales</p></li>
|
||
<li><p><strong>Final Performance</strong>: SGD with momentum often achieves better test accuracy on computer vision tasks despite slower convergence</p></li>
|
||
<li><p><strong>Learning Rate Sensitivity</strong>: Adam/AdamW are more robust to learning rate choice than vanilla SGD, making them popular for transformer training</p></li>
|
||
<li><p><strong>Computational Cost</strong>: Adam requires ~1.5x more computation per step (moment updates + bias correction + sqrt operations) than SGD</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="critical-thinking-memory-vs-convergence-trade-offs">
|
||
<h3>Critical Thinking: Memory vs Convergence Trade-offs<a class="headerlink" href="#critical-thinking-memory-vs-convergence-trade-offs" title="Link to this heading">#</a></h3>
|
||
<p><strong>Reflection Question</strong>: Why does Adam use 3x the memory of parameter-only storage (and 1.5x SGD), and when is this trade-off worth it?</p>
|
||
<p><strong>Key Insights:</strong></p>
|
||
<ul class="simple">
|
||
<li><p><strong>Memory Cost</strong>: Adam stores parameter data + first moment (momentum) + second moment (variance) for every parameter</p></li>
|
||
<li><p><strong>Adaptive Benefit</strong>: Per-parameter learning rates handle different gradient scales automatically</p></li>
|
||
<li><p><strong>Use Case</strong>: Transformers benefit from Adam (varying embedding vs attention scales), CNNs often prefer SGD (more uniform scales)</p></li>
|
||
<li><p><strong>Production Decision</strong>: Memory-constrained systems (mobile, edge devices) may prefer SGD despite slower convergence</p></li>
|
||
<li><p><strong>Training Time</strong>: Faster convergence can save GPU hours, offsetting memory cost in cloud training scenarios</p></li>
|
||
</ul>
|
||
<p><strong>Reflection Question</strong>: Why does SGD with momentum often achieve better test accuracy than Adam on vision tasks, despite slower training?</p>
|
||
<p><strong>Key Insights:</strong></p>
|
||
<ul class="simple">
|
||
<li><p><strong>Generalization</strong>: SGD explores flatter minima that generalize better to test data</p></li>
|
||
<li><p><strong>Overfitting</strong>: Adam’s fast convergence may lead to sharper minima with worse generalization</p></li>
|
||
<li><p><strong>Learning Rate Schedule</strong>: Careful learning rate decay with SGD achieves better final performance</p></li>
|
||
<li><p><strong>Task Dependency</strong>: Effect is strongest on CNNs, less pronounced on transformers</p></li>
|
||
<li><p><strong>Modern Practice</strong>: AdamW with proper weight decay often bridges this gap</p></li>
|
||
</ul>
|
||
<p><strong>Reflection Question</strong>: How does AdamW’s decoupled weight decay fix Adam’s regularization bug?</p>
|
||
<p><strong>Key Insights:</strong></p>
|
||
<ul class="simple">
|
||
<li><p><strong>Adam Bug</strong>: Adds weight decay to gradients, so adaptive learning rates affect regularization strength inconsistently</p></li>
|
||
<li><p><strong>AdamW Fix</strong>: Applies weight decay directly to parameters after gradient update, decoupling optimization from regularization</p></li>
|
||
<li><p><strong>Consistency</strong>: Weight decay effect is now uniform across parameters regardless of gradient magnitudes</p></li>
|
||
<li><p><strong>Production Impact</strong>: AdamW is now preferred over Adam in most modern training pipelines (BERT, GPT-3, etc.)</p></li>
|
||
</ul>
|
||
</section>
|
||
</section>
|
||
<section id="ready-to-build">
|
||
<h2>Ready to Build?<a class="headerlink" href="#ready-to-build" title="Link to this heading">#</a></h2>
|
||
<p>You’re about to implement the algorithms that enable all of modern deep learning! Every neural network—from the image classifiers in your phone to GPT-4—depends on the optimization algorithms you’re building in this module.</p>
|
||
<p>Understanding these algorithms from first principles will transform how you think about training. When you implement momentum physics and see how velocity accumulation smooths noisy gradients, when you build Adam’s adaptive learning rates and understand why they help with varying parameter scales, when you create AdamW and see how decoupled weight decay fixes Adam’s bug—you’ll develop deep intuition for why some training configurations work and others fail.</p>
|
||
<p>Take your time with the mathematics. Test your optimizers on simple quadratic functions where you can verify convergence analytically. Compare SGD vs Adam vs AdamW on the same problem to see their different behaviors. Visualize loss curves to understand optimization dynamics. Monitor memory usage to see the trade-offs. This hands-on experience will make you a better practitioner who can debug training failures, tune hyperparameters effectively, and make informed decisions about optimizer choice in production systems. Enjoy building the intelligence behind intelligent systems!</p>
|
||
<p>Choose your preferred way to engage with this module:</p>
|
||
<div class="sd-container-fluid sd-sphinx-override sd-mb-4 docutils">
|
||
<div class="sd-row sd-row-cols-1 sd-row-cols-xs-1 sd-row-cols-sm-2 sd-row-cols-md-3 sd-row-cols-lg-3 docutils">
|
||
<div class="sd-col sd-d-flex-row docutils">
|
||
<div class="sd-card sd-sphinx-override sd-w-100 sd-shadow-sm sd-card-hover docutils">
|
||
<div class="sd-card-body docutils">
|
||
<div class="sd-card-title sd-font-weight-bold docutils">
|
||
🚀 Launch Binder</div>
|
||
<p class="sd-card-text">Run this module interactively in your browser. No installation required!</p>
|
||
</div>
|
||
<a class="sd-stretched-link sd-hide-link-text reference external" href="https://mybinder.org/v2/gh/mlsysbook/TinyTorch/main?filepath=modules/06_optimizers/optimizers_dev.ipynb"><span>https://mybinder.org/v2/gh/mlsysbook/TinyTorch/main?filepath=modules/06_optimizers/optimizers_dev.ipynb</span></a></div>
|
||
</div>
|
||
<div class="sd-col sd-d-flex-row docutils">
|
||
<div class="sd-card sd-sphinx-override sd-w-100 sd-shadow-sm sd-card-hover docutils">
|
||
<div class="sd-card-body docutils">
|
||
<div class="sd-card-title sd-font-weight-bold docutils">
|
||
⚡ Open in Colab</div>
|
||
<p class="sd-card-text">Use Google Colab for GPU access and cloud compute power.</p>
|
||
</div>
|
||
<a class="sd-stretched-link sd-hide-link-text reference external" href="https://colab.research.google.com/github/mlsysbook/TinyTorch/blob/main/modules/06_optimizers/optimizers_dev.ipynb"><span>https://colab.research.google.com/github/mlsysbook/TinyTorch/blob/main/modules/06_optimizers/optimizers_dev.ipynb</span></a></div>
|
||
</div>
|
||
<div class="sd-col sd-d-flex-row docutils">
|
||
<div class="sd-card sd-sphinx-override sd-w-100 sd-shadow-sm sd-card-hover docutils">
|
||
<div class="sd-card-body docutils">
|
||
<div class="sd-card-title sd-font-weight-bold docutils">
|
||
📖 View Source</div>
|
||
<p class="sd-card-text">Browse the Jupyter notebook and understand the implementation.</p>
|
||
</div>
|
||
<a class="sd-stretched-link sd-hide-link-text reference external" href="https://github.com/mlsysbook/TinyTorch/blob/main/modules/06_optimizers/optimizers_dev.ipynb"><span>https://github.com/mlsysbook/TinyTorch/blob/main/modules/06_optimizers/optimizers_dev.ipynb</span></a></div>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
<div class="tip admonition">
|
||
<p class="admonition-title">💾 Save Your Progress</p>
|
||
<p><strong>Binder sessions are temporary!</strong> Download your completed notebook when done, or switch to local development for persistent work.</p>
|
||
</div>
|
||
<hr class="docutils" />
|
||
<div class="prev-next-area">
|
||
<a class="left-prev" href="../modules/05_autograd_ABOUT.html" title="previous page">← Previous Module</a>
|
||
<a class="right-next" href="../modules/07_training_ABOUT.html" title="next page">Next Module →</a>
|
||
</div>
|
||
</section>
|
||
</section>
|
||
|
||
<script type="text/x-thebe-config">
|
||
{
|
||
requestKernel: true,
|
||
binderOptions: {
|
||
repo: "binder-examples/jupyter-stacks-datascience",
|
||
ref: "master",
|
||
},
|
||
codeMirrorConfig: {
|
||
theme: "abcdef",
|
||
mode: "python"
|
||
},
|
||
kernelOptions: {
|
||
name: "python3",
|
||
path: "./modules"
|
||
},
|
||
predefinedOutput: true
|
||
}
|
||
</script>
|
||
<script>kernelName = 'python3'</script>
|
||
|
||
</article>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<footer class="prev-next-footer d-print-none">
|
||
|
||
<div class="prev-next-area">
|
||
<a class="left-prev"
|
||
href="05_autograd_ABOUT.html"
|
||
title="previous page">
|
||
<i class="fa-solid fa-angle-left"></i>
|
||
<div class="prev-next-info">
|
||
<p class="prev-next-subtitle">previous</p>
|
||
<p class="prev-next-title">05. Autograd</p>
|
||
</div>
|
||
</a>
|
||
<a class="right-next"
|
||
href="07_training_ABOUT.html"
|
||
title="next page">
|
||
<div class="prev-next-info">
|
||
<p class="prev-next-subtitle">next</p>
|
||
<p class="prev-next-title">07. Training</p>
|
||
</div>
|
||
<i class="fa-solid fa-angle-right"></i>
|
||
</a>
|
||
</div>
|
||
</footer>
|
||
|
||
</div>
|
||
|
||
|
||
|
||
<div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner">
|
||
|
||
|
||
<div class="sidebar-secondary-item">
|
||
<div class="page-toc tocsection onthispage">
|
||
<i class="fa-solid fa-list"></i> Contents
|
||
</div>
|
||
<nav class="bd-toc-nav page-toc">
|
||
<ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#overview">Overview</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#learning-objectives">Learning Objectives</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#build-use-reflect">Build → Use → Reflect</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementation-guide">Implementation Guide</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#core-optimization-algorithms">Core Optimization Algorithms</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#sgd-with-momentum-implementation">SGD with Momentum Implementation</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#adam-optimizer-implementation">Adam Optimizer Implementation</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#adamw-implementation-decoupled-weight-decay">AdamW Implementation (Decoupled Weight Decay)</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#complete-training-integration">Complete Training Integration</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started">Getting Started</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#prerequisites">Prerequisites</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#development-workflow">Development Workflow</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#testing">Testing</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#comprehensive-test-suite">Comprehensive Test Suite</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#test-coverage-areas">Test Coverage Areas</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#inline-testing-convergence-analysis">Inline Testing & Convergence Analysis</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#manual-testing-examples">Manual Testing Examples</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#systems-thinking-questions">Systems Thinking Questions</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#real-world-applications">Real-World Applications</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#optimization-theory-foundations">Optimization Theory Foundations</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#performance-characteristics">Performance Characteristics</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#critical-thinking-memory-vs-convergence-trade-offs">Critical Thinking: Memory vs Convergence Trade-offs</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#ready-to-build">Ready to Build?</a></li>
|
||
</ul>
|
||
</nav></div>
|
||
|
||
</div></div>
|
||
|
||
|
||
</div>
|
||
<footer class="bd-footer-content">
|
||
|
||
<div class="bd-footer-content__inner container">
|
||
|
||
<div class="footer-item">
|
||
|
||
<p class="component-author">
|
||
By Prof. Vijay Janapa Reddi (Harvard University)
|
||
</p>
|
||
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
|
||
|
||
<p class="copyright">
|
||
|
||
© Copyright 2025.
|
||
<br/>
|
||
|
||
</p>
|
||
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</footer>
|
||
|
||
|
||
</main>
|
||
</div>
|
||
</div>
|
||
|
||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||
<script src="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||
<script src="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||
|
||
<footer class="bd-footer">
|
||
</footer>
|
||
</body>
|
||
</html> |