Trainer SDK reference

The SDK is for custom training loops and algorithm experiments. Use it when the CLI is too high-level and you want to control rollout, reward calculation, advantage construction, loss selection, role scoring, or checkpoint cadence directly from Python.

class areno.Trainer(world_size, model_path, backend_type=None, custom_config=None, metrics_log_dir=None)

Main entry point for local Areno training workflows.

Trainer initializes tokenizer and backend workers, generates rollout batches, runs policy training steps, manages PPO/DPO auxiliary roles, scores logprobs/values/rewards, and saves Hugging Face-compatible checkpoints.

It provides methods to:

  • create a local tensor-parallel Areno backend

  • load prompt batches from dataset-like objects

  • generate text rollouts from string prompts or token ids

  • run agentic rollouts through a local OpenAI-compatible proxy

  • train policy batches with caller-provided loss functions

  • prepare reference, reward, and critic roles for PPO/DPO workflows

  • score logprobs, values, and rewards through backend-owned roles

  • save Hugging Face-compatible checkpoints

Direct rollout calls must run inside async with trainer.rollout_session(...). The session is the lifecycle boundary for rollout state, actor onload/offload, and optional agentic proxy serving.

Typical flow

import asyncio
import areno
from areno import Trainer

async def main():
    # Near-instant: constructs the Python wrapper only.
    trainer = Trainer(
        world_size=1,
        model_path="Qwen/Qwen3.5-4B",
        backend_type=areno.Areno,
        custom_config=areno.ArenoConfig(tp_size=1),
    )

    # Takes a moment: loads tokenizer, starts workers, loads checkpoint.
    trainer.init()

    # Rollout calls must run inside an explicit rollout session. The
    # session owns actor onload/offload and rollout-state cleanup.
    sampling = areno.SamplingParams(max_new_tokens=128)
    async with trainer.rollout_session(sampling_params=sampling, proxy=False):
        rollout = trainer.rollout_batch(["Solve 12 * 13."], n_samples=1, sampling_params=sampling)

    # Build TrainSequence rows and a loss function for your algorithm,
    # then run one backend optimizer step.
    # stats = trainer.train(batch_data, loss_fn, mini_bs=1)

    # Release metric writers and local resources.
    trainer.close()

asyncio.run(main())

Note

Trainer(...) does not load the model. init() is the expensive boundary because it initializes workers and model weights. Rollout, scoring, and training calls then reuse the initialized backend.

import areno
from areno import Trainer

trainer = Trainer(
    world_size=1,
    model_path="Qwen/Qwen3.5-4B",
    backend_type=areno.Areno,
    custom_config=areno.ArenoConfig(tp_size=1),
)
trainer.init()
Parameters:
  • world_size (int) – Total number of devices or local worker ranks.

  • model_path (str) – Local checkpoint path or Hugging Face repo ID.

  • backend_type – Backend selector. Defaults to Areno when omitted.

  • custom_config – Backend-specific configuration, such as areno.ArenoConfig(tp_size=1).

  • metrics_log_dir (str | None) – Optional TensorBoard metrics directory.

init()

Load the tokenizer, create the backend context, and initialize backend workers.

trainer.init()
Returns:

None

Important

Call init() exactly once before rollout, scoring, training, or checkpoint saving.

get_tokenizer()

Return the initialized tokenizer.

tokenizer = trainer.get_tokenizer()
ids = tokenizer.encode("Hello")
Returns:

tokenizer object from the selected model path.

load_prompt_batches(dataset, *, batch_size, max_prompt_tokens, prompt_key='prompt', solutions_key='solutions')

Yield tokenized prompt batches from a dataset-like object.

The dataset must already expose the normalized prompt schema. If your raw dataset has different field names, normalize it before calling this method or use the CLI --dataset-loader-fn path.

Parameters:
  • dataset – Object supporting len(dataset) and row indexing.

  • batch_size (int) – Number of accepted rows per prompt batch.

  • max_prompt_tokens (int) – Skip rows whose tokenized prompt is longer than this limit.

  • prompt_key (str) – Field containing the prompt text.

  • solutions_key (str) – Optional field containing reference answers.

Returns:

iterable of PromptBatch.

for prompt_batch in trainer.load_prompt_batches(
    dataset,
    batch_size=8,
    max_prompt_tokens=1024,
):
    prompts = [item.prompt for item in prompt_batch.items]
rollout_batch(prompts, n_samples, sampling_params)

Generate completions from text prompts.

Must be called inside async with trainer.rollout_session(..., proxy=False) for direct prompt rollouts. The explicit session defines the rollout lifecycle and prevents accidental consecutive rollouts from leaving stale rollout state.

Parameters:
  • prompts (list[str]) – Prompt strings.

  • n_samples (int) – Number of completions per prompt.

  • sampling_params (SamplingParams) – Generation controls.

Returns:

list[RolloutResult]

This method tokenizes prompts with encode_generation_prompt and then delegates to rollout_token_batch().

from areno import SamplingParams

sampling = SamplingParams(max_new_tokens=128, temperature=1.0)
async with trainer.rollout_session(sampling_params=sampling, proxy=False):
    rollouts = trainer.rollout_batch(
        ["Solve 12 * 13."],
        n_samples=4,
        sampling_params=sampling,
    )
rollout_token_batch(prompt_tokens, n_samples, sampling_params)

Generate completions from pre-tokenized prompts.

Must be called inside an explicit rollout session, same as rollout_batch().

Parameters:
  • prompt_tokens (list[list[int]]) – Prompt token ids.

  • n_samples (int) – Number of completions per prompt.

  • sampling_params (SamplingParams) – Generation controls.

Returns:

list[RolloutResult]

Use this method when your loop already tokenized prompts while building a dataset batch.

tokenizer = trainer.get_tokenizer()
prompt_tokens = [tokenizer.encode("Solve 12 * 13.")]
sampling = SamplingParams(max_new_tokens=128, temperature=1.0)
async with trainer.rollout_session(sampling_params=sampling, proxy=False):
    rollouts = trainer.rollout_token_batch(
        prompt_tokens,
        n_samples=4,
        sampling_params=sampling,
    )
rollout_session(*, sampling_params, loss_mask_policy=None, max_running_prompts=None, timeout_s=300.0, proxy=True)

Create an async rollout session.

The session is the required lifecycle boundary for rollout. On enter, it prepares actor rollout state. On exit, it finalizes rollout-only state and prepares the backend for scoring or training. For direct prompt rollout, pass proxy=False. For agentic rollout, keep the default proxy=True so the session starts a local OpenAI-compatible proxy.

In proxy mode, agent code calls ctx.get_base_url() with a standard OpenAI client. The proxy returns OpenAI responses with Areno token and logprob metadata; run_agent returns explicit trajectory turns built from those responses. Assistant text and assistant tool-call spans are trainable by default; tool-result spans are masked unless enabled through LossMaskPolicy.

Parameters:
  • sampling_params (SamplingParams) – Default generation controls.

  • loss_mask_policy (LossMaskPolicy | None) – Optional span-level loss mask policy.

  • max_running_prompts (int | None) – Global concurrent prompt budget.

  • timeout_s (float) – Proxy request and agent-function timeout.

  • proxy (bool) – Whether to start the local OpenAI-compatible proxy.

Returns:

async RolloutSession context manager.

async with trainer.rollout_session(
    sampling_params=SamplingParams(max_new_tokens=32, temperature=0.7),
    max_running_prompts=64,
) as ctx:
    print(ctx.get_base_url())
train(batch_data, loss_fn, mini_bs=8, gradient_accumulation_steps=None)

Run one backend policy training step with a caller-provided loss function.

Parameters:
  • batch_data (list[TrainSequence]) – Token, mask, logprob, reward, and advantage rows.

  • loss_fn (Callable) – Loss function called by the backend.

  • mini_bs (int) – Backend training microbatch size.

  • gradient_accumulation_steps (int | None) – Optimizer step interval in microbatches.

Returns:

dict[str, float] with scalar training metrics.

loss_fn receives the backend data pack and current logprobs. Built-in loss functions live under areno.loss_fns.

from functools import partial
from areno.loss_fns import gspo_loss_fn

stats = trainer.train(batch, partial(gspo_loss_fn, clip_eps=3.0e-4), mini_bs=4)
ensure_roles(roles)

Prepare backend-owned auxiliary model roles for algorithms like PPO and DPO.

Parameters:

roles (dict[str, ModelRole]) – Role name to model role configuration.

Returns:

None

from areno import ModelRole

trainer.ensure_roles({
    "ref": ModelRole(name="ref", path="/path/to/reference", trainable=False),
    "critic": ModelRole(name="critic", path="/path/to/critic", trainable=True, optimizer_lr=1e-5),
})
score_logprobs(role, token_rows)

Score fixed token sequences with a backend-owned model role.

Parameters:
  • role (str) – Role name, such as ref or actor.

  • token_rows (list[list[int]]) – Token rows to score.

Returns:

list[list[float]]

ref_logprobs = trainer.score_logprobs("ref", token_rows)
score_values(role, token_rows)

Score per-token critic values with a backend-owned model role.

Parameters:
  • role (str) – Role name, such as critic.

  • token_rows (list[list[int]]) – Token rows to score.

Returns:

list[list[float]]

values = trainer.score_values("critic", token_rows)
score_rewards(role, token_rows)

Score sequence rewards with a backend-owned reward model role.

Parameters:
  • role (str) – Role name, such as reward.

  • token_rows (list[list[int]]) – Token rows to score.

Returns:

list[float]

rewards = trainer.score_rewards("reward", token_rows)
train_values(role, batch_data, mini_bs, gradient_accumulation_steps=None, *, cliprange_value=0.5, value_loss_coef=0.5)

Train a backend-owned critic or value role.

Parameters:
  • role (str) – Role name, such as critic.

  • batch_data (list[TrainSequence]) – Training rows.

  • mini_bs (int) – Critic training microbatch size.

  • gradient_accumulation_steps (int | None) – Optimizer step interval in microbatches.

  • cliprange_value (float) – PPO value-function clipping range.

  • value_loss_coef (float) – Value loss coefficient.

Returns:

dict[str, float]

critic_stats = trainer.train_values("critic", batch_data, mini_bs=4)
save_checkpoint(path)

Save a Hugging Face-compatible checkpoint when supported by the backend.

Parameters:

path (str) – Output directory.

Returns:

saved checkpoint path as str.

saved_path = trainer.save_checkpoint("/tmp/areno-step-10")
close()

Release local resources such as metric writers.

Returns:

None

Data classes

class areno.SamplingParams(greedy=False, top_p=1.0, top_k=-1, max_new_tokens=16, max_context_len=None, temperature=1.0, stop=None, stop_token_ids=None, ignore_eos=False, skip_special_tokens=True, max_prompt_len=None)

Generation controls used by rollout APIs.

Parameters:
  • greedy (bool) – Force greedy decoding. Overrides temperature in the backend.

  • top_p (float) – Nucleus sampling threshold.

  • top_k (int) – Top-k sampling threshold. -1 disables top-k filtering.

  • max_new_tokens (int) – Maximum number of generated response tokens.

  • max_context_len (int | None) – Optional total context cap for agentic trajectories. The cap is applied to the prompt plus all generated turns concatenated into the trainable trajectory row.

  • temperature (float) – Sampling temperature.

  • stop (list[str] | None) – Stop strings.

  • stop_token_ids (list[int] | None) – Stop token ids.

  • ignore_eos (bool) – Continue generation without EOS stopping.

  • skip_special_tokens (bool) – Decode helper preference for completions.

  • max_prompt_len (int | None) – Optional prompt length cap.

class areno.TrainSequence(prompt_mask=None, tokens=None, logprobs=None, advantages=None, returns=None, values=None, ref_logprobs=None, reward=0.0, eos_token_id=0)

One rollout sequence converted into a policy-gradient training sample.

Parameters:
  • prompt_mask (list[bool]) – True for prompt or padded positions; losses train on response positions.

  • tokens (list[int]) – Prompt and response token ids.

  • logprobs (list[float]) – Rollout-policy logprobs aligned with tokens.

  • advantages (list[float]) – Per-token advantages.

  • returns (list[float]) – Optional value targets for PPO.

  • values (list[float]) – Optional old value predictions for PPO.

  • ref_logprobs (list[float]) – Optional reference logprobs for KL.

  • reward (float) – Sequence-level reward.

  • eos_token_id (int) – EOS id used for padding backend packs.

class areno.ModelRole(name, path, trainable, optimizer_lr=None)

Auxiliary model role owned by the backend.

Parameters:
  • name (str) – Role name, for example ref, reward, or critic.

  • path (str) – Checkpoint path or Hugging Face repo ID.

  • trainable (bool) – Whether the role has an optimizer.

  • optimizer_lr (float | None) – Optimizer LR for trainable roles.

class areno.ArenoConfig(model_path=None, tp_size=1, dp_size=None, devices=None, dummy_load=False, optimizer=None, runtime=None, max_running_prompts=64, decode_progress_interval_s=10.0)

Backend configuration for the local Areno engine.

Parameters:
  • model_path (str | None) – Optional backend model path override.

  • tp_size (int) – Tensor-parallel size.

  • dp_size (int | None) – Data-parallel size. Defaults to world_size // tp_size.

  • devices (list[int] | None) – Device ids for worker ranks.

  • dummy_load (bool) – Build model without loading checkpoint weights.

  • optimizer (dict | None) – Advanced optimizer config passed to the engine.

  • runtime (dict | None) – Advanced runtime config passed to the engine. Set runtime={"attn_backend": "native"} to run without flash-attn on the areno_accel native compatibility path. AReno also falls back to native attention on flash-attn-unsupported GPUs such as Tesla T4 and prints a warning. The default is "flash" for normal high-throughput training on supported GPUs.

  • max_running_prompts (int) – Concurrent rollout prompt limit.

  • decode_progress_interval_s (float) – Worker decode progress log interval. Logs report per-DP scheduled decode throughput for the current window and include cuda_graph=True when CUDA graph replay was used in that window.

class areno.api.agentic.AgentBatch(records, prompts, input_tokens, n_samples)

Prompt batch expanded into one item per prompt/sample pair for agent execution.

Parameters:
  • records (list[dict]) – Source dataset records.

  • prompts (list[str]) – Prompt strings.

  • input_tokens (list[list[int]]) – Prompt token ids.

  • n_samples (int) – Samples per prompt.

class areno.api.agentic.RewardRecord(...)

Unified reward input for agentic rollouts.

Reward functions receive one RewardRecord per completed trajectory. For multi-turn agents, the record represents one prompt/sample pair, not one HTTP request.

completion contains concatenated assistant response spans for backwards compatibility. final_answer contains the last assistant response. messages is the full OpenAI-style message list, including tool-result messages. rendered_completion is the same trajectory rendered through the tokenizer chat template when available. tool_calls and tool_results expose parsed tool calls and environment observations. tokens, logprobs, and loss_mask describe the model-generated response spans.

Tool-result/context spans are included in train rows so logprob scoring sees the same context as rollout, but they are masked from policy loss by default.

class areno.api.agentic.LossMaskPolicy(assistant_text=True, assistant_tool_calls=True, tool_results=False, final_assistant_text=True, system_prompt=False, user_prompt=False)

Span-level policy-loss controls for agentic trajectories.

Parameters:
  • assistant_text (bool) – Train assistant text spans.

  • assistant_tool_calls (bool) – Train assistant tool-call spans.

  • tool_results (bool) – Train tool-result spans. Defaults to False.

  • final_assistant_text (bool) – Reserved for final-response text spans.

  • system_prompt (bool) – Reserved for system prompt spans.

  • user_prompt (bool) – Reserved for user prompt spans.

One GSPO-style rollout/train step

import asyncio
from functools import partial

from datasets import load_dataset

import areno
from areno import SamplingParams, TrainSequence, Trainer
from areno.loss_fns import gspo_loss_fn


def normalize_rewards(rewards):
    mean = sum(rewards) / len(rewards)
    var = sum((reward - mean) ** 2 for reward in rewards) / max(len(rewards), 1)
    std = max(var ** 0.5, 1e-6)
    return [(reward - mean) / std for reward in rewards]


async def main():
    trainer = Trainer(
        world_size=1,
        model_path="Qwen/Qwen3.5-4B",
        backend_type=areno.Areno,
        custom_config=areno.ArenoConfig(tp_size=1),
    )
    trainer.init()

    row = load_dataset("gsm8k", "main", split="train[0:1]")[0]
    target = str(row["answer"]).rsplit("####", 1)[-1].strip()
    prompt = (
        "Solve the problem and put the final answer in \\boxed{}.\n\n"
        f"Problem: {row['question']}\nSolution:"
    )
    prompt_tokens = trainer.get_tokenizer().encode(prompt)
    sampling = SamplingParams(max_new_tokens=128, temperature=1.0)

    async with trainer.rollout_session(sampling_params=sampling, proxy=False):
        rollout = trainer.rollout_token_batch([prompt_tokens], n_samples=4, sampling_params=sampling)[0]

    completions = [trainer.get_tokenizer().decode(seq.resp_tokens) for seq in rollout.sequences]
    rewards = [1.0 if target in completion else 0.0 for completion in completions]
    advantages = normalize_rewards(rewards)

    batch = []
    for seq, reward, advantage in zip(rollout.sequences, rewards, advantages, strict=True):
        response_len = len(seq.resp_tokens)
        batch.append(
            TrainSequence(
                prompt_mask=[True] * len(prompt_tokens) + [False] * response_len,
                tokens=prompt_tokens + seq.resp_tokens,
                logprobs=[0.0] * len(prompt_tokens) + seq.resp_logprobs,
                advantages=[0.0] * len(prompt_tokens) + [advantage] * response_len,
                reward=reward,
                eos_token_id=trainer.get_tokenizer().eos_token_id,
            )
        )
    )

    stats = trainer.train(batch, partial(gspo_loss_fn, clip_eps=3.0e-4), mini_bs=4)
    print(stats)
    trainer.close()

asyncio.run(main())

Agentic rollout with tools

This example shows the SDK pieces used by --agent-fn. The agent calls a local OpenAI-compatible proxy with Chat Completions tools and returns explicit trajectories. Areno parses supported model-native tool-call output from those responses, then the trainer converts trajectories into the same token, logprob, reward, and loss-mask rows used by regular rollouts.

import asyncio

import areno
from areno import SamplingParams, Trainer
from areno.api.agentic import AgentBatch, AgentTrajectory, AgentTrajectoryTurn
from openai import AsyncOpenAI


tools = [
    {
        "type": "function",
        "function": {
            "name": "choose_move",
            "parameters": {
                "type": "object",
                "properties": {
                    "direction": {"type": "string", "enum": ["up", "down", "left", "right"]},
                },
                "required": ["direction"],
            },
        },
    }
]


async def run_agent(ctx, batch):
    client = AsyncOpenAI(base_url=ctx.get_base_url(), api_key=ctx.api_key, max_retries=0)

    async def run_one(item):
        messages = [
            {"role": "system", "content": "Call choose_move with the selected direction."},
            {"role": "user", "content": item.prompt},
        ]
        tool_choice = {"type": "function", "function": {"name": "choose_move"}}
        response = await client.chat.completions.create(
            model="policy",
            messages=messages,
            tools=tools,
            tool_choice=tool_choice,
            max_tokens=16,
            temperature=0.7,
        )
        return AgentTrajectoryTurn(
            item=item,
            messages=messages,
            response=response,
            tools=tools,
            tool_choice=tool_choice,
        )

    try:
        turns = await asyncio.gather(*(run_one(item) for item in batch.iter_samples()))
        return [AgentTrajectory(turns=[turn]) for turn in turns]
    finally:
        await client.close()


def reward_fn(record):
    if not record.tool_calls:
        return -1.0
    return 1.0


async def collect_agentic_trajectories(trainer, prompt_batch):
    agent_batch = AgentBatch.from_prompt_batch(prompt_batch, n_samples=4)
    async with trainer.rollout_session(
        sampling_params=SamplingParams(max_new_tokens=16, temperature=0.7),
        max_running_prompts=len(agent_batch),
    ) as ctx:
        return await run_agent(ctx, agent_batch)


trainer = Trainer(
    world_size=1,
    model_path="Qwen/Qwen3-0.6B",
    backend_type=areno.Areno,
    custom_config=areno.ArenoConfig(tp_size=1),
)
trainer.init()

# In CLI training, --agent-fn returns these trajectories to the trainer.
# In a custom loop, load a PromptBatch and call:
# trajectories = asyncio.run(collect_agentic_trajectories(trainer, prompt_batch))