diff --git a/sdks/python/apache_beam/yaml/tests/create.yaml b/sdks/python/apache_beam/yaml/tests/create.yaml index bf346f7667c8..6cd7807681c0 100644 --- a/sdks/python/apache_beam/yaml/tests/create.yaml +++ b/sdks/python/apache_beam/yaml/tests/create.yaml @@ -138,3 +138,21 @@ pipelines: - {sdk: MapReduce, year: 2004} - {sdk: MillWheel, year: 2008} + + # Simple Create with mixed types + - pipeline: + type: chain + transforms: + - type: Create + config: + elements: + - 1 + - {a: 2, c: "hello"} + - 3 + - type: AssertEqual + config: + elements: + - {element: 1, a: null, c: null} + - {element: null, a: 2, c: "hello"} + - {element: 3, a: null, c: null} + diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index e9882602d100..5a3ccf6b0c2e 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -864,6 +864,19 @@ def create(elements: Iterable[Any], reshuffle: Optional[bool] = True): str: "bar" values: [4, 5, 6] + If the elements are a mix of dicts and non-dicts, the non-dict elements + will be wrapped in a Row with a single field "element". For example:: + + type: Create + config: + elements: [1, {"a": 2}] + + will result in an output with two elements with a schema of + Row(element=int, a=int) looking like: + + Row(element=1, a=None) + Row(element=None, a=2) + Args: elements: The set of elements that should belong to the PCollection. YAML/JSON-style mappings will be interpreted as Beam rows. @@ -878,6 +891,25 @@ def create(elements: Iterable[Any], reshuffle: Optional[bool] = True): if not isinstance(elements, Iterable) or isinstance(elements, (dict, str)): raise TypeError('elements must be a list of elements') + if elements: + # Normalize elements to be all dicts or all primitives. + has_dict = False + has_non_dict = False + for e in elements: + if isinstance(e, dict): + has_dict = True + else: + has_non_dict = True + if has_dict and has_non_dict: + break + + if has_dict and has_non_dict: + elements = [ + e if isinstance(e, dict) else { + 'element': e + } for e in elements + ] + # Check if elements have different keys updated_elements = elements if elements and all(isinstance(e, dict) for e in elements): diff --git a/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py b/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py index 1ebae9a3b446..e1e3ee847d96 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py @@ -364,3 +364,16 @@ def test_empty_base(self): if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() + + +class YamlProvidersCreateTest(unittest.TestCase): + def test_create_mixed_types(self): + with beam.Pipeline() as p: + # A mix of a primitive (Row(element=1)) and a dict (Row(a=2)) + result = p | YamlProviders.create([1, {"a": 2}]) + assert_that( + result | beam.Map(lambda x: sorted(x._asdict().items())), + equal_to([ + [('a', None), ('element', 1)], + [('a', 2), ('element', None)], + ]))