postsUsing token healing to teach LLMs to spell

Teaching LLMs to spell with token healing

We used token healing to help our autocomplete suggest better code.


Why can’t LLMs spell?

You’ve never seen a typo in ChatGPT. So why is this a problem in autocomplete?

This is because LLMs have been trained on complete sequences of text. In high quality training data like books or production codebases, partially typed or misspelled words are uncommon.

However, when we use LLMs for code completion, the codebase is often in a partial state because the developer is still typing and expecting a suggestion.

Consider this example where the developer has just typed Nod. They expect the autocomplete to add an e at their cursor (marked by ), but the autocomplete actually adds a :, which is completely wrong.

Expected:
class Node:
    ...
 
    def get_node(self, value) -> Nod█e
 
Actual:
class Node:
    ...
 
    def get_node(self, value) -> Nod█:

This happens because the string Nod is tokenized into two separate tokens, N and od. The model only sees the token IDs, and the token sequence N od e is not something it has learned to predict from the training data. Node is a separate token, which confuses the LLM, causing it to incorrectly predict :.

Compare the tokenization we’d need to generate in autocomplete with a similar piece of training data:

class Node: ... def get_node(self, value) -> Node
Click or hover to highlight the tokens
class Node: ... def get_node(self, value) -> Node
Click or hover to highlight the tokens

In the training data (second example), the LLM simply repeats the token Node which appears earlier in the context. But when the developer types Nod, they inadvertently break this single common token into a rarely seen sequence of N + od + e - a pattern the model has never learned to complete correctly.

Why character-level models don’t work

Wouldn’t this problem just go away if we didn’t tokenize the inputs, and trained LLMs to generate one character at a time?

Modern LLMs could do this, but this would slow down inference by an unacceptable amount. (In a previous blogpost, we covered how we were able to generate suggestions under 100ms.)

A 7B model on an H100 with FP8 quantization can process ~32k input tokens, which translates to roughly 128k characters (at ~4 characters per token).

If we trained the same model with character tokens, we’d be limited to just 32k characters, 1/4th the context. Worse yet, inference would be 4x slower for identical text, making character-level models impractical for real-world applications.

Histogram

An unfortunate consequence of tokenization is that models struggle to choose the right completion when they see partial words.

Fine-tuning isn’t ideal either

One possible solution is to fine-tune the model to respect the prefix. We can use supervised fine-tuning to make the model prefer respecting the prefix, or use algorithms like direct preference optimization (DPO) to penalize the model when it doesn’t respect the prefix.

Another training-based approach is subword regularization - randomly using different tokenizations during training to expose the model to more token boundaries. For example, “Node” could be tokenized as ["Node"], ["N", "ode"], or ["No", "de"] during different training steps. While this helps the model handle partial words better, it comes with significant drawbacks.

There are several reasons why training-based solutions fall short:

  1. Fine-tuning is not a deterministic solution. To completely fix this, we would have to teach the model every possible prefix of each token.
  2. Similarly, subword regularization slows down training and can hurt model performance, as the model has to learn multiple representations for the same concepts.
  3. Fine-tuning for this use case trades off on the model’s performance across other important use cases (We also care about rewriting code in next edit autocomplete, which does not require respecting the prefix).

Most importantly, fixing this at inference time is actually easier than it sounds!

Token healing

One approach is to backtrack to the last “complete” tokenization point. Instead of trying to generate from the problematic N + od sequence, we remove these incomplete tokens and let the model generate from a clean state:

# Remove "Nod" and generate from the last complete token "->"
def get_node(self, value) ->█

But this creates a new problem: without the partial word as context, the model doesn’t know what the developer intended to type. When we remove “Nod” and generate from ->, the model suggests “None” instead of “Node”:

class Node: ... def get_node(self, value) -> None
Click or hover to highlight the tokens

You may be familiar with token healing, which handles this problem for a single token.

Consider the following example from Guidance. We want to generate a valid URL after the user has typed http:, but the problem is that the token : followed by // is not common in the training data, as it usually appears as ://.

The url of Google is http:█

At inference time, the LLM gets confused and wants to write : followed by " ", leading to http:  (an invalid URL).

To fix this we can remove : from the end of the string, and then force the LLM to generate tokens that start with :.

This is the algorithm for single-token healing:

  1. Remove the nth token. We can treat this as s.
  2. Make sure the next token (n+1) is sampled from tokens that begin with s.

So if the nth token is :, the next token can only be chosen from this subset : , :, ://. This generates http://www.google.com.

This helps in some cases but wouldn’t completely solve our N od e example. We’d just generate from N which is still an irregular token boundary. We need to go back multiple tokens.

Multi-Token Healing

Multi-token healing handles cases where we need to remove multiple tokens to properly align with the user’s typed text.

Consider when a developer types sw and expects the model to generate sweep. This creates a tokenization mismatch:

sw <- what the user typed sweep <- how the model wants to tokenize "sweep"
Click or hover to highlight the tokens

The challenge is that sweep tokenizes as ["s", "weep"], but the prefix sw is a single token. We can’t simply constrain to tokens starting with sw because the ideal first token s doesn’t match that pattern.

Let’s see how we can handle this more complex case.

The key insight is to allow tokens that either:

  1. Continues our prefix (e.g., swing continues sw). This is normal token healing.
  2. Are a substring of our prefix (e.g., s is contained in sw)

Here’s the algorithm:

def get_allowed_tokens(prefix: str):
    """Find all tokens that are valid continuations of our prefix string."""
    allowed_tokens = []
    for token in vocab:
        # Allow if token extends the prefix OR prefix contains the token
        if token.startswith(prefix) or prefix.startswith(token):
            allowed_tokens.append(token)
    return allowed_tokens

For prefix "sw", this would allow:

"swing"  # ✓ starts with "sw"
"swap"   # ✓ starts with "sw"
"s"      # ✓ "sw" starts with "s"
"salad"  # ✗ doesn't match either condition

This visualization shows how we guide the model step-by-step through generating sweep while respecting the sw prefix:

Token Healing Example: "sw" to "sweep"

Constraints: The next token must start with "sw" ("swing") or be a prefix of "sw" ("s")
Possible tokens: s, swing, switch
Generated:  +s

This state machine approach ensures that at each generation step, we only allow tokens that keep us on track toward our target string.

After the prefix has been satisfied, we just continue generating with no constraints!

Integrating into TensorRT-LLM

We want to generate more accurate suggestions using our multi-token healing algorithm, but without losing all the other optimizations of TensorRT-LLM such as speculative decoding.

We know the allowed tokens from our algorithm, but how do we actually force TensorRT to respect these constraints?

We mask disallowed tokens by setting their logits (the raw scores before probability calculation) to negative infinity. When the model calculates probabilities using softmax, any token with -\infty logit gets zero probability:

P(token)=eelogits=0P(token) = \frac{e^{-\infty}}{\sum e^{logits}} = 0

This effectively removes disallowed tokens from consideration while preserving the relative probabilities of allowed tokens.

This does work, and we correctly complete "Node":

successful_cpc.png

Success! However, this came with a new issue - the autocomplete started taking almost a second to respond. After profiling the code, we found that the for-loop in get_allowed_tokens takes ~700ms.

for token in tokens:
    if token.startswith(prefix) or prefix.startswith(token):
        allowed_tokens.append(token)

This is unacceptable since the rest of the inference only takes 100ms!

Using tries for sub 1ms lookup

The root cause is that get_allowed_tokens performs a linear search and string comparison through all 150k tokens.

To optimize this, we can use tries. Tries are data structures for compressing a large set of strings such that you can query prefixes in O(q) time, where q is the length of the query string. Tries use trees under the hood, where every child contains a single character and every child node of the character represents a potential suffix.

Here’s an example of a trie with the words cat, car, card, care and careful. You can click on the words in the corner to see the paths to the words.

c
a
t
r
d
e
f
u
l
Words in Trie:
cat
car
card
care
careful

We used a pure Python implementation of tries. We considered optimizations like path compression or using JIT / Rust-bindings, but we found Python was sufficient since there are only 150k tokens and the average token is approximately 4 characters.

In practice, almost all queries are <300 microseconds and the data structure only takes an additional 30mb.

Trie vs Naive Lookup Speed

Here’s what the speedup looks like for a few example prefixes:

Example PrefixMatchesTrie (ms)Naive (ms)Speedup
test120.03694.7324081.9x
ing260.05702.0514022.0x
not410.08701.239134.1x
a23672.69698.42259.2x
" "5302153.58699.2413.0x

This works. We’ve made the autocomplete respect the prefix and already shaved off 600ms. We just have one more hurdle to cross.

We noticed that sometimes the trie lookup would take ~50ms. A quick check reveals that of the 150k entries to our tokenizer, 50k start with a space!

Distribution of tokens by first character

CharASCIITokens%Description
" "325302156.7Space
.4661006.5Punctuation
_9555305.9Special
(4026412.8Special
a9723672.5Letter (A)

This means that everytime the user’s last keystroke is a single space, we run a depth-first search through 50k tokens, which takes ~50ms in Python. To make this faster, we precomputed and cached the query results for all prefixes with more than 1000 tokens.

This lookup is instantaneous, so we end up with a worst-case latency of <1ms!

Conclusion

We solved the autocomplete accuracy problem caused by tokenization mismatches through token healing, and made it fast enough for real-time use by using tries and caching, reducing latency to under 1ms.

We’re using techniques like these to build the best autocomplete and coding agent for JetBrains. We’d love for you to try it out and let us know your thoughts.

We also have a Discord server where we post updates on Sweep!