Datasets#

TextCompletionDataset#

class TextCompletionDataset#

A dataset class for standard text completion tasks.

Parameters:
  • source (ray.data.Dataset) – The source Ray dataset containing the text data

  • tokenizer (Optional[AutoTokenizer]) – The tokenizer instance to use

  • tokenizer_handle (Optional[str]) – Handle/name of the HF tokenizer to load if not provided. E.g. “hf://google/gemma-2-2b”

  • column_mapping (Optional[Dict[str, str]]) – Mapping of source column name to expected column name (“text”)

  • model_type (ModelImplementationType | "auto") – Type of model implementation to use. Supported types: “KerasHub”, “MaxText”, “auto”

  • max_seq_len (int) – Maximum sequence length for tokenization (default: 1024)

  • custom_formatting_fn (Optional[callable]) – A custom formatting function to apply to the raw sample before any other transformation steps

  • packing (bool) – Whether to enable sequence packing

to_packed_dataset()#

Converts the current dataset to a PackedDataset for more efficient processing.

Returns:

A new PackedDataset instance

Return type:

PackedDataset

SFTDataset#

class SFTDataset(TextCompletionDataset)#

A dataset class for Supervised Fine-Tuning (SFT) tasks.

Parameters:
  • source (ray.data.Dataset) – The source Ray dataset containing the training data

  • tokenizer (Optional[AutoTokenizer]) – HuggingFace tokenizer instance

  • tokenizer_handle (Optional[str]) – Handle/name of the HF tokenizer to load if not provided. E.g. “hf://google/gemma-2-2b”

  • column_mapping (Optional[Dict[str, str]]) – Mapping of source column names to expected column names (“prompt” and “answer”)

  • model_type (ModelImplementationType | "auto") – Type of model implementation to use. Supported types: “KerasHub”, “MaxText”, “auto”

  • max_seq_len (int) – Maximum sequence length for tokenization (default: 1024)

  • custom_formatting_fn (Optional[callable]) – A custom formatting function to apply to the raw sample before any other transformation steps

to_packed_dataset()#

Converts the current dataset to a PackedDataset for more efficient processing.

Returns:

A new PackedDataset instance

Return type:

PackedDataset

PackedDataset#

class PackedDataset#

A dataset class that packs multiple sequences together on the fly for more efficient processing.

Parameters:
  • source_dataset (TextCompletionDataset) – The source dataset containing unpacked sequences

  • pad_value (int) – The value to use for padding (default: 0)

Note

  • Packing must be used with Flash Attention enabled (which should be enabled by default)

  • Packing currently only works for MaxText models

  • Packing does not currently work for DDP training

Example Usage#

Here’s a simple example of using the TextCompletionDataset:

dataset = TextCompletionDataset(
    source=ray_dataset,
    tokenizer_handle="hf://google/gemma-2-2b",
    max_seq_len=512,
)

For supervised fine-tuning tasks, use the SFTDataset:

sft_dataset = SFTDataset(
    source=ray_dataset,
    tokenizer_handle="hf://google/gemma-2-2b",
    column_mapping={"input": "prompt", "output": "answer"},
    max_seq_len=1024
)

To enable sequence packing for more efficient processing:

packed_dataset = dataset.to_packed_dataset()