Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions google/genai/_interactions/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,20 @@ def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:

def is_file_content(obj: object) -> TypeGuard[FileContent]:
return (
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
isinstance(obj, bytes)
or isinstance(obj, tuple)
or isinstance(obj, io.IOBase)
or isinstance(obj, os.PathLike)
)


def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
if not is_file_content(obj):
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
prefix = (
f"Expected entry at `{key}`"
if key is not None
else f"Expected file input `{obj!r}`"
)
raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead."
) from None
Expand All @@ -71,7 +78,9 @@ def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
elif is_sequence_t(files):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
raise TypeError(
f"Unexpected file type input {type(files)}, expected mapping or sequence"
)

return files

Expand All @@ -80,19 +89,23 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = pathlib.Path(file)
return (path.name, path.read_bytes())
# Return an open file handle instead of loading entire file into memory.
# This prevents OOM errors for large files. httpx supports IO[bytes] directly.
return (path.name, open(path, "rb"))

return file

if is_tuple_t(file):
return (file[0], read_file_content(file[1]), *file[2:])

raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
raise TypeError(
f"Expected file types input to be a FileContent type or to be a tuple"
)


def read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return pathlib.Path(file).read_bytes()
return open(pathlib.Path(file), "rb")
return file


Expand All @@ -113,27 +126,31 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles
elif is_sequence_t(files):
files = [(key, await _async_transform_file(file)) for key, file in files]
else:
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
raise TypeError(
"Unexpected file type input {type(files)}, expected mapping or sequence"
)

return files


async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = anyio.Path(file)
return (path.name, await path.read_bytes())
path = pathlib.Path(file)
return (path.name, open(path, "rb"))

return file

if is_tuple_t(file):
return (file[0], await async_read_file_content(file[1]), *file[2:])

raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
raise TypeError(
f"Expected file types input to be a FileContent type or to be a tuple"
)


async def async_read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return await anyio.Path(file).read_bytes()
return open(pathlib.Path(file), "rb")

return file
163 changes: 120 additions & 43 deletions google/genai/_interactions/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
from typing_extensions import (
Literal,
get_args,
override,
get_type_hints as _get_type_hints,
)

import anyio
import pydantic
Expand Down Expand Up @@ -196,15 +201,26 @@ def _transform_recursive(

if origin == dict and is_mapping(data):
items_type = get_args(stripped_type)[1]
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
return {
key: _transform_recursive(value, annotation=items_type)
for key, value in data.items()
}

if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
or (
is_iterable_type(stripped_type)
and is_iterable(data)
and not isinstance(data, str)
)
# Sequence[T]
or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
or (
is_sequence_type(stripped_type)
and is_sequence(data)
and not isinstance(data, str)
)
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
Expand All @@ -221,7 +237,10 @@ def _transform_recursive(
return data
return list(data)

return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
return [
_transform_recursive(d, annotation=annotation, inner_type=inner_type)
for d in data
]

if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
Expand All @@ -248,7 +267,9 @@ def _transform_recursive(
return data


def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
def _format_data(
data: object, format_: PropertyFormat, format_template: str | None
) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()
Expand All @@ -257,22 +278,35 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N
return data.strftime(format_template)

if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None

if isinstance(data, pathlib.Path):
binary = data.read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
return _encode_file_to_base64(data)

if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()

if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return data

return base64.b64encode(binary).decode("ascii")

return data
def _encode_file_to_base64(data: object) -> str:
"""Encode file content to base64 using chunked reading to reduce peak memory usage."""
CHUNK_SIZE = 3 * 1024 * 1024 # 3MB (must be multiple of 3 for base64)
chunks: list[str] = []

if isinstance(data, pathlib.Path):
with open(data, "rb") as f:
while True:
chunk = f.read(CHUNK_SIZE)
if not chunk:
break
chunks.append(base64.b64encode(chunk).decode("ascii"))
elif isinstance(data, io.IOBase):
while True:
chunk = data.read(CHUNK_SIZE)
if not chunk:
break
if isinstance(chunk, str):
chunk = chunk.encode()
chunks.append(base64.b64encode(chunk).decode("ascii"))
else:
raise RuntimeError(f"Could not read bytes from {data}; Received {type(data)}")

return "".join(chunks)


def _transform_typeddict(
Expand All @@ -292,7 +326,9 @@ def _transform_typeddict(
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
result[_maybe_transform_key(key, type_)] = _transform_recursive(
value, annotation=type_
)
return result


Expand Down Expand Up @@ -328,7 +364,9 @@ class Params(TypedDict, total=False):

It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
transformed = await _async_transform_recursive(
data, annotation=cast(type, expected_type)
)
return cast(_T, transformed)


Expand Down Expand Up @@ -362,15 +400,26 @@ async def _async_transform_recursive(

if origin == dict and is_mapping(data):
items_type = get_args(stripped_type)[1]
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
return {
key: _transform_recursive(value, annotation=items_type)
for key, value in data.items()
}

if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
or (
is_iterable_type(stripped_type)
and is_iterable(data)
and not isinstance(data, str)
)
# Sequence[T]
or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
or (
is_sequence_type(stripped_type)
and is_sequence(data)
and not isinstance(data, str)
)
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
Expand All @@ -387,15 +436,22 @@ async def _async_transform_recursive(
return data
return list(data)

return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
return [
await _async_transform_recursive(
d, annotation=annotation, inner_type=inner_type
)
for d in data
]

if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
data = await _async_transform_recursive(
data, annotation=annotation, inner_type=subtype
)
return data

if isinstance(data, pydantic.BaseModel):
Expand All @@ -409,12 +465,16 @@ async def _async_transform_recursive(
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return await _async_format_data(data, annotation.format, annotation.format_template)
return await _async_format_data(
data, annotation.format, annotation.format_template
)

return data


async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
async def _async_format_data(
data: object, format_: PropertyFormat, format_template: str | None
) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()
Expand All @@ -423,22 +483,35 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ
return data.strftime(format_template)

if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None

if isinstance(data, pathlib.Path):
binary = await anyio.Path(data).read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()

if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
return await _async_encode_file_to_base64(data)

if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return data

return base64.b64encode(binary).decode("ascii")

return data
async def _async_encode_file_to_base64(data: object) -> str:
"""Encode file content to base64 using chunked reading to reduce peak memory usage."""
CHUNK_SIZE = 3 * 1024 * 1024 # 3MB (must be multiple of 3 for base64)
chunks: list[str] = []

if isinstance(data, pathlib.Path):
async with await anyio.Path(data).open("rb") as f:
while True:
chunk = await f.read(CHUNK_SIZE)
if not chunk:
break
chunks.append(base64.b64encode(chunk).decode("ascii"))
elif isinstance(data, io.IOBase):
while True:
chunk = data.read(CHUNK_SIZE)
if not chunk:
break
if isinstance(chunk, str):
chunk = chunk.encode()
chunks.append(base64.b64encode(chunk).decode("ascii"))
else:
raise RuntimeError(f"Could not read bytes from {data}; Received {type(data)}")

return "".join(chunks)


async def _async_transform_typeddict(
Expand All @@ -458,7 +531,9 @@ async def _async_transform_typeddict(
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(
value, annotation=type_
)
return result


Expand All @@ -469,4 +544,6 @@ def get_type_hints(
localns: Mapping[str, Any] | None = None,
include_extras: bool = False,
) -> dict[str, Any]:
return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
return _get_type_hints(
obj, globalns=globalns, localns=localns, include_extras=include_extras
)
Loading