diff --git a/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs b/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs index 725eafa002..9e8a45d4ea 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs @@ -20,6 +20,15 @@ public static (int Id, int TokenIndex, int TokenLength)[] BytePairEncode(ReadOnl return [(ranks[mergingBytes], 0, 1)]; } + // For large inputs, use heap-based algorithm to avoid O(n²) behavior. + // Threshold of 128 chosen empirically: linear scan is cache-friendly for small inputs, + // while heap overhead (O(log n) per operation) becomes worthwhile for larger inputs. + // Based on upstream tiktoken using 100, adjusted upward for C#'s efficient span operations. + if (mergingBytes.Length > 128) + { + return BytePairEncodeLarge(mergingBytes, ranks, indexMappingSpan); + } + (int Index, int Rank)[]? arrayPoolArray = null; int requiredLength = mergingBytes.Length + 1; Span<(int Index, int Rank)> byteIndicesAndRanks = requiredLength <= 64 ? @@ -116,6 +125,168 @@ int GetRank(Span<(int Index, int Rank)> byteIndicesAndRanks, int startIndex, int return result; } + private struct State + { + public int Prev; + public int End; + public int NextEnd; + public int NextRank; + // Note: In the Tiktoken tokenizer, the rank is also the token Id. + // This field is used to cache the rank/Id after a merge so we don't need to re-look it up. + // Using this code with a different tokenizer where rank != token Id would produce wrong results. + public int CurRank; + } + + private struct MergeEntry : IComparable + { + public int Rank; + public int Start; + + public int CompareTo(MergeEntry other) + { + int rankComparison = Rank.CompareTo(other.Rank); + if (rankComparison != 0) + { + return rankComparison; + } + return Start.CompareTo(other.Start); + } + } + + private static (int Id, int TokenIndex, int TokenLength)[] BytePairEncodeLarge(ReadOnlyMemory mergingBytes, IReadOnlyDictionary, int> ranks, ReadOnlySpan indexMappingSpan) + { + int stateLength = mergingBytes.Length; + State[] statePoolArray = ArrayPool.Shared.Rent(stateLength); + Span state = statePoolArray.AsSpan(0, stateLength); + + state[0] = new State + { + Prev = int.MaxValue, + End = 1, + NextEnd = 2, + NextRank = int.MaxValue, + CurRank = int.MaxValue + }; + + var heap = new PriorityQueue(); + + for (int i = 0; i < mergingBytes.Length - 1; i++) + { + var slice = mergingBytes.Slice(i, 2); + if (ranks.TryGetValue(slice, out int rank)) + { + heap.Enqueue(new MergeEntry { Start = i, Rank = rank }); + state[i].NextRank = rank; + } + + state[i + 1] = new State + { + Prev = i, + End = i + 2, + NextEnd = i + 3, + NextRank = int.MaxValue, + CurRank = int.MaxValue + }; + } + + // Local function to add a potential merge to the heap. + void PotentialMerge(Span stateSpan, PriorityQueue heapQueue, int start, int nextEndItem) + { + stateSpan[start].NextEnd = nextEndItem; + stateSpan[start].NextRank = int.MaxValue; + + if (nextEndItem <= mergingBytes.Length) + { + var slice = mergingBytes.Slice(start, nextEndItem - start); + if (ranks.TryGetValue(slice, out int rank)) + { + heapQueue.Enqueue(new MergeEntry { Start = start, Rank = rank }); + stateSpan[start].NextRank = rank; + } + } + } + + while (heap.Count > 0) + { + MergeEntry left = heap.Dequeue(); + + if (left.Rank == int.MaxValue) + { + break; + } + + if (left.Rank != state[left.Start].NextRank) + { + continue; + } + + int leftStart = left.Start; + int rightStart = state[leftStart].End; + int rightEnd = state[leftStart].NextEnd; + int rightNextEnd = state[rightStart].NextEnd; + + state[leftStart].CurRank = state[leftStart].NextRank; + state[leftStart].End = rightEnd; + PotentialMerge(state, heap, leftStart, rightNextEnd); + + if (rightEnd < state.Length) + { + state[rightEnd].Prev = leftStart; + } + + if (leftStart > 0) + { + int prevStart = state[leftStart].Prev; + PotentialMerge(state, heap, prevStart, rightEnd); + } + + state[rightStart].NextRank = int.MaxValue; + } + + // Use ArrayPool for the result buffer to avoid List overhead. + // The maximum number of tokens is mergingBytes.Length (no merges). + var resultPoolArray = ArrayPool<(int Id, int TokenIndex, int TokenLength)>.Shared.Rent(mergingBytes.Length); + int resultCount = 0; + int currentIndex = 0; + + while (currentIndex < state.Length) + { + int startIndex = currentIndex; + int endIndex = state[currentIndex].End; + + int mappedStartIndex = indexMappingSpan[startIndex]; + int mappedEndIndex = indexMappingSpan[endIndex]; + + int finalEndIndex = endIndex; + + // Handle partial characters/elements at token boundaries. + // If the byte at endIndex-1 maps to the same character as endIndex, + // extend the token to include the complete character. + if (finalEndIndex > 0 && indexMappingSpan[finalEndIndex - 1] == mappedEndIndex) + { + finalEndIndex++; + while (finalEndIndex < indexMappingSpan.Length && indexMappingSpan[finalEndIndex] == mappedEndIndex) + { + finalEndIndex++; + } + } + + int tokenId = state[currentIndex].CurRank != int.MaxValue + ? state[currentIndex].CurRank + : ranks[mergingBytes.SliceStartEnd(startIndex, endIndex)]; + + resultPoolArray[resultCount++] = (tokenId, mappedStartIndex, indexMappingSpan[finalEndIndex] - mappedStartIndex); + + currentIndex = state[currentIndex].End; + } + + ArrayPool.Shared.Return(statePoolArray); + + var result = resultPoolArray.AsSpan(0, resultCount).ToArray(); + ArrayPool<(int Id, int TokenIndex, int TokenLength)>.Shared.Return(resultPoolArray); + return result; + } + private static ReadOnlyMemory SliceStartEnd(this ReadOnlyMemory memory, int start, int end) => memory.Slice(start, end - start); } } diff --git a/src/Microsoft.ML.Tokenizers/Utils/PriorityQueue.cs b/src/Microsoft.ML.Tokenizers/Utils/PriorityQueue.cs index 751ce6bc10..5ae1da0cba 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/PriorityQueue.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/PriorityQueue.cs @@ -12,6 +12,10 @@ internal class PriorityQueue where T : IComparable { private readonly List _data; + public PriorityQueue() : this(0) + { + } + public PriorityQueue(int capacity) { _data = new List(capacity); diff --git a/test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs index c7c1e342d8..d15ff22aaa 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.DotNet.RemoteExecutor; using System; using System.Buffers; using System.Collections.Generic; @@ -13,6 +12,7 @@ using System.Text; using System.Text.Json; using System.Threading.Tasks; +using Microsoft.DotNet.RemoteExecutor; using Xunit; namespace Microsoft.ML.Tokenizers.Tests @@ -848,6 +848,76 @@ public void TestOss() private static IReadOnlyDictionary? GetVocabulary(TiktokenTokenizer tiktoken) => typeof(TiktokenTokenizer).GetProperty("Vocabulary", BindingFlags.Instance | BindingFlags.NonPublic)?.GetValue(tiktoken) as IReadOnlyDictionary; + + [Fact] + public void TestLargeInputOptimization() + { + // Verify that large inputs (>128 bytes) and boundary cases round-trip correctly via the public API. + // This exercises the large-input optimization path but does not directly compare it to the small-input path. + + // Test with repeated characters - this is the adversarial case that caused O(n^2) behavior + string largeRepeatedInput = new string('a', 1000); + IReadOnlyList ids = GPT4.EncodeToIds(largeRepeatedInput); + string decoded = GPT4.Decode(ids); + Assert.Equal(largeRepeatedInput, decoded); + + // Test with a more realistic large input + string largeMixedInput = string.Join(" ", Enumerable.Repeat("Hello World! This is a test.", 50)); + IReadOnlyList mixedIds = GPT4.EncodeToIds(largeMixedInput); + string mixedDecoded = GPT4.Decode(mixedIds); + Assert.Equal(largeMixedInput, mixedDecoded); + + // Test boundary case - exactly at threshold (128) + string boundaryInput = new string('x', 128); + IReadOnlyList boundaryIds = GPT4.EncodeToIds(boundaryInput); + string boundaryDecoded = GPT4.Decode(boundaryIds); + Assert.Equal(boundaryInput, boundaryDecoded); + + // Test just below threshold (127) + string belowThresholdInput = new string('x', 127); + IReadOnlyList belowIds = GPT4.EncodeToIds(belowThresholdInput); + string belowDecoded = GPT4.Decode(belowIds); + Assert.Equal(belowThresholdInput, belowDecoded); + + // Test just above threshold (129) + string aboveThresholdInput = new string('x', 129); + IReadOnlyList aboveIds = GPT4.EncodeToIds(aboveThresholdInput); + string aboveDecoded = GPT4.Decode(aboveIds); + Assert.Equal(aboveThresholdInput, aboveDecoded); + } + + [Theory] + [InlineData(200)] + [InlineData(500)] + [InlineData(1000)] + [InlineData(2000)] + public void TestLargeInputConsistency(int length) + { + // Verify that large inputs are handled correctly by the public API and round-trip successfully. + // These tests focus on observable behavior (round-tripping and reconstruction), not on comparing internal code paths. + + // Test with repeated character + string inputRepeated = new string('z', length); + IReadOnlyList idsRepeated = GPT4.EncodeToIds(inputRepeated); + + // Verify round-trip + string decodedRepeated = GPT4.Decode(idsRepeated); + Assert.Equal(inputRepeated, decodedRepeated); + + // Test with mixed content (more realistic scenario) + string inputMixed = string.Join(" ", Enumerable.Repeat("Hello World! Test123", length / 20 + 1)).Substring(0, length); + IReadOnlyList idsMixed = GPT4.EncodeToIds(inputMixed); + string decodedMixed = GPT4.Decode(idsMixed); + Assert.Equal(inputMixed, decodedMixed); + + // Verify with EncodingToTokens as well + IReadOnlyList tokens = GPT4.EncodeToTokens(inputRepeated, out string? normalizedText); + Assert.Null(normalizedText); // No normalization expected + + // Reconstruct from tokens + var reconstructed = string.Concat(tokens.Select(t => t.Value)); + Assert.Equal(inputRepeated, reconstructed); + } } }