Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ load_full_state_path: ""
# If enable_checkpointing is true, an asynchronous checkpointer will be used if
# async_checkpointing is true, else a synchronous one is used. If you have
# problems with the checkpointer we recommend trying the synchronous one.
enable_checkpointing: true
enable_checkpointing: false
save_checkpoint_on_completion: true
async_checkpointing: true
checkpoint_period: 10_000
Expand Down Expand Up @@ -839,9 +839,7 @@ tpu_num_sparse_cores_to_trace: 2
# - upload xplane profiling, if it is enabled.
# - upload training metrics, at the defined log_period interval.
managed_mldiagnostics: false # Whether to enable the managed diagnostics
managed_mldiagnostics_on_demand_profiling: true # Enable on-demand profiling server by default
managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs.
managed_mldiagnostics_region: "" # Optional. GCP region for managed mldiagnostics. If empty, it will be auto-detected by the SDK.

# Dump HLO and jaxpr options
dump_hlo: false
Expand Down Expand Up @@ -1124,12 +1122,14 @@ remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision en
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg"
video_path: "" # Local video path used for decoding, can be multiple paths separated by comma, exp "/path/video1.mp4,/path/video2.mp4"
video_directory: "" # Local video directory used for SFT training, e.g. "/mounted/LLaVA-Video-178K"
audio_path: "" # Local audio path used for decoding, can be multiple paths separated by comma, exp "/path/audio1.wav,/path/audio2.wav"
image_placeholder: "<|image|>"
video_placeholder: "<|video|>"
audio_placeholder: "<|audio|>"
use_audio_in_video: false
posemb_type_for_vit: "learn"
filter_sft_sequences_by_length: false
# max_num_images_per_example only applies for training when your image column is a list of images.
# -1 means no limit, and will pad to the max possible number of images determined by sequence length.
# Set it to avoid unnecessary padding if you know the maximum number of images per example.
Expand Down
38 changes: 38 additions & 0 deletions src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

base_config: "base.yml"

use_sft: true
use_tunix_gradient_accumulation: true
use_multimodal: true
sft_train_on_completion_only: true
packing: false # packing is not supported yet
freeze_vision_encoder_params: true
learning_rate: 2.e-5

# -------------- Model --------------
model_name: "qwen3-omni-30b-a3b"
tokenizer_path: "Qwen/Qwen3-Omni-30B-A3B-Instruct"

# -------------- HF pipeline --------------
dataset_type: "hf"
hf_path: "parquet"
hf_train_files: "gs://hengtaoguo-maxtext-logs/datasets/LLaVA-Video-178K/0_30_s_academic_v0_1/*.parquet"
train_split: "train"
train_data_columns: ["query", "label"]
train_image_column: "video"

# Local SSD path for videos on the TPU VM
video_directory: "/mounted/LLaVA-Video-178K/0_30_s_academic_v0_1"
5 changes: 5 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,13 +1872,18 @@ class MultimodalGeneral(BaseModel):
description="Maximum number of images per example for training with image lists. -1 means no limit.",
)
video_path: PathStr = Field("", description="Path to a video for decoding.")
video_directory: PathStr = Field("", description="Local directory path containing video files for SFT.")
audio_path: PathStr = Field("", description="Path to an audio file for decoding.")
video_placeholder: str = Field("<|video|>", description="Placeholder string for video in text prompts.")
audio_placeholder: str = Field("<|audio|>", description="Placeholder string for audio in text prompts.")
use_audio_in_video: bool = Field(False, description="Extract and use audio from video files.")
use_mrope: bool = Field(False, description="Enable Multi-dimensional RoPE for Qwen3-Omni models.")
mrope_section: list[int] = Field([24, 20, 20], description="Dimensions for temporal, height, width in MRoPE.")
position_id_per_seconds: int = Field(25, description="Temporal granularity for MRoPE (tokens per second).")
filter_sft_sequences_by_length: bool = Field(
False,
description="Filter out multimodal SFT sequences that exceed max_prefill_predict_length or max_target_length.",
)


class VisionTower(BaseModel):
Expand Down
31 changes: 31 additions & 0 deletions src/maxtext/input_pipeline/hf_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ def vision_sft_preprocessing_pipeline(
"""pipeline for multimodal SFT with HF dataset"""

assert len(text_columns) == 2, f"Need two text_columns for query and response, received {text_columns=}"

# Format conversations if columns are missing
features_keys = list(dataset.features.keys()) if dataset.features else []
if "conversations" in features_keys and not all(col in features_keys for col in text_columns):
def format_llava_video_dataset(example):
conversations = example["conversations"]
query = ""
label = ""
for turn in conversations:
if turn["from"] == "human" and not query:
query = turn["value"]
elif turn["from"] == "gpt" and not label:
label = turn["value"]
example[text_columns[0]] = query
example[text_columns[1]] = label
return example

dataset = dataset.map(format_llava_video_dataset)

# Tunix GA requires per-micro-batch slicing at the data level,
# whereas Native GA processes the full batch and splits it internally.
if config.elastic_enabled:
Expand Down Expand Up @@ -137,6 +156,18 @@ def vision_sft_preprocessing_pipeline(
fn_kwargs={"column_name": text_columns[0], "config": config},
)

# Filter out sequences exceeding max_prefill_predict_length or max_target_length
if getattr(config, "filter_sft_sequences_by_length", False):
max_prefill = getattr(config, "max_prefill_predict_length", 8192)
max_target = getattr(config, "max_target_length", 8192 + 512)

def filter_by_length(example):
prefill_len = len(example[text_columns[0]])
response_len = len(example[text_columns[1]])
return (prefill_len <= max_prefill) and (prefill_len + response_len <= max_target)

dataset = dataset.filter(filter_by_length)

dataset = input_pipeline_utils.HFDataSource(
dataset=dataset,
dataloading_host_index=dataloading_host_index,
Expand Down
36 changes: 26 additions & 10 deletions src/maxtext/input_pipeline/input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,25 @@ def _process_string(string_tensor):

def reformat_prompt(example, column, image_placeholder, model_name):
"""reformat prompt for multimodal SFT"""
if isinstance(example["images"], list):
num_images = len(example["images"])
if isinstance(example["images"], str):
example[column] = mm_processor.reformat_prompt(
example[column], image_placeholder, model_name, num_images=0, video_placeholder=image_placeholder, num_videos=1
)
else:
num_images = 1
example[column] = mm_processor.reformat_prompt(example[column], image_placeholder, model_name, num_images)
if isinstance(example["images"], list):
num_images = len(example["images"])
else:
num_images = 1
example[column] = mm_processor.reformat_prompt(example[column], image_placeholder, model_name, num_images)
return example


def reformat_response(example, column, model_name):
"""reformat response for multimodal SFT"""
example[column] = mm_processor.reformat_response(example[column][0], model_name)
val = example[column]
if isinstance(val, (list, tuple)) and len(val) > 0:
val = val[0]
example[column] = mm_processor.reformat_response(val, model_name)
return example


Expand All @@ -120,9 +128,17 @@ def merge_image_columns(example, image_columns, max_num_images_per_example):


def pre_process_image_sft(example, image_column, config):
"""pre-process image for multimodal SFT"""
"""pre-process image or video for multimodal SFT"""

def _process_image_fn(image):
if isinstance(image, str):
import os

video_directory = getattr(config, "video_directory", "")
if video_directory:
image = os.path.join(video_directory, image)
return mm_processor.preprocess_image_for_training(image, config)

if isinstance(image, list):
image = [np.array(mm_utils.convert_to_RGB(img)) for img in image]
else:
Expand All @@ -131,7 +147,7 @@ def _process_image_fn(image):
image = mm_processor.preprocess_image_for_training(image, config)
return image

example[image_column] = _process_image_fn(example[image_column])
example[image_column] = _process_image_fn(example[image_column]) if example.get(image_column) is not None else None
return example


Expand Down Expand Up @@ -702,12 +718,12 @@ def _pad_image_and_mask(self, preprocessed_image: mm_utils.PreprocessorOutput) -
if not isinstance(preprocessed_image, mm_utils.PreprocessorOutput):
raise TypeError(f"Input must be multimodal_utils.PreprocessorOutput, but got {type(preprocessed_image)}")

if preprocessed_image.pixel_values is None:
raise ValueError("Input preprocessed_image must have pixel_values to pad images.")

if self.config.model_name and self.config.model_name.startswith("qwen3-omni"):
return preprocessed_image

if preprocessed_image.pixel_values is None:
raise ValueError("Input preprocessed_image must have pixel_values to pad images.")

# Determine the maximum number of images/masks allowed.
image_offsets = mm_processor.get_image_offsets(self.config, preprocessed_image)
single_image_offset = image_offsets // preprocessed_image.pixel_values.shape[0]
Expand Down
8 changes: 6 additions & 2 deletions src/maxtext/multimodal/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,13 @@ def preprocess_image_for_training(image, config):

return preprocess_mm_data_llama4(image)
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel
from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training, preprocess_mm_data_qwen3_omni_for_training_video # pylint: disable=import-outside-toplevel

return preprocess_mm_data_qwen3_omni_for_training(image, config)
if isinstance(image, str):
use_audio_in_video = getattr(config, "use_audio_in_video", False)
return preprocess_mm_data_qwen3_omni_for_training_video(image, config)
else:
return preprocess_mm_data_qwen3_omni_for_training(image, config)
else:
raise ValueError(f"Model {config.model_name} not supported for image preprocessing.")

Expand Down
Loading
Loading