Training#

This document provides a detailed guide for training the policy model for robot manipulation tasks. It covers the overall training workflow, model and dataset configuration, and important hyperparameters. The training pipeline is modular and designed to support various policies, datasets, and fine-tuning options.

Overview of the Training Process#

The entire training process includes the following steps:

  1. Model Initialization: Load pretrained models and configure trainable parameters

  2. Dataset Loading: Configure dataset paths, transforms, and data loaders

  3. Training Configuration: Set up training arguments, optimizers, and schedulers

  4. Training Execution: Run training with checkpointing and logging

  5. Evaluation: Optional evaluation during or after training

Trainer Configuration#

Model Configuration#

We provide multiple policy models through a unified interface in scripts/train/train.py:

  • pi0

  • gr00t_n1

  • gr00t_n1_5

  • dp_clip

  • act_clip

Some models support granular control over which components to fine-tune:

model = model_cls.from_pretrained(
    pretrained_model_name_or_path="path/to/pretrained/model",
    tune_llm=True,  # backbone's LLM
    tune_visual=True,  # backbone's vision tower
    tune_projector=True,  # action head's projector
    tune_diffusion_model=True,  # action head's DiT
)

If you want to add your own model, refer to this document.

Dataset Loading#

The following code snippet demonstrates how to load a dataset using the LeRobotSingleDataset class, which is designed to work with the LeRobot dataset format:

from grmanipulation.dataset.base import LeRobotSingleDataset
from grmanipulation.dataset.embodiment_tags import EmbodimentTag

embodiment_tag = EmbodimentTag(config.embodiment_tag)

train_dataset = LeRobotSingleDataset(
    dataset_path=config.dataset_path,
    embodiment_tags=embodiment_tag,
    modality_configs=modality_configs,
    transform=transforms,
)

Important Hyperparameters#

In scripts/train/train.py, the training scheduler is now configurable through a YAML file using TrainingArguments from the 🤗 transformers library.

This makes it easier to manage, share, and reproduce training configurations. Example usage:

python scripts/train/train.py --config configs/train/pi0_genmanip.yaml


#### Policy and Dataset
```python
policy = ""    # Options: pi0, gr00t_n1, gr00t_n1_5, dp_clip, pi0fast, act_clip
dataset_path = "genmanip-demo"
data_config = "genmanip-v1"     # Data configuration name from DATA_CONFIG_MAP
output_dir = ""                 # Directory to save model checkpoints

Training parameters#

batch_size = 16                  # Batch size per GPU
gradient_accumulation_steps = 1  # Gradient accumulation steps
max_steps = 10000                # Maximum training steps
save_steps = 500                 # Save checkpoints every 500 steps
num_gpus = 1                     # Number of GPUs for training
resume_from_checkpoint = False   # Resume from a checkpoint if available

Learning Rate & Optimizer#

Use cosine annealing with warm-up:

learning_rate = 1e-4     # Learning rate
weight_decay = 1e-5      # Weight decay for AdamW
warmup_ratio = 0.05      # Warm-up ratio for total steps

Model Fine-tuning#

base_model_path = ""          # Path or HuggingFace model ID for base model
tune_llm = False              # Fine-tune language model backbone
tune_visual = True            # Fine-tune vision tower
tune_projector = True         # Fine-tune projector
tune_diffusion_model = True   # Fine-tune diffusion model
use_pretrained_model = False  # Use a pretrained model or not

LoRA Configuration#

lora_rank = 0         # Rank of LORA
lora_alpha = 16       # Alpha value
lora_dropout = 0.1    # Dropout rate

Data Loading#

embodiment_tag = "gr1"      # Embodiment tag (e.g., gr1, new_embodiment)
video_backend = "torchcodec"    # Video backend: decord, torchvision_av, opencv, or torchcodec
dataloader_num_workers = 8  # Number of workers for data loading

⚠️ Note: The default torchcodec works for most video data. Decord supports H.264 videos but cannot handle AV1 format. When processing AV1 videos, torchvision_av may cause communication deadlocks on multi-node setups. See more video standrads.

Miscellaneous#

augsteps = 4         # Number of augmentation steps
report_to = "wandb"  # Logging backend: wandb or tensorboard

⚠️ Note: You need to log in to your own Weights & Biases (wandb) account.