Datasets#
TextCompletionDataset#
- class TextCompletionDataset#
A dataset class for standard text completion tasks.
- Parameters:
source (ray.data.Dataset) – The source Ray dataset containing the text data
tokenizer (Optional[AutoTokenizer]) – The tokenizer instance to use
tokenizer_handle (Optional[str]) – Handle/name of the HF tokenizer to load if not provided. E.g. “hf://google/gemma-2-2b”
column_mapping (Optional[Dict[str, str]]) – Mapping of source column name to expected column name (“text”)
model_type (ModelImplementationType | "auto") – Type of model implementation to use. Supported types: “KerasHub”, “MaxText”, “auto”
max_seq_len (int) – Maximum sequence length for tokenization (default: 1024)
custom_formatting_fn (Optional[callable]) – A custom formatting function to apply to the raw sample before any other transformation steps
packing (bool) – Whether to enable sequence packing
- to_packed_dataset()#
Converts the current dataset to a PackedDataset for more efficient processing.
- Returns:
A new PackedDataset instance
- Return type:
SFTDataset#
- class SFTDataset(TextCompletionDataset)#
A dataset class for Supervised Fine-Tuning (SFT) tasks.
- Parameters:
source (ray.data.Dataset) – The source Ray dataset containing the training data
tokenizer (Optional[AutoTokenizer]) – HuggingFace tokenizer instance
tokenizer_handle (Optional[str]) – Handle/name of the HF tokenizer to load if not provided. E.g. “hf://google/gemma-2-2b”
column_mapping (Optional[Dict[str, str]]) – Mapping of source column names to expected column names (“prompt” and “answer”)
model_type (ModelImplementationType | "auto") – Type of model implementation to use. Supported types: “KerasHub”, “MaxText”, “auto”
max_seq_len (int) – Maximum sequence length for tokenization (default: 1024)
custom_formatting_fn (Optional[callable]) – A custom formatting function to apply to the raw sample before any other transformation steps
- to_packed_dataset()#
Converts the current dataset to a PackedDataset for more efficient processing.
- Returns:
A new PackedDataset instance
- Return type:
PackedDataset#
- class PackedDataset#
A dataset class that packs multiple sequences together on the fly for more efficient processing.
- Parameters:
source_dataset (TextCompletionDataset) – The source dataset containing unpacked sequences
pad_value (int) – The value to use for padding (default: 0)
Note
Packing must be used with Flash Attention enabled (which should be enabled by default)
Packing currently only works for MaxText models
Packing does not currently work for DDP training
Example Usage#
Here’s a simple example of using the TextCompletionDataset:
dataset = TextCompletionDataset(
source=ray_dataset,
tokenizer_handle="hf://google/gemma-2-2b",
max_seq_len=512,
)
For supervised fine-tuning tasks, use the SFTDataset:
sft_dataset = SFTDataset(
source=ray_dataset,
tokenizer_handle="hf://google/gemma-2-2b",
column_mapping={"input": "prompt", "output": "answer"},
max_seq_len=1024
)
To enable sequence packing for more efficient processing:
packed_dataset = dataset.to_packed_dataset()