Files
cs249r_book/book/tools/scripts/genai/generate_comm_primitives_diagram.py
Vijay Janapa Reddi 73a956a09b chore(volumes,vscode-ext): batch volume updates and tooling improvements
Checkpoint the branch-wide content/config revisions together with workbench enhancements so chapter rendering and developer workflows stay aligned. This captures the current validation-driven formatting and parallel build/debug improvements in one commit.
2026-02-15 14:03:27 -05:00

154 lines
4.5 KiB
Python

import sys
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
# Add book/quarto/mlsys to path to import viz
# Script is in book/tools/scripts/genai/
# We need to reach book/quarto/mlsys
# ../../../quarto/mlsys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../quarto/mlsys")))
try:
import viz
viz.set_book_style()
COLORS = viz.COLORS
except ImportError:
print("Warning: Could not import viz.py, using fallback style.")
COLORS = {
"primary": "#333333",
"RedLine": "#CB202D",
"BlueLine": "#006395",
"GreenLine": "#008F45",
"OrangeLine": "#E67817",
"grid": "#CCCCCC"
}
plt.style.use('seaborn-v0_8-whitegrid')
# Custom Yellow that fits the palette better than pure yellow but is distinct from Orange
COLORS["YellowLine"] = "#F4D03F"
def draw_node(ax, x, y, label=None, radius=0.15, color="#E0E0E0"):
"""Draws a circular node."""
circle = patches.Circle((x, y), radius, facecolor=color, edgecolor=COLORS["primary"], linewidth=1.5, zorder=10)
ax.add_patch(circle)
if label:
ax.text(x, y, label, ha='center', va='center', fontsize=14, fontweight='bold', zorder=11, color=COLORS["primary"])
return circle
def draw_rect(ax, x, y, width, height, color):
"""Draws a data rectangle centered at (x, y)."""
rect = patches.Rectangle((x - width/2, y - height/2), width, height,
facecolor=color, edgecolor=None, zorder=15)
ax.add_patch(rect)
def draw_arrow(ax, start, end):
"""Draws an arrow from start to end."""
ax.annotate("", xy=end, xytext=start,
arrowprops=dict(arrowstyle="-|>", color=COLORS["primary"], lw=1.5, shrinkA=0, shrinkB=0),
zorder=5)
def setup_subplot(ax, title):
ax.set_xlim(-1.2, 1.2)
ax.set_ylim(-0.2, 1.8)
ax.axis('off')
# Title at the bottom
ax.text(0, -0.15, title, ha='center', va='center', fontsize=14, fontweight='bold', color=COLORS["primary"])
def generate_diagram():
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
# Node positions
src_y = 0.2
dest_y = 1.4
# Horizontal spacing for 4 nodes
xs = np.linspace(-0.9, 0.9, 4)
# Colors
c_red = COLORS["RedLine"]
c_yellow = COLORS["YellowLine"]
c_green = COLORS["GreenLine"]
c_blue = COLORS["BlueLine"]
palette = [c_red, c_yellow, c_green, c_blue]
# --- 1. Broadcast (Top Left) ---
ax = axes[0, 0]
setup_subplot(ax, "Broadcast")
# Source (Bottom)
draw_node(ax, 0, src_y)
draw_rect(ax, 0, src_y, 0.08, 0.15, c_red)
# Dests (Top)
for x in xs:
draw_node(ax, x, dest_y)
draw_rect(ax, x, dest_y, 0.08, 0.15, c_red)
# Arrow
draw_arrow(ax, (0, src_y + 0.15), (x, dest_y - 0.15))
# --- 2. Scatter (Top Right) ---
ax = axes[0, 1]
setup_subplot(ax, "Scatter")
# Source (Bottom)
draw_node(ax, 0, src_y)
# Composite block at source
w = 0.06
h = 0.12
total_w = 4 * w
start_x = -total_w / 2 + w/2
for i, color in enumerate(palette):
draw_rect(ax, start_x + i*w, src_y, w, h, color)
# Dests (Top)
for i, x in enumerate(xs):
draw_node(ax, x, dest_y)
draw_rect(ax, x, dest_y, 0.08, 0.15, palette[i])
# Arrow
draw_arrow(ax, (0, src_y + 0.15), (x, dest_y - 0.15))
# --- 3. Gather (Bottom Left) ---
ax = axes[1, 0]
setup_subplot(ax, "Gather")
# Dest (Bottom)
draw_node(ax, 0, src_y)
# Composite block at dest
start_x = -total_w / 2 + w/2
for i, color in enumerate(palette):
draw_rect(ax, start_x + i*w, src_y, w, h, color)
# Sources (Top)
for i, x in enumerate(xs):
draw_node(ax, x, dest_y)
draw_rect(ax, x, dest_y, 0.08, 0.15, palette[i])
# Arrow
draw_arrow(ax, (x, dest_y - 0.15), (0, src_y + 0.15))
# --- 4. Reduction (Bottom Right) ---
ax = axes[1, 1]
setup_subplot(ax, "Reduction")
# Dest (Bottom)
draw_node(ax, 0, src_y, label="16")
# Sources (Top)
values = ["1", "3", "5", "7"]
for i, x in enumerate(xs):
draw_node(ax, x, dest_y, label=values[i])
# Arrow
draw_arrow(ax, (x, dest_y - 0.15), (0, src_y + 0.15))
plt.tight_layout()
# Save
output_path = os.path.abspath("comm_primitives.png")
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"Generated image at: {output_path}")
if __name__ == "__main__":
generate_diagram()