Trainer SDK reference¶
The SDK is for custom training loops and algorithm experiments. Use it when the CLI is too high-level and you want to control rollout, reward calculation, advantage construction, loss selection, role scoring, or checkpoint cadence directly from Python.
- class areno.Trainer(world_size, model_path, backend_type=None, custom_config=None, metrics_log_dir=None)¶
Main entry point for local Areno training workflows.
Trainerinitializes tokenizer and backend workers, generates rollout batches, runs policy training steps, manages PPO/DPO auxiliary roles, scores logprobs/values/rewards, and saves Hugging Face-compatible checkpoints.It provides methods to:
create a local tensor-parallel Areno backend
load prompt batches from dataset-like objects
generate text rollouts from string prompts or token ids
run agentic rollouts through a local OpenAI-compatible proxy
train policy batches with caller-provided loss functions
prepare reference, reward, and critic roles for PPO/DPO workflows
score logprobs, values, and rewards through backend-owned roles
save Hugging Face-compatible checkpoints
Direct rollout calls must run inside
async with trainer.rollout_session(...). The session is the lifecycle boundary for rollout state, actor onload/offload, and optional agentic proxy serving.Typical flow
import asyncio import areno from areno import Trainer async def main(): # Near-instant: constructs the Python wrapper only. trainer = Trainer( world_size=1, model_path="Qwen/Qwen3.5-4B", backend_type=areno.Areno, custom_config=areno.ArenoConfig(tp_size=1), ) # Takes a moment: loads tokenizer, starts workers, loads checkpoint. trainer.init() # Rollout calls must run inside an explicit rollout session. The # session owns actor onload/offload and rollout-state cleanup. sampling = areno.SamplingParams(max_new_tokens=128) async with trainer.rollout_session(sampling_params=sampling, proxy=False): rollout = trainer.rollout_batch(["Solve 12 * 13."], n_samples=1, sampling_params=sampling) # Build TrainSequence rows and a loss function for your algorithm, # then run one backend optimizer step. # stats = trainer.train(batch_data, loss_fn, mini_bs=1) # Release metric writers and local resources. trainer.close() asyncio.run(main())
Note
Trainer(...)does not load the model.init()is the expensive boundary because it initializes workers and model weights. Rollout, scoring, and training calls then reuse the initialized backend.import areno from areno import Trainer trainer = Trainer( world_size=1, model_path="Qwen/Qwen3.5-4B", backend_type=areno.Areno, custom_config=areno.ArenoConfig(tp_size=1), ) trainer.init()
- Parameters:
world_size (int) – Total number of devices or local worker ranks.
model_path (str) – Local checkpoint path or Hugging Face repo ID.
backend_type – Backend selector. Defaults to Areno when omitted.
custom_config – Backend-specific configuration, such as
areno.ArenoConfig(tp_size=1).metrics_log_dir (str | None) – Optional TensorBoard metrics directory.
- init()¶
Load the tokenizer, create the backend context, and initialize backend workers.
trainer.init()- Returns:
None
Important
Call
init()exactly once before rollout, scoring, training, or checkpoint saving.
- get_tokenizer()¶
Return the initialized tokenizer.
tokenizer = trainer.get_tokenizer() ids = tokenizer.encode("Hello")
- Returns:
tokenizer object from the selected model path.
- load_prompt_batches(dataset, *, batch_size, max_prompt_tokens, prompt_key='prompt', solutions_key='solutions')¶
Yield tokenized prompt batches from a dataset-like object.
The dataset must already expose the normalized prompt schema. If your raw dataset has different field names, normalize it before calling this method or use the CLI
--dataset-loader-fnpath.- Parameters:
dataset – Object supporting
len(dataset)and row indexing.batch_size (int) – Number of accepted rows per prompt batch.
max_prompt_tokens (int) – Skip rows whose tokenized prompt is longer than this limit.
prompt_key (str) – Field containing the prompt text.
solutions_key (str) – Optional field containing reference answers.
- Returns:
iterable of
PromptBatch.
for prompt_batch in trainer.load_prompt_batches( dataset, batch_size=8, max_prompt_tokens=1024, ): prompts = [item.prompt for item in prompt_batch.items]
- rollout_batch(prompts, n_samples, sampling_params)¶
Generate completions from text prompts.
Must be called inside
async with trainer.rollout_session(..., proxy=False)for direct prompt rollouts. The explicit session defines the rollout lifecycle and prevents accidental consecutive rollouts from leaving stale rollout state.- Parameters:
prompts (list[str]) – Prompt strings.
n_samples (int) – Number of completions per prompt.
sampling_params (SamplingParams) – Generation controls.
- Returns:
list[RolloutResult]
This method tokenizes prompts with
encode_generation_promptand then delegates torollout_token_batch().from areno import SamplingParams sampling = SamplingParams(max_new_tokens=128, temperature=1.0) async with trainer.rollout_session(sampling_params=sampling, proxy=False): rollouts = trainer.rollout_batch( ["Solve 12 * 13."], n_samples=4, sampling_params=sampling, )
- rollout_token_batch(prompt_tokens, n_samples, sampling_params)¶
Generate completions from pre-tokenized prompts.
Must be called inside an explicit rollout session, same as
rollout_batch().- Parameters:
prompt_tokens (list[list[int]]) – Prompt token ids.
n_samples (int) – Number of completions per prompt.
sampling_params (SamplingParams) – Generation controls.
- Returns:
list[RolloutResult]
Use this method when your loop already tokenized prompts while building a dataset batch.
tokenizer = trainer.get_tokenizer() prompt_tokens = [tokenizer.encode("Solve 12 * 13.")] sampling = SamplingParams(max_new_tokens=128, temperature=1.0) async with trainer.rollout_session(sampling_params=sampling, proxy=False): rollouts = trainer.rollout_token_batch( prompt_tokens, n_samples=4, sampling_params=sampling, )
- rollout_session(*, sampling_params, loss_mask_policy=None, max_running_prompts=None, timeout_s=300.0, proxy=True)¶
Create an async rollout session.
The session is the required lifecycle boundary for rollout. On enter, it prepares actor rollout state. On exit, it finalizes rollout-only state and prepares the backend for scoring or training. For direct prompt rollout, pass
proxy=False. For agentic rollout, keep the defaultproxy=Trueso the session starts a local OpenAI-compatible proxy.In proxy mode, agent code calls
ctx.get_base_url()with a standard OpenAI client. The proxy returns OpenAI responses with Areno token and logprob metadata;run_agentreturns explicit trajectory turns built from those responses. Assistant text and assistant tool-call spans are trainable by default; tool-result spans are masked unless enabled throughLossMaskPolicy.- Parameters:
sampling_params (SamplingParams) – Default generation controls.
loss_mask_policy (LossMaskPolicy | None) – Optional span-level loss mask policy.
max_running_prompts (int | None) – Global concurrent prompt budget.
timeout_s (float) – Proxy request and agent-function timeout.
proxy (bool) – Whether to start the local OpenAI-compatible proxy.
- Returns:
async
RolloutSessioncontext manager.
async with trainer.rollout_session( sampling_params=SamplingParams(max_new_tokens=32, temperature=0.7), max_running_prompts=64, ) as ctx: print(ctx.get_base_url())
- train(batch_data, loss_fn, mini_bs=8, gradient_accumulation_steps=None)¶
Run one backend policy training step with a caller-provided loss function.
- Parameters:
batch_data (list[TrainSequence]) – Token, mask, logprob, reward, and advantage rows.
loss_fn (Callable) – Loss function called by the backend.
mini_bs (int) – Backend training microbatch size.
gradient_accumulation_steps (int | None) – Optimizer step interval in microbatches.
- Returns:
dict[str, float]with scalar training metrics.
loss_fnreceives the backend data pack and current logprobs. Built-in loss functions live underareno.loss_fns.from functools import partial from areno.loss_fns import gspo_loss_fn stats = trainer.train(batch, partial(gspo_loss_fn, clip_eps=3.0e-4), mini_bs=4)
- ensure_roles(roles)¶
Prepare backend-owned auxiliary model roles for algorithms like PPO and DPO.
- Parameters:
roles (dict[str, ModelRole]) – Role name to model role configuration.
- Returns:
None
from areno import ModelRole trainer.ensure_roles({ "ref": ModelRole(name="ref", path="/path/to/reference", trainable=False), "critic": ModelRole(name="critic", path="/path/to/critic", trainable=True, optimizer_lr=1e-5), })
- score_logprobs(role, token_rows)¶
Score fixed token sequences with a backend-owned model role.
- Parameters:
role (str) – Role name, such as
reforactor.token_rows (list[list[int]]) – Token rows to score.
- Returns:
list[list[float]]
ref_logprobs = trainer.score_logprobs("ref", token_rows)
- score_values(role, token_rows)¶
Score per-token critic values with a backend-owned model role.
- Parameters:
role (str) – Role name, such as
critic.token_rows (list[list[int]]) – Token rows to score.
- Returns:
list[list[float]]
values = trainer.score_values("critic", token_rows)
- score_rewards(role, token_rows)¶
Score sequence rewards with a backend-owned reward model role.
- Parameters:
role (str) – Role name, such as
reward.token_rows (list[list[int]]) – Token rows to score.
- Returns:
list[float]
rewards = trainer.score_rewards("reward", token_rows)
- train_values(role, batch_data, mini_bs, gradient_accumulation_steps=None, *, cliprange_value=0.5, value_loss_coef=0.5)¶
Train a backend-owned critic or value role.
- Parameters:
role (str) – Role name, such as
critic.batch_data (list[TrainSequence]) – Training rows.
mini_bs (int) – Critic training microbatch size.
gradient_accumulation_steps (int | None) – Optimizer step interval in microbatches.
cliprange_value (float) – PPO value-function clipping range.
value_loss_coef (float) – Value loss coefficient.
- Returns:
dict[str, float]
critic_stats = trainer.train_values("critic", batch_data, mini_bs=4)
- save_checkpoint(path)¶
Save a Hugging Face-compatible checkpoint when supported by the backend.
- Parameters:
path (str) – Output directory.
- Returns:
saved checkpoint path as
str.
saved_path = trainer.save_checkpoint("/tmp/areno-step-10")
- close()¶
Release local resources such as metric writers.
- Returns:
None
Data classes¶
- class areno.SamplingParams(greedy=False, top_p=1.0, top_k=-1, max_new_tokens=16, max_context_len=None, temperature=1.0, stop=None, stop_token_ids=None, ignore_eos=False, skip_special_tokens=True, max_prompt_len=None)¶
Generation controls used by rollout APIs.
- Parameters:
greedy (bool) – Force greedy decoding. Overrides temperature in the backend.
top_p (float) – Nucleus sampling threshold.
top_k (int) – Top-k sampling threshold.
-1disables top-k filtering.max_new_tokens (int) – Maximum number of generated response tokens.
max_context_len (int | None) – Optional total context cap for agentic trajectories. The cap is applied to the prompt plus all generated turns concatenated into the trainable trajectory row.
temperature (float) – Sampling temperature.
stop (list[str] | None) – Stop strings.
stop_token_ids (list[int] | None) – Stop token ids.
ignore_eos (bool) – Continue generation without EOS stopping.
skip_special_tokens (bool) – Decode helper preference for completions.
max_prompt_len (int | None) – Optional prompt length cap.
- class areno.TrainSequence(prompt_mask=None, tokens=None, logprobs=None, advantages=None, returns=None, values=None, ref_logprobs=None, reward=0.0, eos_token_id=0)¶
One rollout sequence converted into a policy-gradient training sample.
- Parameters:
prompt_mask (list[bool]) –
Truefor prompt or padded positions; losses train on response positions.tokens (list[int]) – Prompt and response token ids.
logprobs (list[float]) – Rollout-policy logprobs aligned with tokens.
advantages (list[float]) – Per-token advantages.
returns (list[float]) – Optional value targets for PPO.
values (list[float]) – Optional old value predictions for PPO.
ref_logprobs (list[float]) – Optional reference logprobs for KL.
reward (float) – Sequence-level reward.
eos_token_id (int) – EOS id used for padding backend packs.
- class areno.ModelRole(name, path, trainable, optimizer_lr=None)¶
Auxiliary model role owned by the backend.
- Parameters:
name (str) – Role name, for example
ref,reward, orcritic.path (str) – Checkpoint path or Hugging Face repo ID.
trainable (bool) – Whether the role has an optimizer.
optimizer_lr (float | None) – Optimizer LR for trainable roles.
- class areno.ArenoConfig(model_path=None, tp_size=1, dp_size=None, devices=None, dummy_load=False, optimizer=None, runtime=None, max_running_prompts=64, decode_progress_interval_s=10.0)¶
Backend configuration for the local Areno engine.
- Parameters:
model_path (str | None) – Optional backend model path override.
tp_size (int) – Tensor-parallel size.
dp_size (int | None) – Data-parallel size. Defaults to
world_size // tp_size.devices (list[int] | None) – Device ids for worker ranks.
dummy_load (bool) – Build model without loading checkpoint weights.
optimizer (dict | None) – Advanced optimizer config passed to the engine.
runtime (dict | None) – Advanced runtime config passed to the engine. Set
runtime={"attn_backend": "native"}to run withoutflash-attnon the areno_accel native compatibility path. AReno also falls back to native attention on flash-attn-unsupported GPUs such as Tesla T4 and prints a warning. The default is"flash"for normal high-throughput training on supported GPUs.max_running_prompts (int) – Concurrent rollout prompt limit.
decode_progress_interval_s (float) – Worker decode progress log interval. Logs report per-DP scheduled decode throughput for the current window and include
cuda_graph=Truewhen CUDA graph replay was used in that window.
- class areno.api.agentic.AgentBatch(records, prompts, input_tokens, n_samples)¶
Prompt batch expanded into one item per prompt/sample pair for agent execution.
- Parameters:
records (list[dict]) – Source dataset records.
prompts (list[str]) – Prompt strings.
input_tokens (list[list[int]]) – Prompt token ids.
n_samples (int) – Samples per prompt.
- class areno.api.agentic.RewardRecord(...)¶
Unified reward input for agentic rollouts.
Reward functions receive one
RewardRecordper completed trajectory. For multi-turn agents, the record represents one prompt/sample pair, not one HTTP request.completioncontains concatenated assistant response spans for backwards compatibility.final_answercontains the last assistant response.messagesis the full OpenAI-style message list, including tool-result messages.rendered_completionis the same trajectory rendered through the tokenizer chat template when available.tool_callsandtool_resultsexpose parsed tool calls and environment observations.tokens,logprobs, andloss_maskdescribe the model-generated response spans.Tool-result/context spans are included in train rows so logprob scoring sees the same context as rollout, but they are masked from policy loss by default.
- class areno.api.agentic.LossMaskPolicy(assistant_text=True, assistant_tool_calls=True, tool_results=False, final_assistant_text=True, system_prompt=False, user_prompt=False)¶
Span-level policy-loss controls for agentic trajectories.
- Parameters:
assistant_text (bool) – Train assistant text spans.
assistant_tool_calls (bool) – Train assistant tool-call spans.
tool_results (bool) – Train tool-result spans. Defaults to
False.final_assistant_text (bool) – Reserved for final-response text spans.
system_prompt (bool) – Reserved for system prompt spans.
user_prompt (bool) – Reserved for user prompt spans.
One GSPO-style rollout/train step¶
import asyncio
from functools import partial
from datasets import load_dataset
import areno
from areno import SamplingParams, TrainSequence, Trainer
from areno.loss_fns import gspo_loss_fn
def normalize_rewards(rewards):
mean = sum(rewards) / len(rewards)
var = sum((reward - mean) ** 2 for reward in rewards) / max(len(rewards), 1)
std = max(var ** 0.5, 1e-6)
return [(reward - mean) / std for reward in rewards]
async def main():
trainer = Trainer(
world_size=1,
model_path="Qwen/Qwen3.5-4B",
backend_type=areno.Areno,
custom_config=areno.ArenoConfig(tp_size=1),
)
trainer.init()
row = load_dataset("gsm8k", "main", split="train[0:1]")[0]
target = str(row["answer"]).rsplit("####", 1)[-1].strip()
prompt = (
"Solve the problem and put the final answer in \\boxed{}.\n\n"
f"Problem: {row['question']}\nSolution:"
)
prompt_tokens = trainer.get_tokenizer().encode(prompt)
sampling = SamplingParams(max_new_tokens=128, temperature=1.0)
async with trainer.rollout_session(sampling_params=sampling, proxy=False):
rollout = trainer.rollout_token_batch([prompt_tokens], n_samples=4, sampling_params=sampling)[0]
completions = [trainer.get_tokenizer().decode(seq.resp_tokens) for seq in rollout.sequences]
rewards = [1.0 if target in completion else 0.0 for completion in completions]
advantages = normalize_rewards(rewards)
batch = []
for seq, reward, advantage in zip(rollout.sequences, rewards, advantages, strict=True):
response_len = len(seq.resp_tokens)
batch.append(
TrainSequence(
prompt_mask=[True] * len(prompt_tokens) + [False] * response_len,
tokens=prompt_tokens + seq.resp_tokens,
logprobs=[0.0] * len(prompt_tokens) + seq.resp_logprobs,
advantages=[0.0] * len(prompt_tokens) + [advantage] * response_len,
reward=reward,
eos_token_id=trainer.get_tokenizer().eos_token_id,
)
)
)
stats = trainer.train(batch, partial(gspo_loss_fn, clip_eps=3.0e-4), mini_bs=4)
print(stats)
trainer.close()
asyncio.run(main())
Agentic rollout with tools¶
This example shows the SDK pieces used by --agent-fn. The agent calls a
local OpenAI-compatible proxy with Chat Completions tools and returns
explicit trajectories. Areno parses supported model-native tool-call output
from those responses, then the trainer converts trajectories into the same
token, logprob, reward, and loss-mask rows used by regular rollouts.
import asyncio
import areno
from areno import SamplingParams, Trainer
from areno.api.agentic import AgentBatch, AgentTrajectory, AgentTrajectoryTurn
from openai import AsyncOpenAI
tools = [
{
"type": "function",
"function": {
"name": "choose_move",
"parameters": {
"type": "object",
"properties": {
"direction": {"type": "string", "enum": ["up", "down", "left", "right"]},
},
"required": ["direction"],
},
},
}
]
async def run_agent(ctx, batch):
client = AsyncOpenAI(base_url=ctx.get_base_url(), api_key=ctx.api_key, max_retries=0)
async def run_one(item):
messages = [
{"role": "system", "content": "Call choose_move with the selected direction."},
{"role": "user", "content": item.prompt},
]
tool_choice = {"type": "function", "function": {"name": "choose_move"}}
response = await client.chat.completions.create(
model="policy",
messages=messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=16,
temperature=0.7,
)
return AgentTrajectoryTurn(
item=item,
messages=messages,
response=response,
tools=tools,
tool_choice=tool_choice,
)
try:
turns = await asyncio.gather(*(run_one(item) for item in batch.iter_samples()))
return [AgentTrajectory(turns=[turn]) for turn in turns]
finally:
await client.close()
def reward_fn(record):
if not record.tool_calls:
return -1.0
return 1.0
async def collect_agentic_trajectories(trainer, prompt_batch):
agent_batch = AgentBatch.from_prompt_batch(prompt_batch, n_samples=4)
async with trainer.rollout_session(
sampling_params=SamplingParams(max_new_tokens=16, temperature=0.7),
max_running_prompts=len(agent_batch),
) as ctx:
return await run_agent(ctx, agent_batch)
trainer = Trainer(
world_size=1,
model_path="Qwen/Qwen3-0.6B",
backend_type=areno.Areno,
custom_config=areno.ArenoConfig(tp_size=1),
)
trainer.init()
# In CLI training, --agent-fn returns these trajectories to the trainer.
# In a custom loop, load a PromptBatch and call:
# trajectories = asyncio.run(collect_agentic_trajectories(trainer, prompt_batch))