Trainer#
- class Trainer#
A base Trainer class supporting continued pretraining and SFT, designed to be subclassed for implementing other training objectives, e.g. DPO.
- Parameters:
model (kithara.Model) β The model to be trained or evaluated
optimizer (keras.Optimizer) β The optimizer used for training
train_dataloader (kithara.Dataloader) β A dataloader that provides training batches
eval_dataloader (Optional[kithara.Dataloader]) β A dataloader that provides evaluation batches
steps (Optional[int]) β The total number of training steps to execute. Defaults to None and trains 1 epoch
epochs (Optional[int]) β The total number of training epochs to execute. Defaults to None
log_steps_interval (int) β The interval between logging steps. Each log includes the current loss value and performance metrics
eval_steps_interval (Optional[int]) β The interval between evaluation steps. Only one of eval_steps_interval or eval_epochs_interval can be set
eval_epochs_interval (Optional[int]) β The interval between evaluation epochs. Only one of eval_steps_interval or eval_epochs_interval can be set
max_eval_samples (int) β The maximum number of samples to use during evaluation. Uses the entire evaluation dataset if not provided
tensorboard_dir (Optional[str]) β The directory path for TensorBoard logs. Can be either a local directory or a Google Cloud Storage path
profiler (Optional[kithara.Profiler]) β A profiler instance for monitoring performance metrics
checkpointer (Optional[kithara.Checkpointer]) β A checkpointer instance for saving model checkpoints
- train()#
Execute the main training loop. Handles epoch iteration, batch processing, loss computation, model updates, progress logging, and periodic evaluation.
Example usage:
trainer = Trainer(
model=my_model,
optimizer=keras.optimizers.Adam(learning_rate=1e-4),
train_dataloader=train_loader,
eval_dataloader=eval_loader,
steps=1000,
log_steps_interval=10,
eval_steps_interval=100,
tensorboard_dir="local_dir_or_gs_bucket",
checkpointer= kithara.Checkpointer(
save_dir="local_dir_or_gs_bucket",
save_interval=100,
),
)
trainer.train()
Note#
If both
stepsandepochsare None, defaults to training for 1 epochIf
eval_dataloaderis provided but no evaluation interval is set, defaults to evaluating every epochThe trainer automatically handles data sharding for distributed training