FiretitanServiceClient & TrainingClient

no
Summary: Connect to a trainer endpoint and use the training client for forward/backward passes, optimizer steps, and checkpointing.

Original Documentation

Documentation Index#

Fetch the complete documentation index at: https://docs.fireworks.ai/llms.txt Use this file to discover all available pages before exploring further.

Connect to a trainer endpoint and use the training client for forward/backward passes, optimizer steps, and checkpointing.

Overview#

FiretitanServiceClient is the recommended direct SDK entry point. In the managed path, it creates or reattaches the FireTitan trainer, optional reference trainer, and optional inference deployment, then returns Tinker-compatible training and sampling clients.

For most direct SDK code, create it with FiretitanServiceClient.from_firetitan_config(...). The bare constructor is still useful when you already have a trainer endpoint URL, but that is a lower-level compatibility path.

from fireworks.training.sdk import FiretitanServiceClient, GradAccNormalization

FiretitanServiceClient#

from_firetitan_config(...)#

Create a lazy SDK-managed service. The trainer and deployment are provisioned on the first client call, usually create_training_client(...):

service = FiretitanServiceClient.from_firetitan_config(
    api_key="<FIREWORKS_API_KEY>",
    base_url="https://api.fireworks.ai",
    base_model="accounts/fireworks/models/qwen3-8b",
    tokenizer_model="Qwen/Qwen3-8B",
    lora_rank=0,
    training_shape_id="accounts/fireworks/trainingShapes/qwen3-8b-128k-h200",
    deployment_id="research-serving",   # set create_deployment=False for trainer-only flows
    learning_rate=1e-5,
    replica_count=1,                     # deployment replicas
    cleanup_trainer_on_close=True,
    cleanup_deployment_on_close="scale_to_zero",
)

training_client = service.create_training_client(
    base_model="accounts/fireworks/models/qwen3-8b",
    lora_rank=0,
)

Core managed config fields:

FieldTypeDefaultDescription
api_keystr | NoneFIREWORKS_API_KEYFireworks API key.
base_urlstr | Nonehttps://api.fireworks.aiControl-plane URL.
inference_urlstr | NoneNoneOptional inference gateway URL.
base_modelstrFireworks base model resource name.
tokenizer_modelstr | NoneNoneHuggingFace tokenizer name used by get_tokenizer() and sampler setup.
lora_rankint00 for full-parameter training; positive value for LoRA.
training_shape_idstr | NoneNoneUser-facing training shape ID. The SDK resolves the pinned version.
reference_training_shape_idstr | NoneNoneOptional separate forward-only reference trainer shape.
trainer_job_idstr | NoneNoneReattach to an existing trainer instead of creating one.
reference_trainer_job_idstr | NoneNoneReattach to an existing reference trainer.
create_deploymentboolTrueWhether to create or reattach an inference deployment. Set False for trainer-only SFT/DPO-style loops.
deployment_idstr | NoneNoneCreate or reattach an inference deployment for sampling and weight sync.
deployment_shapestr | NoneLinked shapeOptional deployment shape override. Usually inherited from the training shape.
trainer_replica_countint | NoneNoneData-parallel HSDP replicas for the trainer.
replica_countint1Inference deployment replicas.
cleanup_trainer_on_closeboolFalseDelete the SDK-managed policy trainer when service.close() runs.
cleanup_reference_trainer_on_closeboolTrueDelete SDK-managed separate reference trainers when released/closed.
cleanup_deployment_on_close"scale_to_zero" | "delete" | NoneNoneOptional deployment cleanup action on close.

The managed service exposes resolved metadata after provisioning:

print(service.trainer_job_id)
print(service.deployment_id)
print(service.max_context_length)
print(service.reference_trainer_job_id)  # None when the reference is shared

Bare constructor#

service = FiretitanServiceClient(
    base_url=endpoint.base_url,  # From TrainerJobManager.create_and_wait(...)
    api_key="<FIREWORKS_API_KEY>",
)

base_url is the trainer endpoint URL from TrainerServiceEndpoint.base_url. Use this only when you intentionally manage trainer lifecycle yourself. New user code should use from_firetitan_config(...).

create_training_client(base_model, lora_rank, user_metadata)#

Creates a FiretitanTrainingClient for training operations:

training_client = service.create_training_client(
    base_model="accounts/fireworks/models/qwen3-8b",
    lora_rank=0,  # Must match lora_rank from job creation
)
ParameterTypeDefaultDescription
base_modelstrMust match the trainer job’s base_model
lora_rankint0Must match trainer creation config (0 for full-parameter)
user_metadatadict[str, str] | NoneNoneOptional run metadata

A ValueError is raised if you attempt to create a second training client with the same (base_model, lora_rank) on the same FiretitanServiceClient instance. Create a new FiretitanServiceClient for a separate trainer.

Connecting to an existing trainer#

If you already have a running trainer (e.g. from a previous session), connect directly by URL:

service = FiretitanServiceClient(
    base_url="https://<existing-trainer-url>",
    api_key="<FIREWORKS_API_KEY>",
)
training_client = service.create_training_client(
    base_model="accounts/fireworks/models/qwen3-8b",
    lora_rank=0,
)

create_base_training_client(base_model, user_metadata=None)#

Creates a base-only client on the same trainer session. Use this as a frozen reference for LoRA KL/reference logprobs without launching a separate forward-only trainer:

reference_client = service.create_base_training_client(base_model=base_model)
ref = reference_client.forward(datums, "cross_entropy").result()

Do not call forward_backward, forward_backward_custom, or optim_step on this client; it is for reference forward passes only.

create_reference_client(base_model, lora_rank=0, user_metadata=None)#

Create a frozen reference client for KL/DPO baseline logprobs:

reference_client = service.create_reference_client(base_model, lora_rank=0)
ref = reference_client.forward(datums, "cross_entropy").result()

The SDK chooses the backing automatically. LoRA policies without an explicit reference shape reuse the policy trainer with the adapter disabled. Full-parameter policies, explicit reference_training_shape_id, or explicit reference_trainer_job_id use a separate forward-only reference trainer owned by the service.

create_sampling_client(model_path=None, ...)#

Return a Tinker-shaped sampling client backed by the SDK-managed deployment. When model_path is provided, the SDK first syncs that sampler snapshot to the deployment:

saved = training_client.save_weights_for_sampler("step-100").result()
sampler = service.create_sampling_client(model_path=saved.path)

This is the replacement for calling a standalone weight-sync helper in user code. The SDK tracks the base/delta chain and builds the weight-sync metadata internally.

create_deployment_sampler(model_path=None, tokenizer=None, concurrency_controller=None)#

Return the FireTitan-native DeploymentSampler directly. Use this when you need tokenized completions, inference logprobs, routing matrices, or adaptive concurrency:

sampler = service.create_deployment_sampler(
    model_path=saved.path,
    tokenizer=tokenizer,
    concurrency_controller=controller,
)

hotload_sampler_snapshot(model_path)#

Low-level method for syncing a previously saved sampler snapshot into the SDK-managed deployment without constructing a sampler:

service.hotload_sampler_snapshot(saved.path)

FiretitanTrainingClient#

The training client returned by create_training_client(). Core training RPCs like forward(...), forward_backward_custom(...), optim_step(...), save_state(...), and load_state_with_optimizer(...) return futures. Fireworks convenience helpers like save_weights_for_sampler_ext(...), list_checkpoints(), and resolve_checkpoint_path(...) return concrete values directly.

forward(datums, loss_type)#

Forward-only pass (no gradient computation). Useful for computing reference logprobs in GRPO/DPO:

result = training_client.forward(datums, "cross_entropy").result()
logprobs = result.loss_fn_outputs[0]["logprobs"].data

Built-in loss types like "cross_entropy" require datums with target_tokens in loss_fn_inputs. Datums built with datum_from_model_input_weights will fail. Use the target-token tinker.Datum example in Loss Functions for built-in losses, or use forward_backward_custom with the weight-based format in Building datums and the custom-loss pattern in Example: simple cross-entropy.

forward_backward_custom(datums, loss_fn)#

Forward + backward with your custom loss function. See Loss Functions for details:

def my_loss(data, logprobs_list):
    loss = compute_loss(data, logprobs_list)
    return loss, {"loss": float(loss.item())}

result = training_client.forward_backward_custom(datums, my_loss).result()
print(result.metrics)  # {"loss": 0.42}

For embedding-space objectives, pass output="embedding" and choose pooling="mean" or "last"; your loss function then receives pooled embedding tensors instead of logprobs:

result = training_client.forward_backward_custom(
    datums,
    embedding_loss,
    output="embedding",
    pooling="mean",
).result()

optim_step(adam_params, grad_accumulation_normalization=None)#

Apply optimizer update after accumulating gradients:

import tinker

training_client.optim_step(
    tinker.AdamParams(
        learning_rate=1e-5,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        weight_decay=0.01,
    )
).result()

Supports grad_accumulation_normalization for controlling how accumulated gradients are normalized. Pass GradAccNormalization.NUM_LOSS_TOKENS, GradAccNormalization.NUM_SEQUENCES, or GradAccNormalization.NONE rather than raw strings. See Loss Functions for when to use each mode.

save_weights_for_sampler(name, ttl_seconds=None, checkpoint_type=None)#

Save serving-compatible sampler weights and return a future. This is the normal Tinker-shaped API:

saved = training_client.save_weights_for_sampler(
    "step-100",
    checkpoint_type="base",  # optional: "base" or "delta"
).result()
print(saved.path)  # Snapshot identity for create_sampling_client(model_path=...)

Full-parameter training saves a base checkpoint first and deltas after that by default. LoRA training always saves base checkpoints. The returned path is a public snapshot identity, not a raw storage URI.

save_weights_for_sampler_ext(name, checkpoint_type, ttl_seconds)#

Fireworks-specific extension that returns a concrete SaveSamplerResult instead of a future:

result = training_client.save_weights_for_sampler_ext(
    "step-100",
    checkpoint_type="base",  # "base" for full weights, "delta" for incremental
)
print(result.snapshot_name)  # Session-qualified name for weight sync
ParameterTypeDefaultDescription
namestrCheckpoint name (auto-suffixed with session ID)
checkpoint_typestr | NoneNone"base" for full weights, "delta" for incremental
ttl_secondsint | NoneNoneAuto-delete checkpoint after this many seconds

On full-parameter training, only checkpoint_type="base" produces a promotable blob; "delta" cannot be promoted. LoRA is always promotable. See Checkpoint kinds for the full promotability matrix.

save_weights_for_sampler_ext saves the snapshot only; it does not mutate a deployment. To serve the snapshot, pass result.snapshot_name to the managed service weight-sync path, or use create_sampling_client(model_path=...) / create_deployment_sampler(model_path=...), which sync and return a sampler.

save_state(name, ttl_seconds=None, timeout=None)#

Save full train state (weights + optimizer) for resume:

training_client.save_state("train_state_step_100").result()
ParameterTypeDefaultDescription
namestrCheckpoint name
ttl_secondsint | NoneNoneAuto-delete checkpoint after this many seconds
timeoutfloat | NoneNoneIf set, block until the save completes or the timeout expires

load_state_with_optimizer(name)#

Restore full train state (weights + optimizer) from a checkpoint:

training_client.load_state_with_optimizer("train_state_step_100").result()

load_state(name)#

Load model weights from a checkpoint without restoring optimizer state. The optimizer is reset so the next optim_step starts fresh:

training_client.load_state("train_state_step_100").result()

load_adapter(adapter_path)#

Load Hugging Face PEFT adapter weights into the current LoRA session. This is a weights-only warm start; it does not restore optimizer state, scheduler state, or data cursor.

training_client.load_adapter("gs://my-bucket/adapters/run-42").result()

list_checkpoints()#

List available DCP checkpoints from the trainer. Returns a list[str]:

checkpoint_names = training_client.list_checkpoints()
print(checkpoint_names)  # e.g. ["step-2", "step-4"]

resolve_checkpoint_path(checkpoint_name, source_job_id)#

Resolve a checkpoint path for cross-job resume:

checkpoint_ref = training_client.resolve_checkpoint_path(
    "step-4",
    source_job_id="previous-job-id",
)
training_client.load_state_with_optimizer(checkpoint_ref).result()

SaveSamplerResult#

Returned by save_weights_for_sampler_ext:

FieldTypeDescription
pathstrSnapshot name from trainer
snapshot_namestrSession-qualified name for weight sync operations

GradAccNormalization#

Enum for optim_step’s grad_accumulation_normalization parameter:

EnumWire valueDescription
GradAccNormalization.NUM_LOSS_TOKENS"num_loss_tokens"Normalize by total loss tokens across accumulated micro-batches
GradAccNormalization.NUM_SEQUENCES"num_sequences"Normalize by total sequences across accumulated micro-batches
GradAccNormalization.NONE"none"No normalization (raw gradient sum)
Link last verified June 7, 2026. View original ↗
Source: Fireworks AI Docs
Link last verified: 2026-06-07