Model#
MaxTextModel#
from_random#
- classmethod MaxTextModel.from_random(model_name: str, seq_len: int = 2048, per_device_batch_size: int = 1, precision: str = 'mixed_bfloat16', scan_layers: bool = False, maxtext_config_args: dict | None = None) 'MaxTextModel'#
Create a randomly initialized MaxText model with the given configuration.
- Parameters:
model_name – Name of the MaxText model configuration to use. Supported: “default”, “llama2-7b”, “llama2-13b”, “llama2-70b”, “llama3-8b”, “llama3-70b”, “llama3.1-8b”, “llama3.1-70b”, “llama3.1-405b”, “llama3.3-70b”, “mistral-7b”, “mixtral-8x7b”, “mixtral-8x22b”, “deepseek3-671b”, “gemma-7b”, “gemma-2b”, “gemma2-2b”, “gemma2-9b”, “gemma2-27b”, “gpt3-175b”, “gpt3-22b”, “gpt3-6b”, “gpt3-52k”
seq_len – Maximum sequence length (default: 2048)
per_device_batch_size – Batch size per device (default: 1)
precision – Precision mode for computations. Supported policies include “float32”, “float16”, “bfloat16”, “mixed_float16”, and “mixed_bfloat16”. Mixed precision policies load model weight in float32 and casts activations to the specified dtype. (default: “mixed_bfloat16”)
scan_layers – Whether to use scan layers for memory efficiency. Set to True for models <9B for performance gain. (default: False)
maxtext_config_args – Additional MaxText configuration arguments (default: None)
- Returns:
A new instance of MaxTextModel with random initialization
Example Usage:
model = MaxTextModel.from_random( "gemma2-2b", seq_len=8192, # Seq len and batch size need to be specified up front per_device_batch_size=1, scan_layers=True # Set to True for models <9B for performance gain )
from_preset#
- classmethod MaxTextModel.from_preset(preset_handle: str, seq_len: int = 2048, per_device_batch_size: int = 1, precision: str = 'mixed_bfloat16', scan_layers: bool = False, maxtext_config_args: dict | None = None) 'MaxTextModel'#
Create a MaxText model initialized with weights from HuggingFace Hub.
- Parameters:
preset_handle – HuggingFace model identifier for the supported model architectures. Can be: - HuggingFace Hub path (e.g “gs://google/gemma-2-2b”) - Local HuggingFace checkpoint path (e.g. “tmp/my_model/checkpoint”) - GCS HuggingFace checkpoint path (e.g. “gs://bucket_name/my_model/checkpoint”)
seq_len – Maximum sequence length (default: 2048)
per_device_batch_size – Batch size per device (default: 1)
precision – Precision mode for computations. Supported policies include “float32”, “float16”, “bfloat16”, “mixed_float16”, and “mixed_bfloat16”. Mixed precision policies load model weight in float32 and casts activations to the specified dtype. (default: “mixed_bfloat16”)
scan_layers – Whether to use scan layers. Set to True for models <9B for performance gain. (default: False)
maxtext_config_args – Additional configuration arguments (default: None)
- Returns:
A new instance of MaxTextModel initialized with pretrained weights
Example Usage:
model = MaxTextModel.from_preset( "hf://google/gemma-2-2b", # HuggingFace model seq_len=8192, # Seq len and batch size need to be specified up front per_device_batch_size=1, scan_layers=True # Set to True for models <9B for performance gain )
save_in_hf_format#
- save_in_hf_format(output_dir: str, dtype: str = 'auto', parallel_threads: int = 8)#
Save the model in HuggingFace format, including:
Model configuration file (config.json)
Model weights file (model.safetensors for models smaller than DEFAULT_MAX_SHARD_SIZE, model-x-of-x.safetensors for larger models)
Safe tensors index file (model.safetensors.index.json)
- Parameters:
output_dir – Directory path where the model should be saved. Can be a local folder (e.g. “foldername/”), HuggingFaceHub repo prefixed with “hf://” (e.g. “hf://your_hf_id/repo_name”) or a Google cloud storage path prefixed with “gs://” (e.g. “gs://your_bucket/folder_name), and will be created if it doesn’t exist.
dtype – Data type for saved weights. Defaults to “auto” which saves the model in its current precision type. (default: “auto”)
parallel_threads – Number of parallel threads to use for saving (default: 8). Note: Local system must have at least parallel_threads * DEFAULT_MAX_SHARD_SIZE free disk space, as each thread maintains a local cache of size DEFAULT_MAX_SHARD_SIZE
generate#
- generate(inputs: str | List[str] | List[int] | np.ndarray | List[np.ndarray] | List[List[int]], max_length: int = 100, stop_token_ids: str | List[int] = 'auto', strip_prompt: bool = False, tokenizer: AutoTokenizer | None = None, tokenizer_handle: str | None = None, return_decoded: bool = True, skip_special_tokens: bool = True, **kwargs) List[str] | Dict[str, np.ndarray]#
Generate text tokens using the model.
- Parameters:
inputs – Inputs can be either string or integer tokens. String inputs can be a single string, or a list of strings. Token inputs can be a numpy array, a list of numpy arrays, an integer array, or a list of integer arrays. If strings are provided, one of tokenizer and tokenizer_handle must be provided.
max_length – Maximum total sequence length (prompt + generated tokens). If tokenizer and tokenizer_handle are None, inputs should be padded to the desired maximum length and this argument will be ignored. When inputs is string, this value must be provided. (default: 100)
stop_token_ids – List of token IDs that stop generation. Defaults to “auto”, which extracts the end token id from the tokenizer.
strip_prompt – If True, returns only the generated tokens without the input prompt. If False, returns the full sequence including the prompt. (default: False)
tokenizer – Optional AutoTokenizer instance.
tokenizer_handle – Optional HuggingFace tokenizer identifier string. E.g. “google/gemma-2-2b”.
return_decoded – If True, returns the decoded text using the tokenizer, otherwise return the predicted tokens. (default: True). This option must be set to False if no tokenizer is provided.
skip_special_tokens – Whether to remove special tokens from the decoded text. Only used when return_decoded is True. (default: True)
- Returns:
A list of strings if input is text, or a dictionary containing ‘token_ids’ (Generated token IDs [B, S]) and ‘padding_mask’ (Attention mask [B, S]) if return_decoded is False.
Example Usage:
# Return tokens prompt = "what is your name?" pred_tokens = model.generate(prompt, max_length=100, tokenizer_handle="hf://google/gemma-2-2b") print(pred_tokens) # Return text pred_text = model.generate(prompt, max_length=100, tokenizer_handle="hf://google/gemma-2-2b", return_decoded=True, strip_prompt=True) print(pred_text) # Use an initialized tokenizer from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("hf://google/gemma-2-2b") pred_text = model.generate(prompt, max_length=100, tokenizer=tokenizer, return_decoded=True, strip_prompt=True)
KerasHubModel#
from_preset#
- classmethod KerasHubModel.from_preset(preset_handle: str, lora_rank: int | None = None, precision: str = 'mixed_bfloat16', sharding_strategy: ShardingStrategy | None = None, **kwargs) 'KerasHubModel'#
Create a KerasHub model initialized with weights from various sources, with optional LoRA adaptation.
- Parameters:
preset_handle – Model identifier that can be: - A built-in KerasHub preset identifier (e.g., “bert_base_en”) - A Kaggle Models handle (e.g., “kaggle://user/bert/keras/bert_base_en”) - A Hugging Face handle (e.g., “hf://user/bert_base_en”) - A local directory path (e.g., “./bert_base_en”)
lora_rank – Rank for LoRA adaptation. If None, LoRA is disabled. When enabled, LoRA is applied to the q_proj and v_proj layers. (default: None)
precision – Precision mode for computations. Supported policies include “float32”, “float16”, “bfloat16”, “mixed_float16”, and “mixed_bfloat16”. Mixed precision policies load model weights in float32 and cast activations to the specified dtype. (default: “mixed_bfloat16”)
sharding_strategy – Strategy for distributing model parameters, optimizer states, and data tensors. If None, tensors will be sharded using FSDP. Use kithara.ShardingStrategy to configure custom sharding. (default: None)
- Returns:
A new instance of KerasHubModel initialized with the specified configuration
Example Usage:
# Initialize a model with LoRA adaptation model = KerasHubModel.from_preset( "hf://google/gemma-2-2b", lora_rank=4 )
save_in_hf_format#
- save_in_hf_format(output_dir: str, dtype: str = 'auto', only_save_adapters: bool = False, save_adapters_separately: bool = False, parallel_threads: int = 8)#
Save the model in HuggingFace format, including configuration and weights files.
- Parameters:
output_dir – Directory path where the model should be saved. Can be a local folder (e.g. “foldername/”), HuggingFaceHub repo prefixed with “hf://” (e.g. “hf://your_hf_id/repo_name”) or a Google cloud storage path prefixed with “gs://” (e.g. “gs://your_bucket/folder_name), and will be created if it doesn’t exist.
dtype – Data type for saved weights. Defaults to “auto” which saves the model in its current precision type.
only_save_adapters – If True, only adapter weights will be saved. If False, both base model weights and adapter weights will be saved. (default: False)
save_adapters_separately – If False, adapter weights will be merged with base model. If True, adapter weights will be saved separately in HuggingFace’s peft format. (default: False)
parallel_threads – Number of parallel threads to use for saving (default: 8). Note: Local system must have at least parallel_threads * DEFAULT_MAX_SHARD_SIZE free disk space, as each thread maintains a local cache of size DEFAULT_MAX_SHARD_SIZE
Example Usage:
# Save full model model.save_in_hf_format("./output_dir") # Save only LoRA adapters model.save_in_hf_format( "./adapter_weights", only_save_adapters=True, )
generate#
- generate(inputs: str | List[str] | List[int] | np.ndarray | List[np.ndarray] | List[List[int]], max_length: int = 100, stop_token_ids: str | List[int] = 'auto', strip_prompt: bool = False, tokenizer: AutoTokenizer | None = None, tokenizer_handle: str | None = None, return_decoded: bool = True, skip_special_tokens: bool = True, **kwargs) List[str] | Dict[str, np.ndarray]#
Generate text tokens using the model.
- Parameters:
inputs – Inputs can be either string or integer tokens. String inputs can be a single string, or a list of strings. Token inputs can be a numpy array, a list of numpy arrays, an integer array, or a list of integer arrays. If strings are provided, one of tokenizer and tokenizer_handle must be provided.
max_length – Maximum total sequence length (prompt + generated tokens). If tokenizer and tokenizer_handle are None, inputs should be padded to the desired maximum length and this argument will be ignored. When inputs is string, this value must be provided. (default: 100)
stop_token_ids – List of token IDs that stop generation. Defaults to “auto”, which extracts the end token id from the tokenizer.
strip_prompt – If True, returns only the generated tokens without the input prompt. If False, returns the full sequence including the prompt. (default: False)
tokenizer – Optional AutoTokenizer instance.
tokenizer_handle – Optional HuggingFace tokenizer identifier string. E.g. “google/gemma-2-2b”.
return_decoded – If True, returns the decoded text using the tokenizer, otherwise return the predicted tokens. (default: True). This option must be set to False if no tokenizer is provided.
skip_special_tokens – Whether to remove special tokens from the decoded text. Only used when return_decoded is True. (default: True)
- Returns:
A list of strings if input is text, or a dictionary containing ‘token_ids’ (Generated token IDs [B, S]) and ‘padding_mask’ (Attention mask [B, S]) if return_decoded is False.
Example Usage:
# Return tokens prompt = "what is your name?" pred_tokens = model.generate(prompt, max_length=100, tokenizer_handle="hf://google/gemma-2-2b") print(pred_tokens) # Return text pred_text = model.generate(prompt, max_length=100, tokenizer_handle="hf://google/gemma-2-2b", return_decoded=True, strip_prompt=True) print(pred_text) # Use an initialized tokenizer from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("hf://google/gemma-2-2b") pred_text = model.generate(prompt, max_length=100, tokenizer=tokenizer, return_decoded=True, strip_prompt=True)