Files
TinyTorch/modules/12_attention_ABOUT.html
2025-12-05 00:52:38 +00:00

1572 lines
98 KiB
HTML
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
<!DOCTYPE html>
<html lang="en" data-content_root="../" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<title>12. Attention - The Mechanism That Powers Modern AI &#8212; Tiny🔥Torch</title>
<script data-cfasync="false">
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!-- Loaded before other Sphinx assets -->
<link href="../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
<link href="../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
<link href="../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
<link href="../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" />
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=03e43079" />
<link rel="stylesheet" type="text/css" href="../_static/styles/sphinx-book-theme.css?v=eba8b062" />
<link rel="stylesheet" type="text/css" href="../_static/togglebutton.css?v=13237357" />
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css?v=76b2166b" />
<link rel="stylesheet" type="text/css" href="../_static/mystnb.8ecb98da25f57f5357bf6f572d296f466b2cfe2517ffebfabe82451661e28f02.css" />
<link rel="stylesheet" type="text/css" href="../_static/sphinx-thebe.css?v=4fa983c6" />
<link rel="stylesheet" type="text/css" href="../_static/sphinx-design.min.css?v=95c83b7e" />
<link rel="stylesheet" type="text/css" href="../_static/custom.css?v=009d37f4" />
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" />
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
<script src="../_static/documentation_options.js?v=9eb32ce0"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
<script src="../_static/copybutton.js?v=f281be69"></script>
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
<script>let toggleHintShow = 'Click to show';</script>
<script>let toggleHintHide = 'Click to hide';</script>
<script>let toggleOpenOnPrint = 'true';</script>
<script src="../_static/togglebutton.js?v=4a39c7ea"></script>
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
<script src="../_static/design-tabs.js?v=f930bc37"></script>
<script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"; const thebe_selector = ".thebe,.cell"; const thebe_selector_input = "pre"; const thebe_selector_output = ".output, .cell_output"</script>
<script async="async" src="../_static/sphinx-thebe.js?v=c100c467"></script>
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
<script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"; const thebe_selector = ".thebe,.cell"; const thebe_selector_input = "pre"; const thebe_selector_output = ".output, .cell_output"</script>
<script type="module" src="https://cdn.jsdelivr.net/npm/mermaid@10.6.1/dist/mermaid.esm.min.mjs"></script>
<script type="module" src="https://cdn.jsdelivr.net/npm/@mermaid-js/layout-elk@0.2.0/dist/mermaid-layout-elk.esm.min.mjs"></script>
<script type="module">import mermaid from "https://cdn.jsdelivr.net/npm/mermaid@10.6.1/dist/mermaid.esm.min.mjs";import elkLayouts from "https://cdn.jsdelivr.net/npm/@mermaid-js/layout-elk@0.2.0/dist/mermaid-layout-elk.esm.min.mjs";mermaid.registerLayoutLoaders(elkLayouts);mermaid.initialize({startOnLoad:false});</script>
<script src="https://cdn.jsdelivr.net/npm/d3@7.9.0/dist/d3.min.js"></script>
<script type="module">import mermaid from "https://cdn.jsdelivr.net/npm/mermaid@10.6.1/dist/mermaid.esm.min.mjs";
const defaultStyle = document.createElement('style');
defaultStyle.textContent = `pre.mermaid {
/* Same as .mermaid-container > pre */
display: block;
width: 100%;
}
pre.mermaid > svg {
/* Same as .mermaid-container > pre > svg */
height: 500px;
width: 100%;
max-width: 100% !important;
}
`;
document.head.appendChild(defaultStyle);
const fullscreenStyle = document.createElement('style');
fullscreenStyle.textContent = `.mermaid-container {
display: flex;
flex-direction: row;
width: 100%;
}
.mermaid-container > pre {
display: block;
width: 100%;
}
.mermaid-container > pre > svg {
height: 500px;
width: 100%;
max-width: 100% !important;
}
.mermaid-fullscreen-btn {
width: 28px;
height: 28px;
background: rgba(255, 255, 255, 0.95);
border: 1px solid rgba(0, 0, 0, 0.3);
border-radius: 4px;
cursor: pointer;
display: flex;
align-items: center;
justify-content: center;
transition: all 0.2s;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.2);
font-size: 14px;
line-height: 1;
padding: 0;
color: #333;
}
.mermaid-fullscreen-btn:hover {
opacity: 100% !important;
background: rgba(255, 255, 255, 1);
box-shadow: 0 3px 10px rgba(0, 0, 0, 0.3);
transform: scale(1.1);
}
.mermaid-fullscreen-btn.dark-theme {
background: rgba(50, 50, 50, 0.95);
border: 1px solid rgba(255, 255, 255, 0.3);
color: #e0e0e0;
}
.mermaid-fullscreen-btn.dark-theme:hover {
background: rgba(60, 60, 60, 1);
box-shadow: 0 3px 10px rgba(255, 255, 255, 0.2);
}
.mermaid-fullscreen-modal {
display: none;
position: fixed !important;
top: 0 !important;
left: 0 !important;
width: 95vw;
height: 100vh;
background: rgba(255, 255, 255, 0.98);
z-index: 9999;
padding: 20px;
overflow: auto;
}
.mermaid-fullscreen-modal.dark-theme {
background: rgba(0, 0, 0, 0.98);
}
.mermaid-fullscreen-modal.active {
display: flex;
align-items: center;
justify-content: center;
}
.mermaid-container-fullscreen {
position: relative;
width: 95vw;
height: 90vh;
max-width: 95vw;
max-height: 90vh;
background: white;
border-radius: 8px;
padding: 20px;
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.3);
overflow: auto;
display: flex;
align-items: center;
justify-content: center;
}
.mermaid-container-fullscreen.dark-theme {
background: #1a1a1a;
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.8);
}
.mermaid-container-fullscreen pre.mermaid {
width: 100%;
height: 100%;
display: flex;
align-items: center;
justify-content: center;
}
.mermaid-container-fullscreen .mermaid svg {
height: 100% !important;
width: 100% !important;
cursor: grab;
}
.mermaid-fullscreen-close {
position: fixed !important;
top: 20px !important;
right: 20px !important;
width: 40px;
height: 40px;
background: rgba(255, 255, 255, 0.95);
border: 1px solid rgba(0, 0, 0, 0.2);
border-radius: 50%;
cursor: pointer;
z-index: 10000;
display: flex;
align-items: center;
justify-content: center;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
transition: all 0.2s;
font-size: 24px;
line-height: 1;
color: #333;
}
.mermaid-fullscreen-close:hover {
background: white;
box-shadow: 0 6px 16px rgba(0, 0, 0, 0.4);
transform: scale(1.1);
}
.mermaid-fullscreen-close.dark-theme {
background: rgba(50, 50, 50, 0.95);
border: 1px solid rgba(255, 255, 255, 0.2);
color: #e0e0e0;
}
.mermaid-fullscreen-close.dark-theme:hover {
background: rgba(60, 60, 60, 1);
box-shadow: 0 6px 16px rgba(255, 255, 255, 0.2);
}
.mermaid-fullscreen-modal .mermaid-fullscreen-btn {
display: none !important;
}`;
document.head.appendChild(fullscreenStyle);
// Detect if page has dark background
const isDarkTheme = () => {
const bgColor = window.getComputedStyle(document.body).backgroundColor;
const match = bgColor.match(/rgb\((\d+),\s*(\d+),\s*(\d+)/);
if (match) {
const r = parseInt(match[1]);
const g = parseInt(match[2]);
const b = parseInt(match[3]);
const brightness = (r * 299 + g * 587 + b * 114) / 1000;
return brightness < 128;
}
return false;
};
const load = async () => {
await mermaid.run();
const all_mermaids = document.querySelectorAll(".mermaid");
const mermaids_processed = document.querySelectorAll(".mermaid[data-processed='true']");
if ("False" === "True") {
const mermaids_to_add_zoom = -1 === -1 ? all_mermaids.length : -1;
if(mermaids_to_add_zoom > 0) {
var svgs = d3.selectAll("");
if(all_mermaids.length !== mermaids_processed.length) {
setTimeout(load, 200);
return;
} else if(svgs.size() !== mermaids_to_add_zoom) {
setTimeout(load, 200);
return;
} else {
svgs.each(function() {
var svg = d3.select(this);
svg.html("<g class='wrapper'>" + svg.html() + "</g>");
var inner = svg.select("g");
var zoom = d3.zoom().on("zoom", function(event) {
inner.attr("transform", event.transform);
});
svg.call(zoom);
});
}
}
} else if(all_mermaids.length !== mermaids_processed.length) {
// Wait for mermaid to process all diagrams
setTimeout(load, 200);
return;
}
const darkTheme = isDarkTheme();
// Stop here if not adding fullscreen capability
if ("True" !== "True") return;
const modal = document.createElement('div');
modal.className = 'mermaid-fullscreen-modal' + (darkTheme ? ' dark-theme' : '');
modal.setAttribute('role', 'dialog');
modal.setAttribute('aria-modal', 'true');
modal.setAttribute('aria-label', 'Fullscreen diagram viewer');
modal.innerHTML = `
<button class="mermaid-fullscreen-close${darkTheme ? ' dark-theme' : ''}" aria-label="Close fullscreen">✕</button>
<div class="mermaid-container-fullscreen${darkTheme ? ' dark-theme' : ''}"></div>
`;
document.body.appendChild(modal);
const modalContent = modal.querySelector('.mermaid-container-fullscreen');
const closeBtn = modal.querySelector('.mermaid-fullscreen-close');
let previousScrollOffset = [window.scrollX, window.scrollY];
const closeModal = () => {
modal.classList.remove('active');
modalContent.innerHTML = '';
document.body.style.overflow = ''
window.scrollTo({left: previousScrollOffset[0], top: previousScrollOffset[1], behavior: 'instant'});
};
closeBtn.addEventListener('click', closeModal);
modal.addEventListener('click', (e) => {
if (e.target === modal) closeModal();
});
document.addEventListener('keydown', (e) => {
if (e.key === 'Escape' && modal.classList.contains('active')) {
closeModal();
}
});
const allButtons = [];
document.querySelectorAll('.mermaid').forEach((mermaidDiv) => {
if (mermaidDiv.parentNode.classList.contains('mermaid-container') ||
mermaidDiv.closest('.mermaid-fullscreen-modal')) {
return;
}
const container = document.createElement('div');
container.className = 'mermaid-container';
mermaidDiv.parentNode.insertBefore(container, mermaidDiv);
container.appendChild(mermaidDiv);
const fullscreenBtn = document.createElement('button');
fullscreenBtn.className = 'mermaid-fullscreen-btn' + (darkTheme ? ' dark-theme' : '');
fullscreenBtn.setAttribute('aria-label', 'View diagram in fullscreen');
fullscreenBtn.textContent = '⛶';
fullscreenBtn.style.opacity = '50%';
// Calculate dynamic position based on diagram's margin and padding
const diagramStyle = window.getComputedStyle(mermaidDiv);
const marginTop = parseFloat(diagramStyle.marginTop) || 0;
const marginRight = parseFloat(diagramStyle.marginRight) || 0;
const paddingTop = parseFloat(diagramStyle.paddingTop) || 0;
const paddingRight = parseFloat(diagramStyle.paddingRight) || 0;
fullscreenBtn.style.top = `${marginTop + paddingTop + 4}px`;
fullscreenBtn.style.right = `${marginRight + paddingRight + 4}px`;
fullscreenBtn.addEventListener('click', () => {
previousScrollOffset = [window.scroll, window.scrollY];
const clone = mermaidDiv.cloneNode(true);
modalContent.innerHTML = '';
modalContent.appendChild(clone);
const svg = clone.querySelector('svg');
if (svg) {
svg.removeAttribute('width');
svg.removeAttribute('height');
svg.style.width = '100%';
svg.style.height = 'auto';
svg.style.maxWidth = '100%';
svg.style.sdisplay = 'block';
if ("False" === "True") {
setTimeout(() => {
const g = svg.querySelector('g');
if (g) {
var svgD3 = d3.select(svg);
svgD3.html("<g class='wrapper'>" + svgD3.html() + "</g>");
var inner = svgD3.select("g");
var zoom = d3.zoom().on("zoom", function(event) {
inner.attr("transform", event.transform);
});
svgD3.call(zoom);
}
}, 100);
}
}
modal.classList.add('active');
document.body.style.overflow = 'hidden';
});
container.appendChild(fullscreenBtn);
allButtons.push(fullscreenBtn);
});
// Update theme classes when theme changes
const updateTheme = () => {
const dark = isDarkTheme();
allButtons.forEach(btn => {
if (dark) {
btn.classList.add('dark-theme');
} else {
btn.classList.remove('dark-theme');
}
});
if (dark) {
modal.classList.add('dark-theme');
modalContent.classList.add('dark-theme');
closeBtn.classList.add('dark-theme');
} else {
modal.classList.remove('dark-theme');
modalContent.classList.remove('dark-theme');
closeBtn.classList.remove('dark-theme');
}
};
// Watch for theme changes
const observer = new MutationObserver(updateTheme);
observer.observe(document.documentElement, {
attributes: true,
attributeFilter: ['class', 'style', 'data-theme']
});
observer.observe(document.body, {
attributes: true,
attributeFilter: ['class', 'style']
});
};
window.addEventListener("load", load);
</script>
<script>DOCUMENTATION_OPTIONS.pagename = 'modules/12_attention_ABOUT';</script>
<script src="../_static/ml-timeline.js?v=76e9b3e3"></script>
<script src="../_static/wip-banner.js?v=04a7e74d"></script>
<script src="../_static/marimo-badges.js?v=e6289128"></script>
<script src="../_static/sidebar-link.js?v=404b701b"></script>
<script src="../_static/hero-carousel.js?v=10341d2a"></script>
<script src="../_static/subscribe-modal.js?v=42919b64"></script>
<link rel="icon" href="../_static/favicon.svg"/>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="13. Transformers - Complete GPT Architecture" href="13_transformers_ABOUT.html" />
<link rel="prev" title="11. Embeddings - Token to Vector Representations" href="11_embeddings_ABOUT.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
</head>
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
<div id="pst-scroll-pixel-helper"></div>
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
<input type="checkbox"
class="sidebar-toggle"
id="pst-primary-sidebar-checkbox"/>
<label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label>
<input type="checkbox"
class="sidebar-toggle"
id="pst-secondary-sidebar-checkbox"/>
<label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label>
<div class="search-button__wrapper">
<div class="search-button__overlay"></div>
<div class="search-button__search-container">
<form class="bd-search d-flex align-items-center"
action="../search.html"
method="get">
<i class="fa-solid fa-magnifying-glass"></i>
<input type="search"
class="form-control"
name="q"
id="search-input"
placeholder="Search..."
aria-label="Search..."
autocomplete="off"
autocorrect="off"
autocapitalize="off"
spellcheck="false"/>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
</form></div>
</div>
<div class="pst-async-banner-revealer d-none">
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
</div>
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
</header>
<div class="bd-container">
<div class="bd-container__inner bd-page-width">
<div class="bd-sidebar-primary bd-sidebar">
<div class="sidebar-header-items sidebar-primary__section">
</div>
<div class="sidebar-primary-items__start sidebar-primary__section">
<div class="sidebar-primary-item">
<a class="navbar-brand logo" href="../intro.html">
<img src="../_static/logo-tinytorch.png" class="logo__image only-light" alt="Tiny🔥Torch - Home"/>
<script>document.write(`<img src="../_static/logo-tinytorch.png" class="logo__image only-dark" alt="Tiny🔥Torch - Home"/>`);</script>
</a></div>
<div class="sidebar-primary-item">
<script>
document.write(`
<button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
<span class="search-button__default-text">Search</span>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
</button>
`);
</script></div>
<div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main">
<div class="bd-toc-item navbar-nav active">
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🚀 Getting Started</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../getting-started.html">Complete Guide</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🏗 Foundation Tier (01-07)</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../tiers/foundation.html">📖 Tier Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="01_tensor_ABOUT.html">01. Tensor</a></li>
<li class="toctree-l1"><a class="reference internal" href="02_activations_ABOUT.html">02. Activations</a></li>
<li class="toctree-l1"><a class="reference internal" href="03_layers_ABOUT.html">03. Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="04_losses_ABOUT.html">04. Losses</a></li>
<li class="toctree-l1"><a class="reference internal" href="05_autograd_ABOUT.html">05. Autograd</a></li>
<li class="toctree-l1"><a class="reference internal" href="06_optimizers_ABOUT.html">06. Optimizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="07_training_ABOUT.html">07. Training</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🏛️ Architecture Tier (08-13)</span></p>
<ul class="current nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../tiers/architecture.html">📖 Tier Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="08_dataloader_ABOUT.html">08. DataLoader</a></li>
<li class="toctree-l1"><a class="reference internal" href="09_spatial_ABOUT.html">09. Convolutions</a></li>
<li class="toctree-l1"><a class="reference internal" href="10_tokenization_ABOUT.html">10. Tokenization</a></li>
<li class="toctree-l1"><a class="reference internal" href="11_embeddings_ABOUT.html">11. Embeddings</a></li>
<li class="toctree-l1 current active"><a class="current reference internal" href="#">12. Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="13_transformers_ABOUT.html">13. Transformers</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">⏱️ Optimization Tier (14-19)</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../tiers/optimization.html">📖 Tier Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="14_profiling_ABOUT.html">14. Profiling</a></li>
<li class="toctree-l1"><a class="reference internal" href="15_quantization_ABOUT.html">15. Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="16_compression_ABOUT.html">16. Compression</a></li>
<li class="toctree-l1"><a class="reference internal" href="17_memoization_ABOUT.html">17. Memoization</a></li>
<li class="toctree-l1"><a class="reference internal" href="18_acceleration_ABOUT.html">18. Acceleration</a></li>
<li class="toctree-l1"><a class="reference internal" href="19_benchmarking_ABOUT.html">19. Benchmarking</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🏅 Capstone Competition</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../tiers/olympics.html">📖 Competition Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="20_capstone_ABOUT.html">20. Torch Olympics</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🧭 Course Orientation</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../chapters/00-introduction.html">Course Structure</a></li>
<li class="toctree-l1"><a class="reference internal" href="../prerequisites.html">Prerequisites &amp; Resources</a></li>
<li class="toctree-l1"><a class="reference internal" href="../chapters/learning-journey.html">Learning Journey</a></li>
<li class="toctree-l1"><a class="reference internal" href="../chapters/milestones.html">Historical Milestones</a></li>
<li class="toctree-l1"><a class="reference internal" href="../faq.html">FAQ</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🛠️ TITO CLI Reference</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../tito/overview.html">Command Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tito/modules.html">Module Workflow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tito/milestones.html">Milestone System</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tito/data.html">Progress &amp; Data</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tito/troubleshooting.html">Troubleshooting</a></li>
<li class="toctree-l1"><a class="reference internal" href="../datasets.html">Datasets Guide</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">🤝 Community</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../community.html">Ecosystem</a></li>
<li class="toctree-l1"><a class="reference internal" href="../resources.html">Learning Resources</a></li>
<li class="toctree-l1"><a class="reference internal" href="../credits.html">Credits &amp; Acknowledgments</a></li>
</ul>
</div>
</nav></div>
</div>
<div class="sidebar-primary-items__end sidebar-primary__section">
</div>
<div id="rtd-footer-container"></div>
</div>
<main id="main-content" class="bd-main" role="main">
<div class="sbt-scroll-pixel-helper"></div>
<div class="bd-content">
<div class="bd-article-container">
<div class="bd-header-article d-print-none">
<div class="header-article-items header-article__inner">
<div class="header-article-items__start">
<div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
<span class="fa-solid fa-bars"></span>
</button></div>
</div>
<div class="header-article-items__end">
<div class="header-article-item">
<div class="article-header-buttons">
<div class="dropdown dropdown-download-buttons">
<button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page">
<i class="fas fa-download"></i>
</button>
<ul class="dropdown-menu">
<li><a href="../_sources/modules/12_attention_ABOUT.md" target="_blank"
class="btn btn-sm btn-download-source-button dropdown-item"
title="Download source file"
data-bs-placement="left" data-bs-toggle="tooltip"
>
<span class="btn__icon-container">
<i class="fas fa-file"></i>
</span>
<span class="btn__text-container">.md</span>
</a>
</li>
<li>
<button onclick="window.print()"
class="btn btn-sm btn-download-pdf-button dropdown-item"
title="Print to PDF"
data-bs-placement="left" data-bs-toggle="tooltip"
>
<span class="btn__icon-container">
<i class="fas fa-file-pdf"></i>
</span>
<span class="btn__text-container">.pdf</span>
</button>
</li>
</ul>
</div>
<button onclick="toggleFullScreen()"
class="btn btn-sm btn-fullscreen-button"
title="Fullscreen mode"
data-bs-placement="bottom" data-bs-toggle="tooltip"
>
<span class="btn__icon-container">
<i class="fas fa-expand"></i>
</span>
</button>
<script>
document.write(`
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i>
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i>
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i>
</button>
`);
</script>
<script>
document.write(`
<button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass fa-lg"></i>
</button>
`);
</script>
<button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip">
<span class="fa-solid fa-list"></span>
</button>
</div></div>
</div>
</div>
</div>
<div id="jb-print-docs-body" class="onlyprint">
<h1>12. Attention - The Mechanism That Powers Modern AI</h1>
<!-- Table of contents -->
<div id="print-main-content">
<div id="jb-print-toc">
<div>
<h2> Contents </h2>
</div>
<nav aria-label="Page">
<ul class="visible nav section-nav flex-column">
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#overview">Overview</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#learning-objectives">Learning Objectives</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#build-use-reflect">Build → Use → Reflect</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementation-guide">Implementation Guide</a><ul class="nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#attention-mechanism-flow">Attention Mechanism Flow</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#core-components">Core Components</a><ul class="nav section-nav flex-column">
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#scaled-dot-product-attention-scaled-dot-product-attention">1. Scaled Dot-Product Attention (<code class="docutils literal notranslate"><span class="pre">scaled_dot_product_attention</span></code>)</a></li>
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#multi-head-attention-multiheadattention">2. Multi-Head Attention (<code class="docutils literal notranslate"><span class="pre">MultiHeadAttention</span></code>)</a></li>
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#masking-utilities">3. Masking Utilities</a></li>
</ul>
</li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#attention-complexity-analysis">Attention Complexity Analysis</a><ul class="nav section-nav flex-column">
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#time-complexity-o-n2-d">Time Complexity: O(n² × d)</a></li>
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#memory-complexity-o-batch-heads-n2">Memory Complexity: O(batch × heads × n²)</a></li>
</ul>
</li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#comparing-to-pytorch">Comparing to PyTorch</a></li>
</ul>
</li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started">Getting Started</a><ul class="nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#prerequisites">Prerequisites</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#development-workflow">Development Workflow</a></li>
</ul>
</li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#testing">Testing</a><ul class="nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#comprehensive-test-suite">Comprehensive Test Suite</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#test-coverage-areas">Test Coverage Areas</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#inline-testing-complexity-analysis">Inline Testing &amp; Complexity Analysis</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#manual-testing-examples">Manual Testing Examples</a></li>
</ul>
</li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#systems-thinking-questions">Systems Thinking Questions</a><ul class="nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#real-world-applications">Real-World Applications</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#mathematical-foundations">Mathematical Foundations</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#computational-characteristics">Computational Characteristics</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#how-your-implementation-maps-to-pytorch">How Your Implementation Maps to PyTorch</a></li>
</ul>
</li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#ready-to-build">Ready to Build?</a></li>
</ul>
</nav>
</div>
</div>
</div>
<div id="searchbox"></div>
<article class="bd-article">
<section id="attention-the-mechanism-that-powers-modern-ai">
<h1>12. Attention - The Mechanism That Powers Modern AI<a class="headerlink" href="#attention-the-mechanism-that-powers-modern-ai" title="Link to this heading">#</a></h1>
<p><strong>ARCHITECTURE TIER</strong> | Difficulty: ⭐⭐⭐ (3/4) | Time: 5-6 hours</p>
<section id="overview">
<h2>Overview<a class="headerlink" href="#overview" title="Link to this heading">#</a></h2>
<p>Implement the attention mechanism that revolutionized AI and sparked the modern transformer era. This module builds scaled dot-product attention and multi-head attention—the exact mechanisms powering GPT, BERT, Claude, and every major language model deployed today. Youll implement attention with explicit loops to viscerally understand the O(n²) complexity that defines both the power and limitations of transformer architectures.</p>
<p>The “Attention is All You Need” paper (2017) introduced these mechanisms and replaced RNNs with pure attention architectures, enabling parallelization and global context from layer one. Understanding attention from first principles—including its computational bottlenecks—is essential for working with production transformers and understanding why FlashAttention, sparse attention, and linear attention are active research frontiers.</p>
</section>
<section id="learning-objectives">
<h2>Learning Objectives<a class="headerlink" href="#learning-objectives" title="Link to this heading">#</a></h2>
<p>By the end of this module, you will be able to:</p>
<ul class="simple">
<li><p><strong>Understand O(n²) Complexity</strong>: Implement attention with explicit loops to witness quadratic scaling in memory and computation, understanding why long-context AI remains challenging</p></li>
<li><p><strong>Build Scaled Dot-Product Attention</strong>: Implement softmax(QK^T / √d_k)V with proper numerical stability, understanding how 1/√d_k prevents gradient vanishing</p></li>
<li><p><strong>Create Multi-Head Attention</strong>: Build parallel attention heads that learn different patterns (syntax, semantics, position) and concatenate their outputs for rich representations</p></li>
<li><p><strong>Master Masking Strategies</strong>: Implement causal masking for autoregressive generation (GPT), understand bidirectional attention for encoding (BERT), and handle padding masks</p></li>
<li><p><strong>Analyze Production Trade-offs</strong>: Experience attentions memory bottleneck firsthand, understand why FlashAttention matters, and explore the compute-memory trade-off space</p></li>
</ul>
</section>
<section id="build-use-reflect">
<h2>Build → Use → Reflect<a class="headerlink" href="#build-use-reflect" title="Link to this heading">#</a></h2>
<p>This module follows TinyTorchs <strong>Build → Use → Reflect</strong> framework:</p>
<ol class="arabic simple">
<li><p><strong>Build</strong>: Implement scaled dot-product attention with explicit O(n²) loops (educational), create MultiHeadAttention class with Q/K/V projections and head splitting, and build masking utilities</p></li>
<li><p><strong>Use</strong>: Apply attention to realistic sequences with causal masking for language modeling, visualize attention patterns showing what the model “sees,” and test with different head configurations</p></li>
<li><p><strong>Reflect</strong>: Why does attention scale O(n²)? How do different heads specialize without supervision? What memory bottlenecks emerge at GPT-4 scale (128 heads, 8K+ context)?</p></li>
</ol>
</section>
<section id="implementation-guide">
<h2>Implementation Guide<a class="headerlink" href="#implementation-guide" title="Link to this heading">#</a></h2>
<section id="attention-mechanism-flow">
<h3>Attention Mechanism Flow<a class="headerlink" href="#attention-mechanism-flow" title="Link to this heading">#</a></h3>
<p>The attention mechanism transforms queries, keys, and values into context-aware representations:</p>
<pre class="mermaid">
graph LR
A[Query&lt;br/&gt;Q: n×d] --&gt; D[Scores&lt;br/&gt;QK^T/√d]
B[Key&lt;br/&gt;K: n×d] --&gt; D
D --&gt; E[Attention&lt;br/&gt;Weights&lt;br/&gt;softmax]
E --&gt; F[Context&lt;br/&gt;×V]
C[Value&lt;br/&gt;V: n×d] --&gt; F
F --&gt; G[Output&lt;br/&gt;n×d]
style A fill:#e3f2fd
style B fill:#e3f2fd
style C fill:#e3f2fd
style D fill:#fff3e0
style E fill:#ffe0b2
style F fill:#f3e5f5
style G fill:#f0fdf4
</pre><p><strong>Flow</strong>: Queries attend to Keys (QK^T) → Scale by √d → Softmax for weights → Weighted sum of Values → Context output</p>
</section>
<section id="core-components">
<h3>Core Components<a class="headerlink" href="#core-components" title="Link to this heading">#</a></h3>
<p>Your attention implementation consists of three fundamental building blocks:</p>
<section id="scaled-dot-product-attention-scaled-dot-product-attention">
<h4>1. Scaled Dot-Product Attention (<code class="docutils literal notranslate"><span class="pre">scaled_dot_product_attention</span></code>)<a class="headerlink" href="#scaled-dot-product-attention-scaled-dot-product-attention" title="Link to this heading">#</a></h4>
<p>The mathematical foundation that powers all transformers:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">scaled_dot_product_attention</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Attention(Q, K, V) = softmax(QK^T / √d_k) V</span>
<span class="sd"> This exact formula powers GPT, BERT, Claude, and all transformers.</span>
<span class="sd"> Implemented with explicit loops to show O(n²) complexity.</span>
<span class="sd"> Args:</span>
<span class="sd"> Q: Query matrix (batch, seq_len, d_model)</span>
<span class="sd"> K: Key matrix (batch, seq_len, d_model)</span>
<span class="sd"> V: Value matrix (batch, seq_len, d_model)</span>
<span class="sd"> mask: Optional causal mask (batch, seq_len, seq_len)</span>
<span class="sd"> Returns:</span>
<span class="sd"> output: Attended values (batch, seq_len, d_model)</span>
<span class="sd"> attention_weights: Attention matrix (batch, seq_len, seq_len)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># Step 1: Compute attention scores (O(n²) operation)</span>
<span class="c1"># For each query i and key j: score[i,j] = Q[i] · K[j]</span>
<span class="c1"># Step 2: Scale by 1/√d_k for numerical stability</span>
<span class="c1"># Prevents softmax saturation as dimensionality increases</span>
<span class="c1"># Step 3: Apply optional causal mask</span>
<span class="c1"># Masked positions set to -1e9 (becomes ~0 after softmax)</span>
<span class="c1"># Step 4: Softmax normalization (each row sums to 1)</span>
<span class="c1"># Converts scores to probability distribution</span>
<span class="c1"># Step 5: Weighted sum of values (another O(n²) operation)</span>
<span class="c1"># output[i] = Σ(attention_weights[i,j] × V[j]) for all j</span>
</pre></div>
</div>
<p><strong>Key Implementation Details:</strong></p>
<ul class="simple">
<li><p><strong>Explicit Loops</strong>: Educational implementation shows exactly where O(n²) complexity comes from (every query attends to every key)</p></li>
<li><p><strong>Scaling Factor</strong>: 1/√d_k prevents dot products from growing large as dimensionality increases, maintaining gradient flow</p></li>
<li><p><strong>Masking Before Softmax</strong>: Setting masked positions to -1e9 makes them effectively zero after softmax</p></li>
<li><p><strong>Return Attention Weights</strong>: Essential for visualization and interpretability analysis</p></li>
</ul>
<p><strong>What Youll Learn:</strong></p>
<ul class="simple">
<li><p>Why attention weights must sum to 1 (probability distribution property)</p></li>
<li><p>How the scaling factor prevents gradient vanishing</p></li>
<li><p>The exact computational cost: 2n²d operations (QK^T + weights×V)</p></li>
<li><p>Why memory scales as O(batch × n²) for attention matrices</p></li>
</ul>
</section>
<section id="multi-head-attention-multiheadattention">
<h4>2. Multi-Head Attention (<code class="docutils literal notranslate"><span class="pre">MultiHeadAttention</span></code>)<a class="headerlink" href="#multi-head-attention-multiheadattention" title="Link to this heading">#</a></h4>
<p>Parallel attention “heads” that learn different relationship patterns:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">MultiHeadAttention</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Multi-head attention from &#39;Attention is All You Need&#39;.</span>
<span class="sd"> Projects input to Q, K, V, splits into multiple heads,</span>
<span class="sd"> applies attention in parallel, concatenates, and projects output.</span>
<span class="sd"> Example: d_model=512, num_heads=8</span>
<span class="sd"> → Each head processes 64 dimensions (512 ÷ 8)</span>
<span class="sd"> → 8 heads learn different attention patterns in parallel</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">):</span>
<span class="c1"># Validate: embed_dim must be divisible by num_heads</span>
<span class="c1"># Create Q, K, V projection layers (Linear(embed_dim, embed_dim))</span>
<span class="c1"># Create output projection layer</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="c1"># 1. Project input to Q, K, V</span>
<span class="c1"># 2. Split into heads: (batch, seq, embed_dim) → (batch, heads, seq, head_dim)</span>
<span class="c1"># 3. Apply attention to each head in parallel</span>
<span class="c1"># 4. Concatenate heads back together</span>
<span class="c1"># 5. Apply output projection to mix information across heads</span>
</pre></div>
</div>
<p><strong>Architecture Flow:</strong></p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>Input (batch, seq, 512)
↓ [Q/K/V Linear Projections]
Q, K, V (batch, seq, 512)
↓ [Reshape &amp; Split into 8 heads]
(batch, 8 heads, seq, 64 per head)
↓ [Parallel Attention on Each Head]
Head₁ learns syntax patterns (subject-verb agreement)
Head₂ learns semantics (word similarity)
Head₃ learns position (relative distance)
Head₄ learns long-range (coreference)
...
↓ [Concatenate Heads]
(batch, seq, 512)
↓ [Output Projection]
Output (batch, seq, 512)
</pre></div>
</div>
<p><strong>Key Implementation Details:</strong></p>
<ul class="simple">
<li><p><strong>Head Splitting</strong>: Reshape from (batch, seq, embed_dim) to (batch, heads, seq, head_dim) via transpose operations</p></li>
<li><p><strong>Parallel Processing</strong>: All heads compute simultaneously—GPU parallelism critical for efficiency</p></li>
<li><p><strong>Four Linear Layers</strong>: Three for Q/K/V projections, one for output (standard transformer architecture)</p></li>
<li><p><strong>Head Concatenation</strong>: Reverse the split operation to merge heads back to original dimensions</p></li>
</ul>
<p><strong>What Youll Learn:</strong></p>
<ul class="simple">
<li><p>Why multiple heads capture richer representations than single-head</p></li>
<li><p>How heads naturally specialize without explicit supervision</p></li>
<li><p>The computational trade-off: same O(n²d) complexity but higher constant factor</p></li>
<li><p>Why head_dim = embed_dim / num_heads is the standard configuration</p></li>
</ul>
</section>
<section id="masking-utilities">
<h4>3. Masking Utilities<a class="headerlink" href="#masking-utilities" title="Link to this heading">#</a></h4>
<p>Control information flow patterns for different tasks:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">create_causal_mask</span><span class="p">(</span><span class="n">seq_len</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Lower triangular mask for autoregressive (GPT-style) attention.</span>
<span class="sd"> Position i can only attend to positions ≤ i (no future peeking).</span>
<span class="sd"> Example (seq_len=4):</span>
<span class="sd"> [[1, 0, 0, 0], # Position 0 sees only position 0</span>
<span class="sd"> [1, 1, 0, 0], # Position 1 sees 0, 1</span>
<span class="sd"> [1, 1, 1, 0], # Position 2 sees 0, 1, 2</span>
<span class="sd"> [1, 1, 1, 1]] # Position 3 sees all positions</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">tril</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">))))</span>
<span class="k">def</span><span class="w"> </span><span class="nf">create_padding_mask</span><span class="p">(</span><span class="n">lengths</span><span class="p">,</span> <span class="n">max_length</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Prevents attention to padding tokens in variable-length sequences.</span>
<span class="sd"> Essential for efficient batching of different-length sequences.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># Create mask where 1=real token, 0=padding</span>
<span class="c1"># Shape: (batch_size, 1, 1, max_length) for broadcasting</span>
</pre></div>
</div>
<p><strong>Masking Strategies:</strong></p>
<ul class="simple">
<li><p><strong>Causal (GPT)</strong>: Lower triangular—blocks n(n-1)/2 connections for autoregressive generation</p></li>
<li><p><strong>Bidirectional (BERT)</strong>: No mask—full n² connections for encoding with full context</p></li>
<li><p><strong>Padding</strong>: Batch-specific—prevents attention to padding tokens in variable-length batches</p></li>
<li><p><strong>Combined</strong>: Can multiply masks element-wise (e.g., causal + padding)</p></li>
</ul>
<p><strong>What Youll Learn:</strong></p>
<ul class="simple">
<li><p>How masking strategy fundamentally defines model capabilities (generation vs encoding)</p></li>
<li><p>Why causal masking is essential for language modeling training stability</p></li>
<li><p>The performance benefit of efficient batching with padding masks</p></li>
<li><p>How mask shape broadcasting works with attention scores</p></li>
</ul>
</section>
</section>
<section id="attention-complexity-analysis">
<h3>Attention Complexity Analysis<a class="headerlink" href="#attention-complexity-analysis" title="Link to this heading">#</a></h3>
<p>Understanding the computational and memory bottlenecks:</p>
<section id="time-complexity-o-n2-d">
<h4>Time Complexity: O(n² × d)<a class="headerlink" href="#time-complexity-o-n2-d" title="Link to this heading">#</a></h4>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>For sequence length n and embedding dimension d:
QK^T computation:
- n queries × n keys = n² similarity scores
- Each score: dot product over d dimensions
- Total: O(n² × d) operations
Softmax normalization:
- Apply to n² scores
- Total: O(n²) operations
Attention × Values:
- n² weights × n values = n³ operations
- But dimension d: effectively O(n² × d)
- Total: O(n² × d) operations
Dominant: O(n² × d) for both QK^T and weights×V
</pre></div>
</div>
<p><strong>Scaling Impact:</strong></p>
<ul class="simple">
<li><p>Doubling sequence length quadruples compute</p></li>
<li><p>n=1024 → 1M scores per head</p></li>
<li><p>n=4096 (GPT-3) → 16M scores per head (16× more)</p></li>
<li><p>n=32K (GPT-4) → 1B scores per head (1000× more than 1024)</p></li>
</ul>
</section>
<section id="memory-complexity-o-batch-heads-n2">
<h4>Memory Complexity: O(batch × heads × n²)<a class="headerlink" href="#memory-complexity-o-batch-heads-n2" title="Link to this heading">#</a></h4>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>Attention weights matrix shape: (batch, heads, seq_len, seq_len)
Example: GPT-3 scale inference
- batch=32, heads=96, seq=2048
- Attention weights: 32 × 96 × 2048 × 2048 = 12.8 billion values
- At FP32 (4 bytes): 51.2 GB just for attention weights
- With 96 layers: 4.9 TB total (clearly infeasible!)
This is why:
- FlashAttention fuses operations to avoid storing attention matrix
- Mixed precision training uses FP16 (2× memory reduction)
- Gradient checkpointing recomputes instead of storing
- Production models use extensive optimization tricks
</pre></div>
</div>
<p><strong>The Memory Bottleneck:</strong></p>
<ul class="simple">
<li><p>For long contexts (32K+ tokens), attention memory dominates total usage</p></li>
<li><p>Storing attention weights becomes infeasible—must compute on-the-fly</p></li>
<li><p>FlashAttention breakthrough: O(n) memory instead of O(n²) via kernel fusion</p></li>
<li><p>Understanding this bottleneck guides all modern attention optimization research</p></li>
</ul>
</section>
</section>
<section id="comparing-to-pytorch">
<h3>Comparing to PyTorch<a class="headerlink" href="#comparing-to-pytorch" title="Link to this heading">#</a></h3>
<p>Your implementation vs <code class="docutils literal notranslate"><span class="pre">torch.nn.MultiheadAttention</span></code>:</p>
<div class="pst-scrollable-table-container"><table class="table">
<thead>
<tr class="row-odd"><th class="head"><p>Aspect</p></th>
<th class="head"><p>Your TinyTorch Implementation</p></th>
<th class="head"><p>PyTorch Production</p></th>
</tr>
</thead>
<tbody>
<tr class="row-even"><td><p><strong>Algorithm</strong></p></td>
<td><p>Exact same: softmax(QK^T/√d_k)V</p></td>
<td><p>Same mathematical formula</p></td>
</tr>
<tr class="row-odd"><td><p><strong>Loops</strong></p></td>
<td><p>Explicit (educational)</p></td>
<td><p>Fused GPU kernels</p></td>
</tr>
<tr class="row-even"><td><p><strong>Masking</strong></p></td>
<td><p>Manual application</p></td>
<td><p>Built-in mask parameter</p></td>
</tr>
<tr class="row-odd"><td><p><strong>Memory</strong></p></td>
<td><p>O(n²) attention matrix stored</p></td>
<td><p>FlashAttention-optimized</p></td>
</tr>
<tr class="row-even"><td><p><strong>Batching</strong></p></td>
<td><p>Standard implementation</p></td>
<td><p>Highly optimized kernels</p></td>
</tr>
<tr class="row-odd"><td><p><strong>Numerical Stability</strong></p></td>
<td><p>1/√d_k scaling</p></td>
<td><p>Same + additional safeguards</p></td>
</tr>
</tbody>
</table>
</div>
<p><strong>What You Gained:</strong></p>
<ul class="simple">
<li><p>Deep understanding of O(n²) complexity by seeing explicit loops</p></li>
<li><p>Insight into why FlashAttention and kernel fusion matter</p></li>
<li><p>Knowledge of masking strategies and their architectural implications</p></li>
<li><p>Foundation for understanding advanced attention variants (sparse, linear)</p></li>
</ul>
</section>
</section>
<section id="getting-started">
<h2>Getting Started<a class="headerlink" href="#getting-started" title="Link to this heading">#</a></h2>
<section id="prerequisites">
<h3>Prerequisites<a class="headerlink" href="#prerequisites" title="Link to this heading">#</a></h3>
<p>Ensure you understand these foundations:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="c1"># Activate TinyTorch environment</span>
<span class="nb">source</span><span class="w"> </span>scripts/activate-tinytorch
<span class="c1"># Verify prerequisite modules</span>
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>tensor<span class="w"> </span><span class="c1"># Matrix operations (matmul, transpose)</span>
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>activations<span class="w"> </span><span class="c1"># Softmax for attention normalization</span>
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>layers<span class="w"> </span><span class="c1"># Linear layers for Q/K/V projections</span>
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>embeddings<span class="w"> </span><span class="c1"># Token/position embeddings attention operates on</span>
</pre></div>
</div>
<p><strong>Core Concepts Youll Need:</strong></p>
<ul class="simple">
<li><p><strong>Matrix Multiplication</strong>: Understanding QK^T computation and broadcasting</p></li>
<li><p><strong>Softmax Numerical Stability</strong>: Subtracting max before exp prevents overflow</p></li>
<li><p><strong>Layer Composition</strong>: How Q/K/V projections combine with attention</p></li>
<li><p><strong>Shape Manipulation</strong>: Reshape and transpose operations for head splitting</p></li>
</ul>
</section>
<section id="development-workflow">
<h3>Development Workflow<a class="headerlink" href="#development-workflow" title="Link to this heading">#</a></h3>
<ol class="arabic simple">
<li><p><strong>Open the development file</strong>: <code class="docutils literal notranslate"><span class="pre">modules/12_attention/attention_dev.ipynb</span></code> (notebook) or <code class="docutils literal notranslate"><span class="pre">attention_dev.py</span></code> (script)</p></li>
<li><p><strong>Implement scaled_dot_product_attention</strong>: Build core attention formula with explicit loops showing O(n²) complexity</p></li>
<li><p><strong>Create MultiHeadAttention class</strong>: Add Q/K/V projections, head splitting, parallel attention, and output projection</p></li>
<li><p><strong>Build masking utilities</strong>: Create causal mask for GPT-style attention and padding mask for batching</p></li>
<li><p><strong>Test and analyze</strong>: Run comprehensive tests, visualize attention patterns, and profile computational scaling</p></li>
<li><p><strong>Export and verify</strong>: <code class="docutils literal notranslate"><span class="pre">tito</span> <span class="pre">module</span> <span class="pre">complete</span> <span class="pre">12</span> <span class="pre">&amp;&amp;</span> <span class="pre">tito</span> <span class="pre">test</span> <span class="pre">attention</span></code></p></li>
</ol>
</section>
</section>
<section id="testing">
<h2>Testing<a class="headerlink" href="#testing" title="Link to this heading">#</a></h2>
<section id="comprehensive-test-suite">
<h3>Comprehensive Test Suite<a class="headerlink" href="#comprehensive-test-suite" title="Link to this heading">#</a></h3>
<p>Run the full test suite to verify attention functionality:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="c1"># TinyTorch CLI (recommended)</span>
tito<span class="w"> </span><span class="nb">test</span><span class="w"> </span>attention
<span class="c1"># Direct pytest execution</span>
python<span class="w"> </span>-m<span class="w"> </span>pytest<span class="w"> </span>tests/<span class="w"> </span>-k<span class="w"> </span>attention<span class="w"> </span>-v
<span class="c1"># Inline testing during development</span>
python<span class="w"> </span>modules/12_attention/attention_dev.py
</pre></div>
</div>
</section>
<section id="test-coverage-areas">
<h3>Test Coverage Areas<a class="headerlink" href="#test-coverage-areas" title="Link to this heading">#</a></h3>
<ul class="simple">
<li><p><strong>Attention Scores Computation</strong>: Verifies QK^T produces correct shapes and values</p></li>
<li><p><strong>Numerical Stability</strong>: Confirms 1/√d_k scaling prevents softmax saturation</p></li>
<li><p><strong>Probability Normalization</strong>: Validates attention weights sum to 1.0 per query</p></li>
<li><p><strong>Causal Masking</strong>: Tests that future positions get zero attention weight</p></li>
<li><p><strong>Multi-Head Configuration</strong>: Checks head splitting, parallel processing, and concatenation</p></li>
<li><p><strong>Shape Preservation</strong>: Ensures input shape equals output shape</p></li>
<li><p><strong>Gradient Flow</strong>: Verifies differentiability through attention computation graph</p></li>
<li><p><strong>Computational Complexity</strong>: Profiles O(n²) scaling with increasing sequence length</p></li>
</ul>
</section>
<section id="inline-testing-complexity-analysis">
<h3>Inline Testing &amp; Complexity Analysis<a class="headerlink" href="#inline-testing-complexity-analysis" title="Link to this heading">#</a></h3>
<p>The module includes comprehensive validation and performance analysis:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="err">🔬</span> <span class="n">Unit</span> <span class="n">Test</span><span class="p">:</span> <span class="n">Scaled</span> <span class="n">Dot</span><span class="o">-</span><span class="n">Product</span> <span class="n">Attention</span><span class="o">...</span>
<span class="err"></span> <span class="n">Attention</span> <span class="n">scores</span> <span class="n">computed</span> <span class="n">correctly</span> <span class="p">(</span><span class="n">QK</span><span class="o">^</span><span class="n">T</span> <span class="n">shape</span> <span class="n">verified</span><span class="p">)</span>
<span class="err"></span> <span class="n">Scaling</span> <span class="n">factor</span> <span class="mi">1</span><span class="o">/</span><span class="err"></span><span class="n">d_k</span> <span class="n">applied</span>
<span class="err"></span> <span class="n">Softmax</span> <span class="n">normalization</span> <span class="n">verified</span> <span class="p">(</span><span class="n">each</span> <span class="n">row</span> <span class="n">sums</span> <span class="n">to</span> <span class="mf">1.0</span><span class="p">)</span>
<span class="err"></span> <span class="n">Output</span> <span class="n">shape</span> <span class="n">matches</span> <span class="n">expected</span> <span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="err"></span> <span class="n">Causal</span> <span class="n">masking</span> <span class="n">blocks</span> <span class="n">future</span> <span class="n">positions</span> <span class="n">correctly</span>
<span class="err">📈</span> <span class="n">Progress</span><span class="p">:</span> <span class="n">Scaled</span> <span class="n">Dot</span><span class="o">-</span><span class="n">Product</span> <span class="n">Attention</span> <span class="err"></span>
<span class="err">🔬</span> <span class="n">Unit</span> <span class="n">Test</span><span class="p">:</span> <span class="n">Multi</span><span class="o">-</span><span class="n">Head</span> <span class="n">Attention</span><span class="o">...</span>
<span class="err"></span> <span class="mi">8</span> <span class="n">heads</span> <span class="n">process</span> <span class="mi">512</span> <span class="n">dimensions</span> <span class="ow">in</span> <span class="n">parallel</span>
<span class="err"></span> <span class="n">Head</span> <span class="n">splitting</span> <span class="ow">and</span> <span class="n">concatenation</span> <span class="n">correct</span>
<span class="err"></span> <span class="n">Q</span><span class="o">/</span><span class="n">K</span><span class="o">/</span><span class="n">V</span> <span class="n">projection</span> <span class="n">layers</span> <span class="n">initialized</span> <span class="n">properly</span>
<span class="err"></span> <span class="n">Output</span> <span class="n">projection</span> <span class="n">applied</span>
<span class="err"></span> <span class="n">Shape</span><span class="p">:</span> <span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="mi">512</span><span class="p">)</span> <span class="err"></span> <span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="mi">512</span><span class="p">)</span> <span class="err"></span>
<span class="err">📈</span> <span class="n">Progress</span><span class="p">:</span> <span class="n">Multi</span><span class="o">-</span><span class="n">Head</span> <span class="n">Attention</span> <span class="err"></span>
<span class="err">📊</span> <span class="n">Analyzing</span> <span class="n">Attention</span> <span class="n">Complexity</span><span class="o">...</span>
<span class="n">Seq</span> <span class="n">Len</span> <span class="o">|</span> <span class="n">Attention</span> <span class="n">Matrix</span> <span class="o">|</span> <span class="n">Memory</span> <span class="p">(</span><span class="n">KB</span><span class="p">)</span> <span class="o">|</span> <span class="n">Scaling</span>
<span class="o">--------------------------------------------------------</span>
<span class="mi">16</span> <span class="o">|</span> <span class="mi">256</span> <span class="o">|</span> <span class="mf">1.00</span> <span class="o">|</span> <span class="mf">1.0</span><span class="n">x</span>
<span class="mi">32</span> <span class="o">|</span> <span class="mi">1</span><span class="p">,</span><span class="mi">024</span> <span class="o">|</span> <span class="mf">4.00</span> <span class="o">|</span> <span class="mf">4.0</span><span class="n">x</span>
<span class="mi">64</span> <span class="o">|</span> <span class="mi">4</span><span class="p">,</span><span class="mi">096</span> <span class="o">|</span> <span class="mf">16.00</span> <span class="o">|</span> <span class="mf">4.0</span><span class="n">x</span>
<span class="mi">128</span> <span class="o">|</span> <span class="mi">16</span><span class="p">,</span><span class="mi">384</span> <span class="o">|</span> <span class="mf">64.00</span> <span class="o">|</span> <span class="mf">4.0</span><span class="n">x</span>
<span class="mi">256</span> <span class="o">|</span> <span class="mi">65</span><span class="p">,</span><span class="mi">536</span> <span class="o">|</span> <span class="mf">256.00</span> <span class="o">|</span> <span class="mf">4.0</span><span class="n">x</span>
<span class="err">💡</span> <span class="n">Memory</span> <span class="n">scales</span> <span class="k">as</span> <span class="n">O</span><span class="p">(</span><span class="n">n</span><span class="err">²</span><span class="p">)</span> <span class="k">with</span> <span class="n">sequence</span> <span class="n">length</span>
<span class="err">🚀</span> <span class="n">For</span> <span class="n">seq_len</span><span class="o">=</span><span class="mi">2048</span> <span class="p">(</span><span class="n">GPT</span><span class="o">-</span><span class="mi">3</span><span class="p">),</span> <span class="n">attention</span> <span class="n">matrix</span> <span class="n">needs</span> <span class="mi">16</span> <span class="n">MB</span> <span class="n">per</span> <span class="n">layer</span>
</pre></div>
</div>
</section>
<section id="manual-testing-examples">
<h3>Manual Testing Examples<a class="headerlink" href="#manual-testing-examples" title="Link to this heading">#</a></h3>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">attention_dev</span><span class="w"> </span><span class="kn">import</span> <span class="n">scaled_dot_product_attention</span><span class="p">,</span> <span class="n">MultiHeadAttention</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">tinytorch.core.tensor</span><span class="w"> </span><span class="kn">import</span> <span class="n">Tensor</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
<span class="c1"># Test 1: Basic scaled dot-product attention</span>
<span class="n">batch</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">64</span>
<span class="n">Q</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span>
<span class="n">K</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span>
<span class="n">V</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span>
<span class="n">output</span><span class="p">,</span> <span class="n">weights</span> <span class="o">=</span> <span class="n">scaled_dot_product_attention</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Output shape: </span><span class="si">{</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="c1"># (2, 10, 64)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Weights shape: </span><span class="si">{</span><span class="n">weights</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="c1"># (2, 10, 10)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Weights sum: </span><span class="si">{</span><span class="n">weights</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="c1"># All ~1.0</span>
<span class="c1"># Test 2: Multi-head attention</span>
<span class="n">mha</span> <span class="o">=</span> <span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">128</span><span class="p">))</span>
<span class="n">attended</span> <span class="o">=</span> <span class="n">mha</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Multi-head output: </span><span class="si">{</span><span class="n">attended</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="c1"># (2, 10, 128)</span>
<span class="c1"># Test 3: Causal masking for language modeling</span>
<span class="n">causal_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">tril</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">))))</span>
<span class="n">causal_output</span><span class="p">,</span> <span class="n">causal_weights</span> <span class="o">=</span> <span class="n">scaled_dot_product_attention</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">causal_mask</span><span class="p">)</span>
<span class="c1"># Verify upper triangle is zero (no future attention)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Future attention blocked:&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">causal_weights</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">:],</span> <span class="mi">0</span><span class="p">))</span>
<span class="c1"># Test 4: Visualize attention patterns</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Attention pattern (position → position):&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">weights</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="p">:</span><span class="mi">5</span><span class="p">,</span> <span class="p">:</span><span class="mi">5</span><span class="p">]</span><span class="o">.</span><span class="n">round</span><span class="p">(</span><span class="mi">3</span><span class="p">))</span> <span class="c1"># First 5x5 submatrix</span>
</pre></div>
</div>
</section>
</section>
<section id="systems-thinking-questions">
<h2>Systems Thinking Questions<a class="headerlink" href="#systems-thinking-questions" title="Link to this heading">#</a></h2>
<section id="real-world-applications">
<h3>Real-World Applications<a class="headerlink" href="#real-world-applications" title="Link to this heading">#</a></h3>
<ul class="simple">
<li><p><strong>Large Language Models (GPT-4, Claude)</strong>: 96+ layers with 128 heads each means 12,288+ parallel attention operations per forward pass; attention accounts for 70% of total compute</p></li>
<li><p><strong>Machine Translation (Google Translate)</strong>: Cross-attention between source and target languages enables word alignment; attention weights provide interpretable translation decisions</p></li>
<li><p><strong>Vision Transformers (ViT)</strong>: Self-attention over image patches replaced convolutions at Google/Meta/OpenAI; global receptive field from layer 1 vs deep CNN stacks</p></li>
<li><p><strong>Scientific AI (AlphaFold2)</strong>: Attention over protein sequences captures amino acid interactions; solved 50-year protein folding problem using transformer architecture</p></li>
</ul>
</section>
<section id="mathematical-foundations">
<h3>Mathematical Foundations<a class="headerlink" href="#mathematical-foundations" title="Link to this heading">#</a></h3>
<ul class="simple">
<li><p><strong>Query-Key-Value Paradigm</strong>: Attention implements differentiable “search”—queries look for relevant keys and retrieve corresponding values</p></li>
<li><p><strong>Scaling Factor (1/√d_k)</strong>: For unit variance Q and K, QK^T has variance d_k; dividing by √d_k restores unit variance, keeping softmax responsive (critical for gradient flow)</p></li>
<li><p><strong>Softmax Normalization</strong>: Converts arbitrary scores to valid probability distribution; enables differentiable, learned routing mechanism</p></li>
<li><p><strong>Masking Implementation</strong>: Setting masked positions to -∞ before softmax makes them effectively zero attention weight after normalization</p></li>
</ul>
</section>
<section id="computational-characteristics">
<h3>Computational Characteristics<a class="headerlink" href="#computational-characteristics" title="Link to this heading">#</a></h3>
<ul class="simple">
<li><p><strong>Quadratic Memory Scaling</strong>: Attention matrix is O(n²); for GPT-3 scale (96 layers, 2048 context), attention weights alone require ~1.5 GB—understanding this guides optimization priorities</p></li>
<li><p><strong>Time-Memory Trade-off</strong>: Can avoid storing attention matrix and recompute in backward pass (gradient checkpointing) at cost of 2× compute</p></li>
<li><p><strong>Parallelization Benefits</strong>: Unlike RNNs, all n² attention scores compute simultaneously; fully utilizes GPU parallelism for massive speedup</p></li>
<li><p><strong>FlashAttention Breakthrough</strong>: Reformulates computation order to reduce memory from O(n²) to O(n) via kernel fusion—enables 2-4× speedup and longer contexts (8K+ tokens)</p></li>
</ul>
</section>
<section id="how-your-implementation-maps-to-pytorch">
<h3>How Your Implementation Maps to PyTorch<a class="headerlink" href="#how-your-implementation-maps-to-pytorch" title="Link to this heading">#</a></h3>
<p><strong>What you just built:</strong></p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Your TinyTorch attention implementation</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">tinytorch.core.attention</span><span class="w"> </span><span class="kn">import</span> <span class="n">MultiheadAttention</span>
<span class="c1"># Create multi-head attention</span>
<span class="n">mha</span> <span class="o">=</span> <span class="n">MultiheadAttention</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
<span class="c1"># Forward pass</span>
<span class="n">query</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> <span class="c1"># (batch, seq_len, embed_dim)</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="o">...</span><span class="p">)</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="o">...</span><span class="p">)</span>
<span class="c1"># Compute attention: YOUR implementation</span>
<span class="n">output</span><span class="p">,</span> <span class="n">attn_weights</span> <span class="o">=</span> <span class="n">mha</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">causal_mask</span><span class="p">)</span>
<span class="c1"># output shape: (batch, seq_len, embed_dim)</span>
<span class="c1"># attn_weights shape: (batch, num_heads, seq_len, seq_len)</span>
</pre></div>
</div>
<p><strong>How PyTorch does it:</strong></p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># PyTorch equivalent</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
<span class="c1"># Create multi-head attention</span>
<span class="n">mha</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiheadAttention</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># Forward pass</span>
<span class="n">query</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> <span class="c1"># (batch, seq_len, embed_dim)</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="o">...</span><span class="p">)</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="o">...</span><span class="p">)</span>
<span class="c1"># Compute attention: PyTorch implementation</span>
<span class="n">output</span><span class="p">,</span> <span class="n">attn_weights</span> <span class="o">=</span> <span class="n">mha</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">,</span> <span class="n">attn_mask</span><span class="o">=</span><span class="n">causal_mask</span><span class="p">)</span>
<span class="c1"># Same shapes, identical semantics</span>
</pre></div>
</div>
<p><strong>Key Insight</strong>: Your attention implementation computes the <strong>exact same mathematical formula</strong> that powers GPT, BERT, and every transformer model:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>Attention(Q, K, V) = softmax(QK^T / √d_k) V
</pre></div>
</div>
<p>When you implement this with explicit loops, you viscerally understand the O(n²) memory scaling that limits context length in production transformers.</p>
<p><strong>Whats the SAME?</strong></p>
<ul class="simple">
<li><p><strong>Core formula</strong>: Scaled dot-product attention (Vaswani et al., 2017)</p></li>
<li><p><strong>Multi-head architecture</strong>: Parallel attention in representation subspaces</p></li>
<li><p><strong>Masking patterns</strong>: Causal masking (GPT), padding masking (BERT)</p></li>
<li><p><strong>API design</strong>: <code class="docutils literal notranslate"><span class="pre">(query,</span> <span class="pre">key,</span> <span class="pre">value)</span></code> inputs, attention weights output</p></li>
<li><p><strong>Conceptual bottleneck</strong>: O(n²) memory for attention matrix</p></li>
</ul>
<p><strong>Whats different in production PyTorch?</strong></p>
<ul class="simple">
<li><p><strong>Backend</strong>: C++/CUDA kernels ~10-100× faster than Python loops</p></li>
<li><p><strong>Memory optimization</strong>: Fused kernels avoid materializing full attention matrix</p></li>
<li><p><strong>FlashAttention</strong>: PyTorch 2.0+ uses optimized attention (O(n) memory vs your O(n²))</p></li>
<li><p><strong>Multi-query attention</strong>: Production systems use grouped-query attention (GQA) to reduce KV cache size</p></li>
</ul>
<p><strong>Why this matters</strong>: When you see <code class="docutils literal notranslate"><span class="pre">RuntimeError:</span> <span class="pre">CUDA</span> <span class="pre">out</span> <span class="pre">of</span> <span class="pre">memory</span></code> training transformers with long sequences, you understand its the O(n²) attention matrix from YOUR implementation—doubling sequence length quadruples memory. When papers mention “linear attention” or “flash attention”, you know theyre solving the scaling bottleneck you experienced.</p>
<p><strong>Production usage example</strong>:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># PyTorch Transformer implementation (after TinyTorch)</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
<span class="k">class</span><span class="w"> </span><span class="nc">TransformerBlock</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="c1"># Uses same multi-head attention you built</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mha</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiheadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ffn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">d_model</span><span class="p">),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="c1"># Same pattern you implemented</span>
<span class="n">attn_out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mha</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">attn_mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span> <span class="c1"># YOUR attention logic</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">attn_out</span> <span class="c1"># Residual connection</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span>
</pre></div>
</div>
<p>After implementing attention yourself, you understand that GPTs causal attention is your <code class="docutils literal notranslate"><span class="pre">mask=causal_mask</span></code>, BERTs bidirectional attention is your <code class="docutils literal notranslate"><span class="pre">mask=padding_mask</span></code>, and every transformers O(n²) scaling comes from the attention matrix you explicitly computed in your implementation.</p>
</section>
</section>
<section id="ready-to-build">
<h2>Ready to Build?<a class="headerlink" href="#ready-to-build" title="Link to this heading">#</a></h2>
<p>Youre about to implement the mechanism that sparked the AI revolution and powers every modern language model. Understanding attention from first principles—including its computational bottlenecks—will give you deep insight into why transformers dominate AI and what limitations remain.</p>
<p><strong>Your Mission</strong>: Implement scaled dot-product attention with explicit loops to viscerally understand O(n²) complexity. Build multi-head attention that processes parallel representation subspaces. Master causal and padding masking for different architectural patterns. Test on real sequences, visualize attention patterns, and profile computational scaling.</p>
<p><strong>Why This Matters</strong>: The attention mechanism youre building didnt just improve NLP—it unified deep learning across all domains. GPT, BERT, Vision Transformers, AlphaFold, DALL-E, and Claude all use the exact formula youre implementing. Understanding attentions power (global context, parallelizable) and limitations (quadratic scaling) is essential for working with production AI systems.</p>
<p><strong>After Completion</strong>: Module 13 (Transformers) will combine your attention with feedforward layers and normalization to build complete transformer blocks. Module 14 (Profiling) will measure your attentions O(n²) scaling and identify optimization opportunities. Module 18 (Acceleration) will implement FlashAttention-style optimizations for your mechanism.</p>
<p>Choose your preferred way to engage with this module:</p>
<div class="sd-container-fluid sd-sphinx-override sd-mb-4 docutils">
<div class="sd-row sd-row-cols-1 sd-row-cols-xs-1 sd-row-cols-sm-2 sd-row-cols-md-3 sd-row-cols-lg-3 docutils">
<div class="sd-col sd-d-flex-row docutils">
<div class="sd-card sd-sphinx-override sd-w-100 sd-shadow-sm sd-card-hover docutils">
<div class="sd-card-body docutils">
<div class="sd-card-title sd-font-weight-bold docutils">
🚀 Launch Binder</div>
<p class="sd-card-text">Run this module interactively in your browser. No installation required!</p>
</div>
<a class="sd-stretched-link sd-hide-link-text reference external" href="https://mybinder.org/v2/gh/mlsysbook/TinyTorch/main?filepath=modules/12_attention/attention_dev.ipynb"><span>https://mybinder.org/v2/gh/mlsysbook/TinyTorch/main?filepath=modules/12_attention/attention_dev.ipynb</span></a></div>
</div>
<div class="sd-col sd-d-flex-row docutils">
<div class="sd-card sd-sphinx-override sd-w-100 sd-shadow-sm sd-card-hover docutils">
<div class="sd-card-body docutils">
<div class="sd-card-title sd-font-weight-bold docutils">
⚡ Open in Colab</div>
<p class="sd-card-text">Use Google Colab for GPU access and cloud compute power.</p>
</div>
<a class="sd-stretched-link sd-hide-link-text reference external" href="https://colab.research.google.com/github/mlsysbook/TinyTorch/blob/main/modules/12_attention/attention_dev.ipynb"><span>https://colab.research.google.com/github/mlsysbook/TinyTorch/blob/main/modules/12_attention/attention_dev.ipynb</span></a></div>
</div>
<div class="sd-col sd-d-flex-row docutils">
<div class="sd-card sd-sphinx-override sd-w-100 sd-shadow-sm sd-card-hover docutils">
<div class="sd-card-body docutils">
<div class="sd-card-title sd-font-weight-bold docutils">
📖 View Source</div>
<p class="sd-card-text">Browse the notebook source code and understand the implementation.</p>
</div>
<a class="sd-stretched-link sd-hide-link-text reference external" href="https://github.com/mlsysbook/TinyTorch/blob/main/modules/12_attention/attention_dev.ipynb"><span>https://github.com/mlsysbook/TinyTorch/blob/main/modules/12_attention/attention_dev.ipynb</span></a></div>
</div>
</div>
</div>
<div class="tip admonition">
<p class="admonition-title">💾 Save Your Progress</p>
<p><strong>Binder sessions are temporary!</strong> Download your completed notebook when done, or switch to local development for persistent work.</p>
</div>
<hr class="docutils" />
<div class="prev-next-area">
<a class="left-prev" href="../chapters/11_embeddings.html" title="previous page">← Module 11: Embeddings</a>
<a class="right-next" href="../chapters/13_transformers.html" title="next page">Module 13: Transformers →</a>
</div>
</section>
</section>
<script type="text/x-thebe-config">
{
requestKernel: true,
binderOptions: {
repo: "binder-examples/jupyter-stacks-datascience",
ref: "master",
},
codeMirrorConfig: {
theme: "abcdef",
mode: "python"
},
kernelOptions: {
name: "python3",
path: "./modules"
},
predefinedOutput: true
}
</script>
<script>kernelName = 'python3'</script>
</article>
<footer class="prev-next-footer d-print-none">
<div class="prev-next-area">
<a class="left-prev"
href="11_embeddings_ABOUT.html"
title="previous page">
<i class="fa-solid fa-angle-left"></i>
<div class="prev-next-info">
<p class="prev-next-subtitle">previous</p>
<p class="prev-next-title">11. Embeddings - Token to Vector Representations</p>
</div>
</a>
<a class="right-next"
href="13_transformers_ABOUT.html"
title="next page">
<div class="prev-next-info">
<p class="prev-next-subtitle">next</p>
<p class="prev-next-title">13. Transformers - Complete GPT Architecture</p>
</div>
<i class="fa-solid fa-angle-right"></i>
</a>
</div>
</footer>
</div>
<div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner">
<div class="sidebar-secondary-item">
<div class="page-toc tocsection onthispage">
<i class="fa-solid fa-list"></i> Contents
</div>
<nav class="bd-toc-nav page-toc">
<ul class="visible nav section-nav flex-column">
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#overview">Overview</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#learning-objectives">Learning Objectives</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#build-use-reflect">Build → Use → Reflect</a></li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#implementation-guide">Implementation Guide</a><ul class="nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#attention-mechanism-flow">Attention Mechanism Flow</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#core-components">Core Components</a><ul class="nav section-nav flex-column">
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#scaled-dot-product-attention-scaled-dot-product-attention">1. Scaled Dot-Product Attention (<code class="docutils literal notranslate"><span class="pre">scaled_dot_product_attention</span></code>)</a></li>
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#multi-head-attention-multiheadattention">2. Multi-Head Attention (<code class="docutils literal notranslate"><span class="pre">MultiHeadAttention</span></code>)</a></li>
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#masking-utilities">3. Masking Utilities</a></li>
</ul>
</li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#attention-complexity-analysis">Attention Complexity Analysis</a><ul class="nav section-nav flex-column">
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#time-complexity-o-n2-d">Time Complexity: O(n² × d)</a></li>
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#memory-complexity-o-batch-heads-n2">Memory Complexity: O(batch × heads × n²)</a></li>
</ul>
</li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#comparing-to-pytorch">Comparing to PyTorch</a></li>
</ul>
</li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#getting-started">Getting Started</a><ul class="nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#prerequisites">Prerequisites</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#development-workflow">Development Workflow</a></li>
</ul>
</li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#testing">Testing</a><ul class="nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#comprehensive-test-suite">Comprehensive Test Suite</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#test-coverage-areas">Test Coverage Areas</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#inline-testing-complexity-analysis">Inline Testing &amp; Complexity Analysis</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#manual-testing-examples">Manual Testing Examples</a></li>
</ul>
</li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#systems-thinking-questions">Systems Thinking Questions</a><ul class="nav section-nav flex-column">
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#real-world-applications">Real-World Applications</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#mathematical-foundations">Mathematical Foundations</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#computational-characteristics">Computational Characteristics</a></li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#how-your-implementation-maps-to-pytorch">How Your Implementation Maps to PyTorch</a></li>
</ul>
</li>
<li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#ready-to-build">Ready to Build?</a></li>
</ul>
</nav></div>
</div></div>
</div>
<footer class="bd-footer-content">
<div class="bd-footer-content__inner container">
<div class="footer-item">
<p class="component-author">
By Prof. Vijay Janapa Reddi (Harvard University)
</p>
</div>
<div class="footer-item">
<p class="copyright">
© Copyright 2025.
<br/>
</p>
</div>
<div class="footer-item">
</div>
<div class="footer-item">
</div>
</div>
</footer>
</main>
</div>
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script src="../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script>
<script src="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script>
<footer class="bd-footer">
</footer>
</body>
</html>