mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-06-02 21:10:57 -05:00
63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../notebooks/01_setup.ipynb.
|
|
|
|
# %% auto 0
|
|
__all__ = ['hello_tinytorch', 'format_tensor_shape', 'validate_tensor_shapes']
|
|
|
|
# %% ../../notebooks/01_setup.ipynb 4
|
|
def hello_tinytorch() -> str:
|
|
"""
|
|
Return a greeting message for new TinyTorch users.
|
|
|
|
This function serves as the "hello world" for the TinyTorch system.
|
|
It introduces students to the nbdev export workflow.
|
|
|
|
Returns:
|
|
A welcoming message string that includes:
|
|
- Welcoming content
|
|
- TinyTorch branding (🔥 emoji)
|
|
- Encouraging message about building ML systems
|
|
"""
|
|
return "🔥 Welcome to TinyTorch! Ready to build ML systems from scratch! 🔥"
|
|
|
|
# %% ../../notebooks/01_setup.ipynb 8
|
|
from typing import Any, List, Dict, Tuple
|
|
import numpy as np
|
|
|
|
def format_tensor_shape(shape: tuple) -> str:
|
|
"""
|
|
Format a tensor shape tuple for pretty printing.
|
|
|
|
Args:
|
|
shape: Tuple representing tensor dimensions
|
|
|
|
Returns:
|
|
Formatted string representation of the shape
|
|
|
|
Example:
|
|
>>> format_tensor_shape((3, 4, 5))
|
|
'(3, 4, 5)'
|
|
"""
|
|
return f"({', '.join(map(str, shape))})"
|
|
|
|
# %% ../../notebooks/01_setup.ipynb 9
|
|
def validate_tensor_shapes(*shapes: tuple) -> bool:
|
|
"""
|
|
Validate that tensor shapes are compatible for operations.
|
|
|
|
Args:
|
|
*shapes: Variable number of shape tuples to validate
|
|
|
|
Returns:
|
|
True if shapes are compatible, False otherwise
|
|
|
|
Example:
|
|
>>> validate_tensor_shapes((3, 4), (3, 4))
|
|
True
|
|
>>> validate_tensor_shapes((3, 4), (2, 4))
|
|
False
|
|
"""
|
|
if len(shapes) < 2:
|
|
return True
|
|
|
|
first_shape = shapes[0]
|
|
return all(shape == first_shape for shape in shapes[1:]) |