# Training This document provides a detailed guide for training the policy model for robot manipulation tasks. It covers the overall **[training workflow](#overview-of-the-training-process)**, **[model and dataset configuration](#trainer-configuration)**, and **[important hyperparameters](#important-hyperparameters)**. The training pipeline is modular and designed to support various policies, datasets, and fine-tuning options. ## Overview of the Training Process The entire training process includes the following steps: 1. **Model Initialization**: Load pretrained models and configure trainable parameters 2. **Dataset Loading**: Configure dataset paths, transforms, and data loaders 3. **Training Configuration**: Set up training arguments, optimizers, and schedulers 4. **Training Execution**: Run training with checkpointing and logging 5. **Evaluation**: Optional evaluation during or after training ## Trainer Configuration ### Model Configuration We provide multiple policy models through a unified interface in `scripts/train/train.py`: - `pi0` - `gr00t_n1` - `gr00t_n1_5` - `dp_clip` - `act_clip` Some models support granular control over which components to fine-tune: ```python model = model_cls.from_pretrained( pretrained_model_name_or_path="path/to/pretrained/model", tune_llm=True, # backbone's LLM tune_visual=True, # backbone's vision tower tune_projector=True, # action head's projector tune_diffusion_model=True, # action head's DiT ) ``` If you want to add your own model, refer to [this document](./model.md). ### Dataset Loading The following code snippet demonstrates how to load a dataset using the `LeRobotSingleDataset` class, which is designed to work with the LeRobot dataset format: ```python from grmanipulation.dataset.base import LeRobotSingleDataset from grmanipulation.dataset.embodiment_tags import EmbodimentTag embodiment_tag = EmbodimentTag(config.embodiment_tag) train_dataset = LeRobotSingleDataset( dataset_path=config.dataset_path, embodiment_tags=embodiment_tag, modality_configs=modality_configs, transform=transforms, ) ``` ### Important Hyperparameters In `scripts/train/train.py`, the training scheduler is now configurable through a YAML file using `TrainingArguments` from the 🤗 `transformers` library. This makes it easier to manage, share, and reproduce training configurations. Example usage: ```bash python scripts/train/train.py --config configs/train/pi0_genmanip.yaml #### Policy and Dataset ```python policy = "" # Options: pi0, gr00t_n1, gr00t_n1_5, dp_clip, pi0fast, act_clip dataset_path = "genmanip-demo" data_config = "genmanip-v1" # Data configuration name from DATA_CONFIG_MAP output_dir = "" # Directory to save model checkpoints ``` #### Training parameters ```python batch_size = 16 # Batch size per GPU gradient_accumulation_steps = 1 # Gradient accumulation steps max_steps = 10000 # Maximum training steps save_steps = 500 # Save checkpoints every 500 steps num_gpus = 1 # Number of GPUs for training resume_from_checkpoint = False # Resume from a checkpoint if available ``` #### Learning Rate & Optimizer Use cosine annealing with warm-up: ```python learning_rate = 1e-4 # Learning rate weight_decay = 1e-5 # Weight decay for AdamW warmup_ratio = 0.05 # Warm-up ratio for total steps ``` #### Model Fine-tuning ```python base_model_path = "" # Path or HuggingFace model ID for base model tune_llm = False # Fine-tune language model backbone tune_visual = True # Fine-tune vision tower tune_projector = True # Fine-tune projector tune_diffusion_model = True # Fine-tune diffusion model use_pretrained_model = False # Use a pretrained model or not ``` #### LoRA Configuration ```python lora_rank = 0 # Rank of LORA lora_alpha = 16 # Alpha value lora_dropout = 0.1 # Dropout rate ``` #### Data Loading ```python embodiment_tag = "gr1" # Embodiment tag (e.g., gr1, new_embodiment) video_backend = "torchcodec" # Video backend: decord, torchvision_av, opencv, or torchcodec dataloader_num_workers = 8 # Number of workers for data loading ``` > ⚠️ Note: The default torchcodec works for most video data. Decord supports H.264 videos but cannot handle AV1 format. When processing AV1 videos, torchvision_av may cause communication deadlocks on multi-node setups. [See more video standrads](https://github.com/huggingface/lerobot/tree/main/benchmarks/video). #### Miscellaneous ```python augsteps = 4 # Number of augmentation steps report_to = "wandb" # Logging backend: wandb or tensorboard ``` > ⚠️ Note: You need to log in to your own Weights & Biases (wandb) account.