Files
TinyTorch/tinytorch/core/utils.py
2025-07-10 11:13:45 -04:00

63 lines
1.8 KiB
Python

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../notebooks/01_setup.ipynb.
# %% auto 0
__all__ = ['hello_tinytorch', 'format_tensor_shape', 'validate_tensor_shapes']
# %% ../../notebooks/01_setup.ipynb 4
def hello_tinytorch() -> str:
"""
Return a greeting message for new TinyTorch users.
This function serves as the "hello world" for the TinyTorch system.
It introduces students to the nbdev export workflow.
Returns:
A welcoming message string that includes:
- Welcoming content
- TinyTorch branding (🔥 emoji)
- Encouraging message about building ML systems
"""
return "🔥 Welcome to TinyTorch! Ready to build ML systems from scratch! 🔥"
# %% ../../notebooks/01_setup.ipynb 8
from typing import Any, List, Dict, Tuple
import numpy as np
def format_tensor_shape(shape: tuple) -> str:
"""
Format a tensor shape tuple for pretty printing.
Args:
shape: Tuple representing tensor dimensions
Returns:
Formatted string representation of the shape
Example:
>>> format_tensor_shape((3, 4, 5))
'(3, 4, 5)'
"""
return f"({', '.join(map(str, shape))})"
# %% ../../notebooks/01_setup.ipynb 9
def validate_tensor_shapes(*shapes: tuple) -> bool:
"""
Validate that tensor shapes are compatible for operations.
Args:
*shapes: Variable number of shape tuples to validate
Returns:
True if shapes are compatible, False otherwise
Example:
>>> validate_tensor_shapes((3, 4), (3, 4))
True
>>> validate_tensor_shapes((3, 4), (2, 4))
False
"""
if len(shapes) < 2:
return True
first_shape = shapes[0]
return all(shape == first_shape for shape in shapes[1:])