Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ public static (int Id, int TokenIndex, int TokenLength)[] BytePairEncode(ReadOnl
return [(ranks[mergingBytes], 0, 1)];
}

// For large inputs, use heap-based algorithm to avoid O(n²) behavior.
// Threshold of 128 chosen empirically: linear scan is cache-friendly for small inputs,
// while heap overhead (O(log n) per operation) becomes worthwhile for larger inputs.
// Based on upstream tiktoken using 100, adjusted upward for C#'s efficient span operations.
if (mergingBytes.Length > 128)
{
return BytePairEncodeLarge(mergingBytes, ranks, indexMappingSpan);
}

(int Index, int Rank)[]? arrayPoolArray = null;
int requiredLength = mergingBytes.Length + 1;
Span<(int Index, int Rank)> byteIndicesAndRanks = requiredLength <= 64 ?
Expand Down Expand Up @@ -116,6 +125,168 @@ int GetRank(Span<(int Index, int Rank)> byteIndicesAndRanks, int startIndex, int
return result;
}

private struct State
{
public int Prev;
public int End;
public int NextEnd;
public int NextRank;
// Note: In the Tiktoken tokenizer, the rank is also the token Id.
// This field is used to cache the rank/Id after a merge so we don't need to re-look it up.
// Using this code with a different tokenizer where rank != token Id would produce wrong results.
public int CurRank;
}

private struct MergeEntry : IComparable<MergeEntry>
{
public int Rank;
public int Start;

public int CompareTo(MergeEntry other)
{
int rankComparison = Rank.CompareTo(other.Rank);
if (rankComparison != 0)
{
return rankComparison;
}
return Start.CompareTo(other.Start);
}
}

private static (int Id, int TokenIndex, int TokenLength)[] BytePairEncodeLarge(ReadOnlyMemory<byte> mergingBytes, IReadOnlyDictionary<ReadOnlyMemory<byte>, int> ranks, ReadOnlySpan<int> indexMappingSpan)
{
int stateLength = mergingBytes.Length;
State[] statePoolArray = ArrayPool<State>.Shared.Rent(stateLength);
Span<State> state = statePoolArray.AsSpan(0, stateLength);

state[0] = new State
{
Prev = int.MaxValue,
End = 1,
NextEnd = 2,
NextRank = int.MaxValue,
CurRank = int.MaxValue
};

var heap = new PriorityQueue<MergeEntry>();

for (int i = 0; i < mergingBytes.Length - 1; i++)
{
var slice = mergingBytes.Slice(i, 2);
if (ranks.TryGetValue(slice, out int rank))
{
heap.Enqueue(new MergeEntry { Start = i, Rank = rank });
state[i].NextRank = rank;
}

state[i + 1] = new State
{
Prev = i,
End = i + 2,
NextEnd = i + 3,
NextRank = int.MaxValue,
CurRank = int.MaxValue
};
}

// Local function to add a potential merge to the heap.
void PotentialMerge(Span<State> stateSpan, PriorityQueue<MergeEntry> heapQueue, int start, int nextEndItem)
{
stateSpan[start].NextEnd = nextEndItem;
stateSpan[start].NextRank = int.MaxValue;

if (nextEndItem <= mergingBytes.Length)
{
var slice = mergingBytes.Slice(start, nextEndItem - start);
if (ranks.TryGetValue(slice, out int rank))
{
heapQueue.Enqueue(new MergeEntry { Start = start, Rank = rank });
stateSpan[start].NextRank = rank;
}
}
}

while (heap.Count > 0)
{
MergeEntry left = heap.Dequeue();

if (left.Rank == int.MaxValue)
{
break;
}

if (left.Rank != state[left.Start].NextRank)
{
continue;
}

int leftStart = left.Start;
int rightStart = state[leftStart].End;
int rightEnd = state[leftStart].NextEnd;
int rightNextEnd = state[rightStart].NextEnd;

state[leftStart].CurRank = state[leftStart].NextRank;
state[leftStart].End = rightEnd;
PotentialMerge(state, heap, leftStart, rightNextEnd);

if (rightEnd < state.Length)
{
state[rightEnd].Prev = leftStart;
}

if (leftStart > 0)
{
int prevStart = state[leftStart].Prev;
PotentialMerge(state, heap, prevStart, rightEnd);
}

state[rightStart].NextRank = int.MaxValue;
}

// Use ArrayPool for the result buffer to avoid List<T> overhead.
// The maximum number of tokens is mergingBytes.Length (no merges).
var resultPoolArray = ArrayPool<(int Id, int TokenIndex, int TokenLength)>.Shared.Rent(mergingBytes.Length);
int resultCount = 0;
int currentIndex = 0;

while (currentIndex < state.Length)
{
int startIndex = currentIndex;
int endIndex = state[currentIndex].End;

int mappedStartIndex = indexMappingSpan[startIndex];
int mappedEndIndex = indexMappingSpan[endIndex];

int finalEndIndex = endIndex;

// Handle partial characters/elements at token boundaries.
// If the byte at endIndex-1 maps to the same character as endIndex,
// extend the token to include the complete character.
if (finalEndIndex > 0 && indexMappingSpan[finalEndIndex - 1] == mappedEndIndex)
{
finalEndIndex++;
while (finalEndIndex < indexMappingSpan.Length && indexMappingSpan[finalEndIndex] == mappedEndIndex)
{
finalEndIndex++;
}
}

int tokenId = state[currentIndex].CurRank != int.MaxValue
? state[currentIndex].CurRank
: ranks[mergingBytes.SliceStartEnd(startIndex, endIndex)];

resultPoolArray[resultCount++] = (tokenId, mappedStartIndex, indexMappingSpan[finalEndIndex] - mappedStartIndex);

currentIndex = state[currentIndex].End;
}

ArrayPool<State>.Shared.Return(statePoolArray);

var result = resultPoolArray.AsSpan(0, resultCount).ToArray();
ArrayPool<(int Id, int TokenIndex, int TokenLength)>.Shared.Return(resultPoolArray);
return result;
}

private static ReadOnlyMemory<byte> SliceStartEnd(this ReadOnlyMemory<byte> memory, int start, int end) => memory.Slice(start, end - start);
}
}
4 changes: 4 additions & 0 deletions src/Microsoft.ML.Tokenizers/Utils/PriorityQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ internal class PriorityQueue<T> where T : IComparable<T>
{
private readonly List<T> _data;

public PriorityQueue() : this(0)
{
}

public PriorityQueue(int capacity)
{
_data = new List<T>(capacity);
Expand Down
72 changes: 71 additions & 1 deletion test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.DotNet.RemoteExecutor;
using System;
using System.Buffers;
using System.Collections.Generic;
Expand All @@ -13,6 +12,7 @@
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.DotNet.RemoteExecutor;
using Xunit;

namespace Microsoft.ML.Tokenizers.Tests
Expand Down Expand Up @@ -848,6 +848,76 @@ public void TestOss()

private static IReadOnlyDictionary<string, int>? GetVocabulary(TiktokenTokenizer tiktoken)
=> typeof(TiktokenTokenizer).GetProperty("Vocabulary", BindingFlags.Instance | BindingFlags.NonPublic)?.GetValue(tiktoken) as IReadOnlyDictionary<string, int>;

[Fact]
public void TestLargeInputOptimization()
{
// Verify that large inputs (>128 bytes) and boundary cases round-trip correctly via the public API.
// This exercises the large-input optimization path but does not directly compare it to the small-input path.

// Test with repeated characters - this is the adversarial case that caused O(n^2) behavior
string largeRepeatedInput = new string('a', 1000);
IReadOnlyList<int> ids = GPT4.EncodeToIds(largeRepeatedInput);
string decoded = GPT4.Decode(ids);
Assert.Equal(largeRepeatedInput, decoded);

// Test with a more realistic large input
string largeMixedInput = string.Join(" ", Enumerable.Repeat("Hello World! This is a test.", 50));
IReadOnlyList<int> mixedIds = GPT4.EncodeToIds(largeMixedInput);
string mixedDecoded = GPT4.Decode(mixedIds);
Assert.Equal(largeMixedInput, mixedDecoded);

// Test boundary case - exactly at threshold (128)
string boundaryInput = new string('x', 128);
IReadOnlyList<int> boundaryIds = GPT4.EncodeToIds(boundaryInput);
string boundaryDecoded = GPT4.Decode(boundaryIds);
Assert.Equal(boundaryInput, boundaryDecoded);

// Test just below threshold (127)
string belowThresholdInput = new string('x', 127);
IReadOnlyList<int> belowIds = GPT4.EncodeToIds(belowThresholdInput);
string belowDecoded = GPT4.Decode(belowIds);
Assert.Equal(belowThresholdInput, belowDecoded);

// Test just above threshold (129)
string aboveThresholdInput = new string('x', 129);
IReadOnlyList<int> aboveIds = GPT4.EncodeToIds(aboveThresholdInput);
string aboveDecoded = GPT4.Decode(aboveIds);
Assert.Equal(aboveThresholdInput, aboveDecoded);
}

[Theory]
[InlineData(200)]
[InlineData(500)]
[InlineData(1000)]
[InlineData(2000)]
public void TestLargeInputConsistency(int length)
{
// Verify that large inputs are handled correctly by the public API and round-trip successfully.
// These tests focus on observable behavior (round-tripping and reconstruction), not on comparing internal code paths.

// Test with repeated character
string inputRepeated = new string('z', length);
IReadOnlyList<int> idsRepeated = GPT4.EncodeToIds(inputRepeated);

// Verify round-trip
string decodedRepeated = GPT4.Decode(idsRepeated);
Assert.Equal(inputRepeated, decodedRepeated);

// Test with mixed content (more realistic scenario)
string inputMixed = string.Join(" ", Enumerable.Repeat("Hello World! Test123", length / 20 + 1)).Substring(0, length);
IReadOnlyList<int> idsMixed = GPT4.EncodeToIds(inputMixed);
string decodedMixed = GPT4.Decode(idsMixed);
Assert.Equal(inputMixed, decodedMixed);

// Verify with EncodingToTokens as well
IReadOnlyList<EncodedToken> tokens = GPT4.EncodeToTokens(inputRepeated, out string? normalizedText);
Assert.Null(normalizedText); // No normalization expected

// Reconstruct from tokens
var reconstructed = string.Concat(tokens.Select(t => t.Value));
Assert.Equal(inputRepeated, reconstructed);
}
}
}

Loading