Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,23 @@ def _get_http_options():
timeout,
transcoded_request,
body=None):

uri = transcoded_request['uri']
method = transcoded_request['method']
headers = dict(metadata)
headers['Content-Type'] = 'application/json'
# Build query string manually to avoid URL-encoding special characters like '$'.
# The `requests` library encodes '$' as '%24' when using the `params` argument,
# which causes API errors for parameters like '$alt'. See:
# https://github.com/googleapis/gapic-generator-python/issues/2514
_query_params = rest_helpers.flatten_query_params(query_params, strict=True)
_request_url = "{host}{uri}".format(host=host, uri=uri)
if _query_params:
_request_url = "{}?{}".format(_request_url, urlencode(_query_params, safe="$"))
response = {{ await_prefix }}getattr(session, method)(
"{host}{uri}".format(host=host, uri=uri),
_request_url,
timeout=timeout,
headers=headers,
params=rest_helpers.flatten_query_params(query_params, strict=True),
{% if body_spec %}
data=body,
{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ from google.iam.v1 import policy_pb2 # type: ignore
from google.cloud.location import locations_pb2 # type: ignore
{% endif %}

from requests import __version__ as requests_version
import dataclasses
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from urllib.parse import urlencode
import warnings

from requests import __version__ as requests_version

{{ shared_macros.operations_mixin_imports(api, service, opts) }}

from .rest_base import _Base{{ service.name }}RestTransport
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import grpc
from grpc.experimental import aio
{% if "rest" in opts.transport %}
from collections.abc import Iterable
import urllib.parse

from google.protobuf import json_format
import json
{% endif %}
Expand Down Expand Up @@ -45,6 +47,7 @@ from google.api_core import client_options
from google.api_core import exceptions as core_exceptions
from google.api_core import grpc_helpers
from google.api_core import path_template
from google.api_core import rest_helpers
from google.api_core import retry as retries
{% if service.has_lro %}
from google.api_core import future
Expand Down Expand Up @@ -1451,8 +1454,12 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
('$alt', 'json;enum-encoding=int')
{% endif %}
]
actual_params = req.call_args.kwargs['params']
assert expected_params == actual_params
# Verify query params are correctly included in the URL
# Session.request is called as request(method, url, ...), so url is args[1]
actual_url = req.call_args.args[1]
parsed_url = urllib.parse.urlparse(actual_url)
actual_params = urllib.parse.parse_qsl(parsed_url.query, keep_blank_values=True)
assert set(expected_params).issubset(set(actual_params))


def test_{{ method_name }}_rest_unset_required_fields():
Expand All @@ -1461,9 +1468,55 @@ def test_{{ method_name }}_rest_unset_required_fields():
unset_fields = transport.{{ method.transport_safe_name|snake_case }}._get_unset_required_fields({})
assert set(unset_fields) == (set(({% for param in method.query_params|sort %}"{{ param|camel_case }}", {% endfor %})) & set(({% for param in method.input.required_fields %}"{{param.name|camel_case}}", {% endfor %})))


{% endif %}{# required_fields #}


def test_{{ method_name }}_rest_url_query_params_encoding():
# Verify that special characters like '$' are correctly preserved (not URL-encoded)
# when building the URL query string. This tests the urlencode call with safe="$".
transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials)
method_class = transport.{{ method.transport_safe_name|snake_case }}.__class__
# Get the _get_response method from the method class
get_response_fn = method_class._get_response.__func__

mock_session = mock.Mock()
mock_response = mock.Mock()
mock_response.status_code = 200
mock_session.get.return_value = mock_response
mock_session.post.return_value = mock_response
mock_session.put.return_value = mock_response
mock_session.patch.return_value = mock_response
mock_session.delete.return_value = mock_response

# Mock flatten_query_params to return query params that include '$' character
with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten:
mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')]

transcoded_request = {
'uri': '/v1/test',
'method': '{{ method.http_options[0].method }}',
}

get_response_fn(
host='https://example.com',
metadata=[],
query_params={},
session=mock_session,
timeout=None,
transcoded_request=transcoded_request,
)

# Verify the session method was called with the URL containing query params
session_method = getattr(mock_session, '{{ method.http_options[0].method }}')
assert session_method.called

# The URL should contain '$alt' (not '%24alt') because safe="$" is used
call_url = session_method.call_args.args[0]
assert '$alt=json' in call_url
assert '%24alt' not in call_url
assert 'foo=bar' in call_url


{% if not method.client_streaming %}
@pytest.mark.parametrize("null_interceptor", [True, False])
def test_{{ method_name }}_rest_interceptors(null_interceptor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,23 @@ def _get_http_options():
timeout,
transcoded_request,
body=None):

uri = transcoded_request['uri']
method = transcoded_request['method']
headers = dict(metadata)
headers['Content-Type'] = 'application/json'
# Build query string manually to avoid URL-encoding special characters like '$'.
# The `requests` library encodes '$' as '%24' when using the `params` argument,
# which causes API errors for parameters like '$alt'. See:
# https://github.com/googleapis/gapic-generator-python/issues/2514
_query_params = rest_helpers.flatten_query_params(query_params, strict=True)
_request_url = "{host}{uri}".format(host=host, uri=uri)
if _query_params:
_request_url = "{}?{}".format(_request_url, urlencode(_query_params, safe="$"))
response = {{ await_prefix }}getattr(session, method)(
"{host}{uri}".format(host=host, uri=uri),
_request_url,
timeout=timeout,
headers=headers,
params=rest_helpers.flatten_query_params(query_params, strict=True),
{% if body_spec %}
data=body,
{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ from google.iam.v1 import policy_pb2 # type: ignore
from google.cloud.location import locations_pb2 # type: ignore
{% endif %}

from requests import __version__ as requests_version
import dataclasses
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from urllib.parse import urlencode
import warnings

from requests import __version__ as requests_version

{{ shared_macros.operations_mixin_imports(api, service, opts) }}

from .rest_base import _Base{{ service.name }}RestTransport
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ from google.iam.v1 import policy_pb2 # type: ignore
from google.cloud.location import locations_pb2 # type: ignore
{% endif %}

import json # type: ignore
import dataclasses
import json # type: ignore
from typing import Any, Dict, List, Callable, Tuple, Optional, Sequence, Union
from urllib.parse import urlencode

{{ shared_macros.operations_mixin_imports(api, service, opts) }}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import grpc
from grpc.experimental import aio
{% if "rest" in opts.transport %}
from collections.abc import Iterable, AsyncIterable
import urllib.parse

from google.protobuf import json_format
{% endif %}
import json
Expand Down Expand Up @@ -72,6 +74,7 @@ from google.api_core import exceptions as core_exceptions
from google.api_core import grpc_helpers
from google.api_core import grpc_helpers_async
from google.api_core import path_template
from google.api_core import rest_helpers
from google.api_core import retry as retries
{% if service.has_lro or service.has_extended_lro %}
from google.api_core import future
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1200,8 +1200,12 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
('$alt', 'json;enum-encoding=int')
{% endif %}
]
actual_params = req.call_args.kwargs['params']
assert expected_params == actual_params
# Verify query params are correctly included in the URL
# Session.request is called as request(method, url, ...), so url is args[1]
actual_url = req.call_args.args[1]
parsed_url = urllib.parse.urlparse(actual_url)
actual_params = urllib.parse.parse_qsl(parsed_url.query, keep_blank_values=True)
assert set(expected_params).issubset(set(actual_params))


def test_{{ method_name }}_rest_unset_required_fields():
Expand All @@ -1213,6 +1217,52 @@ def test_{{ method_name }}_rest_unset_required_fields():
{% endif %}{# required_fields #}


def test_{{ method_name }}_rest_url_query_params_encoding():
# Verify that special characters like '$' are correctly preserved (not URL-encoded)
# when building the URL query string. This tests the urlencode call with safe="$".
transport = transports.{{ service.rest_transport_name }}(credentials=ga_credentials.AnonymousCredentials)
method_class = transport.{{ method.transport_safe_name|snake_case }}.__class__
# Get the _get_response method from the method class
get_response_fn = method_class._get_response.__func__

mock_session = mock.Mock()
mock_response = mock.Mock()
mock_response.status_code = 200
mock_session.get.return_value = mock_response
mock_session.post.return_value = mock_response
mock_session.put.return_value = mock_response
mock_session.patch.return_value = mock_response
mock_session.delete.return_value = mock_response

# Mock flatten_query_params to return query params that include '$' character
with mock.patch.object(rest_helpers, 'flatten_query_params') as mock_flatten:
mock_flatten.return_value = [('$alt', 'json;enum-encoding=int'), ('foo', 'bar')]

transcoded_request = {
'uri': '/v1/test',
'method': '{{ method.http_options[0].method }}',
}

get_response_fn(
host='https://example.com',
metadata=[],
query_params={},
session=mock_session,
timeout=None,
transcoded_request=transcoded_request,
)

# Verify the session method was called with the URL containing query params
session_method = getattr(mock_session, '{{ method.http_options[0].method }}')
assert session_method.called

# The URL should contain '$alt' (not '%24alt') because safe="$" is used
call_url = session_method.call_args.args[0]
assert '$alt=json' in call_url
assert '%24alt' not in call_url
assert 'foo=bar' in call_url


{% if method.flattened_fields and not method.client_streaming %}
def test_{{ method_name }}_rest_flattened():
client = {{ service.client_name }}(
Expand Down
Loading