Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 69 additions & 31 deletions google/ads/googleads/interceptors/exception_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,55 @@
so it translates the error to a GoogleAdsFailure instance and raises it.
"""

import grpc

from grpc import UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor
from typing import Any, Callable, TypeVar, Union, NoReturn

import grpc
from grpc import (
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor,
ClientCallDetails,
Call,
)

from google.ads.googleads.errors import GoogleAdsException
from .interceptor import Interceptor
from .response_wrappers import _UnaryStreamWrapper, _UnaryUnaryWrapper

# Define generic types for request and response messages.
# These are typically protobuf message instances.
RequestType = TypeVar("RequestType")
ResponseType = TypeVar("ResponseType")
# Type for the continuation callable in intercept_unary_unary
UnaryUnaryContinuation = Callable[
[ClientCallDetails, RequestType], Union[Call, Any]
]
# Type for the continuation callable in intercept_unary_stream
UnaryStreamContinuation = Callable[
[ClientCallDetails, RequestType], Union[grpc.Call, Any]
]


class ExceptionInterceptor(
Interceptor, UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor
):
"""An interceptor that wraps rpc exceptions."""

def __init__(self, api_version, use_proto_plus=False):
_api_version: str
_use_proto_plus: bool

def __init__(self, api_version: str, use_proto_plus: bool = False):
"""Initializes the ExceptionInterceptor.

Args:
api_version: a str of the API version of the request.
use_proto_plus: a boolean of whether returned messages should be
api_version: A str of the API version of the request.
use_proto_plus: A boolean of whether returned messages should be
proto_plus or protobuf.
"""
super().__init__(api_version)
self._api_version = api_version
self._use_proto_plus = use_proto_plus

def _handle_grpc_failure(self, response):
def _handle_grpc_failure(self, response: grpc.Call) -> NoReturn:
"""Attempts to convert failed responses to a GoogleAdsException object.

Handles failed gRPC responses of by attempting to convert them
Expand All @@ -61,16 +84,23 @@ def _handle_grpc_failure(self, response):
Raises:
GoogleAdsException: If the exception's trailing metadata
indicates that it is a GoogleAdsException.
RpcError: If the exception's is a gRPC exception but the trailing
grpc.RpcError: If the exception's is a gRPC exception but the trailing
metadata is empty or is not indicative of a GoogleAdsException,
or if the exception has a status code of INTERNAL or
RESOURCE_EXHAUSTED.
Exception: If not a GoogleAdsException or RpcException the error
will be raised as-is.
"""
raise self._get_error_from_response(response)

def intercept_unary_unary(self, continuation, client_call_details, request):
# Assuming _get_error_from_response is defined in the parent Interceptor
# and raises an exception, so this method effectively has -> NoReturn
raise self._get_error_from_response(response) # type: ignore

def intercept_unary_unary(
self,
continuation: UnaryUnaryContinuation[RequestType, ResponseType],
client_call_details: ClientCallDetails,
request: RequestType,
) -> Union[_UnaryUnaryWrapper, ResponseType, Call]:
"""Intercepts and wraps exceptions in the rpc response.

Overrides abstract method defined in grpc.UnaryUnaryClientInterceptor.
Expand All @@ -79,32 +109,41 @@ def intercept_unary_unary(self, continuation, client_call_details, request):
continuation: a function to continue the request process.
client_call_details: a grpc._interceptor._ClientCallDetails
instance containing request metadata.
request: a SearchGoogleAdsRequest or SearchGoogleAdsStreamRequest
message class instance.
request: A protobuf message class instance for the request.

Returns:
A grpc.Call instance representing a service response.
A _UnaryUnaryWrapper instance if successful, otherwise this method
will raise an exception via _handle_grpc_failure. The actual
return type from continuation can be grpc.Call or a future-like
object that has an `exception()` method.

Raises:
GoogleAdsException: If the exception's trailing metadata
indicates that it is a GoogleAdsException.
RpcError: If the exception's trailing metadata is empty or is not
grpc.RpcError: If the exception's trailing metadata is empty or is not
indicative of a GoogleAdsException, or if the exception has a
status code of INTERNAL or RESOURCE_EXHAUSTED.
"""
response = continuation(client_call_details, request)
exception = response.exception()
response_call = continuation(client_call_details, request)
# response_call is often a grpc.Call / grpc.Future in unary-unary.
# It has an exception() method to check for errors.
exception = response_call.exception()

if exception:
self._handle_grpc_failure(response)
# _handle_grpc_failure is guaranteed to raise, so the execution stops here.
self._handle_grpc_failure(response_call)
else:
# If there's no exception, wrap the successful response.
return _UnaryUnaryWrapper(
response, use_proto_plus=self._use_proto_plus
response_call, use_proto_plus=self._use_proto_plus
)

def intercept_unary_stream(
self, continuation, client_call_details, request
):
self,
continuation: UnaryStreamContinuation[RequestType, ResponseType],
client_call_details: ClientCallDetails,
request: RequestType,
) -> _UnaryStreamWrapper:
"""Intercepts and wraps exceptions in the rpc response.

Overrides abstract method defined in grpc.UnaryStreamClientInterceptor.
Expand All @@ -113,22 +152,21 @@ def intercept_unary_stream(
continuation: a function to continue the request process.
client_call_details: a grpc._interceptor._ClientCallDetails
instance containing request metadata.
request: a SearchGoogleAdsRequest or SearchGoogleAdsStreamRequest
message class instance.
request: A protobuf message class instance for the request.

Returns:
A grpc.Call instance representing a service response.
A _UnaryStreamWrapper instance that wraps the stream response.

Raises:
GoogleAdsException: If the exception's trailing metadata
indicates that it is a GoogleAdsException.
RpcError: If the exception's trailing metadata is empty or is not
indicative of a GoogleAdsException, or if the exception has a
status code of INTERNAL or RESOURCE_EXHAUSTED.
This method itself doesn't raise directly but passes
_handle_grpc_failure to _UnaryStreamWrapper, which may raise if
errors occur during streaming or if the initial call fails.
"""
response = continuation(client_call_details, request)
# In unary-stream, continuation returns an object that is an iterator
# of responses, often a grpc.Call.
response_stream_call = continuation(client_call_details, request)
return _UnaryStreamWrapper(
response,
response_stream_call, # type: ignore
self._handle_grpc_failure,
use_proto_plus=self._use_proto_plus,
)
79 changes: 61 additions & 18 deletions google/ads/googleads/interceptors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, TypeVar
from copy import deepcopy
from google.protobuf.message import Message

from google.ads.googleads.util import (
set_nested_message_field,
get_nested_attr,
Expand All @@ -29,7 +32,7 @@
# 1. They are returned as part of a Search or SearchStream request.
# 2. They are returned individually in a Get request.
# 3. They are sent to the API as part of a Mutate request.
_MESSAGES_WITH_SENSITIVE_FIELDS = {
_MESSAGES_WITH_SENSITIVE_FIELDS: Dict[str, List[str]] = {
"CustomerUserAccess": ["email_address", "inviter_user_email_address"],
"CustomerUserAccessInvitation": ["email_address"],
"MutateCustomerUserAccessRequest": [
Expand All @@ -50,26 +53,30 @@
# This is a list of the names of messages that return search results from the
# API. These messages contain other messages that may contain sensitive
# information that needs to be masked before being logged.
_SEARCH_RESPONSE_MESSAGE_NAMES = [
_SEARCH_RESPONSE_MESSAGE_NAMES: List[str] = [
"SearchGoogleAdsResponse",
"SearchGoogleAdsStreamResponse",
]

ProtoMessageT = TypeVar("ProtoMessageT", bound=Message)


def _copy_message(message):
def _copy_message(message: ProtoMessageT) -> ProtoMessageT:
"""Returns a copy of the given message.

Args:
message: An object containing information from an API request
or response.
or response, expected to be a protobuf Message.

Returns:
A copy of the given message.
"""
return deepcopy(message)


def _mask_message_fields(field_list, message, mask):
def _mask_message_fields(
field_list: List[str], message: ProtoMessageT, mask: str
) -> ProtoMessageT:
"""Copies the given message and masks sensitive fields.

Sensitive fields are given as a list of strings and are overridden
Expand All @@ -79,15 +86,21 @@ def _mask_message_fields(field_list, message, mask):
field_list: A list of strings specifying the fields on the message
that should be masked.
message: An object containing information from an API request
or response.
or response, expected to be a protobuf Message.
mask: A str that should replace the sensitive information in the
message.

Returns:
A new instance of the message object with fields copied and masked
where necessary.
"""
copy = _copy_message(message)
# Ensure that the message is not None and is of a type that can be copied.
# The ProtoMessageT TypeVar already implies it's a protobuf message.
if message is None:
# Or handle this case as appropriate, e.g., raise ValueError
return message # Or an empty message of the same type, if possible

copy: ProtoMessageT = _copy_message(message)

for field_path in field_list:
try:
Expand All @@ -98,16 +111,21 @@ def _mask_message_fields(field_list, message, mask):
# AttributeError is raised when the field is not defined on the
# message. In this case there's nothing to mask and the field
# should be skipped.
break
# Original code had "break", which would exit the loop entirely
# after the first AttributeError. "continue" seems more appropriate
# to skip only the problematic field_path.
continue

return copy


def _mask_google_ads_search_response(message, mask):
def _mask_google_ads_search_response(message: Any, mask: str) -> Any:
"""Copies and masks sensitive data in a Search response

Response messages include instances of GoogleAdsSearchResponse and
GoogleAdsSearchStreamResponse.
GoogleAdsSearchStreamResponse. For typing, these are kept as Any
due to the dynamic nature of protobuf messages and to avoid circular
dependencies if specific types were imported.

Args:
message: A SearchGoogleAdsResponse or SearchGoogleAdsStreamResponse
Expand All @@ -118,7 +136,13 @@ def _mask_google_ads_search_response(message, mask):
Returns:
A copy of the message with sensitive fields masked.
"""
copy = _copy_message(message)
# Given message is Any, the copy will also be Any.
# Specific handling for protobuf-like objects is assumed.
copy: Any = _copy_message(message)

# Assuming 'copy' has a 'results' attribute, which is iterable.
if not hasattr(copy, "results"):
return copy # Or raise an error if 'results' is expected

for row in copy.results:
# Each row is an instance of GoogleAdsRow. The ListFields method
Expand Down Expand Up @@ -148,29 +172,48 @@ def _mask_google_ads_search_response(message, mask):
)
# Overwrites the nested message with an exact copy of itself,
# where sensitive fields have been masked.
proto_copy_from(getattr(row, field_name), masked_message)
# for proto_plus messages, _pb holds the protobuf message
# for protobuf messages, it's the message itself
target_nested_message = getattr(row, field_name)
proto_copy_from(target_nested_message, masked_message)

return copy


def mask_message(message, mask):
def mask_message(message: Any, mask: str) -> Any:
"""Copies and returns a message with sensitive fields masked.

Args:
message: An object containing information from an API request
or response.
or response. This is typed as Any due to the variety of
protobuf message types it can handle.
mask: A str that should replace the sensitive information in the
message.

Returns:
A copy of the message instance with sensitive fields masked.
A copy of the message instance with sensitive fields masked, or the
original message if no masking rules apply. The return type is Any,
mirroring the input message type.
"""
class_name = message.__class__.__name__
if not hasattr(message, "__class__") or not hasattr(message.__class__, "__name__"):
# Not an object we can get a class name from, return as is.
return message

class_name: str = message.__class__.__name__

if class_name in _SEARCH_RESPONSE_MESSAGE_NAMES:
# _mask_google_ads_search_response expects Any and returns Any
return _mask_google_ads_search_response(message, mask)
elif class_name in _MESSAGES_WITH_SENSITIVE_FIELDS.keys():
sensitive_fields = _MESSAGES_WITH_SENSITIVE_FIELDS[class_name]
elif class_name in _MESSAGES_WITH_SENSITIVE_FIELDS:
sensitive_fields: List[str] = _MESSAGES_WITH_SENSITIVE_FIELDS[class_name]
# _mask_message_fields is generic over ProtoMessageT.
# Since 'message' is Any here, we're passing Any.
# This might lose some type safety if 'message' isn't actually a Message.
# However, the function's logic implies it expects a message-like object.
# If 'message' here is guaranteed to be a protobuf message,
# we could potentially cast or check, but 'Any' is safer for now.
return _mask_message_fields(sensitive_fields, message, mask)
else:
# If not a special type, return the message as is (or a copy if preferred)
# The original code returns the original message.
return message
Loading