mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-29 06:37:58 -05:00
Introduces the foundational CLI structure and core components for the TinyTorch project. This initial commit establishes the command-line interface (CLI) using `argparse` for training, evaluation, benchmarking, and system information. It also lays out the basic directory structure and essential modules, including tensor operations, autograd, neural network layers, optimizers, data loading, and MLOps components.
84 lines
1.7 KiB
YAML
84 lines
1.7 KiB
YAML
# TinyTorch Default Configuration
|
|
# This file contains default settings for training neural networks
|
|
|
|
# Model configuration
|
|
model:
|
|
type: "mlp" # mlp, cnn
|
|
input_size: 784 # for MNIST (28*28)
|
|
hidden_sizes: [128, 64]
|
|
output_size: 10
|
|
activation: "relu"
|
|
dropout: 0.0
|
|
|
|
# CNN-specific settings (used when model.type == "cnn")
|
|
cnn:
|
|
channels: [32, 64]
|
|
kernel_sizes: [3, 3]
|
|
pool_sizes: [2, 2]
|
|
fc_hidden: 128
|
|
|
|
# Dataset configuration
|
|
dataset:
|
|
name: "mnist" # mnist, cifar10
|
|
batch_size: 64
|
|
shuffle: true
|
|
num_workers: 0
|
|
download: true
|
|
data_dir: "./data"
|
|
|
|
# Data augmentation (for training)
|
|
augmentation:
|
|
enabled: false
|
|
transforms:
|
|
- "random_crop"
|
|
- "random_flip"
|
|
|
|
# Training configuration
|
|
training:
|
|
epochs: 10
|
|
learning_rate: 0.001
|
|
optimizer: "adam" # sgd, adam
|
|
weight_decay: 0.0
|
|
momentum: 0.9 # for SGD
|
|
|
|
# Learning rate scheduling
|
|
lr_scheduler:
|
|
enabled: false
|
|
type: "step" # step, cosine
|
|
step_size: 3
|
|
gamma: 0.1
|
|
|
|
# Validation and testing
|
|
validation:
|
|
enabled: true
|
|
split: 0.1 # fraction of training data for validation
|
|
frequency: 1 # validate every N epochs
|
|
|
|
# Checkpointing
|
|
checkpointing:
|
|
enabled: true
|
|
save_best: true
|
|
save_frequency: 5 # save every N epochs
|
|
max_checkpoints: 3 # keep only N latest checkpoints
|
|
|
|
# Logging and monitoring
|
|
logging:
|
|
level: "INFO" # DEBUG, INFO, WARNING, ERROR
|
|
log_frequency: 100 # log every N batches
|
|
metrics:
|
|
- "loss"
|
|
- "accuracy"
|
|
- "learning_rate"
|
|
|
|
# Profiling
|
|
profiling:
|
|
enabled: false
|
|
memory: true
|
|
compute: true
|
|
|
|
# System configuration
|
|
system:
|
|
device: "auto" # auto, cpu, cuda
|
|
mixed_precision: false
|
|
seed: 42
|
|
deterministic: false |