A modular, extensible template for PyTorch-based deep learning research. Define your entire experiment — model, optimizer, scheduler, loss function, callbacks — in a single YAML file and run it with one command.
- YAML-Driven Configuration — All experiment settings managed in YAML. Frozen, validated configs prevent silent misconfiguration.
- Callback-Based Training — Extensible training loop with priority-ordered callbacks. Add behaviors (logging, checkpointing, early stopping) without modifying core code.
- Configurable Loss & Metrics — Swap loss functions via YAML (
torch.nn.CrossEntropyLoss, custom losses). Built-in metric registry (MSE, MAE, R2) with importlib extension. - Checkpoint & Resume — Full state checkpointing (model, optimizer, scheduler, RNG states) with multi-seed resume via
SeedManifest. - Run Provenance — Automatic capture of Python/PyTorch/CUDA versions, GPU info, git commit, and environment variables per run.
- Hyperparameter Optimization — Optuna integration with custom PFL pruner and deep-merge config overrides.
- Experiment Tracking — Seamless Weights & Biases logging via callback.
- CLI —
typer-based CLI withtrain,validate,preview,doctor, andanalyzesubcommands. - Reproducibility — Deterministic seed management across Python, NumPy, and PyTorch.
The training loop emits events at defined hook points. Each concern (logging, early stopping, checkpointing) is an independent, priority-ordered callback:
| Callback | Priority | Hook | Purpose |
|---|---|---|---|
NaNDetectionCallback |
5 | on_epoch_end |
Detect NaN loss, signal stop |
OptimizerModeCallback |
10 | on_train_epoch_begin, on_val_begin |
SPlus/ScheduleFree train/eval mode |
LossPredictionCallback |
70 | on_val_end |
Predict final loss for early pruning |
WandbLoggingCallback |
80 | on_epoch_end |
Log metrics to W&B |
PrunerCallback |
85 | on_val_end |
Report to Optuna pruner |
EarlyStoppingCallback |
90 | on_val_end |
Monitor metric, signal stop |
CheckpointCallback |
95 | on_epoch_end |
Save periodic/best checkpoints |
Adding custom behavior is as simple as subclassing TrainingCallback and adding it to the callback list — zero changes to the training loop.
-
Clone:
git clone https://github.com/<your-username>/<your-repo>.git cd <your-repo>
-
Install dependencies (uv recommended):
uv venv && source .venv/bin/activate uv pip install -U torch wandb rich beaupy numpy optuna matplotlib scienceplots typer tqdm pyyaml pytorch-optimizer pytorch-scheduler
-
Validate your setup:
python -m cli doctor
-
Preview a config (no training, just inspect):
python -m cli preview configs/run_template.yaml
-
Train:
python -m cli train configs/run_template.yaml # Or with device override: python -m cli train configs/run_template.yaml --device cpu -
Hyperparameter optimization:
python -m cli train configs/run_template.yaml --optimize-config configs/optimize_template.yaml
-
Analyze results:
python -m cli analyze # Or non-interactive: python -m cli analyze --project MyProject --group MyGroup --seed 42
Legacy CLI:
python main.py --run_config configs/run_template.yamlstill works for backward compatibility.
pytorch_template/
├── cli.py # Typer CLI entrypoint (train, validate, preview, doctor, analyze)
├── main.py # Legacy argparse entrypoint
├── config.py # RunConfig (frozen, validated) + OptimizeConfig
├── util.py # Trainer, run(), data loading, analysis helpers
├── callbacks.py # Callback system (8 built-in callbacks + CallbackRunner)
├── metrics.py # Metric registry (MSE, MAE, R2 + importlib extension)
├── checkpoint.py # CheckpointManager + SeedManifest
├── provenance.py # Environment capture + config hashing
├── model.py # Model architectures (MLP)
├── pruner.py # PFL pruner for Optuna
├── configs/
│ ├── run_template.yaml
│ └── optimize_template.yaml
├── recipes/
│ ├── regression/ # Sine wave regression (MLP + MSELoss)
│ └── classification/ # FashionMNIST classification (CNN + CrossEntropyLoss)
├── tests/ # 36 unit tests
└── runs/ # Experiment outputs (auto-created)
project: PyTorch_Template
device: cuda:0
net: model.MLP
optimizer: pytorch_optimizer.SPlus
scheduler: pytorch_scheduler.ExpHyperbolicLRScheduler
criterion: torch.nn.MSELoss # Any loss function via importlib
criterion_config: {} # Arguments for criterion constructor
epochs: 50
batch_size: 256
seeds: [89, 231, 928, 814, 269]
net_config:
nodes: 64
layers: 4
optimizer_config:
lr: 1.e-3
eps: 1.e-10
scheduler_config:
total_steps: 50
upper_bound: 250
min_lr: 1.e-5
early_stopping_config:
enabled: false
patience: 10
mode: min
min_delta: 0.0001
checkpoint_config:
enabled: false
save_every_n_epochs: 10
keep_last_k: 3
save_best: true
monitor: val_loss
mode: minKey fields:
| Field | Description |
|---|---|
net |
Model class path in module.Class format |
optimizer |
Optimizer class path (supports torch.optim.*, pytorch_optimizer.*, custom) |
scheduler |
Scheduler class path (supports torch.optim.lr_scheduler.*, pytorch_scheduler.*, custom) |
criterion |
Loss function class path (e.g., torch.nn.MSELoss, torch.nn.CrossEntropyLoss) |
criterion_config |
Arguments passed to criterion constructor |
seeds |
List of random seeds — each seed is a separate training run |
checkpoint_config |
Periodic/best checkpoint saving with configurable policy |
All module paths are resolved via importlib at runtime. The config is frozen after construction — use config.with_overrides(field=value) to create modified copies.
See configs/optimize_template.yaml for the full template. Key sections: search_space, sampler, pruner.
Create a model class in model.py or a new file. The constructor must accept (hparams: dict, device: str):
# my_model.py
class MyTransformer(nn.Module):
def __init__(self, hparams, device="cpu"):
super().__init__()
# hparams comes from net_config in YAML
...net: my_model.MyTransformer
net_config:
d_model: 256
nhead: 8Subclass TrainingCallback and override hook methods:
# my_callbacks.py
from callbacks import TrainingCallback
class GradientClipCallback(TrainingCallback):
priority = 15 # Run early, after OptimizerMode
def __init__(self, max_norm=1.0):
self.max_norm = max_norm
def on_train_step_end(self, trainer, batch_idx, loss, **kwargs):
torch.nn.utils.clip_grad_norm_(trainer.model.parameters(), self.max_norm)Then add it to the callback list in your training script or extend run().
Register built-in names or importlib paths:
from metrics import MetricRegistry
registry = MetricRegistry(["mse", "mae", "r2", "my_module.MyCustomMetric"])
results = registry.compute(y_pred, y_true)
# {"mse": 0.012, "mae": 0.089, "r2": 0.95, "my_custom_metric": ...}Change one line in YAML — no code changes:
# Regression
criterion: torch.nn.MSELoss
# Classification
criterion: torch.nn.CrossEntropyLoss
# Custom
criterion: my_losses.FocalLoss
criterion_config:
gamma: 2.0
alpha: 0.25# Built-in PyTorch
scheduler: torch.optim.lr_scheduler.CosineAnnealingLR
scheduler_config:
T_max: 50
eta_min: 1.e-5
# ExpHyperbolicLR (via pytorch-scheduler)
scheduler: pytorch_scheduler.ExpHyperbolicLRScheduler
scheduler_config:
total_steps: 50
upper_bound: 250
min_lr: 1.e-5Modify load_data() in util.py to return your (train_dataset, val_dataset). See recipes/ for examples (regression + classification).
| Recipe | Task | Model | Loss | Config |
|---|---|---|---|---|
recipes/regression/ |
Sine wave fitting | MLP (64 nodes, 4 layers) | MSELoss | config.yaml |
recipes/classification/ |
FashionMNIST | SimpleCNN (32 channels) | CrossEntropyLoss | config.yaml |
python -m cli train recipes/regression/config.yaml --device cpu| Command | Description |
|---|---|
python -m cli train <config> [--device] [--optimize-config] |
Train model(s) with optional HPO |
python -m cli validate <config> |
Validate config without training |
python -m cli preview <config> |
Show model architecture and config summary |
python -m cli doctor |
Check Python, PyTorch, CUDA, W&B, packages |
python -m cli analyze [--project] [--group] [--seed] [--device] |
Analyze trained models |
For a deeper dive into components and customization:
- Project Documentation (Generated by Tutorial-Codebase-Knowledge)
Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License.
- pytorch-optimizer for optimizers including SPlus.
- pytorch-scheduler for learning rate schedulers including ExpHyperbolicLR.
PFL (Predicted Final Loss) Pruner
The PFL pruner (pruner.PFLPruner) predicts the final performance of a training run based on early-stage metrics, pruning unpromising Optuna trials early.
pruner:
name: pruner.PFLPruner
kwargs:
n_startup_trials: 10
n_warmup_epochs: 10
top_k: 10
target_epoch: 50- The first
n_startup_trialsrun to completion to establish baseline performance. - For subsequent trials, pruning is considered only after
n_warmup_epochs. - The pruner predicts final loss from the current loss history using exponential curve fitting.
- If the predicted final loss is worse than the top-k completed trials, the trial is pruned.
- Supports multi-seed runs by averaging metrics across seeds.
