mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2025-12-05 19:17:52 -06:00
Add table support to ASCII box fixer and fix table alignment
- Add table detection (┬ ┼ ┴ column separators) - Fix table alignment by adjusting cell widths - Flag tables with content wider than headers for manual review - Manually fix tables in 04_losses.py (expanded column widths) - Fix table in 01_tensor.py
This commit is contained in:
@@ -834,7 +834,7 @@ Neural Network Usage:
|
||||
┌─────────────────────┬─────────────────────┬─────────────────────┐
|
||||
│ Weight Matrices │ Attention Mechanism │ Gradient Computation│
|
||||
├─────────────────────┼─────────────────────┼─────────────────────┤
|
||||
│ Forward: X @ W │ Q @ K^T attention │ ∂L/∂W = X^T @ ∂L/∂Y│
|
||||
│ Forward: X @ W │ Q @ K^T attention │ ∂L/∂W = X^T @ ∂L/∂Y │
|
||||
│ Backward: X @ W^T │ scores │ │
|
||||
└─────────────────────┴─────────────────────┴─────────────────────┘
|
||||
```
|
||||
|
||||
@@ -532,14 +532,14 @@ Loss: -log(0.003) = 5.8 ← Very high loss ❌
|
||||
```
|
||||
What Cross-Entropy Teaches the Model:
|
||||
|
||||
┌─────────────────┬─────────────────┬─────────────────┐
|
||||
│ Prediction │ True Label │ Learning Signal │
|
||||
├─────────────────┼─────────────────┼─────────────────┤
|
||||
│ Confident ✅ │ Correct ✅ │ "Keep doing this"│
|
||||
│ Uncertain ⚠️ │ Correct ✅ │ "Be more confident"│
|
||||
│ Confident ❌ │ Wrong ❌ │ "STOP! Change everything"│
|
||||
│ Uncertain ⚠️ │ Wrong ❌ │ "Learn the right answer"│
|
||||
└─────────────────┴─────────────────┴─────────────────┘
|
||||
┌─────────────────┬─────────────────┬───────────────────────────┐
|
||||
│ Prediction │ True Label │ Learning Signal │
|
||||
├─────────────────┼─────────────────┼───────────────────────────┤
|
||||
│ Confident ✅ │ Correct ✅ │ "Keep doing this" │
|
||||
│ Uncertain ⚠️ │ Correct ✅ │ "Be more confident" │
|
||||
│ Confident ❌ │ Wrong ❌ │ "STOP! Change everything" │
|
||||
│ Uncertain ⚠️ │ Wrong ❌ │ "Learn the right answer" │
|
||||
└─────────────────┴─────────────────┴───────────────────────────┘
|
||||
|
||||
Loss Landscape by Confidence:
|
||||
Loss
|
||||
@@ -1043,26 +1043,26 @@ Different loss functions have different computational costs, especially at scale
|
||||
Computational Cost Comparison (Batch Size B, Classes C):
|
||||
|
||||
MSELoss:
|
||||
┌───────────────┬───────────────┐
|
||||
┌────────────────┬────────────────┐
|
||||
│ Operation │ Complexity │
|
||||
├───────────────┼───────────────┤
|
||||
├────────────────┼────────────────┤
|
||||
│ Subtraction │ O(B) │
|
||||
│ Squaring │ O(B) │
|
||||
│ Mean │ O(B) │
|
||||
│ Total │ O(B) │
|
||||
└───────────────┴───────────────┘
|
||||
└────────────────┴────────────────┘
|
||||
|
||||
CrossEntropyLoss:
|
||||
┌───────────────┬───────────────┐
|
||||
┌────────────────┬────────────────┐
|
||||
│ Operation │ Complexity │
|
||||
├───────────────┼───────────────┤
|
||||
├────────────────┼────────────────┤
|
||||
│ Max (stability)│ O(B*C) │
|
||||
│ Exponential │ O(B*C) │
|
||||
│ Sum │ O(B*C) │
|
||||
│ Log │ O(B) │
|
||||
│ Indexing │ O(B) │
|
||||
│ Total │ O(B*C) │
|
||||
└───────────────┴───────────────┘
|
||||
└────────────────┴────────────────┘
|
||||
|
||||
Cross-entropy is C times more expensive than MSE!
|
||||
For ImageNet (C=1000), CE is 1000x more expensive than MSE.
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fix ASCII Box Alignment
|
||||
Fix ASCII Box and Table Alignment
|
||||
|
||||
This script finds simple ASCII art boxes in Python files and ensures the
|
||||
This script finds ASCII art boxes and tables in Python files and ensures the
|
||||
right-side vertical bars (│) are perfectly aligned with the top border.
|
||||
|
||||
Handles:
|
||||
- Simple boxes (content lines with exactly 2 │)
|
||||
- Boxes with ├───┤ separator lines
|
||||
- Tables with columns (┬ ┼ ┴ separators)
|
||||
|
||||
Skips (requires manual review):
|
||||
- Nested boxes (content lines with more than 2 │)
|
||||
- Nested boxes (boxes inside boxes)
|
||||
- Side-by-side boxes
|
||||
- Dashed boxes
|
||||
|
||||
Usage:
|
||||
python tools/fix_ascii_boxes.py # Preview changes (dry run)
|
||||
python tools/fix_ascii_boxes.py --fix # Apply fixes
|
||||
python tools/fix_ascii_boxes.py --verbose # Show detailed info
|
||||
python tools/dev/fix_ascii_boxes.py # Preview changes (dry run)
|
||||
python tools/dev/fix_ascii_boxes.py --fix # Apply fixes
|
||||
python tools/dev/fix_ascii_boxes.py --verbose # Show detailed info
|
||||
"""
|
||||
|
||||
import sys
|
||||
@@ -185,6 +186,185 @@ def fix_box_alignment(box_lines: list[str]) -> list[str]:
|
||||
return fixed_lines
|
||||
|
||||
|
||||
def find_tables(content: str) -> list[tuple[int, int, list[str]]]:
|
||||
"""
|
||||
Find ASCII tables (boxes with column separators ┬ ┼ ┴).
|
||||
|
||||
Returns list of (start_line, end_line, lines) tuples.
|
||||
"""
|
||||
lines = content.split('\n')
|
||||
tables = []
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
|
||||
# Look for table start: ┌...┬...┐ (has column separator)
|
||||
if '┌' in line and '┐' in line and '┬' in line:
|
||||
first_corner = line.index('┌')
|
||||
|
||||
# Skip nested (│ before ┌)
|
||||
if '│' in line[:first_corner]:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
table_start = i
|
||||
table_lines = [line]
|
||||
left_pos = first_corner
|
||||
|
||||
# Collect table lines
|
||||
j = i + 1
|
||||
while j < len(lines) and j - i < 100:
|
||||
current_line = lines[j]
|
||||
|
||||
# Content line (multiple │)
|
||||
if '│' in current_line:
|
||||
first_bar = current_line.index('│')
|
||||
if first_bar == left_pos:
|
||||
table_lines.append(current_line)
|
||||
j += 1
|
||||
else:
|
||||
break
|
||||
# Separator line with ┼
|
||||
elif '├' in current_line and '┤' in current_line:
|
||||
table_lines.append(current_line)
|
||||
j += 1
|
||||
# Bottom line with ┴
|
||||
elif '└' in current_line and '┘' in current_line:
|
||||
table_lines.append(current_line)
|
||||
tables.append((table_start, j, table_lines))
|
||||
i = j
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
def fix_table_alignment(table_lines: list[str]) -> list[str]:
|
||||
"""
|
||||
Fix alignment of a table with columns.
|
||||
|
||||
Strategy: Find the column separator positions from the TOP line,
|
||||
then adjust each content row's cells to match those widths.
|
||||
"""
|
||||
if len(table_lines) < 3:
|
||||
return table_lines
|
||||
|
||||
top_line = table_lines[0]
|
||||
left_pos = top_line.index('┌')
|
||||
right_pos = top_line.index('┐')
|
||||
total_width = right_pos - left_pos + 1
|
||||
|
||||
# Find column separator positions (┬) in top line
|
||||
separators = [i for i, c in enumerate(top_line) if c == '┬']
|
||||
|
||||
# Column boundaries: [left_pos, sep1, sep2, ..., right_pos]
|
||||
boundaries = [left_pos] + separators + [right_pos]
|
||||
|
||||
# Calculate column widths (between separators)
|
||||
col_widths = [boundaries[i+1] - boundaries[i] - 1 for i in range(len(boundaries) - 1)]
|
||||
|
||||
fixed_lines = [top_line] # Top line stays as-is (defines structure)
|
||||
|
||||
for line in table_lines[1:-1]:
|
||||
prefix = line[:left_pos]
|
||||
|
||||
if '├' in line and '┤' in line:
|
||||
# Separator row - rebuild with ┼
|
||||
parts = ['├']
|
||||
for i, w in enumerate(col_widths):
|
||||
parts.append('─' * w)
|
||||
if i < len(col_widths) - 1:
|
||||
parts.append('┼')
|
||||
parts.append('┤')
|
||||
fixed_lines.append(prefix + ''.join(parts))
|
||||
|
||||
elif '│' in line:
|
||||
# Content row - extract cells by splitting on │
|
||||
# Remove prefix, get content between first and last │
|
||||
first_bar = line.index('│')
|
||||
last_bar = line.rindex('│')
|
||||
inner = line[first_bar + 1:last_bar]
|
||||
|
||||
# Split on │ to get cells
|
||||
cells = inner.split('│')
|
||||
|
||||
# Pad/trim each cell to its column width
|
||||
fixed_cells = []
|
||||
for i, cell in enumerate(cells):
|
||||
if i < len(col_widths):
|
||||
w = col_widths[i]
|
||||
# Preserve leading space, trim/pad to width
|
||||
cell_stripped = cell.rstrip()
|
||||
if len(cell_stripped) > w:
|
||||
cell_stripped = cell_stripped[:w]
|
||||
fixed_cells.append(cell_stripped.ljust(w))
|
||||
else:
|
||||
fixed_cells.append(cell)
|
||||
|
||||
fixed_line = prefix + '│' + '│'.join(fixed_cells) + '│'
|
||||
fixed_lines.append(fixed_line)
|
||||
else:
|
||||
fixed_lines.append(line)
|
||||
|
||||
# Bottom line - rebuild with ┴
|
||||
bottom_line = table_lines[-1]
|
||||
if '└' in bottom_line and '┘' in bottom_line:
|
||||
prefix = bottom_line[:left_pos]
|
||||
parts = ['└']
|
||||
for i, w in enumerate(col_widths):
|
||||
parts.append('─' * w)
|
||||
if i < len(col_widths) - 1:
|
||||
parts.append('┴')
|
||||
parts.append('┘')
|
||||
fixed_lines.append(prefix + ''.join(parts))
|
||||
else:
|
||||
fixed_lines.append(bottom_line)
|
||||
|
||||
return fixed_lines
|
||||
|
||||
|
||||
def table_needs_fixing(table_lines: list[str]) -> tuple[bool, bool]:
|
||||
"""
|
||||
Check if table has misaligned columns.
|
||||
|
||||
Returns (needs_fixing, content_too_wide).
|
||||
content_too_wide means content is wider than header - needs manual fix.
|
||||
"""
|
||||
if len(table_lines) < 3:
|
||||
return False, False
|
||||
|
||||
top_line = table_lines[0]
|
||||
left_pos = top_line.index('┌')
|
||||
right_pos = top_line.index('┐')
|
||||
|
||||
content_too_wide = False
|
||||
misaligned = False
|
||||
|
||||
for line in table_lines[1:]:
|
||||
if '│' in line:
|
||||
last_bar = line.rindex('│')
|
||||
if last_bar != right_pos:
|
||||
misaligned = True
|
||||
if last_bar > right_pos:
|
||||
content_too_wide = True
|
||||
elif '┤' in line:
|
||||
last_corner = line.rindex('┤')
|
||||
if last_corner != right_pos:
|
||||
misaligned = True
|
||||
elif '┘' in line:
|
||||
last_corner = line.rindex('┘')
|
||||
if last_corner != right_pos:
|
||||
misaligned = True
|
||||
|
||||
return misaligned, content_too_wide
|
||||
|
||||
|
||||
def count_complex_boxes(content: str) -> int:
|
||||
"""Count boxes that are too complex to auto-fix."""
|
||||
lines = content.split('\n')
|
||||
@@ -195,6 +375,10 @@ def count_complex_boxes(content: str) -> int:
|
||||
line = lines[i]
|
||||
if '┌' in line and '┐' in line:
|
||||
first_corner = line.index('┌')
|
||||
# Skip tables (handled separately)
|
||||
if '┬' in line:
|
||||
i += 1
|
||||
continue
|
||||
# Nested box
|
||||
if '│' in line[:first_corner]:
|
||||
count += 1
|
||||
@@ -208,23 +392,26 @@ def count_complex_boxes(content: str) -> int:
|
||||
return count
|
||||
|
||||
|
||||
def process_file(filepath: Path, fix: bool = False, verbose: bool = False) -> tuple[bool, int, int]:
|
||||
"""Process a single file. Returns (has_changes, num_fixed, num_complex)."""
|
||||
def process_file(filepath: Path, fix: bool = False, verbose: bool = False) -> tuple[bool, int, int, int]:
|
||||
"""Process a single file. Returns (has_changes, num_boxes_fixed, num_tables_fixed, num_complex)."""
|
||||
try:
|
||||
content = filepath.read_text(encoding='utf-8')
|
||||
except Exception as e:
|
||||
return False, 0, 0
|
||||
return False, 0, 0, 0
|
||||
|
||||
boxes = find_simple_boxes(content)
|
||||
tables = find_tables(content)
|
||||
complex_count = count_complex_boxes(content)
|
||||
|
||||
if not boxes and complex_count == 0:
|
||||
return False, 0, 0
|
||||
if not boxes and not tables and complex_count == 0:
|
||||
return False, 0, 0, 0
|
||||
|
||||
lines = content.split('\n')
|
||||
original_content = content
|
||||
boxes_fixed = 0
|
||||
tables_fixed = 0
|
||||
|
||||
# Fix boxes
|
||||
for start_line, end_line, box_lines in reversed(boxes):
|
||||
if not needs_fixing(box_lines):
|
||||
continue
|
||||
@@ -249,13 +436,53 @@ def process_file(filepath: Path, fix: bool = False, verbose: bool = False) -> tu
|
||||
|
||||
lines[start_line:end_line + 1] = fixed_lines
|
||||
|
||||
# Rebuild content after box fixes
|
||||
content = '\n'.join(lines)
|
||||
lines = content.split('\n')
|
||||
|
||||
# Find tables again (line numbers may have shifted)
|
||||
tables = find_tables(content)
|
||||
|
||||
# Fix tables
|
||||
tables_skipped = 0
|
||||
for start_line, end_line, table_lines in reversed(tables):
|
||||
misaligned, content_too_wide = table_needs_fixing(table_lines)
|
||||
|
||||
if not misaligned:
|
||||
continue
|
||||
|
||||
if content_too_wide:
|
||||
# Content is wider than header - needs manual fix
|
||||
tables_skipped += 1
|
||||
if verbose:
|
||||
print(f"\n ⚠️ Table at lines {start_line + 1}-{end_line + 1} has content wider than header (manual fix needed)")
|
||||
continue
|
||||
|
||||
fixed_lines = fix_table_alignment(table_lines)
|
||||
|
||||
if fixed_lines != table_lines:
|
||||
tables_fixed += 1
|
||||
|
||||
if verbose:
|
||||
print(f"\n 📊 Table at lines {start_line + 1}-{end_line + 1}:")
|
||||
print(" Before:")
|
||||
for line in table_lines:
|
||||
print(f" {line}")
|
||||
print(" After:")
|
||||
for line in fixed_lines:
|
||||
print(f" {line}")
|
||||
|
||||
lines[start_line:end_line + 1] = fixed_lines
|
||||
|
||||
complex_count += tables_skipped # Count tables needing manual fix as complex
|
||||
|
||||
new_content = '\n'.join(lines)
|
||||
changes_made = new_content != original_content
|
||||
|
||||
if changes_made and fix:
|
||||
filepath.write_text(new_content, encoding='utf-8')
|
||||
|
||||
return changes_made, boxes_fixed, complex_count
|
||||
return changes_made, boxes_fixed, tables_fixed, complex_count
|
||||
|
||||
|
||||
def main():
|
||||
@@ -286,36 +513,47 @@ def main():
|
||||
|
||||
print(f"🔍 Scanning {len(py_files)} Python files...\n")
|
||||
|
||||
total_fixed = 0
|
||||
total_boxes = 0
|
||||
total_tables = 0
|
||||
total_complex = 0
|
||||
files_changed = 0
|
||||
|
||||
for filepath in sorted(py_files):
|
||||
has_changes, fixed, complex_count = process_file(filepath, fix=args.fix, verbose=args.verbose)
|
||||
total_fixed += fixed
|
||||
has_changes, boxes_fixed, tables_fixed, complex_count = process_file(filepath, fix=args.fix, verbose=args.verbose)
|
||||
total_boxes += boxes_fixed
|
||||
total_tables += tables_fixed
|
||||
total_complex += complex_count
|
||||
|
||||
total_fixed = boxes_fixed + tables_fixed
|
||||
if has_changes or complex_count > 0:
|
||||
if has_changes:
|
||||
files_changed += 1
|
||||
status = "✅ Fixed" if args.fix and fixed > 0 else "⚠️ Needs fixing" if fixed > 0 else ""
|
||||
status = "✅ Fixed" if args.fix and total_fixed > 0 else "⚠️ Needs fixing" if total_fixed > 0 else ""
|
||||
parts = []
|
||||
if fixed > 0:
|
||||
parts.append(f"{fixed} simple")
|
||||
if boxes_fixed > 0:
|
||||
parts.append(f"{boxes_fixed} box{'es' if boxes_fixed != 1 else ''}")
|
||||
if tables_fixed > 0:
|
||||
parts.append(f"{tables_fixed} table{'s' if tables_fixed != 1 else ''}")
|
||||
if complex_count > 0:
|
||||
parts.append(f"{complex_count} complex (manual)")
|
||||
parts.append(f"{complex_count} complex")
|
||||
if parts:
|
||||
print(f"{status}: {filepath} ({', '.join(parts)})" if status else f"📋 {filepath} ({', '.join(parts)})")
|
||||
|
||||
print()
|
||||
if total_fixed == 0 and total_complex == 0:
|
||||
print("✨ All ASCII boxes are properly aligned.")
|
||||
total_all = total_boxes + total_tables
|
||||
if total_all == 0 and total_complex == 0:
|
||||
print("✨ All ASCII boxes and tables are properly aligned.")
|
||||
else:
|
||||
if total_fixed > 0:
|
||||
if total_all > 0:
|
||||
if args.fix:
|
||||
print(f"✅ Fixed {total_fixed} simple box{'es' if total_fixed != 1 else ''} in {files_changed} file{'s' if files_changed != 1 else ''}.")
|
||||
parts = []
|
||||
if total_boxes > 0:
|
||||
parts.append(f"{total_boxes} box{'es' if total_boxes != 1 else ''}")
|
||||
if total_tables > 0:
|
||||
parts.append(f"{total_tables} table{'s' if total_tables != 1 else ''}")
|
||||
print(f"✅ Fixed {' and '.join(parts)} in {files_changed} file{'s' if files_changed != 1 else ''}.")
|
||||
else:
|
||||
print(f"⚠️ Found {total_fixed} misaligned simple box{'es' if total_fixed != 1 else ''}.")
|
||||
print(f"⚠️ Found {total_all} misaligned item{'s' if total_all != 1 else ''}.")
|
||||
print(" Run with --fix to apply corrections.")
|
||||
if total_complex > 0:
|
||||
print(f"📋 Found {total_complex} complex/nested box{'es' if total_complex != 1 else ''} (require manual review).")
|
||||
|
||||
Reference in New Issue
Block a user