mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-05-24 03:46:18 -05:00
1483 lines
105 KiB
HTML
1483 lines
105 KiB
HTML
|
||
<!DOCTYPE html>
|
||
|
||
|
||
<html lang="en" data-content_root="../" >
|
||
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
|
||
<title>07. Training — Tiny🔥Torch</title>
|
||
|
||
|
||
|
||
<script data-cfasync="false">
|
||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||
</script>
|
||
|
||
<!-- Loaded before other Sphinx assets -->
|
||
<link href="../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||
<link href="../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||
<link href="../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||
|
||
|
||
<link href="../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
|
||
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" />
|
||
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" />
|
||
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" />
|
||
|
||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=eba8b062" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/togglebutton.css?v=13237357" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css?v=76b2166b" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/mystnb.8ecb98da25f57f5357bf6f572d296f466b2cfe2517ffebfabe82451661e28f02.css" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/sphinx-thebe.css?v=4fa983c6" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/sphinx-design.min.css?v=95c83b7e" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/custom.css?v=009d37f4" />
|
||
|
||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" />
|
||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
|
||
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||
|
||
<script src="../_static/documentation_options.js?v=9eb32ce0"></script>
|
||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||
<script src="../_static/copybutton.js?v=f281be69"></script>
|
||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||
<script>let toggleHintShow = 'Click to show';</script>
|
||
<script>let toggleHintHide = 'Click to hide';</script>
|
||
<script>let toggleOpenOnPrint = 'true';</script>
|
||
<script src="../_static/togglebutton.js?v=4a39c7ea"></script>
|
||
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
|
||
<script src="../_static/design-tabs.js?v=f930bc37"></script>
|
||
<script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"; const thebe_selector = ".thebe,.cell"; const thebe_selector_input = "pre"; const thebe_selector_output = ".output, .cell_output"</script>
|
||
<script async="async" src="../_static/sphinx-thebe.js?v=c100c467"></script>
|
||
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
|
||
<script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"; const thebe_selector = ".thebe,.cell"; const thebe_selector_input = "pre"; const thebe_selector_output = ".output, .cell_output"</script>
|
||
<script type="module" src="https://cdn.jsdelivr.net/npm/mermaid@10.6.1/dist/mermaid.esm.min.mjs"></script>
|
||
<script type="module" src="https://cdn.jsdelivr.net/npm/@mermaid-js/layout-elk@0.2.0/dist/mermaid-layout-elk.esm.min.mjs"></script>
|
||
<script type="module">import mermaid from "https://cdn.jsdelivr.net/npm/mermaid@10.6.1/dist/mermaid.esm.min.mjs";import elkLayouts from "https://cdn.jsdelivr.net/npm/@mermaid-js/layout-elk@0.2.0/dist/mermaid-layout-elk.esm.min.mjs";mermaid.registerLayoutLoaders(elkLayouts);mermaid.initialize({startOnLoad:false});</script>
|
||
<script src="https://cdn.jsdelivr.net/npm/d3@7.9.0/dist/d3.min.js"></script>
|
||
<script type="module">import mermaid from "https://cdn.jsdelivr.net/npm/mermaid@10.6.1/dist/mermaid.esm.min.mjs";
|
||
|
||
const defaultStyle = document.createElement('style');
|
||
defaultStyle.textContent = `pre.mermaid {
|
||
/* Same as .mermaid-container > pre */
|
||
display: block;
|
||
width: 100%;
|
||
}
|
||
|
||
pre.mermaid > svg {
|
||
/* Same as .mermaid-container > pre > svg */
|
||
height: 500px;
|
||
width: 100%;
|
||
max-width: 100% !important;
|
||
}
|
||
`;
|
||
document.head.appendChild(defaultStyle);
|
||
|
||
const fullscreenStyle = document.createElement('style');
|
||
fullscreenStyle.textContent = `.mermaid-container {
|
||
display: flex;
|
||
flex-direction: row;
|
||
width: 100%;
|
||
}
|
||
|
||
.mermaid-container > pre {
|
||
display: block;
|
||
width: 100%;
|
||
}
|
||
|
||
.mermaid-container > pre > svg {
|
||
height: 500px;
|
||
width: 100%;
|
||
max-width: 100% !important;
|
||
}
|
||
|
||
.mermaid-fullscreen-btn {
|
||
width: 28px;
|
||
height: 28px;
|
||
background: rgba(255, 255, 255, 0.95);
|
||
border: 1px solid rgba(0, 0, 0, 0.3);
|
||
border-radius: 4px;
|
||
cursor: pointer;
|
||
display: flex;
|
||
align-items: center;
|
||
justify-content: center;
|
||
transition: all 0.2s;
|
||
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.2);
|
||
font-size: 14px;
|
||
line-height: 1;
|
||
padding: 0;
|
||
color: #333;
|
||
}
|
||
|
||
.mermaid-fullscreen-btn:hover {
|
||
opacity: 100% !important;
|
||
background: rgba(255, 255, 255, 1);
|
||
box-shadow: 0 3px 10px rgba(0, 0, 0, 0.3);
|
||
transform: scale(1.1);
|
||
}
|
||
|
||
.mermaid-fullscreen-btn.dark-theme {
|
||
background: rgba(50, 50, 50, 0.95);
|
||
border: 1px solid rgba(255, 255, 255, 0.3);
|
||
color: #e0e0e0;
|
||
}
|
||
|
||
.mermaid-fullscreen-btn.dark-theme:hover {
|
||
background: rgba(60, 60, 60, 1);
|
||
box-shadow: 0 3px 10px rgba(255, 255, 255, 0.2);
|
||
}
|
||
|
||
.mermaid-fullscreen-modal {
|
||
display: none;
|
||
position: fixed !important;
|
||
top: 0 !important;
|
||
left: 0 !important;
|
||
width: 95vw;
|
||
height: 100vh;
|
||
background: rgba(255, 255, 255, 0.98);
|
||
z-index: 9999;
|
||
padding: 20px;
|
||
overflow: auto;
|
||
}
|
||
|
||
.mermaid-fullscreen-modal.dark-theme {
|
||
background: rgba(0, 0, 0, 0.98);
|
||
}
|
||
|
||
.mermaid-fullscreen-modal.active {
|
||
display: flex;
|
||
align-items: center;
|
||
justify-content: center;
|
||
}
|
||
|
||
.mermaid-container-fullscreen {
|
||
position: relative;
|
||
width: 95vw;
|
||
height: 90vh;
|
||
max-width: 95vw;
|
||
max-height: 90vh;
|
||
background: white;
|
||
border-radius: 8px;
|
||
padding: 20px;
|
||
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.3);
|
||
overflow: auto;
|
||
display: flex;
|
||
align-items: center;
|
||
justify-content: center;
|
||
}
|
||
|
||
.mermaid-container-fullscreen.dark-theme {
|
||
background: #1a1a1a;
|
||
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.8);
|
||
}
|
||
|
||
.mermaid-container-fullscreen pre.mermaid {
|
||
width: 100%;
|
||
height: 100%;
|
||
display: flex;
|
||
align-items: center;
|
||
justify-content: center;
|
||
}
|
||
|
||
.mermaid-container-fullscreen .mermaid svg {
|
||
height: 100% !important;
|
||
width: 100% !important;
|
||
cursor: grab;
|
||
}
|
||
|
||
.mermaid-fullscreen-close {
|
||
position: fixed !important;
|
||
top: 20px !important;
|
||
right: 20px !important;
|
||
width: 40px;
|
||
height: 40px;
|
||
background: rgba(255, 255, 255, 0.95);
|
||
border: 1px solid rgba(0, 0, 0, 0.2);
|
||
border-radius: 50%;
|
||
cursor: pointer;
|
||
z-index: 10000;
|
||
display: flex;
|
||
align-items: center;
|
||
justify-content: center;
|
||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
|
||
transition: all 0.2s;
|
||
font-size: 24px;
|
||
line-height: 1;
|
||
color: #333;
|
||
}
|
||
|
||
.mermaid-fullscreen-close:hover {
|
||
background: white;
|
||
box-shadow: 0 6px 16px rgba(0, 0, 0, 0.4);
|
||
transform: scale(1.1);
|
||
}
|
||
|
||
.mermaid-fullscreen-close.dark-theme {
|
||
background: rgba(50, 50, 50, 0.95);
|
||
border: 1px solid rgba(255, 255, 255, 0.2);
|
||
color: #e0e0e0;
|
||
}
|
||
|
||
.mermaid-fullscreen-close.dark-theme:hover {
|
||
background: rgba(60, 60, 60, 1);
|
||
box-shadow: 0 6px 16px rgba(255, 255, 255, 0.2);
|
||
}
|
||
|
||
.mermaid-fullscreen-modal .mermaid-fullscreen-btn {
|
||
display: none !important;
|
||
}`;
|
||
document.head.appendChild(fullscreenStyle);
|
||
|
||
// Detect if page has dark background
|
||
const isDarkTheme = () => {
|
||
const bgColor = window.getComputedStyle(document.body).backgroundColor;
|
||
const match = bgColor.match(/rgb\((\d+),\s*(\d+),\s*(\d+)/);
|
||
if (match) {
|
||
const r = parseInt(match[1]);
|
||
const g = parseInt(match[2]);
|
||
const b = parseInt(match[3]);
|
||
const brightness = (r * 299 + g * 587 + b * 114) / 1000;
|
||
return brightness < 128;
|
||
}
|
||
return false;
|
||
};
|
||
|
||
const load = async () => {
|
||
await mermaid.run();
|
||
|
||
const all_mermaids = document.querySelectorAll(".mermaid");
|
||
const mermaids_processed = document.querySelectorAll(".mermaid[data-processed='true']");
|
||
|
||
if ("False" === "True") {
|
||
const mermaids_to_add_zoom = -1 === -1 ? all_mermaids.length : -1;
|
||
if(mermaids_to_add_zoom > 0) {
|
||
var svgs = d3.selectAll("");
|
||
if(all_mermaids.length !== mermaids_processed.length) {
|
||
setTimeout(load, 200);
|
||
return;
|
||
} else if(svgs.size() !== mermaids_to_add_zoom) {
|
||
setTimeout(load, 200);
|
||
return;
|
||
} else {
|
||
svgs.each(function() {
|
||
var svg = d3.select(this);
|
||
svg.html("<g class='wrapper'>" + svg.html() + "</g>");
|
||
var inner = svg.select("g");
|
||
var zoom = d3.zoom().on("zoom", function(event) {
|
||
inner.attr("transform", event.transform);
|
||
});
|
||
svg.call(zoom);
|
||
});
|
||
}
|
||
}
|
||
} else if(all_mermaids.length !== mermaids_processed.length) {
|
||
// Wait for mermaid to process all diagrams
|
||
setTimeout(load, 200);
|
||
return;
|
||
}
|
||
|
||
const darkTheme = isDarkTheme();
|
||
|
||
// Stop here if not adding fullscreen capability
|
||
if ("True" !== "True") return;
|
||
|
||
const modal = document.createElement('div');
|
||
modal.className = 'mermaid-fullscreen-modal' + (darkTheme ? ' dark-theme' : '');
|
||
modal.setAttribute('role', 'dialog');
|
||
modal.setAttribute('aria-modal', 'true');
|
||
modal.setAttribute('aria-label', 'Fullscreen diagram viewer');
|
||
modal.innerHTML = `
|
||
<button class="mermaid-fullscreen-close${darkTheme ? ' dark-theme' : ''}" aria-label="Close fullscreen">✕</button>
|
||
<div class="mermaid-container-fullscreen${darkTheme ? ' dark-theme' : ''}"></div>
|
||
`;
|
||
document.body.appendChild(modal);
|
||
|
||
const modalContent = modal.querySelector('.mermaid-container-fullscreen');
|
||
const closeBtn = modal.querySelector('.mermaid-fullscreen-close');
|
||
|
||
let previousScrollOffset = [window.scrollX, window.scrollY];
|
||
|
||
const closeModal = () => {
|
||
modal.classList.remove('active');
|
||
modalContent.innerHTML = '';
|
||
document.body.style.overflow = ''
|
||
window.scrollTo({left: previousScrollOffset[0], top: previousScrollOffset[1], behavior: 'instant'});
|
||
};
|
||
|
||
closeBtn.addEventListener('click', closeModal);
|
||
modal.addEventListener('click', (e) => {
|
||
if (e.target === modal) closeModal();
|
||
});
|
||
document.addEventListener('keydown', (e) => {
|
||
if (e.key === 'Escape' && modal.classList.contains('active')) {
|
||
closeModal();
|
||
}
|
||
});
|
||
|
||
const allButtons = [];
|
||
|
||
document.querySelectorAll('.mermaid').forEach((mermaidDiv) => {
|
||
if (mermaidDiv.parentNode.classList.contains('mermaid-container') ||
|
||
mermaidDiv.closest('.mermaid-fullscreen-modal')) {
|
||
return;
|
||
}
|
||
|
||
const container = document.createElement('div');
|
||
container.className = 'mermaid-container';
|
||
mermaidDiv.parentNode.insertBefore(container, mermaidDiv);
|
||
container.appendChild(mermaidDiv);
|
||
|
||
const fullscreenBtn = document.createElement('button');
|
||
fullscreenBtn.className = 'mermaid-fullscreen-btn' + (darkTheme ? ' dark-theme' : '');
|
||
fullscreenBtn.setAttribute('aria-label', 'View diagram in fullscreen');
|
||
fullscreenBtn.textContent = '⛶';
|
||
fullscreenBtn.style.opacity = '50%';
|
||
|
||
// Calculate dynamic position based on diagram's margin and padding
|
||
const diagramStyle = window.getComputedStyle(mermaidDiv);
|
||
const marginTop = parseFloat(diagramStyle.marginTop) || 0;
|
||
const marginRight = parseFloat(diagramStyle.marginRight) || 0;
|
||
const paddingTop = parseFloat(diagramStyle.paddingTop) || 0;
|
||
const paddingRight = parseFloat(diagramStyle.paddingRight) || 0;
|
||
fullscreenBtn.style.top = `${marginTop + paddingTop + 4}px`;
|
||
fullscreenBtn.style.right = `${marginRight + paddingRight + 4}px`;
|
||
|
||
fullscreenBtn.addEventListener('click', () => {
|
||
previousScrollOffset = [window.scroll, window.scrollY];
|
||
const clone = mermaidDiv.cloneNode(true);
|
||
modalContent.innerHTML = '';
|
||
modalContent.appendChild(clone);
|
||
|
||
const svg = clone.querySelector('svg');
|
||
if (svg) {
|
||
svg.removeAttribute('width');
|
||
svg.removeAttribute('height');
|
||
svg.style.width = '100%';
|
||
svg.style.height = 'auto';
|
||
svg.style.maxWidth = '100%';
|
||
svg.style.sdisplay = 'block';
|
||
|
||
if ("False" === "True") {
|
||
setTimeout(() => {
|
||
const g = svg.querySelector('g');
|
||
if (g) {
|
||
var svgD3 = d3.select(svg);
|
||
svgD3.html("<g class='wrapper'>" + svgD3.html() + "</g>");
|
||
var inner = svgD3.select("g");
|
||
var zoom = d3.zoom().on("zoom", function(event) {
|
||
inner.attr("transform", event.transform);
|
||
});
|
||
svgD3.call(zoom);
|
||
}
|
||
}, 100);
|
||
}
|
||
}
|
||
|
||
modal.classList.add('active');
|
||
document.body.style.overflow = 'hidden';
|
||
});
|
||
|
||
container.appendChild(fullscreenBtn);
|
||
allButtons.push(fullscreenBtn);
|
||
});
|
||
|
||
// Update theme classes when theme changes
|
||
const updateTheme = () => {
|
||
const dark = isDarkTheme();
|
||
allButtons.forEach(btn => {
|
||
if (dark) {
|
||
btn.classList.add('dark-theme');
|
||
} else {
|
||
btn.classList.remove('dark-theme');
|
||
}
|
||
});
|
||
if (dark) {
|
||
modal.classList.add('dark-theme');
|
||
modalContent.classList.add('dark-theme');
|
||
closeBtn.classList.add('dark-theme');
|
||
} else {
|
||
modal.classList.remove('dark-theme');
|
||
modalContent.classList.remove('dark-theme');
|
||
closeBtn.classList.remove('dark-theme');
|
||
}
|
||
};
|
||
|
||
// Watch for theme changes
|
||
const observer = new MutationObserver(updateTheme);
|
||
observer.observe(document.documentElement, {
|
||
attributes: true,
|
||
attributeFilter: ['class', 'style', 'data-theme']
|
||
});
|
||
observer.observe(document.body, {
|
||
attributes: true,
|
||
attributeFilter: ['class', 'style']
|
||
});
|
||
};
|
||
|
||
window.addEventListener("load", load);
|
||
</script>
|
||
<script>DOCUMENTATION_OPTIONS.pagename = 'modules/07_training_ABOUT';</script>
|
||
<script src="../_static/ml-timeline.js?v=76e9b3e3"></script>
|
||
<script src="../_static/wip-banner.js?v=04a7e74d"></script>
|
||
<script src="../_static/marimo-badges.js?v=e6289128"></script>
|
||
<script src="../_static/sidebar-link.js?v=404b701b"></script>
|
||
<script src="../_static/hero-carousel.js?v=10341d2a"></script>
|
||
<script src="../_static/subscribe-modal.js?v=42919b64"></script>
|
||
<link rel="icon" href="../_static/favicon.svg"/>
|
||
<link rel="index" title="Index" href="../genindex.html" />
|
||
<link rel="search" title="Search" href="../search.html" />
|
||
<link rel="next" title="🏛️ Architecture Tier (Modules 08-13)" href="../tiers/architecture.html" />
|
||
<link rel="prev" title="06. Optimizers" href="06_optimizers_ABOUT.html" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||
<meta name="docsearch:language" content="en"/>
|
||
</head>
|
||
|
||
|
||
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
|
||
|
||
|
||
|
||
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
|
||
|
||
<div id="pst-scroll-pixel-helper"></div>
|
||
|
||
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
|
||
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
|
||
|
||
|
||
<input type="checkbox"
|
||
class="sidebar-toggle"
|
||
id="pst-primary-sidebar-checkbox"/>
|
||
<label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label>
|
||
|
||
<input type="checkbox"
|
||
class="sidebar-toggle"
|
||
id="pst-secondary-sidebar-checkbox"/>
|
||
<label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label>
|
||
|
||
<div class="search-button__wrapper">
|
||
<div class="search-button__overlay"></div>
|
||
<div class="search-button__search-container">
|
||
<form class="bd-search d-flex align-items-center"
|
||
action="../search.html"
|
||
method="get">
|
||
<i class="fa-solid fa-magnifying-glass"></i>
|
||
<input type="search"
|
||
class="form-control"
|
||
name="q"
|
||
id="search-input"
|
||
placeholder="Search..."
|
||
aria-label="Search..."
|
||
autocomplete="off"
|
||
autocorrect="off"
|
||
autocapitalize="off"
|
||
spellcheck="false"/>
|
||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
|
||
</form></div>
|
||
</div>
|
||
|
||
<div class="pst-async-banner-revealer d-none">
|
||
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
|
||
</div>
|
||
|
||
|
||
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
|
||
</header>
|
||
|
||
|
||
<div class="bd-container">
|
||
<div class="bd-container__inner bd-page-width">
|
||
|
||
|
||
|
||
<div class="bd-sidebar-primary bd-sidebar">
|
||
|
||
|
||
|
||
<div class="sidebar-header-items sidebar-primary__section">
|
||
|
||
|
||
|
||
|
||
</div>
|
||
|
||
<div class="sidebar-primary-items__start sidebar-primary__section">
|
||
<div class="sidebar-primary-item">
|
||
|
||
|
||
|
||
|
||
|
||
<a class="navbar-brand logo" href="../intro.html">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<img src="../_static/logo-tinytorch.png" class="logo__image only-light" alt="Tiny🔥Torch - Home"/>
|
||
<script>document.write(`<img src="../_static/logo-tinytorch.png" class="logo__image only-dark" alt="Tiny🔥Torch - Home"/>`);</script>
|
||
|
||
|
||
</a></div>
|
||
<div class="sidebar-primary-item">
|
||
|
||
<script>
|
||
document.write(`
|
||
<button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="fa-solid fa-magnifying-glass"></i>
|
||
<span class="search-button__default-text">Search</span>
|
||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
|
||
</button>
|
||
`);
|
||
</script></div>
|
||
<div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main">
|
||
<div class="bd-toc-item navbar-nav active">
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🚀 Getting Started</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../getting-started.html">Complete Guide</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🏗 Foundation Tier (01-07)</span></p>
|
||
<ul class="current nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../tiers/foundation.html">📖 Tier Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="01_tensor_ABOUT.html">01. Tensor</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="02_activations_ABOUT.html">02. Activations</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="03_layers_ABOUT.html">03. Layers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="04_losses_ABOUT.html">04. Losses</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="05_autograd_ABOUT.html">05. Autograd</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="06_optimizers_ABOUT.html">06. Optimizers</a></li>
|
||
<li class="toctree-l1 current active"><a class="current reference internal" href="#">07. Training</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🏛️ Architecture Tier (08-13)</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../tiers/architecture.html">📖 Tier Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="08_dataloader_ABOUT.html">08. DataLoader</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="09_spatial_ABOUT.html">09. Convolutions</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="10_tokenization_ABOUT.html">10. Tokenization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="11_embeddings_ABOUT.html">11. Embeddings</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="12_attention_ABOUT.html">12. Attention</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="13_transformers_ABOUT.html">13. Transformers</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">⏱️ Optimization Tier (14-19)</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../tiers/optimization.html">📖 Tier Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="14_profiling_ABOUT.html">14. Profiling</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="15_quantization_ABOUT.html">15. Quantization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="16_compression_ABOUT.html">16. Compression</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="17_memoization_ABOUT.html">17. Memoization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="18_acceleration_ABOUT.html">18. Acceleration</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="19_benchmarking_ABOUT.html">19. Benchmarking</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🏅 Capstone Competition</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../tiers/olympics.html">📖 Competition Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="20_capstone_ABOUT.html">20. Torch Olympics</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🧭 Course Orientation</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../chapters/00-introduction.html">Course Structure</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../prerequisites.html">Prerequisites & Resources</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../chapters/learning-journey.html">Learning Journey</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../chapters/milestones.html">Historical Milestones</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../faq.html">FAQ</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🛠️ TITO CLI Reference</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../tito/overview.html">Command Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../tito/modules.html">Module Workflow</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../tito/milestones.html">Milestone System</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../tito/data.html">Progress & Data</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../tito/troubleshooting.html">Troubleshooting</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../datasets.html">Datasets Guide</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🤝 Community</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../community.html">Ecosystem</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../resources.html">Learning Resources</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../credits.html">Credits & Acknowledgments</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</nav></div>
|
||
</div>
|
||
|
||
|
||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||
</div>
|
||
|
||
<div id="rtd-footer-container"></div>
|
||
|
||
|
||
</div>
|
||
|
||
<main id="main-content" class="bd-main" role="main">
|
||
|
||
|
||
|
||
<div class="sbt-scroll-pixel-helper"></div>
|
||
|
||
<div class="bd-content">
|
||
<div class="bd-article-container">
|
||
|
||
<div class="bd-header-article d-print-none">
|
||
<div class="header-article-items header-article__inner">
|
||
|
||
<div class="header-article-items__start">
|
||
|
||
<div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<span class="fa-solid fa-bars"></span>
|
||
</button></div>
|
||
|
||
</div>
|
||
|
||
|
||
<div class="header-article-items__end">
|
||
|
||
<div class="header-article-item">
|
||
|
||
<div class="article-header-buttons">
|
||
|
||
|
||
|
||
|
||
|
||
<div class="dropdown dropdown-download-buttons">
|
||
<button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page">
|
||
<i class="fas fa-download"></i>
|
||
</button>
|
||
<ul class="dropdown-menu">
|
||
|
||
|
||
|
||
<li><a href="../_sources/modules/07_training_ABOUT.md" target="_blank"
|
||
class="btn btn-sm btn-download-source-button dropdown-item"
|
||
title="Download source file"
|
||
data-bs-placement="left" data-bs-toggle="tooltip"
|
||
>
|
||
|
||
|
||
<span class="btn__icon-container">
|
||
<i class="fas fa-file"></i>
|
||
</span>
|
||
<span class="btn__text-container">.md</span>
|
||
</a>
|
||
</li>
|
||
|
||
|
||
|
||
|
||
<li>
|
||
<button onclick="window.print()"
|
||
class="btn btn-sm btn-download-pdf-button dropdown-item"
|
||
title="Print to PDF"
|
||
data-bs-placement="left" data-bs-toggle="tooltip"
|
||
>
|
||
|
||
|
||
<span class="btn__icon-container">
|
||
<i class="fas fa-file-pdf"></i>
|
||
</span>
|
||
<span class="btn__text-container">.pdf</span>
|
||
</button>
|
||
</li>
|
||
|
||
</ul>
|
||
</div>
|
||
|
||
|
||
|
||
|
||
<button onclick="toggleFullScreen()"
|
||
class="btn btn-sm btn-fullscreen-button"
|
||
title="Fullscreen mode"
|
||
data-bs-placement="bottom" data-bs-toggle="tooltip"
|
||
>
|
||
|
||
|
||
<span class="btn__icon-container">
|
||
<i class="fas fa-expand"></i>
|
||
</span>
|
||
|
||
</button>
|
||
|
||
|
||
|
||
<script>
|
||
document.write(`
|
||
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i>
|
||
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i>
|
||
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i>
|
||
</button>
|
||
`);
|
||
</script>
|
||
|
||
|
||
<script>
|
||
document.write(`
|
||
<button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="fa-solid fa-magnifying-glass fa-lg"></i>
|
||
</button>
|
||
`);
|
||
</script>
|
||
<button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<span class="fa-solid fa-list"></span>
|
||
</button>
|
||
</div></div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</div>
|
||
|
||
|
||
|
||
<div id="jb-print-docs-body" class="onlyprint">
|
||
<h1>07. Training</h1>
|
||
<!-- Table of contents -->
|
||
<div id="print-main-content">
|
||
<div id="jb-print-toc">
|
||
|
||
<div>
|
||
<h2> Contents </h2>
|
||
</div>
|
||
<nav aria-label="Page">
|
||
<ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#overview">Overview</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#learning-objectives">Learning Objectives</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#build-use-reflect">Build → Use → Reflect</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementation-guide">Implementation Guide</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#the-training-loop-cycle">The Training Loop Cycle</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#cosineschedule-adaptive-learning-rate-management">CosineSchedule - Adaptive Learning Rate Management</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#gradient-clipping-preventing-training-explosions">Gradient Clipping - Preventing Training Explosions</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#trainer-class-complete-training-orchestration">Trainer Class - Complete Training Orchestration</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#complete-training-example">Complete Training Example</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started">Getting Started</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#prerequisites">Prerequisites</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#development-workflow">Development Workflow</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#testing">Testing</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#comprehensive-test-suite">Comprehensive Test Suite</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#test-coverage-areas">Test Coverage Areas</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#inline-testing-training-analysis">Inline Testing & Training Analysis</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#manual-testing-examples">Manual Testing Examples</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#systems-thinking-questions">Systems Thinking Questions</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#real-world-applications">Real-World Applications</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#training-system-architecture">Training System Architecture</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#training-dynamics">Training Dynamics</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#performance-characteristics">Performance Characteristics</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#ready-to-build">Ready to Build?</a></li>
|
||
</ul>
|
||
</nav>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
|
||
<div id="searchbox"></div>
|
||
<article class="bd-article">
|
||
|
||
<section id="training">
|
||
<h1>07. Training<a class="headerlink" href="#training" title="Link to this heading">#</a></h1>
|
||
<p><strong>FOUNDATION TIER</strong> | Difficulty: ⭐⭐⭐⭐ (4/4) | Time: 6-8 hours</p>
|
||
<section id="overview">
|
||
<h2>Overview<a class="headerlink" href="#overview" title="Link to this heading">#</a></h2>
|
||
<p>Build the complete training infrastructure that orchestrates neural network learning end-to-end. This capstone module of the Foundation tier brings together all previous components—tensors, layers, losses, gradients, and optimizers—into production-ready training loops with learning rate scheduling, gradient clipping, and model checkpointing. You’ll create the same training patterns that power PyTorch, TensorFlow, and every production ML system.</p>
|
||
</section>
|
||
<section id="learning-objectives">
|
||
<h2>Learning Objectives<a class="headerlink" href="#learning-objectives" title="Link to this heading">#</a></h2>
|
||
<p>By the end of this module, you will be able to:</p>
|
||
<ul class="simple">
|
||
<li><p><strong>Implement complete Trainer class</strong>: Orchestrate forward passes, loss computation, backpropagation, and parameter updates into cohesive training loops with train/eval mode switching</p></li>
|
||
<li><p><strong>Build CosineSchedule for adaptive learning rates</strong>: Create learning rate schedulers that start fast for quick convergence, then slow down for fine-tuning, following cosine annealing curves</p></li>
|
||
<li><p><strong>Create gradient clipping utilities</strong>: Implement global norm gradient clipping to prevent exploding gradients and training instability in deep networks</p></li>
|
||
<li><p><strong>Design checkpointing system</strong>: Build save/load functionality that preserves complete training state—model parameters, optimizer buffers, scheduler state, and training history</p></li>
|
||
<li><p><strong>Understand training systems architecture</strong>: Master memory overhead (4-6× model size), gradient accumulation strategies, checkpoint management, and the difference between training and evaluation modes</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="build-use-reflect">
|
||
<h2>Build → Use → Reflect<a class="headerlink" href="#build-use-reflect" title="Link to this heading">#</a></h2>
|
||
<p>This module follows TinyTorch’s <strong>Build → Use → Reflect</strong> framework:</p>
|
||
<ol class="arabic simple">
|
||
<li><p><strong>Build</strong>: Implement CosineSchedule for learning rate scheduling, clip_grad_norm for gradient stability, and complete Trainer class with checkpointing</p></li>
|
||
<li><p><strong>Use</strong>: Train neural networks end-to-end with real optimization dynamics, observe learning rate adaptation, and experiment with gradient accumulation</p></li>
|
||
<li><p><strong>Reflect</strong>: Analyze training memory overhead (parameters + gradients + optimizer state), understand when to checkpoint, and compare training strategies across different scenarios</p></li>
|
||
</ol>
|
||
</section>
|
||
<section id="implementation-guide">
|
||
<h2>Implementation Guide<a class="headerlink" href="#implementation-guide" title="Link to this heading">#</a></h2>
|
||
<section id="the-training-loop-cycle">
|
||
<h3>The Training Loop Cycle<a class="headerlink" href="#the-training-loop-cycle" title="Link to this heading">#</a></h3>
|
||
<p>Training orchestrates data, forward pass, loss, gradients, and updates in an iterative cycle:</p>
|
||
<pre class="mermaid">
|
||
graph LR
|
||
A[Data Batch] --> B[Forward Pass<br/>Model]
|
||
B --> C[Loss<br/>Compute]
|
||
C --> D[Backward Pass<br/>Autograd]
|
||
D --> E[Optimizer Step<br/>Update θ]
|
||
E --> F[Next Batch]
|
||
F --> A
|
||
|
||
style A fill:#e3f2fd
|
||
style B fill:#f3e5f5
|
||
style C fill:#fff3e0
|
||
style D fill:#ffe0b2
|
||
style E fill:#fce4ec
|
||
style F fill:#f0fdf4
|
||
</pre><p><strong>Cycle</strong>: Load batch → Forward through model → Compute loss → Backward gradients → Update parameters → Repeat</p>
|
||
</section>
|
||
<section id="cosineschedule-adaptive-learning-rate-management">
|
||
<h3>CosineSchedule - Adaptive Learning Rate Management<a class="headerlink" href="#cosineschedule-adaptive-learning-rate-management" title="Link to this heading">#</a></h3>
|
||
<p>Learning rate scheduling is like adjusting driving speed based on road conditions—start fast on the highway, slow down in neighborhoods for precision. Cosine annealing provides smooth transitions from aggressive learning to fine-tuning:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">CosineSchedule</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Cosine annealing learning rate schedule.</span>
|
||
|
||
<span class="sd"> Starts at max_lr, decreases following cosine curve to min_lr.</span>
|
||
<span class="sd"> Formula: lr = min_lr + (max_lr - min_lr) * (1 + cos(π*epoch/T)) / 2</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">min_lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">total_epochs</span><span class="o">=</span><span class="mi">100</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_lr</span> <span class="o">=</span> <span class="n">max_lr</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min_lr</span> <span class="o">=</span> <span class="n">min_lr</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">total_epochs</span> <span class="o">=</span> <span class="n">total_epochs</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">get_lr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Get learning rate for current epoch."""</span>
|
||
<span class="k">if</span> <span class="n">epoch</span> <span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_epochs</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_lr</span>
|
||
|
||
<span class="c1"># Cosine annealing: smooth decrease from max to min</span>
|
||
<span class="n">cosine_factor</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">pi</span> <span class="o">*</span> <span class="n">epoch</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_epochs</span><span class="p">))</span> <span class="o">/</span> <span class="mi">2</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_lr</span> <span class="o">+</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_lr</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_lr</span><span class="p">)</span> <span class="o">*</span> <span class="n">cosine_factor</span>
|
||
|
||
<span class="c1"># Usage example</span>
|
||
<span class="n">schedule</span> <span class="o">=</span> <span class="n">CosineSchedule</span><span class="p">(</span><span class="n">max_lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">min_lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">total_epochs</span><span class="o">=</span><span class="mi">50</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">schedule</span><span class="o">.</span><span class="n">get_lr</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span> <span class="c1"># 0.1 - fast learning initially</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">schedule</span><span class="o">.</span><span class="n">get_lr</span><span class="p">(</span><span class="mi">25</span><span class="p">))</span> <span class="c1"># ~0.05 - gradual slowdown</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">schedule</span><span class="o">.</span><span class="n">get_lr</span><span class="p">(</span><span class="mi">50</span><span class="p">))</span> <span class="c1"># 0.001 - fine-tuning at end</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="gradient-clipping-preventing-training-explosions">
|
||
<h3>Gradient Clipping - Preventing Training Explosions<a class="headerlink" href="#gradient-clipping-preventing-training-explosions" title="Link to this heading">#</a></h3>
|
||
<p>Gradient clipping is a speed governor that prevents dangerously large gradients from destroying training progress. Global norm clipping scales all gradients uniformly while preserving their relative magnitudes:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">clip_grad_norm</span><span class="p">(</span><span class="n">parameters</span><span class="p">:</span> <span class="n">List</span><span class="p">,</span> <span class="n">max_norm</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-></span> <span class="nb">float</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Clip gradients by global norm to prevent exploding gradients.</span>
|
||
|
||
<span class="sd"> Computes total_norm = sqrt(sum of all gradient squares).</span>
|
||
<span class="sd"> If total_norm > max_norm, scales all gradients by max_norm/total_norm.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">parameters</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="mf">0.0</span>
|
||
|
||
<span class="c1"># Compute global norm across all parameters</span>
|
||
<span class="n">total_norm</span> <span class="o">=</span> <span class="mf">0.0</span>
|
||
<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">parameters</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">grad_data</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="s1">'data'</span><span class="p">)</span> <span class="k">else</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span>
|
||
<span class="n">total_norm</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">grad_data</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
|
||
|
||
<span class="n">total_norm</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">total_norm</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Clip if necessary - preserves gradient direction</span>
|
||
<span class="k">if</span> <span class="n">total_norm</span> <span class="o">></span> <span class="n">max_norm</span><span class="p">:</span>
|
||
<span class="n">clip_coef</span> <span class="o">=</span> <span class="n">max_norm</span> <span class="o">/</span> <span class="n">total_norm</span>
|
||
<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">parameters</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">,</span> <span class="s1">'data'</span><span class="p">):</span>
|
||
<span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span> <span class="o">*=</span> <span class="n">clip_coef</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="o">*=</span> <span class="n">clip_coef</span>
|
||
|
||
<span class="k">return</span> <span class="nb">float</span><span class="p">(</span><span class="n">total_norm</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Usage example</span>
|
||
<span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span>
|
||
<span class="n">original_norm</span> <span class="o">=</span> <span class="n">clip_grad_norm</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">max_norm</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Gradient norm: </span><span class="si">{</span><span class="n">original_norm</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> → clipped to 1.0"</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="trainer-class-complete-training-orchestration">
|
||
<h3>Trainer Class - Complete Training Orchestration<a class="headerlink" href="#trainer-class-complete-training-orchestration" title="Link to this heading">#</a></h3>
|
||
<p>The Trainer class conducts the symphony of training—coordinating model, optimizer, loss function, and scheduler into cohesive learning loops with checkpointing and evaluation:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">Trainer</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Complete training orchestrator for neural networks.</span>
|
||
|
||
<span class="sd"> Handles training loops, evaluation, scheduling, gradient clipping,</span>
|
||
<span class="sd"> checkpointing, and train/eval mode switching.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">scheduler</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">grad_clip_norm</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span> <span class="o">=</span> <span class="n">loss_fn</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span> <span class="o">=</span> <span class="n">scheduler</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">grad_clip_norm</span> <span class="o">=</span> <span class="n">grad_clip_norm</span>
|
||
|
||
<span class="c1"># Training state</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">epoch</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">step</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">training_mode</span> <span class="o">=</span> <span class="kc">True</span>
|
||
|
||
<span class="c1"># History tracking</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">history</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s1">'train_loss'</span><span class="p">:</span> <span class="p">[],</span>
|
||
<span class="s1">'eval_loss'</span><span class="p">:</span> <span class="p">[],</span>
|
||
<span class="s1">'learning_rates'</span><span class="p">:</span> <span class="p">[]</span>
|
||
<span class="p">}</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">train_epoch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">,</span> <span class="n">accumulation_steps</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Train for one epoch through the dataset.</span>
|
||
|
||
<span class="sd"> Supports gradient accumulation for effective larger batch sizes.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">training</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="n">total_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
|
||
<span class="n">num_batches</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">accumulated_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
|
||
|
||
<span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dataloader</span><span class="p">):</span>
|
||
<span class="c1"># Forward pass</span>
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span>
|
||
<span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Scale loss for accumulation</span>
|
||
<span class="n">scaled_loss</span> <span class="o">=</span> <span class="n">loss</span><span class="o">.</span><span class="n">data</span> <span class="o">/</span> <span class="n">accumulation_steps</span>
|
||
<span class="n">accumulated_loss</span> <span class="o">+=</span> <span class="n">scaled_loss</span>
|
||
|
||
<span class="c1"># Backward pass</span>
|
||
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
||
|
||
<span class="c1"># Update every accumulation_steps batches</span>
|
||
<span class="k">if</span> <span class="p">(</span><span class="n">batch_idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">accumulation_steps</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="c1"># Gradient clipping</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_clip_norm</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">clip_grad_norm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">grad_clip_norm</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Optimizer step</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
||
|
||
<span class="n">total_loss</span> <span class="o">+=</span> <span class="n">accumulated_loss</span>
|
||
<span class="n">accumulated_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
|
||
<span class="n">num_batches</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">step</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
|
||
<span class="c1"># Update learning rate</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">current_lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span><span class="o">.</span><span class="n">get_lr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">epoch</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">current_lr</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s1">'learning_rates'</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">current_lr</span><span class="p">)</span>
|
||
|
||
<span class="n">avg_loss</span> <span class="o">=</span> <span class="n">total_loss</span> <span class="o">/</span> <span class="nb">max</span><span class="p">(</span><span class="n">num_batches</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s1">'train_loss'</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">avg_loss</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">epoch</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
|
||
<span class="k">return</span> <span class="n">avg_loss</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">evaluate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Evaluate model without updating parameters.</span>
|
||
|
||
<span class="sd"> Sets model.training = False for proper evaluation behavior.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">training</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">total_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
|
||
<span class="n">correct</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">total</span> <span class="o">=</span> <span class="mi">0</span>
|
||
|
||
<span class="k">for</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">:</span>
|
||
<span class="c1"># Forward pass only - no gradients</span>
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span>
|
||
<span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
|
||
<span class="n">total_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="o">.</span><span class="n">data</span>
|
||
|
||
<span class="c1"># Calculate accuracy for classification</span>
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">outputs</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">predictions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">outputs</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">correct</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">predictions</span> <span class="o">==</span> <span class="n">targets</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">correct</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">predictions</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">targets</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
|
||
<span class="n">total</span> <span class="o">+=</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
|
||
|
||
<span class="n">avg_loss</span> <span class="o">=</span> <span class="n">total_loss</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">dataloader</span><span class="p">)</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">dataloader</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span> <span class="k">else</span> <span class="mf">0.0</span>
|
||
<span class="n">accuracy</span> <span class="o">=</span> <span class="n">correct</span> <span class="o">/</span> <span class="n">total</span> <span class="k">if</span> <span class="n">total</span> <span class="o">></span> <span class="mi">0</span> <span class="k">else</span> <span class="mf">0.0</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s1">'eval_loss'</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">avg_loss</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">avg_loss</span><span class="p">,</span> <span class="n">accuracy</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">save_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Save complete training state for resumption."""</span>
|
||
<span class="n">checkpoint</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s1">'epoch'</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span>
|
||
<span class="s1">'step'</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">step</span><span class="p">,</span>
|
||
<span class="s1">'model_state'</span><span class="p">:</span> <span class="p">{</span><span class="n">i</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())},</span>
|
||
<span class="s1">'optimizer_state'</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_optimizer_state</span><span class="p">(),</span>
|
||
<span class="s1">'scheduler_state'</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_scheduler_state</span><span class="p">(),</span>
|
||
<span class="s1">'history'</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">history</span><span class="p">,</span>
|
||
<span class="s1">'training_mode'</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_mode</span>
|
||
<span class="p">}</span>
|
||
|
||
<span class="n">Path</span><span class="p">(</span><span class="n">path</span><span class="p">)</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s1">'wb'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
|
||
<span class="n">pickle</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">checkpoint</span><span class="p">,</span> <span class="n">f</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">load_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Load training state from checkpoint."""</span>
|
||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s1">'rb'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
|
||
<span class="n">checkpoint</span> <span class="o">=</span> <span class="n">pickle</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">epoch</span> <span class="o">=</span> <span class="n">checkpoint</span><span class="p">[</span><span class="s1">'epoch'</span><span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">step</span> <span class="o">=</span> <span class="n">checkpoint</span><span class="p">[</span><span class="s1">'step'</span><span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">history</span> <span class="o">=</span> <span class="n">checkpoint</span><span class="p">[</span><span class="s1">'history'</span><span class="p">]</span>
|
||
|
||
<span class="c1"># Restore model parameters</span>
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">()):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">checkpoint</span><span class="p">[</span><span class="s1">'model_state'</span><span class="p">]:</span>
|
||
<span class="n">param</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">checkpoint</span><span class="p">[</span><span class="s1">'model_state'</span><span class="p">][</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="complete-training-example">
|
||
<h3>Complete Training Example<a class="headerlink" href="#complete-training-example" title="Link to this heading">#</a></h3>
|
||
<p>Bringing all components together into production-ready training:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">tinytorch.core.training</span><span class="w"> </span><span class="kn">import</span> <span class="n">Trainer</span><span class="p">,</span> <span class="n">CosineSchedule</span><span class="p">,</span> <span class="n">clip_grad_norm</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tinytorch.core.layers</span><span class="w"> </span><span class="kn">import</span> <span class="n">Linear</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tinytorch.core.losses</span><span class="w"> </span><span class="kn">import</span> <span class="n">MSELoss</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tinytorch.core.optimizers</span><span class="w"> </span><span class="kn">import</span> <span class="n">SGD</span>
|
||
|
||
<span class="c1"># Build model</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">SimpleNN</span><span class="p">:</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">layer1</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">layer2</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="o">=</span> <span class="kc">True</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">data</span><span class="p">))</span> <span class="c1"># ReLU</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">parameters</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span>
|
||
|
||
<span class="c1"># Configure training</span>
|
||
<span class="n">model</span> <span class="o">=</span> <span class="n">SimpleNN</span><span class="p">()</span>
|
||
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">SGD</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">)</span>
|
||
<span class="n">loss_fn</span> <span class="o">=</span> <span class="n">MSELoss</span><span class="p">()</span>
|
||
<span class="n">scheduler</span> <span class="o">=</span> <span class="n">CosineSchedule</span><span class="p">(</span><span class="n">max_lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">min_lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">total_epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Create trainer with gradient clipping</span>
|
||
<span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span>
|
||
<span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span>
|
||
<span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span>
|
||
<span class="n">loss_fn</span><span class="o">=</span><span class="n">loss_fn</span><span class="p">,</span>
|
||
<span class="n">scheduler</span><span class="o">=</span><span class="n">scheduler</span><span class="p">,</span>
|
||
<span class="n">grad_clip_norm</span><span class="o">=</span><span class="mf">1.0</span> <span class="c1"># Prevent exploding gradients</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># Train for multiple epochs</span>
|
||
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span>
|
||
<span class="n">train_loss</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">train_epoch</span><span class="p">(</span><span class="n">train_data</span><span class="p">)</span>
|
||
<span class="n">eval_loss</span><span class="p">,</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">val_data</span><span class="p">)</span>
|
||
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">: train_loss=</span><span class="si">{</span><span class="n">train_loss</span><span class="si">:</span><span class="s2">.4f</span><span class="si">}</span><span class="s2">, "</span>
|
||
<span class="sa">f</span><span class="s2">"eval_loss=</span><span class="si">{</span><span class="n">eval_loss</span><span class="si">:</span><span class="s2">.4f</span><span class="si">}</span><span class="s2">, accuracy=</span><span class="si">{</span><span class="n">accuracy</span><span class="si">:</span><span class="s2">.4f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Save checkpoint periodically</span>
|
||
<span class="k">if</span> <span class="n">epoch</span> <span class="o">%</span> <span class="mi">5</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">trainer</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">(</span><span class="sa">f</span><span class="s1">'checkpoint_epoch_</span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s1">.pkl'</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Restore from checkpoint</span>
|
||
<span class="n">trainer</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">(</span><span class="s1">'checkpoint_epoch_5.pkl'</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Resumed training from epoch </span><span class="si">{</span><span class="n">trainer</span><span class="o">.</span><span class="n">epoch</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
</section>
|
||
<section id="getting-started">
|
||
<h2>Getting Started<a class="headerlink" href="#getting-started" title="Link to this heading">#</a></h2>
|
||
<section id="prerequisites">
|
||
<h3>Prerequisites<a class="headerlink" href="#prerequisites" title="Link to this heading">#</a></h3>
|
||
<p>Ensure you have completed all Foundation tier modules:</p>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="c1"># Activate TinyTorch environment</span>
|
||
<span class="nb">source</span><span class="w"> </span>scripts/activate-tinytorch
|
||
|
||
<span class="c1"># Verify all prerequisites (Training is the Foundation capstone!)</span>
|
||
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>tensor<span class="w"> </span><span class="c1"># Module 01: Tensor operations</span>
|
||
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>activations<span class="w"> </span><span class="c1"># Module 02: Activation functions</span>
|
||
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>layers<span class="w"> </span><span class="c1"># Module 03: Neural network layers</span>
|
||
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>losses<span class="w"> </span><span class="c1"># Module 04: Loss functions</span>
|
||
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>autograd<span class="w"> </span><span class="c1"># Module 05: Automatic differentiation</span>
|
||
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>optimizers<span class="w"> </span><span class="c1"># Module 06: Parameter update algorithms</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="development-workflow">
|
||
<h3>Development Workflow<a class="headerlink" href="#development-workflow" title="Link to this heading">#</a></h3>
|
||
<ol class="arabic simple">
|
||
<li><p><strong>Open the development file</strong>: <code class="docutils literal notranslate"><span class="pre">modules/07_training/training.py</span></code></p></li>
|
||
<li><p><strong>Implement CosineSchedule</strong>: Build learning rate scheduler with cosine annealing (smooth max_lr → min_lr transition)</p></li>
|
||
<li><p><strong>Create clip_grad_norm</strong>: Implement global norm gradient clipping to prevent exploding gradients</p></li>
|
||
<li><p><strong>Build Trainer class</strong>: Orchestrate complete training loop with train_epoch(), evaluate(), and checkpointing</p></li>
|
||
<li><p><strong>Add gradient accumulation</strong>: Support effective larger batch sizes with limited memory</p></li>
|
||
<li><p><strong>Test end-to-end training</strong>: Validate complete pipeline with real models and data</p></li>
|
||
<li><p><strong>Export and verify</strong>: <code class="docutils literal notranslate"><span class="pre">tito</span> <span class="pre">module</span> <span class="pre">complete</span> <span class="pre">07</span> <span class="pre">&&</span> <span class="pre">tito</span> <span class="pre">test</span> <span class="pre">training</span></code></p></li>
|
||
</ol>
|
||
</section>
|
||
</section>
|
||
<section id="testing">
|
||
<h2>Testing<a class="headerlink" href="#testing" title="Link to this heading">#</a></h2>
|
||
<section id="comprehensive-test-suite">
|
||
<h3>Comprehensive Test Suite<a class="headerlink" href="#comprehensive-test-suite" title="Link to this heading">#</a></h3>
|
||
<p>Run the full test suite to verify complete training infrastructure:</p>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="c1"># TinyTorch CLI (recommended)</span>
|
||
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>training
|
||
|
||
<span class="c1"># Direct pytest execution</span>
|
||
python<span class="w"> </span>-m<span class="w"> </span>pytest<span class="w"> </span>tests/<span class="w"> </span>-k<span class="w"> </span>training<span class="w"> </span>-v
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="test-coverage-areas">
|
||
<h3>Test Coverage Areas<a class="headerlink" href="#test-coverage-areas" title="Link to this heading">#</a></h3>
|
||
<ul class="simple">
|
||
<li><p><strong>CosineSchedule Correctness</strong>: Verify cosine annealing produces correct learning rates at start, middle, and end epochs</p></li>
|
||
<li><p><strong>Gradient Clipping Stability</strong>: Test global norm computation and uniform scaling when gradients exceed threshold</p></li>
|
||
<li><p><strong>Trainer Orchestration</strong>: Ensure proper coordination of forward pass, backward pass, optimization, and scheduling</p></li>
|
||
<li><p><strong>Checkpointing Completeness</strong>: Validate save/load preserves model state, optimizer buffers, scheduler state, and training history</p></li>
|
||
<li><p><strong>Memory Analysis</strong>: Measure training memory overhead (parameters + gradients + optimizer state = 4-6× model size)</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="inline-testing-training-analysis">
|
||
<h3>Inline Testing & Training Analysis<a class="headerlink" href="#inline-testing-training-analysis" title="Link to this heading">#</a></h3>
|
||
<p>The module includes comprehensive validation of training dynamics:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># CosineSchedule validation</span>
|
||
<span class="n">schedule</span> <span class="o">=</span> <span class="n">CosineSchedule</span><span class="p">(</span><span class="n">max_lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">min_lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">total_epochs</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">schedule</span><span class="o">.</span><span class="n">get_lr</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span> <span class="c1"># 0.1 - aggressive learning initially</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">schedule</span><span class="o">.</span><span class="n">get_lr</span><span class="p">(</span><span class="mi">50</span><span class="p">))</span> <span class="c1"># ~0.055 - gradual slowdown</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="n">schedule</span><span class="o">.</span><span class="n">get_lr</span><span class="p">(</span><span class="mi">100</span><span class="p">))</span> <span class="c1"># 0.01 - fine-tuning at end</span>
|
||
|
||
<span class="c1"># Gradient clipping validation</span>
|
||
<span class="n">param</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">100.0</span><span class="p">,</span> <span class="mf">200.0</span><span class="p">])</span> <span class="c1"># Large gradients</span>
|
||
<span class="n">original_norm</span> <span class="o">=</span> <span class="n">clip_grad_norm</span><span class="p">([</span><span class="n">param</span><span class="p">],</span> <span class="n">max_norm</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
|
||
<span class="c1"># original_norm ≈ 223.6 → clipped to 1.0</span>
|
||
<span class="k">assert</span> <span class="n">np</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span><span class="p">)</span> <span class="err">≈</span> <span class="mf">1.0</span>
|
||
|
||
<span class="c1"># Trainer integration validation</span>
|
||
<span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">scheduler</span><span class="p">,</span> <span class="n">grad_clip_norm</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
|
||
<span class="n">loss</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">train_epoch</span><span class="p">(</span><span class="n">train_data</span><span class="p">)</span>
|
||
<span class="n">eval_loss</span><span class="p">,</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_data</span><span class="p">)</span>
|
||
<span class="n">trainer</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">(</span><span class="s1">'checkpoint.pkl'</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
<section id="manual-testing-examples">
|
||
<h3>Manual Testing Examples<a class="headerlink" href="#manual-testing-examples" title="Link to this heading">#</a></h3>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">training</span><span class="w"> </span><span class="kn">import</span> <span class="n">Trainer</span><span class="p">,</span> <span class="n">CosineSchedule</span><span class="p">,</span> <span class="n">clip_grad_norm</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">layers</span><span class="w"> </span><span class="kn">import</span> <span class="n">Linear</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">losses</span><span class="w"> </span><span class="kn">import</span> <span class="n">MSELoss</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">optimizers</span><span class="w"> </span><span class="kn">import</span> <span class="n">SGD</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tensor</span><span class="w"> </span><span class="kn">import</span> <span class="n">Tensor</span>
|
||
|
||
<span class="c1"># Test complete training pipeline</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">SimpleModel</span><span class="p">:</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">layer</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="o">=</span> <span class="kc">True</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">parameters</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span>
|
||
|
||
<span class="c1"># Create training system</span>
|
||
<span class="n">model</span> <span class="o">=</span> <span class="n">SimpleModel</span><span class="p">()</span>
|
||
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">SGD</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>
|
||
<span class="n">loss_fn</span> <span class="o">=</span> <span class="n">MSELoss</span><span class="p">()</span>
|
||
<span class="n">scheduler</span> <span class="o">=</span> <span class="n">CosineSchedule</span><span class="p">(</span><span class="n">max_lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">min_lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">total_epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
|
||
|
||
<span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">scheduler</span><span class="p">,</span> <span class="n">grad_clip_norm</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Create simple dataset</span>
|
||
<span class="n">train_data</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="p">(</span><span class="n">Tensor</span><span class="p">([[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]]),</span> <span class="n">Tensor</span><span class="p">([[</span><span class="mf">2.0</span><span class="p">]])),</span>
|
||
<span class="p">(</span><span class="n">Tensor</span><span class="p">([[</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">]]),</span> <span class="n">Tensor</span><span class="p">([[</span><span class="mf">1.5</span><span class="p">]]))</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="c1"># Train and monitor</span>
|
||
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span>
|
||
<span class="n">loss</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">train_epoch</span><span class="p">(</span><span class="n">train_data</span><span class="p">)</span>
|
||
<span class="n">lr</span> <span class="o">=</span> <span class="n">scheduler</span><span class="o">.</span><span class="n">get_lr</span><span class="p">(</span><span class="n">epoch</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">: loss=</span><span class="si">{</span><span class="n">loss</span><span class="si">:</span><span class="s2">.4f</span><span class="si">}</span><span class="s2">, lr=</span><span class="si">{</span><span class="n">lr</span><span class="si">:</span><span class="s2">.4f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Test checkpointing</span>
|
||
<span class="n">trainer</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">(</span><span class="s1">'test_checkpoint.pkl'</span><span class="p">)</span>
|
||
<span class="n">trainer</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">(</span><span class="s1">'test_checkpoint.pkl'</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Restored from epoch </span><span class="si">{</span><span class="n">trainer</span><span class="o">.</span><span class="n">epoch</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</section>
|
||
</section>
|
||
<section id="systems-thinking-questions">
|
||
<h2>Systems Thinking Questions<a class="headerlink" href="#systems-thinking-questions" title="Link to this heading">#</a></h2>
|
||
<section id="real-world-applications">
|
||
<h3>Real-World Applications<a class="headerlink" href="#real-world-applications" title="Link to this heading">#</a></h3>
|
||
<ul class="simple">
|
||
<li><p><strong>Production Training Pipelines</strong>: PyTorch Lightning, Hugging Face Transformers, TensorFlow Estimators all use similar Trainer architectures with checkpointing and scheduling</p></li>
|
||
<li><p><strong>Large-Scale Model Training</strong>: GPT, BERT, and vision models rely on gradient clipping and learning rate scheduling for stable convergence across billions of parameters</p></li>
|
||
<li><p><strong>Research Experimentation</strong>: Academic ML uses checkpointing for long experiments with periodic evaluation and model selection</p></li>
|
||
<li><p><strong>Fault-Tolerant Training</strong>: Cloud training systems use checkpoints to resume after infrastructure failures or spot instance interruptions</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="training-system-architecture">
|
||
<h3>Training System Architecture<a class="headerlink" href="#training-system-architecture" title="Link to this heading">#</a></h3>
|
||
<ul class="simple">
|
||
<li><p><strong>Memory Breakdown</strong>: Training requires parameters (1×) + gradients (1×) + optimizer state (2-3×) = 4-6× model memory footprint</p></li>
|
||
<li><p><strong>Gradient Accumulation</strong>: Enables effective batch size of accumulation_steps × actual_batch_size with fixed memory—trades time for memory efficiency</p></li>
|
||
<li><p><strong>Train/Eval Modes</strong>: Different layer behaviors during training (dropout active, batch norm updates) vs evaluation (dropout off, fixed batch norm)</p></li>
|
||
<li><p><strong>Checkpoint Components</strong>: Must save model parameters, optimizer buffers (momentum, Adam m/v), scheduler state, epoch counter, and training history for exact resumption</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="training-dynamics">
|
||
<h3>Training Dynamics<a class="headerlink" href="#training-dynamics" title="Link to this heading">#</a></h3>
|
||
<ul class="simple">
|
||
<li><p><strong>Learning Rate Scheduling</strong>: Cosine annealing starts fast (quick convergence when far from optimum) then slows (stable fine-tuning near solution)</p></li>
|
||
<li><p><strong>Exploding Gradients</strong>: Occur in deep networks and RNNs when gradient magnitudes grow exponentially through backpropagation—gradient clipping prevents training collapse</p></li>
|
||
<li><p><strong>Gradient Accumulation Trade-offs</strong>: Reduces memory by processing small batches but increases training time linearly with accumulation steps</p></li>
|
||
<li><p><strong>Checkpointing Strategy</strong>: Balance disk space (1GB+ per checkpoint) vs fault tolerance (more frequent = less lost work) and evaluation frequency (save best model)</p></li>
|
||
</ul>
|
||
</section>
|
||
<section id="performance-characteristics">
|
||
<h3>Performance Characteristics<a class="headerlink" href="#performance-characteristics" title="Link to this heading">#</a></h3>
|
||
<ul class="simple">
|
||
<li><p><strong>Training Memory Scaling</strong>: Adam optimizer uses 4× parameter memory (params + grads + m + v) vs SGD with momentum at 3× (params + grads + momentum)</p></li>
|
||
<li><p><strong>Checkpoint Overhead</strong>: Pickle serialization adds 10-30% overhead beyond raw parameter data—use compression for large models</p></li>
|
||
<li><p><strong>Learning Rate Impact</strong>: Too high causes instability/divergence, too low causes slow convergence—scheduling adapts automatically</p></li>
|
||
<li><p><strong>Global Norm vs Individual Clipping</strong>: Global norm preserves gradient direction while preventing explosion—individual clipping can distort optimization trajectory</p></li>
|
||
</ul>
|
||
</section>
|
||
</section>
|
||
<section id="ready-to-build">
|
||
<h2>Ready to Build?<a class="headerlink" href="#ready-to-build" title="Link to this heading">#</a></h2>
|
||
<p>You’re about to complete the Foundation tier by building the training infrastructure that brings neural networks to life! This is where all your work on tensors, activations, layers, losses, gradients, and optimizers comes together into a cohesive system that actually learns from data.</p>
|
||
<p>Training is the heart of machine learning—the process that transforms random initialization into intelligent models. You’re implementing the same patterns used to train GPT, BERT, ResNet, and every production AI system. Understanding how scheduling, gradient clipping, checkpointing, and mode switching work together gives you mastery over the training process.</p>
|
||
<p>This module is the culmination of everything you’ve built. Take your time understanding how each piece fits into the bigger picture, and enjoy creating a complete ML training system from scratch!</p>
|
||
<p>Choose your preferred way to engage with this module:</p>
|
||
<div class="sd-container-fluid sd-sphinx-override sd-mb-4 docutils">
|
||
<div class="sd-row sd-row-cols-1 sd-row-cols-xs-1 sd-row-cols-sm-2 sd-row-cols-md-3 sd-row-cols-lg-3 docutils">
|
||
<div class="sd-col sd-d-flex-row docutils">
|
||
<div class="sd-card sd-sphinx-override sd-w-100 sd-shadow-sm sd-card-hover docutils">
|
||
<div class="sd-card-body docutils">
|
||
<div class="sd-card-title sd-font-weight-bold docutils">
|
||
🚀 Launch Binder</div>
|
||
<p class="sd-card-text">Run this module interactively in your browser. No installation required!</p>
|
||
</div>
|
||
<a class="sd-stretched-link sd-hide-link-text reference external" href="https://mybinder.org/v2/gh/mlsysbook/TinyTorch/main?filepath=modules/07_training/training.ipynb"><span>https://mybinder.org/v2/gh/mlsysbook/TinyTorch/main?filepath=modules/07_training/training.ipynb</span></a></div>
|
||
</div>
|
||
<div class="sd-col sd-d-flex-row docutils">
|
||
<div class="sd-card sd-sphinx-override sd-w-100 sd-shadow-sm sd-card-hover docutils">
|
||
<div class="sd-card-body docutils">
|
||
<div class="sd-card-title sd-font-weight-bold docutils">
|
||
⚡ Open in Colab</div>
|
||
<p class="sd-card-text">Use Google Colab for GPU access and cloud compute power.</p>
|
||
</div>
|
||
<a class="sd-stretched-link sd-hide-link-text reference external" href="https://colab.research.google.com/github/mlsysbook/TinyTorch/blob/main/modules/07_training/training.ipynb"><span>https://colab.research.google.com/github/mlsysbook/TinyTorch/blob/main/modules/07_training/training.ipynb</span></a></div>
|
||
</div>
|
||
<div class="sd-col sd-d-flex-row docutils">
|
||
<div class="sd-card sd-sphinx-override sd-w-100 sd-shadow-sm sd-card-hover docutils">
|
||
<div class="sd-card-body docutils">
|
||
<div class="sd-card-title sd-font-weight-bold docutils">
|
||
📖 View Source</div>
|
||
<p class="sd-card-text">Browse the Python source code and understand the implementation.</p>
|
||
</div>
|
||
<a class="sd-stretched-link sd-hide-link-text reference external" href="https://github.com/mlsysbook/TinyTorch/blob/main/modules/07_training/training.py"><span>https://github.com/mlsysbook/TinyTorch/blob/main/modules/07_training/training.py</span></a></div>
|
||
</div>
|
||
</div>
|
||
</div>
|
||
<div class="tip admonition">
|
||
<p class="admonition-title">💾 Save Your Progress</p>
|
||
<p><strong>Binder sessions are temporary!</strong> Download your completed notebook when done, or switch to local development for persistent work.</p>
|
||
</div>
|
||
<hr class="docutils" />
|
||
<div class="prev-next-area">
|
||
<a class="left-prev" href="../modules/06_optimizers_ABOUT.html" title="previous page">← Previous Module</a>
|
||
<a class="right-next" href="../modules/08_dataloader_ABOUT.html" title="next page">Next Module →</a>
|
||
</div>
|
||
</section>
|
||
</section>
|
||
|
||
<script type="text/x-thebe-config">
|
||
{
|
||
requestKernel: true,
|
||
binderOptions: {
|
||
repo: "binder-examples/jupyter-stacks-datascience",
|
||
ref: "master",
|
||
},
|
||
codeMirrorConfig: {
|
||
theme: "abcdef",
|
||
mode: "python"
|
||
},
|
||
kernelOptions: {
|
||
name: "python3",
|
||
path: "./modules"
|
||
},
|
||
predefinedOutput: true
|
||
}
|
||
</script>
|
||
<script>kernelName = 'python3'</script>
|
||
|
||
</article>
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<footer class="prev-next-footer d-print-none">
|
||
|
||
<div class="prev-next-area">
|
||
<a class="left-prev"
|
||
href="06_optimizers_ABOUT.html"
|
||
title="previous page">
|
||
<i class="fa-solid fa-angle-left"></i>
|
||
<div class="prev-next-info">
|
||
<p class="prev-next-subtitle">previous</p>
|
||
<p class="prev-next-title">06. Optimizers</p>
|
||
</div>
|
||
</a>
|
||
<a class="right-next"
|
||
href="../tiers/architecture.html"
|
||
title="next page">
|
||
<div class="prev-next-info">
|
||
<p class="prev-next-subtitle">next</p>
|
||
<p class="prev-next-title">🏛️ Architecture Tier (Modules 08-13)</p>
|
||
</div>
|
||
<i class="fa-solid fa-angle-right"></i>
|
||
</a>
|
||
</div>
|
||
</footer>
|
||
|
||
</div>
|
||
|
||
|
||
|
||
<div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner">
|
||
|
||
|
||
<div class="sidebar-secondary-item">
|
||
<div class="page-toc tocsection onthispage">
|
||
<i class="fa-solid fa-list"></i> Contents
|
||
</div>
|
||
<nav class="bd-toc-nav page-toc">
|
||
<ul class="visible nav section-nav flex-column">
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#overview">Overview</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#learning-objectives">Learning Objectives</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#build-use-reflect">Build → Use → Reflect</a></li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementation-guide">Implementation Guide</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#the-training-loop-cycle">The Training Loop Cycle</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#cosineschedule-adaptive-learning-rate-management">CosineSchedule - Adaptive Learning Rate Management</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#gradient-clipping-preventing-training-explosions">Gradient Clipping - Preventing Training Explosions</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#trainer-class-complete-training-orchestration">Trainer Class - Complete Training Orchestration</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#complete-training-example">Complete Training Example</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started">Getting Started</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#prerequisites">Prerequisites</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#development-workflow">Development Workflow</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#testing">Testing</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#comprehensive-test-suite">Comprehensive Test Suite</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#test-coverage-areas">Test Coverage Areas</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#inline-testing-training-analysis">Inline Testing & Training Analysis</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#manual-testing-examples">Manual Testing Examples</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#systems-thinking-questions">Systems Thinking Questions</a><ul class="nav section-nav flex-column">
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#real-world-applications">Real-World Applications</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#training-system-architecture">Training System Architecture</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#training-dynamics">Training Dynamics</a></li>
|
||
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#performance-characteristics">Performance Characteristics</a></li>
|
||
</ul>
|
||
</li>
|
||
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#ready-to-build">Ready to Build?</a></li>
|
||
</ul>
|
||
</nav></div>
|
||
|
||
</div></div>
|
||
|
||
|
||
</div>
|
||
<footer class="bd-footer-content">
|
||
|
||
<div class="bd-footer-content__inner container">
|
||
|
||
<div class="footer-item">
|
||
|
||
<p class="component-author">
|
||
By Prof. Vijay Janapa Reddi (Harvard University)
|
||
</p>
|
||
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
|
||
|
||
<p class="copyright">
|
||
|
||
© Copyright 2025.
|
||
<br/>
|
||
|
||
</p>
|
||
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</footer>
|
||
|
||
|
||
</main>
|
||
</div>
|
||
</div>
|
||
|
||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||
<script src="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||
<script src="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||
|
||
<footer class="bd-footer">
|
||
</footer>
|
||
</body>
|
||
</html> |