mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-05-03 08:08:51 -05:00
Checkpoint the branch-wide content/config revisions together with workbench enhancements so chapter rendering and developer workflows stay aligned. This captures the current validation-driven formatting and parallel build/debug improvements in one commit.
154 lines
4.5 KiB
Python
154 lines
4.5 KiB
Python
|
|
import sys
|
|
import os
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as patches
|
|
import numpy as np
|
|
|
|
# Add book/quarto/mlsys to path to import viz
|
|
# Script is in book/tools/scripts/genai/
|
|
# We need to reach book/quarto/mlsys
|
|
# ../../../quarto/mlsys
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../quarto/mlsys")))
|
|
|
|
try:
|
|
import viz
|
|
viz.set_book_style()
|
|
COLORS = viz.COLORS
|
|
except ImportError:
|
|
print("Warning: Could not import viz.py, using fallback style.")
|
|
COLORS = {
|
|
"primary": "#333333",
|
|
"RedLine": "#CB202D",
|
|
"BlueLine": "#006395",
|
|
"GreenLine": "#008F45",
|
|
"OrangeLine": "#E67817",
|
|
"grid": "#CCCCCC"
|
|
}
|
|
plt.style.use('seaborn-v0_8-whitegrid')
|
|
|
|
# Custom Yellow that fits the palette better than pure yellow but is distinct from Orange
|
|
COLORS["YellowLine"] = "#F4D03F"
|
|
|
|
def draw_node(ax, x, y, label=None, radius=0.15, color="#E0E0E0"):
|
|
"""Draws a circular node."""
|
|
circle = patches.Circle((x, y), radius, facecolor=color, edgecolor=COLORS["primary"], linewidth=1.5, zorder=10)
|
|
ax.add_patch(circle)
|
|
if label:
|
|
ax.text(x, y, label, ha='center', va='center', fontsize=14, fontweight='bold', zorder=11, color=COLORS["primary"])
|
|
return circle
|
|
|
|
def draw_rect(ax, x, y, width, height, color):
|
|
"""Draws a data rectangle centered at (x, y)."""
|
|
rect = patches.Rectangle((x - width/2, y - height/2), width, height,
|
|
facecolor=color, edgecolor=None, zorder=15)
|
|
ax.add_patch(rect)
|
|
|
|
def draw_arrow(ax, start, end):
|
|
"""Draws an arrow from start to end."""
|
|
ax.annotate("", xy=end, xytext=start,
|
|
arrowprops=dict(arrowstyle="-|>", color=COLORS["primary"], lw=1.5, shrinkA=0, shrinkB=0),
|
|
zorder=5)
|
|
|
|
def setup_subplot(ax, title):
|
|
ax.set_xlim(-1.2, 1.2)
|
|
ax.set_ylim(-0.2, 1.8)
|
|
ax.axis('off')
|
|
# Title at the bottom
|
|
ax.text(0, -0.15, title, ha='center', va='center', fontsize=14, fontweight='bold', color=COLORS["primary"])
|
|
|
|
def generate_diagram():
|
|
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
|
|
|
|
# Node positions
|
|
src_y = 0.2
|
|
dest_y = 1.4
|
|
|
|
# Horizontal spacing for 4 nodes
|
|
xs = np.linspace(-0.9, 0.9, 4)
|
|
|
|
# Colors
|
|
c_red = COLORS["RedLine"]
|
|
c_yellow = COLORS["YellowLine"]
|
|
c_green = COLORS["GreenLine"]
|
|
c_blue = COLORS["BlueLine"]
|
|
palette = [c_red, c_yellow, c_green, c_blue]
|
|
|
|
# --- 1. Broadcast (Top Left) ---
|
|
ax = axes[0, 0]
|
|
setup_subplot(ax, "Broadcast")
|
|
|
|
# Source (Bottom)
|
|
draw_node(ax, 0, src_y)
|
|
draw_rect(ax, 0, src_y, 0.08, 0.15, c_red)
|
|
|
|
# Dests (Top)
|
|
for x in xs:
|
|
draw_node(ax, x, dest_y)
|
|
draw_rect(ax, x, dest_y, 0.08, 0.15, c_red)
|
|
# Arrow
|
|
draw_arrow(ax, (0, src_y + 0.15), (x, dest_y - 0.15))
|
|
|
|
# --- 2. Scatter (Top Right) ---
|
|
ax = axes[0, 1]
|
|
setup_subplot(ax, "Scatter")
|
|
|
|
# Source (Bottom)
|
|
draw_node(ax, 0, src_y)
|
|
# Composite block at source
|
|
w = 0.06
|
|
h = 0.12
|
|
total_w = 4 * w
|
|
start_x = -total_w / 2 + w/2
|
|
for i, color in enumerate(palette):
|
|
draw_rect(ax, start_x + i*w, src_y, w, h, color)
|
|
|
|
# Dests (Top)
|
|
for i, x in enumerate(xs):
|
|
draw_node(ax, x, dest_y)
|
|
draw_rect(ax, x, dest_y, 0.08, 0.15, palette[i])
|
|
# Arrow
|
|
draw_arrow(ax, (0, src_y + 0.15), (x, dest_y - 0.15))
|
|
|
|
# --- 3. Gather (Bottom Left) ---
|
|
ax = axes[1, 0]
|
|
setup_subplot(ax, "Gather")
|
|
|
|
# Dest (Bottom)
|
|
draw_node(ax, 0, src_y)
|
|
# Composite block at dest
|
|
start_x = -total_w / 2 + w/2
|
|
for i, color in enumerate(palette):
|
|
draw_rect(ax, start_x + i*w, src_y, w, h, color)
|
|
|
|
# Sources (Top)
|
|
for i, x in enumerate(xs):
|
|
draw_node(ax, x, dest_y)
|
|
draw_rect(ax, x, dest_y, 0.08, 0.15, palette[i])
|
|
# Arrow
|
|
draw_arrow(ax, (x, dest_y - 0.15), (0, src_y + 0.15))
|
|
|
|
# --- 4. Reduction (Bottom Right) ---
|
|
ax = axes[1, 1]
|
|
setup_subplot(ax, "Reduction")
|
|
|
|
# Dest (Bottom)
|
|
draw_node(ax, 0, src_y, label="16")
|
|
|
|
# Sources (Top)
|
|
values = ["1", "3", "5", "7"]
|
|
for i, x in enumerate(xs):
|
|
draw_node(ax, x, dest_y, label=values[i])
|
|
# Arrow
|
|
draw_arrow(ax, (x, dest_y - 0.15), (0, src_y + 0.15))
|
|
|
|
plt.tight_layout()
|
|
|
|
# Save
|
|
output_path = os.path.abspath("comm_primitives.png")
|
|
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
|
print(f"Generated image at: {output_path}")
|
|
|
|
if __name__ == "__main__":
|
|
generate_diagram()
|