Workshop Labs
Workshop Labs

13 min read

Post-Training 50x Faster

We're announcing Trellis, the fastest open-source post-training code for Kimi K2 Thinking

Addie Foote, Rudolf Laine, Tim Kostolansky

March 12, 2026

Post-Training 50x Faster

Today we're announcing Trellis-KimiK2T, our training codebase that enables fast LoRA training on Kimi-K2-Thinking.

Trellis trains LoRAs on all parameters 1 at 6,600 tokens per second on a single 8xH200 node. That's 50x faster than the best open-source alternative, more than 2x cheaper than the closest private training API, and the only single-node implementation that trains the experts. 2

Until now, the promise of open-weight frontier models has been unfulfilled; open weights has not meant open training. We shared a patch to Hugging Face’s open-source implementation, but it was slow, buggy, and failed to train the experts, while fine tuning Kimi-K2.5 on Nvidia NeMo requires a multinode setup. 3

We want to enable anyone to finetune a model, so following safety evaluations, we plan to open-source the codebase.

Comparison

Figure 1We compare the price of our training stack to Hugging Face and Tinker. We benchmarked with SFT on a Yoda dataset. 4 We used the best reasonable setup for each setting.5 The specifics of the costs depend on several factors. 6 We show the cost of Hugging Face when all sequences are constant (no padding), the best scenario for Hugging Face. Note that Hugging Face cannot train all the parameters, so we also achieve better training results. Note the Tinker and Fireworks tokens/dollar refers to their pricing, and their true cost may be lower.
Figure 2We compare the time it takes to do a batch of ~64,000 tokens in Tinker, Trellis, and Hugging Face. 7

A correct implementation

Many of the major problems with HuggingFace, and open source libraries generally, stem from trying to do everything. The result is poor feature support, unpredictable interactions between features (e.g. quantization working in general but breaking for specific model classes), and bugs arising from logic for one use case being automatically applied to another. To design our post-training framework, we started with PrimeRL to provide some high level structure, but removed support for other models and implemented kimi-k2-thinking from scratch.

We first need to correctly implement the forward pass. Correctly is quite important here, and an incorrect version is much easier than a correct one. We verified the outputs against the Moonshot API and compared intermediate activations to the Hugging Face forward pass implementation to debug. When verifying correctness GPU computation is non-deterministic. Code that is logically identical but does operations in a different order can change the result. We expect and allow for small numerical differences to emerge.

We heavily referenced Hugging Face’s kimi modeling file which implements the forward pass. Still, there are several subtleties.

One subtlety was simply where to use configuration parameters. In the model config, there’s a parameter rms_norm_eps, which specifies epsilon for RMS Norm computation. Originally, we used that value for all RMS norms in the model. But, we noticed activations diverging after the attention and found that the correct version only uses rms_norm_eps in the main RMS norm, while the attention RMS Norm’s use the default epsilon value.

Another major complication arose when loading the model. PrimeRL uses FSDP, a common distributed training method. The implementation relies on Pytorch FSDP-related functions. Internally, Pytorch expects the parameters to be a float type. But our quantized parameters are int4’s, stored in groups of 8 in an int32. The most obvious solution is just to cast them to some float type, then cast them back to ints before use. But of course, there’s a catch.

Take bfloat16’s: bfloat16’s consists of 16 bits: 1 sign bit followed by 8 exponent bits and 7 mantissa bits. The value this represents is (-1)^sign × 2^(exponent - bias) × 1 mantissa. There’s also special values, though, including inf, -inf, and NaN. A bfloat16 is NaN if exponent == 255 and mantissa != 0. With NaN, there are many unique bit pattern ranges that map to the same value. With NaNs, when the program encounters one, it may ‘canonicallize’ it. In other words, it will change the bits so all NaNs have the same canonical bit representation. This saves some computation down the line, but it’s a big problem for us! We wanted to do int32→bfloat16→int32, but when we cast to floats of any kind, sometimes some bits change.

We patched requires_grad to accept integer parameters, sidestepping float avoiding float canonicalization. Later, we implemented expert parallelism which doesn’t rely on Torch’s FSDP related functions.

Once we got a working implementation, we focused on optimizing it.

Background

Kimi Specs

Kimi-K2-Thinking is a 1 trillion parameter MoE. It has 61 layers, 1 dense followed by 60 MoE layers, 384 experts per layer, along with a shared expert, and uses MLA attention. Totalling 594 GB, Kimi-K2-Thinking uses INT4 quantization on the 570 GBs of experts, and bfloat16 for everything else (24 GBs combined).

Hardware requirements

Given the nearly 600 GB of model weights, we’ll use 8xH200s with a combined 1140 GB of GPU ram.

Distributed training

With a 600 GB model, we need to use distributed training. The alternative is offloading weights to CPU memory or disk, which is obnoxiously slow. Distributed training methods determine how weights and computation are partitioned across GPUs, and what is communicated between them.

Figure 3Pipeline Not-Parallel. Each GPU holds some layers which are run sequentially on some data. After the layers on one GPU are finished, the activations are sent to the next GPU and the forward pass proceeds.

Our previous Hugging Face solution uses what we call “Pipeline Not-Parallel.” It's like Pipeline Parallel, but not parallel. Figure 3 explains the approach. It’s very inefficient with only one GPU active at a time, and all idle during intra-GPU activation communication.

Figure 4FSDP. Each GPU holds a shard of each parameter and different data starts on each GPU. For each parameter in the forward pass, each GPU sends its shard of it to all the other GPUs, and receives the shards from the other GPUs to reconstruct the parameter, then do the forward with it on the data on that GPU.

The initial implementation used FSDP (fully sharded data parallel). Figure 4 explains the approach. FSDP often exacerbates inefficiencies with MoEs. With MoEs, each expert processes relatively few tokens, resulting in small matrix multiplies with low arithmetic intensity. This makes expert computation memory bandwidth-bound. FSDP effectively 8x’s the memory bandwidth pressure, as every GPU reads all weights from HBM during the forward pass rather than a subset. This causes substantial slow down when memory bandwidth-bound, which is to say, almost always. To resolve this, we use expert parallelism.

Figure 5Expert parallel. Each GPU holds all non-expert parameters and a fraction of the experts in each layer. Different data starts on each GPU and the forward pass for non-expert params happens on each GPU. When the experts are encountered, each token activation is sent to the GPU(s) that have the experts which that token is routed to. The expert forward is computed then the result for each token is sent back to the GPU it started on and the forward pass continues.

Our optimized codebase uses expert parallelism, explained in Figure 5. Expert parallelism applies only to the experts and is agnostic to which distributed training method, if any, is used for other parameters.

Expert Parallel

Background: All to all

Expert parallel has to send tokens between ranks (i.e. per-GPU process). The communication operation used is all-to-all, in which each rank organises data into chunks with a chunk to send to every other rank. Each rank also receives a chunk from every other rank.

Figure 6All to all. Each rank (a process on a GPU) sends data to all other ranks and receives data from all other ranks. Each rank constructs a list with chunks of data designated for the other ranks. They also make an empty buffer which is filled during the all to all with the data sent by the other indexes.

Torch’s all_to_all isn’t differentiable, so we implement a differentiable all to all. The gradient of an all_to_all is also an all_to_all. On the backwards pass, we simply send the data from the receiving ranks back to the original senders.

Implementation

First, we need to specify which GPU(s) each parameter goes on. For our implementation, all non-expert parameters go on every of the GPUs. The experts are split into slices and distributed across GPUs, so with 384 experts per layer and 8 GPUs each GPU gets a slice consisting of 48 experts.

Second, we need an expert-parallel implementation of the expert forward pass that handles the communication between GPUs.

To do the expert MLP without expert parallelism, we start with token activations, apply the router, apply experts to the tokens indicated by the routing, then aggregate the results. 8

To modify this for expert parallelism, we add two all-to-all communication steps (via torch.dist.all_to_all), before and after computing the experts. Before, each rank sends its token activations to all the ranks that contain the corresponding experts. At the same time, it receives tokens from other ranks destined for its experts. After, each rank sends the results back to the rank that owns the token and receives results for its tokens.

In short, each rank owns some token activations, applies the router, temporarily sends its token activations to other ranks to apply experts, then receives and aggregates the results.

Syncing gradients/weights

With Expert Parallelism, we’ve duplicated the non-expert parameters across all GPUs. But now, each of the GPUs non-expert parameters see different data—and compute different gradients! By default, the weights (of what should be the same parameter) would diverge. To fix this we broadcast the gradients, take the average across them, and update all the copies of the parameters using that.

Seems simple, but there’s still a problem. Naively this weighs all ranks equally. Often, the ranks have processed an unequal number of tokens. The gradients should be weighted by the number of tokens, not uniformly across ranks! 9

Theoretically, that fixes the problem. Practically, the weights diverge after many updates, likely due to numerical differences. Since syncing is very cheap, we can copy updated parameters from one rank to all the other ranks.

Optimizations

Before any optimizations, it takes 17.61s per step with batch size 8 and sequence length 2048.

For comparison, the Hugging Face version took 51s per step, but also failed to train the experts. Below, we walk through optimizations and benchmark the time per constant-size step takes. Some optimizations reduce memory and can increase throughput beyond what the step time suggests!

Grouped mm

The simplest way to do the expert MLP forward pass is to loop through each expert and multiply it by the token activations routed to it. A more efficient way, though, is to stack the experts and apply them all at once with one multiplication.

Current performance: 13.66s per step

Dequantize

Profiling showed that dequantizing the weights was taking over 80% of the total training time. Within dequantize, there’s two functions: unpack_from_int32 and dequantize. Initially, unpack_from_int32 took the majority of the time.

How it works.

There’s two parts to dequantization.

unpack_from_int32 takes the tensor of int32’s that all the int4 weights are packed into, and expands it into separate int8’s, which is the smallest int type. Say we’re unpacking a single expert. We’ll get an output tensor of shape (out_dim, in_dim) of int8’s. It takes the quantized version of this, which is a tensor of shape (out_dim, (in_dim + 7) // 8) (The weights are INT4 quantized and there are 8 int4s per packed int32.) The compressed_tensors library implements this by first allocating a tensor of (out_dim, in_dim) int32’s, then looping 8 times, each time extracting the next 4 bit number from every packed int32.

dequantize multiplies all of the unpacked int4 values by the corresponding scale factor.

Reshape rather than iterate

We need to dequantize all the experts on each GPU for each layer. Initially, we iterated through the experts, calling unpack and dequantize on each separately. To speed it up, we can reshape all experts into one tensor and dequantize them all at once.

Current performance: 11.54s per step

Vectorize unpack_from_int32

Still, unpacking alone is taking over half the total time spent training. We dug into the stack and found the compressed_tensors package unpack_from_int32 function at the bottom of it. In the function, it loops eight times, extracting one of the eight int4s from every int32 each iteration. We vectorized this to extract them all at the same time.

Current performance: 8.84s per step

@torch.compile

Compiling applies some extra optimizations, such as fusing kernel calls to make fewer but bigger ones, reducing the overhead of launching kernel and keeping intermediates in registers/shared memory to avoid unnecessary writing and fetching from HBM. It also can eliminate redundant code and pick better memory layouts.

Current performance: 4.93s per step

Skipping unused tokens

During fine-tuning, sequences often end with padding tokens whose outputs are masked in the loss and whose activations are irrelevant. Yet they still incur all-to-all communication and expert compute costs. By identifying and excluding trailing pad tokens from these operations, we can speed up fine tuning by more than 2x.

Current performance: 2.86s per step

Packing

Going a step further: instead of having so many pad tokens, what if we could just put multiple sequences together into one sequence? There’s one complication: we don’t want the sequences to attend to previous sequences. Luckily, there’s already a kernel that does this, and we can use it in our code!

This doesn’t result in a speedup per step, instead it vastly increases the number of tokens per step. In the current setup, it increases effective tokens per step by nearly 10x, but also slows the step down ~2-3x. The net speed up depends on the average length of your sequences and the maximum sequence length.

After applying all these, we now achieve a throughput of 6,600 tokens per second, more than 50x what the Hugging Face implementation achieved.

What’s next?

Post-training trillion-parameter models isn’t easy, but it’s a necessary step on the path to our vision of the future. To enable people to leverage their data, without risk of 3rd parties storing or using it, we need our own training stack. To build truly user-aligned models, memory, prompting, or context is insufficient; we’ll have to make weight updates. To create models that enhance what people can do and work alongside them, rather than instead of them, we need to post train frontier models.

In a month and a half, we built a training stack for Kimi-K2-Thinking that beats all other alternatives available to us. We’ve optimized our throughput from 125 to >6,600 tokens per second on 8xH200s and use this stack internally at a 2x cheaper cost per token than the best available APIs.

We want to enable everyone to finetune a model, so following safety evaluations, we plan to open-source the codebase.

In the meantime, we’re developing our product that lets anyone generate a custom model that works, writes, and thinks more like them. If you’re interested, sign up for our waitlist here.


We thank Jessica Li, Will Brown, Luke Drago, Herbie Bradley, Cody Rushing, Oscar Moxon, Alex DeNuzzo, Daniel McCann-Sayles, and Devansh Pandey for reviewing drafts of this post.

Footnotes

  1. We don’t train the routers, but this is for stability reasons, and the training codebase could.

  2. NeMo’s kimi-k2.5 recipe is set up for a 256 GPU multi-node setup. Because it’s designed for full fine tuning rather than LoRA, NeMO dequantizes all weights. Even with modifications, it needs more memory than 8xH200s have.

  3. The best open source alternative was Hugging Face, which we extensively debugged previously. Still, it only trains the non-quantized weights (attention and shared experts) and does not train the majority of the weights (standard experts).

  4. For HuggingFace this was sequence length 512 and adding padding up to that. For Trellis, we used packing with batches with eight 8,192-token sequences. This resulted in ~400 actual sequences and 64k real tokens per batch.

  5. The custom Yoda dataset consists of trivia questions and responses in the voice of Yoda generated via an LLM API. The sequences had an average of 160 tokens and a max of 364.

  6. Some factors include hardware used, sequence length, batch size, and whether the sequence length is fixed or variable (shown in Figure 1). We used 8xh200s for everything, though 8xb200s would likely be cheaper per token. We chose a sequence length appropriate for the dataset, 512, and increased batch size as much as possible for each. With variable sequence length the cost depends on how much variance in sequence length there is. With Hugging Face the cost per token is generally inversely proportional to the average tokens per sequence/sequence length, whereas with our custom setup the cost increases sub-inversely to the average tokens per sequence/sequence length. Tinker and Fireworks manage sequence length and charge uniformly in both cases. For constant sequence length, we used a dataset of random tokens with all sequences consisting of 512 tokens and no padding. More specifics on the step time, tokens per step, etc here.

  7. The batch size is the number of tokens that fit in one batch with trellis, but this batch size is somewhat arbitrary (and can be increased with gradient accumulation, or decreased). The batch size does affect the time, and sometimes relative times. With higher and lower batch size Tinker speed stays roughly constant, whereas due to running on a single cluster, Trellis and Hugging Face speeds decrease as batch size increases, and there is a maximum batch size based on the memory available on the cluster.

  8. By default, the forward pass logic for the experts is something like this:

    def expert_moe(token_activations):
        selected_experts, expert_weights = router(token_activations)  # 1. get the selected experts and weights for each token
        expert_output = experts(selected_experts, token_activations)  # 2. get the output of the experts for the tokens routed to it
        results_by_token = sum_expert_outputs_by_token(expert_output, expert_weights)  # 3. get the weighted sum of the expert outputs for each token
        shared_out = shared_experts(token_activations)  # 4. apply shared experts for all tokens
        output = results_by_token + shared_out  # 5. add experts and shared expert results
        return output
    
  9. This mirrors a common bug in gradient accumulation.

Please rotate your device to portrait mode