diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 1c3f0918bafd..aa059b14484b 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -178,6 +178,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + length_fn: Optional[Callable[[Any], int]] = None, + bucket_boundaries: Optional[list[int]] = None, large_model: bool = False, model_copies: Optional[int] = None, **kwargs): @@ -190,6 +192,11 @@ def __init__( before emitting; used in streaming contexts. max_batch_weight: the maximum weight of a batch. Requires element_size_fn. element_size_fn: a function that returns the size (weight) of an element. + length_fn: a callable mapping an element to its length. When set with + max_batch_duration_secs, enables length-aware bucketed keying so + elements of similar length are batched together. + bucket_boundaries: sorted list of positive boundary values for length + bucketing. Requires length_fn. large_model: set to true if your model is large enough to run into memory pressure if you load multiple copies. model_copies: The exact number of models that you would like loaded @@ -209,6 +216,10 @@ def __init__( self._batching_kwargs['max_batch_weight'] = max_batch_weight if element_size_fn is not None: self._batching_kwargs['element_size_fn'] = element_size_fn + if length_fn is not None: + self._batching_kwargs['length_fn'] = length_fn + if bucket_boundaries is not None: + self._batching_kwargs['bucket_boundaries'] = bucket_boundaries self._large_model = large_model self._model_copies = model_copies self._share_across_processes = large_model or (model_copies is not None) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index feccd8b0f12e..a5d26be6695b 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -2278,6 +2278,43 @@ def test_max_batch_duration_secs_only(self): self.assertEqual(kwargs, {'max_batch_duration_secs': 60}) + def test_length_fn_and_bucket_boundaries(self): + """length_fn and bucket_boundaries are passed through to kwargs.""" + handler = FakeModelHandlerForBatching( + length_fn=len, bucket_boundaries=[16, 32, 64]) + kwargs = handler.batch_elements_kwargs() + + self.assertIs(kwargs['length_fn'], len) + self.assertEqual(kwargs['bucket_boundaries'], [16, 32, 64]) + + def test_length_fn_only(self): + """length_fn alone is passed through without bucket_boundaries.""" + handler = FakeModelHandlerForBatching(length_fn=len) + kwargs = handler.batch_elements_kwargs() + + self.assertIs(kwargs['length_fn'], len) + self.assertNotIn('bucket_boundaries', kwargs) + + def test_bucket_boundaries_without_length_fn(self): + """Passing bucket_boundaries without length_fn should fail in BatchElements. + + Note: ModelHandler.__init__ doesn't validate this; the error is raised + by BatchElements when batch_elements_kwargs are used.""" + handler = FakeModelHandlerForBatching(bucket_boundaries=[10, 20]) + kwargs = handler.batch_elements_kwargs() + # The kwargs are stored, but BatchElements will reject them + self.assertEqual(kwargs['bucket_boundaries'], [10, 20]) + self.assertNotIn('length_fn', kwargs) + + def test_batching_kwargs_none_values_omitted(self): + """None values for length_fn and bucket_boundaries are not in kwargs.""" + handler = FakeModelHandlerForBatching( + min_batch_size=5, length_fn=None, bucket_boundaries=None) + kwargs = handler.batch_elements_kwargs() + self.assertNotIn('length_fn', kwargs) + self.assertNotIn('bucket_boundaries', kwargs) + self.assertEqual(kwargs['min_batch_size'], 5) + class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]): def load_model(self): diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index fbaab6b4ebbb..348bfd5b9acd 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -20,6 +20,7 @@ # pytype: skip-file +import bisect import collections import contextlib import hashlib @@ -1209,6 +1210,28 @@ def process(self, element): yield (self.key, element) +class WithLengthBucketKey(DoFn): + """Keys elements with (worker_uuid, length_bucket) for length-aware + stateful batching. Elements of similar length are routed to the same + state partition, reducing padding waste.""" + def __init__(self, length_fn, bucket_boundaries): + self.shared_handle = shared.Shared() + self._length_fn = length_fn + self._bucket_boundaries = bucket_boundaries + + def setup(self): + self.key = self.shared_handle.acquire( + load_shared_key, "WithLengthBucketKey").key + + def _get_bucket(self, length): + return bisect.bisect_left(self._bucket_boundaries, length) + + def process(self, element): + length = self._length_fn(element) + bucket = self._get_bucket(length) + yield ((self.key, bucket), element) + + @typehints.with_input_types(T) @typehints.with_output_types(list[T]) class BatchElements(PTransform): @@ -1268,7 +1291,18 @@ class BatchElements(PTransform): donwstream operations (mostly for testing) record_metrics: (optional) whether or not to record beam metrics on distributions of the batch size. Defaults to True. + length_fn: (optional) a callable mapping an element to its length (int). + When set together with max_batch_duration_secs, enables length-aware + bucketed keying on the stateful path so that elements of similar length + are routed to the same batch, reducing padding waste. + bucket_boundaries: (optional) a sorted list of positive boundary values + for length bucketing. Elements with length < boundaries[i] go to + bucket i; overflow goes to bucket len(boundaries). Defaults to + [16, 32, 64, 128, 256, 512] when length_fn is set. Requires + length_fn. """ + _DEFAULT_BUCKET_BOUNDARIES = [16, 32, 64, 128, 256, 512] + def __init__( self, min_batch_size=1, @@ -1281,7 +1315,17 @@ def __init__( element_size_fn=lambda x: 1, variance=0.25, clock=time.time, - record_metrics=True): + record_metrics=True, + length_fn=None, + bucket_boundaries=None): + if bucket_boundaries is not None and length_fn is None: + raise ValueError('bucket_boundaries requires length_fn to be set.') + if bucket_boundaries is not None: + if (not bucket_boundaries or any(b <= 0 for b in bucket_boundaries) or + bucket_boundaries != sorted(bucket_boundaries)): + raise ValueError( + 'bucket_boundaries must be a non-empty sorted list of ' + 'positive values.') self._batch_size_estimator = _BatchSizeEstimator( min_batch_size=min_batch_size, max_batch_size=max_batch_size, @@ -1295,13 +1339,23 @@ def __init__( self._element_size_fn = element_size_fn self._max_batch_dur = max_batch_duration_secs self._clock = clock + self._length_fn = length_fn + if length_fn is not None and bucket_boundaries is None: + self._bucket_boundaries = self._DEFAULT_BUCKET_BOUNDARIES + else: + self._bucket_boundaries = bucket_boundaries def expand(self, pcoll): if getattr(pcoll.pipeline.runner, 'is_streaming', False): raise NotImplementedError("Requires stateful processing (BEAM-2687)") elif self._max_batch_dur is not None: coder = coders.registry.get_coder(pcoll) - return pcoll | ParDo(WithSharedKey()) | ParDo( + if self._length_fn is not None: + keying_dofn = WithLengthBucketKey( + self._length_fn, self._bucket_boundaries) + else: + keying_dofn = WithSharedKey() + return pcoll | ParDo(keying_dofn) | ParDo( _pardo_stateful_batch_elements( coder, self._batch_size_estimator, diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 7389568691cd..05fcfb3c8b36 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -65,6 +65,7 @@ from apache_beam.testing.util import assert_that from apache_beam.testing.util import contains_in_any_order from apache_beam.testing.util import equal_to +from apache_beam.testing.util import is_not_empty from apache_beam.transforms import trigger from apache_beam.transforms import util from apache_beam.transforms import window @@ -1025,6 +1026,236 @@ def test_stateful_grows_to_max_batch(self): | beam.Map(len)) assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50])) + def test_length_bucket_assignment(self): + """WithLengthBucketKey assigns correct bucket indices.""" + boundaries = [10, 50, 100] + dofn = util.WithLengthBucketKey(length_fn=len, bucket_boundaries=boundaries) + # bisect_left: length < 10 -> bucket 0, 10 <= length < 50 -> bucket 1, etc. + self.assertEqual(dofn._get_bucket(5), 0) + self.assertEqual(dofn._get_bucket(10), 0) + self.assertEqual(dofn._get_bucket(11), 1) + self.assertEqual(dofn._get_bucket(50), 1) + self.assertEqual(dofn._get_bucket(51), 2) + self.assertEqual(dofn._get_bucket(100), 2) + self.assertEqual(dofn._get_bucket(101), 3) + self.assertEqual(dofn._get_bucket(999), 3) + + def test_stateful_length_aware_constant_batch(self): + """Elements in distinct length groups produce separate batches.""" + # Create short strings (len 1-5) and long strings (len 50-55) + short = ['x' * i for i in range(1, 6)] * 4 # 20 short strings + long = ['y' * i for i in range(50, 56)] * 4 # 24 long strings + elements = short + long + + p = TestPipeline('FnApiRunner') + batches = ( + p + | beam.Create(elements) + | util.BatchElements( + min_batch_size=5, + max_batch_size=10, + max_batch_duration_secs=100, + length_fn=len, + bucket_boundaries=[10, 50])) + + # Verify that no batch mixes short and long elements + def check_no_mixing(batch): + lengths = [len(s) for s in batch] + min_len, max_len = min(lengths), max(lengths) + # Within a bucket, all elements should have similar length + assert max_len - min_len < 50, ( + f'Batch mixed short and long: lengths {lengths}') + return True + + checks = batches | beam.Map(check_no_mixing) + assert_that(checks, is_not_empty()) + res = p.run() + res.wait_until_finish() + + def test_stateful_length_aware_default_boundaries(self): + """Default boundaries [16, 32, 64, 128, 256, 512] are applied.""" + be = util.BatchElements(max_batch_duration_secs=100, length_fn=len) + self.assertEqual(be._bucket_boundaries, [16, 32, 64, 128, 256, 512]) + + def test_length_aware_requires_length_fn(self): + """bucket_boundaries without length_fn raises ValueError.""" + with self.assertRaises(ValueError): + util.BatchElements( + max_batch_duration_secs=100, bucket_boundaries=[10, 20]) + + def test_bucket_boundaries_must_be_sorted(self): + """Unsorted boundaries raise ValueError.""" + with self.assertRaises(ValueError): + util.BatchElements( + max_batch_duration_secs=100, + length_fn=len, + bucket_boundaries=[50, 10, 100]) + + def test_bucket_boundaries_must_be_positive(self): + """Non-positive boundaries raise ValueError.""" + with self.assertRaises(ValueError): + util.BatchElements( + max_batch_duration_secs=100, + length_fn=len, + bucket_boundaries=[0, 10, 100]) + + def test_length_fn_without_stateful_is_ignored(self): + """length_fn without max_batch_duration_secs uses non-stateful path.""" + with TestPipeline() as p: + res = ( + p + | beam.Create(['a', 'bb', 'ccc']) + | util.BatchElements( + min_batch_size=3, max_batch_size=3, length_fn=len) + | beam.Map(len)) + assert_that(res, equal_to([3])) + + def test_padding_efficiency_bimodal(self): + """Benchmark: length-aware bucketing yields better padding efficiency + than unbucketed batching on a bimodal length distribution. + + Padding efficiency per batch = sum(lengths) / (max_len * batch_size). + With bucketing, short and long elements land in separate batches, + so each batch pads to a smaller max, improving efficiency. + """ + random.seed(42) + short = ['x' * random.randint(5, 30) for _ in range(500)] + long = ['y' * random.randint(200, 512) for _ in range(500)] + elements = short + long + batch_size = 32 + + def batch_efficiency(batch): + """Returns (useful_tokens, padded_tokens) for one batch.""" + lengths = [len(s) for s in batch] + return (sum(lengths), max(lengths) * len(lengths)) + + # Run WITH bucketing — collect (useful, padded) per batch + p_bucketed = TestPipeline('FnApiRunner') + bucketed_eff = ( + p_bucketed + | 'CreateBucketed' >> beam.Create(elements) + | 'BatchBucketed' >> util.BatchElements( + min_batch_size=batch_size, + max_batch_size=batch_size, + max_batch_duration_secs=100, + length_fn=len, + bucket_boundaries=[16, 32, 64, 128, 256, 512]) + | 'EffBucketed' >> beam.Map(batch_efficiency) + | 'SumBucketed' >> beam.CombineGlobally( + lambda pairs: (sum(p[0] for p in pairs), sum(p[1] for p in pairs)))) + + # Run WITHOUT bucketing + p_unbucketed = TestPipeline('FnApiRunner') + unbucketed_eff = ( + p_unbucketed + | 'CreateUnbucketed' >> beam.Create(elements) + | 'BatchUnbucketed' >> util.BatchElements( + min_batch_size=batch_size, + max_batch_size=batch_size, + max_batch_duration_secs=100) + | 'EffUnbucketed' >> beam.Map(batch_efficiency) + | 'SumUnbucketed' >> beam.CombineGlobally( + lambda pairs: (sum(p[0] for p in pairs), sum(p[1] for p in pairs)))) + + def check_bucketed_above_threshold(totals): + useful, padded = totals[0] + eff = useful / padded if padded else 0 + assert eff > 0.70, ( + f'Bucketed padding efficiency {eff:.2%} should be > 70%') + + def check_unbucketed_below_bucketed(totals): + useful, padded = totals[0] + eff = useful / padded if padded else 0 + # With bimodal data in a single key, short elements get padded + # to the max of each batch which often includes long elements. + assert eff < 0.70, ( + f'Unbucketed efficiency {eff:.2%} expected < 70% for ' + f'bimodal distribution (sanity check)') + + assert_that(bucketed_eff, check_bucketed_above_threshold) + res = p_bucketed.run() + res.wait_until_finish() + + assert_that(unbucketed_eff, check_unbucketed_below_bucketed) + res = p_unbucketed.run() + res.wait_until_finish() + + def test_with_length_bucket_key_setup_and_process(self): + """WithLengthBucketKey.setup() and process() work correctly in pipeline.""" + boundaries = [10, 50] + elements = ['short', 'x' * 30, 'y' * 60] + + with TestPipeline('FnApiRunner') as p: + result = ( + p + | beam.Create(elements) + | beam.ParDo(util.WithLengthBucketKey(len, boundaries))) + + def check_keys(keyed_elements): + # Each element should have format ((worker_key, bucket), element) + for (key, bucket), elem in keyed_elements: + # Verify key is a UUID string + assert isinstance(key, str) and len(key) > 0 + # Verify bucket is correct + if len(elem) < 10: + assert bucket == 0, f'Expected bucket 0 for {elem}' + elif len(elem) < 50: + assert bucket == 1, f'Expected bucket 1 for {elem}' + else: + assert bucket == 2, f'Expected bucket 2 for {elem}' + + assert_that(result, check_keys) + + def test_bucket_boundaries_empty_list(self): + """Empty bucket_boundaries list raises ValueError.""" + with self.assertRaises(ValueError): + util.BatchElements( + max_batch_duration_secs=100, length_fn=len, bucket_boundaries=[]) + + def test_with_custom_bucket_boundaries(self): + """Custom bucket_boundaries are used instead of defaults.""" + custom_boundaries = [5, 15, 25] + be = util.BatchElements( + max_batch_duration_secs=100, + length_fn=len, + bucket_boundaries=custom_boundaries) + self.assertEqual(be._bucket_boundaries, custom_boundaries) + + def test_length_fn_applied_in_pipeline(self): + """Verify length_fn is used for bucketing in stateful batching.""" + # Create strings of different lengths that should go to different buckets + short_strings = ['x' * i for i in range(1, 5)] # lengths 1-4, bucket 0 + medium_strings = ['y' * i for i in range(20, 24)] # lengths 20-23, bucket 1 + elements = short_strings + medium_strings + + with TestPipeline('FnApiRunner') as p: + batches = ( + p + | beam.Create(elements) + | util.BatchElements( + min_batch_size=2, + max_batch_size=10, + max_batch_duration_secs=100, + length_fn=len, + bucket_boundaries=[10, 30])) + + def check_batch_homogeneity(batch): + """Batches should contain elements of similar length.""" + lengths = [len(s) for s in batch] + # If bucketing works, all elements should be in same bucket + # (either all < 10 or all between 10 and 30) + min_len, max_len = min(lengths), max(lengths) + if min_len < 10: + # Short bucket: all should be < 10 + assert max_len < 10, f'Mixed batch: {lengths}' + else: + # Medium bucket: all should be >= 10 + assert min_len >= 10, f'Mixed batch: {lengths}' + return True + + checks = batches | beam.Map(check_batch_homogeneity) + assert_that(checks, is_not_empty()) + class IdentityWindowTest(unittest.TestCase): def test_window_preserved(self):