Skip to content
Open
380 changes: 380 additions & 0 deletions config_files/training/config_lorem_ipsum_long_fsdp2_cp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,380 @@
settings:
experiment_id: ${modalities_env:experiment_id}
config_file_path: ${modalities_env:config_file_path}
referencing_keys:
sample_key: input_ids
target_key: target_ids
prediction_key: logits
cuda_env:
local_rank: ${cuda_env:LOCAL_RANK}
global_rank: ${cuda_env:RANK}
world_size: ${cuda_env:WORLD_SIZE}
paths:
checkpoint_saving_path: data/checkpoints
train_dataset_path: ./data/lorem_ipsum_long.pbin
test_dataset_path: ./data/lorem_ipsum.pbin
experiments_root_path: ${modalities_env:experiments_root_path}
intervals:
training_log_interval_in_steps: 1
checkpointing_interval_in_steps: 32
evaluation_interval_in_steps: 32
consistency_enforcement:
enforce_tokens_per_step_consistency: false
enforce_last_step_logged: false
enforce_last_step_evaluated: false
enforce_last_step_checkpointed: false
step_profile:
gradient_accumulation_steps: 1
local_train_micro_batch_size: 1
sequence_length: 256
dp_degree:
instance_key: dp_degree
pass_type: BY_REFERENCE
training_target:
num_target_tokens:
component_key: number_conversion
variant_key: num_tokens_from_packed_mem_map_dataset_continuous
config:
dataset_path: ${settings.paths.train_dataset_path}
sequence_length: ${settings.step_profile.sequence_length}
dp_degree:
instance_key: dp_degree
pass_type: BY_REFERENCE
local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size}
gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps}
num_target_steps: # for the batch progress subscriber
component_key: number_conversion
variant_key: num_steps_from_num_tokens
config:
dp_degree:
instance_key: dp_degree
pass_type: BY_REFERENCE
local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size}
global_num_tokens: ${settings.training_target.num_target_tokens}
sequence_length: ${settings.step_profile.sequence_length}
gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps}
training_progress:
global_num_seen_tokens: 0
num_seen_steps: 0
num_seen_samples: 0
last_step: -1

collate_fn:
component_key: collate_fn
variant_key: gpt_2_llm_collator
config:
sample_key: ${settings.referencing_keys.sample_key}
target_key: ${settings.referencing_keys.target_key}

train_dataset:
component_key: dataset
variant_key: packed_mem_map_dataset_continuous
config:
raw_data_path: ${settings.paths.train_dataset_path}
sequence_length: ${settings.step_profile.sequence_length}
sample_key: ${settings.referencing_keys.sample_key}

train_dataloader:
component_key: data_loader
variant_key: default
config:
num_workers: 2
pin_memory: true
dataloader_tag: train
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
batch_sampler:
component_key: batch_sampler
variant_key: default
config:
batch_size: ${settings.step_profile.local_train_micro_batch_size}
drop_last: true
sampler:
component_key: sampler
variant_key: resumable_distributed_multi_dim_sampler
config:
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
device_mesh:
instance_key: device_mesh
pass_type: BY_REFERENCE
data_parallel_key: dp_shard
shuffle: true
seed: 42
drop_last: true
skip_num_global_samples: ${settings.training_progress.num_seen_samples}
collate_fn:
instance_key: collate_fn
pass_type: BY_REFERENCE

test_dataset:
component_key: dataset
variant_key: packed_mem_map_dataset_continuous
config:
raw_data_path: ${settings.paths.test_dataset_path}
sequence_length: ${settings.step_profile.sequence_length}
sample_key: ${settings.referencing_keys.sample_key}

test_dataloader:
component_key: data_loader
variant_key: default
config:
num_workers: 2
pin_memory: true
dataloader_tag: test
dataset:
instance_key: test_dataset
pass_type: BY_REFERENCE
batch_sampler:
component_key: batch_sampler
variant_key: default
config:
batch_size: ${settings.step_profile.local_train_micro_batch_size}
drop_last: true
sampler:
component_key: sampler
variant_key: resumable_distributed_multi_dim_sampler
config:
dataset:
instance_key: test_dataset
pass_type: BY_REFERENCE
device_mesh:
instance_key: device_mesh
pass_type: BY_REFERENCE
data_parallel_key: dp_shard
shuffle: false
seed: 42
drop_last: true
collate_fn:
instance_key: collate_fn
pass_type: BY_REFERENCE

eval_dataloaders:
- instance_key: test_dataloader
pass_type: BY_REFERENCE

checkpoint_saving:
component_key: checkpoint_saving
variant_key: default
config:
checkpoint_saving_strategy:
component_key: checkpoint_saving_strategy
variant_key: save_k_most_recent_checkpoints_strategy
config:
k: -1 # -1 to save all checkpoints
checkpoint_saving_execution:
component_key: checkpoint_saving_execution
variant_key: dcp
config:
checkpoint_path: ${settings.paths.checkpoint_saving_path}
global_rank: ${settings.cuda_env.global_rank}
experiment_id: ${settings.experiment_id}

loss_fn:
component_key: loss
variant_key: clm_cross_entropy_loss
config:
target_key: ${settings.referencing_keys.target_key}
prediction_key: ${settings.referencing_keys.prediction_key}

device_mesh:
component_key: device_mesh
variant_key: default
config:
device_type: cuda
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
context_parallel_degree: 2
world_size: ${settings.cuda_env.world_size}

dp_degree:
component_key: number_conversion
variant_key: parallel_degree
config: # get the parallel degree from the device mesh
device_mesh:
instance_key: device_mesh
pass_type: BY_REFERENCE
parallelism_methods: [dp_shard, dp_replicate]

app_state:
component_key: app_state
variant_key: raw
config:
model:
instance_key: initialized_model
pass_type: BY_REFERENCE
optimizer:
instance_key: optimizer
pass_type: BY_REFERENCE
lr_scheduler:
instance_key: lr_scheduler
pass_type: BY_REFERENCE

initialized_model:
component_key: model
variant_key: model_initialized
config:
model:
instance_key: fsdp_model
pass_type: BY_REFERENCE
model_initializer:
component_key: model_initialization
variant_key: composed
config:
model_type: gpt2
weight_init_type: scaled
mean: 0.0
std: 0.02
num_layers: ${model_raw.config.n_layer}
multi_device_generator_policy: error

fsdp_model:
component_key: model
variant_key: fsdp2_wrapped
config:
model:
instance_key: gpt2_cp_model
pass_type: BY_REFERENCE
device_mesh:
instance_key: device_mesh
pass_type: BY_REFERENCE
mixed_precision_settings:
param_dtype: BF_16
reduce_dtype: BF_16
block_names: [GPT2Block]

gpt2_cp_model:
component_key: model
variant_key: gpt2_cp
config:
model:
instance_key: model_raw
pass_type: BY_REFERENCE
device_mesh:
instance_key: device_mesh
pass_type: BY_REFERENCE

model_raw:
component_key: model
variant_key: gpt2
config:
use_meta_device: true
use_weight_tying: false
sample_key: ${settings.referencing_keys.sample_key}
poe_type: NOPE
sequence_length: ${settings.step_profile.sequence_length}
prediction_key: ${loss_fn.config.prediction_key}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 2
n_head_q: 8
n_head_kv: 4
ffn_hidden: 128
n_embd: 128
dropout: 0.0
bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
attention_config:
qkv_transforms:
- type_hint: RotaryTransform
config:
n_embd: ${model_raw.config.n_embd}
n_head: ${model_raw.config.n_head_q} #it has to be head_q here
seq_length_dim: -2
base_freq: 10000
attention_implementation: pytorch_flash
activation_type: swiglu
attention_norm_config:
norm_type: layer_norm
config:
normalized_shape: ${model_raw.config.n_embd}
eps: 1e-5
ffn_norm_config:
norm_type: layer_norm
config:
normalized_shape: ${model_raw.config.n_embd}
eps: 1e-5
lm_head_norm_config:
norm_type: layer_norm
config:
normalized_shape: ${model_raw.config.n_embd}
eps: 1e-5

lr_scheduler:
component_key: scheduler
variant_key: onecycle_lr
config:
optimizer:
instance_key: optimizer
pass_type: BY_REFERENCE
max_lr: 6e-4
div_factor: 10
final_div_factor: 1
total_steps: ${settings.training_target.num_target_steps}
pct_start: 0.01
anneal_strategy: cos
last_epoch: ${settings.training_progress.last_step}

optimizer:
component_key: optimizer
variant_key: adam_w
config:
lr: 0.0001
betas: [0.9, 0.95]
eps: 1e-8
weight_decay: 1e-1
weight_decay_groups_excluded: [embedding, layernorm]
wrapped_model:
instance_key: initialized_model
pass_type: BY_REFERENCE

gradient_clipper:
component_key: gradient_clipper
variant_key: fsdp2
config:
wrapped_model:
instance_key: initialized_model
pass_type: BY_REFERENCE
norm_type: P2_NORM
max_norm: 1.0
device_mesh:
instance_key: device_mesh
pass_type: BY_REFERENCE

progress_subscriber:
component_key: progress_subscriber
variant_key: rich
config:
global_rank: ${settings.cuda_env.global_rank}
num_seen_steps: ${settings.training_progress.num_seen_steps}
num_target_steps: ${settings.training_target.num_target_steps}
train_dataloader_tag: ${train_dataloader.config.dataloader_tag}
eval_dataloaders:
instance_key: eval_dataloaders
pass_type: BY_REFERENCE

evaluation_subscriber:
component_key: results_subscriber
variant_key: wandb
config:
global_rank: ${settings.cuda_env.global_rank}
project: modalities_dcp_tests
mode: OFFLINE
experiment_id: ${settings.experiment_id}
directory: wandb_storage
config_file_path: ${settings.config_file_path}

mfu_calculator:
component_key: mfu_calculator
variant_key: gpt2
config:
n_layer: ${model_raw.config.n_layer}
sequence_length: ${settings.step_profile.sequence_length}
n_embd: ${model_raw.config.n_embd}
world_size: ${settings.cuda_env.world_size}
wrapped_model:
instance_key: initialized_model
pass_type: BY_REFERENCE
device_mesh:
instance_key: device_mesh
pass_type: BY_REFERENCE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ settings:
global_rank: ${cuda_env:RANK}
world_size: ${cuda_env:WORLD_SIZE}
paths:
experiments_root_path: ${modalities_env:experiments_root_path}
checkpoint_saving_path: data/checkpoints
train_dataset_path: ./data/lorem_ipsum_long.pbin
test_dataset_path: ./data/lorem_ipsum.pbin
Expand Down
Loading
Loading