diff --git a/endpoints/__init__.py b/endpoints/__init__.py index d811e9f..bc54a84 100644 --- a/endpoints/__init__.py +++ b/endpoints/__init__.py @@ -20,9 +20,9 @@ # pylint: disable=wildcard-import from __future__ import absolute_import -from protorpc import message_types -from protorpc import messages -from protorpc import remote +from .bundled.protorpc import message_types +from .bundled.protorpc import messages +from .bundled.protorpc import remote from .api_config import api, method from .api_config import AUTH_LEVEL, EMAIL_SCOPE diff --git a/endpoints/_endpointscfg_impl.py b/endpoints/_endpointscfg_impl.py index e5d97ba..4e8c72d 100644 --- a/endpoints/_endpointscfg_impl.py +++ b/endpoints/_endpointscfg_impl.py @@ -188,7 +188,7 @@ def GenApiConfig(service_class_names, config_string_generator=None, resolved_services.extend(service.get_api_classes()) elif (not isinstance(service, type) or not issubclass(service, remote.Service)): - raise TypeError('%s is not a ProtoRPC service' % service_class_name) + raise TypeError('%s is not a subclass of endpoints.remote.Service' % service_class_name) else: resolved_services.append(service) diff --git a/endpoints/api_config.py b/endpoints/api_config.py index 93b9113..d753805 100644 --- a/endpoints/api_config.py +++ b/endpoints/api_config.py @@ -43,7 +43,7 @@ def entries_get(self, request): import attr import semver -from protorpc import util +from .bundled.protorpc import util from . import api_exceptions from . import constants diff --git a/endpoints/apiserving.py b/endpoints/apiserving.py index fd2776f..5ec941f 100644 --- a/endpoints/apiserving.py +++ b/endpoints/apiserving.py @@ -70,7 +70,7 @@ def list(self, request): from endpoints_management.control import client as control_client from endpoints_management.control import wsgi as control_wsgi -from protorpc.wsgi import service as wsgi_service +from .bundled.protorpc.wsgi import service as wsgi_service from . import api_config from . import api_exceptions @@ -564,6 +564,10 @@ def api_server(api_services, **kwargs): if 'protocols' in kwargs: raise TypeError("__init__() got an unexpected keyword argument 'protocols'") + for service in api_services: + if not issubclass(service, remote.Service): + raise TypeError('%s is not a subclass of endpoints.remote.Service' % service) + # Construct the api serving app apis_app = _ApiServer(api_services, **kwargs) dispatcher = endpoints_dispatcher.EndpointsDispatcherMiddleware(apis_app) diff --git a/endpoints/bundled/__init__.py b/endpoints/bundled/__init__.py new file mode 100644 index 0000000..7bbe865 --- /dev/null +++ b/endpoints/bundled/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2018 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Embedded libraries.""" diff --git a/endpoints/bundled/protorpc/__init__.py b/endpoints/bundled/protorpc/__init__.py new file mode 100644 index 0000000..9005262 --- /dev/null +++ b/endpoints/bundled/protorpc/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Main module for ProtoRPC package.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' +__version__ = '1.0' diff --git a/endpoints/bundled/protorpc/definition.py b/endpoints/bundled/protorpc/definition.py new file mode 100644 index 0000000..46ee167 --- /dev/null +++ b/endpoints/bundled/protorpc/definition.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Stub library.""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import sys +import types + +from . import descriptor +from . import message_types +from . import messages +from . import protobuf +from . import remote +from . import util + +__all__ = [ + 'define_enum', + 'define_field', + 'define_file', + 'define_message', + 'define_service', + 'import_file', + 'import_file_set', +] + + +# Map variant back to message field classes. +def _build_variant_map(): + """Map variants to fields. + + Returns: + Dictionary mapping field variant to its associated field type. + """ + result = {} + for name in dir(messages): + value = getattr(messages, name) + if isinstance(value, type) and issubclass(value, messages.Field): + for variant in getattr(value, 'VARIANTS', []): + result[variant] = value + return result + +_VARIANT_MAP = _build_variant_map() + +_MESSAGE_TYPE_MAP = { + message_types.DateTimeMessage.definition_name(): message_types.DateTimeField, +} + + +def _get_or_define_module(full_name, modules): + """Helper method for defining new modules. + + Args: + full_name: Fully qualified name of module to create or return. + modules: Dictionary of all modules. Defaults to sys.modules. + + Returns: + Named module if found in 'modules', else creates new module and inserts in + 'modules'. Will also construct parent modules if necessary. + """ + module = modules.get(full_name) + if not module: + module = types.ModuleType(full_name) + modules[full_name] = module + + split_name = full_name.rsplit('.', 1) + if len(split_name) > 1: + parent_module_name, sub_module_name = split_name + parent_module = _get_or_define_module(parent_module_name, modules) + setattr(parent_module, sub_module_name, module) + + return module + + +def define_enum(enum_descriptor, module_name): + """Define Enum class from descriptor. + + Args: + enum_descriptor: EnumDescriptor to build Enum class from. + module_name: Module name to give new descriptor class. + + Returns: + New messages.Enum sub-class as described by enum_descriptor. + """ + enum_values = enum_descriptor.values or [] + + class_dict = dict((value.name, value.number) for value in enum_values) + class_dict['__module__'] = module_name + return type(str(enum_descriptor.name), (messages.Enum,), class_dict) + + +def define_field(field_descriptor): + """Define Field instance from descriptor. + + Args: + field_descriptor: FieldDescriptor class to build field instance from. + + Returns: + New field instance as described by enum_descriptor. + """ + field_class = _VARIANT_MAP[field_descriptor.variant] + params = {'number': field_descriptor.number, + 'variant': field_descriptor.variant, + } + + if field_descriptor.label == descriptor.FieldDescriptor.Label.REQUIRED: + params['required'] = True + elif field_descriptor.label == descriptor.FieldDescriptor.Label.REPEATED: + params['repeated'] = True + + message_type_field = _MESSAGE_TYPE_MAP.get(field_descriptor.type_name) + if message_type_field: + return message_type_field(**params) + elif field_class in (messages.EnumField, messages.MessageField): + return field_class(field_descriptor.type_name, **params) + else: + if field_descriptor.default_value: + value = field_descriptor.default_value + try: + value = descriptor._DEFAULT_FROM_STRING_MAP[field_class](value) + except (TypeError, ValueError, KeyError): + pass # Let the value pass to the constructor. + params['default'] = value + return field_class(**params) + + +def define_message(message_descriptor, module_name): + """Define Message class from descriptor. + + Args: + message_descriptor: MessageDescriptor to describe message class from. + module_name: Module name to give to new descriptor class. + + Returns: + New messages.Message sub-class as described by message_descriptor. + """ + class_dict = {'__module__': module_name} + + for enum in message_descriptor.enum_types or []: + enum_instance = define_enum(enum, module_name) + class_dict[enum.name] = enum_instance + + # TODO(rafek): support nested messages when supported by descriptor. + + for field in message_descriptor.fields or []: + field_instance = define_field(field) + class_dict[field.name] = field_instance + + class_name = message_descriptor.name.encode('utf-8') + return type(class_name, (messages.Message,), class_dict) + + +def define_service(service_descriptor, module): + """Define a new service proxy. + + Args: + service_descriptor: ServiceDescriptor class that describes the service. + module: Module to add service to. Request and response types are found + relative to this module. + + Returns: + Service class proxy capable of communicating with a remote server. + """ + class_dict = {'__module__': module.__name__} + class_name = service_descriptor.name.encode('utf-8') + + for method_descriptor in service_descriptor.methods or []: + request_definition = messages.find_definition( + method_descriptor.request_type, module) + response_definition = messages.find_definition( + method_descriptor.response_type, module) + + method_name = method_descriptor.name.encode('utf-8') + def remote_method(self, request): + """Actual service method.""" + raise NotImplementedError('Method is not implemented') + remote_method.__name__ = method_name + remote_method_decorator = remote.method(request_definition, + response_definition) + + class_dict[method_name] = remote_method_decorator(remote_method) + + service_class = type(class_name, (remote.Service,), class_dict) + return service_class + + +def define_file(file_descriptor, module=None): + """Define module from FileDescriptor. + + Args: + file_descriptor: FileDescriptor instance to describe module from. + module: Module to add contained objects to. Module name overrides value + in file_descriptor.package. Definitions are added to existing + module if provided. + + Returns: + If no module provided, will create a new module with its name set to the + file descriptor's package. If a module is provided, returns the same + module. + """ + if module is None: + module = types.ModuleType(file_descriptor.package) + + for enum_descriptor in file_descriptor.enum_types or []: + enum_class = define_enum(enum_descriptor, module.__name__) + setattr(module, enum_descriptor.name, enum_class) + + for message_descriptor in file_descriptor.message_types or []: + message_class = define_message(message_descriptor, module.__name__) + setattr(module, message_descriptor.name, message_class) + + for service_descriptor in file_descriptor.service_types or []: + service_class = define_service(service_descriptor, module) + setattr(module, service_descriptor.name, service_class) + + return module + + +@util.positional(1) +def import_file(file_descriptor, modules=None): + """Import FileDescriptor in to module space. + + This is like define_file except that a new module and any required parent + modules are created and added to the modules parameter or sys.modules if not + provided. + + Args: + file_descriptor: FileDescriptor instance to describe module from. + modules: Dictionary of modules to update. Modules and their parents that + do not exist will be created. If an existing module is found that + matches file_descriptor.package, that module is updated with the + FileDescriptor contents. + + Returns: + Module found in modules, else a new module. + """ + if not file_descriptor.package: + raise ValueError('File descriptor must have package name') + + if modules is None: + modules = sys.modules + + module = _get_or_define_module(file_descriptor.package.encode('utf-8'), + modules) + + return define_file(file_descriptor, module) + + +@util.positional(1) +def import_file_set(file_set, modules=None, _open=open): + """Import FileSet in to module space. + + Args: + file_set: If string, open file and read serialized FileSet. Otherwise, + a FileSet instance to import definitions from. + modules: Dictionary of modules to update. Modules and their parents that + do not exist will be created. If an existing module is found that + matches file_descriptor.package, that module is updated with the + FileDescriptor contents. + _open: Used for dependency injection during tests. + """ + if isinstance(file_set, six.string_types): + encoded_file = _open(file_set, 'rb') + try: + encoded_file_set = encoded_file.read() + finally: + encoded_file.close() + + file_set = protobuf.decode_message(descriptor.FileSet, encoded_file_set) + + for file_descriptor in file_set.files: + # Do not reload built in protorpc classes. + if not file_descriptor.package.startswith('protorpc.'): + import_file(file_descriptor, modules=modules) diff --git a/endpoints/bundled/protorpc/descriptor.py b/endpoints/bundled/protorpc/descriptor.py new file mode 100644 index 0000000..5f9e2e7 --- /dev/null +++ b/endpoints/bundled/protorpc/descriptor.py @@ -0,0 +1,712 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Services descriptor definitions. + +Contains message definitions and functions for converting +service classes into transmittable message format. + +Describing an Enum instance, Enum class, Field class or Message class will +generate an appropriate descriptor object that describes that class. +This message can itself be used to transmit information to clients wishing +to know the description of an enum value, enum, field or message without +needing to download the source code. This format is also compatible with +other, non-Python languages. + +The descriptors are modeled to be binary compatible with: + + http://code.google.com/p/protobuf/source/browse/trunk/src/google/protobuf/descriptor.proto + +NOTE: The names of types and fields are not always the same between these +descriptors and the ones defined in descriptor.proto. This was done in order +to make source code files that use these descriptors easier to read. For +example, it is not necessary to prefix TYPE to all the values in +FieldDescriptor.Variant as is done in descriptor.proto FieldDescriptorProto.Type. + +Example: + + class Pixel(messages.Message): + + x = messages.IntegerField(1, required=True) + y = messages.IntegerField(2, required=True) + + color = messages.BytesField(3) + + # Describe Pixel class using message descriptor. + fields = [] + + field = FieldDescriptor() + field.name = 'x' + field.number = 1 + field.label = FieldDescriptor.Label.REQUIRED + field.variant = FieldDescriptor.Variant.INT64 + fields.append(field) + + field = FieldDescriptor() + field.name = 'y' + field.number = 2 + field.label = FieldDescriptor.Label.REQUIRED + field.variant = FieldDescriptor.Variant.INT64 + fields.append(field) + + field = FieldDescriptor() + field.name = 'color' + field.number = 3 + field.label = FieldDescriptor.Label.OPTIONAL + field.variant = FieldDescriptor.Variant.BYTES + fields.append(field) + + message = MessageDescriptor() + message.name = 'Pixel' + message.fields = fields + + # Describing is the equivalent of building the above message. + message == describe_message(Pixel) + +Public Classes: + EnumValueDescriptor: Describes Enum values. + EnumDescriptor: Describes Enum classes. + FieldDescriptor: Describes field instances. + FileDescriptor: Describes a single 'file' unit. + FileSet: Describes a collection of file descriptors. + MessageDescriptor: Describes Message classes. + MethodDescriptor: Describes a method of a service. + ServiceDescriptor: Describes a services. + +Public Functions: + describe_enum_value: Describe an individual enum-value. + describe_enum: Describe an Enum class. + describe_field: Describe a Field definition. + describe_file: Describe a 'file' unit from a Python module or object. + describe_file_set: Describe a file set from a list of modules or objects. + describe_message: Describe a Message definition. + describe_method: Describe a Method definition. + describe_service: Describe a Service definition. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import codecs +import types + +from . import messages +from . import util + + +__all__ = ['EnumDescriptor', + 'EnumValueDescriptor', + 'FieldDescriptor', + 'MessageDescriptor', + 'MethodDescriptor', + 'FileDescriptor', + 'FileSet', + 'ServiceDescriptor', + 'DescriptorLibrary', + + 'describe_enum', + 'describe_enum_value', + 'describe_field', + 'describe_message', + 'describe_method', + 'describe_file', + 'describe_file_set', + 'describe_service', + 'describe', + 'import_descriptor_loader', + ] + + +# NOTE: MessageField is missing because message fields cannot have +# a default value at this time. +# TODO(rafek): Support default message values. +# +# Map to functions that convert default values of fields of a given type +# to a string. The function must return a value that is compatible with +# FieldDescriptor.default_value and therefore a unicode string. +_DEFAULT_TO_STRING_MAP = { + messages.IntegerField: six.text_type, + messages.FloatField: six.text_type, + messages.BooleanField: lambda value: value and u'true' or u'false', + messages.BytesField: lambda value: codecs.escape_encode(value)[0], + messages.StringField: lambda value: value, + messages.EnumField: lambda value: six.text_type(value.number), +} + +_DEFAULT_FROM_STRING_MAP = { + messages.IntegerField: int, + messages.FloatField: float, + messages.BooleanField: lambda value: value == u'true', + messages.BytesField: lambda value: codecs.escape_decode(value)[0], + messages.StringField: lambda value: value, + messages.EnumField: int, +} + + +class EnumValueDescriptor(messages.Message): + """Enum value descriptor. + + Fields: + name: Name of enumeration value. + number: Number of enumeration value. + """ + + # TODO(rafek): Why are these listed as optional in descriptor.proto. + # Harmonize? + name = messages.StringField(1, required=True) + number = messages.IntegerField(2, + required=True, + variant=messages.Variant.INT32) + + +class EnumDescriptor(messages.Message): + """Enum class descriptor. + + Fields: + name: Name of Enum without any qualification. + values: Values defined by Enum class. + """ + + name = messages.StringField(1) + values = messages.MessageField(EnumValueDescriptor, 2, repeated=True) + + +class FieldDescriptor(messages.Message): + """Field definition descriptor. + + Enums: + Variant: Wire format hint sub-types for field. + Label: Values for optional, required and repeated fields. + + Fields: + name: Name of field. + number: Number of field. + variant: Variant of field. + type_name: Type name for message and enum fields. + default_value: String representation of default value. + """ + + Variant = messages.Variant + + class Label(messages.Enum): + """Field label.""" + + OPTIONAL = 1 + REQUIRED = 2 + REPEATED = 3 + + name = messages.StringField(1, required=True) + number = messages.IntegerField(3, + required=True, + variant=messages.Variant.INT32) + label = messages.EnumField(Label, 4, default=Label.OPTIONAL) + variant = messages.EnumField(Variant, 5) + type_name = messages.StringField(6) + + # For numeric types, contains the original text representation of the value. + # For booleans, "true" or "false". + # For strings, contains the default text contents (not escaped in any way). + # For bytes, contains the C escaped value. All bytes < 128 are that are + # traditionally considered unprintable are also escaped. + default_value = messages.StringField(7) + + +class MessageDescriptor(messages.Message): + """Message definition descriptor. + + Fields: + name: Name of Message without any qualification. + fields: Fields defined for message. + message_types: Nested Message classes defined on message. + enum_types: Nested Enum classes defined on message. + """ + + name = messages.StringField(1) + fields = messages.MessageField(FieldDescriptor, 2, repeated=True) + + message_types = messages.MessageField( + 'protorpc.descriptor.MessageDescriptor', 3, repeated=True) + enum_types = messages.MessageField(EnumDescriptor, 4, repeated=True) + + +class MethodDescriptor(messages.Message): + """Service method definition descriptor. + + Fields: + name: Name of service method. + request_type: Fully qualified or relative name of request message type. + response_type: Fully qualified or relative name of response message type. + """ + + name = messages.StringField(1) + + request_type = messages.StringField(2) + response_type = messages.StringField(3) + + +class ServiceDescriptor(messages.Message): + """Service definition descriptor. + + Fields: + name: Name of Service without any qualification. + methods: Remote methods of Service. + """ + + name = messages.StringField(1) + + methods = messages.MessageField(MethodDescriptor, 2, repeated=True) + + +class FileDescriptor(messages.Message): + """Description of file containing protobuf definitions. + + Fields: + package: Fully qualified name of package that definitions belong to. + message_types: Message definitions contained in file. + enum_types: Enum definitions contained in file. + service_types: Service definitions contained in file. + """ + + package = messages.StringField(2) + + # TODO(rafek): Add dependency field + + message_types = messages.MessageField(MessageDescriptor, 4, repeated=True) + enum_types = messages.MessageField(EnumDescriptor, 5, repeated=True) + service_types = messages.MessageField(ServiceDescriptor, 6, repeated=True) + + +class FileSet(messages.Message): + """A collection of FileDescriptors. + + Fields: + files: Files in file-set. + """ + + files = messages.MessageField(FileDescriptor, 1, repeated=True) + + +def describe_enum_value(enum_value): + """Build descriptor for Enum instance. + + Args: + enum_value: Enum value to provide descriptor for. + + Returns: + Initialized EnumValueDescriptor instance describing the Enum instance. + """ + enum_value_descriptor = EnumValueDescriptor() + enum_value_descriptor.name = six.text_type(enum_value.name) + enum_value_descriptor.number = enum_value.number + return enum_value_descriptor + + +def describe_enum(enum_definition): + """Build descriptor for Enum class. + + Args: + enum_definition: Enum class to provide descriptor for. + + Returns: + Initialized EnumDescriptor instance describing the Enum class. + """ + enum_descriptor = EnumDescriptor() + enum_descriptor.name = enum_definition.definition_name().split('.')[-1] + + values = [] + for number in enum_definition.numbers(): + value = enum_definition.lookup_by_number(number) + values.append(describe_enum_value(value)) + + if values: + enum_descriptor.values = values + + return enum_descriptor + + +def describe_field(field_definition): + """Build descriptor for Field instance. + + Args: + field_definition: Field instance to provide descriptor for. + + Returns: + Initialized FieldDescriptor instance describing the Field instance. + """ + field_descriptor = FieldDescriptor() + field_descriptor.name = field_definition.name + field_descriptor.number = field_definition.number + field_descriptor.variant = field_definition.variant + + if isinstance(field_definition, messages.EnumField): + field_descriptor.type_name = field_definition.type.definition_name() + + if isinstance(field_definition, messages.MessageField): + field_descriptor.type_name = field_definition.message_type.definition_name() + + if field_definition.default is not None: + field_descriptor.default_value = _DEFAULT_TO_STRING_MAP[ + type(field_definition)](field_definition.default) + + # Set label. + if field_definition.repeated: + field_descriptor.label = FieldDescriptor.Label.REPEATED + elif field_definition.required: + field_descriptor.label = FieldDescriptor.Label.REQUIRED + else: + field_descriptor.label = FieldDescriptor.Label.OPTIONAL + + return field_descriptor + + +def describe_message(message_definition): + """Build descriptor for Message class. + + Args: + message_definition: Message class to provide descriptor for. + + Returns: + Initialized MessageDescriptor instance describing the Message class. + """ + message_descriptor = MessageDescriptor() + message_descriptor.name = message_definition.definition_name().split('.')[-1] + + fields = sorted(message_definition.all_fields(), + key=lambda v: v.number) + if fields: + message_descriptor.fields = [describe_field(field) for field in fields] + + try: + nested_messages = message_definition.__messages__ + except AttributeError: + pass + else: + message_descriptors = [] + for name in nested_messages: + value = getattr(message_definition, name) + message_descriptors.append(describe_message(value)) + + message_descriptor.message_types = message_descriptors + + try: + nested_enums = message_definition.__enums__ + except AttributeError: + pass + else: + enum_descriptors = [] + for name in nested_enums: + value = getattr(message_definition, name) + enum_descriptors.append(describe_enum(value)) + + message_descriptor.enum_types = enum_descriptors + + return message_descriptor + + +def describe_method(method): + """Build descriptor for service method. + + Args: + method: Remote service method to describe. + + Returns: + Initialized MethodDescriptor instance describing the service method. + """ + method_info = method.remote + descriptor = MethodDescriptor() + descriptor.name = method_info.method.__name__ + descriptor.request_type = method_info.request_type.definition_name() + descriptor.response_type = method_info.response_type.definition_name() + + return descriptor + + +def describe_service(service_class): + """Build descriptor for service. + + Args: + service_class: Service class to describe. + + Returns: + Initialized ServiceDescriptor instance describing the service. + """ + descriptor = ServiceDescriptor() + descriptor.name = service_class.__name__ + methods = [] + remote_methods = service_class.all_remote_methods() + for name in sorted(remote_methods.keys()): + if name == 'get_descriptor': + continue + + method = remote_methods[name] + methods.append(describe_method(method)) + if methods: + descriptor.methods = methods + + return descriptor + + +def describe_file(module): + """Build a file from a specified Python module. + + Args: + module: Python module to describe. + + Returns: + Initialized FileDescriptor instance describing the module. + """ + # May not import remote at top of file because remote depends on this + # file + # TODO(rafek): Straighten out this dependency. Possibly move these functions + # from descriptor to their own module. + from . import remote + + descriptor = FileDescriptor() + descriptor.package = util.get_package_for_module(module) + + if not descriptor.package: + descriptor.package = None + + message_descriptors = [] + enum_descriptors = [] + service_descriptors = [] + + # Need to iterate over all top level attributes of the module looking for + # message, enum and service definitions. Each definition must be itself + # described. + for name in sorted(dir(module)): + value = getattr(module, name) + + if isinstance(value, type): + if issubclass(value, messages.Message): + message_descriptors.append(describe_message(value)) + + elif issubclass(value, messages.Enum): + enum_descriptors.append(describe_enum(value)) + + elif issubclass(value, remote.Service): + service_descriptors.append(describe_service(value)) + + if message_descriptors: + descriptor.message_types = message_descriptors + + if enum_descriptors: + descriptor.enum_types = enum_descriptors + + if service_descriptors: + descriptor.service_types = service_descriptors + + return descriptor + + +def describe_file_set(modules): + """Build a file set from a specified Python modules. + + Args: + modules: Iterable of Python module to describe. + + Returns: + Initialized FileSet instance describing the modules. + """ + descriptor = FileSet() + file_descriptors = [] + for module in modules: + file_descriptors.append(describe_file(module)) + + if file_descriptors: + descriptor.files = file_descriptors + + return descriptor + + +def describe(value): + """Describe any value as a descriptor. + + Helper function for describing any object with an appropriate descriptor + object. + + Args: + value: Value to describe as a descriptor. + + Returns: + Descriptor message class if object is describable as a descriptor, else + None. + """ + from . import remote + if isinstance(value, types.ModuleType): + return describe_file(value) + elif callable(value) and hasattr(value, 'remote'): + return describe_method(value) + elif isinstance(value, messages.Field): + return describe_field(value) + elif isinstance(value, messages.Enum): + return describe_enum_value(value) + elif isinstance(value, type): + if issubclass(value, messages.Message): + return describe_message(value) + elif issubclass(value, messages.Enum): + return describe_enum(value) + elif issubclass(value, remote.Service): + return describe_service(value) + return None + + +@util.positional(1) +def import_descriptor_loader(definition_name, importer=__import__): + """Find objects by importing modules as needed. + + A definition loader is a function that resolves a definition name to a + descriptor. + + The import finder resolves definitions to their names by importing modules + when necessary. + + Args: + definition_name: Name of definition to find. + importer: Import function used for importing new modules. + + Returns: + Appropriate descriptor for any describable type located by name. + + Raises: + DefinitionNotFoundError when a name does not refer to either a definition + or a module. + """ + # Attempt to import descriptor as a module. + if definition_name.startswith('.'): + definition_name = definition_name[1:] + if not definition_name.startswith('.'): + leaf = definition_name.split('.')[-1] + if definition_name: + try: + module = importer(definition_name, '', '', [leaf]) + except ImportError: + pass + else: + return describe(module) + + try: + # Attempt to use messages.find_definition to find item. + return describe(messages.find_definition(definition_name, + importer=__import__)) + except messages.DefinitionNotFoundError as err: + # There are things that find_definition will not find, but if the parent + # is loaded, its children can be searched for a match. + split_name = definition_name.rsplit('.', 1) + if len(split_name) > 1: + parent, child = split_name + try: + parent_definition = import_descriptor_loader(parent, importer=importer) + except messages.DefinitionNotFoundError: + # Fall through to original error. + pass + else: + # Check the parent definition for a matching descriptor. + if isinstance(parent_definition, FileDescriptor): + search_list = parent_definition.service_types or [] + elif isinstance(parent_definition, ServiceDescriptor): + search_list = parent_definition.methods or [] + elif isinstance(parent_definition, EnumDescriptor): + search_list = parent_definition.values or [] + elif isinstance(parent_definition, MessageDescriptor): + search_list = parent_definition.fields or [] + else: + search_list = [] + + for definition in search_list: + if definition.name == child: + return definition + + # Still didn't find. Reraise original exception. + raise err + + +class DescriptorLibrary(object): + """A descriptor library is an object that contains known definitions. + + A descriptor library contains a cache of descriptor objects mapped by + definition name. It contains all types of descriptors except for + file sets. + + When a definition name is requested that the library does not know about + it can be provided with a descriptor loader which attempt to resolve the + missing descriptor. + """ + + @util.positional(1) + def __init__(self, + descriptors=None, + descriptor_loader=import_descriptor_loader): + """Constructor. + + Args: + descriptors: A dictionary or dictionary-like object that can be used + to store and cache descriptors by definition name. + definition_loader: A function used for resolving missing descriptors. + The function takes a definition name as its parameter and returns + an appropriate descriptor. It may raise DefinitionNotFoundError. + """ + self.__descriptor_loader = descriptor_loader + self.__descriptors = descriptors or {} + + def lookup_descriptor(self, definition_name): + """Lookup descriptor by name. + + Get descriptor from library by name. If descriptor is not found will + attempt to find via descriptor loader if provided. + + Args: + definition_name: Definition name to find. + + Returns: + Descriptor that describes definition name. + + Raises: + DefinitionNotFoundError if not descriptor exists for definition name. + """ + try: + return self.__descriptors[definition_name] + except KeyError: + pass + + if self.__descriptor_loader: + definition = self.__descriptor_loader(definition_name) + self.__descriptors[definition_name] = definition + return definition + else: + raise messages.DefinitionNotFoundError( + 'Could not find definition for %s' % definition_name) + + def lookup_package(self, definition_name): + """Determines the package name for any definition. + + Determine the package that any definition name belongs to. May check + parent for package name and will resolve missing descriptors if provided + descriptor loader. + + Args: + definition_name: Definition name to find package for. + """ + while True: + descriptor = self.lookup_descriptor(definition_name) + if isinstance(descriptor, FileDescriptor): + return descriptor.package + else: + index = definition_name.rfind('.') + if index < 0: + return None + definition_name = definition_name[:index] diff --git a/endpoints/bundled/protorpc/google_imports.py b/endpoints/bundled/protorpc/google_imports.py new file mode 100644 index 0000000..7ed0f40 --- /dev/null +++ b/endpoints/bundled/protorpc/google_imports.py @@ -0,0 +1,15 @@ +"""Dynamically decide from where to import other SDK modules. + +All other protorpc code should import other SDK modules from +this module. If necessary, add new imports here (in both places). +""" + +__author__ = 'yey@google.com (Ye Yuan)' + +# pylint: disable=g-import-not-at-top +# pylint: disable=unused-import + +try: + from google.net.proto import ProtocolBuffer +except ImportError: + pass diff --git a/endpoints/bundled/protorpc/message_types.py b/endpoints/bundled/protorpc/message_types.py new file mode 100644 index 0000000..a1474b7 --- /dev/null +++ b/endpoints/bundled/protorpc/message_types.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Simple protocol message types. + +Includes new message and field types that are outside what is defined by the +protocol buffers standard. +""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import datetime + +from . import messages +from . import util + +__all__ = [ + 'DateTimeField', + 'DateTimeMessage', + 'VoidMessage', +] + +package = 'protorpc.message_types' + + +class VoidMessage(messages.Message): + """Empty message.""" + + +class DateTimeMessage(messages.Message): + """Message to store/transmit a DateTime. + + Fields: + milliseconds: Milliseconds since Jan 1st 1970 local time. + time_zone_offset: Optional time zone offset, in minutes from UTC. + """ + milliseconds = messages.IntegerField(1, required=True) + time_zone_offset = messages.IntegerField(2) + + +class DateTimeField(messages.MessageField): + """Field definition for datetime values. + + Stores a python datetime object as a field. If time zone information is + included in the datetime object, it will be included in + the encoded data when this is encoded/decoded. + """ + + type = datetime.datetime + + message_type = DateTimeMessage + + @util.positional(3) + def __init__(self, + number, + **kwargs): + super(DateTimeField, self).__init__(self.message_type, + number, + **kwargs) + + def value_from_message(self, message): + """Convert DateTimeMessage to a datetime. + + Args: + A DateTimeMessage instance. + + Returns: + A datetime instance. + """ + message = super(DateTimeField, self).value_from_message(message) + if message.time_zone_offset is None: + return datetime.datetime.utcfromtimestamp(message.milliseconds / 1000.0) + + # Need to subtract the time zone offset, because when we call + # datetime.fromtimestamp, it will add the time zone offset to the + # value we pass. + milliseconds = (message.milliseconds - + 60000 * message.time_zone_offset) + + timezone = util.TimeZoneOffset(message.time_zone_offset) + return datetime.datetime.fromtimestamp(milliseconds / 1000.0, + tz=timezone) + + def value_to_message(self, value): + value = super(DateTimeField, self).value_to_message(value) + # First, determine the delta from the epoch, so we can fill in + # DateTimeMessage's milliseconds field. + if value.tzinfo is None: + time_zone_offset = 0 + local_epoch = datetime.datetime.utcfromtimestamp(0) + else: + time_zone_offset = util.total_seconds(value.tzinfo.utcoffset(value)) + # Determine Jan 1, 1970 local time. + local_epoch = datetime.datetime.fromtimestamp(-time_zone_offset, + tz=value.tzinfo) + delta = value - local_epoch + + # Create and fill in the DateTimeMessage, including time zone if + # one was specified. + message = DateTimeMessage() + message.milliseconds = int(util.total_seconds(delta) * 1000) + if value.tzinfo is not None: + utc_offset = value.tzinfo.utcoffset(value) + if utc_offset is not None: + message.time_zone_offset = int( + util.total_seconds(value.tzinfo.utcoffset(value)) / 60) + + return message diff --git a/endpoints/bundled/protorpc/messages.py b/endpoints/bundled/protorpc/messages.py new file mode 100644 index 0000000..f86ed48 --- /dev/null +++ b/endpoints/bundled/protorpc/messages.py @@ -0,0 +1,1951 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Stand-alone implementation of in memory protocol messages. + +Public Classes: + Enum: Represents an enumerated type. + Variant: Hint for wire format to determine how to serialize. + Message: Base class for user defined messages. + IntegerField: Field for integer values. + FloatField: Field for float values. + BooleanField: Field for boolean values. + BytesField: Field for binary string values. + StringField: Field for UTF-8 string values. + MessageField: Field for other message type values. + EnumField: Field for enumerated type values. + +Public Exceptions (indentation indications class hierarchy): + EnumDefinitionError: Raised when enumeration is incorrectly defined. + FieldDefinitionError: Raised when field is incorrectly defined. + InvalidVariantError: Raised when variant is not compatible with field type. + InvalidDefaultError: Raised when default is not compatiable with field. + InvalidNumberError: Raised when field number is out of range or reserved. + MessageDefinitionError: Raised when message is incorrectly defined. + DuplicateNumberError: Raised when field has duplicate number with another. + ValidationError: Raised when a message or field is not valid. + DefinitionNotFoundError: Raised when definition not found. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import types +import weakref + +from . import util + +__all__ = ['MAX_ENUM_VALUE', + 'MAX_FIELD_NUMBER', + 'FIRST_RESERVED_FIELD_NUMBER', + 'LAST_RESERVED_FIELD_NUMBER', + + 'Enum', + 'Field', + 'FieldList', + 'Variant', + 'Message', + 'IntegerField', + 'FloatField', + 'BooleanField', + 'BytesField', + 'StringField', + 'MessageField', + 'EnumField', + 'find_definition', + + 'Error', + 'DecodeError', + 'EncodeError', + 'EnumDefinitionError', + 'FieldDefinitionError', + 'InvalidVariantError', + 'InvalidDefaultError', + 'InvalidNumberError', + 'MessageDefinitionError', + 'DuplicateNumberError', + 'ValidationError', + 'DefinitionNotFoundError', + ] + +package = 'protorpc.messages' + + +# TODO(rafek): Add extended module test to ensure all exceptions +# in services extends Error. +Error = util.Error + + +class EnumDefinitionError(Error): + """Enumeration definition error.""" + + +class FieldDefinitionError(Error): + """Field definition error.""" + + +class InvalidVariantError(FieldDefinitionError): + """Invalid variant provided to field.""" + + +class InvalidDefaultError(FieldDefinitionError): + """Invalid default provided to field.""" + + +class InvalidNumberError(FieldDefinitionError): + """Invalid number provided to field.""" + + +class MessageDefinitionError(Error): + """Message definition error.""" + + +class DuplicateNumberError(Error): + """Duplicate number assigned to field.""" + + +class DefinitionNotFoundError(Error): + """Raised when definition is not found.""" + + +class DecodeError(Error): + """Error found decoding message from encoded form.""" + + +class EncodeError(Error): + """Error found when encoding message.""" + + +class ValidationError(Error): + """Invalid value for message error.""" + + def __str__(self): + """Prints string with field name if present on exception.""" + message = Error.__str__(self) + try: + field_name = self.field_name + except AttributeError: + return message + else: + return message + + +# Attributes that are reserved by a class definition that +# may not be used by either Enum or Message class definitions. +_RESERVED_ATTRIBUTE_NAMES = frozenset( + ['__module__', '__doc__', '__qualname__']) + +_POST_INIT_FIELD_ATTRIBUTE_NAMES = frozenset( + ['name', + '_message_definition', + '_MessageField__type', + '_EnumField__type', + '_EnumField__resolved_default']) + +_POST_INIT_ATTRIBUTE_NAMES = frozenset( + ['_message_definition']) + +# Maximum enumeration value as defined by the protocol buffers standard. +# All enum values must be less than or equal to this value. +MAX_ENUM_VALUE = (2 ** 29) - 1 + +# Maximum field number as defined by the protocol buffers standard. +# All field numbers must be less than or equal to this value. +MAX_FIELD_NUMBER = (2 ** 29) - 1 + +# Field numbers between 19000 and 19999 inclusive are reserved by the +# protobuf protocol and may not be used by fields. +FIRST_RESERVED_FIELD_NUMBER = 19000 +LAST_RESERVED_FIELD_NUMBER = 19999 + + +class _DefinitionClass(type): + """Base meta-class used for definition meta-classes. + + The Enum and Message definition classes share some basic functionality. + Both of these classes may be contained by a Message definition. After + initialization, neither class may have attributes changed + except for the protected _message_definition attribute, and that attribute + may change only once. + """ + + __initialized = False + + def __init__(cls, name, bases, dct): + """Constructor.""" + type.__init__(cls, name, bases, dct) + # Base classes may never be initialized. + if cls.__bases__ != (object,): + cls.__initialized = True + + def message_definition(cls): + """Get outer Message definition that contains this definition. + + Returns: + Containing Message definition if definition is contained within one, + else None. + """ + try: + return cls._message_definition() + except AttributeError: + return None + + def __setattr__(cls, name, value): + """Overridden so that cannot set variables on definition classes after init. + + Setting attributes on a class must work during the period of initialization + to set the enumation value class variables and build the name/number maps. + Once __init__ has set the __initialized flag to True prohibits setting any + more values on the class. The class is in effect frozen. + + Args: + name: Name of value to set. + value: Value to set. + """ + if cls.__initialized and name not in _POST_INIT_ATTRIBUTE_NAMES: + raise AttributeError('May not change values: %s' % name) + else: + type.__setattr__(cls, name, value) + + def __delattr__(cls, name): + """Overridden so that cannot delete varaibles on definition classes.""" + raise TypeError('May not delete attributes on definition class') + + def definition_name(cls): + """Helper method for creating definition name. + + Names will be generated to include the classes package name, scope (if the + class is nested in another definition) and class name. + + By default, the package name for a definition is derived from its module + name. However, this value can be overriden by placing a 'package' attribute + in the module that contains the definition class. For example: + + package = 'some.alternate.package' + + class MyMessage(Message): + ... + + >>> MyMessage.definition_name() + some.alternate.package.MyMessage + + Returns: + Dot-separated fully qualified name of definition. + """ + outer_definition_name = cls.outer_definition_name() + if outer_definition_name is None: + return six.text_type(cls.__name__) + else: + return u'%s.%s' % (outer_definition_name, cls.__name__) + + def outer_definition_name(cls): + """Helper method for creating outer definition name. + + Returns: + If definition is nested, will return the outer definitions name, else the + package name. + """ + outer_definition = cls.message_definition() + if not outer_definition: + return util.get_package_for_module(cls.__module__) + else: + return outer_definition.definition_name() + + def definition_package(cls): + """Helper method for creating creating the package of a definition. + + Returns: + Name of package that definition belongs to. + """ + outer_definition = cls.message_definition() + if not outer_definition: + return util.get_package_for_module(cls.__module__) + else: + return outer_definition.definition_package() + + +class _EnumClass(_DefinitionClass): + """Meta-class used for defining the Enum base class. + + Meta-class enables very specific behavior for any defined Enum + class. All attributes defined on an Enum sub-class must be integers. + Each attribute defined on an Enum sub-class is translated + into an instance of that sub-class, with the name of the attribute + as its name, and the number provided as its value. It also ensures + that only one level of Enum class hierarchy is possible. In other + words it is not possible to delcare sub-classes of sub-classes of + Enum. + + This class also defines some functions in order to restrict the + behavior of the Enum class and its sub-classes. It is not possible + to change the behavior of the Enum class in later classes since + any new classes may be defined with only integer values, and no methods. + """ + + def __init__(cls, name, bases, dct): + # Can only define one level of sub-classes below Enum. + if not (bases == (object,) or bases == (Enum,)): + raise EnumDefinitionError('Enum type %s may only inherit from Enum' % + (name,)) + + cls.__by_number = {} + cls.__by_name = {} + + # Enum base class does not need to be initialized or locked. + if bases != (object,): + # Replace integer with number. + for attribute, value in dct.items(): + + # Module will be in every enum class. + if attribute in _RESERVED_ATTRIBUTE_NAMES: + continue + + # Reject anything that is not an int. + if not isinstance(value, six.integer_types): + raise EnumDefinitionError( + 'May only use integers in Enum definitions. Found: %s = %s' % + (attribute, value)) + + # Protocol buffer standard recommends non-negative values. + # Reject negative values. + if value < 0: + raise EnumDefinitionError( + 'Must use non-negative enum values. Found: %s = %d' % + (attribute, value)) + + if value > MAX_ENUM_VALUE: + raise EnumDefinitionError( + 'Must use enum values less than or equal %d. Found: %s = %d' % + (MAX_ENUM_VALUE, attribute, value)) + + if value in cls.__by_number: + raise EnumDefinitionError( + 'Value for %s = %d is already defined: %s' % + (attribute, value, cls.__by_number[value].name)) + + # Create enum instance and list in new Enum type. + instance = object.__new__(cls) + cls.__init__(instance, attribute, value) + cls.__by_name[instance.name] = instance + cls.__by_number[instance.number] = instance + setattr(cls, attribute, instance) + + _DefinitionClass.__init__(cls, name, bases, dct) + + def __iter__(cls): + """Iterate over all values of enum. + + Yields: + Enumeration instances of the Enum class in arbitrary order. + """ + return iter(cls.__by_number.values()) + + def names(cls): + """Get all names for Enum. + + Returns: + An iterator for names of the enumeration in arbitrary order. + """ + return cls.__by_name.keys() + + def numbers(cls): + """Get all numbers for Enum. + + Returns: + An iterator for all numbers of the enumeration in arbitrary order. + """ + return cls.__by_number.keys() + + def lookup_by_name(cls, name): + """Look up Enum by name. + + Args: + name: Name of enum to find. + + Returns: + Enum sub-class instance of that value. + """ + return cls.__by_name[name] + + def lookup_by_number(cls, number): + """Look up Enum by number. + + Args: + number: Number of enum to find. + + Returns: + Enum sub-class instance of that value. + """ + return cls.__by_number[number] + + def __len__(cls): + return len(cls.__by_name) + + +class Enum(six.with_metaclass(_EnumClass, object)): + """Base class for all enumerated types.""" + + __slots__ = set(('name', 'number')) + + def __new__(cls, index): + """Acts as look-up routine after class is initialized. + + The purpose of overriding __new__ is to provide a way to treat + Enum subclasses as casting types, similar to how the int type + functions. A program can pass a string or an integer and this + method with "convert" that value in to an appropriate Enum instance. + + Args: + index: Name or number to look up. During initialization + this is always the name of the new enum value. + + Raises: + TypeError: When an inappropriate index value is passed provided. + """ + # If is enum type of this class, return it. + if isinstance(index, cls): + return index + + # If number, look up by number. + if isinstance(index, six.integer_types): + try: + return cls.lookup_by_number(index) + except KeyError: + pass + + # If name, look up by name. + if isinstance(index, six.string_types): + try: + return cls.lookup_by_name(index) + except KeyError: + pass + + raise TypeError('No such value for %s in Enum %s' % + (index, cls.__name__)) + + def __init__(self, name, number=None): + """Initialize new Enum instance. + + Since this should only be called during class initialization any + calls that happen after the class is frozen raises an exception. + """ + # Immediately return if __init__ was called after _Enum.__init__(). + # It means that casting operator version of the class constructor + # is being used. + if getattr(type(self), '_DefinitionClass__initialized'): + return + object.__setattr__(self, 'name', name) + object.__setattr__(self, 'number', number) + + def __setattr__(self, name, value): + raise TypeError('May not change enum values') + + def __str__(self): + return self.name + + def __int__(self): + return self.number + + def __repr__(self): + return '%s(%s, %d)' % (type(self).__name__, self.name, self.number) + + def __reduce__(self): + """Enable pickling. + + Returns: + A 2-tuple containing the class and __new__ args to be used for restoring + a pickled instance. + """ + return self.__class__, (self.number,) + + def __cmp__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return cmp(self.number, other.number) + return NotImplemented + + def __lt__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return self.number < other.number + return NotImplemented + + def __le__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return self.number <= other.number + return NotImplemented + + def __eq__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return self.number == other.number + return NotImplemented + + def __ne__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return self.number != other.number + return NotImplemented + + def __ge__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return self.number >= other.number + return NotImplemented + + def __gt__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return self.number > other.number + return NotImplemented + + def __hash__(self): + """Hash by number.""" + return hash(self.number) + + @classmethod + def to_dict(cls): + """Make dictionary version of enumerated class. + + Dictionary created this way can be used with def_num. + + Returns: + A dict (name) -> number + """ + return dict((item.name, item.number) for item in iter(cls)) + + @staticmethod + def def_enum(dct, name): + """Define enum class from dictionary. + + Args: + dct: Dictionary of enumerated values for type. + name: Name of enum. + """ + return type(name, (Enum,), dct) + + +# TODO(rafek): Determine to what degree this enumeration should be compatible +# with FieldDescriptor.Type in: +# +# http://code.google.com/p/protobuf/source/browse/trunk/src/google/protobuf/descriptor.proto +class Variant(Enum): + """Wire format variant. + + Used by the 'protobuf' wire format to determine how to transmit + a single piece of data. May be used by other formats. + + See: http://code.google.com/apis/protocolbuffers/docs/encoding.html + + Values: + DOUBLE: 64-bit floating point number. + FLOAT: 32-bit floating point number. + INT64: 64-bit signed integer. + UINT64: 64-bit unsigned integer. + INT32: 32-bit signed integer. + BOOL: Boolean value (True or False). + STRING: String of UTF-8 encoded text. + MESSAGE: Embedded message as byte string. + BYTES: String of 8-bit bytes. + UINT32: 32-bit unsigned integer. + ENUM: Enum value as integer. + SINT32: 32-bit signed integer. Uses "zig-zag" encoding. + SINT64: 64-bit signed integer. Uses "zig-zag" encoding. + """ + DOUBLE = 1 + FLOAT = 2 + INT64 = 3 + UINT64 = 4 + INT32 = 5 + BOOL = 8 + STRING = 9 + MESSAGE = 11 + BYTES = 12 + UINT32 = 13 + ENUM = 14 + SINT32 = 17 + SINT64 = 18 + + +class _MessageClass(_DefinitionClass): + """Meta-class used for defining the Message base class. + + For more details about Message classes, see the Message class docstring. + Information contained there may help understanding this class. + + Meta-class enables very specific behavior for any defined Message + class. All attributes defined on an Message sub-class must be field + instances, Enum class definitions or other Message class definitions. Each + field attribute defined on an Message sub-class is added to the set of + field definitions and the attribute is translated in to a slot. It also + ensures that only one level of Message class hierarchy is possible. In other + words it is not possible to declare sub-classes of sub-classes of + Message. + + This class also defines some functions in order to restrict the + behavior of the Message class and its sub-classes. It is not possible + to change the behavior of the Message class in later classes since + any new classes may be defined with only field, Enums and Messages, and + no methods. + """ + + def __new__(cls, name, bases, dct): + """Create new Message class instance. + + The __new__ method of the _MessageClass type is overridden so as to + allow the translation of Field instances to slots. + """ + by_number = {} + by_name = {} + + variant_map = {} + + if bases != (object,): + # Can only define one level of sub-classes below Message. + if bases != (Message,): + raise MessageDefinitionError( + 'Message types may only inherit from Message') + + enums = [] + messages = [] + # Must not use iteritems because this loop will change the state of dct. + for key, field in dct.items(): + + if key in _RESERVED_ATTRIBUTE_NAMES: + continue + + if isinstance(field, type) and issubclass(field, Enum): + enums.append(key) + continue + + if (isinstance(field, type) and + issubclass(field, Message) and + field is not Message): + messages.append(key) + continue + + # Reject anything that is not a field. + if type(field) is Field or not isinstance(field, Field): + raise MessageDefinitionError( + 'May only use fields in message definitions. Found: %s = %s' % + (key, field)) + + if field.number in by_number: + raise DuplicateNumberError( + 'Field with number %d declared more than once in %s' % + (field.number, name)) + + field.name = key + + # Place in name and number maps. + by_name[key] = field + by_number[field.number] = field + + # Add enums if any exist. + if enums: + dct['__enums__'] = sorted(enums) + + # Add messages if any exist. + if messages: + dct['__messages__'] = sorted(messages) + + dct['_Message__by_number'] = by_number + dct['_Message__by_name'] = by_name + + return _DefinitionClass.__new__(cls, name, bases, dct) + + def __init__(cls, name, bases, dct): + """Initializer required to assign references to new class.""" + if bases != (object,): + for value in dct.values(): + if isinstance(value, _DefinitionClass) and not value is Message: + value._message_definition = weakref.ref(cls) + + for field in cls.all_fields(): + field._message_definition = weakref.ref(cls) + + _DefinitionClass.__init__(cls, name, bases, dct) + + +class Message(six.with_metaclass(_MessageClass, object)): + """Base class for user defined message objects. + + Used to define messages for efficient transmission across network or + process space. Messages are defined using the field classes (IntegerField, + FloatField, EnumField, etc.). + + Messages are more restricted than normal classes in that they may only + contain field attributes and other Message and Enum definitions. These + restrictions are in place because the structure of the Message class is + intentended to itself be transmitted across network or process space and + used directly by clients or even other servers. As such methods and + non-field attributes could not be transmitted with the structural information + causing discrepancies between different languages and implementations. + + Initialization and validation: + + A Message object is considered to be initialized if it has all required + fields and any nested messages are also initialized. + + Calling 'check_initialized' will raise a ValidationException if it is not + initialized; 'is_initialized' returns a boolean value indicating if it is + valid. + + Validation automatically occurs when Message objects are created + and populated. Validation that a given value will be compatible with + a field that it is assigned to can be done through the Field instances + validate() method. The validate method used on a message will check that + all values of a message and its sub-messages are valid. Assingning an + invalid value to a field will raise a ValidationException. + + Example: + + # Trade type. + class TradeType(Enum): + BUY = 1 + SELL = 2 + SHORT = 3 + CALL = 4 + + class Lot(Message): + price = IntegerField(1, required=True) + quantity = IntegerField(2, required=True) + + class Order(Message): + symbol = StringField(1, required=True) + total_quantity = IntegerField(2, required=True) + trade_type = EnumField(TradeType, 3, required=True) + lots = MessageField(Lot, 4, repeated=True) + limit = IntegerField(5) + + order = Order(symbol='GOOG', + total_quantity=10, + trade_type=TradeType.BUY) + + lot1 = Lot(price=304, + quantity=7) + + lot2 = Lot(price = 305, + quantity=3) + + order.lots = [lot1, lot2] + + # Now object is initialized! + order.check_initialized() + """ + + def __init__(self, **kwargs): + """Initialize internal messages state. + + Args: + A message can be initialized via the constructor by passing in keyword + arguments corresponding to fields. For example: + + class Date(Message): + day = IntegerField(1) + month = IntegerField(2) + year = IntegerField(3) + + Invoking: + + date = Date(day=6, month=6, year=1911) + + is the same as doing: + + date = Date() + date.day = 6 + date.month = 6 + date.year = 1911 + """ + # Tag being an essential implementation detail must be private. + self.__tags = {} + self.__unrecognized_fields = {} + + assigned = set() + for name, value in kwargs.items(): + setattr(self, name, value) + assigned.add(name) + + # initialize repeated fields. + for field in self.all_fields(): + if field.repeated and field.name not in assigned: + setattr(self, field.name, []) + + + def check_initialized(self): + """Check class for initialization status. + + Check that all required fields are initialized + + Raises: + ValidationError: If message is not initialized. + """ + for name, field in self.__by_name.items(): + value = getattr(self, name) + if value is None: + if field.required: + raise ValidationError("Message %s is missing required field %s" % + (type(self).__name__, name)) + else: + try: + if (isinstance(field, MessageField) and + issubclass(field.message_type, Message)): + if field.repeated: + for item in value: + item_message_value = field.value_to_message(item) + item_message_value.check_initialized() + else: + message_value = field.value_to_message(value) + message_value.check_initialized() + except ValidationError as err: + if not hasattr(err, 'message_name'): + err.message_name = type(self).__name__ + raise + + def is_initialized(self): + """Get initialization status. + + Returns: + True if message is valid, else False. + """ + try: + self.check_initialized() + except ValidationError: + return False + else: + return True + + @classmethod + def all_fields(cls): + """Get all field definition objects. + + Ordering is arbitrary. + + Returns: + Iterator over all values in arbitrary order. + """ + return cls.__by_name.values() + + @classmethod + def field_by_name(cls, name): + """Get field by name. + + Returns: + Field object associated with name. + + Raises: + KeyError if no field found by that name. + """ + return cls.__by_name[name] + + @classmethod + def field_by_number(cls, number): + """Get field by number. + + Returns: + Field object associated with number. + + Raises: + KeyError if no field found by that number. + """ + return cls.__by_number[number] + + def get_assigned_value(self, name): + """Get the assigned value of an attribute. + + Get the underlying value of an attribute. If value has not been set, will + not return the default for the field. + + Args: + name: Name of attribute to get. + + Returns: + Value of attribute, None if it has not been set. + """ + message_type = type(self) + try: + field = message_type.field_by_name(name) + except KeyError: + raise AttributeError('Message %s has no field %s' % ( + message_type.__name__, name)) + return self.__tags.get(field.number) + + def reset(self, name): + """Reset assigned value for field. + + Resetting a field will return it to its default value or None. + + Args: + name: Name of field to reset. + """ + message_type = type(self) + try: + field = message_type.field_by_name(name) + except KeyError: + if name not in message_type.__by_name: + raise AttributeError('Message %s has no field %s' % ( + message_type.__name__, name)) + if field.repeated: + self.__tags[field.number] = FieldList(field, []) + else: + self.__tags.pop(field.number, None) + + def all_unrecognized_fields(self): + """Get the names of all unrecognized fields in this message.""" + return list(self.__unrecognized_fields.keys()) + + def get_unrecognized_field_info(self, key, value_default=None, + variant_default=None): + """Get the value and variant of an unknown field in this message. + + Args: + key: The name or number of the field to retrieve. + value_default: Value to be returned if the key isn't found. + variant_default: Value to be returned as variant if the key isn't + found. + + Returns: + (value, variant), where value and variant are whatever was passed + to set_unrecognized_field. + """ + value, variant = self.__unrecognized_fields.get(key, (value_default, + variant_default)) + return value, variant + + def set_unrecognized_field(self, key, value, variant): + """Set an unrecognized field, used when decoding a message. + + Args: + key: The name or number used to refer to this unknown value. + value: The value of the field. + variant: Type information needed to interpret the value or re-encode it. + + Raises: + TypeError: If the variant is not an instance of messages.Variant. + """ + if not isinstance(variant, Variant): + raise TypeError('Variant type %s is not valid.' % variant) + self.__unrecognized_fields[key] = value, variant + + def __setattr__(self, name, value): + """Change set behavior for messages. + + Messages may only be assigned values that are fields. + + Does not try to validate field when set. + + Args: + name: Name of field to assign to. + value: Value to assign to field. + + Raises: + AttributeError when trying to assign value that is not a field. + """ + if name in self.__by_name or name.startswith('_Message__'): + object.__setattr__(self, name, value) + else: + raise AttributeError("May not assign arbitrary value %s " + "to message %s" % (name, type(self).__name__)) + + def __repr__(self): + """Make string representation of message. + + Example: + + class MyMessage(messages.Message): + integer_value = messages.IntegerField(1) + string_value = messages.StringField(2) + + my_message = MyMessage() + my_message.integer_value = 42 + my_message.string_value = u'A string' + + print my_message + >>> + + Returns: + String representation of message, including the values + of all fields and repr of all sub-messages. + """ + body = ['<', type(self).__name__] + for field in sorted(self.all_fields(), + key=lambda f: f.number): + attribute = field.name + value = self.get_assigned_value(field.name) + if value is not None: + body.append('\n %s: %s' % (attribute, repr(value))) + body.append('>') + return ''.join(body) + + def __eq__(self, other): + """Equality operator. + + Does field by field comparison with other message. For + equality, must be same type and values of all fields must be + equal. + + Messages not required to be initialized for comparison. + + Does not attempt to determine equality for values that have + default values that are not set. In other words: + + class HasDefault(Message): + + attr1 = StringField(1, default='default value') + + message1 = HasDefault() + message2 = HasDefault() + message2.attr1 = 'default value' + + message1 != message2 + + Does not compare unknown values. + + Args: + other: Other message to compare with. + """ + # TODO(rafek): Implement "equivalent" which does comparisons + # taking default values in to consideration. + if self is other: + return True + + if type(self) is not type(other): + return False + + return self.__tags == other.__tags + + def __ne__(self, other): + """Not equals operator. + + Does field by field comparison with other message. For + non-equality, must be different type or any value of a field must be + non-equal to the same field in the other instance. + + Messages not required to be initialized for comparison. + + Args: + other: Other message to compare with. + """ + return not self.__eq__(other) + + +class FieldList(list): + """List implementation that validates field values. + + This list implementation overrides all methods that add values in to a list + in order to validate those new elements. Attempting to add or set list + values that are not of the correct type will raise ValidationError. + """ + + def __init__(self, field_instance, sequence): + """Constructor. + + Args: + field_instance: Instance of field that validates the list. + sequence: List or tuple to construct list from. + """ + if not field_instance.repeated: + raise FieldDefinitionError('FieldList may only accept repeated fields') + self.__field = field_instance + self.__field.validate(sequence) + list.__init__(self, sequence) + + def __getstate__(self): + """Enable pickling. + + The assigned field instance can't be pickled if it belongs to a Message + definition (message_definition uses a weakref), so the Message class and + field number are returned in that case. + + Returns: + A 3-tuple containing: + - The field instance, or None if it belongs to a Message class. + - The Message class that the field instance belongs to, or None. + - The field instance number of the Message class it belongs to, or None. + """ + message_class = self.__field.message_definition() + if message_class is None: + return self.__field, None, None + else: + return None, message_class, self.__field.number + + def __setstate__(self, state): + """Enable unpickling. + + Args: + state: A 3-tuple containing: + - The field instance, or None if it belongs to a Message class. + - The Message class that the field instance belongs to, or None. + - The field instance number of the Message class it belongs to, or None. + """ + field_instance, message_class, number = state + if field_instance is None: + self.__field = message_class.field_by_number(number) + else: + self.__field = field_instance + + @property + def field(self): + """Field that validates list.""" + return self.__field + + def __setslice__(self, i, j, sequence): + """Validate slice assignment to list.""" + self.__field.validate(sequence) + list.__setslice__(self, i, j, sequence) + + def __setitem__(self, index, value): + """Validate item assignment to list.""" + if isinstance(index, slice): + self.__field.validate(value) + else: + self.__field.validate_element(value) + list.__setitem__(self, index, value) + + def append(self, value): + """Validate item appending to list.""" + self.__field.validate_element(value) + return list.append(self, value) + + def extend(self, sequence): + """Validate extension of list.""" + self.__field.validate(sequence) + return list.extend(self, sequence) + + def insert(self, index, value): + """Validate item insertion to list.""" + self.__field.validate_element(value) + return list.insert(self, index, value) + + +class _FieldMeta(type): + + def __init__(cls, name, bases, dct): + getattr(cls, '_Field__variant_to_type').update( + (variant, cls) for variant in dct.get('VARIANTS', [])) + type.__init__(cls, name, bases, dct) + + +# TODO(rafek): Prevent additional field subclasses. +class Field(six.with_metaclass(_FieldMeta, object)): + + __initialized = False + __variant_to_type = {} + + @util.positional(2) + def __init__(self, + number, + required=False, + repeated=False, + variant=None, + default=None): + """Constructor. + + The required and repeated parameters are mutually exclusive. Setting both + to True will raise a FieldDefinitionError. + + Sub-class Attributes: + Each sub-class of Field must define the following: + VARIANTS: Set of variant types accepted by that field. + DEFAULT_VARIANT: Default variant type if not specified in constructor. + + Args: + number: Number of field. Must be unique per message class. + required: Whether or not field is required. Mutually exclusive with + 'repeated'. + repeated: Whether or not field is repeated. Mutually exclusive with + 'required'. + variant: Wire-format variant hint. + default: Default value for field if not found in stream. + + Raises: + InvalidVariantError when invalid variant for field is provided. + InvalidDefaultError when invalid default for field is provided. + FieldDefinitionError when invalid number provided or mutually exclusive + fields are used. + InvalidNumberError when the field number is out of range or reserved. + """ + if not isinstance(number, int) or not 1 <= number <= MAX_FIELD_NUMBER: + raise InvalidNumberError('Invalid number for field: %s\n' + 'Number must be 1 or greater and %d or less' % + (number, MAX_FIELD_NUMBER)) + + if FIRST_RESERVED_FIELD_NUMBER <= number <= LAST_RESERVED_FIELD_NUMBER: + raise InvalidNumberError('Tag number %d is a reserved number.\n' + 'Numbers %d to %d are reserved' % + (number, FIRST_RESERVED_FIELD_NUMBER, + LAST_RESERVED_FIELD_NUMBER)) + + if repeated and required: + raise FieldDefinitionError('Cannot set both repeated and required') + + if variant is None: + variant = self.DEFAULT_VARIANT + + if repeated and default is not None: + raise FieldDefinitionError('Repeated fields may not have defaults') + + if variant not in self.VARIANTS: + raise InvalidVariantError( + 'Invalid variant: %s\nValid variants for %s are %r' % + (variant, type(self).__name__, sorted(self.VARIANTS))) + + self.number = number + self.required = required + self.repeated = repeated + self.variant = variant + + if default is not None: + try: + self.validate_default(default) + except ValidationError as err: + try: + name = self.name + except AttributeError: + # For when raising error before name initialization. + raise InvalidDefaultError('Invalid default value for %s: %r: %s' % + (self.__class__.__name__, default, err)) + else: + raise InvalidDefaultError('Invalid default value for field %s: ' + '%r: %s' % (name, default, err)) + + self.__default = default + self.__initialized = True + + def __setattr__(self, name, value): + """Setter overidden to prevent assignment to fields after creation. + + Args: + name: Name of attribute to set. + value: Value to assign. + """ + # Special case post-init names. They need to be set after constructor. + if name in _POST_INIT_FIELD_ATTRIBUTE_NAMES: + object.__setattr__(self, name, value) + return + + # All other attributes must be set before __initialized. + if not self.__initialized: + # Not initialized yet, allow assignment. + object.__setattr__(self, name, value) + else: + raise AttributeError('Field objects are read-only') + + def __set__(self, message_instance, value): + """Set value on message. + + Args: + message_instance: Message instance to set value on. + value: Value to set on message. + """ + # Reaches in to message instance directly to assign to private tags. + if value is None: + if self.repeated: + raise ValidationError( + 'May not assign None to repeated field %s' % self.name) + else: + message_instance._Message__tags.pop(self.number, None) + else: + if self.repeated: + value = FieldList(self, value) + else: + value = self.validate(value) + message_instance._Message__tags[self.number] = value + + def __get__(self, message_instance, message_class): + if message_instance is None: + return self + + result = message_instance._Message__tags.get(self.number) + if result is None: + return self.default + else: + return result + + def validate_element(self, value): + """Validate single element of field. + + This is different from validate in that it is used on individual + values of repeated fields. + + Args: + value: Value to validate. + + Returns: + The value casted in the expectes type. + + Raises: + ValidationError if value is not expected type. + """ + if not isinstance(value, self.type): + # Authorize in values as float + if isinstance(value, six.integer_types) and self.type == float: + return float(value) + + if value is None: + if self.required: + raise ValidationError('Required field is missing') + else: + try: + name = self.name + except AttributeError: + raise ValidationError('Expected type %s for %s, ' + 'found %s (type %s)' % + (self.type, self.__class__.__name__, + value, type(value))) + else: + raise ValidationError('Expected type %s for field %s, ' + 'found %s (type %s)' % + (self.type, name, value, type(value))) + return value + + def __validate(self, value, validate_element): + """Internal validation function. + + Validate an internal value using a function to validate individual elements. + + Args: + value: Value to validate. + validate_element: Function to use to validate individual elements. + + Raises: + ValidationError if value is not expected type. + """ + if not self.repeated: + return validate_element(value) + else: + # Must be a list or tuple, may not be a string. + if isinstance(value, (list, tuple)): + result = [] + for element in value: + if element is None: + try: + name = self.name + except AttributeError: + raise ValidationError('Repeated values for %s ' + 'may not be None' % self.__class__.__name__) + else: + raise ValidationError('Repeated values for field %s ' + 'may not be None' % name) + result.append(validate_element(element)) + return result + elif value is not None: + try: + name = self.name + except AttributeError: + raise ValidationError('%s is repeated. Found: %s' % ( + self.__class__.__name__, value)) + else: + raise ValidationError('Field %s is repeated. Found: %s' % (name, + value)) + return value + + def validate(self, value): + """Validate value assigned to field. + + Args: + value: Value to validate. + + Returns: + the value eventually casted in the correct type. + + Raises: + ValidationError if value is not expected type. + """ + return self.__validate(value, self.validate_element) + + def validate_default_element(self, value): + """Validate value as assigned to field default field. + + Some fields may allow for delayed resolution of default types necessary + in the case of circular definition references. In this case, the default + value might be a place holder that is resolved when needed after all the + message classes are defined. + + Args: + value: Default value to validate. + + Returns: + the value eventually casted in the correct type. + + Raises: + ValidationError if value is not expected type. + """ + return self.validate_element(value) + + def validate_default(self, value): + """Validate default value assigned to field. + + Args: + value: Value to validate. + + Returns: + the value eventually casted in the correct type. + + Raises: + ValidationError if value is not expected type. + """ + return self.__validate(value, self.validate_default_element) + + def message_definition(self): + """Get Message definition that contains this Field definition. + + Returns: + Containing Message definition for Field. Will return None if for + some reason Field is defined outside of a Message class. + """ + try: + return self._message_definition() + except AttributeError: + return None + + @property + def default(self): + """Get default value for field.""" + return self.__default + + @classmethod + def lookup_field_type_by_variant(cls, variant): + return cls.__variant_to_type[variant] + + +class IntegerField(Field): + """Field definition for integer values.""" + + VARIANTS = frozenset([Variant.INT32, + Variant.INT64, + Variant.UINT32, + Variant.UINT64, + Variant.SINT32, + Variant.SINT64, + ]) + + DEFAULT_VARIANT = Variant.INT64 + + type = six.integer_types + + +class FloatField(Field): + """Field definition for float values.""" + + VARIANTS = frozenset([Variant.FLOAT, + Variant.DOUBLE, + ]) + + DEFAULT_VARIANT = Variant.DOUBLE + + type = float + + +class BooleanField(Field): + """Field definition for boolean values.""" + + VARIANTS = frozenset([Variant.BOOL]) + + DEFAULT_VARIANT = Variant.BOOL + + type = bool + + +class BytesField(Field): + """Field definition for byte string values.""" + + VARIANTS = frozenset([Variant.BYTES]) + + DEFAULT_VARIANT = Variant.BYTES + + type = bytes + + +class StringField(Field): + """Field definition for unicode string values.""" + + VARIANTS = frozenset([Variant.STRING]) + + DEFAULT_VARIANT = Variant.STRING + + type = six.text_type + + def validate_element(self, value): + """Validate StringField allowing for str and unicode. + + Raises: + ValidationError if a str value is not 7-bit ascii. + """ + # If value is str is it considered valid. Satisfies "required=True". + if isinstance(value, bytes): + try: + six.text_type(value, 'ascii') + except UnicodeDecodeError as err: + try: + name = self.name + except AttributeError: + validation_error = ValidationError( + 'Field encountered non-ASCII string %r: %s' % (value, + err)) + else: + validation_error = ValidationError( + 'Field %s encountered non-ASCII string %r: %s' % (self.name, + value, + err)) + validation_error.field_name = self.name + raise validation_error + else: + return super(StringField, self).validate_element(value) + + +class MessageField(Field): + """Field definition for sub-message values. + + Message fields contain instance of other messages. Instances stored + on messages stored on message fields are considered to be owned by + the containing message instance and should not be shared between + owning instances. + + Message fields must be defined to reference a single type of message. + Normally message field are defined by passing the referenced message + class in to the constructor. + + It is possible to define a message field for a type that does not yet + exist by passing the name of the message in to the constructor instead + of a message class. Resolution of the actual type of the message is + deferred until it is needed, for example, during message verification. + Names provided to the constructor must refer to a class within the same + python module as the class that is using it. Names refer to messages + relative to the containing messages scope. For example, the two fields + of OuterMessage refer to the same message type: + + class Outer(Message): + + inner_relative = MessageField('Inner', 1) + inner_absolute = MessageField('Outer.Inner', 2) + + class Inner(Message): + ... + + When resolving an actual type, MessageField will traverse the entire + scope of nested messages to match a message name. This makes it easy + for siblings to reference siblings: + + class Outer(Message): + + class Inner(Message): + + sibling = MessageField('Sibling', 1) + + class Sibling(Message): + ... + """ + + VARIANTS = frozenset([Variant.MESSAGE]) + + DEFAULT_VARIANT = Variant.MESSAGE + + @util.positional(3) + def __init__(self, + message_type, + number, + required=False, + repeated=False, + variant=None): + """Constructor. + + Args: + message_type: Message type for field. Must be subclass of Message. + number: Number of field. Must be unique per message class. + required: Whether or not field is required. Mutually exclusive to + 'repeated'. + repeated: Whether or not field is repeated. Mutually exclusive to + 'required'. + variant: Wire-format variant hint. + + Raises: + FieldDefinitionError when invalid message_type is provided. + """ + valid_type = (isinstance(message_type, six.string_types) or + (message_type is not Message and + isinstance(message_type, type) and + issubclass(message_type, Message))) + + if not valid_type: + raise FieldDefinitionError('Invalid message class: %s' % message_type) + + if isinstance(message_type, six.string_types): + self.__type_name = message_type + self.__type = None + else: + self.__type = message_type + + super(MessageField, self).__init__(number, + required=required, + repeated=repeated, + variant=variant) + + def __set__(self, message_instance, value): + """Set value on message. + + Args: + message_instance: Message instance to set value on. + value: Value to set on message. + """ + message_type = self.type + if isinstance(message_type, type) and issubclass(message_type, Message): + if self.repeated: + if value and isinstance(value, (list, tuple)): + value = [(message_type(**v) if isinstance(v, dict) else v) + for v in value] + elif isinstance(value, dict): + value = message_type(**value) + super(MessageField, self).__set__(message_instance, value) + + @property + def type(self): + """Message type used for field.""" + if self.__type is None: + message_type = find_definition(self.__type_name, self.message_definition()) + if not (message_type is not Message and + isinstance(message_type, type) and + issubclass(message_type, Message)): + raise FieldDefinitionError('Invalid message class: %s' % message_type) + self.__type = message_type + return self.__type + + @property + def message_type(self): + """Underlying message type used for serialization. + + Will always be a sub-class of Message. This is different from type + which represents the python value that message_type is mapped to for + use by the user. + """ + return self.type + + def value_from_message(self, message): + """Convert a message to a value instance. + + Used by deserializers to convert from underlying messages to + value of expected user type. + + Args: + message: A message instance of type self.message_type. + + Returns: + Value of self.message_type. + """ + if not isinstance(message, self.message_type): + raise DecodeError('Expected type %s, got %s: %r' % + (self.message_type.__name__, + type(message).__name__, + message)) + return message + + def value_to_message(self, value): + """Convert a value instance to a message. + + Used by serializers to convert Python user types to underlying + messages for transmission. + + Args: + value: A value of type self.type. + + Returns: + An instance of type self.message_type. + """ + if not isinstance(value, self.type): + raise EncodeError('Expected type %s, got %s: %r' % + (self.type.__name__, + type(value).__name__, + value)) + return value + + +class EnumField(Field): + """Field definition for enum values. + + Enum fields may have default values that are delayed until the associated enum + type is resolved. This is necessary to support certain circular references. + + For example: + + class Message1(Message): + + class Color(Enum): + + RED = 1 + GREEN = 2 + BLUE = 3 + + # This field default value will be validated when default is accessed. + animal = EnumField('Message2.Animal', 1, default='HORSE') + + class Message2(Message): + + class Animal(Enum): + + DOG = 1 + CAT = 2 + HORSE = 3 + + # This fields default value will be validated right away since Color is + # already fully resolved. + color = EnumField(Message1.Color, 1, default='RED') + """ + + VARIANTS = frozenset([Variant.ENUM]) + + DEFAULT_VARIANT = Variant.ENUM + + def __init__(self, enum_type, number, **kwargs): + """Constructor. + + Args: + enum_type: Enum type for field. Must be subclass of Enum. + number: Number of field. Must be unique per message class. + required: Whether or not field is required. Mutually exclusive to + 'repeated'. + repeated: Whether or not field is repeated. Mutually exclusive to + 'required'. + variant: Wire-format variant hint. + default: Default value for field if not found in stream. + + Raises: + FieldDefinitionError when invalid enum_type is provided. + """ + valid_type = (isinstance(enum_type, six.string_types) or + (enum_type is not Enum and + isinstance(enum_type, type) and + issubclass(enum_type, Enum))) + + if not valid_type: + raise FieldDefinitionError('Invalid enum type: %s' % enum_type) + + if isinstance(enum_type, six.string_types): + self.__type_name = enum_type + self.__type = None + else: + self.__type = enum_type + + super(EnumField, self).__init__(number, **kwargs) + + def validate_default_element(self, value): + """Validate default element of Enum field. + + Enum fields allow for delayed resolution of default values when the type + of the field has not been resolved. The default value of a field may be + a string or an integer. If the Enum type of the field has been resolved, + the default value is validated against that type. + + Args: + value: Value to validate. + + Raises: + ValidationError if value is not expected message type. + """ + if isinstance(value, (six.string_types, six.integer_types)): + # Validation of the value does not happen for delayed resolution + # enumerated types. Ignore if type is not yet resolved. + if self.__type: + self.__type(value) + return + + return super(EnumField, self).validate_default_element(value) + + @property + def type(self): + """Enum type used for field.""" + if self.__type is None: + found_type = find_definition(self.__type_name, self.message_definition()) + if not (found_type is not Enum and + isinstance(found_type, type) and + issubclass(found_type, Enum)): + raise FieldDefinitionError('Invalid enum type: %s' % found_type) + + self.__type = found_type + return self.__type + + @property + def default(self): + """Default for enum field. + + Will cause resolution of Enum type and unresolved default value. + """ + try: + return self.__resolved_default + except AttributeError: + resolved_default = super(EnumField, self).default + if isinstance(resolved_default, (six.string_types, six.integer_types)): + resolved_default = self.type(resolved_default) + self.__resolved_default = resolved_default + return self.__resolved_default + + +@util.positional(2) +def find_definition(name, relative_to=None, importer=__import__): + """Find definition by name in module-space. + + The find algorthm will look for definitions by name relative to a message + definition or by fully qualfied name. If no definition is found relative + to the relative_to parameter it will do the same search against the container + of relative_to. If relative_to is a nested Message, it will search its + message_definition(). If that message has no message_definition() it will + search its module. If relative_to is a module, it will attempt to look for + the containing module and search relative to it. If the module is a top-level + module, it will look for the a message using a fully qualified name. If + no message is found then, the search fails and DefinitionNotFoundError is + raised. + + For example, when looking for any definition 'foo.bar.ADefinition' relative to + an actual message definition abc.xyz.SomeMessage: + + find_definition('foo.bar.ADefinition', SomeMessage) + + It is like looking for the following fully qualified names: + + abc.xyz.SomeMessage. foo.bar.ADefinition + abc.xyz. foo.bar.ADefinition + abc. foo.bar.ADefinition + foo.bar.ADefinition + + When resolving the name relative to Message definitions and modules, the + algorithm searches any Messages or sub-modules found in its path. + Non-Message values are not searched. + + A name that begins with '.' is considered to be a fully qualified name. The + name is always searched for from the topmost package. For example, assume + two message types: + + abc.xyz.SomeMessage + xyz.SomeMessage + + Searching for '.xyz.SomeMessage' relative to 'abc' will resolve to + 'xyz.SomeMessage' and not 'abc.xyz.SomeMessage'. For this kind of name, + the relative_to parameter is effectively ignored and always set to None. + + For more information about package name resolution, please see: + + http://code.google.com/apis/protocolbuffers/docs/proto.html#packages + + Args: + name: Name of definition to find. May be fully qualified or relative name. + relative_to: Search for definition relative to message definition or module. + None will cause a fully qualified name search. + importer: Import function to use for resolving modules. + + Returns: + Enum or Message class definition associated with name. + + Raises: + DefinitionNotFoundError if no definition is found in any search path. + """ + # Check parameters. + if not (relative_to is None or + isinstance(relative_to, types.ModuleType) or + isinstance(relative_to, type) and issubclass(relative_to, Message)): + raise TypeError('relative_to must be None, Message definition or module. ' + 'Found: %s' % relative_to) + + name_path = name.split('.') + + # Handle absolute path reference. + if not name_path[0]: + relative_to = None + name_path = name_path[1:] + + def search_path(): + """Performs a single iteration searching the path from relative_to. + + This is the function that searches up the path from a relative object. + + fully.qualified.object . relative.or.nested.Definition + ----------------------------> + ^ + | + this part of search --+ + + Returns: + Message or Enum at the end of name_path, else None. + """ + next = relative_to + for node in name_path: + # Look for attribute first. + attribute = getattr(next, node, None) + + if attribute is not None: + next = attribute + else: + # If module, look for sub-module. + if next is None or isinstance(next, types.ModuleType): + if next is None: + module_name = node + else: + module_name = '%s.%s' % (next.__name__, node) + + try: + fromitem = module_name.split('.')[-1] + next = importer(module_name, '', '', [str(fromitem)]) + except ImportError: + return None + else: + return None + + if (not isinstance(next, types.ModuleType) and + not (isinstance(next, type) and + issubclass(next, (Message, Enum)))): + return None + + return next + + while True: + found = search_path() + if isinstance(found, type) and issubclass(found, (Enum, Message)): + return found + else: + # Find next relative_to to search against. + # + # fully.qualified.object . relative.or.nested.Definition + # <--------------------- + # ^ + # | + # does this part of search + if relative_to is None: + # Fully qualified search was done. Nothing found. Fail. + raise DefinitionNotFoundError('Could not find definition for %s' + % (name,)) + else: + if isinstance(relative_to, types.ModuleType): + # Find parent module. + module_path = relative_to.__name__.split('.')[:-1] + if not module_path: + relative_to = None + else: + # Should not raise ImportError. If it does... weird and + # unexepected. Propagate. + relative_to = importer( + '.'.join(module_path), '', '', [module_path[-1]]) + elif (isinstance(relative_to, type) and + issubclass(relative_to, Message)): + parent = relative_to.message_definition() + if parent is None: + last_module_name = relative_to.__module__.split('.')[-1] + relative_to = importer( + relative_to.__module__, '', '', [last_module_name]) + else: + relative_to = parent diff --git a/endpoints/bundled/protorpc/non_sdk_imports.py b/endpoints/bundled/protorpc/non_sdk_imports.py new file mode 100644 index 0000000..5b971ec --- /dev/null +++ b/endpoints/bundled/protorpc/non_sdk_imports.py @@ -0,0 +1,21 @@ +"""Dynamically decide from where to import other non SDK Google modules. + +All other protorpc code should import other non SDK modules from +this module. If necessary, add new imports here (in both places). +""" + +__author__ = 'yey@google.com (Ye Yuan)' + +# pylint: disable=g-import-not-at-top +# pylint: disable=unused-import + +try: + from google.protobuf import descriptor + normal_environment = True +except ImportError: + normal_environment = False + +if normal_environment: + from google.protobuf import descriptor_pb2 + from google.protobuf import message + from google.protobuf import reflection diff --git a/endpoints/bundled/protorpc/protobuf.py b/endpoints/bundled/protorpc/protobuf.py new file mode 100644 index 0000000..18d0074 --- /dev/null +++ b/endpoints/bundled/protorpc/protobuf.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Protocol buffer support for message types. + +For more details about protocol buffer encoding and decoding please see: + + http://code.google.com/apis/protocolbuffers/docs/encoding.html + +Public Exceptions: + DecodeError: Raised when a decode error occurs from incorrect protobuf format. + +Public Functions: + encode_message: Encodes a message in to a protocol buffer string. + decode_message: Decode from a protocol buffer string to a message. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import array + +from . import message_types +from . import messages +from . import util +from .google_imports import ProtocolBuffer + + +__all__ = ['ALTERNATIVE_CONTENT_TYPES', + 'CONTENT_TYPE', + 'encode_message', + 'decode_message', + ] + +CONTENT_TYPE = 'application/octet-stream' + +ALTERNATIVE_CONTENT_TYPES = ['application/x-google-protobuf'] + + +class _Encoder(ProtocolBuffer.Encoder): + """Extension of protocol buffer encoder. + + Original protocol buffer encoder does not have complete set of methods + for handling required encoding. This class adds them. + """ + + # TODO(rafek): Implement the missing encoding types. + def no_encoding(self, value): + """No encoding available for type. + + Args: + value: Value to encode. + + Raises: + NotImplementedError at all times. + """ + raise NotImplementedError() + + def encode_enum(self, value): + """Encode an enum value. + + Args: + value: Enum to encode. + """ + self.putVarInt32(value.number) + + def encode_message(self, value): + """Encode a Message in to an embedded message. + + Args: + value: Message instance to encode. + """ + self.putPrefixedString(encode_message(value)) + + + def encode_unicode_string(self, value): + """Helper to properly pb encode unicode strings to UTF-8. + + Args: + value: String value to encode. + """ + if isinstance(value, six.text_type): + value = value.encode('utf-8') + self.putPrefixedString(value) + + +class _Decoder(ProtocolBuffer.Decoder): + """Extension of protocol buffer decoder. + + Original protocol buffer decoder does not have complete set of methods + for handling required decoding. This class adds them. + """ + + # TODO(rafek): Implement the missing encoding types. + def no_decoding(self): + """No decoding available for type. + + Raises: + NotImplementedError at all times. + """ + raise NotImplementedError() + + def decode_string(self): + """Decode a unicode string. + + Returns: + Next value in stream as a unicode string. + """ + return self.getPrefixedString().decode('UTF-8') + + def decode_boolean(self): + """Decode a boolean value. + + Returns: + Next value in stream as a boolean. + """ + return bool(self.getBoolean()) + + +# Number of bits used to describe a protocol buffer bits used for the variant. +_WIRE_TYPE_BITS = 3 +_WIRE_TYPE_MASK = 7 + + +# Maps variant to underlying wire type. Many variants map to same type. +_VARIANT_TO_WIRE_TYPE = { + messages.Variant.DOUBLE: _Encoder.DOUBLE, + messages.Variant.FLOAT: _Encoder.FLOAT, + messages.Variant.INT64: _Encoder.NUMERIC, + messages.Variant.UINT64: _Encoder.NUMERIC, + messages.Variant.INT32: _Encoder.NUMERIC, + messages.Variant.BOOL: _Encoder.NUMERIC, + messages.Variant.STRING: _Encoder.STRING, + messages.Variant.MESSAGE: _Encoder.STRING, + messages.Variant.BYTES: _Encoder.STRING, + messages.Variant.UINT32: _Encoder.NUMERIC, + messages.Variant.ENUM: _Encoder.NUMERIC, + messages.Variant.SINT32: _Encoder.NUMERIC, + messages.Variant.SINT64: _Encoder.NUMERIC, +} + + +# Maps variant to encoder method. +_VARIANT_TO_ENCODER_MAP = { + messages.Variant.DOUBLE: _Encoder.putDouble, + messages.Variant.FLOAT: _Encoder.putFloat, + messages.Variant.INT64: _Encoder.putVarInt64, + messages.Variant.UINT64: _Encoder.putVarUint64, + messages.Variant.INT32: _Encoder.putVarInt32, + messages.Variant.BOOL: _Encoder.putBoolean, + messages.Variant.STRING: _Encoder.encode_unicode_string, + messages.Variant.MESSAGE: _Encoder.encode_message, + messages.Variant.BYTES: _Encoder.encode_unicode_string, + messages.Variant.UINT32: _Encoder.no_encoding, + messages.Variant.ENUM: _Encoder.encode_enum, + messages.Variant.SINT32: _Encoder.no_encoding, + messages.Variant.SINT64: _Encoder.no_encoding, +} + + +# Basic wire format decoders. Used for reading unknown values. +_WIRE_TYPE_TO_DECODER_MAP = { + _Encoder.NUMERIC: _Decoder.getVarInt64, + _Encoder.DOUBLE: _Decoder.getDouble, + _Encoder.STRING: _Decoder.getPrefixedString, + _Encoder.FLOAT: _Decoder.getFloat, +} + + +# Map wire type to variant. Used to find a variant for unknown values. +_WIRE_TYPE_TO_VARIANT_MAP = { + _Encoder.NUMERIC: messages.Variant.INT64, + _Encoder.DOUBLE: messages.Variant.DOUBLE, + _Encoder.STRING: messages.Variant.STRING, + _Encoder.FLOAT: messages.Variant.FLOAT, +} + + +# Wire type to name mapping for error messages. +_WIRE_TYPE_NAME = { + _Encoder.NUMERIC: 'NUMERIC', + _Encoder.DOUBLE: 'DOUBLE', + _Encoder.STRING: 'STRING', + _Encoder.FLOAT: 'FLOAT', +} + + +# Maps variant to decoder method. +_VARIANT_TO_DECODER_MAP = { + messages.Variant.DOUBLE: _Decoder.getDouble, + messages.Variant.FLOAT: _Decoder.getFloat, + messages.Variant.INT64: _Decoder.getVarInt64, + messages.Variant.UINT64: _Decoder.getVarUint64, + messages.Variant.INT32: _Decoder.getVarInt32, + messages.Variant.BOOL: _Decoder.decode_boolean, + messages.Variant.STRING: _Decoder.decode_string, + messages.Variant.MESSAGE: _Decoder.getPrefixedString, + messages.Variant.BYTES: _Decoder.getPrefixedString, + messages.Variant.UINT32: _Decoder.no_decoding, + messages.Variant.ENUM: _Decoder.getVarInt32, + messages.Variant.SINT32: _Decoder.no_decoding, + messages.Variant.SINT64: _Decoder.no_decoding, +} + + +def encode_message(message): + """Encode Message instance to protocol buffer. + + Args: + Message instance to encode in to protocol buffer. + + Returns: + String encoding of Message instance in protocol buffer format. + + Raises: + messages.ValidationError if message is not initialized. + """ + message.check_initialized() + encoder = _Encoder() + + # Get all fields, from the known fields we parsed and the unknown fields + # we saved. Note which ones were known, so we can process them differently. + all_fields = [(field.number, field) for field in message.all_fields()] + all_fields.extend((key, None) + for key in message.all_unrecognized_fields() + if isinstance(key, six.integer_types)) + all_fields.sort() + for field_num, field in all_fields: + if field: + # Known field. + value = message.get_assigned_value(field.name) + if value is None: + continue + variant = field.variant + repeated = field.repeated + else: + # Unrecognized field. + value, variant = message.get_unrecognized_field_info(field_num) + if not isinstance(variant, messages.Variant): + continue + repeated = isinstance(value, (list, tuple)) + + tag = ((field_num << _WIRE_TYPE_BITS) | _VARIANT_TO_WIRE_TYPE[variant]) + + # Write value to wire. + if repeated: + values = value + else: + values = [value] + for next in values: + encoder.putVarInt32(tag) + if isinstance(field, messages.MessageField): + next = field.value_to_message(next) + field_encoder = _VARIANT_TO_ENCODER_MAP[variant] + field_encoder(encoder, next) + + return encoder.buffer().tostring() + + +def decode_message(message_type, encoded_message): + """Decode protocol buffer to Message instance. + + Args: + message_type: Message type to decode data to. + encoded_message: Encoded version of message as string. + + Returns: + Decoded instance of message_type. + + Raises: + DecodeError if an error occurs during decoding, such as incompatible + wire format for a field. + messages.ValidationError if merged message is not initialized. + """ + message = message_type() + message_array = array.array('B') + message_array.fromstring(encoded_message) + try: + decoder = _Decoder(message_array, 0, len(message_array)) + + while decoder.avail() > 0: + # Decode tag and variant information. + encoded_tag = decoder.getVarInt32() + tag = encoded_tag >> _WIRE_TYPE_BITS + wire_type = encoded_tag & _WIRE_TYPE_MASK + try: + found_wire_type_decoder = _WIRE_TYPE_TO_DECODER_MAP[wire_type] + except: + raise messages.DecodeError('No such wire type %d' % wire_type) + + if tag < 1: + raise messages.DecodeError('Invalid tag value %d' % tag) + + try: + field = message.field_by_number(tag) + except KeyError: + # Unexpected tags are ok. + field = None + wire_type_decoder = found_wire_type_decoder + else: + expected_wire_type = _VARIANT_TO_WIRE_TYPE[field.variant] + if expected_wire_type != wire_type: + raise messages.DecodeError('Expected wire type %s but found %s' % ( + _WIRE_TYPE_NAME[expected_wire_type], + _WIRE_TYPE_NAME[wire_type])) + + wire_type_decoder = _VARIANT_TO_DECODER_MAP[field.variant] + + value = wire_type_decoder(decoder) + + # Save unknown fields and skip additional processing. + if not field: + # When saving this, save it under the tag number (which should + # be unique), and set the variant and value so we know how to + # interpret the value later. + variant = _WIRE_TYPE_TO_VARIANT_MAP.get(wire_type) + if variant: + message.set_unrecognized_field(tag, value, variant) + continue + + # Special case Enum and Message types. + if isinstance(field, messages.EnumField): + try: + value = field.type(value) + except TypeError: + raise messages.DecodeError('Invalid enum value %s' % value) + elif isinstance(field, messages.MessageField): + value = decode_message(field.message_type, value) + value = field.value_from_message(value) + + # Merge value in to message. + if field.repeated: + values = getattr(message, field.name) + if values is None: + setattr(message, field.name, [value]) + else: + values.append(value) + else: + setattr(message, field.name, value) + except ProtocolBuffer.ProtocolBufferDecodeError as err: + raise messages.DecodeError('Decoding error: %s' % str(err)) + + message.check_initialized() + return message diff --git a/endpoints/bundled/protorpc/protojson.py b/endpoints/bundled/protorpc/protojson.py new file mode 100644 index 0000000..8e2c94e --- /dev/null +++ b/endpoints/bundled/protorpc/protojson.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""JSON support for message types. + +Public classes: + MessageJSONEncoder: JSON encoder for message objects. + +Public functions: + encode_message: Encodes a message in to a JSON string. + decode_message: Merge from a JSON string in to a message. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import base64 +import binascii +import logging + +from . import message_types +from . import messages +from . import util + +__all__ = [ + 'ALTERNATIVE_CONTENT_TYPES', + 'CONTENT_TYPE', + 'MessageJSONEncoder', + 'encode_message', + 'decode_message', + 'ProtoJson', +] + + +def _load_json_module(): + """Try to load a valid json module. + + There are more than one json modules that might be installed. They are + mostly compatible with one another but some versions may be different. + This function attempts to load various json modules in a preferred order. + It does a basic check to guess if a loaded version of json is compatible. + + Returns: + Compatible json module. + + Raises: + ImportError if there are no json modules or the loaded json module is + not compatible with ProtoRPC. + """ + first_import_error = None + for module_name in ['json', + 'simplejson']: + try: + module = __import__(module_name, {}, {}, 'json') + if not hasattr(module, 'JSONEncoder'): + message = ('json library "%s" is not compatible with ProtoRPC' % + module_name) + logging.warning(message) + raise ImportError(message) + else: + return module + except ImportError as err: + if not first_import_error: + first_import_error = err + + logging.error('Must use valid json library (Python 2.6 json or simplejson)') + raise first_import_error +json = _load_json_module() + + +# TODO: Rename this to MessageJsonEncoder. +class MessageJSONEncoder(json.JSONEncoder): + """Message JSON encoder class. + + Extension of JSONEncoder that can build JSON from a message object. + """ + + def __init__(self, protojson_protocol=None, **kwargs): + """Constructor. + + Args: + protojson_protocol: ProtoJson instance. + """ + super(MessageJSONEncoder, self).__init__(**kwargs) + self.__protojson_protocol = protojson_protocol or ProtoJson.get_default() + + def default(self, value): + """Return dictionary instance from a message object. + + Args: + value: Value to get dictionary for. If not encodable, will + call superclasses default method. + """ + if isinstance(value, messages.Enum): + return str(value) + + if six.PY3 and isinstance(value, bytes): + return value.decode('utf8') + + if isinstance(value, messages.Message): + result = {} + for field in value.all_fields(): + item = value.get_assigned_value(field.name) + if item not in (None, [], ()): + result[field.name] = self.__protojson_protocol.encode_field( + field, item) + # Handle unrecognized fields, so they're included when a message is + # decoded then encoded. + for unknown_key in value.all_unrecognized_fields(): + unrecognized_field, _ = value.get_unrecognized_field_info(unknown_key) + result[unknown_key] = unrecognized_field + return result + else: + return super(MessageJSONEncoder, self).default(value) + + +class ProtoJson(object): + """ProtoRPC JSON implementation class. + + Implementation of JSON based protocol used for serializing and deserializing + message objects. Instances of remote.ProtocolConfig constructor or used with + remote.Protocols.add_protocol. See the remote.py module for more details. + """ + + CONTENT_TYPE = 'application/json' + ALTERNATIVE_CONTENT_TYPES = [ + 'application/x-javascript', + 'text/javascript', + 'text/x-javascript', + 'text/x-json', + 'text/json', + ] + + def encode_field(self, field, value): + """Encode a python field value to a JSON value. + + Args: + field: A ProtoRPC field instance. + value: A python value supported by field. + + Returns: + A JSON serializable value appropriate for field. + """ + if isinstance(field, messages.BytesField): + if field.repeated: + value = [base64.b64encode(byte) for byte in value] + else: + value = base64.b64encode(value) + elif isinstance(field, message_types.DateTimeField): + # DateTimeField stores its data as a RFC 3339 compliant string. + if field.repeated: + value = [i.isoformat() for i in value] + else: + value = value.isoformat() + return value + + def encode_message(self, message): + """Encode Message instance to JSON string. + + Args: + Message instance to encode in to JSON string. + + Returns: + String encoding of Message instance in protocol JSON format. + + Raises: + messages.ValidationError if message is not initialized. + """ + message.check_initialized() + + return json.dumps(message, cls=MessageJSONEncoder, protojson_protocol=self) + + def decode_message(self, message_type, encoded_message): + """Merge JSON structure to Message instance. + + Args: + message_type: Message to decode data to. + encoded_message: JSON encoded version of message. + + Returns: + Decoded instance of message_type. + + Raises: + ValueError: If encoded_message is not valid JSON. + messages.ValidationError if merged message is not initialized. + """ + dictionary = json.loads(encoded_message) if encoded_message.strip() else {} + message = self.__decode_dictionary(message_type, dictionary) + message.check_initialized() + return message + + def __find_variant(self, value): + """Find the messages.Variant type that describes this value. + + Args: + value: The value whose variant type is being determined. + + Returns: + The messages.Variant value that best describes value's type, or None if + it's a type we don't know how to handle. + """ + if isinstance(value, bool): + return messages.Variant.BOOL + elif isinstance(value, six.integer_types): + return messages.Variant.INT64 + elif isinstance(value, float): + return messages.Variant.DOUBLE + elif isinstance(value, six.string_types): + return messages.Variant.STRING + elif isinstance(value, (list, tuple)): + # Find the most specific variant that covers all elements. + variant_priority = [None, messages.Variant.INT64, messages.Variant.DOUBLE, + messages.Variant.STRING] + chosen_priority = 0 + for v in value: + variant = self.__find_variant(v) + try: + priority = variant_priority.index(variant) + except IndexError: + priority = -1 + if priority > chosen_priority: + chosen_priority = priority + return variant_priority[chosen_priority] + # Unrecognized type. + return None + + def __decode_dictionary(self, message_type, dictionary): + """Merge dictionary in to message. + + Args: + message: Message to merge dictionary in to. + dictionary: Dictionary to extract information from. Dictionary + is as parsed from JSON. Nested objects will also be dictionaries. + """ + message = message_type() + for key, value in six.iteritems(dictionary): + if value is None: + try: + message.reset(key) + except AttributeError: + pass # This is an unrecognized field, skip it. + continue + + try: + field = message.field_by_name(key) + except KeyError: + # Save unknown values. + variant = self.__find_variant(value) + if variant: + if key.isdigit(): + key = int(key) + message.set_unrecognized_field(key, value, variant) + else: + logging.warning('No variant found for unrecognized field: %s', key) + continue + + # Normalize values in to a list. + if isinstance(value, list): + if not value: + continue + else: + value = [value] + + valid_value = [] + for item in value: + valid_value.append(self.decode_field(field, item)) + + if field.repeated: + existing_value = getattr(message, field.name) + setattr(message, field.name, valid_value) + else: + setattr(message, field.name, valid_value[-1]) + return message + + def decode_field(self, field, value): + """Decode a JSON value to a python value. + + Args: + field: A ProtoRPC field instance. + value: A serialized JSON value. + + Return: + A Python value compatible with field. + """ + if isinstance(field, messages.EnumField): + try: + return field.type(value) + except TypeError: + raise messages.DecodeError('Invalid enum value "%s"' % (value or '')) + + elif isinstance(field, messages.BytesField): + try: + return base64.b64decode(value) + except (binascii.Error, TypeError) as err: + raise messages.DecodeError('Base64 decoding error: %s' % err) + + elif isinstance(field, message_types.DateTimeField): + try: + return util.decode_datetime(value) + except ValueError as err: + raise messages.DecodeError(err) + + elif (isinstance(field, messages.MessageField) and + issubclass(field.type, messages.Message)): + return self.__decode_dictionary(field.type, value) + + elif (isinstance(field, messages.FloatField) and + isinstance(value, (six.integer_types, six.string_types))): + try: + return float(value) + except: + pass + + elif (isinstance(field, messages.IntegerField) and + isinstance(value, six.string_types)): + try: + return int(value) + except: + pass + + return value + + @staticmethod + def get_default(): + """Get default instanceof ProtoJson.""" + try: + return ProtoJson.__default + except AttributeError: + ProtoJson.__default = ProtoJson() + return ProtoJson.__default + + @staticmethod + def set_default(protocol): + """Set the default instance of ProtoJson. + + Args: + protocol: A ProtoJson instance. + """ + if not isinstance(protocol, ProtoJson): + raise TypeError('Expected protocol of type ProtoJson') + ProtoJson.__default = protocol + +CONTENT_TYPE = ProtoJson.CONTENT_TYPE + +ALTERNATIVE_CONTENT_TYPES = ProtoJson.ALTERNATIVE_CONTENT_TYPES + +encode_message = ProtoJson.get_default().encode_message + +decode_message = ProtoJson.get_default().decode_message diff --git a/endpoints/bundled/protorpc/protorpc_test.proto b/endpoints/bundled/protorpc/protorpc_test.proto new file mode 100644 index 0000000..50d76e0 --- /dev/null +++ b/endpoints/bundled/protorpc/protorpc_test.proto @@ -0,0 +1,83 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package protorpc; + +// Message used to nest inside another message. +message NestedMessage { + required string a_value = 1; +} + +// Message that contains nested messages. +message HasNestedMessage { + optional NestedMessage nested = 1; + repeated NestedMessage repeated_nested = 2; +} + +message HasDefault { + optional string a_value = 1 [default="a default"]; +} + +// Message that contains all variants as optional fields. +message OptionalMessage { + enum SimpleEnum { + VAL1 = 1; + VAL2 = 2; + } + + optional double double_value = 1; + optional float float_value = 2; + optional int64 int64_value = 3; + optional uint64 uint64_value = 4; + optional int32 int32_value = 5; + optional bool bool_value = 6; + optional string string_value = 7; + optional bytes bytes_value = 8; + optional SimpleEnum enum_value = 10; + + // TODO(rafek): Add support for these variants. + // optional uint32 uint32_value = 9; + // optional sint32 sint32_value = 11; + // optional sint64 sint64_value = 12; +} + +// Message that contains all variants as repeated fields. +message RepeatedMessage { + enum SimpleEnum { + VAL1 = 1; + VAL2 = 2; + } + + repeated double double_value = 1; + repeated float float_value = 2; + repeated int64 int64_value = 3; + repeated uint64 uint64_value = 4; + repeated int32 int32_value = 5; + repeated bool bool_value = 6; + repeated string string_value = 7; + repeated bytes bytes_value = 8; + repeated SimpleEnum enum_value = 10; + + // TODO(rafek): Add support for these variants. + // repeated uint32 uint32_value = 9; + // repeated sint32 sint32_value = 11; + // repeated sint64 sint64_value = 12; +} + +// Message that has nested message with all optional fields. +message HasOptionalNestedMessage { + optional OptionalMessage nested = 1; + repeated OptionalMessage repeated_nested = 2; +} diff --git a/endpoints/bundled/protorpc/protourlencode.py b/endpoints/bundled/protorpc/protourlencode.py new file mode 100644 index 0000000..9f6059e --- /dev/null +++ b/endpoints/bundled/protorpc/protourlencode.py @@ -0,0 +1,563 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""URL encoding support for messages types. + +Protocol support for URL encoded form parameters. + +Nested Fields: + Nested fields are repesented by dot separated names. For example, consider + the following messages: + + class WebPage(Message): + + title = StringField(1) + tags = StringField(2, repeated=True) + + class WebSite(Message): + + name = StringField(1) + home = MessageField(WebPage, 2) + pages = MessageField(WebPage, 3, repeated=True) + + And consider the object: + + page = WebPage() + page.title = 'Welcome to NewSite 2010' + + site = WebSite() + site.name = 'NewSite 2010' + site.home = page + + The URL encoded representation of this constellation of objects is. + + name=NewSite+2010&home.title=Welcome+to+NewSite+2010 + + An object that exists but does not have any state can be represented with + a reference to its name alone with no value assigned to it. For example: + + page = WebSite() + page.name = 'My Empty Site' + page.home = WebPage() + + is represented as: + + name=My+Empty+Site&home= + + This represents a site with an empty uninitialized home page. + +Repeated Fields: + Repeated fields are represented by the name of and the index of each value + separated by a dash. For example, consider the following message: + + home = Page() + home.title = 'Nome' + + news = Page() + news.title = 'News' + news.tags = ['news', 'articles'] + + instance = WebSite() + instance.name = 'Super fun site' + instance.pages = [home, news, preferences] + + An instance of this message can be represented as: + + name=Super+fun+site&page-0.title=Home&pages-1.title=News&... + pages-1.tags-0=new&pages-1.tags-1=articles + +Helper classes: + + URLEncodedRequestBuilder: Used for encapsulating the logic used for building + a request message from a URL encoded RPC. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import cgi +import re +import urllib + +from . import message_types +from . import messages +from . import util + +__all__ = ['CONTENT_TYPE', + 'URLEncodedRequestBuilder', + 'encode_message', + 'decode_message', + ] + +CONTENT_TYPE = 'application/x-www-form-urlencoded' + +_FIELD_NAME_REGEX = re.compile(r'^([a-zA-Z_][a-zA-Z_0-9]*)(?:-([0-9]+))?$') + + +class URLEncodedRequestBuilder(object): + """Helper that encapsulates the logic used for building URL encoded messages. + + This helper is used to map query parameters from a URL encoded RPC to a + message instance. + """ + + @util.positional(2) + def __init__(self, message, prefix=''): + """Constructor. + + Args: + message: Message instance to build from parameters. + prefix: Prefix expected at the start of valid parameters. + """ + self.__parameter_prefix = prefix + + # The empty tuple indicates the root message, which has no path. + # __messages is a full cache that makes it very easy to look up message + # instances by their paths. See make_path for details about what a path + # is. + self.__messages = {(): message} + + # This is a cache that stores paths which have been checked for + # correctness. Correctness means that an index is present for repeated + # fields on the path and absent for non-repeated fields. The cache is + # also used to check that indexes are added in the right order so that + # dicontiguous ranges of indexes are ignored. + self.__checked_indexes = set([()]) + + def make_path(self, parameter_name): + """Parse a parameter name and build a full path to a message value. + + The path of a method is a tuple of 2-tuples describing the names and + indexes within repeated fields from the root message (the message being + constructed by the builder) to an arbitrarily nested message within it. + + Each 2-tuple node of a path (name, index) is: + name: The name of the field that refers to the message instance. + index: The index within a repeated field that refers to the message + instance, None if not a repeated field. + + For example, consider: + + class VeryInner(messages.Message): + ... + + class Inner(messages.Message): + + very_inner = messages.MessageField(VeryInner, 1, repeated=True) + + class Outer(messages.Message): + + inner = messages.MessageField(Inner, 1) + + If this builder is building an instance of Outer, that instance is + referred to in the URL encoded parameters without a path. Therefore + its path is (). + + The child 'inner' is referred to by its path (('inner', None)). + + The first child of repeated field 'very_inner' on the Inner instance + is referred to by (('inner', None), ('very_inner', 0)). + + Examples: + # Correct reference to model where nation is a Message, district is + # repeated Message and county is any not repeated field type. + >>> make_path('nation.district-2.county') + (('nation', None), ('district', 2), ('county', None)) + + # Field is not part of model. + >>> make_path('nation.made_up_field') + None + + # nation field is not repeated and index provided. + >>> make_path('nation-1') + None + + # district field is repeated and no index provided. + >>> make_path('nation.district') + None + + Args: + parameter_name: Name of query parameter as passed in from the request. + in order to make a path, this parameter_name must point to a valid + field within the message structure. Nodes of the path that refer to + repeated fields must be indexed with a number, non repeated nodes must + not have an index. + + Returns: + Parsed version of the parameter_name as a tuple of tuples: + attribute: Name of attribute associated with path. + index: Postitive integer index when it is a repeated field, else None. + Will return None if the parameter_name does not have the right prefix, + does not point to a field within the message structure, does not have + an index if it is a repeated field or has an index but is not a repeated + field. + """ + if parameter_name.startswith(self.__parameter_prefix): + parameter_name = parameter_name[len(self.__parameter_prefix):] + else: + return None + + path = [] + name = [] + message_type = type(self.__messages[()]) # Get root message. + + for item in parameter_name.split('.'): + # This will catch sub_message.real_message_field.not_real_field + if not message_type: + return None + + item_match = _FIELD_NAME_REGEX.match(item) + if not item_match: + return None + attribute = item_match.group(1) + index = item_match.group(2) + if index: + index = int(index) + + try: + field = message_type.field_by_name(attribute) + except KeyError: + return None + + if field.repeated != (index is not None): + return None + + if isinstance(field, messages.MessageField): + message_type = field.message_type + else: + message_type = None + + # Path is valid so far. Append node and continue. + path.append((attribute, index)) + + return tuple(path) + + def __check_index(self, parent_path, name, index): + """Check correct index use and value relative to a given path. + + Check that for a given path the index is present for repeated fields + and that it is in range for the existing list that it will be inserted + in to or appended to. + + Args: + parent_path: Path to check against name and index. + name: Name of field to check for existance. + index: Index to check. If field is repeated, should be a number within + range of the length of the field, or point to the next item for + appending. + """ + # Don't worry about non-repeated fields. + # It's also ok if index is 0 because that means next insert will append. + if not index: + return True + + parent = self.__messages.get(parent_path, None) + value_list = getattr(parent, name, None) + # If the list does not exist then the index should be 0. Since it is + # not, path is not valid. + if not value_list: + return False + + # The index must either point to an element of the list or to the tail. + return len(value_list) >= index + + def __check_indexes(self, path): + """Check that all indexes are valid and in the right order. + + This method must iterate over the path and check that all references + to indexes point to an existing message or to the end of the list, meaning + the next value should be appended to the repeated field. + + Args: + path: Path to check indexes for. Tuple of 2-tuples (name, index). See + make_path for more information. + + Returns: + True if all the indexes of the path are within range, else False. + """ + if path in self.__checked_indexes: + return True + + # Start with the root message. + parent_path = () + + for name, index in path: + next_path = parent_path + ((name, index),) + # First look in the checked indexes cache. + if next_path not in self.__checked_indexes: + if not self.__check_index(parent_path, name, index): + return False + self.__checked_indexes.add(next_path) + + parent_path = next_path + + return True + + def __get_or_create_path(self, path): + """Get a message from the messages cache or create it and add it. + + This method will also create any parent messages based on the path. + + When a new instance of a given message is created, it is stored in + __message by its path. + + Args: + path: Path of message to get. Path must be valid, in other words + __check_index(path) returns true. Tuple of 2-tuples (name, index). + See make_path for more information. + + Returns: + Message instance if the field being pointed to by the path is a + message, else will return None for non-message fields. + """ + message = self.__messages.get(path, None) + if message: + return message + + parent_path = () + parent = self.__messages[()] # Get the root object + + for name, index in path: + field = parent.field_by_name(name) + next_path = parent_path + ((name, index),) + next_message = self.__messages.get(next_path, None) + if next_message is None: + next_message = field.message_type() + self.__messages[next_path] = next_message + if not field.repeated: + setattr(parent, field.name, next_message) + else: + list_value = getattr(parent, field.name, None) + if list_value is None: + setattr(parent, field.name, [next_message]) + else: + list_value.append(next_message) + + parent_path = next_path + parent = next_message + + return parent + + def add_parameter(self, parameter, values): + """Add a single parameter. + + Adds a single parameter and its value to the request message. + + Args: + parameter: Query string parameter to map to request. + values: List of values to assign to request message. + + Returns: + True if parameter was valid and added to the message, else False. + + Raises: + DecodeError if the parameter refers to a valid field, and the values + parameter does not have one and only one value. Non-valid query + parameters may have multiple values and should not cause an error. + """ + path = self.make_path(parameter) + + if not path: + return False + + # Must check that all indexes of all items in the path are correct before + # instantiating any of them. For example, consider: + # + # class Repeated(object): + # ... + # + # class Inner(object): + # + # repeated = messages.MessageField(Repeated, 1, repeated=True) + # + # class Outer(object): + # + # inner = messages.MessageField(Inner, 1) + # + # instance = Outer() + # builder = URLEncodedRequestBuilder(instance) + # builder.add_parameter('inner.repeated') + # + # assert not hasattr(instance, 'inner') + # + # The check is done relative to the instance of Outer pass in to the + # constructor of the builder. This instance is not referred to at all + # because all names are assumed to be relative to it. + # + # The 'repeated' part of the path is not correct because it is missing an + # index. Because it is missing an index, it should not create an instance + # of Repeated. In this case add_parameter will return False and have no + # side effects. + # + # A correct path that would cause a new Inner instance to be inserted at + # instance.inner and a new Repeated instance to be appended to the + # instance.inner.repeated list would be 'inner.repeated-0'. + if not self.__check_indexes(path): + return False + + # Ok to build objects. + parent_path = path[:-1] + parent = self.__get_or_create_path(parent_path) + name, index = path[-1] + field = parent.field_by_name(name) + + if len(values) != 1: + raise messages.DecodeError( + 'Found repeated values for field %s.' % field.name) + + value = values[0] + + if isinstance(field, messages.IntegerField): + converted_value = int(value) + elif isinstance(field, message_types.DateTimeField): + try: + converted_value = util.decode_datetime(value) + except ValueError as e: + raise messages.DecodeError(e) + elif isinstance(field, messages.MessageField): + # Just make sure it's instantiated. Assignment to field or + # appending to list is done in __get_or_create_path. + self.__get_or_create_path(path) + return True + elif isinstance(field, messages.StringField): + converted_value = value.decode('utf-8') + elif isinstance(field, messages.BooleanField): + converted_value = value.lower() == 'true' and True or False + else: + try: + converted_value = field.type(value) + except TypeError: + raise messages.DecodeError('Invalid enum value "%s"' % value) + + if field.repeated: + value_list = getattr(parent, field.name, None) + if value_list is None: + setattr(parent, field.name, [converted_value]) + else: + if index == len(value_list): + value_list.append(converted_value) + else: + # Index should never be above len(value_list) because it was + # verified during the index check above. + value_list[index] = converted_value + else: + setattr(parent, field.name, converted_value) + + return True + + +@util.positional(1) +def encode_message(message, prefix=''): + """Encode Message instance to url-encoded string. + + Args: + message: Message instance to encode in to url-encoded string. + prefix: Prefix to append to field names of contained values. + + Returns: + String encoding of Message in URL encoded format. + + Raises: + messages.ValidationError if message is not initialized. + """ + message.check_initialized() + + parameters = [] + def build_message(parent, prefix): + """Recursively build parameter list for URL response. + + Args: + parent: Message to build parameters for. + prefix: Prefix to append to field names of contained values. + + Returns: + True if some value of parent was added to the parameters list, + else False, meaning the object contained no values. + """ + has_any_values = False + for field in sorted(parent.all_fields(), key=lambda f: f.number): + next_value = parent.get_assigned_value(field.name) + if next_value is None: + continue + + # Found a value. Ultimate return value should be True. + has_any_values = True + + # Normalize all values in to a list. + if not field.repeated: + next_value = [next_value] + + for index, item in enumerate(next_value): + # Create a name with an index if it is a repeated field. + if field.repeated: + field_name = '%s%s-%s' % (prefix, field.name, index) + else: + field_name = prefix + field.name + + if isinstance(field, message_types.DateTimeField): + # DateTimeField stores its data as a RFC 3339 compliant string. + parameters.append((field_name, item.isoformat())) + elif isinstance(field, messages.MessageField): + # Message fields must be recursed in to in order to construct + # their component parameter values. + if not build_message(item, field_name + '.'): + # The nested message is empty. Append an empty value to + # represent it. + parameters.append((field_name, '')) + elif isinstance(field, messages.BooleanField): + parameters.append((field_name, item and 'true' or 'false')) + else: + if isinstance(item, six.text_type): + item = item.encode('utf-8') + parameters.append((field_name, str(item))) + + return has_any_values + + build_message(message, prefix) + + # Also add any unrecognized values from the decoded string. + for key in message.all_unrecognized_fields(): + values, _ = message.get_unrecognized_field_info(key) + if not isinstance(values, (list, tuple)): + values = (values,) + for value in values: + parameters.append((key, value)) + + return urllib.urlencode(parameters) + + +def decode_message(message_type, encoded_message, **kwargs): + """Decode urlencoded content to message. + + Args: + message_type: Message instance to merge URL encoded content into. + encoded_message: URL encoded message. + prefix: Prefix to append to field names of contained values. + + Returns: + Decoded instance of message_type. + """ + message = message_type() + builder = URLEncodedRequestBuilder(message, **kwargs) + arguments = cgi.parse_qs(encoded_message, keep_blank_values=True) + for argument, values in sorted(six.iteritems(arguments)): + added = builder.add_parameter(argument, values) + # Save off any unknown values, so they're still accessible. + if not added: + message.set_unrecognized_field(argument, values, messages.Variant.STRING) + message.check_initialized() + return message diff --git a/endpoints/bundled/protorpc/registry.py b/endpoints/bundled/protorpc/registry.py new file mode 100644 index 0000000..23ba876 --- /dev/null +++ b/endpoints/bundled/protorpc/registry.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Service regsitry for service discovery. + +The registry service can be deployed on a server in order to provide a +central place where remote clients can discover available. + +On the server side, each service is registered by their name which is unique +to the registry. Typically this name provides enough information to identify +the service and locate it within a server. For example, for an HTTP based +registry the name is the URL path on the host where the service is invocable. + +The registry is also able to resolve the full descriptor.FileSet necessary to +describe the service and all required data-types (messages and enums). + +A configured registry is itself a remote service and should reference itself. +""" + +import sys + +from . import descriptor +from . import messages +from . import remote +from . import util + + +__all__ = [ + 'ServiceMapping', + 'ServicesResponse', + 'GetFileSetRequest', + 'GetFileSetResponse', + 'RegistryService', +] + + +class ServiceMapping(messages.Message): + """Description of registered service. + + Fields: + name: Name of service. On HTTP based services this will be the + URL path used for invocation. + definition: Fully qualified name of the service definition. Useful + for clients that can look up service definitions based on an existing + repository of definitions. + """ + + name = messages.StringField(1, required=True) + definition = messages.StringField(2, required=True) + + +class ServicesResponse(messages.Message): + """Response containing all registered services. + + May also contain complete descriptor file-set for all services known by the + registry. + + Fields: + services: Service mappings for all registered services in registry. + file_set: Descriptor file-set describing all services, messages and enum + types needed for use with all requested services if asked for in the + request. + """ + + services = messages.MessageField(ServiceMapping, 1, repeated=True) + + +class GetFileSetRequest(messages.Message): + """Request for service descriptor file-set. + + Request to retrieve file sets for specific services. + + Fields: + names: Names of services to retrieve file-set for. + """ + + names = messages.StringField(1, repeated=True) + + +class GetFileSetResponse(messages.Message): + """Descriptor file-set for all names in GetFileSetRequest. + + Fields: + file_set: Descriptor file-set containing all descriptors for services, + messages and enum types needed for listed names in request. + """ + + file_set = messages.MessageField(descriptor.FileSet, 1, required=True) + + +class RegistryService(remote.Service): + """Registry service. + + Maps names to services and is able to describe all descriptor file-sets + necessary to use contined services. + + On an HTTP based server, the name is the URL path to the service. + """ + + @util.positional(2) + def __init__(self, registry, modules=None): + """Constructor. + + Args: + registry: Map of name to service class. This map is not copied and may + be modified after the reigstry service has been configured. + modules: Module dict to draw descriptors from. Defaults to sys.modules. + """ + # Private Attributes: + # __registry: Map of name to service class. Refers to same instance as + # registry parameter. + # __modules: Mapping of module name to module. + # __definition_to_modules: Mapping of definition types to set of modules + # that they refer to. This cache is used to make repeated look-ups + # faster and to prevent circular references from causing endless loops. + + self.__registry = registry + if modules is None: + modules = sys.modules + self.__modules = modules + # This cache will only last for a single request. + self.__definition_to_modules = {} + + def __find_modules_for_message(self, message_type): + """Find modules referred to by a message type. + + Determines the entire list of modules ultimately referred to by message_type + by iterating over all of its message and enum fields. Includes modules + referred to fields within its referred messages. + + Args: + message_type: Message type to find all referring modules for. + + Returns: + Set of modules referred to by message_type by traversing all its + message and enum fields. + """ + # TODO(rafek): Maybe this should be a method on Message and Service? + def get_dependencies(message_type, seen=None): + """Get all dependency definitions of a message type. + + This function works by collecting the types of all enumeration and message + fields defined within the message type. When encountering a message + field, it will recursivly find all of the associated message's + dependencies. It will terminate on circular dependencies by keeping track + of what definitions it already via the seen set. + + Args: + message_type: Message type to get dependencies for. + seen: Set of definitions that have already been visited. + + Returns: + All dependency message and enumerated types associated with this message + including the message itself. + """ + if seen is None: + seen = set() + seen.add(message_type) + + for field in message_type.all_fields(): + if isinstance(field, messages.MessageField): + if field.message_type not in seen: + get_dependencies(field.message_type, seen) + elif isinstance(field, messages.EnumField): + seen.add(field.type) + + return seen + + found_modules = self.__definition_to_modules.setdefault(message_type, set()) + if not found_modules: + dependencies = get_dependencies(message_type) + found_modules.update(self.__modules[definition.__module__] + for definition in dependencies) + + return found_modules + + def __describe_file_set(self, names): + """Get file-set for named services. + + Args: + names: List of names to get file-set for. + + Returns: + descriptor.FileSet containing all the descriptors for all modules + ultimately referred to by all service types request by names parameter. + """ + service_modules = set() + if names: + for service in (self.__registry[name] for name in names): + found_modules = self.__definition_to_modules.setdefault(service, set()) + if not found_modules: + found_modules.add(self.__modules[service.__module__]) + for method_name in service.all_remote_methods(): + method = getattr(service, method_name) + for message_type in (method.remote.request_type, + method.remote.response_type): + found_modules.update( + self.__find_modules_for_message(message_type)) + service_modules.update(found_modules) + + return descriptor.describe_file_set(service_modules) + + @property + def registry(self): + """Get service registry associated with this service instance.""" + return self.__registry + + @remote.method(response_type=ServicesResponse) + def services(self, request): + """Get all registered services.""" + response = ServicesResponse() + response.services = [] + for name, service_class in self.__registry.items(): + mapping = ServiceMapping() + mapping.name = name.decode('utf-8') + mapping.definition = service_class.definition_name().decode('utf-8') + response.services.append(mapping) + + return response + + @remote.method(GetFileSetRequest, GetFileSetResponse) + def get_file_set(self, request): + """Get file-set for registered servies.""" + response = GetFileSetResponse() + response.file_set = self.__describe_file_set(request.names) + return response diff --git a/endpoints/bundled/protorpc/remote.py b/endpoints/bundled/protorpc/remote.py new file mode 100644 index 0000000..61fe6c8 --- /dev/null +++ b/endpoints/bundled/protorpc/remote.py @@ -0,0 +1,1248 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Remote service library. + +This module contains classes that are useful for building remote services that +conform to a standard request and response model. To conform to this model +a service must be like the following class: + + # Each service instance only handles a single request and is then discarded. + # Make these objects light weight. + class Service(object): + + # It must be possible to construct service objects without any parameters. + # If your constructor needs extra information you should provide a + # no-argument factory function to create service instances. + def __init__(self): + ... + + # Each remote method must use the 'method' decorator, passing the request + # and response message types. The remote method itself must take a single + # parameter which is an instance of RequestMessage and return an instance + # of ResponseMessage. + @method(RequestMessage, ResponseMessage) + def remote_method(self, request): + # Return an instance of ResponseMessage. + + # A service object may optionally implement an 'initialize_request_state' + # method that takes as a parameter a single instance of a RequestState. If + # a service does not implement this method it will not receive the request + # state. + def initialize_request_state(self, state): + ... + +The 'Service' class is provided as a convenient base class that provides the +above functionality. It implements all required and optional methods for a +service. It also has convenience methods for creating factory functions that +can pass persistent global state to a new service instance. + +The 'method' decorator is used to declare which methods of a class are +meant to service RPCs. While this decorator is not responsible for handling +actual remote method invocations, such as handling sockets, handling various +RPC protocols and checking messages for correctness, it does attach information +to methods that responsible classes can examine and ensure the correctness +of the RPC. + +When the method decorator is used on a method, the wrapper method will have a +'remote' property associated with it. The 'remote' property contains the +request_type and response_type expected by the methods implementation. + +On its own, the method decorator does not provide any support for subclassing +remote methods. In order to extend a service, one would need to redecorate +the sub-classes methods. For example: + + class MyService(Service): + + @method(DoSomethingRequest, DoSomethingResponse) + def do_stuff(self, request): + ... implement do_stuff ... + + class MyBetterService(MyService): + + @method(DoSomethingRequest, DoSomethingResponse) + def do_stuff(self, request): + response = super(MyBetterService, self).do_stuff.remote.method(request) + ... do stuff with response ... + return response + +A Service subclass also has a Stub class that can be used with a transport for +making RPCs. When a stub is created, it is capable of doing both synchronous +and asynchronous RPCs if the underlying transport supports it. To make a stub +using an HTTP transport do: + + my_service = MyService.Stub(HttpTransport('')) + +For synchronous calls, just call the expected methods on the service stub: + + request = DoSomethingRequest() + ... + response = my_service.do_something(request) + +Each stub instance has an async object that can be used for initiating +asynchronous RPCs if the underlying protocol transport supports it. To +make an asynchronous call, do: + + rpc = my_service.async.do_something(request) + response = rpc.get_response() +""" + +from __future__ import with_statement +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import functools +import logging +import sys +import threading +from wsgiref import headers as wsgi_headers + +from . import message_types +from . import messages +from . import protobuf +from . import protojson +from . import util + + +__all__ = [ + 'ApplicationError', + 'MethodNotFoundError', + 'NetworkError', + 'RequestError', + 'RpcError', + 'ServerError', + 'ServiceConfigurationError', + 'ServiceDefinitionError', + + 'HttpRequestState', + 'ProtocolConfig', + 'Protocols', + 'RequestState', + 'RpcState', + 'RpcStatus', + 'Service', + 'StubBase', + 'check_rpc_status', + 'get_remote_method_info', + 'is_error_status', + 'method', + 'remote', +] + + +class ServiceDefinitionError(messages.Error): + """Raised when a service is improperly defined.""" + + +class ServiceConfigurationError(messages.Error): + """Raised when a service is incorrectly configured.""" + + +# TODO: Use error_name to map to specific exception message types. +class RpcStatus(messages.Message): + """Status of on-going or complete RPC. + + Fields: + state: State of RPC. + error_name: Error name set by application. Only set when + status is APPLICATION_ERROR. For use by application to transmit + specific reason for error. + error_message: Error message associated with status. + """ + + class State(messages.Enum): + """Enumeration of possible RPC states. + + Values: + OK: Completed successfully. + RUNNING: Still running, not complete. + REQUEST_ERROR: Request was malformed or incomplete. + SERVER_ERROR: Server experienced an unexpected error. + NETWORK_ERROR: An error occured on the network. + APPLICATION_ERROR: The application is indicating an error. + When in this state, RPC should also set application_error. + """ + OK = 0 + RUNNING = 1 + + REQUEST_ERROR = 2 + SERVER_ERROR = 3 + NETWORK_ERROR = 4 + APPLICATION_ERROR = 5 + METHOD_NOT_FOUND_ERROR = 6 + + state = messages.EnumField(State, 1, required=True) + error_message = messages.StringField(2) + error_name = messages.StringField(3) + + +RpcState = RpcStatus.State + + +class RpcError(messages.Error): + """Base class for RPC errors. + + Each sub-class of RpcError is associated with an error value from RpcState + and has an attribute STATE that refers to that value. + """ + + def __init__(self, message, cause=None): + super(RpcError, self).__init__(message) + self.cause = cause + + @classmethod + def from_state(cls, state): + """Get error class from RpcState. + + Args: + state: RpcState value. Can be enum value itself, string or int. + + Returns: + Exception class mapped to value if state is an error. Returns None + if state is OK or RUNNING. + """ + return _RPC_STATE_TO_ERROR.get(RpcState(state)) + + +class RequestError(RpcError): + """Raised when wrong request objects received during method invocation.""" + + STATE = RpcState.REQUEST_ERROR + + +class MethodNotFoundError(RequestError): + """Raised when unknown method requested by RPC.""" + + STATE = RpcState.METHOD_NOT_FOUND_ERROR + + +class NetworkError(RpcError): + """Raised when network error occurs during RPC.""" + + STATE = RpcState.NETWORK_ERROR + + +class ServerError(RpcError): + """Unexpected error occured on server.""" + + STATE = RpcState.SERVER_ERROR + + +class ApplicationError(RpcError): + """Raised for application specific errors. + + Attributes: + error_name: Application specific error name for exception. + """ + + STATE = RpcState.APPLICATION_ERROR + + def __init__(self, message, error_name=None): + """Constructor. + + Args: + message: Application specific error message. + error_name: Application specific error name. Must be None, string + or unicode string. + """ + super(ApplicationError, self).__init__(message) + self.error_name = error_name + + def __str__(self): + return self.args[0] or '' + + def __repr__(self): + if self.error_name is None: + error_format = '' + else: + error_format = ', %r' % self.error_name + return '%s(%r%s)' % (type(self).__name__, self.args[0], error_format) + + +_RPC_STATE_TO_ERROR = { + RpcState.REQUEST_ERROR: RequestError, + RpcState.NETWORK_ERROR: NetworkError, + RpcState.SERVER_ERROR: ServerError, + RpcState.APPLICATION_ERROR: ApplicationError, + RpcState.METHOD_NOT_FOUND_ERROR: MethodNotFoundError, +} + +class _RemoteMethodInfo(object): + """Object for encapsulating remote method information. + + An instance of this method is associated with the 'remote' attribute + of the methods 'invoke_remote_method' instance. + + Instances of this class are created by the remote decorator and should not + be created directly. + """ + + def __init__(self, + method, + request_type, + response_type): + """Constructor. + + Args: + method: The method which implements the remote method. This is a + function that will act as an instance method of a class definition + that is decorated by '@method'. It must always take 'self' as its + first parameter. + request_type: Expected request type for the remote method. + response_type: Expected response type for the remote method. + """ + self.__method = method + self.__request_type = request_type + self.__response_type = response_type + + @property + def method(self): + """Original undecorated method.""" + return self.__method + + @property + def request_type(self): + """Expected request type for remote method.""" + if isinstance(self.__request_type, six.string_types): + self.__request_type = messages.find_definition( + self.__request_type, + relative_to=sys.modules[self.__method.__module__]) + return self.__request_type + + @property + def response_type(self): + """Expected response type for remote method.""" + if isinstance(self.__response_type, six.string_types): + self.__response_type = messages.find_definition( + self.__response_type, + relative_to=sys.modules[self.__method.__module__]) + return self.__response_type + + +def method(request_type=message_types.VoidMessage, + response_type=message_types.VoidMessage): + """Method decorator for creating remote methods. + + Args: + request_type: Message type of expected request. + response_type: Message type of expected response. + + Returns: + 'remote_method_wrapper' function. + + Raises: + TypeError: if the request_type or response_type parameters are not + proper subclasses of messages.Message. + """ + if (not isinstance(request_type, six.string_types) and + (not isinstance(request_type, type) or + not issubclass(request_type, messages.Message) or + request_type is messages.Message)): + raise TypeError( + 'Must provide message class for request-type. Found %s', + request_type) + + if (not isinstance(response_type, six.string_types) and + (not isinstance(response_type, type) or + not issubclass(response_type, messages.Message) or + response_type is messages.Message)): + raise TypeError( + 'Must provide message class for response-type. Found %s', + response_type) + + def remote_method_wrapper(method): + """Decorator used to wrap method. + + Args: + method: Original method being wrapped. + + Returns: + 'invoke_remote_method' function responsible for actual invocation. + This invocation function instance is assigned an attribute 'remote' + which contains information about the remote method: + request_type: Expected request type for remote method. + response_type: Response type returned from remote method. + + Raises: + TypeError: If request_type or response_type is not a subclass of Message + or is the Message class itself. + """ + + @functools.wraps(method) + def invoke_remote_method(service_instance, request): + """Function used to replace original method. + + Invoke wrapped remote method. Checks to ensure that request and + response objects are the correct types. + + Does not check whether messages are initialized. + + Args: + service_instance: The service object whose method is being invoked. + This is passed to 'self' during the invocation of the original + method. + request: Request message. + + Returns: + Results of calling wrapped remote method. + + Raises: + RequestError: Request object is not of the correct type. + ServerError: Response object is not of the correct type. + """ + if not isinstance(request, remote_method_info.request_type): + raise RequestError('Method %s.%s expected request type %s, ' + 'received %s' % + (type(service_instance).__name__, + method.__name__, + remote_method_info.request_type, + type(request))) + response = method(service_instance, request) + if not isinstance(response, remote_method_info.response_type): + raise ServerError('Method %s.%s expected response type %s, ' + 'sent %s' % + (type(service_instance).__name__, + method.__name__, + remote_method_info.response_type, + type(response))) + return response + + remote_method_info = _RemoteMethodInfo(method, + request_type, + response_type) + + invoke_remote_method.remote = remote_method_info + return invoke_remote_method + + return remote_method_wrapper + + +def remote(request_type, response_type): + """Temporary backward compatibility alias for method.""" + logging.warning('The remote decorator has been renamed method. It will be ' + 'removed in very soon from future versions of ProtoRPC.') + return method(request_type, response_type) + + +def get_remote_method_info(method): + """Get remote method info object from remote method. + + Returns: + Remote method info object if method is a remote method, else None. + """ + if not callable(method): + return None + + try: + method_info = method.remote + except AttributeError: + return None + + if not isinstance(method_info, _RemoteMethodInfo): + return None + + return method_info + + +class StubBase(object): + """Base class for client side service stubs. + + The remote method stubs are created by the _ServiceClass meta-class + when a Service class is first created. The resulting stub will + extend both this class and the service class it handles communications for. + + Assume that there is a service: + + class NewContactRequest(messages.Message): + + name = messages.StringField(1, required=True) + phone = messages.StringField(2) + email = messages.StringField(3) + + class NewContactResponse(message.Message): + + contact_id = messages.StringField(1) + + class AccountService(remote.Service): + + @remote.method(NewContactRequest, NewContactResponse): + def new_contact(self, request): + ... implementation ... + + A stub of this service can be called in two ways. The first is to pass in a + correctly initialized NewContactRequest message: + + request = NewContactRequest() + request.name = 'Bob Somebody' + request.phone = '+1 415 555 1234' + + response = account_service_stub.new_contact(request) + + The second way is to pass in keyword parameters that correspond with the root + request message type: + + account_service_stub.new_contact(name='Bob Somebody', + phone='+1 415 555 1234') + + The second form will create a request message of the appropriate type. + """ + + def __init__(self, transport): + """Constructor. + + Args: + transport: Underlying transport to communicate with remote service. + """ + self.__transport = transport + + @property + def transport(self): + """Transport used to communicate with remote service.""" + return self.__transport + + +class _ServiceClass(type): + """Meta-class for service class.""" + + def __new_async_method(cls, remote): + """Create asynchronous method for Async handler. + + Args: + remote: RemoteInfo to create method for. + """ + def async_method(self, *args, **kwargs): + """Asynchronous remote method. + + Args: + self: Instance of StubBase.Async subclass. + + Stub methods either take a single positional argument when a full + request message is passed in, or keyword arguments, but not both. + + See docstring for StubBase for more information on how to use remote + stub methods. + + Returns: + Rpc instance used to represent asynchronous RPC. + """ + if args and kwargs: + raise TypeError('May not provide both args and kwargs') + + if not args: + # Construct request object from arguments. + request = remote.request_type() + for name, value in six.iteritems(kwargs): + setattr(request, name, value) + else: + # First argument is request object. + request = args[0] + + return self.transport.send_rpc(remote, request) + + async_method.__name__ = remote.method.__name__ + async_method = util.positional(2)(async_method) + async_method.remote = remote + return async_method + + def __new_sync_method(cls, async_method): + """Create synchronous method for stub. + + Args: + async_method: asynchronous method to delegate calls to. + """ + def sync_method(self, *args, **kwargs): + """Synchronous remote method. + + Args: + self: Instance of StubBase.Async subclass. + args: Tuple (request,): + request: Request object. + kwargs: Field values for request. Must be empty if request object + is provided. + + Returns: + Response message from synchronized RPC. + """ + return async_method(self.async, *args, **kwargs).response + sync_method.__name__ = async_method.__name__ + sync_method.remote = async_method.remote + return sync_method + + def __create_async_methods(cls, remote_methods): + """Construct a dictionary of asynchronous methods based on remote methods. + + Args: + remote_methods: Dictionary of methods with associated RemoteInfo objects. + + Returns: + Dictionary of asynchronous methods with assocaited RemoteInfo objects. + Results added to AsyncStub subclass. + """ + async_methods = {} + for method_name, method in remote_methods.items(): + async_methods[method_name] = cls.__new_async_method(method.remote) + return async_methods + + def __create_sync_methods(cls, async_methods): + """Construct a dictionary of synchronous methods based on remote methods. + + Args: + async_methods: Dictionary of async methods to delegate calls to. + + Returns: + Dictionary of synchronous methods with assocaited RemoteInfo objects. + Results added to Stub subclass. + """ + sync_methods = {} + for method_name, async_method in async_methods.items(): + sync_methods[method_name] = cls.__new_sync_method(async_method) + return sync_methods + + def __new__(cls, name, bases, dct): + """Instantiate new service class instance.""" + if StubBase not in bases: + # Collect existing remote methods. + base_methods = {} + for base in bases: + try: + remote_methods = base.__remote_methods + except AttributeError: + pass + else: + base_methods.update(remote_methods) + + # Set this class private attribute so that base_methods do not have + # to be recacluated in __init__. + dct['_ServiceClass__base_methods'] = base_methods + + for attribute, value in dct.items(): + base_method = base_methods.get(attribute, None) + if base_method: + if not callable(value): + raise ServiceDefinitionError( + 'Must override %s in %s with a method.' % ( + attribute, name)) + + if get_remote_method_info(value): + raise ServiceDefinitionError( + 'Do not use method decorator when overloading remote method %s ' + 'on service %s.' % + (attribute, name)) + + base_remote_method_info = get_remote_method_info(base_method) + remote_decorator = method( + base_remote_method_info.request_type, + base_remote_method_info.response_type) + new_remote_method = remote_decorator(value) + dct[attribute] = new_remote_method + + return type.__new__(cls, name, bases, dct) + + def __init__(cls, name, bases, dct): + """Create uninitialized state on new class.""" + type.__init__(cls, name, bases, dct) + + # Only service implementation classes should have remote methods and stub + # sub classes created. Stub implementations have their own methods passed + # in to the type constructor. + if StubBase not in bases: + # Create list of remote methods. + cls.__remote_methods = dict(cls.__base_methods) + + for attribute, value in dct.items(): + value = getattr(cls, attribute) + remote_method_info = get_remote_method_info(value) + if remote_method_info: + cls.__remote_methods[attribute] = value + + # Build asynchronous stub class. + stub_attributes = {'Service': cls} + async_methods = cls.__create_async_methods(cls.__remote_methods) + stub_attributes.update(async_methods) + async_class = type('AsyncStub', (StubBase, cls), stub_attributes) + cls.AsyncStub = async_class + + # Constructor for synchronous stub class. + def __init__(self, transport): + """Constructor. + + Args: + transport: Underlying transport to communicate with remote service. + """ + super(cls.Stub, self).__init__(transport) + self.async = cls.AsyncStub(transport) + + # Build synchronous stub class. + stub_attributes = {'Service': cls, + '__init__': __init__} + stub_attributes.update(cls.__create_sync_methods(async_methods)) + + cls.Stub = type('Stub', (StubBase, cls), stub_attributes) + + @staticmethod + def all_remote_methods(cls): + """Get all remote methods of service. + + Returns: + Dict from method name to unbound method. + """ + return dict(cls.__remote_methods) + + +class RequestState(object): + """Request state information. + + Properties: + remote_host: Remote host name where request originated. + remote_address: IP address where request originated. + server_host: Host of server within which service resides. + server_port: Post which service has recevied request from. + """ + + @util.positional(1) + def __init__(self, + remote_host=None, + remote_address=None, + server_host=None, + server_port=None): + """Constructor. + + Args: + remote_host: Assigned to property. + remote_address: Assigned to property. + server_host: Assigned to property. + server_port: Assigned to property. + """ + self.__remote_host = remote_host + self.__remote_address = remote_address + self.__server_host = server_host + self.__server_port = server_port + + @property + def remote_host(self): + return self.__remote_host + + @property + def remote_address(self): + return self.__remote_address + + @property + def server_host(self): + return self.__server_host + + @property + def server_port(self): + return self.__server_port + + def _repr_items(self): + for name in ['remote_host', + 'remote_address', + 'server_host', + 'server_port']: + yield name, getattr(self, name) + + def __repr__(self): + """String representation of state.""" + state = [self.__class__.__name__] + for name, value in self._repr_items(): + if value: + state.append('%s=%r' % (name, value)) + + return '<%s>' % (' '.join(state),) + + +class HttpRequestState(RequestState): + """HTTP request state information. + + NOTE: Does not attempt to represent certain types of information from the + request such as the query string as query strings are not permitted in + ProtoRPC URLs unless required by the underlying message format. + + Properties: + headers: wsgiref.headers.Headers instance of HTTP request headers. + http_method: HTTP method as a string. + service_path: Path on HTTP service where service is mounted. This path + will not include the remote method name. + """ + + @util.positional(1) + def __init__(self, + http_method=None, + service_path=None, + headers=None, + **kwargs): + """Constructor. + + Args: + Same as RequestState, including: + http_method: Assigned to property. + service_path: Assigned to property. + headers: HTTP request headers. If instance of Headers, assigned to + property without copying. If dict, will convert to name value pairs + for use with Headers constructor. Otherwise, passed as parameters to + Headers constructor. + """ + super(HttpRequestState, self).__init__(**kwargs) + + self.__http_method = http_method + self.__service_path = service_path + + # Initialize headers. + if isinstance(headers, dict): + header_list = [] + for key, value in sorted(headers.items()): + if not isinstance(value, list): + value = [value] + for item in value: + header_list.append((key, item)) + headers = header_list + self.__headers = wsgi_headers.Headers(headers or []) + + @property + def http_method(self): + return self.__http_method + + @property + def service_path(self): + return self.__service_path + + @property + def headers(self): + return self.__headers + + def _repr_items(self): + for item in super(HttpRequestState, self)._repr_items(): + yield item + + for name in ['http_method', 'service_path']: + yield name, getattr(self, name) + + yield 'headers', list(self.headers.items()) + + +class Service(six.with_metaclass(_ServiceClass, object)): + """Service base class. + + Base class used for defining remote services. Contains reflection functions, + useful helpers and built-in remote methods. + + Services are expected to be constructed via either a constructor or factory + which takes no parameters. However, it might be required that some state or + configuration is passed in to a service across multiple requests. + + To do this, define parameters to the constructor of the service and use + the 'new_factory' class method to build a constructor that will transmit + parameters to the constructor. For example: + + class MyService(Service): + + def __init__(self, configuration, state): + self.configuration = configuration + self.state = state + + configuration = MyServiceConfiguration() + global_state = MyServiceState() + + my_service_factory = MyService.new_factory(configuration, + state=global_state) + + The contract with any service handler is that a new service object is created + to handle each user request, and that the construction does not take any + parameters. The factory satisfies this condition: + + new_instance = my_service_factory() + assert new_instance.state is global_state + + Attributes: + request_state: RequestState set via initialize_request_state. + """ + + __request_state = None + + @classmethod + def all_remote_methods(cls): + """Get all remote methods for service class. + + Built-in methods do not appear in the dictionary of remote methods. + + Returns: + Dictionary mapping method name to remote method. + """ + return _ServiceClass.all_remote_methods(cls) + + @classmethod + def new_factory(cls, *args, **kwargs): + """Create factory for service. + + Useful for passing configuration or state objects to the service. Accepts + arbitrary parameters and keywords, however, underlying service must accept + also accept not other parameters in its constructor. + + Args: + args: Args to pass to service constructor. + kwargs: Keyword arguments to pass to service constructor. + + Returns: + Factory function that will create a new instance and forward args and + keywords to the constructor. + """ + + def service_factory(): + return cls(*args, **kwargs) + + # Update docstring so that it is easier to debug. + full_class_name = '%s.%s' % (cls.__module__, cls.__name__) + service_factory.__doc__ = ( + 'Creates new instances of service %s.\n\n' + 'Returns:\n' + ' New instance of %s.' + % (cls.__name__, full_class_name)) + + # Update name so that it is easier to debug the factory function. + service_factory.__name__ = '%s_service_factory' % cls.__name__ + + service_factory.service_class = cls + + return service_factory + + def initialize_request_state(self, request_state): + """Save request state for use in remote method. + + Args: + request_state: RequestState instance. + """ + self.__request_state = request_state + + @classmethod + def definition_name(cls): + """Get definition name for Service class. + + Package name is determined by the global 'package' attribute in the + module that contains the Service definition. If no 'package' attribute + is available, uses module name. If no module is found, just uses class + name as name. + + Returns: + Fully qualified service name. + """ + try: + return cls.__definition_name + except AttributeError: + outer_definition_name = cls.outer_definition_name() + if outer_definition_name is None: + cls.__definition_name = cls.__name__ + else: + cls.__definition_name = '%s.%s' % (outer_definition_name, cls.__name__) + + return cls.__definition_name + + @classmethod + def outer_definition_name(cls): + """Get outer definition name. + + Returns: + Package for service. Services are never nested inside other definitions. + """ + return cls.definition_package() + + @classmethod + def definition_package(cls): + """Get package for service. + + Returns: + Package name for service. + """ + try: + return cls.__definition_package + except AttributeError: + cls.__definition_package = util.get_package_for_module(cls.__module__) + + return cls.__definition_package + + @property + def request_state(self): + """Request state associated with this Service instance.""" + return self.__request_state + + +def is_error_status(status): + """Function that determines whether the RPC status is an error. + + Args: + status: Initialized RpcStatus message to check for errors. + """ + status.check_initialized() + return RpcError.from_state(status.state) is not None + + +def check_rpc_status(status): + """Function converts an error status to a raised exception. + + Args: + status: Initialized RpcStatus message to check for errors. + + Raises: + RpcError according to state set on status, if it is an error state. + """ + status.check_initialized() + error_class = RpcError.from_state(status.state) + if error_class is not None: + if error_class is ApplicationError: + raise error_class(status.error_message, status.error_name) + else: + raise error_class(status.error_message) + + +class ProtocolConfig(object): + """Configuration for single protocol mapping. + + A read-only protocol configuration provides a given protocol implementation + with a name and a set of content-types that it recognizes. + + Properties: + protocol: The protocol implementation for configuration (usually a module, + for example, protojson, protobuf, etc.). This is an object that has the + following attributes: + CONTENT_TYPE: Used as the default content-type if default_content_type + is not set. + ALTERNATIVE_CONTENT_TYPES (optional): A list of alternative + content-types to the default that indicate the same protocol. + encode_message: Function that matches the signature of + ProtocolConfig.encode_message. Used for encoding a ProtoRPC message. + decode_message: Function that matches the signature of + ProtocolConfig.decode_message. Used for decoding a ProtoRPC message. + name: Name of protocol configuration. + default_content_type: The default content type for the protocol. Overrides + CONTENT_TYPE defined on protocol. + alternative_content_types: A list of alternative content-types supported + by the protocol. Must not contain the default content-type, nor + duplicates. Overrides ALTERNATIVE_CONTENT_TYPE defined on protocol. + content_types: A list of all content-types supported by configuration. + Combination of default content-type and alternatives. + """ + + def __init__(self, + protocol, + name, + default_content_type=None, + alternative_content_types=None): + """Constructor. + + Args: + protocol: The protocol implementation for configuration. + name: The name of the protocol configuration. + default_content_type: The default content-type for protocol. If none + provided it will check protocol.CONTENT_TYPE. + alternative_content_types: A list of content-types. If none provided, + it will check protocol.ALTERNATIVE_CONTENT_TYPES. If that attribute + does not exist, will be an empty tuple. + + Raises: + ServiceConfigurationError if there are any duplicate content-types. + """ + self.__protocol = protocol + self.__name = name + self.__default_content_type = (default_content_type or + protocol.CONTENT_TYPE).lower() + if alternative_content_types is None: + alternative_content_types = getattr(protocol, + 'ALTERNATIVE_CONTENT_TYPES', + ()) + self.__alternative_content_types = tuple( + content_type.lower() for content_type in alternative_content_types) + self.__content_types = ( + (self.__default_content_type,) + self.__alternative_content_types) + + # Detect duplicate content types in definition. + previous_type = None + for content_type in sorted(self.content_types): + if content_type == previous_type: + raise ServiceConfigurationError( + 'Duplicate content-type %s' % content_type) + previous_type = content_type + + @property + def protocol(self): + return self.__protocol + + @property + def name(self): + return self.__name + + @property + def default_content_type(self): + return self.__default_content_type + + @property + def alternate_content_types(self): + return self.__alternative_content_types + + @property + def content_types(self): + return self.__content_types + + def encode_message(self, message): + """Encode message. + + Args: + message: Message instance to encode. + + Returns: + String encoding of Message instance encoded in protocol's format. + """ + return self.__protocol.encode_message(message) + + def decode_message(self, message_type, encoded_message): + """Decode buffer to Message instance. + + Args: + message_type: Message type to decode data to. + encoded_message: Encoded version of message as string. + + Returns: + Decoded instance of message_type. + """ + return self.__protocol.decode_message(message_type, encoded_message) + + +class Protocols(object): + """Collection of protocol configurations. + + Used to describe a complete set of content-type mappings for multiple + protocol configurations. + + Properties: + names: Sorted list of the names of registered protocols. + content_types: Sorted list of supported content-types. + """ + + __default_protocols = None + __lock = threading.Lock() + + def __init__(self): + """Constructor.""" + self.__by_name = {} + self.__by_content_type = {} + + def add_protocol_config(self, config): + """Add a protocol configuration to protocol mapping. + + Args: + config: A ProtocolConfig. + + Raises: + ServiceConfigurationError if protocol.name is already registered + or any of it's content-types are already registered. + """ + if config.name in self.__by_name: + raise ServiceConfigurationError( + 'Protocol name %r is already in use' % config.name) + for content_type in config.content_types: + if content_type in self.__by_content_type: + raise ServiceConfigurationError( + 'Content type %r is already in use' % content_type) + + self.__by_name[config.name] = config + self.__by_content_type.update((t, config) for t in config.content_types) + + def add_protocol(self, *args, **kwargs): + """Add a protocol configuration from basic parameters. + + Simple helper method that creates and registeres a ProtocolConfig instance. + """ + self.add_protocol_config(ProtocolConfig(*args, **kwargs)) + + @property + def names(self): + return tuple(sorted(self.__by_name)) + + @property + def content_types(self): + return tuple(sorted(self.__by_content_type)) + + def lookup_by_name(self, name): + """Look up a ProtocolConfig by name. + + Args: + name: Name of protocol to look for. + + Returns: + ProtocolConfig associated with name. + + Raises: + KeyError if there is no protocol for name. + """ + return self.__by_name[name.lower()] + + def lookup_by_content_type(self, content_type): + """Look up a ProtocolConfig by content-type. + + Args: + content_type: Content-type to find protocol configuration for. + + Returns: + ProtocolConfig associated with content-type. + + Raises: + KeyError if there is no protocol for content-type. + """ + return self.__by_content_type[content_type.lower()] + + @classmethod + def new_default(cls): + """Create default protocols configuration. + + Returns: + New Protocols instance configured for protobuf and protorpc. + """ + protocols = cls() + protocols.add_protocol(protobuf, 'protobuf') + protocols.add_protocol(protojson.ProtoJson.get_default(), 'protojson') + return protocols + + @classmethod + def get_default(cls): + """Get the global default Protocols instance. + + Returns: + Current global default Protocols instance. + """ + default_protocols = cls.__default_protocols + if default_protocols is None: + with cls.__lock: + default_protocols = cls.__default_protocols + if default_protocols is None: + default_protocols = cls.new_default() + cls.__default_protocols = default_protocols + return default_protocols + + @classmethod + def set_default(cls, protocols): + """Set the global default Protocols instance. + + Args: + protocols: A Protocols instance. + + Raises: + TypeError: If protocols is not an instance of Protocols. + """ + if not isinstance(protocols, Protocols): + raise TypeError( + 'Expected value of type "Protocols", found %r' % protocols) + with cls.__lock: + cls.__default_protocols = protocols diff --git a/endpoints/bundled/protorpc/transport.py b/endpoints/bundled/protorpc/transport.py new file mode 100644 index 0000000..5d7e564 --- /dev/null +++ b/endpoints/bundled/protorpc/transport.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Transport library for ProtoRPC. + +Contains underlying infrastructure used for communicating RPCs over low level +transports such as HTTP. + +Includes HTTP transport built over urllib2. +""" + +import six.moves.http_client +import logging +import os +import socket +import sys +import urlparse + +from . import messages +from . import protobuf +from . import remote +from . import util +import six + +__all__ = [ + 'RpcStateError', + + 'HttpTransport', + 'LocalTransport', + 'Rpc', + 'Transport', +] + + +class RpcStateError(messages.Error): + """Raised when trying to put RPC in to an invalid state.""" + + +class Rpc(object): + """Represents a client side RPC. + + An RPC is created by the transport class and is used with a single RPC. While + an RPC is still in process, the response is set to None. When it is complete + the response will contain the response message. + """ + + def __init__(self, request): + """Constructor. + + Args: + request: Request associated with this RPC. + """ + self.__request = request + self.__response = None + self.__state = remote.RpcState.RUNNING + self.__error_message = None + self.__error_name = None + + @property + def request(self): + """Request associated with RPC.""" + return self.__request + + @property + def response(self): + """Response associated with RPC.""" + self.wait() + self.__check_status() + return self.__response + + @property + def state(self): + """State associated with RPC.""" + return self.__state + + @property + def error_message(self): + """Error, if any, associated with RPC.""" + self.wait() + return self.__error_message + + @property + def error_name(self): + """Error name, if any, associated with RPC.""" + self.wait() + return self.__error_name + + def wait(self): + """Wait for an RPC to finish.""" + if self.__state == remote.RpcState.RUNNING: + self._wait_impl() + + def _wait_impl(self): + """Implementation for wait().""" + raise NotImplementedError() + + def __check_status(self): + error_class = remote.RpcError.from_state(self.__state) + if error_class is not None: + if error_class is remote.ApplicationError: + raise error_class(self.__error_message, self.__error_name) + else: + raise error_class(self.__error_message) + + def __set_state(self, state, error_message=None, error_name=None): + if self.__state != remote.RpcState.RUNNING: + raise RpcStateError( + 'RPC must be in RUNNING state to change to %s' % state) + if state == remote.RpcState.RUNNING: + raise RpcStateError('RPC is already in RUNNING state') + self.__state = state + self.__error_message = error_message + self.__error_name = error_name + + def set_response(self, response): + # TODO: Even more specific type checking. + if not isinstance(response, messages.Message): + raise TypeError('Expected Message type, received %r' % (response)) + + self.__response = response + self.__set_state(remote.RpcState.OK) + + def set_status(self, status): + status.check_initialized() + self.__set_state(status.state, status.error_message, status.error_name) + + +class Transport(object): + """Transport base class. + + Provides basic support for implementing a ProtoRPC transport such as one + that can send and receive messages over HTTP. + + Implementations override _start_rpc. This method receives a RemoteInfo + instance and a request Message. The transport is expected to set the rpc + response or raise an exception before termination. + """ + + @util.positional(1) + def __init__(self, protocol=protobuf): + """Constructor. + + Args: + protocol: If string, will look up a protocol from the default Protocols + instance by name. Can also be an instance of remote.ProtocolConfig. + If neither, it must be an object that implements a protocol interface + by implementing encode_message, decode_message and set CONTENT_TYPE. + For example, the modules protobuf and protojson can be used directly. + """ + if isinstance(protocol, six.string_types): + protocols = remote.Protocols.get_default() + try: + protocol = protocols.lookup_by_name(protocol) + except KeyError: + protocol = protocols.lookup_by_content_type(protocol) + if isinstance(protocol, remote.ProtocolConfig): + self.__protocol = protocol.protocol + self.__protocol_config = protocol + else: + self.__protocol = protocol + self.__protocol_config = remote.ProtocolConfig( + protocol, 'default', default_content_type=protocol.CONTENT_TYPE) + + @property + def protocol(self): + """Protocol associated with this transport.""" + return self.__protocol + + @property + def protocol_config(self): + """Protocol associated with this transport.""" + return self.__protocol_config + + def send_rpc(self, remote_info, request): + """Initiate sending an RPC over the transport. + + Args: + remote_info: RemoteInfo instance describing remote method. + request: Request message to send to service. + + Returns: + An Rpc instance intialized with the request.. + """ + request.check_initialized() + + rpc = self._start_rpc(remote_info, request) + + return rpc + + def _start_rpc(self, remote_info, request): + """Start a remote procedure call. + + Args: + remote_info: RemoteInfo instance describing remote method. + request: Request message to send to service. + + Returns: + An Rpc instance initialized with the request. + """ + raise NotImplementedError() + + +class HttpTransport(Transport): + """Transport for communicating with HTTP servers.""" + + @util.positional(2) + def __init__(self, + service_url, + protocol=protobuf): + """Constructor. + + Args: + service_url: URL where the service is located. All communication via + the transport will go to this URL. + protocol: The protocol implementation. Must implement encode_message and + decode_message. Can also be an instance of remote.ProtocolConfig. + """ + super(HttpTransport, self).__init__(protocol=protocol) + self.__service_url = service_url + + def __get_rpc_status(self, response, content): + """Get RPC status from HTTP response. + + Args: + response: HTTPResponse object. + content: Content read from HTTP response. + + Returns: + RpcStatus object parsed from response, else an RpcStatus with a generic + HTTP error. + """ + # Status above 400 may have RpcStatus content. + if response.status >= 400: + content_type = response.getheader('content-type') + if content_type == self.protocol_config.default_content_type: + try: + rpc_status = self.protocol.decode_message(remote.RpcStatus, content) + except Exception as decode_err: + logging.warning( + 'An error occurred trying to parse status: %s\n%s', + str(decode_err), content) + else: + if rpc_status.is_initialized(): + return rpc_status + else: + logging.warning( + 'Body does not result in an initialized RpcStatus message:\n%s', + content) + + # If no RpcStatus message present, attempt to forward any content. If empty + # use standard error message. + if not content.strip(): + content = six.moves.http_client.responses.get(response.status, 'Unknown Error') + return remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, + error_message='HTTP Error %s: %s' % ( + response.status, content or 'Unknown Error')) + + def __set_response(self, remote_info, connection, rpc): + """Set response on RPC. + + Sets response or status from HTTP request. Implements the wait method of + Rpc instance. + + Args: + remote_info: Remote info for invoked RPC. + connection: HTTPConnection that is making request. + rpc: Rpc instance. + """ + try: + response = connection.getresponse() + + content = response.read() + + if response.status == six.moves.http_client.OK: + response = self.protocol.decode_message(remote_info.response_type, + content) + rpc.set_response(response) + else: + status = self.__get_rpc_status(response, content) + rpc.set_status(status) + finally: + connection.close() + + def _start_rpc(self, remote_info, request): + """Start a remote procedure call. + + Args: + remote_info: A RemoteInfo instance for this RPC. + request: The request message for this RPC. + + Returns: + An Rpc instance initialized with a Request. + """ + method_url = '%s.%s' % (self.__service_url, remote_info.method.__name__) + encoded_request = self.protocol.encode_message(request) + + url = urlparse.urlparse(method_url) + if url.scheme == 'https': + connection_type = six.moves.http_client.HTTPSConnection + else: + connection_type = six.moves.http_client.HTTPConnection + connection = connection_type(url.hostname, url.port) + try: + self._send_http_request(connection, url.path, encoded_request) + rpc = Rpc(request) + except remote.RpcError: + # Pass through all ProtoRPC errors + connection.close() + raise + except socket.error as err: + connection.close() + raise remote.NetworkError('Socket error: %s %r' % (type(err).__name__, + err.args), + err) + except Exception as err: + connection.close() + raise remote.NetworkError('Error communicating with HTTP server', + err) + else: + wait_impl = lambda: self.__set_response(remote_info, connection, rpc) + rpc._wait_impl = wait_impl + + return rpc + + def _send_http_request(self, connection, http_path, encoded_request): + connection.request( + 'POST', + http_path, + encoded_request, + headers={'Content-type': self.protocol_config.default_content_type, + 'Content-length': len(encoded_request)}) + + +class LocalTransport(Transport): + """Local transport that sends messages directly to services. + + Useful in tests or creating code that can work with either local or remote + services. Using LocalTransport is preferrable to simply instantiating a + single instance of a service and reusing it. The entire request process + involves instantiating a new instance of a service, initializing it with + request state and then invoking the remote method for every request. + """ + + def __init__(self, service_factory): + """Constructor. + + Args: + service_factory: Service factory or class. + """ + super(LocalTransport, self).__init__() + self.__service_class = getattr(service_factory, + 'service_class', + service_factory) + self.__service_factory = service_factory + + @property + def service_class(self): + return self.__service_class + + @property + def service_factory(self): + return self.__service_factory + + def _start_rpc(self, remote_info, request): + """Start a remote procedure call. + + Args: + remote_info: RemoteInfo instance describing remote method. + request: Request message to send to service. + + Returns: + An Rpc instance initialized with the request. + """ + rpc = Rpc(request) + def wait_impl(): + instance = self.__service_factory() + try: + initalize_request_state = instance.initialize_request_state + except AttributeError: + pass + else: + host = six.text_type(os.uname()[1]) + initalize_request_state(remote.RequestState(remote_host=host, + remote_address=u'127.0.0.1', + server_host=host, + server_port=-1)) + try: + response = remote_info.method(instance, request) + assert isinstance(response, remote_info.response_type) + except remote.ApplicationError: + raise + except: + exc_type, exc_value, traceback = sys.exc_info() + message = 'Unexpected error %s: %s' % (exc_type.__name__, exc_value) + six.reraise(remote.ServerError, message, traceback) + rpc.set_response(response) + rpc._wait_impl = wait_impl + return rpc diff --git a/endpoints/bundled/protorpc/util.py b/endpoints/bundled/protorpc/util.py new file mode 100644 index 0000000..935295c --- /dev/null +++ b/endpoints/bundled/protorpc/util.py @@ -0,0 +1,494 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Common utility library.""" + +from __future__ import with_statement +import six + +__author__ = ['rafek@google.com (Rafe Kaplan)', + 'guido@google.com (Guido van Rossum)', +] + +import cgi +import datetime +import functools +import inspect +import os +import re +import sys + +__all__ = ['AcceptItem', + 'AcceptError', + 'Error', + 'choose_content_type', + 'decode_datetime', + 'get_package_for_module', + 'pad_string', + 'parse_accept_header', + 'positional', + 'PROTORPC_PROJECT_URL', + 'TimeZoneOffset', + 'total_seconds', +] + + +class Error(Exception): + """Base class for protorpc exceptions.""" + + +class AcceptError(Error): + """Raised when there is an error parsing the accept header.""" + + +PROTORPC_PROJECT_URL = 'http://code.google.com/p/google-protorpc' + +_TIME_ZONE_RE_STRING = r""" + # Examples: + # +01:00 + # -05:30 + # Z12:00 + ((?PZ) | (?P[-+]) + (?P\d\d) : + (?P\d\d))$ +""" +_TIME_ZONE_RE = re.compile(_TIME_ZONE_RE_STRING, re.IGNORECASE | re.VERBOSE) + + +def pad_string(string): + """Pad a string for safe HTTP error responses. + + Prevents Internet Explorer from displaying their own error messages + when sent as the content of error responses. + + Args: + string: A string. + + Returns: + Formatted string left justified within a 512 byte field. + """ + return string.ljust(512) + + +def positional(max_positional_args): + """A decorator to declare that only the first N arguments may be positional. + + This decorator makes it easy to support Python 3 style keyword-only + parameters. For example, in Python 3 it is possible to write: + + def fn(pos1, *, kwonly1=None, kwonly1=None): + ... + + All named parameters after * must be a keyword: + + fn(10, 'kw1', 'kw2') # Raises exception. + fn(10, kwonly1='kw1') # Ok. + + Example: + To define a function like above, do: + + @positional(1) + def fn(pos1, kwonly1=None, kwonly2=None): + ... + + If no default value is provided to a keyword argument, it becomes a required + keyword argument: + + @positional(0) + def fn(required_kw): + ... + + This must be called with the keyword parameter: + + fn() # Raises exception. + fn(10) # Raises exception. + fn(required_kw=10) # Ok. + + When defining instance or class methods always remember to account for + 'self' and 'cls': + + class MyClass(object): + + @positional(2) + def my_method(self, pos1, kwonly1=None): + ... + + @classmethod + @positional(2) + def my_method(cls, pos1, kwonly1=None): + ... + + One can omit the argument to 'positional' altogether, and then no + arguments with default values may be passed positionally. This + would be equivalent to placing a '*' before the first argument + with a default value in Python 3. If there are no arguments with + default values, and no argument is given to 'positional', an error + is raised. + + @positional + def fn(arg1, arg2, required_kw1=None, required_kw2=0): + ... + + fn(1, 3, 5) # Raises exception. + fn(1, 3) # Ok. + fn(1, 3, required_kw1=5) # Ok. + + Args: + max_positional_arguments: Maximum number of positional arguments. All + parameters after the this index must be keyword only. + + Returns: + A decorator that prevents using arguments after max_positional_args from + being used as positional parameters. + + Raises: + TypeError if a keyword-only argument is provided as a positional parameter. + ValueError if no maximum number of arguments is provided and the function + has no arguments with default values. + """ + def positional_decorator(wrapped): + @functools.wraps(wrapped) + def positional_wrapper(*args, **kwargs): + if len(args) > max_positional_args: + plural_s = '' + if max_positional_args != 1: + plural_s = 's' + raise TypeError('%s() takes at most %d positional argument%s ' + '(%d given)' % (wrapped.__name__, + max_positional_args, + plural_s, len(args))) + return wrapped(*args, **kwargs) + return positional_wrapper + + if isinstance(max_positional_args, six.integer_types): + return positional_decorator + else: + args, _, _, defaults = inspect.getargspec(max_positional_args) + if defaults is None: + raise ValueError( + 'Functions with no keyword arguments must specify ' + 'max_positional_args') + return positional(len(args) - len(defaults))(max_positional_args) + + +# TODO(rafek): Support 'level' from the Accept header standard. +class AcceptItem(object): + """Encapsulate a single entry of an Accept header. + + Parses and extracts relevent values from an Accept header and implements + a sort order based on the priority of each requested type as defined + here: + + http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html + + Accept headers are normally a list of comma separated items. Each item + has the format of a normal HTTP header. For example: + + Accept: text/plain, text/html, text/*, */* + + This header means to prefer plain text over HTML, HTML over any other + kind of text and text over any other kind of supported format. + + This class does not attempt to parse the list of items from the Accept header. + The constructor expects the unparsed sub header and the index within the + Accept header that the fragment was found. + + Properties: + index: The index that this accept item was found in the Accept header. + main_type: The main type of the content type. + sub_type: The sub type of the content type. + q: The q value extracted from the header as a float. If there is no q + value, defaults to 1.0. + values: All header attributes parsed form the sub-header. + sort_key: A tuple (no_main_type, no_sub_type, q, no_values, index): + no_main_type: */* has the least priority. + no_sub_type: Items with no sub-type have less priority. + q: Items with lower q value have less priority. + no_values: Items with no values have less priority. + index: Index of item in accept header is the last priority. + """ + + __CONTENT_TYPE_REGEX = re.compile(r'^([^/]+)/([^/]+)$') + + def __init__(self, accept_header, index): + """Parse component of an Accept header. + + Args: + accept_header: Unparsed sub-expression of accept header. + index: The index that this accept item was found in the Accept header. + """ + accept_header = accept_header.lower() + content_type, values = cgi.parse_header(accept_header) + match = self.__CONTENT_TYPE_REGEX.match(content_type) + if not match: + raise AcceptError('Not valid Accept header: %s' % accept_header) + self.__index = index + self.__main_type = match.group(1) + self.__sub_type = match.group(2) + self.__q = float(values.get('q', 1)) + self.__values = values + + if self.__main_type == '*': + self.__main_type = None + + if self.__sub_type == '*': + self.__sub_type = None + + self.__sort_key = (not self.__main_type, + not self.__sub_type, + -self.__q, + not self.__values, + self.__index) + + @property + def index(self): + return self.__index + + @property + def main_type(self): + return self.__main_type + + @property + def sub_type(self): + return self.__sub_type + + @property + def q(self): + return self.__q + + @property + def values(self): + """Copy the dictionary of values parsed from the header fragment.""" + return dict(self.__values) + + @property + def sort_key(self): + return self.__sort_key + + def match(self, content_type): + """Determine if the given accept header matches content type. + + Args: + content_type: Unparsed content type string. + + Returns: + True if accept header matches content type, else False. + """ + content_type, _ = cgi.parse_header(content_type) + match = self.__CONTENT_TYPE_REGEX.match(content_type.lower()) + if not match: + return False + + main_type, sub_type = match.group(1), match.group(2) + if not(main_type and sub_type): + return False + + return ((self.__main_type is None or self.__main_type == main_type) and + (self.__sub_type is None or self.__sub_type == sub_type)) + + + def __cmp__(self, other): + """Comparison operator based on sort keys.""" + if not isinstance(other, AcceptItem): + return NotImplemented + return cmp(self.sort_key, other.sort_key) + + def __str__(self): + """Rebuilds Accept header.""" + content_type = '%s/%s' % (self.__main_type or '*', self.__sub_type or '*') + values = self.values + + if values: + value_strings = ['%s=%s' % (i, v) for i, v in values.items()] + return '%s; %s' % (content_type, '; '.join(value_strings)) + else: + return content_type + + def __repr__(self): + return 'AcceptItem(%r, %d)' % (str(self), self.__index) + + +def parse_accept_header(accept_header): + """Parse accept header. + + Args: + accept_header: Unparsed accept header. Does not include name of header. + + Returns: + List of AcceptItem instances sorted according to their priority. + """ + accept_items = [] + for index, header in enumerate(accept_header.split(',')): + accept_items.append(AcceptItem(header, index)) + return sorted(accept_items) + + +def choose_content_type(accept_header, supported_types): + """Choose most appropriate supported type based on what client accepts. + + Args: + accept_header: Unparsed accept header. Does not include name of header. + supported_types: List of content-types supported by the server. The index + of the supported types determines which supported type is prefered by + the server should the accept header match more than one at the same + priority. + + Returns: + The preferred supported type if the accept header matches any, else None. + """ + for accept_item in parse_accept_header(accept_header): + for supported_type in supported_types: + if accept_item.match(supported_type): + return supported_type + return None + + +@positional(1) +def get_package_for_module(module): + """Get package name for a module. + + Helper calculates the package name of a module. + + Args: + module: Module to get name for. If module is a string, try to find + module in sys.modules. + + Returns: + If module contains 'package' attribute, uses that as package name. + Else, if module is not the '__main__' module, the module __name__. + Else, the base name of the module file name. Else None. + """ + if isinstance(module, six.string_types): + try: + module = sys.modules[module] + except KeyError: + return None + + try: + return six.text_type(module.package) + except AttributeError: + if module.__name__ == '__main__': + try: + file_name = module.__file__ + except AttributeError: + pass + else: + base_name = os.path.basename(file_name) + split_name = os.path.splitext(base_name) + if len(split_name) == 1: + return six.text_type(base_name) + else: + return u'.'.join(split_name[:-1]) + + return six.text_type(module.__name__) + + +def total_seconds(offset): + """Backport of offset.total_seconds() from python 2.7+.""" + seconds = offset.days * 24 * 60 * 60 + offset.seconds + microseconds = seconds * 10**6 + offset.microseconds + return microseconds / (10**6 * 1.0) + + +class TimeZoneOffset(datetime.tzinfo): + """Time zone information as encoded/decoded for DateTimeFields.""" + + def __init__(self, offset): + """Initialize a time zone offset. + + Args: + offset: Integer or timedelta time zone offset, in minutes from UTC. This + can be negative. + """ + super(TimeZoneOffset, self).__init__() + if isinstance(offset, datetime.timedelta): + offset = total_seconds(offset) / 60 + self.__offset = offset + + def utcoffset(self, dt): + """Get the a timedelta with the time zone's offset from UTC. + + Returns: + The time zone offset from UTC, as a timedelta. + """ + return datetime.timedelta(minutes=self.__offset) + + def dst(self, dt): + """Get the daylight savings time offset. + + The formats that ProtoRPC uses to encode/decode time zone information don't + contain any information about daylight savings time. So this always + returns a timedelta of 0. + + Returns: + A timedelta of 0. + """ + return datetime.timedelta(0) + + +def decode_datetime(encoded_datetime): + """Decode a DateTimeField parameter from a string to a python datetime. + + Args: + encoded_datetime: A string in RFC 3339 format. + + Returns: + A datetime object with the date and time specified in encoded_datetime. + + Raises: + ValueError: If the string is not in a recognized format. + """ + # Check if the string includes a time zone offset. Break out the + # part that doesn't include time zone info. Convert to uppercase + # because all our comparisons should be case-insensitive. + time_zone_match = _TIME_ZONE_RE.search(encoded_datetime) + if time_zone_match: + time_string = encoded_datetime[:time_zone_match.start(1)].upper() + else: + time_string = encoded_datetime.upper() + + if '.' in time_string: + format_string = '%Y-%m-%dT%H:%M:%S.%f' + else: + format_string = '%Y-%m-%dT%H:%M:%S' + + decoded_datetime = datetime.datetime.strptime(time_string, format_string) + + if not time_zone_match: + return decoded_datetime + + # Time zone info was included in the parameter. Add a tzinfo + # object to the datetime. Datetimes can't be changed after they're + # created, so we'll need to create a new one. + if time_zone_match.group('z'): + offset_minutes = 0 + else: + sign = time_zone_match.group('sign') + hours, minutes = [int(value) for value in + time_zone_match.group('hours', 'minutes')] + offset_minutes = hours * 60 + minutes + if sign == '-': + offset_minutes *= -1 + + return datetime.datetime(decoded_datetime.year, + decoded_datetime.month, + decoded_datetime.day, + decoded_datetime.hour, + decoded_datetime.minute, + decoded_datetime.second, + decoded_datetime.microsecond, + TimeZoneOffset(offset_minutes)) diff --git a/endpoints/bundled/protorpc/wsgi/__init__.py b/endpoints/bundled/protorpc/wsgi/__init__.py new file mode 100644 index 0000000..00be5b0 --- /dev/null +++ b/endpoints/bundled/protorpc/wsgi/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/endpoints/bundled/protorpc/wsgi/service.py b/endpoints/bundled/protorpc/wsgi/service.py new file mode 100644 index 0000000..954658a --- /dev/null +++ b/endpoints/bundled/protorpc/wsgi/service.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""ProtoRPC WSGI service applications. + +Use functions in this module to configure ProtoRPC services for use with +WSGI applications. For more information about WSGI, please see: + + http://wsgi.org/wsgi + http://docs.python.org/library/wsgiref.html +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import cgi +import six.moves.http_client +import logging +import re + +from .. import messages +from .. import registry +from .. import remote +from .. import util +from . import util as wsgi_util + +__all__ = [ + 'DEFAULT_REGISTRY_PATH', + 'service_app', +] + +_METHOD_PATTERN = r'(?:\.([^?]+))' +_REQUEST_PATH_PATTERN = r'^(%%s)%s$' % _METHOD_PATTERN + +_HTTP_BAD_REQUEST = wsgi_util.error(six.moves.http_client.BAD_REQUEST) +_HTTP_NOT_FOUND = wsgi_util.error(six.moves.http_client.NOT_FOUND) +_HTTP_UNSUPPORTED_MEDIA_TYPE = wsgi_util.error(six.moves.http_client.UNSUPPORTED_MEDIA_TYPE) + +DEFAULT_REGISTRY_PATH = '/protorpc' + + +@util.positional(2) +def service_mapping(service_factory, service_path=r'.*', protocols=None): + """WSGI application that handles a single ProtoRPC service mapping. + + Args: + service_factory: Service factory for creating instances of service request + handlers. Either callable that takes no parameters and returns a service + instance or a service class whose constructor requires no parameters. + service_path: Regular expression for matching requests against. Requests + that do not have matching paths will cause a 404 (Not Found) response. + protocols: remote.Protocols instance that configures supported protocols + on server. + """ + service_class = getattr(service_factory, 'service_class', service_factory) + remote_methods = service_class.all_remote_methods() + path_matcher = re.compile(_REQUEST_PATH_PATTERN % service_path) + + def protorpc_service_app(environ, start_response): + """Actual WSGI application function.""" + path_match = path_matcher.match(environ['PATH_INFO']) + if not path_match: + return _HTTP_NOT_FOUND(environ, start_response) + service_path = path_match.group(1) + method_name = path_match.group(2) + + content_type = environ.get('CONTENT_TYPE') + if not content_type: + content_type = environ.get('HTTP_CONTENT_TYPE') + if not content_type: + return _HTTP_BAD_REQUEST(environ, start_response) + + # TODO(rafek): Handle alternate encodings. + content_type = cgi.parse_header(content_type)[0] + + request_method = environ['REQUEST_METHOD'] + if request_method != 'POST': + content = ('%s.%s is a ProtoRPC method.\n\n' + 'Service %s\n\n' + 'More about ProtoRPC: ' + '%s\n' % + (service_path, + method_name, + service_class.definition_name().encode('utf-8'), + util.PROTORPC_PROJECT_URL)) + error_handler = wsgi_util.error( + six.moves.http_client.METHOD_NOT_ALLOWED, + six.moves.http_client.responses[six.moves.http_client.METHOD_NOT_ALLOWED], + content=content, + content_type='text/plain; charset=utf-8') + return error_handler(environ, start_response) + + local_protocols = protocols or remote.Protocols.get_default() + try: + protocol = local_protocols.lookup_by_content_type(content_type) + except KeyError: + return _HTTP_UNSUPPORTED_MEDIA_TYPE(environ,start_response) + + def send_rpc_error(status_code, state, message, error_name=None): + """Helper function to send an RpcStatus message as response. + + Will create static error handler and begin response. + + Args: + status_code: HTTP integer status code. + state: remote.RpcState enum value to send as response. + message: Helpful message to send in response. + error_name: Error name if applicable. + + Returns: + List containing encoded content response using the same content-type as + the request. + """ + status = remote.RpcStatus(state=state, + error_message=message, + error_name=error_name) + encoded_status = protocol.encode_message(status) + error_handler = wsgi_util.error( + status_code, + content_type=protocol.default_content_type, + content=encoded_status) + return error_handler(environ, start_response) + + method = remote_methods.get(method_name) + if not method: + return send_rpc_error(six.moves.http_client.BAD_REQUEST, + remote.RpcState.METHOD_NOT_FOUND_ERROR, + 'Unrecognized RPC method: %s' % method_name) + + content_length = int(environ.get('CONTENT_LENGTH') or '0') + + remote_info = method.remote + try: + request = protocol.decode_message( + remote_info.request_type, environ['wsgi.input'].read(content_length)) + except (messages.ValidationError, messages.DecodeError) as err: + return send_rpc_error(six.moves.http_client.BAD_REQUEST, + remote.RpcState.REQUEST_ERROR, + 'Error parsing ProtoRPC request ' + '(Unable to parse request content: %s)' % err) + + instance = service_factory() + + initialize_request_state = getattr( + instance, 'initialize_request_state', None) + if initialize_request_state: + # TODO(rafek): This is not currently covered by tests. + server_port = environ.get('SERVER_PORT', None) + if server_port: + server_port = int(server_port) + + headers = [] + for name, value in six.iteritems(environ): + if name.startswith('HTTP_'): + headers.append((name[len('HTTP_'):].lower().replace('_', '-'), value)) + request_state = remote.HttpRequestState( + remote_host=environ.get('REMOTE_HOST', None), + remote_address=environ.get('REMOTE_ADDR', None), + server_host=environ.get('SERVER_HOST', None), + server_port=server_port, + http_method=request_method, + service_path=service_path, + headers=headers) + + initialize_request_state(request_state) + + try: + response = method(instance, request) + encoded_response = protocol.encode_message(response) + except remote.ApplicationError as err: + return send_rpc_error(six.moves.http_client.BAD_REQUEST, + remote.RpcState.APPLICATION_ERROR, + unicode(err), + err.error_name) + except Exception as err: + logging.exception('Encountered unexpected error from ProtoRPC ' + 'method implementation: %s (%s)' % + (err.__class__.__name__, err)) + return send_rpc_error(six.moves.http_client.INTERNAL_SERVER_ERROR, + remote.RpcState.SERVER_ERROR, + 'Internal Server Error') + + response_headers = [('content-type', content_type)] + start_response('%d %s' % (six.moves.http_client.OK, six.moves.http_client.responses[six.moves.http_client.OK],), + response_headers) + return [encoded_response] + + # Return WSGI application. + return protorpc_service_app + + +@util.positional(1) +def service_mappings(services, registry_path=DEFAULT_REGISTRY_PATH): + """Create multiple service mappings with optional RegistryService. + + Use this function to create single WSGI application that maps to + multiple ProtoRPC services plus an optional RegistryService. + + Example: + services = service.service_mappings( + [(r'/time', TimeService), + (r'/weather', WeatherService) + ]) + + In this example, the services WSGI application will map to two services, + TimeService and WeatherService to the '/time' and '/weather' paths + respectively. In addition, it will also add a ProtoRPC RegistryService + configured to serve information about both services at the (default) path + '/protorpc'. + + Args: + services: If a dictionary is provided instead of a list of tuples, the + dictionary item pairs are used as the mappings instead. + Otherwise, a list of tuples (service_path, service_factory): + service_path: The path to mount service on. + service_factory: A service class or service instance factory. + registry_path: A string to change where the registry is mapped (the default + location is '/protorpc'). When None, no registry is created or mounted. + + Returns: + WSGI application that serves ProtoRPC services on their respective URLs + plus optional RegistryService. + """ + if isinstance(services, dict): + services = six.iteritems(services) + + final_mapping = [] + paths = set() + registry_map = {} if registry_path else None + + for service_path, service_factory in services: + try: + service_class = service_factory.service_class + except AttributeError: + service_class = service_factory + + if service_path not in paths: + paths.add(service_path) + else: + raise remote.ServiceConfigurationError( + 'Path %r is already defined in service mapping' % + service_path.encode('utf-8')) + + if registry_map is not None: + registry_map[service_path] = service_class + + final_mapping.append(service_mapping(service_factory, service_path)) + + if registry_map is not None: + final_mapping.append(service_mapping( + registry.RegistryService.new_factory(registry_map), registry_path)) + + return wsgi_util.first_found(final_mapping) diff --git a/endpoints/bundled/protorpc/wsgi/util.py b/endpoints/bundled/protorpc/wsgi/util.py new file mode 100644 index 0000000..344a6bd --- /dev/null +++ b/endpoints/bundled/protorpc/wsgi/util.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""WSGI utilities + +Small collection of helpful utilities for working with WSGI. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import six.moves.http_client +import re + +from .. import util + +__all__ = ['static_page', + 'error', + 'first_found', +] + +_STATUS_PATTERN = re.compile('^(\d{3})\s') + + +@util.positional(1) +def static_page(content='', + status='200 OK', + content_type='text/html; charset=utf-8', + headers=None): + """Create a WSGI application that serves static content. + + A static page is one that will be the same every time it receives a request. + It will always serve the same status, content and headers. + + Args: + content: Content to serve in response to HTTP request. + status: Status to serve in response to HTTP request. If string, status + is served as is without any error checking. If integer, will look up + status message. Otherwise, parameter is tuple (status, description): + status: Integer status of response. + description: Brief text description of response. + content_type: Convenient parameter for content-type header. Will appear + before any content-type header that appears in 'headers' parameter. + headers: Dictionary of headers or iterable of tuples (name, value): + name: String name of header. + value: String value of header. + + Returns: + WSGI application that serves static content. + """ + if isinstance(status, six.integer_types): + status = '%d %s' % (status, six.moves.http_client.responses.get(status, 'Unknown Error')) + elif not isinstance(status, six.string_types): + status = '%d %s' % tuple(status) + + if isinstance(headers, dict): + headers = six.iteritems(headers) + + headers = [('content-length', str(len(content))), + ('content-type', content_type), + ] + list(headers or []) + + # Ensure all headers are str. + for index, (key, value) in enumerate(headers): + if isinstance(value, six.text_type): + value = value.encode('utf-8') + headers[index] = key, value + + if not isinstance(key, str): + raise TypeError('Header key must be str, found: %r' % (key,)) + + if not isinstance(value, str): + raise TypeError( + 'Header %r must be type str or unicode, found: %r' % (key, value)) + + def static_page_application(environ, start_response): + start_response(status, headers) + return [content] + + return static_page_application + + +@util.positional(2) +def error(status_code, status_message=None, + content_type='text/plain; charset=utf-8', + headers=None, content=None): + """Create WSGI application that statically serves an error page. + + Creates a static error page specifically for non-200 HTTP responses. + + Browsers such as Internet Explorer will display their own error pages for + error content responses smaller than 512 bytes. For this reason all responses + are right-padded up to 512 bytes. + + Error pages that are not provided will content will contain the standard HTTP + status message as their content. + + Args: + status_code: Integer status code of error. + status_message: Status message. + + Returns: + Static WSGI application that sends static error response. + """ + if status_message is None: + status_message = six.moves.http_client.responses.get(status_code, 'Unknown Error') + + if content is None: + content = status_message + + content = util.pad_string(content) + + return static_page(content, + status=(status_code, status_message), + content_type=content_type, + headers=headers) + + +def first_found(apps): + """Serve the first application that does not response with 404 Not Found. + + If no application serves content, will respond with generic 404 Not Found. + + Args: + apps: List of WSGI applications to search through. Will serve the content + of the first of these that does not return a 404 Not Found. Applications + in this list must not modify the environment or any objects in it if they + do not match. Applications that do not obey this restriction can create + unpredictable results. + + Returns: + Compound application that serves the contents of the first application that + does not response with 404 Not Found. + """ + apps = tuple(apps) + not_found = error(six.moves.http_client.NOT_FOUND) + + def first_found_app(environ, start_response): + """Compound application returned from the first_found function.""" + final_result = {} # Used in absence of Python local scoping. + + def first_found_start_response(status, response_headers): + """Replacement for start_response as passed in to first_found_app. + + Called by each application in apps instead of the real start response. + Checks the response status, and if anything other than 404, sets 'status' + and 'response_headers' in final_result. + """ + status_match = _STATUS_PATTERN.match(status) + assert status_match, ('Status must be a string beginning ' + 'with 3 digit number. Found: %s' % status) + status_code = status_match.group(0) + if int(status_code) == six.moves.http_client.NOT_FOUND: + return + + final_result['status'] = status + final_result['response_headers'] = response_headers + + for app in apps: + response = app(environ, first_found_start_response) + if final_result: + start_response(final_result['status'], final_result['response_headers']) + return response + + return not_found(environ, start_response) + return first_found_app diff --git a/endpoints/protojson.py b/endpoints/protojson.py index 83658db..9d36e03 100644 --- a/endpoints/protojson.py +++ b/endpoints/protojson.py @@ -17,7 +17,7 @@ import base64 -from protorpc import protojson +from .bundled.protorpc import protojson from . import messages diff --git a/endpoints/test/apiserving_test.py b/endpoints/test/apiserving_test.py index 56cf226..a9eb137 100644 --- a/endpoints/test/apiserving_test.py +++ b/endpoints/test/apiserving_test.py @@ -28,6 +28,7 @@ import urllib2 import mock +import pytest import test_util import webtest from endpoints import api_config @@ -38,6 +39,8 @@ from endpoints import remote from endpoints import resource_container +from protorpc import remote as nonbundled_remote + package = 'endpoints.test' @@ -362,5 +365,21 @@ def testGetApiConfigs(self): self.assertEqual(TEST_SERVICE_CUSTOM_URL_API_CONFIG, configs) +@api_config.api(name='testapi', version='v3', description='A wonderful API.') +class TestNonbundledService(nonbundled_remote.Service): + + @api_config.method(test_request, + message_types.VoidMessage, + http_method='DELETE', path='items/{id}') + # Silence lint warning about method naming conventions + # pylint: disable=g-bad-name + def delete(self, unused_request): + return message_types.VoidMessage() + + +def test_nonbundled_service_error(): + with pytest.raises(TypeError): + apiserving.api_server([TestNonbundledService]) + if __name__ == '__main__': unittest.main() diff --git a/test-requirements.txt b/test-requirements.txt index 3d07a84..6547b5d 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -4,6 +4,6 @@ pytest>=2.8.3 pytest-cov>=1.8.1 pytest-timeout>=1.0.0 webtest>=2.0.23,<3.0 -git+git://github.com/inklesspen/protorpc.git@endpoints-dependency#egg=protorpc-0.12.0a0 +protorpc>=0.12.0 protobuf>=3.0.0b3 PyYAML==3.12