Trainer#

class Trainer#

A base Trainer class supporting continued pretraining and SFT, designed to be subclassed for implementing other training objectives, e.g. DPO.

Parameters:
  • model (kithara.Model) – The model to be trained or evaluated

  • optimizer (keras.Optimizer) – The optimizer used for training

  • train_dataloader (kithara.Dataloader) – A dataloader that provides training batches

  • eval_dataloader (Optional[kithara.Dataloader]) – A dataloader that provides evaluation batches

  • steps (Optional[int]) – The total number of training steps to execute. Defaults to None and trains 1 epoch

  • epochs (Optional[int]) – The total number of training epochs to execute. Defaults to None

  • log_steps_interval (int) – The interval between logging steps. Each log includes the current loss value and performance metrics

  • eval_steps_interval (Optional[int]) – The interval between evaluation steps. Only one of eval_steps_interval or eval_epochs_interval can be set

  • eval_epochs_interval (Optional[int]) – The interval between evaluation epochs. Only one of eval_steps_interval or eval_epochs_interval can be set

  • max_eval_samples (int) – The maximum number of samples to use during evaluation. Uses the entire evaluation dataset if not provided

  • tensorboard_dir (Optional[str]) – The directory path for TensorBoard logs. Can be either a local directory or a Google Cloud Storage path

  • profiler (Optional[kithara.Profiler]) – A profiler instance for monitoring performance metrics

  • checkpointer (Optional[kithara.Checkpointer]) – A checkpointer instance for saving model checkpoints

train()#

Execute the main training loop. Handles epoch iteration, batch processing, loss computation, model updates, progress logging, and periodic evaluation.

Example usage:

trainer = Trainer(
    model=my_model,
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    train_dataloader=train_loader,
    eval_dataloader=eval_loader,
    steps=1000,
    log_steps_interval=10,
    eval_steps_interval=100,
    tensorboard_dir="local_dir_or_gs_bucket",
    checkpointer= kithara.Checkpointer(
        save_dir="local_dir_or_gs_bucket",
        save_interval=100,
    ),
)

trainer.train()

Note#

  • If both steps and epochs are None, defaults to training for 1 epoch

  • If eval_dataloader is provided but no evaluation interval is set, defaults to evaluating every epoch

  • The trainer automatically handles data sharding for distributed training