Training and Sampling ↗
noOriginal Documentation
Documentation Index#
Fetch the complete documentation index at: https://docs.fireworks.ai/llms.txt Use this file to discover all available pages before exploring further.
End-to-end SDK walkthrough: bootstrap resources, train, checkpoint, and sample through a serving deployment.
What this is#
This is the default lifecycle for research loops that need serving-quality evaluation during training: create an SDK-managed trainer and deployment, run iterative updates, save sampler weights, sync those weights to the deployment, then sample through the deployment.
For production RL, prefer the cookbook recipes. They wrap this same SDK-managed service path and handle batching, reference clients, checkpoints, reconnect, and cleanup.
Workflow#
- Create the managed service with
FiretitanServiceClient.from_firetitan_config(...). - Create a training client with
service.create_training_client(...). - Create a deployment sampler with
service.create_deployment_sampler(...). - Run train steps:
forward_backward_custom(...)+optim_step(...). - Save sampler weights with
training_client.save_weights_for_sampler(...).result(). - Refresh the sampler with
service.create_deployment_sampler(model_path=saved.path, ...). - Sample and evaluate through the deployment endpoint.
The SDK owns trainer provisioning, deployment provisioning, bucket wiring, base-vs-delta sampler checkpoint selection, weight sync, and teardown. You do not construct TrainerJobManager, DeploymentManager, or WeightSyncer for the normal SDK flow.
End-to-end example#
The only training-shape input you choose below is the shape ID. The SDK resolves the versioned trainer shape and linked deployment shape before launch.
1. Bootstrap trainer and deployment#
import os
import tinker
from transformers import AutoTokenizer
from fireworks.training.sdk import (
AdaptiveConcurrencyController,
FiretitanServiceClient,
)
api_key = os.environ["FIREWORKS_API_KEY"]
base_url = os.environ.get("FIREWORKS_BASE_URL", "https://api.fireworks.ai")
base_model = "accounts/fireworks/models/qwen3-8b"
tokenizer_model = "Qwen/Qwen3-8B"
shape_id = "accounts/fireworks/trainingShapes/qwen3-8b-128k-h200"
service = FiretitanServiceClient.from_firetitan_config(
api_key=api_key,
base_url=base_url,
base_model=base_model,
tokenizer_model=tokenizer_model,
lora_rank=0,
training_shape_id=shape_id,
deployment_id="research-serving",
learning_rate=1e-5,
replica_count=1, # deployment replicas for rollout/eval throughput
cleanup_trainer_on_close=True,
cleanup_deployment_on_close="scale_to_zero",
)
training_client = service.create_training_client(base_model=base_model, lora_rank=0)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model, trust_remote_code=True)
concurrency = AdaptiveConcurrencyController(initial_window=16)
sampler = service.create_deployment_sampler(
tokenizer=tokenizer,
concurrency_controller=concurrency,
)
print({"trainer_job_id": service.trainer_job_id, "deployment_id": service.deployment_id})2. Train step with custom objective#
def objective(data, logprobs_list):
loss = compute_objective(data=data, logprobs_list=logprobs_list)
return loss, {"loss": float(loss.item())}
for step in range(total_steps):
# Accumulate gradients client-side: run N forward/backward calls, then one optim_step.
micro_batches = build_micro_batches(step)
for micro_batch in micro_batches:
training_client.forward_backward_custom(micro_batch, objective).result()
training_client.optim_step(
tinker.AdamParams(
learning_rate=1e-5,
beta1=0.9,
beta2=0.999,
eps=1e-8,
weight_decay=0.01,
)
).result()3. Save, sync, sample, evaluate#
import asyncio
if step % eval_interval == 0:
saved = training_client.save_weights_for_sampler(f"step_{step:05d}").result()
# Passing model_path syncs the saved snapshot into the SDK-managed
# deployment and returns a sampler backed by that deployment.
sampler = service.create_deployment_sampler(
model_path=saved.path,
tokenizer=tokenizer,
concurrency_controller=concurrency,
)
completions = asyncio.run(
sampler.sample_with_tokens(
messages=eval_prompts,
n=1,
max_tokens=512,
)
)
score = evaluate_responses(completions)
print({"step": step, "checkpoint": saved.path, "eval_score": score})save_weights_for_sampler(...) returns a future whose .result().path is a public sampler snapshot identity, not a raw storage URI. create_deployment_sampler(model_path=...) consumes that identity, syncs it to the deployment, and returns the FireTitan-native deployment sampler. Use service.create_sampling_client(model_path=...) instead if you need the Tinker-shaped sampling client wrapper.
Concurrency control#
sample_with_tokens(n=K) fans out K concurrent requests. A concurrency controller prevents overloading the deployment:
AdaptiveConcurrencyController(recommended) — automatically adjusts the concurrency window based on the server’s prefill queue latency. Starts atinitial_windowand grows or shrinks between steps using AIMD.FixedConcurrencyController— a static semaphore with a fixed maximum. Use when you already know the right concurrency for your deployment.
See DeploymentSampler — Concurrency Control for full details and configuration options.
Reference clients#
For DPO, GRPO with KL, or any objective that needs frozen-reference logprobs, ask the service for a reference client:
reference_client = service.create_reference_client(base_model, lora_rank=0)
ref = reference_client.forward(datums, "cross_entropy").result()The SDK chooses the backing automatically:
- LoRA policy with no explicit
reference_training_shape_idreuses the policy trainer session with adapters disabled. - Full-parameter policy, or any explicit
reference_training_shape_id, uses a separate forward-only reference trainer owned by the service.
Reconnecting to a running trainer#
If your client disconnects, re-create the service with the existing trainer job ID. The SDK waits for the trainer, reconnects the training client, and can reuse or reattach the deployment:
service = FiretitanServiceClient.from_firetitan_config(
api_key=api_key,
base_url=base_url,
base_model=base_model,
tokenizer_model=tokenizer_model,
lora_rank=0,
training_shape_id=shape_id,
trainer_job_id="<existing-trainer-job-id>",
deployment_id="research-serving",
)
training_client = service.create_training_client(base_model=base_model, lora_rank=0)For DCP train-state resume, load a saved state after creating the client:
training_client.load_state_with_optimizer("step-100").result()Cleanup#
Close the service when the loop exits:
try:
run_training_loop()
finally:
service.close()cleanup_trainer_on_close=True deletes SDK-managed trainers. cleanup_deployment_on_close="scale_to_zero" releases deployment GPUs while keeping the deployment resource around for later reuse; use "delete" only when you want to remove the deployment entirely.
Operational guidance#
- Start from cookbook recipes for SFT, DPO, ORPO, GRPO, IGPO, and async RL; fork them when you need custom loop behavior.
- Use the managed service as the provisioning boundary in direct SDK code. Manager classes are documented only for compatibility and advanced lifecycle debugging.
- Service mode supports both full-parameter and LoRA tuning. Set
lora_rank=0for full-parameter or a positive integer for LoRA. - Use
save_weights_for_sampler(...)for normal sampler refresh. The SDK tracks the base/delta chain and performs weight sync throughcreate_sampling_client(model_path=...)orcreate_deployment_sampler(model_path=...). - Use
save_state(...)for DCP resume checkpoints. Sampler checkpoints are for serving/evaluation and promotion; DCP checkpoints restore training state. - Store the exact prompt set and sampler snapshot path for every evaluation sweep.
Related guides#
- Loss Functions — built-in and custom loss function patterns
- Vision Inputs — fine-tune VLMs with image and text data
- Saving and Loading — checkpoint types and weight sync details
- DeploymentSampler reference — sampling API details
- Cleanup and Teardown — managed service cleanup