SFT Trainer

yes

Editorial Notes

Supervised fine-tuning is the simplest and most common way to adapt a model to your data, and the SFTTrainer is where most TRL users begin. Pay close attention to dataset format: it accepts both language-modeling and prompt-completion shapes and auto-applies the chat template for conversational data, so mismatched formats are the most common source of silent quality loss. Two gotchas worth remembering are that completion-only loss is on by default for prompt-completion datasets, and that training adapters via PEFT usually wants a higher learning rate near 1e-4. Read the TRL overview first, and pair this with the PEFT LoRA guide when you train adapters.


Original Documentation

All_models-SFT-blue smol_course-Chapter_1-yellow

Overview#

TRL supports the Supervised Fine-Tuning (SFT) Trainer for training language models.

This post-training method was contributed by Younes Belkada.

Quick start#

This example demonstrates how to train a language model using the SFTTrainer from TRL. We train a Qwen 3 0.6B model on the Capybara dataset, a compact, diverse multi-turn dataset to benchmark reasoning and generalization.

from trl import SFTTrainer
from datasets import load_dataset

trainer = SFTTrainer(
    model="Qwen/Qwen3-0.6B",
    train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()

Expected dataset type and format#

SFT supports both language modeling and prompt-completion datasets. The SFTTrainer is compatible with both standard and conversational dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

# Standard language modeling
{"text": "The sky is blue."}

# Conversational language modeling
{"messages": [{"role": "user", "content": "What color is the sky?"},
              {"role": "assistant", "content": "It is blue."}]}

# Standard prompt-completion
{"prompt": "The sky is",
 "completion": " blue."}

# Conversational prompt-completion
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
 "completion": [{"role": "assistant", "content": "It is blue."}]}

If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the FreedomIntelligence/medical-o1-reasoning-SFT dataset:

from datasets import load_dataset

dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")

def preprocess_function(example):
    return {
        "prompt": [{"role": "user", "content": example["Question"]}],
        "completion": [
            {"role": "assistant", "content": f"<think>{example['Complex_CoT']}</think>{example['Response']}"}
        ],
    }

dataset = dataset.map(preprocess_function, remove_columns=["Question", "Response", "Complex_CoT"])
print(next(iter(dataset["train"])))
{
    "prompt": [
        {
            "content": "Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?",
            "role": "user",
        }
    ],
    "completion": [
        {
            "content": "<think>Okay, let's see what's going on here. We've got sudden weakness [...] clicks into place!</think>The specific cardiac abnormality most likely to be found in [...] the presence of a PFO facilitating a paradoxical embolism.",
            "role": "assistant",
        }
    ],
}

Looking deeper into the SFT method#

Supervised Fine-Tuning (SFT) is the simplest and most commonly used method to adapt a language model to a target dataset. The model is trained in a fully supervised fashion using pairs of input and output sequences. The goal is to minimize the negative log-likelihood (NLL) of the target sequence, conditioning on the input.

This section breaks down how SFT works in practice, covering the key steps: preprocessing, tokenization and loss computation.

Preprocessing and tokenization#

During training, each example is expected to contain a text field or a (prompt, completion) pair, depending on the dataset format. For more details on the expected formats, see Dataset formats. The SFTTrainer tokenizes each input using the model’s tokenizer. If both prompt and completion are provided separately, they are concatenated before tokenization.

Computing the loss#

sft_figure

The loss used in SFT is the token-level cross-entropy loss, defined as:

$$ \mathcal{L}{\text{SFT}}(\theta) = - \sum{t=1}^{T} \log p_\theta(y_t \mid y_{<t}), $$

where \( y_t \) is the target token at timestep \( t \), and the model is trained to predict the next token given the previous ones. In practice, padding tokens are masked out during loss computation.

The paper On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification proposes an alternative loss function, called Dynamic Fine-Tuning (DFT), which aims to improve generalization by rectifying the reward signal. This method can be enabled by setting loss_type="dft" in the SFTConfig. For more details, see Paper Index - Dynamic Fine-Tuning.

For a memory-efficient variant of the standard loss, set loss_type="chunked_nll" in the SFTConfig. Same math as "nll", but the lm_head projection skips ignored-label tokens and the cross-entropy is processed in chunks, so peak activation memory does not scale with the full vocab × seq_len logits tensor. See Chunked cross-entropy for reducing peak memory usage.

Label shifting and masking#

During training, the loss is computed using a one-token shift: the model is trained to predict each token in the sequence based on all previous tokens. Specifically, the input sequence is shifted right by one position to form the target labels. Padding tokens (if present) are ignored in the loss computation by applying an ignore index (default: -100) to the corresponding positions. This ensures that the loss focuses only on meaningful, non-padding tokens.

Logged metrics#

While training and evaluating we record the following reward metrics:

  • global_step: The total number of optimizer steps taken so far.
  • epoch: The current epoch number, based on dataset iteration.
  • num_tokens: The total number of tokens processed so far.
  • loss: The average cross-entropy loss computed over non-masked tokens in the current logging interval.
  • entropy: The average entropy of the model’s predicted token distribution over non-masked tokens.
  • mean_token_accuracy: The proportion of non-masked tokens for which the model’s top-1 prediction matches the ground truth token.
  • learning_rate: The current learning rate, which may change dynamically if a scheduler is used.
  • grad_norm: The L2 norm of the gradients, computed before gradient clipping.

Customization#

Model initialization#

You can directly pass the kwargs of the from_pretrained() method to the SFTConfig. For example, if you want to load a model in a different precision, analogous to

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16)

you can do so by passing the model_init_kwargs={"dtype": torch.bfloat16} argument to the SFTConfig.

from trl import SFTConfig

training_args = SFTConfig(
    model_init_kwargs={"dtype": torch.bfloat16},
)

Note that all keyword arguments of from_pretrained() are supported.

Packing#

SFTTrainer supports example packing, where multiple examples are packed in the same input sequence to increase training efficiency. To enable packing, simply pass packing=True to the SFTConfig constructor.

training_args = SFTConfig(packing=True)

For more details on packing, see Packing.

Train on assistant messages only#

To train on assistant messages only, use a conversational dataset and set assistant_only_loss=True in the SFTConfig. This setting ensures that loss is computed only on the assistant responses, ignoring user or system messages.

training_args = SFTConfig(assistant_only_loss=True)

train_on_assistant

This functionality requires the chat template to include &#123;% generation %&#125; and &#123;% endgeneration %&#125; keywords. For known model families (e.g. Qwen3), TRL automatically patches the template when assistant_only_loss=True. See Chat Templates for the full list of bundled training templates. For other models, check that your chat template includes these keywords. See HuggingFaceTB/SmolLM3-3B for an example.

Train on completion only#

To train on completion only, use a prompt-completion dataset. By default, the trainer computes the loss on the completion tokens only, ignoring the prompt tokens. If you want to train on the full sequence, set completion_only_loss=False in the SFTConfig.

from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

# Load a prompt-completion dataset; loss is computed on the completion only by default
dataset = load_dataset("trl-lib/kto-mix-14k", split="train")

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    args=SFTConfig(completion_only_loss=True),  # True by default for prompt-completion datasets
    train_dataset=dataset,
)
trainer.train()

train_on_completion

Training on completion only is compatible with training on assistant messages only. In this case, use a conversational prompt-completion dataset and set assistant_only_loss=True in the SFTConfig.

Train adapters with PEFT#

We support tight integration with 🤗 PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model.

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig

dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    "Qwen/Qwen3-0.6B",
    train_dataset=dataset,
    peft_config=LoraConfig(),
)

trainer.train()

You can also continue training your PeftModel. For that, first load a PeftModel outside SFTTrainer and pass it directly to the trainer without the peft_config argument being passed.

from datasets import load_dataset
from trl import SFTTrainer
from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-LoRA", is_trainable=True)
dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
)

trainer.train()

When training adapters, you typically use a higher learning rate (≈1e‑4) since only new parameters are being learned.

SFTConfig(learning_rate=1e-4, ...)

Train with Liger Kernel#

Liger Kernel is a collection of Triton kernels for LLM training that boosts multi-GPU throughput by 20%, cuts memory use by 60% (enabling up to 4× longer context), and works seamlessly with tools like FlashAttention, PyTorch FSDP, and DeepSpeed. For more information, see Liger Kernel Integration.

Rapid Experimentation for SFT#

RapidFire AI is an open-source experimentation engine that sits on top of TRL and lets you launch multiple SFT configurations at once, even on a single GPU. Instead of trying configurations sequentially, RapidFire lets you see all their learning curves earlier, stop underperforming runs, and clone promising ones with new settings in flight without restarting. For more information, see RapidFire AI Integration.

Train with Unsloth#

Unsloth is an open‑source framework for fine‑tuning and reinforcement learning that trains LLMs (like Llama, Mistral, Gemma, DeepSeek, and more) up to 2× faster with up to 70% less VRAM, while providing a streamlined, Hugging Face–compatible workflow for training, evaluation, and deployment. For more information, see Unsloth Integration.

Instruction tuning example#

Instruction tuning teaches a base language model to follow user instructions and engage in conversations. This requires:

  1. Chat template: Defines how to structure conversations into text sequences, including role markers (user/assistant), special tokens, and turn boundaries. Read more about chat templates in Chat templates.
  2. Conversational dataset: Contains instruction-response pairs

This example shows how to transform the Qwen 3 0.6B Base model into an instruction-following model using the Capybara dataset and a chat template from HuggingFaceTB/SmolLM3-3B. The SFT Trainer automatically handles tokenizer updates and special token configuration.

from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

trainer = SFTTrainer(
model="Qwen/Qwen3-0.6B-Base",
args=SFTConfig(
    output_dir="Qwen3-0.6B-Instruct",
    chat_template_path="HuggingFaceTB/SmolLM3-3B",
),
train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()

Some base models, like those from Qwen, have a predefined chat template in the model’s tokenizer. In these cases, it is not necessary to apply clone_chat_template(), as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model’s responses terminate correctly. In these cases, specify eos_token in SFTConfig; for example, for Qwen/Qwen2.5-1.5B, one should set eos_token="<|im_end|>".

Once trained, your model can now follow instructions and engage in conversations using its new chat template.

>>> from transformers import pipeline
>>> pipe = pipeline("text-generation", model="Qwen3-0.6B-Instruct/checkpoint-5000")
>>> prompt = "<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\n"
>>> response = pipe(prompt)
>>> response[0]["generated_text"]
'<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\nThe capital of France is Paris.'

Alternatively, use the structured conversation format (recommended):

>>> prompt = [{"role": "user", "content": "What is the capital of France? Answer in one word."}]
>>> response = pipe(prompt)
>>> response[0]["generated_text"]
[{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, {'role': 'assistant', 'content': 'The capital of France is Paris.'}]

Tool Calling with SFT#

The SFTTrainer fully supports fine-tuning models with tool calling capabilities. In this case, each dataset example should include:

  • The conversation messages, including any tool calls (tool_calls) and tool responses (tool role messages)
  • The list of available tools in the tools column, typically provided as JSON schemas

For details on the expected dataset structure, see the Dataset Format — Tool Calling section.

Training Vision Language Models#

SFTTrainer fully supports training Vision-Language Models (VLMs). To train a VLM, provide a dataset with either an image column (single image per sample) or an images column (list of images per sample). For more information on the expected dataset structure, see the Dataset Format — Vision Dataset section. An example of such a dataset is the LLaVA Instruct Mix.

from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

trainer = SFTTrainer(
model="Qwen/Qwen2.5-VL-3B-Instruct",
args=SFTConfig(max_length=None),
train_dataset=load_dataset("trl-lib/llava-instruct-mix", split="train"),
)
trainer.train()

For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set max_length=None in the SFTConfig. This allows the model to process the full sequence length without truncating image tokens.

SFTConfig(max_length=None, ...)

Only use max_length when you’ve verified that truncation won’t remove image tokens for the entire dataset.

SFTTrainer[[trl.SFTTrainer]]#

trl.SFTTrainer[[trl.SFTTrainer]]#

Source

Trainer for Supervised Fine-Tuning (SFT) method.

This class is a wrapper around the Trainer class and inherits all of its attributes and methods.

Example:

from trl import SFTTrainer
from datasets import load_dataset

dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")

trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=dataset,
)
trainer.train()

traintrl.SFTTrainer.trainhttps://github.com/huggingface/trl/blob/v1.5.1/transformers/trainer.py#L1325[{“name”: “resume_from_checkpoint”, “val”: “: str | bool | None = None”}, {“name”: “trial”, “val”: “: optuna.Trial | dict[str, Any] | None = None”}, {“name”: “ignore_keys_for_eval”, “val”: “: list[str] | None = None”}]- resume_from_checkpoint (str or bool, optional) – If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.

  • trial (optuna.Trial or dict[str, Any], optional) – The trial run or the hyperparameter dictionary for hyperparameter search.
  • ignore_keys_for_eval (list[str], optional) – A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.0~trainer_utils.TrainOutputObject containing the global step count, training loss, and metrics.

Main training entry point.

Parameters:

model (str or PreTrainedModel or PeftModel) : Model to be trained. Can be either: - A string, being the model id of a pretrained model hosted inside a model repo on huggingface.co, or a path to a directory containing model weights saved using save_pretrained, e.g., './my_model_directory/'. The model is loaded using <ModelArchitecture>.from_pretrained (where <ModelArchitecture> is derived from the model config) with the keyword arguments in args.model_init_kwargs. - A PreTrainedModel object. Only causal language models are supported. - A PeftModel object. Only causal language models are supported. If you’re training a model with an MoE architecture and want to include the load balancing/auxiliary loss as a part of the final loss, remember to set the output_router_logits config of the model to True.

args (SFTConfig, optional) : Configuration for this trainer. If None, a default configuration is used.

data_collator (DataCollator, optional) : Function to use to form a batch from a list of elements of the processed train_dataset or eval_dataset. Will default to DataCollatorForLanguageModeling if the model is a language model and DataCollatorForVisionLanguageModeling if the model is a vision-language model. Custom collators must truncate sequences before padding; the trainer does not apply post-collation truncation.

train_dataset (Dataset or IterableDataset) : Dataset to use for training. This trainer supports both language modeling type and prompt-completion type. The format of the samples can be either: - Standard: Each sample contains plain text. - Conversational: Each sample contains structured messages (e.g., role and content). The trainer also supports processed datasets (tokenized) as long as they contain an input_ids field.

eval_dataset (Dataset, IterableDataset or dict[str, Dataset | IterableDataset]) : Dataset to use for evaluation. It must meet the same requirements as train_dataset.

processing_class (PreTrainedTokenizerBase, ProcessorMixin, optional) : Processing class used to process the data. If None, the processing class is loaded from the model’s name with from_pretrained. A padding token, tokenizer.pad_token, must be set. If the processing class has not set a padding token, tokenizer.eos_token will be used as the default.

compute_loss_func (Callable, optional) : A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default loss function used by Trainer.

compute_metrics (Callable[[EvalPrediction], dict], optional) : The function that will be used to compute metrics at evaluation. Must take a EvalPrediction and return a dictionary string to metric values. When passing SFTConfig with batch_eval_metrics set to True, your compute_metrics function must take a boolean compute_result argument. This will be triggered after the last eval batch to signal that the function needs to calculate and return the global summary statistics rather than accumulating the batch-level statistics.

callbacks (list of TrainerCallback, optional) : List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in here. If you want to remove one of the default callbacks used, use the remove_callback method.

optimizers (tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None], optional, defaults to (None, None)) : A tuple containing the optimizer and the scheduler to use. Will default to an instance of AdamW on your model and a scheduler given by get_linear_schedule_with_warmup controlled by args.

optimizer_cls_and_kwargs (tuple[Type[torch.optim.Optimizer], Dict[str, Any]], optional) : A tuple containing the optimizer class and keyword arguments to use. Overrides optim and optim_args in args. Incompatible with the optimizers argument. Unlike optimizers, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.

preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optional) : A function that preprocess the logits right before caching them at each evaluation step. Must take two tensors, the logits and the labels, and return the logits once processed as desired. The modifications made by this function will be reflected in the predictions received by compute_metrics. Note that the labels (second parameter) will be None if the dataset does not have them.

peft_config (PeftConfig, optional) : PEFT configuration used to wrap the model. If None, the model is not wrapped.

formatting_func (Callable, optional) : Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly converts the dataset into a language modeling type.

Returns:

~trainer_utils.TrainOutput

Object containing the global step count, training loss, and metrics.

save_model[[trl.SFTTrainer.save_model]]#

Source

Will save the model, so you can reload it using from_pretrained().

Will only save from the main process.

push_to_hub[[trl.SFTTrainer.push_to_hub]]#

Source

Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.

Parameters:

commit_message (str, optional, defaults to "End of training") : Message to commit while pushing.

blocking (bool, optional, defaults to True) : Whether the function should return only when the git push has finished.

token (str, optional, defaults to None) : Token with write permission to overwrite Trainer’s original args.

revision (str, optional) : The git revision to commit from. Defaults to the head of the “main” branch.

kwargs (dict[str, Any], optional) : Additional keyword arguments passed along to ~Trainer.create_model_card.

Returns:

The URL of the repository where the model was pushed if blocking=False, or a Future object tracking the progress of the commit if blocking=True.

SFTConfig[[trl.SFTConfig]]#

trl.SFTConfig[[trl.SFTConfig]]#

Source

Configuration class for the SFTTrainer.

This class includes only the parameters that are specific to SFT training. For a full list of training arguments, please refer to the TrainingArguments documentation. Note that default values in this class may differ from those in TrainingArguments.

Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.

These parameters have default values different from TrainingArguments:

  • logging_steps: Defaults to 10 instead of 500.
  • gradient_checkpointing: Defaults to True instead of False.
  • bf16: Defaults to True if fp16 is not set, instead of False.
  • learning_rate: Defaults to 2e-5 instead of 5e-5.
Link last verified June 7, 2026. View original ↗
Source: TRL Docs
Link last verified: 2026-06-07