Skip to content

Latest commit



127 lines (96 loc) · 20 KB

File metadata and controls

127 lines (96 loc) · 20 KB



Component type Component Version Implementation Configuration Component Interface Description
model gpt2 GPT2LLM GPT2LLMConfig NNModel GPT2 model for language modeling
model huggingface_pretrained_model HuggingFacePretrainedModel HuggingFacePretrainedModelConfig NNModel HuggingFace pretrained model for language modeling
model checkpointed ModelFactory.get_checkpointed_model CheckpointedModelConfig nn.Module Checkpointed Model instance
model fsdp_wrapped ModelFactory.get_fsdp_wrapped_model FSDPWrappedModelConfig NNModel Model that has been sharded via FSDP
model model_initialized ModelFactory.get_weight_initalized_model WeightInitializedModelConfig nn.Module Model with initialized weights
model coca CoCa CoCaConfig NNModel CoCa Model (Contrastive Captioners)

Weight Initialization

Component type Component Version Implementation Configuration Component Interface Description
model_initialization composed ComposedInitializationRoutines.get_composed_model_initializer ComposedModelInitializationConfig ModelInitializationIF Component for initializing model weights in place


Component type Component Version Implementation Configuration Component Interface Description
loss clm_cross_entropy_loss CLMCrossEntropyLoss CLMCrossEntropyLossConfig Loss Cross-entropy loss function


Component type Component Version Implementation Configuration Component Interface Description
optimizer adam OptimizerFactory.get_adam AdamOptimizerConfig Optimizer ADAM optimizer
optimizer adam_w OptimizerFactory.get_adam_w AdamWOptimizerConfig Optimizer ADAMW Optimizer
optimizer checkpointed OptimizerFactory.get_checkpointed_optimizer CheckpointedOptimizerConfig Optimizer Optimizer instantiated from checkpoint

LR Scheduling

Component type Component Version Implementation Configuration Component Interface Description
scheduler dummy_lr DummyLRScheduler DummyLRSchedulerConfig LRScheduler Fake lr scheduler not adapting the lr rate
scheduler step_lr StepLR StepLRSchedulerConfig LRScheduler Decays the learning rate of each parameter group by gamma every step_size steps
scheduler constant_lr ConstantLR ConstantLRSchedulerConfig LRScheduler Multiplies the learning rate of each parameter group by a small constant factor until the number of steps reaches a pre-defined milestone
scheduler onecycle_lr OneCycleLR OneCycleLRSchedulerConfig LRScheduler Sets the learning rate of each parameter group according to the 1cycle learning rate policy.
scheduler cosine_annealing_lr CosineAnnealingLR CosineAnnealingLRSchedulerConfig LRScheduler Set the learning rate of each parameter group using a cosine annealing schedule


Component type Component Version Implementation Configuration Component Interface Description
tokenizer pretrained_hf_tokenizer PreTrainedHFTokenizer PreTrainedHFTokenizerConfig TokenizerWrapper Pretrained Huggingface tokenizer
tokenizer pretrained_sp_tokenizer PreTrainedSPTokenizer PreTrainedSPTokenizerConfig TokenizerWrapper Pretrained SentencePiece tokenizer


Component type Component Version Implementation Configuration Component Interface Description
dataset mem_map_dataset DatasetFactory.get_mem_map_dataset MemMapDatasetConfig Dataset MemMap Dataset
dataset packed_mem_map_dataset_continuous DatasetFactory.get_packed_mem_map_dataset_continuous PackedMemMapDatasetContinuousConfig Dataset Packed Memory Mapped Dataset Continuous
dataset dummy_dataset DatasetFactory.get_dummy_dataset DummyDatasetConfig Dataset Dummy dataset creating random samples of specified shape

Data sampling

Component type Component Version Implementation Configuration Component Interface Description
sampler distributed_sampler DistributedSampler DistributedSamplerConfig Sampler Sampler that restricts data loading to a subset of the dataset for distributed training
batch_sampler default BatchSampler BatchSamplerConfig Sampler Wraps another sampler to yield a mini-batch of indices.

Data collation

Component type Component Version Implementation Configuration Component Interface Description
collate_fn gpt_2_llm_collator GPT2LLMCollateFn GPT2LLMCollateFnConfig CollateFnIF Data collator for the GPT2 model
collate_fn coca_collator CoCaCollatorFn CoCaCollateFnConfig CollateFnIF Data collator for the CoCa model

Data loaders

Component type Component Version Implementation Configuration Component Interface Description
data_loader default DataloaderFactory.get_dataloader LLMDataLoaderConfig DataLoader LLM Data loader extending pytorch data loader functionality
data_loader repeating_data_loader DataloaderFactory.get_repeating_dataloader RepeatingDataLoaderConfig DataLoader Data loader that repeats the given dataloader for the specified number of epochs.


Component type Component Version Implementation Configuration Component Interface Description
checkpoint_saving default CheckpointSaving CheckpointSavingConfig -- Component for saving checkpoints based on a savig and execution strategy.
checkpoint_saving_strategy save_every_k_steps_checkpointing_strategy SaveEveryKStepsCheckpointingStrategy SaveEveryKStepsCheckpointingStrategyConfig CheckpointSavingStrategyIF Checkpointing strategy saving a checkpoint every k steps
checkpoint_saving_strategy save_k_most_recent_checkpoints_strategy SaveKMostRecentCheckpointsStrategy SaveKMostRecentCheckpointsStrategyConfig CheckpointSavingStrategyIF Checkpointing strategy saving only the last k checkpoints and deleting the previous ones
checkpoint_saving_execution fsdp FSDPCheckpointSaving FSDPCheckpointSavingConfig CheckpointSavingExecutionABC FSDPCheckpointSaving class for saving checkpoints of FSDP models and optimizers.
checkpoint_loading fsdp FSDPCheckpointLoading FSDPCheckpointLoadingConfig CheckpointLoadingIF Component for loading FSDP checkpoints
checkpoint_loading torch TorchCheckpointLoading TorchCheckpointLoadingConfig CheckpointLoadingIF Component for loading PyTorch checkpoints


Component type Component Version Implementation Configuration Component Interface Description
progress_subscriber dummy ProgressSubscriberFactory.get_dummy_progress_subscriber DummyProgressSubscriberConfig MessageSubscriberIF Dummy Progress subscriber not consuming any messages
progress_subscriber rich ProgressSubscriberFactory.get_rich_progress_subscriber RichProgressSubscriberConfig MessageSubscriberIF Subscriber for writing out rich-formatted console outputs w.r.t. to training and evaluation progress
results_subscriber wandb ProgressSubscriberFactory.get_wandb_result_subscriber WandBEvaluationResultSubscriberConfig MessageSubscriberIF Subscriber for logging evaluation results to Weights and Biases

Layer Norms

Component type Component Version Implementation Configuration Component Interface Description
layer_norm rms_norm RMSLayerNorm RMSLayerNormConfig nn.Module RMS Layer norm
layer_norm layer_norm nn.LayerNorm LayerNormConfig nn.Module Layer norm

Gradient Clipping

Component type Component Version Implementation Configuration Component Interface Description
gradient_clipper fsdp FSDPGradientClipper FSDPGradientClipperConfig GradientClipperIF FSDP Gradient Clipper
gradient_clipper fsdp_logging_only FSDPLoggingOnlyGradientClipper FSDPGradientClipperConfig GradientClipperIF Clipper that is responsible for logging the gradient norms without actually clipping the gradients
gradient_clipper dummy DummyGradientClipper DummyGradientClipperConfig GradientClipperIF Dummy clipper that does not apply any gradient clipping.

Number conversions

Component type Component Version Implementation Configuration Component Interface Description
number_conversion local_num_batches_from_num_samples NumberConversion.get_local_num_batches_from_num_samples LocalNumBatchesFromNumSamplesConfig -- Calculates the number of local batches for each rank, given the global number of samples and number of ranks.
number_conversion local_num_batches_from_num_tokens NumberConversion.get_local_num_batches_from_num_tokens LocalNumBatchesFromNumTokensConfig -- Calculates the number of local batches for each rank, given the global number of tokens and number of ranks.
number_conversion num_steps_from_num_samples NumberConversion.get_num_steps_from_num_samples NumStepsFromNumSamplesConfig -- Calculates the number of steps given the global number of samples, local micro batch size and number of ranks.
number_conversion num_steps_from_num_tokens NumberConversion.get_num_steps_from_num_tokens NumStepsFromNumTokensConfig -- Calculates the number of steps given the global number of tokens, local micro batch size and number of ranks.
number_conversion num_tokens_from_num_steps NumberConversion.get_num_tokens_from_num_steps NumTokensFromNumStepsConfig -- Calculates the number of tokens from the number of steps, number of ranks, local micro batch size, global number of tokens, squence length and gradient accumulation steps
number_conversion last_step_from_checkpoint_path NumberConversion.get_num_seen_steps_from_checkpoint_path NumberConversionFromCheckpointPathConfig -- Get the last step id from a model or checkpoint file path.
number_conversion global_num_target_tokens_from_checkpoint_path NumberConversion.get_global_num_target_tokens_from_checkpoint_path NumberConversionFromCheckpointPathConfig -- Get the number of target tokens from a model or checkpoint file path.
number_conversion num_tokens_from_packed_mem_map_dataset_continuous NumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous NumTokensFromPackedMemMapDatasetContinuousConfig -- Get the number of tokens stored in a packed mem map continuous dataset from the respective dataset file path.
number_conversion num_steps_from_raw_dataset_index NumberConversion.get_num_steps_from_raw_dataset_index NumStepsFromRawDatasetIndexConfig -- Get the number of steps partially from the raw index of a raw JSONL dataset. Requires the file path to index, number of ranks, local micro batch size and gardient accumulation steps.