GRPO With Replay Buffer ↗
noOriginal Documentation
This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that’ve been used to train a model in prior batches.
Usage#
import torch
from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferConfig, GRPOWithReplayBufferTrainer
from datasets import load_dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
# Guarantee that some rewards have 0 std
def custom_reward_func(completions, **kwargs):
if torch.rand(1).item() < 0.25:
return [0] * len(completions) # simulate some None rewards
else:
return torch.rand(len(completions)).tolist()
training_args = GRPOWithReplayBufferConfig(
output_dir="./tmp",
learning_rate=1e-4,
per_device_train_batch_size=4,
num_generations=4,
max_completion_length=8,
replay_buffer_size=8,
report_to="none",
)
trainer = GRPOWithReplayBufferTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=[custom_reward_func],
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()GRPOWithReplayBufferTrainer[[trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer]]#
trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer[[trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer]]#
traintrl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer.trainhttps://github.com/huggingface/trl/blob/v1.5.1/transformers/trainer.py#L1325[{“name”: “resume_from_checkpoint”, “val”: “: str | bool | None = None”}, {“name”: “trial”, “val”: “: optuna.Trial | dict[str, Any] | None = None”}, {“name”: “ignore_keys_for_eval”, “val”: “: list[str] | None = None”}]- resume_from_checkpoint (str or bool, optional) –
If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a
bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance
of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
- trial (
optuna.Trialordict[str, Any], optional) – The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
list[str], optional) – A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.0~trainer_utils.TrainOutputObject containing the global step count, training loss, and metrics.
Main training entry point.
Parameters:
resume_from_checkpoint (str or bool, optional) : If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
trial (optuna.Trial or dict[str, Any], optional) : The trial run or the hyperparameter dictionary for hyperparameter search.
ignore_keys_for_eval (list[str], optional) : A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.
Returns:
~trainer_utils.TrainOutput
Object containing the global step count, training loss, and metrics.
save_model[[trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer.save_model]]#
Will save the model, so you can reload it using from_pretrained().
Will only save from the main process.
push_to_hub[[trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer.push_to_hub]]#
Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.
Parameters:
commit_message (str, optional, defaults to "End of training") : Message to commit while pushing.
blocking (bool, optional, defaults to True) : Whether the function should return only when the git push has finished.
token (str, optional, defaults to None) : Token with write permission to overwrite Trainer’s original args.
revision (str, optional) : The git revision to commit from. Defaults to the head of the “main” branch.
kwargs (dict[str, Any], optional) : Additional keyword arguments passed along to ~Trainer.create_model_card.
Returns:
The URL of the repository where the model was pushed if blocking=False, or a Future object tracking the
progress of the commit if blocking=True.
GRPOWithReplayBufferConfig[[trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig]]#
trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig[[trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig]]#
New Parameters:
replay_buffer_size (int, optional, defaults to 0):
A cache that stores the rollouts with the highest advantage scores and variance per group. If a new
group has 0 variance, it is replaced with a group sampled from the replay buffer.
ReplayBuffer[[trl.experimental.grpo_with_replay_buffer.ReplayBuffer]]#
trl.experimental.grpo_with_replay_buffer.ReplayBuffer[[trl.experimental.grpo_with_replay_buffer.ReplayBuffer]]#
A simple replay buffer to store and sample previously seen rollouts.