diff options
| author | Tamas Berghammer <tberghammer@google.com> | 2016-06-03 17:53:47 -0700 |
|---|---|---|
| committer | Tamas Berghammer <tberghammer@google.com> | 2016-06-06 16:38:21 -0700 |
| commit | b0575e93e4c39dec69365b850088a1eb7f82c5b3 (patch) | |
| tree | af0dd50df9603fdc849ad26d3d6b50384035d51e /python | |
| parent | 9f183e211f11e0e5c0a1fa519107dbfbb1a8d4ca (diff) | |
| download | platform_external_protobuf-b0575e93e4c39dec69365b850088a1eb7f82c5b3.tar.gz platform_external_protobuf-b0575e93e4c39dec69365b850088a1eb7f82c5b3.tar.bz2 platform_external_protobuf-b0575e93e4c39dec69365b850088a1eb7f82c5b3.zip | |
Update from protobuf v2.6.1 to protobuf 3.0.0-beta-3
This change just copies the upstream code into the repository without
fixing the Android.mk or fixing the possible cmpile errors. All of those
will be fixed with foloowup CLs.
Bug: b/28974522
Change-Id: I79fb3966dbef85915965692fa6ab14dc611ed9ea
Diffstat (limited to 'python')
90 files changed, 16537 insertions, 4992 deletions
diff --git a/python/MANIFEST.in b/python/MANIFEST.in new file mode 100644 index 000000000..260888263 --- /dev/null +++ b/python/MANIFEST.in @@ -0,0 +1,14 @@ +prune google/protobuf/internal/import_test_package +exclude google/protobuf/internal/*_pb2.py +exclude google/protobuf/internal/*_test.py +exclude google/protobuf/internal/*.proto +exclude google/protobuf/internal/test_util.py + +recursive-exclude google *_test.py +recursive-exclude google *_test.proto +recursive-exclude google unittest*_pb2.py + +global-exclude *.dll +global-exclude *.pyc +global-exclude *.pyo +global-exclude *.so diff --git a/python/README.txt b/python/README.md index da044af44..57acfd94d 100644 --- a/python/README.txt +++ b/python/README.md @@ -1,4 +1,8 @@ Protocol Buffers - Google's data interchange format +=================================================== + +[](https://travis-ci.org/google/protobuf) + Copyright 2008 Google Inc. This directory contains the Python Protocol Buffers runtime library. @@ -26,7 +30,7 @@ join the Protocol Buffers discussion list and let us know! Installation ============ -1) Make sure you have Python 2.4 or newer. If in doubt, run: +1) Make sure you have Python 2.6 or newer. If in doubt, run: $ python -V @@ -35,7 +39,7 @@ Installation If you would rather install it manually, you may do so by following the instructions on this page: - http://peak.telecommunity.com/DevCenter/EasyInstall#installation-instructions + https://packaging.python.org/en/latest/installing.html#setup-for-installing-packages 3) Build the C++ code, or install a binary distribution of protoc. If you install a binary distribution, make sure that it is the same @@ -46,9 +50,38 @@ Installation 4) Build and run the tests: $ python setup.py build - $ python setup.py google_test + $ python setup.py test + + To build, test, and use the C++ implementation, you must first compile + libprotobuf.so: + + $ (cd .. && make) + + On OS X: + + If you are running a homebrew-provided python, you must make sure another + version of protobuf is not already installed, as homebrew's python will + search /usr/local/lib for libprotobuf.so before it searches ../src/.libs + You can either unlink homebrew's protobuf or install the libprotobuf you + built earlier: + + $ brew unlink protobuf + or + $ (cd .. && make install) + + On other *nix: - If you want to test c++ implementation, run: + You must make libprotobuf.so dynamically available. You can either + install libprotobuf you built earlier, or set LD_LIBRARY_PATH: + + $ export LD_LIBRARY_PATH=../src/.libs + or + $ (cd .. && make install) + + To build the C++ implementation run: + $ python setup.py build --cpp_implementation + + Then run the tests like so: $ python setup.py test --cpp_implementation If some tests fail, this library may not work correctly on your @@ -64,14 +97,17 @@ Installation 5) Install: - $ python setup.py install - or: - $ python setup.py install --cpp_implementation + $ python setup.py install + + or: + + $ (cd .. && make install) + $ python setup.py install --cpp_implementation This step may require superuser privileges. - NOTE: To use C++ implementation, you need to install C++ protobuf runtime - library of the same version and export the environment variable before this - step. See the "C++ Implementation" section below for more details. + NOTE: To use C++ implementation, you need to export an environment + variable before running your program. See the "C++ Implementation" + section below for more details. Usage ===== @@ -87,19 +123,5 @@ C++ Implementation The C++ implementation for Python messages is built as a Python extension to improve the overall protobuf Python performance. -To use the C++ implementation, you need to: -1) Install the C++ protobuf runtime library, please see instructions in the - parent directory. -2) Export an environment variable: - - $ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp - $ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 - -You need to export this variable before running setup.py script to build and -install the extension. You must also set the variable at runtime, otherwise -the pure-Python implementation will be used. In a future release, we will -change the default so that C++ implementation is used whenever it is available. -It is strongly recommended to run `python setup.py test` after setting the -variable to "cpp", so the tests will be against C++ implemented Python -messages. - +To use the C++ implementation, you need to install the C++ protobuf runtime +library, please see instructions in the parent directory. diff --git a/python/ez_setup.py b/python/ez_setup.py deleted file mode 100755 index 3aec98e48..000000000 --- a/python/ez_setup.py +++ /dev/null @@ -1,284 +0,0 @@ -#!python - -# This file was obtained from: -# http://peak.telecommunity.com/dist/ez_setup.py -# on 2011/1/21. - -"""Bootstrap setuptools installation - -If you want to use setuptools in your package's setup.py, just include this -file in the same directory with it, and add this to the top of your setup.py:: - - from ez_setup import use_setuptools - use_setuptools() - -If you want to require a specific version of setuptools, set a download -mirror, or use an alternate download directory, you can do so by supplying -the appropriate options to ``use_setuptools()``. - -This file can also be run as a script to install or upgrade setuptools. -""" -import sys -DEFAULT_VERSION = "0.6c11" -DEFAULT_URL = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3] - -md5_data = { - 'setuptools-0.6b1-py2.3.egg': '8822caf901250d848b996b7f25c6e6ca', - 'setuptools-0.6b1-py2.4.egg': 'b79a8a403e4502fbb85ee3f1941735cb', - 'setuptools-0.6b2-py2.3.egg': '5657759d8a6d8fc44070a9d07272d99b', - 'setuptools-0.6b2-py2.4.egg': '4996a8d169d2be661fa32a6e52e4f82a', - 'setuptools-0.6b3-py2.3.egg': 'bb31c0fc7399a63579975cad9f5a0618', - 'setuptools-0.6b3-py2.4.egg': '38a8c6b3d6ecd22247f179f7da669fac', - 'setuptools-0.6b4-py2.3.egg': '62045a24ed4e1ebc77fe039aa4e6f7e5', - 'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4', - 'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c', - 'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b', - 'setuptools-0.6c10-py2.3.egg': 'ce1e2ab5d3a0256456d9fc13800a7090', - 'setuptools-0.6c10-py2.4.egg': '57d6d9d6e9b80772c59a53a8433a5dd4', - 'setuptools-0.6c10-py2.5.egg': 'de46ac8b1c97c895572e5e8596aeb8c7', - 'setuptools-0.6c10-py2.6.egg': '58ea40aef06da02ce641495523a0b7f5', - 'setuptools-0.6c11-py2.3.egg': '2baeac6e13d414a9d28e7ba5b5a596de', - 'setuptools-0.6c11-py2.4.egg': 'bd639f9b0eac4c42497034dec2ec0c2b', - 'setuptools-0.6c11-py2.5.egg': '64c94f3bf7a72a13ec83e0b24f2749b2', - 'setuptools-0.6c11-py2.6.egg': 'bfa92100bd772d5a213eedd356d64086', - 'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27', - 'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277', - 'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa', - 'setuptools-0.6c3-py2.4.egg': 'e0ed74682c998bfb73bf803a50e7b71e', - 'setuptools-0.6c3-py2.5.egg': 'abef16fdd61955514841c7c6bd98965e', - 'setuptools-0.6c4-py2.3.egg': 'b0b9131acab32022bfac7f44c5d7971f', - 'setuptools-0.6c4-py2.4.egg': '2a1f9656d4fbf3c97bf946c0a124e6e2', - 'setuptools-0.6c4-py2.5.egg': '8f5a052e32cdb9c72bcf4b5526f28afc', - 'setuptools-0.6c5-py2.3.egg': 'ee9fd80965da04f2f3e6b3576e9d8167', - 'setuptools-0.6c5-py2.4.egg': 'afe2adf1c01701ee841761f5bcd8aa64', - 'setuptools-0.6c5-py2.5.egg': 'a8d3f61494ccaa8714dfed37bccd3d5d', - 'setuptools-0.6c6-py2.3.egg': '35686b78116a668847237b69d549ec20', - 'setuptools-0.6c6-py2.4.egg': '3c56af57be3225019260a644430065ab', - 'setuptools-0.6c6-py2.5.egg': 'b2f8a7520709a5b34f80946de5f02f53', - 'setuptools-0.6c7-py2.3.egg': '209fdf9adc3a615e5115b725658e13e2', - 'setuptools-0.6c7-py2.4.egg': '5a8f954807d46a0fb67cf1f26c55a82e', - 'setuptools-0.6c7-py2.5.egg': '45d2ad28f9750e7434111fde831e8372', - 'setuptools-0.6c8-py2.3.egg': '50759d29b349db8cfd807ba8303f1902', - 'setuptools-0.6c8-py2.4.egg': 'cba38d74f7d483c06e9daa6070cce6de', - 'setuptools-0.6c8-py2.5.egg': '1721747ee329dc150590a58b3e1ac95b', - 'setuptools-0.6c9-py2.3.egg': 'a83c4020414807b496e4cfbe08507c03', - 'setuptools-0.6c9-py2.4.egg': '260a2be2e5388d66bdaee06abec6342a', - 'setuptools-0.6c9-py2.5.egg': 'fe67c3e5a17b12c0e7c541b7ea43a8e6', - 'setuptools-0.6c9-py2.6.egg': 'ca37b1ff16fa2ede6e19383e7b59245a', -} - -import sys, os -try: from hashlib import md5 -except ImportError: from md5 import md5 - -def _validate_md5(egg_name, data): - if egg_name in md5_data: - digest = md5(data).hexdigest() - if digest != md5_data[egg_name]: - print >>sys.stderr, ( - "md5 validation of %s failed! (Possible download problem?)" - % egg_name - ) - sys.exit(2) - return data - -def use_setuptools( - version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, - download_delay=15 -): - """Automatically find/download setuptools and make it available on sys.path - - `version` should be a valid setuptools version number that is available - as an egg for download under the `download_base` URL (which should end with - a '/'). `to_dir` is the directory where setuptools will be downloaded, if - it is not already available. If `download_delay` is specified, it should - be the number of seconds that will be paused before initiating a download, - should one be required. If an older version of setuptools is installed, - this routine will print a message to ``sys.stderr`` and raise SystemExit in - an attempt to abort the calling script. - """ - was_imported = 'pkg_resources' in sys.modules or 'setuptools' in sys.modules - def do_download(): - egg = download_setuptools(version, download_base, to_dir, download_delay) - sys.path.insert(0, egg) - import setuptools; setuptools.bootstrap_install_from = egg - try: - import pkg_resources - except ImportError: - return do_download() - try: - return do_download() - pkg_resources.require("setuptools>="+version); return - except pkg_resources.VersionConflict, e: - if was_imported: - print >>sys.stderr, ( - "The required version of setuptools (>=%s) is not available, and\n" - "can't be installed while this script is running. Please install\n" - " a more recent version first, using 'easy_install -U setuptools'." - "\n\n(Currently using %r)" - ) % (version, e.args[0]) - sys.exit(2) - except pkg_resources.DistributionNotFound: - pass - - del pkg_resources, sys.modules['pkg_resources'] # reload ok - return do_download() - -def download_setuptools( - version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, - delay = 15 -): - """Download setuptools from a specified location and return its filename - - `version` should be a valid setuptools version number that is available - as an egg for download under the `download_base` URL (which should end - with a '/'). `to_dir` is the directory where the egg will be downloaded. - `delay` is the number of seconds to pause before an actual download attempt. - """ - import urllib2, shutil - egg_name = "setuptools-%s-py%s.egg" % (version,sys.version[:3]) - url = download_base + egg_name - saveto = os.path.join(to_dir, egg_name) - src = dst = None - if not os.path.exists(saveto): # Avoid repeated downloads - try: - from distutils import log - if delay: - log.warn(""" ---------------------------------------------------------------------------- -This script requires setuptools version %s to run (even to display -help). I will attempt to download it for you (from -%s), but -you may need to enable firewall access for this script first. -I will start the download in %d seconds. - -(Note: if this machine does not have network access, please obtain the file - - %s - -and place it in this directory before rerunning this script.) ----------------------------------------------------------------------------""", - version, download_base, delay, url - ); from time import sleep; sleep(delay) - log.warn("Downloading %s", url) - src = urllib2.urlopen(url) - # Read/write all in one block, so we don't create a corrupt file - # if the download is interrupted. - data = _validate_md5(egg_name, src.read()) - dst = open(saveto,"wb"); dst.write(data) - finally: - if src: src.close() - if dst: dst.close() - return os.path.realpath(saveto) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -def main(argv, version=DEFAULT_VERSION): - """Install or upgrade setuptools and EasyInstall""" - try: - import setuptools - except ImportError: - egg = None - try: - egg = download_setuptools(version, delay=0) - sys.path.insert(0,egg) - from setuptools.command.easy_install import main - return main(list(argv)+[egg]) # we're done here - finally: - if egg and os.path.exists(egg): - os.unlink(egg) - else: - if setuptools.__version__ == '0.0.1': - print >>sys.stderr, ( - "You have an obsolete version of setuptools installed. Please\n" - "remove it from your system entirely before rerunning this script." - ) - sys.exit(2) - - req = "setuptools>="+version - import pkg_resources - try: - pkg_resources.require(req) - except pkg_resources.VersionConflict: - try: - from setuptools.command.easy_install import main - except ImportError: - from easy_install import main - main(list(argv)+[download_setuptools(delay=0)]) - sys.exit(0) # try to force an exit - else: - if argv: - from setuptools.command.easy_install import main - main(argv) - else: - print "Setuptools version",version,"or greater has been installed." - print '(Run "ez_setup.py -U setuptools" to reinstall or upgrade.)' - -def update_md5(filenames): - """Update our built-in md5 registry""" - - import re - - for name in filenames: - base = os.path.basename(name) - f = open(name,'rb') - md5_data[base] = md5(f.read()).hexdigest() - f.close() - - data = [" %r: %r,\n" % it for it in md5_data.items()] - data.sort() - repl = "".join(data) - - import inspect - srcfile = inspect.getsourcefile(sys.modules[__name__]) - f = open(srcfile, 'rb'); src = f.read(); f.close() - - match = re.search("\nmd5_data = {\n([^}]+)}", src) - if not match: - print >>sys.stderr, "Internal error!" - sys.exit(2) - - src = src[:match.start(1)] + repl + src[match.end(1):] - f = open(srcfile,'w') - f.write(src) - f.close() - - -if __name__=='__main__': - if len(sys.argv)>2 and sys.argv[1]=='--md5update': - update_md5(sys.argv[2:]) - else: - main(sys.argv[1:]) diff --git a/python/google/__init__.py b/python/google/__init__.py index de40ea7ca..558561412 100755 --- a/python/google/__init__.py +++ b/python/google/__init__.py @@ -1 +1,4 @@ -__import__('pkg_resources').declare_namespace(__name__) +try: + __import__('pkg_resources').declare_namespace(__name__) +except ImportError: + __path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/python/google/protobuf/__init__.py b/python/google/protobuf/__init__.py index e69de29bb..2a3c6771a 100755 --- a/python/google/protobuf/__init__.py +++ b/python/google/protobuf/__init__.py @@ -0,0 +1,39 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Copyright 2007 Google Inc. All Rights Reserved. + +__version__ = '3.0.0b3' + +if __name__ != '__main__': + try: + __import__('pkg_resources').declare_namespace(__name__) + except ImportError: + __path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index 6da8bb0bc..3209b34d5 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -28,28 +28,23 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# Needs to stay compatible with Python 2.5 due to GAE. -# -# Copyright 2007 Google Inc. All Rights Reserved. - """Descriptors essentially contain exactly the information found in a .proto file, in types that make this information accessible in Python. """ __author__ = 'robinson@google.com (Will Robinson)' -from google.protobuf.internal import api_implementation +import six +from google.protobuf.internal import api_implementation +_USE_C_DESCRIPTORS = False if api_implementation.Type() == 'cpp': # Used by MakeDescriptor in cpp mode import os import uuid - - if api_implementation.Version() == 2: - from google.protobuf.pyext import _message - else: - from google.protobuf.internal import cpp_message + from google.protobuf.pyext import _message + _USE_C_DESCRIPTORS = getattr(_message, '_USE_C_DESCRIPTORS', False) class Error(Exception): @@ -60,12 +55,29 @@ class TypeTransformationError(Error): """Error transforming between python proto type and corresponding C++ type.""" -class DescriptorBase(object): +if _USE_C_DESCRIPTORS: + # This metaclass allows to override the behavior of code like + # isinstance(my_descriptor, FieldDescriptor) + # and make it return True when the descriptor is an instance of the extension + # type written in C++. + class DescriptorMetaclass(type): + def __instancecheck__(cls, obj): + if super(DescriptorMetaclass, cls).__instancecheck__(obj): + return True + if isinstance(obj, cls._C_DESCRIPTOR_CLASS): + return True + return False +else: + # The standard metaclass; nothing changes. + DescriptorMetaclass = type + + +class DescriptorBase(six.with_metaclass(DescriptorMetaclass)): """Descriptors base class. This class is the base of all descriptor classes. It provides common options - related functionaility. + related functionality. Attributes: has_options: True if the descriptor has non-default options. Usually it @@ -75,6 +87,11 @@ class DescriptorBase(object): avoid some bootstrapping issues. """ + if _USE_C_DESCRIPTORS: + # The class, or tuple of classes, that are considered as "virtual + # subclasses" of this descriptor class. + _C_DESCRIPTOR_CLASS = () + def __init__(self, options, options_class_name): """Initialize the descriptor given its options message and the name of the class of the options message. The name of the class is required in case @@ -201,6 +218,9 @@ class Descriptor(_NestedDescriptorBase): fields_by_name: (dict str -> FieldDescriptor) Same FieldDescriptor objects as in |fields|, but indexed by "name" attribute in each FieldDescriptor. + fields_by_camelcase_name: (dict str -> FieldDescriptor) Same + FieldDescriptor objects as in |fields|, but indexed by + "camelcase_name" attribute in each FieldDescriptor. nested_types: (list of Descriptors) Descriptor references for all protocol message types nested within this one. @@ -224,9 +244,6 @@ class Descriptor(_NestedDescriptorBase): is_extendable: Does this type define any extension ranges? - options: (descriptor_pb2.MessageOptions) Protocol message options or None - to use default message options. - oneofs: (list of OneofDescriptor) The list of descriptors for oneof fields in this message. oneofs_by_name: (dict str -> OneofDescriptor) Same objects as in |oneofs|, @@ -235,13 +252,25 @@ class Descriptor(_NestedDescriptorBase): file: (FileDescriptor) Reference to file descriptor. """ + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.Descriptor + + def __new__(cls, name, full_name, filename, containing_type, fields, + nested_types, enum_types, extensions, options=None, + is_extendable=True, extension_ranges=None, oneofs=None, + file=None, serialized_start=None, serialized_end=None, + syntax=None): + _message.Message._CheckCalledFromGeneratedFile() + return _message.default_pool.FindMessageTypeByName(full_name) + # NOTE(tmarek): The file argument redefining a builtin is nothing we can # fix right now since we don't know how many clients already rely on the # name of the argument. def __init__(self, name, full_name, filename, containing_type, fields, nested_types, enum_types, extensions, options=None, is_extendable=True, extension_ranges=None, oneofs=None, - file=None, serialized_start=None, serialized_end=None): # pylint:disable=redefined-builtin + file=None, serialized_start=None, serialized_end=None, + syntax=None): # pylint:disable=redefined-builtin """Arguments to __init__() are as described in the description of Descriptor fields above. @@ -263,6 +292,7 @@ class Descriptor(_NestedDescriptorBase): field.containing_type = self self.fields_by_number = dict((f.number, f) for f in fields) self.fields_by_name = dict((f.name, f) for f in fields) + self._fields_by_camelcase_name = None self.nested_types = nested_types for nested_type in nested_types: @@ -286,6 +316,14 @@ class Descriptor(_NestedDescriptorBase): self.oneofs_by_name = dict((o.name, o) for o in self.oneofs) for oneof in self.oneofs: oneof.containing_type = self + self.syntax = syntax or "proto2" + + @property + def fields_by_camelcase_name(self): + if self._fields_by_camelcase_name is None: + self._fields_by_camelcase_name = dict( + (f.camelcase_name, f) for f in self.fields) + return self._fields_by_camelcase_name def EnumValueName(self, enum, value): """Returns the string name of an enum value. @@ -335,6 +373,7 @@ class FieldDescriptor(DescriptorBase): name: (str) Name of this field, exactly as it appears in .proto. full_name: (str) Name of this field, including containing scope. This is particularly relevant for extensions. + camelcase_name: (str) Camelcase name of this field. index: (int) Dense, 0-indexed index giving the order that this field textually appears within its message in the .proto file. number: (int) Tag number declared for this field in the .proto file. @@ -452,6 +491,19 @@ class FieldDescriptor(DescriptorBase): FIRST_RESERVED_FIELD_NUMBER = 19000 LAST_RESERVED_FIELD_NUMBER = 19999 + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.FieldDescriptor + + def __new__(cls, name, full_name, index, number, type, cpp_type, label, + default_value, message_type, enum_type, containing_type, + is_extension, extension_scope, options=None, + has_default_value=True, containing_oneof=None): + _message.Message._CheckCalledFromGeneratedFile() + if is_extension: + return _message.default_pool.FindExtensionByName(full_name) + else: + return _message.default_pool.FindFieldByName(full_name) + def __init__(self, name, full_name, index, number, type, cpp_type, label, default_value, message_type, enum_type, containing_type, is_extension, extension_scope, options=None, @@ -466,6 +518,7 @@ class FieldDescriptor(DescriptorBase): super(FieldDescriptor, self).__init__(options, 'FieldOptions') self.name = name self.full_name = full_name + self._camelcase_name = None self.index = index self.number = number self.type = type @@ -481,23 +534,18 @@ class FieldDescriptor(DescriptorBase): self.containing_oneof = containing_oneof if api_implementation.Type() == 'cpp': if is_extension: - if api_implementation.Version() == 2: - # pylint: disable=protected-access - self._cdescriptor = ( - _message.Message._GetExtensionDescriptor(full_name)) - # pylint: enable=protected-access - else: - self._cdescriptor = cpp_message.GetExtensionDescriptor(full_name) + self._cdescriptor = _message.default_pool.FindExtensionByName(full_name) else: - if api_implementation.Version() == 2: - # pylint: disable=protected-access - self._cdescriptor = _message.Message._GetFieldDescriptor(full_name) - # pylint: enable=protected-access - else: - self._cdescriptor = cpp_message.GetFieldDescriptor(full_name) + self._cdescriptor = _message.default_pool.FindFieldByName(full_name) else: self._cdescriptor = None + @property + def camelcase_name(self): + if self._camelcase_name is None: + self._camelcase_name = _ToCamelCase(self.name) + return self._camelcase_name + @staticmethod def ProtoTypeToCppProtoType(proto_type): """Converts from a Python proto type to a C++ Proto Type. @@ -544,6 +592,15 @@ class EnumDescriptor(_NestedDescriptorBase): None to use default enum options. """ + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.EnumDescriptor + + def __new__(cls, name, full_name, filename, values, + containing_type=None, options=None, file=None, + serialized_start=None, serialized_end=None): + _message.Message._CheckCalledFromGeneratedFile() + return _message.default_pool.FindEnumTypeByName(full_name) + def __init__(self, name, full_name, filename, values, containing_type=None, options=None, file=None, serialized_start=None, serialized_end=None): @@ -588,6 +645,17 @@ class EnumValueDescriptor(DescriptorBase): None to use default enum value options options. """ + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.EnumValueDescriptor + + def __new__(cls, name, index, number, type=None, options=None): + _message.Message._CheckCalledFromGeneratedFile() + # There is no way we can build a complete EnumValueDescriptor with the + # given parameters (the name of the Enum is not known, for example). + # Fortunately generated files just pass it to the EnumDescriptor() + # constructor, which will ignore it, so returning None is good enough. + return None + def __init__(self, name, index, number, type=None, options=None): """Arguments are as described in the attribute description above.""" super(EnumValueDescriptor, self).__init__(options, 'EnumValueOptions') @@ -611,6 +679,13 @@ class OneofDescriptor(object): oneof can contain. """ + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.OneofDescriptor + + def __new__(cls, name, full_name, index, containing_type, fields): + _message.Message._CheckCalledFromGeneratedFile() + return _message.default_pool.FindOneofByName(full_name) + def __init__(self, name, full_name, index, containing_type, fields): """Arguments are as described in the attribute description above.""" self.name = name @@ -704,36 +779,58 @@ class FileDescriptor(DescriptorBase): name: name of file, relative to root of source tree. package: name of the package + syntax: string indicating syntax of the file (can be "proto2" or "proto3") serialized_pb: (str) Byte string of serialized descriptor_pb2.FileDescriptorProto. dependencies: List of other FileDescriptors this FileDescriptor depends on. + public_dependencies: A list of FileDescriptors, subset of the dependencies + above, which were declared as "public". message_types_by_name: Dict of message names of their descriptors. enum_types_by_name: Dict of enum names and their descriptors. extensions_by_name: Dict of extension names and their descriptors. + pool: the DescriptorPool this descriptor belongs to. When not passed to the + constructor, the global default pool is used. """ + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.FileDescriptor + + def __new__(cls, name, package, options=None, serialized_pb=None, + dependencies=None, public_dependencies=None, + syntax=None, pool=None): + # FileDescriptor() is called from various places, not only from generated + # files, to register dynamic proto files and messages. + if serialized_pb: + # TODO(amauryfa): use the pool passed as argument. This will work only + # for C++-implemented DescriptorPools. + return _message.default_pool.AddSerializedFile(serialized_pb) + else: + return super(FileDescriptor, cls).__new__(cls) + def __init__(self, name, package, options=None, serialized_pb=None, - dependencies=None): + dependencies=None, public_dependencies=None, + syntax=None, pool=None): """Constructor.""" super(FileDescriptor, self).__init__(options, 'FileOptions') + if pool is None: + from google.protobuf import descriptor_pool + pool = descriptor_pool.Default() + self.pool = pool self.message_types_by_name = {} self.name = name self.package = package + self.syntax = syntax or "proto2" self.serialized_pb = serialized_pb self.enum_types_by_name = {} self.extensions_by_name = {} self.dependencies = (dependencies or []) + self.public_dependencies = (public_dependencies or []) if (api_implementation.Type() == 'cpp' and self.serialized_pb is not None): - if api_implementation.Version() == 2: - # pylint: disable=protected-access - _message.Message._BuildFile(self.serialized_pb) - # pylint: enable=protected-access - else: - cpp_message.BuildFile(self.serialized_pb) + _message.default_pool.AddSerializedFile(self.serialized_pb) def CopyToProto(self, proto): """Copies this to a descriptor_pb2.FileDescriptorProto. @@ -754,7 +851,29 @@ def _ParseOptions(message, string): return message -def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True): +def _ToCamelCase(name): + """Converts name to camel-case and returns it.""" + capitalize_next = False + result = [] + + for c in name: + if c == '_': + if result: + capitalize_next = True + elif capitalize_next: + result.append(c.upper()) + capitalize_next = False + else: + result += c + + # Lower-case the first letter. + if result and result[0].isupper(): + result[0] = result[0].lower() + return ''.join(result) + + +def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, + syntax=None): """Make a protobuf Descriptor given a DescriptorProto protobuf. Handles nested descriptors. Note that this is limited to the scope of defining @@ -766,6 +885,8 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True): package: Optional package name for the new message Descriptor (string). build_file_if_cpp: Update the C++ descriptor pool if api matches. Set to False on recursion, so no duplicates are created. + syntax: The syntax/semantics that should be used. Set to "proto3" to get + proto3 field presence semantics. Returns: A Descriptor for protobuf messages. """ @@ -778,10 +899,10 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True): file_descriptor_proto = descriptor_pb2.FileDescriptorProto() file_descriptor_proto.message_type.add().MergeFrom(desc_proto) - # Generate a random name for this proto file to prevent conflicts with - # any imported ones. We need to specify a file name so BuildFile accepts - # our FileDescriptorProto, but it is not important what that file name - # is actually set to. + # Generate a random name for this proto file to prevent conflicts with any + # imported ones. We need to specify a file name so the descriptor pool + # accepts our FileDescriptorProto, but it is not important what that file + # name is actually set to. proto_name = str(uuid.uuid4()) if package: @@ -791,12 +912,11 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True): else: file_descriptor_proto.name = proto_name + '.proto' - if api_implementation.Version() == 2: - # pylint: disable=protected-access - _message.Message._BuildFile(file_descriptor_proto.SerializeToString()) - # pylint: enable=protected-access - else: - cpp_message.BuildFile(file_descriptor_proto.SerializeToString()) + _message.default_pool.Add(file_descriptor_proto) + result = _message.default_pool.FindFileByName(file_descriptor_proto.name) + + if _USE_C_DESCRIPTORS: + return result.message_types_by_name[desc_proto.name] full_message_name = [desc_proto.name] if package: full_message_name.insert(0, package) @@ -819,7 +939,8 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True): # used by fields in the message, so no loops are possible here. nested_desc = MakeDescriptor(nested_proto, package='.'.join(full_message_name), - build_file_if_cpp=False) + build_file_if_cpp=False, + syntax=syntax) nested_types[full_name] = nested_desc fields = [] @@ -841,9 +962,10 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True): field_proto.number, field_proto.type, FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type), field_proto.label, None, nested_desc, enum_desc, None, False, None, - has_default_value=False) + options=field_proto.options, has_default_value=False) fields.append(field) desc_name = '.'.join(full_message_name) return Descriptor(desc_proto.name, desc_name, None, None, fields, - nested_types.values(), enum_types.values(), []) + list(nested_types.values()), list(enum_types.values()), [], + options=desc_proto.options) diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py index 55fb8c707..1333f9966 100644 --- a/python/google/protobuf/descriptor_database.py +++ b/python/google/protobuf/descriptor_database.py @@ -65,6 +65,7 @@ class DescriptorDatabase(object): raise DescriptorDatabaseConflictingDefinitionError( '%s already added, but with different descriptor.' % proto_name) + # Add the top-level Message, Enum and Extension descriptors to the index. package = file_desc_proto.package for message in file_desc_proto.message_type: self._file_desc_protos_by_symbol.update( @@ -72,6 +73,9 @@ class DescriptorDatabase(object): for enum in file_desc_proto.enum_type: self._file_desc_protos_by_symbol[ '.'.join((package, enum.name))] = file_desc_proto + for extension in file_desc_proto.extension: + self._file_desc_protos_by_symbol[ + '.'.join((package, extension.name))] = file_desc_proto def FindFileByName(self, name): """Finds the file descriptor proto by file name. @@ -133,5 +137,5 @@ def _ExtractSymbols(desc_proto, package): for nested_type in desc_proto.nested_type: for symbol in _ExtractSymbols(nested_type, message_name): yield symbol - for enum_type in desc_proto.enum_type: - yield '.'.join((message_name, enum_type.name)) + for enum_type in desc_proto.enum_type: + yield '.'.join((message_name, enum_type.name)) diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index cf234cfae..20a337017 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -57,13 +57,14 @@ directly instead of this class. __author__ = 'matthewtoia@google.com (Matt Toia)' -import sys - from google.protobuf import descriptor from google.protobuf import descriptor_database from google.protobuf import text_encoding +_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS + + def _NormalizeFullyQualifiedName(name): """Remove leading period from fully-qualified type name. @@ -82,6 +83,12 @@ def _NormalizeFullyQualifiedName(name): class DescriptorPool(object): """A collection of protobufs dynamically constructed by descriptor protos.""" + if _USE_C_DESCRIPTORS: + + def __new__(cls, descriptor_db=None): + # pylint: disable=protected-access + return descriptor._message.DescriptorPool(descriptor_db) + def __init__(self, descriptor_db=None): """Initializes a Pool of proto buffs. @@ -110,6 +117,20 @@ class DescriptorPool(object): self._internal_db.Add(file_desc_proto) + def AddSerializedFile(self, serialized_file_desc_proto): + """Adds the FileDescriptorProto and its types to this pool. + + Args: + serialized_file_desc_proto: A bytes string, serialization of the + FileDescriptorProto to add. + """ + + # pylint: disable=g-import-not-at-top + from google.protobuf import descriptor_pb2 + file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( + serialized_file_desc_proto) + self.Add(file_desc_proto) + def AddDescriptor(self, desc): """Adds a Descriptor to the pool, non-recursively. @@ -175,8 +196,7 @@ class DescriptorPool(object): try: file_proto = self._internal_db.FindFileByName(file_name) - except KeyError: - _, error, _ = sys.exc_info() #PY25 compatible for GAE. + except KeyError as error: if self._descriptor_db: file_proto = self._descriptor_db.FindFileByName(file_name) else: @@ -211,8 +231,7 @@ class DescriptorPool(object): try: file_proto = self._internal_db.FindFileContainingSymbol(symbol) - except KeyError: - _, error, _ = sys.exc_info() #PY25 compatible for GAE. + except KeyError as error: if self._descriptor_db: file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) else: @@ -251,6 +270,39 @@ class DescriptorPool(object): self.FindFileContainingSymbol(full_name) return self._enum_descriptors[full_name] + def FindFieldByName(self, full_name): + """Loads the named field descriptor from the pool. + + Args: + full_name: The full name of the field descriptor to load. + + Returns: + The field descriptor for the named field. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + message_name, _, field_name = full_name.rpartition('.') + message_descriptor = self.FindMessageTypeByName(message_name) + return message_descriptor.fields_by_name[field_name] + + def FindExtensionByName(self, full_name): + """Loads the named extension descriptor from the pool. + + Args: + full_name: The full name of the extension descriptor to load. + + Returns: + A FieldDescriptor, describing the named extension. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + message_name, _, extension_name = full_name.rpartition('.') + try: + # Most extensions are nested inside a message. + scope = self.FindMessageTypeByName(message_name) + except KeyError: + # Some extensions are defined at file scope. + scope = self.FindFileContainingSymbol(full_name) + return scope.extensions_by_name[extension_name] + def _ConvertFileProtoToFileDescriptor(self, file_proto): """Creates a FileDescriptor from a proto or returns a cached copy. @@ -267,62 +319,88 @@ class DescriptorPool(object): if file_proto.name not in self._file_descriptors: built_deps = list(self._GetDeps(file_proto.dependency)) direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] + public_deps = [direct_deps[i] for i in file_proto.public_dependency] file_descriptor = descriptor.FileDescriptor( + pool=self, name=file_proto.name, package=file_proto.package, + syntax=file_proto.syntax, options=file_proto.options, serialized_pb=file_proto.SerializeToString(), - dependencies=direct_deps) - scope = {} - - # This loop extracts all the message and enum types from all the - # dependencoes of the file_proto. This is necessary to create the - # scope of available message types when defining the passed in - # file proto. - for dependency in built_deps: - scope.update(self._ExtractSymbols( - dependency.message_types_by_name.values())) - scope.update((_PrefixWithDot(enum.full_name), enum) - for enum in dependency.enum_types_by_name.values()) - - for message_type in file_proto.message_type: - message_desc = self._ConvertMessageDescriptor( - message_type, file_proto.package, file_descriptor, scope) - file_descriptor.message_types_by_name[message_desc.name] = message_desc - - for enum_type in file_proto.enum_type: - file_descriptor.enum_types_by_name[enum_type.name] = ( - self._ConvertEnumDescriptor(enum_type, file_proto.package, - file_descriptor, None, scope)) - - for index, extension_proto in enumerate(file_proto.extension): - extension_desc = self.MakeFieldDescriptor( - extension_proto, file_proto.package, index, is_extension=True) - extension_desc.containing_type = self._GetTypeFromScope( - file_descriptor.package, extension_proto.extendee, scope) - self.SetFieldType(extension_proto, extension_desc, - file_descriptor.package, scope) - file_descriptor.extensions_by_name[extension_desc.name] = extension_desc - - for desc_proto in file_proto.message_type: - self.SetAllFieldTypes(file_proto.package, desc_proto, scope) - - if file_proto.package: - desc_proto_prefix = _PrefixWithDot(file_proto.package) + dependencies=direct_deps, + public_dependencies=public_deps) + if _USE_C_DESCRIPTORS: + # When using C++ descriptors, all objects defined in the file were added + # to the C++ database when the FileDescriptor was built above. + # Just add them to this descriptor pool. + def _AddMessageDescriptor(message_desc): + self._descriptors[message_desc.full_name] = message_desc + for nested in message_desc.nested_types: + _AddMessageDescriptor(nested) + for enum_type in message_desc.enum_types: + _AddEnumDescriptor(enum_type) + def _AddEnumDescriptor(enum_desc): + self._enum_descriptors[enum_desc.full_name] = enum_desc + for message_type in file_descriptor.message_types_by_name.values(): + _AddMessageDescriptor(message_type) + for enum_type in file_descriptor.enum_types_by_name.values(): + _AddEnumDescriptor(enum_type) else: - desc_proto_prefix = '' + scope = {} + + # This loop extracts all the message and enum types from all the + # dependencies of the file_proto. This is necessary to create the + # scope of available message types when defining the passed in + # file proto. + for dependency in built_deps: + scope.update(self._ExtractSymbols( + dependency.message_types_by_name.values())) + scope.update((_PrefixWithDot(enum.full_name), enum) + for enum in dependency.enum_types_by_name.values()) + + for message_type in file_proto.message_type: + message_desc = self._ConvertMessageDescriptor( + message_type, file_proto.package, file_descriptor, scope, + file_proto.syntax) + file_descriptor.message_types_by_name[message_desc.name] = ( + message_desc) + + for enum_type in file_proto.enum_type: + file_descriptor.enum_types_by_name[enum_type.name] = ( + self._ConvertEnumDescriptor(enum_type, file_proto.package, + file_descriptor, None, scope)) + + for index, extension_proto in enumerate(file_proto.extension): + extension_desc = self._MakeFieldDescriptor( + extension_proto, file_proto.package, index, is_extension=True) + extension_desc.containing_type = self._GetTypeFromScope( + file_descriptor.package, extension_proto.extendee, scope) + self._SetFieldType(extension_proto, extension_desc, + file_descriptor.package, scope) + file_descriptor.extensions_by_name[extension_desc.name] = ( + extension_desc) + + for desc_proto in file_proto.message_type: + self._SetAllFieldTypes(file_proto.package, desc_proto, scope) + + if file_proto.package: + desc_proto_prefix = _PrefixWithDot(file_proto.package) + else: + desc_proto_prefix = '' + + for desc_proto in file_proto.message_type: + desc = self._GetTypeFromScope( + desc_proto_prefix, desc_proto.name, scope) + file_descriptor.message_types_by_name[desc_proto.name] = desc - for desc_proto in file_proto.message_type: - desc = self._GetTypeFromScope(desc_proto_prefix, desc_proto.name, scope) - file_descriptor.message_types_by_name[desc_proto.name] = desc self.Add(file_proto) self._file_descriptors[file_proto.name] = file_descriptor return self._file_descriptors[file_proto.name] def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None, - scope=None): + scope=None, syntax=None): """Adds the proto to the pool in the specified package. Args: @@ -349,15 +427,17 @@ class DescriptorPool(object): scope = {} nested = [ - self._ConvertMessageDescriptor(nested, desc_name, file_desc, scope) + self._ConvertMessageDescriptor( + nested, desc_name, file_desc, scope, syntax) for nested in desc_proto.nested_type] enums = [ self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope) for enum in desc_proto.enum_type] - fields = [self.MakeFieldDescriptor(field, desc_name, index) + fields = [self._MakeFieldDescriptor(field, desc_name, index) for index, field in enumerate(desc_proto.field)] extensions = [ - self.MakeFieldDescriptor(extension, desc_name, index, is_extension=True) + self._MakeFieldDescriptor(extension, desc_name, index, + is_extension=True) for index, extension in enumerate(desc_proto.extension)] oneofs = [ descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)), @@ -383,7 +463,8 @@ class DescriptorPool(object): extension_ranges=extension_ranges, file=file_desc, serialized_start=None, - serialized_end=None) + serialized_end=None, + syntax=syntax) for nested in desc.nested_types: nested.containing_type = desc for enum in desc.enum_types: @@ -436,8 +517,8 @@ class DescriptorPool(object): self._enum_descriptors[enum_name] = desc return desc - def MakeFieldDescriptor(self, field_proto, message_name, index, - is_extension=False): + def _MakeFieldDescriptor(self, field_proto, message_name, index, + is_extension=False): """Creates a field descriptor from a FieldDescriptorProto. For message and enum type fields, this method will do a look up @@ -478,7 +559,7 @@ class DescriptorPool(object): extension_scope=None, options=field_proto.options) - def SetAllFieldTypes(self, package, desc_proto, scope): + def _SetAllFieldTypes(self, package, desc_proto, scope): """Sets all the descriptor's fields's types. This method also sets the containing types on any extensions. @@ -499,18 +580,18 @@ class DescriptorPool(object): nested_package = '.'.join([package, desc_proto.name]) for field_proto, field_desc in zip(desc_proto.field, main_desc.fields): - self.SetFieldType(field_proto, field_desc, nested_package, scope) + self._SetFieldType(field_proto, field_desc, nested_package, scope) for extension_proto, extension_desc in ( zip(desc_proto.extension, main_desc.extensions)): extension_desc.containing_type = self._GetTypeFromScope( nested_package, extension_proto.extendee, scope) - self.SetFieldType(extension_proto, extension_desc, nested_package, scope) + self._SetFieldType(extension_proto, extension_desc, nested_package, scope) for nested_type in desc_proto.nested_type: - self.SetAllFieldTypes(nested_package, nested_type, scope) + self._SetAllFieldTypes(nested_package, nested_type, scope) - def SetFieldType(self, field_proto, field_desc, package, scope): + def _SetFieldType(self, field_proto, field_desc, package, scope): """Sets the field's type, cpp_type, message_type and enum_type. Args: @@ -554,15 +635,29 @@ class DescriptorPool(object): field_desc.default_value = field_proto.default_value.lower() == 'true' elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: field_desc.default_value = field_desc.enum_type.values_by_name[ - field_proto.default_value].index + field_proto.default_value].number elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: field_desc.default_value = text_encoding.CUnescape( field_proto.default_value) else: + # All other types are of the "int" type. field_desc.default_value = int(field_proto.default_value) else: field_desc.has_default_value = False - field_desc.default_value = None + if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or + field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): + field_desc.default_value = 0.0 + elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: + field_desc.default_value = u'' + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: + field_desc.default_value = False + elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.default_value = field_desc.enum_type.values[0].number + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: + field_desc.default_value = b'' + else: + # All other types are of the "int" type. + field_desc.default_value = 0 field_desc.type = field_proto.type @@ -641,3 +736,16 @@ class DescriptorPool(object): def _PrefixWithDot(name): return name if name.startswith('.') else '.%s' % name + + +if _USE_C_DESCRIPTORS: + # TODO(amauryfa): This pool could be constructed from Python code, when we + # support a flag like 'use_cpp_generated_pool=True'. + # pylint: disable=protected-access + _DEFAULT = descriptor._message.default_pool +else: + _DEFAULT = DescriptorPool() + + +def Default(): + return _DEFAULT diff --git a/python/google/protobuf/internal/_parameterized.py b/python/google/protobuf/internal/_parameterized.py new file mode 100755 index 000000000..dea3f1997 --- /dev/null +++ b/python/google/protobuf/internal/_parameterized.py @@ -0,0 +1,443 @@ +#! /usr/bin/env python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Adds support for parameterized tests to Python's unittest TestCase class. + +A parameterized test is a method in a test case that is invoked with different +argument tuples. + +A simple example: + + class AdditionExample(parameterized.ParameterizedTestCase): + @parameterized.Parameters( + (1, 2, 3), + (4, 5, 9), + (1, 1, 3)) + def testAddition(self, op1, op2, result): + self.assertEqual(result, op1 + op2) + + +Each invocation is a separate test case and properly isolated just +like a normal test method, with its own setUp/tearDown cycle. In the +example above, there are three separate testcases, one of which will +fail due to an assertion error (1 + 1 != 3). + +Parameters for invididual test cases can be tuples (with positional parameters) +or dictionaries (with named parameters): + + class AdditionExample(parameterized.ParameterizedTestCase): + @parameterized.Parameters( + {'op1': 1, 'op2': 2, 'result': 3}, + {'op1': 4, 'op2': 5, 'result': 9}, + ) + def testAddition(self, op1, op2, result): + self.assertEqual(result, op1 + op2) + +If a parameterized test fails, the error message will show the +original test name (which is modified internally) and the arguments +for the specific invocation, which are part of the string returned by +the shortDescription() method on test cases. + +The id method of the test, used internally by the unittest framework, +is also modified to show the arguments. To make sure that test names +stay the same across several invocations, object representations like + + >>> class Foo(object): + ... pass + >>> repr(Foo()) + '<__main__.Foo object at 0x23d8610>' + +are turned into '<__main__.Foo>'. For even more descriptive names, +especially in test logs, you can use the NamedParameters decorator. In +this case, only tuples are supported, and the first parameters has to +be a string (or an object that returns an apt name when converted via +str()): + + class NamedExample(parameterized.ParameterizedTestCase): + @parameterized.NamedParameters( + ('Normal', 'aa', 'aaa', True), + ('EmptyPrefix', '', 'abc', True), + ('BothEmpty', '', '', True)) + def testStartsWith(self, prefix, string, result): + self.assertEqual(result, strings.startswith(prefix)) + +Named tests also have the benefit that they can be run individually +from the command line: + + $ testmodule.py NamedExample.testStartsWithNormal + . + -------------------------------------------------------------------- + Ran 1 test in 0.000s + + OK + +Parameterized Classes +===================== +If invocation arguments are shared across test methods in a single +ParameterizedTestCase class, instead of decorating all test methods +individually, the class itself can be decorated: + + @parameterized.Parameters( + (1, 2, 3) + (4, 5, 9)) + class ArithmeticTest(parameterized.ParameterizedTestCase): + def testAdd(self, arg1, arg2, result): + self.assertEqual(arg1 + arg2, result) + + def testSubtract(self, arg2, arg2, result): + self.assertEqual(result - arg1, arg2) + +Inputs from Iterables +===================== +If parameters should be shared across several test cases, or are dynamically +created from other sources, a single non-tuple iterable can be passed into +the decorator. This iterable will be used to obtain the test cases: + + class AdditionExample(parameterized.ParameterizedTestCase): + @parameterized.Parameters( + c.op1, c.op2, c.result for c in testcases + ) + def testAddition(self, op1, op2, result): + self.assertEqual(result, op1 + op2) + + +Single-Argument Test Methods +============================ +If a test method takes only one argument, the single argument does not need to +be wrapped into a tuple: + + class NegativeNumberExample(parameterized.ParameterizedTestCase): + @parameterized.Parameters( + -1, -3, -4, -5 + ) + def testIsNegative(self, arg): + self.assertTrue(IsNegative(arg)) +""" + +__author__ = 'tmarek@google.com (Torsten Marek)' + +import collections +import functools +import re +import types +try: + import unittest2 as unittest +except ImportError: + import unittest +import uuid + +import six + +ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>') +_SEPARATOR = uuid.uuid1().hex +_FIRST_ARG = object() +_ARGUMENT_REPR = object() + + +def _CleanRepr(obj): + return ADDR_RE.sub(r'<\1>', repr(obj)) + + +# Helper function formerly from the unittest module, removed from it in +# Python 2.7. +def _StrClass(cls): + return '%s.%s' % (cls.__module__, cls.__name__) + + +def _NonStringIterable(obj): + return (isinstance(obj, collections.Iterable) and not + isinstance(obj, six.string_types)) + + +def _FormatParameterList(testcase_params): + if isinstance(testcase_params, collections.Mapping): + return ', '.join('%s=%s' % (argname, _CleanRepr(value)) + for argname, value in testcase_params.items()) + elif _NonStringIterable(testcase_params): + return ', '.join(map(_CleanRepr, testcase_params)) + else: + return _FormatParameterList((testcase_params,)) + + +class _ParameterizedTestIter(object): + """Callable and iterable class for producing new test cases.""" + + def __init__(self, test_method, testcases, naming_type): + """Returns concrete test functions for a test and a list of parameters. + + The naming_type is used to determine the name of the concrete + functions as reported by the unittest framework. If naming_type is + _FIRST_ARG, the testcases must be tuples, and the first element must + have a string representation that is a valid Python identifier. + + Args: + test_method: The decorated test method. + testcases: (list of tuple/dict) A list of parameter + tuples/dicts for individual test invocations. + naming_type: The test naming type, either _NAMED or _ARGUMENT_REPR. + """ + self._test_method = test_method + self.testcases = testcases + self._naming_type = naming_type + + def __call__(self, *args, **kwargs): + raise RuntimeError('You appear to be running a parameterized test case ' + 'without having inherited from parameterized.' + 'ParameterizedTestCase. This is bad because none of ' + 'your test cases are actually being run.') + + def __iter__(self): + test_method = self._test_method + naming_type = self._naming_type + + def MakeBoundParamTest(testcase_params): + @functools.wraps(test_method) + def BoundParamTest(self): + if isinstance(testcase_params, collections.Mapping): + test_method(self, **testcase_params) + elif _NonStringIterable(testcase_params): + test_method(self, *testcase_params) + else: + test_method(self, testcase_params) + + if naming_type is _FIRST_ARG: + # Signal the metaclass that the name of the test function is unique + # and descriptive. + BoundParamTest.__x_use_name__ = True + BoundParamTest.__name__ += str(testcase_params[0]) + testcase_params = testcase_params[1:] + elif naming_type is _ARGUMENT_REPR: + # __x_extra_id__ is used to pass naming information to the __new__ + # method of TestGeneratorMetaclass. + # The metaclass will make sure to create a unique, but nondescriptive + # name for this test. + BoundParamTest.__x_extra_id__ = '(%s)' % ( + _FormatParameterList(testcase_params),) + else: + raise RuntimeError('%s is not a valid naming type.' % (naming_type,)) + + BoundParamTest.__doc__ = '%s(%s)' % ( + BoundParamTest.__name__, _FormatParameterList(testcase_params)) + if test_method.__doc__: + BoundParamTest.__doc__ += '\n%s' % (test_method.__doc__,) + return BoundParamTest + return (MakeBoundParamTest(c) for c in self.testcases) + + +def _IsSingletonList(testcases): + """True iff testcases contains only a single non-tuple element.""" + return len(testcases) == 1 and not isinstance(testcases[0], tuple) + + +def _ModifyClass(class_object, testcases, naming_type): + assert not getattr(class_object, '_id_suffix', None), ( + 'Cannot add parameters to %s,' + ' which already has parameterized methods.' % (class_object,)) + class_object._id_suffix = id_suffix = {} + # We change the size of __dict__ while we iterate over it, + # which Python 3.x will complain about, so use copy(). + for name, obj in class_object.__dict__.copy().items(): + if (name.startswith(unittest.TestLoader.testMethodPrefix) + and isinstance(obj, types.FunctionType)): + delattr(class_object, name) + methods = {} + _UpdateClassDictForParamTestCase( + methods, id_suffix, name, + _ParameterizedTestIter(obj, testcases, naming_type)) + for name, meth in methods.items(): + setattr(class_object, name, meth) + + +def _ParameterDecorator(naming_type, testcases): + """Implementation of the parameterization decorators. + + Args: + naming_type: The naming type. + testcases: Testcase parameters. + + Returns: + A function for modifying the decorated object. + """ + def _Apply(obj): + if isinstance(obj, type): + _ModifyClass( + obj, + list(testcases) if not isinstance(testcases, collections.Sequence) + else testcases, + naming_type) + return obj + else: + return _ParameterizedTestIter(obj, testcases, naming_type) + + if _IsSingletonList(testcases): + assert _NonStringIterable(testcases[0]), ( + 'Single parameter argument must be a non-string iterable') + testcases = testcases[0] + + return _Apply + + +def Parameters(*testcases): + """A decorator for creating parameterized tests. + + See the module docstring for a usage example. + Args: + *testcases: Parameters for the decorated method, either a single + iterable, or a list of tuples/dicts/objects (for tests + with only one argument). + + Returns: + A test generator to be handled by TestGeneratorMetaclass. + """ + return _ParameterDecorator(_ARGUMENT_REPR, testcases) + + +def NamedParameters(*testcases): + """A decorator for creating parameterized tests. + + See the module docstring for a usage example. The first element of + each parameter tuple should be a string and will be appended to the + name of the test method. + + Args: + *testcases: Parameters for the decorated method, either a single + iterable, or a list of tuples. + + Returns: + A test generator to be handled by TestGeneratorMetaclass. + """ + return _ParameterDecorator(_FIRST_ARG, testcases) + + +class TestGeneratorMetaclass(type): + """Metaclass for test cases with test generators. + + A test generator is an iterable in a testcase that produces callables. These + callables must be single-argument methods. These methods are injected into + the class namespace and the original iterable is removed. If the name of the + iterable conforms to the test pattern, the injected methods will be picked + up as tests by the unittest framework. + + In general, it is supposed to be used in conjuction with the + Parameters decorator. + """ + + def __new__(mcs, class_name, bases, dct): + dct['_id_suffix'] = id_suffix = {} + for name, obj in dct.items(): + if (name.startswith(unittest.TestLoader.testMethodPrefix) and + _NonStringIterable(obj)): + iterator = iter(obj) + dct.pop(name) + _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator) + + return type.__new__(mcs, class_name, bases, dct) + + +def _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator): + """Adds individual test cases to a dictionary. + + Args: + dct: The target dictionary. + id_suffix: The dictionary for mapping names to test IDs. + name: The original name of the test case. + iterator: The iterator generating the individual test cases. + """ + for idx, func in enumerate(iterator): + assert callable(func), 'Test generators must yield callables, got %r' % ( + func,) + if getattr(func, '__x_use_name__', False): + new_name = func.__name__ + else: + new_name = '%s%s%d' % (name, _SEPARATOR, idx) + assert new_name not in dct, ( + 'Name of parameterized test case "%s" not unique' % (new_name,)) + dct[new_name] = func + id_suffix[new_name] = getattr(func, '__x_extra_id__', '') + + +class ParameterizedTestCase(unittest.TestCase): + """Base class for test cases using the Parameters decorator.""" + __metaclass__ = TestGeneratorMetaclass + + def _OriginalName(self): + return self._testMethodName.split(_SEPARATOR)[0] + + def __str__(self): + return '%s (%s)' % (self._OriginalName(), _StrClass(self.__class__)) + + def id(self): # pylint: disable=invalid-name + """Returns the descriptive ID of the test. + + This is used internally by the unittesting framework to get a name + for the test to be used in reports. + + Returns: + The test id. + """ + return '%s.%s%s' % (_StrClass(self.__class__), + self._OriginalName(), + self._id_suffix.get(self._testMethodName, '')) + + +def CoopParameterizedTestCase(other_base_class): + """Returns a new base class with a cooperative metaclass base. + + This enables the ParameterizedTestCase to be used in combination + with other base classes that have custom metaclasses, such as + mox.MoxTestBase. + + Only works with metaclasses that do not override type.__new__. + + Example: + + import google3 + import mox + + from google3.testing.pybase import parameterized + + class ExampleTest(parameterized.CoopParameterizedTestCase(mox.MoxTestBase)): + ... + + Args: + other_base_class: (class) A test case base class. + + Returns: + A new class object. + """ + metaclass = type( + 'CoopMetaclass', + (other_base_class.__metaclass__, + TestGeneratorMetaclass), {}) + return metaclass( + 'CoopParameterizedTestCase', + (other_base_class, ParameterizedTestCase), {}) diff --git a/python/google/protobuf/internal/any_test.proto b/python/google/protobuf/internal/any_test.proto new file mode 100644 index 000000000..cd641ca0b --- /dev/null +++ b/python/google/protobuf/internal/any_test.proto @@ -0,0 +1,42 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: jieluo@google.com (Jie Luo) + +syntax = "proto3"; + +package google.protobuf.internal; + +import "google/protobuf/any.proto"; + +message TestAny { + google.protobuf.Any value = 1; + int32 int_value = 2; +} diff --git a/python/google/protobuf/internal/api_implementation.cc b/python/google/protobuf/internal/api_implementation.cc index 83db40b11..6db12e8dc 100644 --- a/python/google/protobuf/internal/api_implementation.cc +++ b/python/google/protobuf/internal/api_implementation.cc @@ -50,10 +50,7 @@ namespace python { // and // PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 #ifdef PYTHON_PROTO2_CPP_IMPL_V1 -#if PY_MAJOR_VERSION >= 3 -#error "PYTHON_PROTO2_CPP_IMPL_V1 is not supported under Python 3." -#endif -static int kImplVersion = 1; +#error "PYTHON_PROTO2_CPP_IMPL_V1 is no longer supported." #else #ifdef PYTHON_PROTO2_CPP_IMPL_V2 static int kImplVersion = 2; @@ -62,14 +59,7 @@ static int kImplVersion = 2; static int kImplVersion = 0; #else -// The defaults are set here. Python 3 uses the fast C++ APIv2 by default. -// Python 2 still uses the Python version by default until some compatibility -// issues can be worked around. -#if PY_MAJOR_VERSION >= 3 -static int kImplVersion = 2; -#else -static int kImplVersion = 0; -#endif +static int kImplVersion = -1; // -1 means "Unspecified by compiler flags". #endif // PYTHON_PROTO2_PYTHON_IMPL #endif // PYTHON_PROTO2_CPP_IMPL_V2 diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index f7926c16a..460a4a6c2 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -32,6 +32,7 @@ """ import os +import warnings import sys try: @@ -40,14 +41,33 @@ try: # The compile-time constants in the _api_implementation module can be used to # switch to a certain implementation of the Python API at build time. _api_version = _api_implementation.api_version - del _api_implementation + _proto_extension_modules_exist_in_build = True except ImportError: - _api_version = 0 + _api_version = -1 # Unspecified by compiler flags. + _proto_extension_modules_exist_in_build = False + +if _api_version == 1: + raise ValueError('api_version=1 is no longer supported.') +if _api_version < 0: # Still unspecified? + try: + # The presence of this module in a build allows the proto implementation to + # be upgraded merely via build deps rather than a compiler flag or the + # runtime environment variable. + # pylint: disable=g-import-not-at-top + from google.protobuf import _use_fast_cpp_protos + # Work around a known issue in the classic bootstrap .par import hook. + if not _use_fast_cpp_protos: + raise ImportError('_use_fast_cpp_protos import succeeded but was None') + del _use_fast_cpp_protos + _api_version = 2 + except ImportError: + if _proto_extension_modules_exist_in_build: + if sys.version_info[0] >= 3: # Python 3 defaults to C++ impl v2. + _api_version = 2 + # TODO(b/17427486): Make Python 2 default to C++ impl v2. _default_implementation_type = ( - 'python' if _api_version == 0 else 'cpp') -_default_version_str = ( - '1' if _api_version <= 1 else '2') + 'python' if _api_version <= 0 else 'cpp') # This environment variable can be used to switch to a certain implementation # of the Python API, overriding the compile-time constants in the @@ -59,18 +79,22 @@ _implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', if _implementation_type != 'python': _implementation_type = 'cpp' +if 'PyPy' in sys.version and _implementation_type == 'cpp': + warnings.warn('PyPy does not work yet with cpp protocol buffers. ' + 'Falling back to the python implementation.') + _implementation_type = 'python' + # This environment variable can be used to switch between the two # 'cpp' implementations, overriding the compile-time constants in the -# _api_implementation module. Right now only 1 and 2 are valid values. Any other -# value will be ignored. +# _api_implementation module. Right now only '2' is supported. Any other +# value will cause an error to be raised. _implementation_version_str = os.getenv( - 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', - _default_version_str) + 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', '2') -if _implementation_version_str not in ('1', '2'): +if _implementation_version_str != '2': raise ValueError( - "unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: '" + - _implementation_version_str + "' (supported versions: 1, 2)" + 'unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: "' + + _implementation_version_str + '" (supported versions: 2)' ) _implementation_version = int(_implementation_version_str) diff --git a/python/google/protobuf/internal/api_implementation_default_test.py b/python/google/protobuf/internal/api_implementation_default_test.py deleted file mode 100644 index 78d5cf232..000000000 --- a/python/google/protobuf/internal/api_implementation_default_test.py +++ /dev/null @@ -1,63 +0,0 @@ -#! /usr/bin/python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Test that the api_implementation defaults are what we expect.""" - -import os -import sys -# Clear environment implementation settings before the google3 imports. -os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', None) -os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', None) - -# pylint: disable=g-import-not-at-top -from google.apputils import basetest -from google.protobuf.internal import api_implementation - - -class ApiImplementationDefaultTest(basetest.TestCase): - - if sys.version_info.major <= 2: - - def testThatPythonIsTheDefault(self): - """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail.""" - self.assertEqual('python', api_implementation.Type()) - - else: - - def testThatCppApiV2IsTheDefault(self): - """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail.""" - self.assertEqual('cpp', api_implementation.Type()) - self.assertEqual(2, api_implementation.Version()) - - -if __name__ == '__main__': - basetest.main() diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 20bfa8570..97cdd848e 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -41,6 +41,146 @@ are: __author__ = 'petar@google.com (Petar Petrov)' +import collections +import sys + +if sys.version_info[0] < 3: + # We would use collections.MutableMapping all the time, but in Python 2 it + # doesn't define __slots__. This causes two significant problems: + # + # 1. we can't disallow arbitrary attribute assignment, even if our derived + # classes *do* define __slots__. + # + # 2. we can't safely derive a C type from it without __slots__ defined (the + # interpreter expects to find a dict at tp_dictoffset, which we can't + # robustly provide. And we don't want an instance dict anyway. + # + # So this is the Python 2.7 definition of Mapping/MutableMapping functions + # verbatim, except that: + # 1. We declare __slots__. + # 2. We don't declare this as a virtual base class. The classes defined + # in collections are the interesting base classes, not us. + # + # Note: deriving from object is critical. It is the only thing that makes + # this a true type, allowing us to derive from it in C++ cleanly and making + # __slots__ properly disallow arbitrary element assignment. + + class Mapping(object): + __slots__ = () + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key): + try: + self[key] + except KeyError: + return False + else: + return True + + def iterkeys(self): + return iter(self) + + def itervalues(self): + for key in self: + yield self[key] + + def iteritems(self): + for key in self: + yield (key, self[key]) + + def keys(self): + return list(self) + + def items(self): + return [(key, self[key]) for key in self] + + def values(self): + return [self[key] for key in self] + + # Mappings are not hashable by default, but subclasses can change this + __hash__ = None + + def __eq__(self, other): + if not isinstance(other, collections.Mapping): + return NotImplemented + return dict(self.items()) == dict(other.items()) + + def __ne__(self, other): + return not (self == other) + + class MutableMapping(Mapping): + __slots__ = () + + __marker = object() + + def pop(self, key, default=__marker): + try: + value = self[key] + except KeyError: + if default is self.__marker: + raise + return default + else: + del self[key] + return value + + def popitem(self): + try: + key = next(iter(self)) + except StopIteration: + raise KeyError + value = self[key] + del self[key] + return key, value + + def clear(self): + try: + while True: + self.popitem() + except KeyError: + pass + + def update(*args, **kwds): + if len(args) > 2: + raise TypeError("update() takes at most 2 positional " + "arguments ({} given)".format(len(args))) + elif not args: + raise TypeError("update() takes at least 1 argument (0 given)") + self = args[0] + other = args[1] if len(args) >= 2 else () + + if isinstance(other, Mapping): + for key in other: + self[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + collections.Mapping.register(Mapping) + collections.MutableMapping.register(MutableMapping) + +else: + # In Python 3 we can just use MutableMapping directly, because it defines + # __slots__. + MutableMapping = collections.MutableMapping + class BaseContainer(object): @@ -119,15 +259,23 @@ class RepeatedScalarFieldContainer(BaseContainer): self._message_listener.Modified() def extend(self, elem_seq): - """Extends by appending the given sequence. Similar to list.extend().""" - if not elem_seq: - return + """Extends by appending the given iterable. Similar to list.extend().""" - new_values = [] - for elem in elem_seq: - new_values.append(self._type_checker.CheckValue(elem)) - self._values.extend(new_values) - self._message_listener.Modified() + if elem_seq is None: + return + try: + elem_seq_iter = iter(elem_seq) + except TypeError: + if not elem_seq: + # silently ignore falsy inputs :-/. + # TODO(ptucker): Deprecate this behavior. b/18413862 + return + raise + + new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter] + if new_values: + self._values.extend(new_values) + self._message_listener.Modified() def MergeFrom(self, other): """Appends the contents of another repeated field of the same type to this @@ -141,6 +289,12 @@ class RepeatedScalarFieldContainer(BaseContainer): self._values.remove(elem) self._message_listener.Modified() + def pop(self, key=-1): + """Removes and returns an item at a given index. Similar to list.pop().""" + value = self._values[key] + self.__delitem__(key) + return value + def __setitem__(self, key, value): """Sets the item on the specified position.""" if isinstance(key, slice): # PY3 @@ -183,6 +337,8 @@ class RepeatedScalarFieldContainer(BaseContainer): # We are presumably comparing against some other sequence type. return other == self._values +collections.MutableSequence.register(BaseContainer) + class RepeatedCompositeFieldContainer(BaseContainer): @@ -245,6 +401,12 @@ class RepeatedCompositeFieldContainer(BaseContainer): self._values.remove(elem) self._message_listener.Modified() + def pop(self, key=-1): + """Removes and returns an item at a given index. Similar to list.pop().""" + value = self._values[key] + self.__delitem__(key) + return value + def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" return self._values[start:stop] @@ -267,3 +429,183 @@ class RepeatedCompositeFieldContainer(BaseContainer): raise TypeError('Can only compare repeated composite fields against ' 'other repeated composite fields.') return self._values == other._values + + +class ScalarMap(MutableMapping): + + """Simple, type-checked, dict-like container for holding repeated scalars.""" + + # Disallows assignment to other attributes. + __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener'] + + def __init__(self, message_listener, key_checker, value_checker): + """ + Args: + message_listener: A MessageListener implementation. + The ScalarMap will call this object's Modified() method when it + is modified. + key_checker: A type_checkers.ValueChecker instance to run on keys + inserted into this container. + value_checker: A type_checkers.ValueChecker instance to run on values + inserted into this container. + """ + self._message_listener = message_listener + self._key_checker = key_checker + self._value_checker = value_checker + self._values = {} + + def __getitem__(self, key): + try: + return self._values[key] + except KeyError: + key = self._key_checker.CheckValue(key) + val = self._value_checker.DefaultValue() + self._values[key] = val + return val + + def __contains__(self, item): + # We check the key's type to match the strong-typing flavor of the API. + # Also this makes it easier to match the behavior of the C++ implementation. + self._key_checker.CheckValue(item) + return item in self._values + + # We need to override this explicitly, because our defaultdict-like behavior + # will make the default implementation (from our base class) always insert + # the key. + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def __setitem__(self, key, value): + checked_key = self._key_checker.CheckValue(key) + checked_value = self._value_checker.CheckValue(value) + self._values[checked_key] = checked_value + self._message_listener.Modified() + + def __delitem__(self, key): + del self._values[key] + self._message_listener.Modified() + + def __len__(self): + return len(self._values) + + def __iter__(self): + return iter(self._values) + + def __repr__(self): + return repr(self._values) + + def MergeFrom(self, other): + self._values.update(other._values) + self._message_listener.Modified() + + def InvalidateIterators(self): + # It appears that the only way to reliably invalidate iterators to + # self._values is to ensure that its size changes. + original = self._values + self._values = original.copy() + original[None] = None + + # This is defined in the abstract base, but we can do it much more cheaply. + def clear(self): + self._values.clear() + self._message_listener.Modified() + + +class MessageMap(MutableMapping): + + """Simple, type-checked, dict-like container for with submessage values.""" + + # Disallows assignment to other attributes. + __slots__ = ['_key_checker', '_values', '_message_listener', + '_message_descriptor'] + + def __init__(self, message_listener, message_descriptor, key_checker): + """ + Args: + message_listener: A MessageListener implementation. + The ScalarMap will call this object's Modified() method when it + is modified. + key_checker: A type_checkers.ValueChecker instance to run on keys + inserted into this container. + value_checker: A type_checkers.ValueChecker instance to run on values + inserted into this container. + """ + self._message_listener = message_listener + self._message_descriptor = message_descriptor + self._key_checker = key_checker + self._values = {} + + def __getitem__(self, key): + try: + return self._values[key] + except KeyError: + key = self._key_checker.CheckValue(key) + new_element = self._message_descriptor._concrete_class() + new_element._SetListener(self._message_listener) + self._values[key] = new_element + self._message_listener.Modified() + + return new_element + + def get_or_create(self, key): + """get_or_create() is an alias for getitem (ie. map[key]). + + Args: + key: The key to get or create in the map. + + This is useful in cases where you want to be explicit that the call is + mutating the map. This can avoid lint errors for statements like this + that otherwise would appear to be pointless statements: + + msg.my_map[key] + """ + return self[key] + + # We need to override this explicitly, because our defaultdict-like behavior + # will make the default implementation (from our base class) always insert + # the key. + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def __contains__(self, item): + return item in self._values + + def __setitem__(self, key, value): + raise ValueError('May not set values directly, call my_map[key].foo = 5') + + def __delitem__(self, key): + del self._values[key] + self._message_listener.Modified() + + def __len__(self): + return len(self._values) + + def __iter__(self): + return iter(self._values) + + def __repr__(self): + return repr(self._values) + + def MergeFrom(self, other): + for key in other: + self[key].MergeFrom(other[key]) + # self._message_listener.Modified() not required here, because + # mutations to submessages already propagate. + + def InvalidateIterators(self): + # It appears that the only way to reliably invalidate iterators to + # self._values is to ensure that its size changes. + original = self._values + self._values = original.copy() + original[None] = None + + # This is defined in the abstract base, but we can do it much more cheaply. + def clear(self): + self._values.clear() + self._message_listener.Modified() diff --git a/python/google/protobuf/internal/cpp_message.py b/python/google/protobuf/internal/cpp_message.py deleted file mode 100755 index 0313cb0bc..000000000 --- a/python/google/protobuf/internal/cpp_message.py +++ /dev/null @@ -1,663 +0,0 @@ -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Contains helper functions used to create protocol message classes from -Descriptor objects at runtime backed by the protocol buffer C++ API. -""" - -__author__ = 'petar@google.com (Petar Petrov)' - -import copy_reg -import operator -from google.protobuf.internal import _net_proto2___python -from google.protobuf.internal import enum_type_wrapper -from google.protobuf import message - - -_LABEL_REPEATED = _net_proto2___python.LABEL_REPEATED -_LABEL_OPTIONAL = _net_proto2___python.LABEL_OPTIONAL -_CPPTYPE_MESSAGE = _net_proto2___python.CPPTYPE_MESSAGE -_TYPE_MESSAGE = _net_proto2___python.TYPE_MESSAGE - - -def GetDescriptorPool(): - """Creates a new DescriptorPool C++ object.""" - return _net_proto2___python.NewCDescriptorPool() - - -_pool = GetDescriptorPool() - - -def GetFieldDescriptor(full_field_name): - """Searches for a field descriptor given a full field name.""" - return _pool.FindFieldByName(full_field_name) - - -def BuildFile(content): - """Registers a new proto file in the underlying C++ descriptor pool.""" - _net_proto2___python.BuildFile(content) - - -def GetExtensionDescriptor(full_extension_name): - """Searches for extension descriptor given a full field name.""" - return _pool.FindExtensionByName(full_extension_name) - - -def NewCMessage(full_message_name): - """Creates a new C++ protocol message by its name.""" - return _net_proto2___python.NewCMessage(full_message_name) - - -def ScalarProperty(cdescriptor): - """Returns a scalar property for the given descriptor.""" - - def Getter(self): - return self._cmsg.GetScalar(cdescriptor) - - def Setter(self, value): - self._cmsg.SetScalar(cdescriptor, value) - - return property(Getter, Setter) - - -def CompositeProperty(cdescriptor, message_type): - """Returns a Python property the given composite field.""" - - def Getter(self): - sub_message = self._composite_fields.get(cdescriptor.name, None) - if sub_message is None: - cmessage = self._cmsg.NewSubMessage(cdescriptor) - sub_message = message_type._concrete_class(__cmessage=cmessage) - self._composite_fields[cdescriptor.name] = sub_message - return sub_message - - return property(Getter) - - -class RepeatedScalarContainer(object): - """Container for repeated scalar fields.""" - - __slots__ = ['_message', '_cfield_descriptor', '_cmsg'] - - def __init__(self, msg, cfield_descriptor): - self._message = msg - self._cmsg = msg._cmsg - self._cfield_descriptor = cfield_descriptor - - def append(self, value): - self._cmsg.AddRepeatedScalar( - self._cfield_descriptor, value) - - def extend(self, sequence): - for element in sequence: - self.append(element) - - def insert(self, key, value): - values = self[slice(None, None, None)] - values.insert(key, value) - self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) - - def remove(self, value): - values = self[slice(None, None, None)] - values.remove(value) - self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) - - def __setitem__(self, key, value): - values = self[slice(None, None, None)] - values[key] = value - self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) - - def __getitem__(self, key): - return self._cmsg.GetRepeatedScalar(self._cfield_descriptor, key) - - def __delitem__(self, key): - self._cmsg.DeleteRepeatedField(self._cfield_descriptor, key) - - def __len__(self): - return len(self[slice(None, None, None)]) - - def __eq__(self, other): - if self is other: - return True - if not operator.isSequenceType(other): - raise TypeError( - 'Can only compare repeated scalar fields against sequences.') - # We are presumably comparing against some other sequence type. - return other == self[slice(None, None, None)] - - def __ne__(self, other): - return not self == other - - def __hash__(self): - raise TypeError('unhashable object') - - def sort(self, *args, **kwargs): - # Maintain compatibility with the previous interface. - if 'sort_function' in kwargs: - kwargs['cmp'] = kwargs.pop('sort_function') - self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, - sorted(self, *args, **kwargs)) - - -def RepeatedScalarProperty(cdescriptor): - """Returns a Python property the given repeated scalar field.""" - - def Getter(self): - container = self._composite_fields.get(cdescriptor.name, None) - if container is None: - container = RepeatedScalarContainer(self, cdescriptor) - self._composite_fields[cdescriptor.name] = container - return container - - def Setter(self, new_value): - raise AttributeError('Assignment not allowed to repeated field ' - '"%s" in protocol message object.' % cdescriptor.name) - - doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name - return property(Getter, Setter, doc=doc) - - -class RepeatedCompositeContainer(object): - """Container for repeated composite fields.""" - - __slots__ = ['_message', '_subclass', '_cfield_descriptor', '_cmsg'] - - def __init__(self, msg, cfield_descriptor, subclass): - self._message = msg - self._cmsg = msg._cmsg - self._subclass = subclass - self._cfield_descriptor = cfield_descriptor - - def add(self, **kwargs): - cmessage = self._cmsg.AddMessage(self._cfield_descriptor) - return self._subclass(__cmessage=cmessage, __owner=self._message, **kwargs) - - def extend(self, elem_seq): - """Extends by appending the given sequence of elements of the same type - as this one, copying each individual message. - """ - for message in elem_seq: - self.add().MergeFrom(message) - - def remove(self, value): - # TODO(protocol-devel): This is inefficient as it needs to generate a - # message pointer for each message only to do index(). Move this to a C++ - # extension function. - self.__delitem__(self[slice(None, None, None)].index(value)) - - def MergeFrom(self, other): - for message in other[:]: - self.add().MergeFrom(message) - - def __getitem__(self, key): - cmessages = self._cmsg.GetRepeatedMessage( - self._cfield_descriptor, key) - subclass = self._subclass - if not isinstance(cmessages, list): - return subclass(__cmessage=cmessages, __owner=self._message) - - return [subclass(__cmessage=m, __owner=self._message) for m in cmessages] - - def __delitem__(self, key): - self._cmsg.DeleteRepeatedField( - self._cfield_descriptor, key) - - def __len__(self): - return self._cmsg.FieldLength(self._cfield_descriptor) - - def __eq__(self, other): - """Compares the current instance with another one.""" - if self is other: - return True - if not isinstance(other, self.__class__): - raise TypeError('Can only compare repeated composite fields against ' - 'other repeated composite fields.') - messages = self[slice(None, None, None)] - other_messages = other[slice(None, None, None)] - return messages == other_messages - - def __hash__(self): - raise TypeError('unhashable object') - - def sort(self, cmp=None, key=None, reverse=False, **kwargs): - # Maintain compatibility with the old interface. - if cmp is None and 'sort_function' in kwargs: - cmp = kwargs.pop('sort_function') - - # The cmp function, if provided, is passed the results of the key function, - # so we only need to wrap one of them. - if key is None: - index_key = self.__getitem__ - else: - index_key = lambda i: key(self[i]) - - # Sort the list of current indexes by the underlying object. - indexes = range(len(self)) - indexes.sort(cmp=cmp, key=index_key, reverse=reverse) - - # Apply the transposition. - for dest, src in enumerate(indexes): - if dest == src: - continue - self._cmsg.SwapRepeatedFieldElements(self._cfield_descriptor, dest, src) - # Don't swap the same value twice. - indexes[src] = src - - -def RepeatedCompositeProperty(cdescriptor, message_type): - """Returns a Python property for the given repeated composite field.""" - - def Getter(self): - container = self._composite_fields.get(cdescriptor.name, None) - if container is None: - container = RepeatedCompositeContainer( - self, cdescriptor, message_type._concrete_class) - self._composite_fields[cdescriptor.name] = container - return container - - def Setter(self, new_value): - raise AttributeError('Assignment not allowed to repeated field ' - '"%s" in protocol message object.' % cdescriptor.name) - - doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name - return property(Getter, Setter, doc=doc) - - -class ExtensionDict(object): - """Extension dictionary added to each protocol message.""" - - def __init__(self, msg): - self._message = msg - self._cmsg = msg._cmsg - self._values = {} - - def __setitem__(self, extension, value): - from google.protobuf import descriptor - if not isinstance(extension, descriptor.FieldDescriptor): - raise KeyError('Bad extension %r.' % (extension,)) - cdescriptor = extension._cdescriptor - if (cdescriptor.label != _LABEL_OPTIONAL or - cdescriptor.cpp_type == _CPPTYPE_MESSAGE): - raise TypeError('Extension %r is repeated and/or a composite type.' % ( - extension.full_name,)) - self._cmsg.SetScalar(cdescriptor, value) - self._values[extension] = value - - def __getitem__(self, extension): - from google.protobuf import descriptor - if not isinstance(extension, descriptor.FieldDescriptor): - raise KeyError('Bad extension %r.' % (extension,)) - - cdescriptor = extension._cdescriptor - if (cdescriptor.label != _LABEL_REPEATED and - cdescriptor.cpp_type != _CPPTYPE_MESSAGE): - return self._cmsg.GetScalar(cdescriptor) - - ext = self._values.get(extension, None) - if ext is not None: - return ext - - ext = self._CreateNewHandle(extension) - self._values[extension] = ext - return ext - - def ClearExtension(self, extension): - from google.protobuf import descriptor - if not isinstance(extension, descriptor.FieldDescriptor): - raise KeyError('Bad extension %r.' % (extension,)) - self._cmsg.ClearFieldByDescriptor(extension._cdescriptor) - if extension in self._values: - del self._values[extension] - - def HasExtension(self, extension): - from google.protobuf import descriptor - if not isinstance(extension, descriptor.FieldDescriptor): - raise KeyError('Bad extension %r.' % (extension,)) - return self._cmsg.HasFieldByDescriptor(extension._cdescriptor) - - def _FindExtensionByName(self, name): - """Tries to find a known extension with the specified name. - - Args: - name: Extension full name. - - Returns: - Extension field descriptor. - """ - return self._message._extensions_by_name.get(name, None) - - def _CreateNewHandle(self, extension): - cdescriptor = extension._cdescriptor - if (cdescriptor.label != _LABEL_REPEATED and - cdescriptor.cpp_type == _CPPTYPE_MESSAGE): - cmessage = self._cmsg.NewSubMessage(cdescriptor) - return extension.message_type._concrete_class(__cmessage=cmessage) - - if cdescriptor.label == _LABEL_REPEATED: - if cdescriptor.cpp_type == _CPPTYPE_MESSAGE: - return RepeatedCompositeContainer( - self._message, cdescriptor, extension.message_type._concrete_class) - else: - return RepeatedScalarContainer(self._message, cdescriptor) - # This shouldn't happen! - assert False - return None - - -def NewMessage(bases, message_descriptor, dictionary): - """Creates a new protocol message *class*.""" - _AddClassAttributesForNestedExtensions(message_descriptor, dictionary) - _AddEnumValues(message_descriptor, dictionary) - _AddDescriptors(message_descriptor, dictionary) - return bases - - -def InitMessage(message_descriptor, cls): - """Constructs a new message instance (called before instance's __init__).""" - cls._extensions_by_name = {} - _AddInitMethod(message_descriptor, cls) - _AddMessageMethods(message_descriptor, cls) - _AddPropertiesForExtensions(message_descriptor, cls) - copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) - - -def _AddDescriptors(message_descriptor, dictionary): - """Sets up a new protocol message class dictionary. - - Args: - message_descriptor: A Descriptor instance describing this message type. - dictionary: Class dictionary to which we'll add a '__slots__' entry. - """ - dictionary['__descriptors'] = {} - for field in message_descriptor.fields: - dictionary['__descriptors'][field.name] = GetFieldDescriptor( - field.full_name) - - dictionary['__slots__'] = list(dictionary['__descriptors'].iterkeys()) + [ - '_cmsg', '_owner', '_composite_fields', 'Extensions', '_HACK_REFCOUNTS'] - - -def _AddEnumValues(message_descriptor, dictionary): - """Sets class-level attributes for all enum fields defined in this message. - - Args: - message_descriptor: Descriptor object for this message type. - dictionary: Class dictionary that should be populated. - """ - for enum_type in message_descriptor.enum_types: - dictionary[enum_type.name] = enum_type_wrapper.EnumTypeWrapper(enum_type) - for enum_value in enum_type.values: - dictionary[enum_value.name] = enum_value.number - - -def _AddClassAttributesForNestedExtensions(message_descriptor, dictionary): - """Adds class attributes for the nested extensions.""" - extension_dict = message_descriptor.extensions_by_name - for extension_name, extension_field in extension_dict.iteritems(): - assert extension_name not in dictionary - dictionary[extension_name] = extension_field - - -def _AddInitMethod(message_descriptor, cls): - """Adds an __init__ method to cls.""" - - # Create and attach message field properties to the message class. - # This can be done just once per message class, since property setters and - # getters are passed the message instance. - # This makes message instantiation extremely fast, and at the same time it - # doesn't require the creation of property objects for each message instance, - # which saves a lot of memory. - for field in message_descriptor.fields: - field_cdescriptor = cls.__descriptors[field.name] - if field.label == _LABEL_REPEATED: - if field.cpp_type == _CPPTYPE_MESSAGE: - value = RepeatedCompositeProperty(field_cdescriptor, field.message_type) - else: - value = RepeatedScalarProperty(field_cdescriptor) - elif field.cpp_type == _CPPTYPE_MESSAGE: - value = CompositeProperty(field_cdescriptor, field.message_type) - else: - value = ScalarProperty(field_cdescriptor) - setattr(cls, field.name, value) - - # Attach a constant with the field number. - constant_name = field.name.upper() + '_FIELD_NUMBER' - setattr(cls, constant_name, field.number) - - def Init(self, **kwargs): - """Message constructor.""" - cmessage = kwargs.pop('__cmessage', None) - if cmessage: - self._cmsg = cmessage - else: - self._cmsg = NewCMessage(message_descriptor.full_name) - - # Keep a reference to the owner, as the owner keeps a reference to the - # underlying protocol buffer message. - owner = kwargs.pop('__owner', None) - if owner: - self._owner = owner - - if message_descriptor.is_extendable: - self.Extensions = ExtensionDict(self) - else: - # Reference counting in the C++ code is broken and depends on - # the Extensions reference to keep this object alive during unit - # tests (see b/4856052). Remove this once b/4945904 is fixed. - self._HACK_REFCOUNTS = self - self._composite_fields = {} - - for field_name, field_value in kwargs.iteritems(): - field_cdescriptor = self.__descriptors.get(field_name, None) - if not field_cdescriptor: - raise ValueError('Protocol message has no "%s" field.' % field_name) - if field_cdescriptor.label == _LABEL_REPEATED: - if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: - field_name = getattr(self, field_name) - for val in field_value: - field_name.add().MergeFrom(val) - else: - getattr(self, field_name).extend(field_value) - elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: - getattr(self, field_name).MergeFrom(field_value) - else: - setattr(self, field_name, field_value) - - Init.__module__ = None - Init.__doc__ = None - cls.__init__ = Init - - -def _IsMessageSetExtension(field): - """Checks if a field is a message set extension.""" - return (field.is_extension and - field.containing_type.has_options and - field.containing_type.GetOptions().message_set_wire_format and - field.type == _TYPE_MESSAGE and - field.message_type == field.extension_scope and - field.label == _LABEL_OPTIONAL) - - -def _AddMessageMethods(message_descriptor, cls): - """Adds the methods to a protocol message class.""" - if message_descriptor.is_extendable: - - def ClearExtension(self, extension): - self.Extensions.ClearExtension(extension) - - def HasExtension(self, extension): - return self.Extensions.HasExtension(extension) - - def HasField(self, field_name): - return self._cmsg.HasField(field_name) - - def ClearField(self, field_name): - child_cmessage = None - if field_name in self._composite_fields: - child_field = self._composite_fields[field_name] - del self._composite_fields[field_name] - - child_cdescriptor = self.__descriptors[field_name] - # TODO(anuraag): Support clearing repeated message fields as well. - if (child_cdescriptor.label != _LABEL_REPEATED and - child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): - child_field._owner = None - child_cmessage = child_field._cmsg - - if child_cmessage is not None: - self._cmsg.ClearField(field_name, child_cmessage) - else: - self._cmsg.ClearField(field_name) - - def Clear(self): - cmessages_to_release = [] - for field_name, child_field in self._composite_fields.iteritems(): - child_cdescriptor = self.__descriptors[field_name] - # TODO(anuraag): Support clearing repeated message fields as well. - if (child_cdescriptor.label != _LABEL_REPEATED and - child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): - child_field._owner = None - cmessages_to_release.append((child_cdescriptor, child_field._cmsg)) - self._composite_fields.clear() - self._cmsg.Clear(cmessages_to_release) - - def IsInitialized(self, errors=None): - if self._cmsg.IsInitialized(): - return True - if errors is not None: - errors.extend(self.FindInitializationErrors()); - return False - - def SerializeToString(self): - if not self.IsInitialized(): - raise message.EncodeError( - 'Message %s is missing required fields: %s' % ( - self._cmsg.full_name, ','.join(self.FindInitializationErrors()))) - return self._cmsg.SerializeToString() - - def SerializePartialToString(self): - return self._cmsg.SerializePartialToString() - - def ParseFromString(self, serialized): - self.Clear() - self.MergeFromString(serialized) - - def MergeFromString(self, serialized): - byte_size = self._cmsg.MergeFromString(serialized) - if byte_size < 0: - raise message.DecodeError('Unable to merge from string.') - return byte_size - - def MergeFrom(self, msg): - if not isinstance(msg, cls): - raise TypeError( - "Parameter to MergeFrom() must be instance of same class: " - "expected %s got %s." % (cls.__name__, type(msg).__name__)) - self._cmsg.MergeFrom(msg._cmsg) - - def CopyFrom(self, msg): - self._cmsg.CopyFrom(msg._cmsg) - - def ByteSize(self): - return self._cmsg.ByteSize() - - def SetInParent(self): - return self._cmsg.SetInParent() - - def ListFields(self): - all_fields = [] - field_list = self._cmsg.ListFields() - fields_by_name = cls.DESCRIPTOR.fields_by_name - for is_extension, field_name in field_list: - if is_extension: - extension = cls._extensions_by_name[field_name] - all_fields.append((extension, self.Extensions[extension])) - else: - field_descriptor = fields_by_name[field_name] - all_fields.append( - (field_descriptor, getattr(self, field_name))) - all_fields.sort(key=lambda item: item[0].number) - return all_fields - - def FindInitializationErrors(self): - return self._cmsg.FindInitializationErrors() - - def __str__(self): - return str(self._cmsg) - - def __eq__(self, other): - if self is other: - return True - if not isinstance(other, self.__class__): - return False - return self.ListFields() == other.ListFields() - - def __ne__(self, other): - return not self == other - - def __hash__(self): - raise TypeError('unhashable object') - - def __unicode__(self): - # Lazy import to prevent circular import when text_format imports this file. - from google.protobuf import text_format - return text_format.MessageToString(self, as_utf8=True).decode('utf-8') - - # Attach the local methods to the message class. - for key, value in locals().copy().iteritems(): - if key not in ('key', 'value', '__builtins__', '__name__', '__doc__'): - setattr(cls, key, value) - - # Static methods: - - def RegisterExtension(extension_handle): - extension_handle.containing_type = cls.DESCRIPTOR - cls._extensions_by_name[extension_handle.full_name] = extension_handle - - if _IsMessageSetExtension(extension_handle): - # MessageSet extension. Also register under type name. - cls._extensions_by_name[ - extension_handle.message_type.full_name] = extension_handle - cls.RegisterExtension = staticmethod(RegisterExtension) - - def FromString(string): - msg = cls() - msg.MergeFromString(string) - return msg - cls.FromString = staticmethod(FromString) - - - -def _AddPropertiesForExtensions(message_descriptor, cls): - """Adds properties for all fields in this protocol message type.""" - extension_dict = message_descriptor.extensions_by_name - for extension_name, extension_field in extension_dict.iteritems(): - constant_name = extension_name.upper() + '_FIELD_NUMBER' - setattr(cls, constant_name, extension_field.number) diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index a4b906086..31869e457 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -28,10 +28,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#PY25 compatible for GAE. -# -# Copyright 2009 Google Inc. All Rights Reserved. - """Code for decoding protocol buffer primitives. This code is very similar to encoder.py -- read the docs for that module first. @@ -85,8 +81,12 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it. __author__ = 'kenton@google.com (Kenton Varda)' import struct -import sys ##PY25 -_PY2 = sys.version_info[0] < 3 ##PY25 + +import six + +if six.PY3: + long = int + from google.protobuf.internal import encoder from google.protobuf.internal import wire_format from google.protobuf import message @@ -114,14 +114,11 @@ def _VarintDecoder(mask, result_type): decoder returns a (value, new_pos) pair. """ - local_ord = ord - py2 = _PY2 ##PY25 -##!PY25 py2 = str is bytes def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = local_ord(buffer[pos]) if py2 else buffer[pos] + b = six.indexbytes(buffer, pos) result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): @@ -137,14 +134,11 @@ def _VarintDecoder(mask, result_type): def _SignedVarintDecoder(mask, result_type): """Like _VarintDecoder() but decodes signed values.""" - local_ord = ord - py2 = _PY2 ##PY25 -##!PY25 py2 = str is bytes def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = local_ord(buffer[pos]) if py2 else buffer[pos] + b = six.indexbytes(buffer, pos) result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): @@ -183,10 +177,8 @@ def ReadTag(buffer, pos): use that, but not in Python. """ - py2 = _PY2 ##PY25 -##!PY25 py2 = str is bytes start = pos - while (ord(buffer[pos]) if py2 else buffer[pos]) & 0x80: + while six.indexbytes(buffer, pos) & 0x80: pos += 1 pos += 1 return (buffer[start:pos], pos) @@ -301,7 +293,6 @@ def _FloatDecoder(): """ local_unpack = struct.unpack - b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25 def InnerDecode(buffer, pos): # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign @@ -312,17 +303,12 @@ def _FloatDecoder(): # If this value has all its exponent bits set, then it's non-finite. # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. # To avoid that, we parse it specially. - if ((float_bytes[3:4] in b('\x7F\xFF')) ##PY25 -##!PY25 if ((float_bytes[3:4] in b'\x7F\xFF') - and (float_bytes[2:3] >= b('\x80'))): ##PY25 -##!PY25 and (float_bytes[2:3] >= b'\x80')): + if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): # If at least one significand bit is set... - if float_bytes[0:3] != b('\x00\x00\x80'): ##PY25 -##!PY25 if float_bytes[0:3] != b'\x00\x00\x80': + if float_bytes[0:3] != b'\x00\x00\x80': return (_NAN, new_pos) # If sign bit is set... - if float_bytes[3:4] == b('\xFF'): ##PY25 -##!PY25 if float_bytes[3:4] == b'\xFF': + if float_bytes[3:4] == b'\xFF': return (_NEG_INF, new_pos) return (_POS_INF, new_pos) @@ -341,7 +327,6 @@ def _DoubleDecoder(): """ local_unpack = struct.unpack - b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25 def InnerDecode(buffer, pos): # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign @@ -352,12 +337,9 @@ def _DoubleDecoder(): # If this value has all its exponent bits set and at least one significand # bit set, it's not a number. In Python 2.4, struct.unpack will treat it # as inf or -inf. To avoid that, we treat it specially. -##!PY25 if ((double_bytes[7:8] in b'\x7F\xFF') -##!PY25 and (double_bytes[6:7] >= b'\xF0') -##!PY25 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): - if ((double_bytes[7:8] in b('\x7F\xFF')) ##PY25 - and (double_bytes[6:7] >= b('\xF0')) ##PY25 - and (double_bytes[0:7] != b('\x00\x00\x00\x00\x00\x00\xF0'))): ##PY25 + if ((double_bytes[7:8] in b'\x7F\xFF') + and (double_bytes[6:7] >= b'\xF0') + and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): return (_NAN, new_pos) # Note that we expect someone up-stack to catch struct.error and convert @@ -480,12 +462,12 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): """Returns a decoder for a string field.""" local_DecodeVarint = _DecodeVarint - local_unicode = unicode + local_unicode = six.text_type def _ConvertToUnicode(byte_str): try: return local_unicode(byte_str, 'utf-8') - except UnicodeDecodeError, e: + except UnicodeDecodeError as e: # add more information to the error message and re-raise it. e.reason = '%s in field: %s' % (e, key.full_name) raise @@ -621,9 +603,6 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: - value = field_dict.get(key) - if value is None: - value = field_dict.setdefault(key, new_default(message)) # Read length. (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size @@ -736,6 +715,50 @@ def MessageSetItemDecoder(extensions_by_number): return DecodeItem # -------------------------------------------------------------------- + +def MapDecoder(field_descriptor, new_default, is_message_map): + """Returns a decoder for a map field.""" + + key = field_descriptor + tag_bytes = encoder.TagBytes(field_descriptor.number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + local_DecodeVarint = _DecodeVarint + # Can't read _concrete_class yet; might not be initialized. + message_type = field_descriptor.message_type + + def DecodeMap(buffer, pos, end, message, field_dict): + submsg = message_type._concrete_class() + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + # Read length. + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated message.') + # Read sub-message. + submsg.Clear() + if submsg._InternalParse(buffer, pos, new_pos) != new_pos: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + + if is_message_map: + value[submsg.key].MergeFrom(submsg.value) + else: + value[submsg.key] = submsg.value + + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + + return DecodeMap + +# -------------------------------------------------------------------- # Optimization is not as heavy here because calls to SkipField() are rare, # except for handling end-group tags. diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py index fc65b69ad..5225a4582 100644 --- a/python/google/protobuf/internal/descriptor_database_test.py +++ b/python/google/protobuf/internal/descriptor_database_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -34,13 +34,17 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' -from google.apputils import basetest +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test2_pb2 from google.protobuf import descriptor_database -class DescriptorDatabaseTest(basetest.TestCase): +class DescriptorDatabaseTest(unittest.TestCase): def testAdd(self): db = descriptor_database.DescriptorDatabase() @@ -48,16 +52,18 @@ class DescriptorDatabaseTest(basetest.TestCase): factory_test2_pb2.DESCRIPTOR.serialized_pb) db.Add(file_desc_proto) - self.assertEquals(file_desc_proto, db.FindFileByName( + self.assertEqual(file_desc_proto, db.FindFileByName( 'google/protobuf/internal/factory_test2.proto')) - self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Message')) - self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message')) - self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Enum')) - self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum')) + self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.MessageWithNestedEnumOnly.NestedEnum')) if __name__ == '__main__': - basetest.main() + unittest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index d2f855798..6a13e0bcf 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -35,9 +35,15 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' import os -import unittest +import sys -from google.apputils import basetest +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_import_public_pb2 from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 from google.protobuf.internal import api_implementation @@ -45,12 +51,15 @@ from google.protobuf.internal import descriptor_pool_test1_pb2 from google.protobuf.internal import descriptor_pool_test2_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 +from google.protobuf.internal import more_messages_pb2 from google.protobuf import descriptor from google.protobuf import descriptor_database from google.protobuf import descriptor_pool +from google.protobuf import message_factory +from google.protobuf import symbol_database -class DescriptorPoolTest(basetest.TestCase): +class DescriptorPoolTest(unittest.TestCase): def setUp(self): self.pool = descriptor_pool.DescriptorPool() @@ -65,15 +74,15 @@ class DescriptorPoolTest(basetest.TestCase): name1 = 'google/protobuf/internal/factory_test1.proto' file_desc1 = self.pool.FindFileByName(name1) self.assertIsInstance(file_desc1, descriptor.FileDescriptor) - self.assertEquals(name1, file_desc1.name) - self.assertEquals('google.protobuf.python.internal', file_desc1.package) + self.assertEqual(name1, file_desc1.name) + self.assertEqual('google.protobuf.python.internal', file_desc1.package) self.assertIn('Factory1Message', file_desc1.message_types_by_name) name2 = 'google/protobuf/internal/factory_test2.proto' file_desc2 = self.pool.FindFileByName(name2) self.assertIsInstance(file_desc2, descriptor.FileDescriptor) - self.assertEquals(name2, file_desc2.name) - self.assertEquals('google.protobuf.python.internal', file_desc2.package) + self.assertEqual(name2, file_desc2.name) + self.assertEqual('google.protobuf.python.internal', file_desc2.package) self.assertIn('Factory2Message', file_desc2.message_types_by_name) def testFindFileByNameFailure(self): @@ -84,17 +93,17 @@ class DescriptorPoolTest(basetest.TestCase): file_desc1 = self.pool.FindFileContainingSymbol( 'google.protobuf.python.internal.Factory1Message') self.assertIsInstance(file_desc1, descriptor.FileDescriptor) - self.assertEquals('google/protobuf/internal/factory_test1.proto', - file_desc1.name) - self.assertEquals('google.protobuf.python.internal', file_desc1.package) + self.assertEqual('google/protobuf/internal/factory_test1.proto', + file_desc1.name) + self.assertEqual('google.protobuf.python.internal', file_desc1.package) self.assertIn('Factory1Message', file_desc1.message_types_by_name) file_desc2 = self.pool.FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Message') self.assertIsInstance(file_desc2, descriptor.FileDescriptor) - self.assertEquals('google/protobuf/internal/factory_test2.proto', - file_desc2.name) - self.assertEquals('google.protobuf.python.internal', file_desc2.package) + self.assertEqual('google/protobuf/internal/factory_test2.proto', + file_desc2.name) + self.assertEqual('google.protobuf.python.internal', file_desc2.package) self.assertIn('Factory2Message', file_desc2.message_types_by_name) def testFindFileContainingSymbolFailure(self): @@ -105,72 +114,72 @@ class DescriptorPoolTest(basetest.TestCase): msg1 = self.pool.FindMessageTypeByName( 'google.protobuf.python.internal.Factory1Message') self.assertIsInstance(msg1, descriptor.Descriptor) - self.assertEquals('Factory1Message', msg1.name) - self.assertEquals('google.protobuf.python.internal.Factory1Message', - msg1.full_name) - self.assertEquals(None, msg1.containing_type) + self.assertEqual('Factory1Message', msg1.name) + self.assertEqual('google.protobuf.python.internal.Factory1Message', + msg1.full_name) + self.assertEqual(None, msg1.containing_type) nested_msg1 = msg1.nested_types[0] - self.assertEquals('NestedFactory1Message', nested_msg1.name) - self.assertEquals(msg1, nested_msg1.containing_type) + self.assertEqual('NestedFactory1Message', nested_msg1.name) + self.assertEqual(msg1, nested_msg1.containing_type) nested_enum1 = msg1.enum_types[0] - self.assertEquals('NestedFactory1Enum', nested_enum1.name) - self.assertEquals(msg1, nested_enum1.containing_type) + self.assertEqual('NestedFactory1Enum', nested_enum1.name) + self.assertEqual(msg1, nested_enum1.containing_type) - self.assertEquals(nested_msg1, msg1.fields_by_name[ + self.assertEqual(nested_msg1, msg1.fields_by_name[ 'nested_factory_1_message'].message_type) - self.assertEquals(nested_enum1, msg1.fields_by_name[ + self.assertEqual(nested_enum1, msg1.fields_by_name[ 'nested_factory_1_enum'].enum_type) msg2 = self.pool.FindMessageTypeByName( 'google.protobuf.python.internal.Factory2Message') self.assertIsInstance(msg2, descriptor.Descriptor) - self.assertEquals('Factory2Message', msg2.name) - self.assertEquals('google.protobuf.python.internal.Factory2Message', - msg2.full_name) + self.assertEqual('Factory2Message', msg2.name) + self.assertEqual('google.protobuf.python.internal.Factory2Message', + msg2.full_name) self.assertIsNone(msg2.containing_type) nested_msg2 = msg2.nested_types[0] - self.assertEquals('NestedFactory2Message', nested_msg2.name) - self.assertEquals(msg2, nested_msg2.containing_type) + self.assertEqual('NestedFactory2Message', nested_msg2.name) + self.assertEqual(msg2, nested_msg2.containing_type) nested_enum2 = msg2.enum_types[0] - self.assertEquals('NestedFactory2Enum', nested_enum2.name) - self.assertEquals(msg2, nested_enum2.containing_type) + self.assertEqual('NestedFactory2Enum', nested_enum2.name) + self.assertEqual(msg2, nested_enum2.containing_type) - self.assertEquals(nested_msg2, msg2.fields_by_name[ + self.assertEqual(nested_msg2, msg2.fields_by_name[ 'nested_factory_2_message'].message_type) - self.assertEquals(nested_enum2, msg2.fields_by_name[ + self.assertEqual(nested_enum2, msg2.fields_by_name[ 'nested_factory_2_enum'].enum_type) self.assertTrue(msg2.fields_by_name['int_with_default'].has_default_value) - self.assertEquals( + self.assertEqual( 1776, msg2.fields_by_name['int_with_default'].default_value) self.assertTrue( msg2.fields_by_name['double_with_default'].has_default_value) - self.assertEquals( + self.assertEqual( 9.99, msg2.fields_by_name['double_with_default'].default_value) self.assertTrue( msg2.fields_by_name['string_with_default'].has_default_value) - self.assertEquals( + self.assertEqual( 'hello world', msg2.fields_by_name['string_with_default'].default_value) self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default_value) self.assertFalse(msg2.fields_by_name['bool_with_default'].default_value) self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default_value) - self.assertEquals( + self.assertEqual( 1, msg2.fields_by_name['enum_with_default'].default_value) msg3 = self.pool.FindMessageTypeByName( 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message') - self.assertEquals(nested_msg2, msg3) + self.assertEqual(nested_msg2, msg3) self.assertTrue(msg2.fields_by_name['bytes_with_default'].has_default_value) - self.assertEquals( + self.assertEqual( b'a\xfb\x00c', msg2.fields_by_name['bytes_with_default'].default_value) @@ -190,35 +199,66 @@ class DescriptorPoolTest(basetest.TestCase): enum1 = self.pool.FindEnumTypeByName( 'google.protobuf.python.internal.Factory1Enum') self.assertIsInstance(enum1, descriptor.EnumDescriptor) - self.assertEquals(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number) - self.assertEquals(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number) + self.assertEqual(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number) + self.assertEqual(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number) nested_enum1 = self.pool.FindEnumTypeByName( 'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum') self.assertIsInstance(nested_enum1, descriptor.EnumDescriptor) - self.assertEquals( + self.assertEqual( 0, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_0'].number) - self.assertEquals( + self.assertEqual( 1, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_1'].number) enum2 = self.pool.FindEnumTypeByName( 'google.protobuf.python.internal.Factory2Enum') self.assertIsInstance(enum2, descriptor.EnumDescriptor) - self.assertEquals(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number) - self.assertEquals(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number) + self.assertEqual(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number) + self.assertEqual(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number) nested_enum2 = self.pool.FindEnumTypeByName( 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum') self.assertIsInstance(nested_enum2, descriptor.EnumDescriptor) - self.assertEquals( + self.assertEqual( 0, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_0'].number) - self.assertEquals( + self.assertEqual( 1, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_1'].number) def testFindEnumTypeByNameFailure(self): with self.assertRaises(KeyError): self.pool.FindEnumTypeByName('Does not exist') + def testFindFieldByName(self): + field = self.pool.FindFieldByName( + 'google.protobuf.python.internal.Factory1Message.list_value') + self.assertEqual(field.name, 'list_value') + self.assertEqual(field.label, field.LABEL_REPEATED) + with self.assertRaises(KeyError): + self.pool.FindFieldByName('Does not exist') + + def testFindExtensionByName(self): + # An extension defined in a message. + extension = self.pool.FindExtensionByName( + 'google.protobuf.python.internal.Factory2Message.one_more_field') + self.assertEqual(extension.name, 'one_more_field') + # An extension defined at file scope. + extension = self.pool.FindExtensionByName( + 'google.protobuf.python.internal.another_field') + self.assertEqual(extension.name, 'another_field') + self.assertEqual(extension.number, 1002) + with self.assertRaises(KeyError): + self.pool.FindFieldByName('Does not exist') + + def testExtensionsAreNotFields(self): + with self.assertRaises(KeyError): + self.pool.FindFieldByName('google.protobuf.python.internal.another_field') + with self.assertRaises(KeyError): + self.pool.FindFieldByName( + 'google.protobuf.python.internal.Factory2Message.one_more_field') + with self.assertRaises(KeyError): + self.pool.FindExtensionByName( + 'google.protobuf.python.internal.Factory1Message.list_value') + def testUserDefinedDB(self): db = descriptor_database.DescriptorDatabase() self.pool = descriptor_pool.DescriptorPool(db) @@ -226,32 +266,109 @@ class DescriptorPoolTest(basetest.TestCase): db.Add(self.factory_test2_fd) self.testFindMessageTypeByName() + def testAddSerializedFile(self): + self.pool = descriptor_pool.DescriptorPool() + self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString()) + self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString()) + self.testFindMessageTypeByName() + def testComplexNesting(self): + more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString( + more_messages_pb2.DESCRIPTOR.serialized_pb) test1_desc = descriptor_pb2.FileDescriptorProto.FromString( descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) test2_desc = descriptor_pb2.FileDescriptorProto.FromString( descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(more_messages_desc) self.pool.Add(test1_desc) self.pool.Add(test2_desc) TEST1_FILE.CheckFile(self, self.pool) TEST2_FILE.CheckFile(self, self.pool) + def testEnumDefaultValue(self): + """Test the default value of enums which don't start at zero.""" + def _CheckDefaultValue(file_descriptor): + default_value = (file_descriptor + .message_types_by_name['DescriptorPoolTest1'] + .fields_by_name['nested_enum'] + .default_value) + self.assertEqual(default_value, + descriptor_pool_test1_pb2.DescriptorPoolTest1.BETA) + # First check what the generated descriptor contains. + _CheckDefaultValue(descriptor_pool_test1_pb2.DESCRIPTOR) + # Then check the generated pool. Normally this is the same descriptor. + file_descriptor = symbol_database.Default().pool.FindFileByName( + 'google/protobuf/internal/descriptor_pool_test1.proto') + self.assertIs(file_descriptor, descriptor_pool_test1_pb2.DESCRIPTOR) + _CheckDefaultValue(file_descriptor) + + # Then check the dynamic pool and its internal DescriptorDatabase. + descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(descriptor_proto) + # And do the same check as above + file_descriptor = self.pool.FindFileByName( + 'google/protobuf/internal/descriptor_pool_test1.proto') + _CheckDefaultValue(file_descriptor) + + def testDefaultValueForCustomMessages(self): + """Check the value returned by non-existent fields.""" + def _CheckValueAndType(value, expected_value, expected_type): + self.assertEqual(value, expected_value) + self.assertIsInstance(value, expected_type) + + def _CheckDefaultValues(msg): + try: + int64 = long + except NameError: # Python3 + int64 = int + try: + unicode_type = unicode + except NameError: # Python3 + unicode_type = str + _CheckValueAndType(msg.optional_int32, 0, int) + _CheckValueAndType(msg.optional_uint64, 0, (int64, int)) + _CheckValueAndType(msg.optional_float, 0, (float, int)) + _CheckValueAndType(msg.optional_double, 0, (float, int)) + _CheckValueAndType(msg.optional_bool, False, bool) + _CheckValueAndType(msg.optional_string, u'', unicode_type) + _CheckValueAndType(msg.optional_bytes, b'', bytes) + _CheckValueAndType(msg.optional_nested_enum, msg.FOO, int) + # First for the generated message + _CheckDefaultValues(unittest_pb2.TestAllTypes()) + # Then for a message built with from the DescriptorPool. + pool = descriptor_pool.DescriptorPool() + pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_public_pb2.DESCRIPTOR.serialized_pb)) + pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_pb2.DESCRIPTOR.serialized_pb)) + pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_pb2.DESCRIPTOR.serialized_pb)) + message_class = message_factory.MessageFactory(pool).GetPrototype( + pool.FindMessageTypeByName( + unittest_pb2.TestAllTypes.DESCRIPTOR.full_name)) + _CheckDefaultValues(message_class()) + class ProtoFile(object): - def __init__(self, name, package, messages, dependencies=None): + def __init__(self, name, package, messages, dependencies=None, + public_dependencies=None): self.name = name self.package = package self.messages = messages self.dependencies = dependencies or [] + self.public_dependencies = public_dependencies or [] def CheckFile(self, test, pool): file_desc = pool.FindFileByName(self.name) - test.assertEquals(self.name, file_desc.name) - test.assertEquals(self.package, file_desc.package) + test.assertEqual(self.name, file_desc.name) + test.assertEqual(self.package, file_desc.package) dependencies_names = [f.name for f in file_desc.dependencies] test.assertEqual(self.dependencies, dependencies_names) + public_dependencies_names = [f.name for f in file_desc.public_dependencies] + test.assertEqual(self.public_dependencies, public_dependencies_names) for name, msg_type in self.messages.items(): msg_type.CheckType(test, None, name, file_desc) @@ -328,7 +445,7 @@ class EnumField(object): test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_ENUM, field_desc.cpp_type) test.assertTrue(field_desc.has_default_value) - test.assertEqual(enum_desc.values_by_name[self.default_value].index, + test.assertEqual(enum_desc.values_by_name[self.default_value].number, field_desc.default_value) test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(enum_desc, field_desc.enum_type) @@ -399,12 +516,12 @@ class ExtensionField(object): test.assertEqual(self.extended_type, field_desc.containing_type.name) -class AddDescriptorTest(basetest.TestCase): +class AddDescriptorTest(unittest.TestCase): def _TestMessage(self, prefix): pool = descriptor_pool.DescriptorPool() pool.AddDescriptor(unittest_pb2.TestAllTypes.DESCRIPTOR) - self.assertEquals( + self.assertEqual( 'protobuf_unittest.TestAllTypes', pool.FindMessageTypeByName( prefix + 'protobuf_unittest.TestAllTypes').full_name) @@ -415,22 +532,24 @@ class AddDescriptorTest(basetest.TestCase): prefix + 'protobuf_unittest.TestAllTypes.NestedMessage') pool.AddDescriptor(unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR) - self.assertEquals( + self.assertEqual( 'protobuf_unittest.TestAllTypes.NestedMessage', pool.FindMessageTypeByName( prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name) # Files are implicitly also indexed when messages are added. - self.assertEquals( + self.assertEqual( 'google/protobuf/unittest.proto', pool.FindFileByName( 'google/protobuf/unittest.proto').name) - self.assertEquals( + self.assertEqual( 'google/protobuf/unittest.proto', pool.FindFileContainingSymbol( prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name) + @unittest.skipIf(api_implementation.Type() == 'cpp', + 'With the cpp implementation, Add() must be called first') def testMessage(self): self._TestMessage('') self._TestMessage('.') @@ -438,7 +557,7 @@ class AddDescriptorTest(basetest.TestCase): def _TestEnum(self, prefix): pool = descriptor_pool.DescriptorPool() pool.AddEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) - self.assertEquals( + self.assertEqual( 'protobuf_unittest.ForeignEnum', pool.FindEnumTypeByName( prefix + 'protobuf_unittest.ForeignEnum').full_name) @@ -449,30 +568,34 @@ class AddDescriptorTest(basetest.TestCase): prefix + 'protobuf_unittest.ForeignEnum.NestedEnum') pool.AddEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) - self.assertEquals( + self.assertEqual( 'protobuf_unittest.TestAllTypes.NestedEnum', pool.FindEnumTypeByName( prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) # Files are implicitly also indexed when enums are added. - self.assertEquals( + self.assertEqual( 'google/protobuf/unittest.proto', pool.FindFileByName( 'google/protobuf/unittest.proto').name) - self.assertEquals( + self.assertEqual( 'google/protobuf/unittest.proto', pool.FindFileContainingSymbol( prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name) + @unittest.skipIf(api_implementation.Type() == 'cpp', + 'With the cpp implementation, Add() must be called first') def testEnum(self): self._TestEnum('') self._TestEnum('.') + @unittest.skipIf(api_implementation.Type() == 'cpp', + 'With the cpp implementation, Add() must be called first') def testFile(self): pool = descriptor_pool.DescriptorPool() pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR) - self.assertEquals( + self.assertEqual( 'google/protobuf/unittest.proto', pool.FindFileByName( 'google/protobuf/unittest.proto').name) @@ -483,6 +606,67 @@ class AddDescriptorTest(basetest.TestCase): pool.FindFileContainingSymbol( 'protobuf_unittest.TestAllTypes') + def testEmptyDescriptorPool(self): + # Check that an empty DescriptorPool() contains no messages. + pool = descriptor_pool.DescriptorPool() + proto_file_name = descriptor_pb2.DESCRIPTOR.name + self.assertRaises(KeyError, pool.FindFileByName, proto_file_name) + # Add the above file to the pool + file_descriptor = descriptor_pb2.FileDescriptorProto() + descriptor_pb2.DESCRIPTOR.CopyToProto(file_descriptor) + pool.Add(file_descriptor) + # Now it exists. + self.assertTrue(pool.FindFileByName(proto_file_name)) + + def testCustomDescriptorPool(self): + # Create a new pool, and add a file descriptor. + pool = descriptor_pool.DescriptorPool() + file_desc = descriptor_pb2.FileDescriptorProto( + name='some/file.proto', package='package') + file_desc.message_type.add(name='Message') + pool.Add(file_desc) + self.assertEqual(pool.FindFileByName('some/file.proto').name, + 'some/file.proto') + self.assertEqual(pool.FindMessageTypeByName('package.Message').name, + 'Message') + + +@unittest.skipIf( + api_implementation.Type() != 'cpp', + 'default_pool is only supported by the C++ implementation') +class DefaultPoolTest(unittest.TestCase): + + def testFindMethods(self): + # pylint: disable=g-import-not-at-top + from google.protobuf.pyext import _message + pool = _message.default_pool + self.assertIs( + pool.FindFileByName('google/protobuf/unittest.proto'), + unittest_pb2.DESCRIPTOR) + self.assertIs( + pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'), + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertIs( + pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'), + unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32']) + self.assertIs( + pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), + unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) + self.assertIs( + pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'), + unittest_pb2.ForeignEnum.DESCRIPTOR) + self.assertIs( + pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'), + unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field']) + + def testAddFileDescriptor(self): + # pylint: disable=g-import-not-at-top + from google.protobuf.pyext import _message + pool = _message.default_pool + file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') + pool.Add(file_desc) + pool.AddSerializedFile(file_desc.SerializeToString()) + TEST1_FILE = ProtoFile( 'google/protobuf/internal/descriptor_pool_test1.proto', @@ -557,8 +741,10 @@ TEST2_FILE = ProtoFile( ExtensionField(1001, 'DescriptorPoolTest1')), ]), }, - dependencies=['google/protobuf/internal/descriptor_pool_test1.proto']) + dependencies=['google/protobuf/internal/descriptor_pool_test1.proto', + 'google/protobuf/internal/more_messages.proto'], + public_dependencies=['google/protobuf/internal/more_messages.proto']) if __name__ == '__main__': - basetest.main() + unittest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test1.proto b/python/google/protobuf/internal/descriptor_pool_test1.proto index 6dfe4ef32..00816b78e 100644 --- a/python/google/protobuf/internal/descriptor_pool_test1.proto +++ b/python/google/protobuf/internal/descriptor_pool_test1.proto @@ -28,6 +28,8 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +syntax = "proto2"; + package google.protobuf.python.internal; diff --git a/python/google/protobuf/internal/descriptor_pool_test2.proto b/python/google/protobuf/internal/descriptor_pool_test2.proto index fbc84382a..a218eccb9 100644 --- a/python/google/protobuf/internal/descriptor_pool_test2.proto +++ b/python/google/protobuf/internal/descriptor_pool_test2.proto @@ -28,9 +28,12 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +syntax = "proto2"; + package google.protobuf.python.internal; import "google/protobuf/internal/descriptor_pool_test1.proto"; +import public "google/protobuf/internal/more_messages.proto"; message DescriptorPoolTest3 { diff --git a/python/google/protobuf/internal/descriptor_python_test.py b/python/google/protobuf/internal/descriptor_python_test.py deleted file mode 100644 index 5471ae021..000000000 --- a/python/google/protobuf/internal/descriptor_python_test.py +++ /dev/null @@ -1,54 +0,0 @@ -#! /usr/bin/python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Unittest for descriptor.py for the pure Python implementation.""" - -import os -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' - -# We must set the implementation version above before the google3 imports. -# pylint: disable=g-import-not-at-top -from google.apputils import basetest -from google.protobuf.internal import api_implementation -# Run all tests from the original module by putting them in our namespace. -# pylint: disable=wildcard-import -from google.protobuf.internal.descriptor_test import * - - -class ConfirmPurePythonTest(basetest.TestCase): - - def testImplementationSetting(self): - self.assertEqual('python', api_implementation.Type()) - - -if __name__ == '__main__': - basetest.main() diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index b3777e396..b8e755533 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -34,12 +34,22 @@ __author__ = 'robinson@google.com (Will Robinson)' -from google.apputils import basetest +import sys + +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 +from google.protobuf.internal import api_implementation +from google.protobuf.internal import test_util from google.protobuf import descriptor +from google.protobuf import descriptor_pool +from google.protobuf import symbol_database from google.protobuf import text_format @@ -48,44 +58,31 @@ name: 'TestEmptyMessage' """ -class DescriptorTest(basetest.TestCase): +class DescriptorTest(unittest.TestCase): def setUp(self): - self.my_file = descriptor.FileDescriptor( + file_proto = descriptor_pb2.FileDescriptorProto( name='some/filename/some.proto', - package='protobuf_unittest' - ) - self.my_enum = descriptor.EnumDescriptor( - name='ForeignEnum', - full_name='protobuf_unittest.ForeignEnum', - filename=None, - file=self.my_file, - values=[ - descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4), - descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5), - descriptor.EnumValueDescriptor(name='FOREIGN_BAZ', index=2, number=6), - ]) - self.my_message = descriptor.Descriptor( - name='NestedMessage', - full_name='protobuf_unittest.TestAllTypes.NestedMessage', - filename=None, - file=self.my_file, - containing_type=None, - fields=[ - descriptor.FieldDescriptor( - name='bb', - full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb', - index=0, number=1, - type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None), - ], - nested_types=[], - enum_types=[ - self.my_enum, - ], - extensions=[]) + package='protobuf_unittest') + message_proto = file_proto.message_type.add( + name='NestedMessage') + message_proto.field.add( + name='bb', + number=1, + type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32, + label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL) + enum_proto = message_proto.enum_type.add( + name='ForeignEnum') + enum_proto.value.add(name='FOREIGN_FOO', number=4) + enum_proto.value.add(name='FOREIGN_BAR', number=5) + enum_proto.value.add(name='FOREIGN_BAZ', number=6) + + self.pool = self.GetDescriptorPool() + self.pool.Add(file_proto) + self.my_file = self.pool.FindFileByName(file_proto.name) + self.my_message = self.my_file.message_types_by_name[message_proto.name] + self.my_enum = self.my_message.enum_types_by_name[enum_proto.name] + self.my_method = descriptor.MethodDescriptor( name='Bar', full_name='protobuf_unittest.TestService.Bar', @@ -102,6 +99,9 @@ class DescriptorTest(basetest.TestCase): self.my_method ]) + def GetDescriptorPool(self): + return symbol_database.Default().pool + def testEnumValueName(self): self.assertEqual(self.my_message.EnumValueName('ForeignEnum', 4), 'FOREIGN_FOO') @@ -173,6 +173,11 @@ class DescriptorTest(basetest.TestCase): self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2, method_options.Extensions[method_opt1]) + message_descriptor = ( + unittest_custom_options_pb2.DummyMessageContainingEnum.DESCRIPTOR) + self.assertTrue(file_descriptor.has_options) + self.assertFalse(message_descriptor.has_options) + def testDifferentCustomOptionTypes(self): kint32min = -2**31 kint64min = -2**63 @@ -393,9 +398,130 @@ class DescriptorTest(basetest.TestCase): def testFileDescriptor(self): self.assertEqual(self.my_file.name, 'some/filename/some.proto') self.assertEqual(self.my_file.package, 'protobuf_unittest') - - -class DescriptorCopyToProtoTest(basetest.TestCase): + self.assertEqual(self.my_file.pool, self.pool) + # Generated modules also belong to the default pool. + self.assertEqual(unittest_pb2.DESCRIPTOR.pool, descriptor_pool.Default()) + + @unittest.skipIf( + api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, + 'Immutability of descriptors is only enforced in v2 implementation') + def testImmutableCppDescriptor(self): + message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + with self.assertRaises(AttributeError): + message_descriptor.fields_by_name = None + with self.assertRaises(TypeError): + message_descriptor.fields_by_name['Another'] = None + with self.assertRaises(TypeError): + message_descriptor.fields.append(None) + + +class NewDescriptorTest(DescriptorTest): + """Redo the same tests as above, but with a separate DescriptorPool.""" + + def GetDescriptorPool(self): + return descriptor_pool.DescriptorPool() + + +class GeneratedDescriptorTest(unittest.TestCase): + """Tests for the properties of descriptors in generated code.""" + + def CheckMessageDescriptor(self, message_descriptor): + # Basic properties + self.assertEqual(message_descriptor.name, 'TestAllTypes') + self.assertEqual(message_descriptor.full_name, + 'protobuf_unittest.TestAllTypes') + # Test equality and hashability + self.assertEqual(message_descriptor, message_descriptor) + self.assertEqual(message_descriptor.fields[0].containing_type, + message_descriptor) + self.assertIn(message_descriptor, [message_descriptor]) + self.assertIn(message_descriptor, {message_descriptor: None}) + # Test field containers + self.CheckDescriptorSequence(message_descriptor.fields) + self.CheckDescriptorMapping(message_descriptor.fields_by_name) + self.CheckDescriptorMapping(message_descriptor.fields_by_number) + self.CheckDescriptorMapping(message_descriptor.fields_by_camelcase_name) + + def CheckFieldDescriptor(self, field_descriptor): + # Basic properties + self.assertEqual(field_descriptor.name, 'optional_int32') + self.assertEqual(field_descriptor.camelcase_name, 'optionalInt32') + self.assertEqual(field_descriptor.full_name, + 'protobuf_unittest.TestAllTypes.optional_int32') + self.assertEqual(field_descriptor.containing_type.name, 'TestAllTypes') + # Test equality and hashability + self.assertEqual(field_descriptor, field_descriptor) + self.assertEqual( + field_descriptor.containing_type.fields_by_name['optional_int32'], + field_descriptor) + self.assertEqual( + field_descriptor.containing_type.fields_by_camelcase_name[ + 'optionalInt32'], + field_descriptor) + self.assertIn(field_descriptor, [field_descriptor]) + self.assertIn(field_descriptor, {field_descriptor: None}) + + def CheckDescriptorSequence(self, sequence): + # Verifies that a property like 'messageDescriptor.fields' has all the + # properties of an immutable abc.Sequence. + self.assertGreater(len(sequence), 0) # Sized + self.assertEqual(len(sequence), len(list(sequence))) # Iterable + item = sequence[0] + self.assertEqual(item, sequence[0]) + self.assertIn(item, sequence) # Container + self.assertEqual(sequence.index(item), 0) + self.assertEqual(sequence.count(item), 1) + reversed_iterator = reversed(sequence) + self.assertEqual(list(reversed_iterator), list(sequence)[::-1]) + self.assertRaises(StopIteration, next, reversed_iterator) + + def CheckDescriptorMapping(self, mapping): + # Verifies that a property like 'messageDescriptor.fields' has all the + # properties of an immutable abc.Mapping. + self.assertGreater(len(mapping), 0) # Sized + self.assertEqual(len(mapping), len(list(mapping))) # Iterable + if sys.version_info >= (3,): + key, item = next(iter(mapping.items())) + else: + key, item = mapping.items()[0] + self.assertIn(key, mapping) # Container + self.assertEqual(mapping.get(key), item) + # keys(), iterkeys() &co + item = (next(iter(mapping.keys())), next(iter(mapping.values()))) + self.assertEqual(item, next(iter(mapping.items()))) + if sys.version_info < (3,): + def CheckItems(seq, iterator): + self.assertEqual(next(iterator), seq[0]) + self.assertEqual(list(iterator), seq[1:]) + CheckItems(mapping.keys(), mapping.iterkeys()) + CheckItems(mapping.values(), mapping.itervalues()) + CheckItems(mapping.items(), mapping.iteritems()) + + def testDescriptor(self): + message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + self.CheckMessageDescriptor(message_descriptor) + field_descriptor = message_descriptor.fields_by_name['optional_int32'] + self.CheckFieldDescriptor(field_descriptor) + field_descriptor = message_descriptor.fields_by_camelcase_name[ + 'optionalInt32'] + self.CheckFieldDescriptor(field_descriptor) + + def testCppDescriptorContainer(self): + # Check that the collection is still valid even if the parent disappeared. + enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum'] + values = enum.values + del enum + self.assertEqual('FOO', values[0].name) + + def testCppDescriptorContainer_Iterator(self): + # Same test with the iterator + enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum'] + values_iter = iter(enum.values) + del enum + self.assertEqual('FOO', next(values_iter).name) + + +class DescriptorCopyToProtoTest(unittest.TestCase): """Tests for CopyTo functions of Descriptor.""" def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii): @@ -471,7 +597,7 @@ class DescriptorCopyToProtoTest(basetest.TestCase): """ self._InternalTestCopyToProto( - unittest_pb2._FOREIGNENUM, + unittest_pb2.ForeignEnum.DESCRIPTOR, descriptor_pb2.EnumDescriptorProto, TEST_FOREIGN_ENUM_ASCII) @@ -588,13 +714,15 @@ class DescriptorCopyToProtoTest(basetest.TestCase): output_type: '.protobuf_unittest.BarResponse' > """ - self._InternalTestCopyToProto( - unittest_pb2.TestService.DESCRIPTOR, - descriptor_pb2.ServiceDescriptorProto, - TEST_SERVICE_ASCII) + # TODO(rocking): enable this test after the proto descriptor change is + # checked in. + #self._InternalTestCopyToProto( + # unittest_pb2.TestService.DESCRIPTOR, + # descriptor_pb2.ServiceDescriptorProto, + # TEST_SERVICE_ASCII) -class MakeDescriptorTest(basetest.TestCase): +class MakeDescriptorTest(unittest.TestCase): def testMakeDescriptorWithNestedFields(self): file_descriptor_proto = descriptor_pb2.FileDescriptorProto() @@ -665,5 +793,30 @@ class MakeDescriptorTest(basetest.TestCase): descriptor.FieldDescriptor.CPPTYPE_UINT64) + def testMakeDescriptorWithOptions(self): + descriptor_proto = descriptor_pb2.DescriptorProto() + aggregate_message = unittest_custom_options_pb2.AggregateMessage + aggregate_message.DESCRIPTOR.CopyToProto(descriptor_proto) + reformed_descriptor = descriptor.MakeDescriptor(descriptor_proto) + + options = reformed_descriptor.GetOptions() + self.assertEqual(101, + options.Extensions[unittest_custom_options_pb2.msgopt].i) + + def testCamelcaseName(self): + descriptor_proto = descriptor_pb2.DescriptorProto() + descriptor_proto.name = 'Bar' + names = ['foo_foo', 'FooBar', 'fooBaz', 'fooFoo', 'foobar'] + camelcase_names = ['fooFoo', 'fooBar', 'fooBaz', 'fooFoo', 'foobar'] + for index in range(len(names)): + field = descriptor_proto.field.add() + field.number = index + 1 + field.name = names[index] + result = descriptor.MakeDescriptor(descriptor_proto) + for index in range(len(camelcase_names)): + self.assertEqual(result.fields[index].camelcase_name, + camelcase_names[index]) + + if __name__ == '__main__': - basetest.main() + unittest.main() diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index 38a5138ae..48ef2df31 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -28,10 +28,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#PY25 compatible for GAE. -# -# Copyright 2009 Google Inc. All Rights Reserved. - """Code for encoding protocol message primitives. Contains the logic for encoding every logical protocol field type @@ -45,7 +41,7 @@ FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The sizer takes a value of this field's type and computes its byte size. The encoder takes a writer function and a value. It encodes the value into byte strings and invokes the writer function to write those strings. Typically the -writer function is the write() method of a cStringIO. +writer function is the write() method of a BytesIO. We try to do as much work as possible when constructing the writer and the sizer rather than when calling them. In particular: @@ -71,8 +67,9 @@ sizer rather than when calling them. In particular: __author__ = 'kenton@google.com (Kenton Varda)' import struct -import sys ##PY25 -_PY2 = sys.version_info[0] < 3 ##PY25 + +import six + from google.protobuf.internal import wire_format @@ -314,7 +311,7 @@ def MessageSizer(field_number, is_repeated, is_packed): # -------------------------------------------------------------------- -# MessageSet is special. +# MessageSet is special: it needs custom logic to compute its size properly. def MessageSetItemSizer(field_number): @@ -339,6 +336,32 @@ def MessageSetItemSizer(field_number): return FieldSize +# -------------------------------------------------------------------- +# Map is special: it needs custom logic to compute its size properly. + + +def MapSizer(field_descriptor): + """Returns a sizer for a map field.""" + + # Can't look at field_descriptor.message_type._concrete_class because it may + # not have been initialized yet. + message_type = field_descriptor.message_type + message_sizer = MessageSizer(field_descriptor.number, False, False) + + def FieldSize(map_value): + total = 0 + for key in map_value: + value = map_value[key] + # It's wasteful to create the messages and throw them away one second + # later since we'll do the same for the actual encode. But there's not an + # obvious way to avoid this within the current design without tons of code + # duplication. + entry_msg = message_type._concrete_class(key=key, value=value) + total += message_sizer(entry_msg) + return total + + return FieldSize + # ==================================================================== # Encoders! @@ -346,16 +369,14 @@ def MessageSetItemSizer(field_number): def _VarintEncoder(): """Return an encoder for a basic varint value (does not include tag).""" - local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25 -##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,)) def EncodeVarint(write, value): bits = value & 0x7f value >>= 7 while value: - write(local_chr(0x80|bits)) + write(six.int2byte(0x80|bits)) bits = value & 0x7f value >>= 7 - return write(local_chr(bits)) + return write(six.int2byte(bits)) return EncodeVarint @@ -364,18 +385,16 @@ def _SignedVarintEncoder(): """Return an encoder for a basic signed varint value (does not include tag).""" - local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25 -##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,)) def EncodeSignedVarint(write, value): if value < 0: value += (1 << 64) bits = value & 0x7f value >>= 7 while value: - write(local_chr(0x80|bits)) + write(six.int2byte(0x80|bits)) bits = value & 0x7f value >>= 7 - return write(local_chr(bits)) + return write(six.int2byte(bits)) return EncodeSignedVarint @@ -390,8 +409,7 @@ def _VarintBytes(value): pieces = [] _EncodeVarint(pieces.append, value) - return "".encode("latin1").join(pieces) ##PY25 -##!PY25 return b"".join(pieces) + return b"".join(pieces) def TagBytes(field_number, wire_type): @@ -529,33 +547,26 @@ def _FloatingPointEncoder(wire_type, format): format: The format string to pass to struct.pack(). """ - b = _PY2 and (lambda x:x) or (lambda x:x.encode('latin1')) ##PY25 value_size = struct.calcsize(format) if value_size == 4: def EncodeNonFiniteOrRaise(write, value): # Remember that the serialized form uses little-endian byte order. if value == _POS_INF: - write(b('\x00\x00\x80\x7F')) ##PY25 -##!PY25 write(b'\x00\x00\x80\x7F') + write(b'\x00\x00\x80\x7F') elif value == _NEG_INF: - write(b('\x00\x00\x80\xFF')) ##PY25 -##!PY25 write(b'\x00\x00\x80\xFF') + write(b'\x00\x00\x80\xFF') elif value != value: # NaN - write(b('\x00\x00\xC0\x7F')) ##PY25 -##!PY25 write(b'\x00\x00\xC0\x7F') + write(b'\x00\x00\xC0\x7F') else: raise elif value_size == 8: def EncodeNonFiniteOrRaise(write, value): if value == _POS_INF: - write(b('\x00\x00\x00\x00\x00\x00\xF0\x7F')) ##PY25 -##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F') + write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F') elif value == _NEG_INF: - write(b('\x00\x00\x00\x00\x00\x00\xF0\xFF')) ##PY25 -##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF') + write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF') elif value != value: # NaN - write(b('\x00\x00\x00\x00\x00\x00\xF8\x7F')) ##PY25 -##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F') + write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F') else: raise else: @@ -631,10 +642,8 @@ DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d') def BoolEncoder(field_number, is_repeated, is_packed): """Returns an encoder for a boolean field.""" -##!PY25 false_byte = b'\x00' -##!PY25 true_byte = b'\x01' - false_byte = '\x00'.encode('latin1') ##PY25 - true_byte = '\x01'.encode('latin1') ##PY25 + false_byte = b'\x00' + true_byte = b'\x01' if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint @@ -770,8 +779,7 @@ def MessageSetItemEncoder(field_number): } } """ - start_bytes = "".encode("latin1").join([ ##PY25 -##!PY25 start_bytes = b"".join([ + start_bytes = b"".join([ TagBytes(1, wire_format.WIRETYPE_START_GROUP), TagBytes(2, wire_format.WIRETYPE_VARINT), _VarintBytes(field_number), @@ -786,3 +794,30 @@ def MessageSetItemEncoder(field_number): return write(end_bytes) return EncodeField + + +# -------------------------------------------------------------------- +# As before, Map is special. + + +def MapEncoder(field_descriptor): + """Encoder for extensions of MessageSet. + + Maps always have a wire format like this: + message MapEntry { + key_type key = 1; + value_type value = 2; + } + repeated MapEntry map = N; + """ + # Can't look at field_descriptor.message_type._concrete_class because it may + # not have been initialized yet. + message_type = field_descriptor.message_type + encode_message = MessageEncoder(field_descriptor.number, False, False) + + def EncodeField(write, value): + for key in value: + entry_msg = message_type._concrete_class(key=key, value=value[key]) + encode_message(write, entry_msg) + + return EncodeField diff --git a/python/google/protobuf/internal/factory_test1.proto b/python/google/protobuf/internal/factory_test1.proto index 9f5a39192..d2fbbeecf 100644 --- a/python/google/protobuf/internal/factory_test1.proto +++ b/python/google/protobuf/internal/factory_test1.proto @@ -30,6 +30,7 @@ // Author: matthewtoia@google.com (Matt Toia) +syntax = "proto2"; package google.protobuf.python.internal; diff --git a/python/google/protobuf/internal/factory_test2.proto b/python/google/protobuf/internal/factory_test2.proto index 27feb6ce4..bb1b54ada 100644 --- a/python/google/protobuf/internal/factory_test2.proto +++ b/python/google/protobuf/internal/factory_test2.proto @@ -30,6 +30,7 @@ // Author: matthewtoia@google.com (Matt Toia) +syntax = "proto2"; package google.protobuf.python.internal; @@ -87,6 +88,12 @@ message LoopMessage { optional Factory2Message loop = 1; } +message MessageWithNestedEnumOnly { + enum NestedEnum { + NESTED_MESSAGE_ENUM_0 = 0; + } +} + extend Factory1Message { optional string another_field = 1002; } diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index 422fa9a66..83ea5f509 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -35,18 +35,23 @@ # indirect testing of the protocol compiler output. """Unittest that directly tests the output of the pure-Python protocol -compiler. See //google/protobuf/reflection_test.py for a test which +compiler. See //google/protobuf/internal/reflection_test.py for a test which further ensures that we can use Python protocol message objects as we expect. """ __author__ = 'robinson@google.com (Will Robinson)' -from google.apputils import basetest +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + from google.protobuf.internal import test_bad_identifiers_pb2 from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_import_public_pb2 from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_mset_wire_format_pb2 from google.protobuf import unittest_no_generic_services_pb2 from google.protobuf import unittest_pb2 from google.protobuf import service @@ -55,7 +60,7 @@ from google.protobuf import symbol_database MAX_EXTENSION = 536870912 -class GeneratorTest(basetest.TestCase): +class GeneratorTest(unittest.TestCase): def testNestedMessageDescriptor(self): field_name = 'optional_nested_message' @@ -142,18 +147,18 @@ class GeneratorTest(basetest.TestCase): self.assertTrue(not non_extension_descriptor.is_extension) def testOptions(self): - proto = unittest_mset_pb2.TestMessageSet() + proto = unittest_mset_wire_format_pb2.TestMessageSet() self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format) def testMessageWithCustomOptions(self): proto = unittest_custom_options_pb2.TestMessageWithCustomOptions() enum_options = proto.DESCRIPTOR.enum_types_by_name['AnEnum'].GetOptions() self.assertTrue(enum_options is not None) - # TODO(gps): We really should test for the presense of the enum_opt1 + # TODO(gps): We really should test for the presence of the enum_opt1 # extension and for its value to be set to -789. def testNestedTypes(self): - self.assertEquals( + self.assertEqual( set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types), set([ unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR, @@ -291,53 +296,53 @@ class GeneratorTest(basetest.TestCase): self.assertIs(desc.oneofs[0], desc.oneofs_by_name['oneof_field']) nested_names = set(['oneof_uint32', 'oneof_nested_message', 'oneof_string', 'oneof_bytes']) - self.assertSameElements( + self.assertEqual( nested_names, - [field.name for field in desc.oneofs[0].fields]) - for field_name, field_desc in desc.fields_by_name.iteritems(): + set([field.name for field in desc.oneofs[0].fields])) + for field_name, field_desc in desc.fields_by_name.items(): if field_name in nested_names: self.assertIs(desc.oneofs[0], field_desc.containing_oneof) else: self.assertIsNone(field_desc.containing_oneof) -class SymbolDatabaseRegistrationTest(basetest.TestCase): +class SymbolDatabaseRegistrationTest(unittest.TestCase): """Checks that messages, enums and files are correctly registered.""" def testGetSymbol(self): - self.assertEquals( + self.assertEqual( unittest_pb2.TestAllTypes, symbol_database.Default().GetSymbol( 'protobuf_unittest.TestAllTypes')) - self.assertEquals( + self.assertEqual( unittest_pb2.TestAllTypes.NestedMessage, symbol_database.Default().GetSymbol( 'protobuf_unittest.TestAllTypes.NestedMessage')) with self.assertRaises(KeyError): symbol_database.Default().GetSymbol('protobuf_unittest.NestedMessage') - self.assertEquals( + self.assertEqual( unittest_pb2.TestAllTypes.OptionalGroup, symbol_database.Default().GetSymbol( 'protobuf_unittest.TestAllTypes.OptionalGroup')) - self.assertEquals( + self.assertEqual( unittest_pb2.TestAllTypes.RepeatedGroup, symbol_database.Default().GetSymbol( 'protobuf_unittest.TestAllTypes.RepeatedGroup')) def testEnums(self): - self.assertEquals( + self.assertEqual( 'protobuf_unittest.ForeignEnum', symbol_database.Default().pool.FindEnumTypeByName( 'protobuf_unittest.ForeignEnum').full_name) - self.assertEquals( + self.assertEqual( 'protobuf_unittest.TestAllTypes.NestedEnum', symbol_database.Default().pool.FindEnumTypeByName( 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) def testFindFileByName(self): - self.assertEquals( + self.assertEqual( 'google/protobuf/unittest.proto', symbol_database.Default().pool.FindFileByName( 'google/protobuf/unittest.proto').name) if __name__ == '__main__': - basetest.main() + unittest.main() diff --git a/python/google/protobuf/internal/message_python_test.py b/python/google/protobuf/internal/import_test_package/__init__.py index c40623a8c..5121dd0ec 100644 --- a/python/google/protobuf/internal/message_python_test.py +++ b/python/google/protobuf/internal/import_test_package/__init__.py @@ -1,5 +1,3 @@ -#! /usr/bin/python -# # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. # https://developers.google.com/protocol-buffers/ @@ -30,25 +28,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -"""Tests for ..public.message for the pure Python implementation.""" - -import os -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' - -# We must set the implementation version above before the google3 imports. -# pylint: disable=g-import-not-at-top -from google.apputils import basetest -from google.protobuf.internal import api_implementation -# Run all tests from the original module by putting them in our namespace. -# pylint: disable=wildcard-import -from google.protobuf.internal.message_test import * - - -class ConfirmPurePythonTest(basetest.TestCase): - - def testImplementationSetting(self): - self.assertEqual('python', api_implementation.Type()) - +"""Sample module importing a nested proto from itself.""" -if __name__ == '__main__': - basetest.main() +from google.protobuf.internal.import_test_package import outer_pb2 as myproto diff --git a/python/google/protobuf/internal/import_test_package/inner.proto b/python/google/protobuf/internal/import_test_package/inner.proto new file mode 100644 index 000000000..2887c1230 --- /dev/null +++ b/python/google/protobuf/internal/import_test_package/inner.proto @@ -0,0 +1,37 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +syntax = "proto2"; + +package google.protobuf.python.internal.import_test_package; + +message Inner { + optional int32 value = 1 [default = 57]; +} diff --git a/python/google/protobuf/internal/import_test_package/outer.proto b/python/google/protobuf/internal/import_test_package/outer.proto new file mode 100644 index 000000000..a27fb5c8f --- /dev/null +++ b/python/google/protobuf/internal/import_test_package/outer.proto @@ -0,0 +1,39 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +syntax = "proto2"; + +package google.protobuf.python.internal.import_test_package; + +import "google/protobuf/internal/import_test_package/inner.proto"; + +message Outer { + optional Inner inner = 1; +} diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py new file mode 100644 index 000000000..bdc9f49a4 --- /dev/null +++ b/python/google/protobuf/internal/json_format_test.py @@ -0,0 +1,769 @@ +#! /usr/bin/env python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for google.protobuf.json_format.""" + +__author__ = 'jieluo@google.com (Jie Luo)' + +import json +import math +import sys + +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + +from google.protobuf import any_pb2 +from google.protobuf import duration_pb2 +from google.protobuf import field_mask_pb2 +from google.protobuf import struct_pb2 +from google.protobuf import timestamp_pb2 +from google.protobuf import wrappers_pb2 +from google.protobuf.internal import well_known_types +from google.protobuf import json_format +from google.protobuf.util import json_format_proto3_pb2 + + +class JsonFormatBase(unittest.TestCase): + + def FillAllFields(self, message): + message.int32_value = 20 + message.int64_value = -20 + message.uint32_value = 3120987654 + message.uint64_value = 12345678900 + message.float_value = float('-inf') + message.double_value = 3.1415 + message.bool_value = True + message.string_value = 'foo' + message.bytes_value = b'bar' + message.message_value.value = 10 + message.enum_value = json_format_proto3_pb2.BAR + # Repeated + message.repeated_int32_value.append(0x7FFFFFFF) + message.repeated_int32_value.append(-2147483648) + message.repeated_int64_value.append(9007199254740992) + message.repeated_int64_value.append(-9007199254740992) + message.repeated_uint32_value.append(0xFFFFFFF) + message.repeated_uint32_value.append(0x7FFFFFF) + message.repeated_uint64_value.append(9007199254740992) + message.repeated_uint64_value.append(9007199254740991) + message.repeated_float_value.append(0) + + message.repeated_double_value.append(1E-15) + message.repeated_double_value.append(float('inf')) + message.repeated_bool_value.append(True) + message.repeated_bool_value.append(False) + message.repeated_string_value.append('Few symbols!#$,;') + message.repeated_string_value.append('bar') + message.repeated_bytes_value.append(b'foo') + message.repeated_bytes_value.append(b'bar') + message.repeated_message_value.add().value = 10 + message.repeated_message_value.add().value = 11 + message.repeated_enum_value.append(json_format_proto3_pb2.FOO) + message.repeated_enum_value.append(json_format_proto3_pb2.BAR) + self.message = message + + def CheckParseBack(self, message, parsed_message): + json_format.Parse(json_format.MessageToJson(message), + parsed_message) + self.assertEqual(message, parsed_message) + + def CheckError(self, text, error_message): + message = json_format_proto3_pb2.TestMessage() + self.assertRaisesRegexp( + json_format.ParseError, + error_message, + json_format.Parse, text, message) + + +class JsonFormatTest(JsonFormatBase): + + def testEmptyMessageToJson(self): + message = json_format_proto3_pb2.TestMessage() + self.assertEqual(json_format.MessageToJson(message), + '{}') + parsed_message = json_format_proto3_pb2.TestMessage() + self.CheckParseBack(message, parsed_message) + + def testPartialMessageToJson(self): + message = json_format_proto3_pb2.TestMessage( + string_value='test', + repeated_int32_value=[89, 4]) + self.assertEqual(json.loads(json_format.MessageToJson(message)), + json.loads('{"stringValue": "test", ' + '"repeatedInt32Value": [89, 4]}')) + parsed_message = json_format_proto3_pb2.TestMessage() + self.CheckParseBack(message, parsed_message) + + def testAllFieldsToJson(self): + message = json_format_proto3_pb2.TestMessage() + text = ('{"int32Value": 20, ' + '"int64Value": "-20", ' + '"uint32Value": 3120987654,' + '"uint64Value": "12345678900",' + '"floatValue": "-Infinity",' + '"doubleValue": 3.1415,' + '"boolValue": true,' + '"stringValue": "foo",' + '"bytesValue": "YmFy",' + '"messageValue": {"value": 10},' + '"enumValue": "BAR",' + '"repeatedInt32Value": [2147483647, -2147483648],' + '"repeatedInt64Value": ["9007199254740992", "-9007199254740992"],' + '"repeatedUint32Value": [268435455, 134217727],' + '"repeatedUint64Value": ["9007199254740992", "9007199254740991"],' + '"repeatedFloatValue": [0],' + '"repeatedDoubleValue": [1e-15, "Infinity"],' + '"repeatedBoolValue": [true, false],' + '"repeatedStringValue": ["Few symbols!#$,;", "bar"],' + '"repeatedBytesValue": ["Zm9v", "YmFy"],' + '"repeatedMessageValue": [{"value": 10}, {"value": 11}],' + '"repeatedEnumValue": ["FOO", "BAR"]' + '}') + self.FillAllFields(message) + self.assertEqual( + json.loads(json_format.MessageToJson(message)), + json.loads(text)) + parsed_message = json_format_proto3_pb2.TestMessage() + json_format.Parse(text, parsed_message) + self.assertEqual(message, parsed_message) + + def testJsonEscapeString(self): + message = json_format_proto3_pb2.TestMessage() + if sys.version_info[0] < 3: + message.string_value = '&\n<\"\r>\b\t\f\\\001/\xe2\x80\xa8\xe2\x80\xa9' + else: + message.string_value = '&\n<\"\r>\b\t\f\\\001/' + message.string_value += (b'\xe2\x80\xa8\xe2\x80\xa9').decode('utf-8') + self.assertEqual( + json_format.MessageToJson(message), + '{\n "stringValue": ' + '"&\\n<\\\"\\r>\\b\\t\\f\\\\\\u0001/\\u2028\\u2029"\n}') + parsed_message = json_format_proto3_pb2.TestMessage() + self.CheckParseBack(message, parsed_message) + text = u'{"int32Value": "\u0031"}' + json_format.Parse(text, message) + self.assertEqual(message.int32_value, 1) + + def testAlwaysSeriliaze(self): + message = json_format_proto3_pb2.TestMessage( + string_value='foo') + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads('{' + '"repeatedStringValue": [],' + '"stringValue": "foo",' + '"repeatedBoolValue": [],' + '"repeatedUint32Value": [],' + '"repeatedInt32Value": [],' + '"enumValue": "FOO",' + '"int32Value": 0,' + '"floatValue": 0,' + '"int64Value": "0",' + '"uint32Value": 0,' + '"repeatedBytesValue": [],' + '"repeatedUint64Value": [],' + '"repeatedDoubleValue": [],' + '"bytesValue": "",' + '"boolValue": false,' + '"repeatedEnumValue": [],' + '"uint64Value": "0",' + '"doubleValue": 0,' + '"repeatedFloatValue": [],' + '"repeatedInt64Value": [],' + '"repeatedMessageValue": []}')) + parsed_message = json_format_proto3_pb2.TestMessage() + self.CheckParseBack(message, parsed_message) + + def testMapFields(self): + message = json_format_proto3_pb2.TestMap() + message.bool_map[True] = 1 + message.bool_map[False] = 2 + message.int32_map[1] = 2 + message.int32_map[2] = 3 + message.int64_map[1] = 2 + message.int64_map[2] = 3 + message.uint32_map[1] = 2 + message.uint32_map[2] = 3 + message.uint64_map[1] = 2 + message.uint64_map[2] = 3 + message.string_map['1'] = 2 + message.string_map['null'] = 3 + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads('{' + '"boolMap": {"false": 2, "true": 1},' + '"int32Map": {"1": 2, "2": 3},' + '"int64Map": {"1": 2, "2": 3},' + '"uint32Map": {"1": 2, "2": 3},' + '"uint64Map": {"1": 2, "2": 3},' + '"stringMap": {"1": 2, "null": 3}' + '}')) + parsed_message = json_format_proto3_pb2.TestMap() + self.CheckParseBack(message, parsed_message) + + def testOneofFields(self): + message = json_format_proto3_pb2.TestOneof() + # Always print does not affect oneof fields. + self.assertEqual( + json_format.MessageToJson(message, True), + '{}') + message.oneof_int32_value = 0 + self.assertEqual( + json_format.MessageToJson(message, True), + '{\n' + ' "oneofInt32Value": 0\n' + '}') + parsed_message = json_format_proto3_pb2.TestOneof() + self.CheckParseBack(message, parsed_message) + + def testTimestampMessage(self): + message = json_format_proto3_pb2.TestTimestamp() + message.value.seconds = 0 + message.value.nanos = 0 + message.repeated_value.add().seconds = 20 + message.repeated_value[0].nanos = 1 + message.repeated_value.add().seconds = 0 + message.repeated_value[1].nanos = 10000 + message.repeated_value.add().seconds = 100000000 + message.repeated_value[2].nanos = 0 + # Maximum time + message.repeated_value.add().seconds = 253402300799 + message.repeated_value[3].nanos = 999999999 + # Minimum time + message.repeated_value.add().seconds = -62135596800 + message.repeated_value[4].nanos = 0 + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads('{' + '"value": "1970-01-01T00:00:00Z",' + '"repeatedValue": [' + ' "1970-01-01T00:00:20.000000001Z",' + ' "1970-01-01T00:00:00.000010Z",' + ' "1973-03-03T09:46:40Z",' + ' "9999-12-31T23:59:59.999999999Z",' + ' "0001-01-01T00:00:00Z"' + ']' + '}')) + parsed_message = json_format_proto3_pb2.TestTimestamp() + self.CheckParseBack(message, parsed_message) + text = (r'{"value": "1970-01-01T00:00:00.01+08:00",' + r'"repeatedValue":[' + r' "1970-01-01T00:00:00.01+08:30",' + r' "1970-01-01T00:00:00.01-01:23"]}') + json_format.Parse(text, parsed_message) + self.assertEqual(parsed_message.value.seconds, -8 * 3600) + self.assertEqual(parsed_message.value.nanos, 10000000) + self.assertEqual(parsed_message.repeated_value[0].seconds, -8.5 * 3600) + self.assertEqual(parsed_message.repeated_value[1].seconds, 3600 + 23 * 60) + + def testDurationMessage(self): + message = json_format_proto3_pb2.TestDuration() + message.value.seconds = 1 + message.repeated_value.add().seconds = 0 + message.repeated_value[0].nanos = 10 + message.repeated_value.add().seconds = -1 + message.repeated_value[1].nanos = -1000 + message.repeated_value.add().seconds = 10 + message.repeated_value[2].nanos = 11000000 + message.repeated_value.add().seconds = -315576000000 + message.repeated_value.add().seconds = 315576000000 + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads('{' + '"value": "1s",' + '"repeatedValue": [' + ' "0.000000010s",' + ' "-1.000001s",' + ' "10.011s",' + ' "-315576000000s",' + ' "315576000000s"' + ']' + '}')) + parsed_message = json_format_proto3_pb2.TestDuration() + self.CheckParseBack(message, parsed_message) + + def testFieldMaskMessage(self): + message = json_format_proto3_pb2.TestFieldMask() + message.value.paths.append('foo.bar') + message.value.paths.append('bar') + self.assertEqual( + json_format.MessageToJson(message, True), + '{\n' + ' "value": "foo.bar,bar"\n' + '}') + parsed_message = json_format_proto3_pb2.TestFieldMask() + self.CheckParseBack(message, parsed_message) + + def testWrapperMessage(self): + message = json_format_proto3_pb2.TestWrapper() + message.bool_value.value = False + message.int32_value.value = 0 + message.string_value.value = '' + message.bytes_value.value = b'' + message.repeated_bool_value.add().value = True + message.repeated_bool_value.add().value = False + message.repeated_int32_value.add() + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads('{\n' + ' "int32Value": 0,' + ' "boolValue": false,' + ' "stringValue": "",' + ' "bytesValue": "",' + ' "repeatedBoolValue": [true, false],' + ' "repeatedInt32Value": [0],' + ' "repeatedUint32Value": [],' + ' "repeatedFloatValue": [],' + ' "repeatedDoubleValue": [],' + ' "repeatedBytesValue": [],' + ' "repeatedInt64Value": [],' + ' "repeatedUint64Value": [],' + ' "repeatedStringValue": []' + '}')) + parsed_message = json_format_proto3_pb2.TestWrapper() + self.CheckParseBack(message, parsed_message) + + def testStructMessage(self): + message = json_format_proto3_pb2.TestStruct() + message.value['name'] = 'Jim' + message.value['age'] = 10 + message.value['attend'] = True + message.value['email'] = None + message.value.get_or_create_struct('address')['city'] = 'SFO' + message.value['address']['house_number'] = 1024 + struct_list = message.value.get_or_create_list('list') + struct_list.extend([6, 'seven', True, False, None]) + struct_list.add_struct()['subkey2'] = 9 + message.repeated_value.add()['age'] = 11 + message.repeated_value.add() + self.assertEqual( + json.loads(json_format.MessageToJson(message, False)), + json.loads( + '{' + ' "value": {' + ' "address": {' + ' "city": "SFO", ' + ' "house_number": 1024' + ' }, ' + ' "age": 10, ' + ' "name": "Jim", ' + ' "attend": true, ' + ' "email": null, ' + ' "list": [6, "seven", true, false, null, {"subkey2": 9}]' + ' },' + ' "repeatedValue": [{"age": 11}, {}]' + '}')) + parsed_message = json_format_proto3_pb2.TestStruct() + self.CheckParseBack(message, parsed_message) + + def testValueMessage(self): + message = json_format_proto3_pb2.TestValue() + message.value.string_value = 'hello' + message.repeated_value.add().number_value = 11.1 + message.repeated_value.add().bool_value = False + message.repeated_value.add().null_value = 0 + self.assertEqual( + json.loads(json_format.MessageToJson(message, False)), + json.loads( + '{' + ' "value": "hello",' + ' "repeatedValue": [11.1, false, null]' + '}')) + parsed_message = json_format_proto3_pb2.TestValue() + self.CheckParseBack(message, parsed_message) + # Can't parse back if the Value message is not set. + message.repeated_value.add() + self.assertEqual( + json.loads(json_format.MessageToJson(message, False)), + json.loads( + '{' + ' "value": "hello",' + ' "repeatedValue": [11.1, false, null, null]' + '}')) + + def testListValueMessage(self): + message = json_format_proto3_pb2.TestListValue() + message.value.values.add().number_value = 11.1 + message.value.values.add().null_value = 0 + message.value.values.add().bool_value = True + message.value.values.add().string_value = 'hello' + message.value.values.add().struct_value['name'] = 'Jim' + message.repeated_value.add().values.add().number_value = 1 + message.repeated_value.add() + self.assertEqual( + json.loads(json_format.MessageToJson(message, False)), + json.loads( + '{"value": [11.1, null, true, "hello", {"name": "Jim"}]\n,' + '"repeatedValue": [[1], []]}')) + parsed_message = json_format_proto3_pb2.TestListValue() + self.CheckParseBack(message, parsed_message) + + def testAnyMessage(self): + message = json_format_proto3_pb2.TestAny() + value1 = json_format_proto3_pb2.MessageType() + value2 = json_format_proto3_pb2.MessageType() + value1.value = 1234 + value2.value = 5678 + message.value.Pack(value1) + message.repeated_value.add().Pack(value1) + message.repeated_value.add().Pack(value2) + message.repeated_value.add() + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "repeatedValue": [ {\n' + ' "@type": "type.googleapis.com/proto3.MessageType",\n' + ' "value": 1234\n' + ' }, {\n' + ' "@type": "type.googleapis.com/proto3.MessageType",\n' + ' "value": 5678\n' + ' },\n' + ' {}],\n' + ' "value": {\n' + ' "@type": "type.googleapis.com/proto3.MessageType",\n' + ' "value": 1234\n' + ' }\n' + '}\n')) + parsed_message = json_format_proto3_pb2.TestAny() + self.CheckParseBack(message, parsed_message) + + def testWellKnownInAnyMessage(self): + message = any_pb2.Any() + int32_value = wrappers_pb2.Int32Value() + int32_value.value = 1234 + message.Pack(int32_value) + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "@type": \"type.googleapis.com/google.protobuf.Int32Value\",\n' + ' "value": 1234\n' + '}\n')) + parsed_message = any_pb2.Any() + self.CheckParseBack(message, parsed_message) + + timestamp = timestamp_pb2.Timestamp() + message.Pack(timestamp) + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "@type": "type.googleapis.com/google.protobuf.Timestamp",\n' + ' "value": "1970-01-01T00:00:00Z"\n' + '}\n')) + self.CheckParseBack(message, parsed_message) + + duration = duration_pb2.Duration() + duration.seconds = 1 + message.Pack(duration) + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "@type": "type.googleapis.com/google.protobuf.Duration",\n' + ' "value": "1s"\n' + '}\n')) + self.CheckParseBack(message, parsed_message) + + field_mask = field_mask_pb2.FieldMask() + field_mask.paths.append('foo.bar') + field_mask.paths.append('bar') + message.Pack(field_mask) + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "@type": "type.googleapis.com/google.protobuf.FieldMask",\n' + ' "value": "foo.bar,bar"\n' + '}\n')) + self.CheckParseBack(message, parsed_message) + + struct_message = struct_pb2.Struct() + struct_message['name'] = 'Jim' + message.Pack(struct_message) + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "@type": "type.googleapis.com/google.protobuf.Struct",\n' + ' "value": {"name": "Jim"}\n' + '}\n')) + self.CheckParseBack(message, parsed_message) + + nested_any = any_pb2.Any() + int32_value.value = 5678 + nested_any.Pack(int32_value) + message.Pack(nested_any) + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads( + '{\n' + ' "@type": "type.googleapis.com/google.protobuf.Any",\n' + ' "value": {\n' + ' "@type": "type.googleapis.com/google.protobuf.Int32Value",\n' + ' "value": 5678\n' + ' }\n' + '}\n')) + self.CheckParseBack(message, parsed_message) + + def testParseNull(self): + message = json_format_proto3_pb2.TestMessage() + parsed_message = json_format_proto3_pb2.TestMessage() + self.FillAllFields(parsed_message) + json_format.Parse('{"int32Value": null, ' + '"int64Value": null, ' + '"uint32Value": null,' + '"uint64Value": null,' + '"floatValue": null,' + '"doubleValue": null,' + '"boolValue": null,' + '"stringValue": null,' + '"bytesValue": null,' + '"messageValue": null,' + '"enumValue": null,' + '"repeatedInt32Value": null,' + '"repeatedInt64Value": null,' + '"repeatedUint32Value": null,' + '"repeatedUint64Value": null,' + '"repeatedFloatValue": null,' + '"repeatedDoubleValue": null,' + '"repeatedBoolValue": null,' + '"repeatedStringValue": null,' + '"repeatedBytesValue": null,' + '"repeatedMessageValue": null,' + '"repeatedEnumValue": null' + '}', + parsed_message) + self.assertEqual(message, parsed_message) + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to parse repeatedInt32Value field: ' + 'null is not allowed to be used as an element in a repeated field.', + json_format.Parse, + '{"repeatedInt32Value":[1, null]}', + parsed_message) + + def testNanFloat(self): + message = json_format_proto3_pb2.TestMessage() + message.float_value = float('nan') + text = '{\n "floatValue": "NaN"\n}' + self.assertEqual(json_format.MessageToJson(message), text) + parsed_message = json_format_proto3_pb2.TestMessage() + json_format.Parse(text, parsed_message) + self.assertTrue(math.isnan(parsed_message.float_value)) + + def testParseEmptyText(self): + self.CheckError('', + r'Failed to load JSON: (Expecting value)|(No JSON).') + + def testParseBadEnumValue(self): + self.CheckError( + '{"enumValue": 1}', + 'Enum value must be a string literal with double quotes. ' + 'Type "proto3.EnumType" has no value named 1.') + self.CheckError( + '{"enumValue": "baz"}', + 'Enum value must be a string literal with double quotes. ' + 'Type "proto3.EnumType" has no value named baz.') + + def testParseBadIdentifer(self): + self.CheckError('{int32Value: 1}', + (r'Failed to load JSON: Expecting property name' + r'( enclosed in double quotes)?: line 1')) + self.CheckError('{"unknownName": 1}', + 'Message type "proto3.TestMessage" has no field named ' + '"unknownName".') + + def testDuplicateField(self): + # Duplicate key check is not supported for python2.6 + if sys.version_info < (2, 7): + return + self.CheckError('{"int32Value": 1,\n"int32Value":2}', + 'Failed to load JSON: duplicate key int32Value.') + + def testInvalidBoolValue(self): + self.CheckError('{"boolValue": 1}', + 'Failed to parse boolValue field: ' + 'Expected true or false without quotes.') + self.CheckError('{"boolValue": "true"}', + 'Failed to parse boolValue field: ' + 'Expected true or false without quotes.') + + def testInvalidIntegerValue(self): + message = json_format_proto3_pb2.TestMessage() + text = '{"int32Value": 0x12345}' + self.assertRaises(json_format.ParseError, + json_format.Parse, text, message) + self.CheckError('{"int32Value": 012345}', + (r'Failed to load JSON: Expecting \'?,\'? delimiter: ' + r'line 1.')) + self.CheckError('{"int32Value": 1.0}', + 'Failed to parse int32Value field: ' + 'Couldn\'t parse integer: 1.0.') + self.CheckError('{"int32Value": " 1 "}', + 'Failed to parse int32Value field: ' + 'Couldn\'t parse integer: " 1 ".') + self.CheckError('{"int32Value": "1 "}', + 'Failed to parse int32Value field: ' + 'Couldn\'t parse integer: "1 ".') + self.CheckError('{"int32Value": 12345678901234567890}', + 'Failed to parse int32Value field: Value out of range: ' + '12345678901234567890.') + self.CheckError('{"int32Value": 1e5}', + 'Failed to parse int32Value field: ' + 'Couldn\'t parse integer: 100000.0.') + self.CheckError('{"uint32Value": -1}', + 'Failed to parse uint32Value field: ' + 'Value out of range: -1.') + + def testInvalidFloatValue(self): + self.CheckError('{"floatValue": "nan"}', + 'Failed to parse floatValue field: Couldn\'t ' + 'parse float "nan", use "NaN" instead.') + + def testInvalidBytesValue(self): + self.CheckError('{"bytesValue": "AQI"}', + 'Failed to parse bytesValue field: Incorrect padding.') + self.CheckError('{"bytesValue": "AQI*"}', + 'Failed to parse bytesValue field: Incorrect padding.') + + def testInvalidMap(self): + message = json_format_proto3_pb2.TestMap() + text = '{"int32Map": {"null": 2, "2": 3}}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to parse int32Map field: invalid literal', + json_format.Parse, text, message) + text = '{"int32Map": {1: 2, "2": 3}}' + self.assertRaisesRegexp( + json_format.ParseError, + (r'Failed to load JSON: Expecting property name' + r'( enclosed in double quotes)?: line 1'), + json_format.Parse, text, message) + text = '{"boolMap": {"null": 1}}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to parse boolMap field: Expected "true" or "false", not null.', + json_format.Parse, text, message) + if sys.version_info < (2, 7): + return + text = r'{"stringMap": {"a": 3, "\u0061": 2}}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to load JSON: duplicate key a', + json_format.Parse, text, message) + + def testInvalidTimestamp(self): + message = json_format_proto3_pb2.TestTimestamp() + text = '{"value": "10000-01-01T00:00:00.00Z"}' + self.assertRaisesRegexp( + json_format.ParseError, + 'time data \'10000-01-01T00:00:00\' does not match' + ' format \'%Y-%m-%dT%H:%M:%S\'.', + json_format.Parse, text, message) + text = '{"value": "1970-01-01T00:00:00.0123456789012Z"}' + self.assertRaisesRegexp( + well_known_types.ParseError, + 'nanos 0123456789012 more than 9 fractional digits.', + json_format.Parse, text, message) + text = '{"value": "1972-01-01T01:00:00.01+08"}' + self.assertRaisesRegexp( + well_known_types.ParseError, + (r'Invalid timezone offset value: \+08.'), + json_format.Parse, text, message) + # Time smaller than minimum time. + text = '{"value": "0000-01-01T00:00:00Z"}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to parse value field: year is out of range.', + json_format.Parse, text, message) + # Time bigger than maxinum time. + message.value.seconds = 253402300800 + self.assertRaisesRegexp( + OverflowError, + 'date value out of range', + json_format.MessageToJson, message) + + def testInvalidOneof(self): + message = json_format_proto3_pb2.TestOneof() + text = '{"oneofInt32Value": 1, "oneofStringValue": "2"}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Message type "proto3.TestOneof"' + ' should not have multiple "oneof_value" oneof fields.', + json_format.Parse, text, message) + + def testInvalidListValue(self): + message = json_format_proto3_pb2.TestListValue() + text = '{"value": 1234}' + self.assertRaisesRegexp( + json_format.ParseError, + r'Failed to parse value field: ListValue must be in \[\] which is 1234', + json_format.Parse, text, message) + + def testInvalidStruct(self): + message = json_format_proto3_pb2.TestStruct() + text = '{"value": 1234}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to parse value field: Struct must be in a dict which is 1234', + json_format.Parse, text, message) + + def testInvalidAny(self): + message = any_pb2.Any() + text = '{"@type": "type.googleapis.com/google.protobuf.Int32Value"}' + self.assertRaisesRegexp( + KeyError, + 'value', + json_format.Parse, text, message) + text = '{"value": 1234}' + self.assertRaisesRegexp( + json_format.ParseError, + '@type is missing when parsing any message.', + json_format.Parse, text, message) + text = '{"@type": "type.googleapis.com/MessageNotExist", "value": 1234}' + self.assertRaisesRegexp( + TypeError, + 'Can not find message descriptor by type_url: ' + 'type.googleapis.com/MessageNotExist.', + json_format.Parse, text, message) + # Only last part is to be used: b/25630112 + text = (r'{"@type": "incorrect.googleapis.com/google.protobuf.Int32Value",' + r'"value": 1234}') + json_format.Parse(text, message) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/message_factory_python_test.py b/python/google/protobuf/internal/message_factory_python_test.py deleted file mode 100644 index 85e02b259..000000000 --- a/python/google/protobuf/internal/message_factory_python_test.py +++ /dev/null @@ -1,54 +0,0 @@ -#! /usr/bin/python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Tests for ..public.message_factory for the pure Python implementation.""" - -import os -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' - -# We must set the implementation version above before the google3 imports. -# pylint: disable=g-import-not-at-top -from google.apputils import basetest -from google.protobuf.internal import api_implementation -# Run all tests from the original module by putting them in our namespace. -# pylint: disable=wildcard-import -from google.protobuf.internal.message_factory_test import * - - -class ConfirmPurePythonTest(basetest.TestCase): - - def testImplementationSetting(self): - self.assertEqual('python', api_implementation.Type()) - - -if __name__ == '__main__': - basetest.main() diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index fcf134103..7bb7d1ace 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -34,7 +34,11 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' -from google.apputils import basetest +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 @@ -43,7 +47,7 @@ from google.protobuf import descriptor_pool from google.protobuf import message_factory -class MessageFactoryTest(basetest.TestCase): +class MessageFactoryTest(unittest.TestCase): def setUp(self): self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( @@ -81,9 +85,9 @@ class MessageFactoryTest(basetest.TestCase): serialized = msg.SerializeToString() converted = factory_test2_pb2.Factory2Message.FromString(serialized) reserialized = converted.SerializeToString() - self.assertEquals(serialized, reserialized) + self.assertEqual(serialized, reserialized) result = cls.FromString(reserialized) - self.assertEquals(msg, result) + self.assertEqual(msg, result) def testGetPrototype(self): db = descriptor_database.DescriptorDatabase() @@ -93,28 +97,29 @@ class MessageFactoryTest(basetest.TestCase): factory = message_factory.MessageFactory() cls = factory.GetPrototype(pool.FindMessageTypeByName( 'google.protobuf.python.internal.Factory2Message')) - self.assertIsNot(cls, factory_test2_pb2.Factory2Message) + self.assertFalse(cls is factory_test2_pb2.Factory2Message) self._ExerciseDynamicClass(cls) cls2 = factory.GetPrototype(pool.FindMessageTypeByName( 'google.protobuf.python.internal.Factory2Message')) - self.assertIs(cls, cls2) + self.assertTrue(cls is cls2) def testGetMessages(self): # performed twice because multiple calls with the same input must be allowed for _ in range(2): - messages = message_factory.GetMessages([self.factory_test2_fd, - self.factory_test1_fd]) - self.assertContainsSubset( - ['google.protobuf.python.internal.Factory2Message', - 'google.protobuf.python.internal.Factory1Message'], - messages.keys()) + messages = message_factory.GetMessages([self.factory_test1_fd, + self.factory_test2_fd]) + self.assertTrue( + set(['google.protobuf.python.internal.Factory2Message', + 'google.protobuf.python.internal.Factory1Message'], + ).issubset(set(messages.keys()))) self._ExerciseDynamicClass( messages['google.protobuf.python.internal.Factory2Message']) - self.assertContainsSubset( - ['google.protobuf.python.internal.Factory2Message.one_more_field', - 'google.protobuf.python.internal.another_field'], - (messages['google.protobuf.python.internal.Factory1Message'] - ._extensions_by_name.keys())) + self.assertTrue( + set(['google.protobuf.python.internal.Factory2Message.one_more_field', + 'google.protobuf.python.internal.another_field'], + ).issubset( + set(messages['google.protobuf.python.internal.Factory1Message'] + ._extensions_by_name.keys()))) factory_msg1 = messages['google.protobuf.python.internal.Factory1Message'] msg1 = messages['google.protobuf.python.internal.Factory1Message']() ext1 = factory_msg1._extensions_by_name[ @@ -123,9 +128,63 @@ class MessageFactoryTest(basetest.TestCase): 'google.protobuf.python.internal.another_field'] msg1.Extensions[ext1] = 'test1' msg1.Extensions[ext2] = 'test2' - self.assertEquals('test1', msg1.Extensions[ext1]) - self.assertEquals('test2', msg1.Extensions[ext2]) + self.assertEqual('test1', msg1.Extensions[ext1]) + self.assertEqual('test2', msg1.Extensions[ext2]) + + def testDuplicateExtensionNumber(self): + pool = descriptor_pool.DescriptorPool() + factory = message_factory.MessageFactory(pool=pool) + + # Add Container message. + f = descriptor_pb2.FileDescriptorProto() + f.name = 'google/protobuf/internal/container.proto' + f.package = 'google.protobuf.python.internal' + msg = f.message_type.add() + msg.name = 'Container' + rng = msg.extension_range.add() + rng.start = 1 + rng.end = 10 + pool.Add(f) + msgs = factory.GetMessages([f.name]) + self.assertIn('google.protobuf.python.internal.Container', msgs) + + # Extend container. + f = descriptor_pb2.FileDescriptorProto() + f.name = 'google/protobuf/internal/extension.proto' + f.package = 'google.protobuf.python.internal' + f.dependency.append('google/protobuf/internal/container.proto') + msg = f.message_type.add() + msg.name = 'Extension' + ext = msg.extension.add() + ext.name = 'extension_field' + ext.number = 2 + ext.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + ext.type_name = 'Extension' + ext.extendee = 'Container' + pool.Add(f) + msgs = factory.GetMessages([f.name]) + self.assertIn('google.protobuf.python.internal.Extension', msgs) + + # Add Duplicate extending the same field number. + f = descriptor_pb2.FileDescriptorProto() + f.name = 'google/protobuf/internal/duplicate.proto' + f.package = 'google.protobuf.python.internal' + f.dependency.append('google/protobuf/internal/container.proto') + msg = f.message_type.add() + msg.name = 'Duplicate' + ext = msg.extension.add() + ext.name = 'extension_field' + ext.number = 2 + ext.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + ext.type_name = 'Duplicate' + ext.extendee = 'Container' + pool.Add(f) + + with self.assertRaises(Exception) as cm: + factory.GetMessages([f.name]) + + self.assertIsInstance(cm.exception, (AssertionError, ValueError)) if __name__ == '__main__': - basetest.main() + unittest.main() diff --git a/python/google/protobuf/internal/message_set_extensions.proto b/python/google/protobuf/internal/message_set_extensions.proto new file mode 100644 index 000000000..14e5f1937 --- /dev/null +++ b/python/google/protobuf/internal/message_set_extensions.proto @@ -0,0 +1,74 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file contains messages that extend MessageSet. + +syntax = "proto2"; +package google.protobuf.internal; + + +// A message with message_set_wire_format. +message TestMessageSet { + option message_set_wire_format = true; + extensions 4 to max; +} + +message TestMessageSetExtension1 { + extend TestMessageSet { + optional TestMessageSetExtension1 message_set_extension = 98418603; + } + optional int32 i = 15; +} + +message TestMessageSetExtension2 { + extend TestMessageSet { + optional TestMessageSetExtension2 message_set_extension = 98418634; + } + optional string str = 25; +} + +message TestMessageSetExtension3 { + optional string text = 35; +} + +extend TestMessageSet { + optional TestMessageSetExtension3 message_set_extension3 = 98418655; +} + +// This message was used to generate +// //net/proto2/python/internal/testdata/message_set_message, but is commented +// out since it must not actually exist in code, to simulate an "unknown" +// extension. +// message TestMessageSetUnknownExtension { +// extend TestMessageSet { +// optional TestMessageSetUnknownExtension message_set_extension = 56141421; +// } +// optional int64 a = 1; +// } diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 48b7ffd4e..4ee31d8ed 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -43,17 +43,36 @@ abstract interface. __author__ = 'gps@google.com (Gregory P. Smith)' + +import collections import copy import math import operator import pickle +import six import sys -from google.apputils import basetest +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + +from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_pb2 +from google.protobuf import unittest_proto3_arena_pb2 +from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor_pool +from google.protobuf import message_factory +from google.protobuf import text_format from google.protobuf.internal import api_implementation +from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util from google.protobuf import message +from google.protobuf.internal import _parameterized + +if six.PY3: + long = int + # Python pre-2.6 does not have isinf() or isnan() functions, so we have # to provide our own. @@ -69,88 +88,72 @@ def IsNegInf(val): return isinf(val) and (val < 0) -class MessageTest(basetest.TestCase): +@_parameterized.Parameters( + (unittest_pb2), + (unittest_proto3_arena_pb2)) +class MessageTest(unittest.TestCase): - def testBadUtf8String(self): + def testBadUtf8String(self, message_module): if api_implementation.Type() != 'python': self.skipTest("Skipping testBadUtf8String, currently only the python " "api implementation raises UnicodeDecodeError when a " "string field contains bad utf-8.") bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') with self.assertRaises(UnicodeDecodeError) as context: - unittest_pb2.TestAllTypes.FromString(bad_utf8_data) - self.assertIn('field: protobuf_unittest.TestAllTypes.optional_string', - str(context.exception)) - - def testGoldenMessage(self): - golden_data = test_util.GoldenFileData( - 'golden_message_oneof_implemented') - golden_message = unittest_pb2.TestAllTypes() - golden_message.ParseFromString(golden_data) - test_util.ExpectAllFieldsSet(self, golden_message) - self.assertEqual(golden_data, golden_message.SerializeToString()) - golden_copy = copy.deepcopy(golden_message) - self.assertEqual(golden_data, golden_copy.SerializeToString()) + message_module.TestAllTypes.FromString(bad_utf8_data) + self.assertIn('TestAllTypes.optional_string', str(context.exception)) + + def testGoldenMessage(self, message_module): + # Proto3 doesn't have the "default_foo" members or foreign enums, + # and doesn't preserve unknown fields, so for proto3 we use a golden + # message that doesn't have these fields set. + if message_module is unittest_pb2: + golden_data = test_util.GoldenFileData( + 'golden_message_oneof_implemented') + else: + golden_data = test_util.GoldenFileData('golden_message_proto3') - def testGoldenExtensions(self): - golden_data = test_util.GoldenFileData('golden_message') - golden_message = unittest_pb2.TestAllExtensions() + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) - all_set = unittest_pb2.TestAllExtensions() - test_util.SetAllExtensions(all_set) - self.assertEquals(all_set, golden_message) + if message_module is unittest_pb2: + test_util.ExpectAllFieldsSet(self, golden_message) self.assertEqual(golden_data, golden_message.SerializeToString()) golden_copy = copy.deepcopy(golden_message) self.assertEqual(golden_data, golden_copy.SerializeToString()) - def testGoldenPackedMessage(self): + def testGoldenPackedMessage(self, message_module): golden_data = test_util.GoldenFileData('golden_packed_fields_message') - golden_message = unittest_pb2.TestPackedTypes() + golden_message = message_module.TestPackedTypes() golden_message.ParseFromString(golden_data) - all_set = unittest_pb2.TestPackedTypes() + all_set = message_module.TestPackedTypes() test_util.SetAllPackedFields(all_set) - self.assertEquals(all_set, golden_message) + self.assertEqual(all_set, golden_message) self.assertEqual(golden_data, all_set.SerializeToString()) golden_copy = copy.deepcopy(golden_message) self.assertEqual(golden_data, golden_copy.SerializeToString()) - def testGoldenPackedExtensions(self): - golden_data = test_util.GoldenFileData('golden_packed_fields_message') - golden_message = unittest_pb2.TestPackedExtensions() - golden_message.ParseFromString(golden_data) - all_set = unittest_pb2.TestPackedExtensions() - test_util.SetAllPackedExtensions(all_set) - self.assertEquals(all_set, golden_message) - self.assertEqual(golden_data, all_set.SerializeToString()) - golden_copy = copy.deepcopy(golden_message) - self.assertEqual(golden_data, golden_copy.SerializeToString()) - - def testPickleSupport(self): + def testPickleSupport(self, message_module): golden_data = test_util.GoldenFileData('golden_message') - golden_message = unittest_pb2.TestAllTypes() + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) pickled_message = pickle.dumps(golden_message) unpickled_message = pickle.loads(pickled_message) - self.assertEquals(unpickled_message, golden_message) - - - def testPickleIncompleteProto(self): - golden_message = unittest_pb2.TestRequired(a=1) - pickled_message = pickle.dumps(golden_message) - - unpickled_message = pickle.loads(pickled_message) - self.assertEquals(unpickled_message, golden_message) - self.assertEquals(unpickled_message.a, 1) - # This is still an incomplete proto - so serializing should fail - self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) + self.assertEqual(unpickled_message, golden_message) + + def testPositiveInfinity(self, message_module): + if message_module is unittest_pb2: + golden_data = (b'\x5D\x00\x00\x80\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' + b'\xCD\x02\x00\x00\x80\x7F' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') + else: + golden_data = (b'\x5D\x00\x00\x80\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' + b'\xCA\x02\x04\x00\x00\x80\x7F' + b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') - def testPositiveInfinity(self): - golden_data = (b'\x5D\x00\x00\x80\x7F' - b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' - b'\xCD\x02\x00\x00\x80\x7F' - b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') - golden_message = unittest_pb2.TestAllTypes() + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.optional_float)) self.assertTrue(IsPosInf(golden_message.optional_double)) @@ -158,12 +161,19 @@ class MessageTest(basetest.TestCase): self.assertTrue(IsPosInf(golden_message.repeated_double[0])) self.assertEqual(golden_data, golden_message.SerializeToString()) - def testNegativeInfinity(self): - golden_data = (b'\x5D\x00\x00\x80\xFF' - b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' - b'\xCD\x02\x00\x00\x80\xFF' - b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') - golden_message = unittest_pb2.TestAllTypes() + def testNegativeInfinity(self, message_module): + if message_module is unittest_pb2: + golden_data = (b'\x5D\x00\x00\x80\xFF' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' + b'\xCD\x02\x00\x00\x80\xFF' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') + else: + golden_data = (b'\x5D\x00\x00\x80\xFF' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' + b'\xCA\x02\x04\x00\x00\x80\xFF' + b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') + + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.optional_float)) self.assertTrue(IsNegInf(golden_message.optional_double)) @@ -171,12 +181,12 @@ class MessageTest(basetest.TestCase): self.assertTrue(IsNegInf(golden_message.repeated_double[0])) self.assertEqual(golden_data, golden_message.SerializeToString()) - def testNotANumber(self): + def testNotANumber(self, message_module): golden_data = (b'\x5D\x00\x00\xC0\x7F' b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' b'\xCD\x02\x00\x00\xC0\x7F' b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') - golden_message = unittest_pb2.TestAllTypes() + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.optional_float)) self.assertTrue(isnan(golden_message.optional_double)) @@ -188,47 +198,47 @@ class MessageTest(basetest.TestCase): # verify the serialized string can be converted into a correctly # behaving protocol buffer. serialized = golden_message.SerializeToString() - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() message.ParseFromString(serialized) self.assertTrue(isnan(message.optional_float)) self.assertTrue(isnan(message.optional_double)) self.assertTrue(isnan(message.repeated_float[0])) self.assertTrue(isnan(message.repeated_double[0])) - def testPositiveInfinityPacked(self): + def testPositiveInfinityPacked(self, message_module): golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F' b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') - golden_message = unittest_pb2.TestPackedTypes() + golden_message = message_module.TestPackedTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.packed_float[0])) self.assertTrue(IsPosInf(golden_message.packed_double[0])) self.assertEqual(golden_data, golden_message.SerializeToString()) - def testNegativeInfinityPacked(self): + def testNegativeInfinityPacked(self, message_module): golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF' b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') - golden_message = unittest_pb2.TestPackedTypes() + golden_message = message_module.TestPackedTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.packed_float[0])) self.assertTrue(IsNegInf(golden_message.packed_double[0])) self.assertEqual(golden_data, golden_message.SerializeToString()) - def testNotANumberPacked(self): + def testNotANumberPacked(self, message_module): golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F' b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') - golden_message = unittest_pb2.TestPackedTypes() + golden_message = message_module.TestPackedTypes() golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.packed_float[0])) self.assertTrue(isnan(golden_message.packed_double[0])) serialized = golden_message.SerializeToString() - message = unittest_pb2.TestPackedTypes() + message = message_module.TestPackedTypes() message.ParseFromString(serialized) self.assertTrue(isnan(message.packed_float[0])) self.assertTrue(isnan(message.packed_double[0])) - def testExtremeFloatValues(self): - message = unittest_pb2.TestAllTypes() + def testExtremeFloatValues(self, message_module): + message = message_module.TestAllTypes() # Most positive exponent, no significand bits set. kMostPosExponentNoSigBits = math.pow(2, 127) @@ -272,8 +282,8 @@ class MessageTest(basetest.TestCase): message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) - def testExtremeDoubleValues(self): - message = unittest_pb2.TestAllTypes() + def testExtremeDoubleValues(self, message_module): + message = message_module.TestAllTypes() # Most positive exponent, no significand bits set. kMostPosExponentNoSigBits = math.pow(2, 1023) @@ -317,29 +327,43 @@ class MessageTest(basetest.TestCase): message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit) - def testFloatPrinting(self): - message = unittest_pb2.TestAllTypes() + def testFloatPrinting(self, message_module): + message = message_module.TestAllTypes() message.optional_float = 2.0 self.assertEqual(str(message), 'optional_float: 2.0\n') - def testHighPrecisionFloatPrinting(self): - message = unittest_pb2.TestAllTypes() + def testHighPrecisionFloatPrinting(self, message_module): + message = message_module.TestAllTypes() message.optional_double = 0.12345678912345678 - if sys.version_info.major >= 3: + if sys.version_info >= (3,): self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n') else: self.assertEqual(str(message), 'optional_double: 0.123456789123\n') - def testUnknownFieldPrinting(self): - populated = unittest_pb2.TestAllTypes() + def testUnknownFieldPrinting(self, message_module): + populated = message_module.TestAllTypes() test_util.SetAllNonLazyFields(populated) - empty = unittest_pb2.TestEmptyMessage() + empty = message_module.TestEmptyMessage() empty.ParseFromString(populated.SerializeToString()) self.assertEqual(str(empty), '') - def testSortingRepeatedScalarFieldsDefaultComparator(self): + def testRepeatedNestedFieldIteration(self, message_module): + msg = message_module.TestAllTypes() + msg.repeated_nested_message.add(bb=1) + msg.repeated_nested_message.add(bb=2) + msg.repeated_nested_message.add(bb=3) + msg.repeated_nested_message.add(bb=4) + + self.assertEqual([1, 2, 3, 4], + [m.bb for m in msg.repeated_nested_message]) + self.assertEqual([4, 3, 2, 1], + [m.bb for m in reversed(msg.repeated_nested_message)]) + self.assertEqual([4, 3, 2, 1], + [m.bb for m in msg.repeated_nested_message[::-1]]) + + def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module): """Check some different types with the default comparator.""" - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() # TODO(mattp): would testing more scalar types strengthen test? message.repeated_int32.append(1) @@ -374,9 +398,9 @@ class MessageTest(basetest.TestCase): self.assertEqual(message.repeated_bytes[1], b'b') self.assertEqual(message.repeated_bytes[2], b'c') - def testSortingRepeatedScalarFieldsCustomComparator(self): + def testSortingRepeatedScalarFieldsCustomComparator(self, message_module): """Check some different types with custom comparator.""" - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() message.repeated_int32.append(-3) message.repeated_int32.append(-2) @@ -394,9 +418,9 @@ class MessageTest(basetest.TestCase): self.assertEqual(message.repeated_string[1], 'bb') self.assertEqual(message.repeated_string[2], 'aaa') - def testSortingRepeatedCompositeFieldsCustomComparator(self): + def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module): """Check passing a custom comparator to sort a repeated composite field.""" - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() message.repeated_nested_message.add().bb = 1 message.repeated_nested_message.add().bb = 3 @@ -412,9 +436,34 @@ class MessageTest(basetest.TestCase): self.assertEqual(message.repeated_nested_message[4].bb, 5) self.assertEqual(message.repeated_nested_message[5].bb, 6) - def testRepeatedCompositeFieldSortArguments(self): + def testSortingRepeatedCompositeFieldsStable(self, message_module): + """Check passing a custom comparator to sort a repeated composite field.""" + message = message_module.TestAllTypes() + + message.repeated_nested_message.add().bb = 21 + message.repeated_nested_message.add().bb = 20 + message.repeated_nested_message.add().bb = 13 + message.repeated_nested_message.add().bb = 33 + message.repeated_nested_message.add().bb = 11 + message.repeated_nested_message.add().bb = 24 + message.repeated_nested_message.add().bb = 10 + message.repeated_nested_message.sort(key=lambda z: z.bb // 10) + self.assertEqual( + [13, 11, 10, 21, 20, 24, 33], + [n.bb for n in message.repeated_nested_message]) + + # Make sure that for the C++ implementation, the underlying fields + # are actually reordered. + pb = message.SerializeToString() + message.Clear() + message.MergeFromString(pb) + self.assertEqual( + [13, 11, 10, 21, 20, 24, 33], + [n.bb for n in message.repeated_nested_message]) + + def testRepeatedCompositeFieldSortArguments(self, message_module): """Check sorting a repeated composite field using list.sort() arguments.""" - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() get_bb = operator.attrgetter('bb') cmp_bb = lambda a, b: cmp(a.bb, b.bb) @@ -430,7 +479,7 @@ class MessageTest(basetest.TestCase): message.repeated_nested_message.sort(key=get_bb, reverse=True) self.assertEqual([k.bb for k in message.repeated_nested_message], [6, 5, 4, 3, 2, 1]) - if sys.version_info.major >= 3: return # No cmp sorting in PY3. + if sys.version_info >= (3,): return # No cmp sorting in PY3. message.repeated_nested_message.sort(sort_function=cmp_bb) self.assertEqual([k.bb for k in message.repeated_nested_message], [1, 2, 3, 4, 5, 6]) @@ -438,9 +487,9 @@ class MessageTest(basetest.TestCase): self.assertEqual([k.bb for k in message.repeated_nested_message], [6, 5, 4, 3, 2, 1]) - def testRepeatedScalarFieldSortArguments(self): + def testRepeatedScalarFieldSortArguments(self, message_module): """Check sorting a scalar field using list.sort() arguments.""" - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() message.repeated_int32.append(-3) message.repeated_int32.append(-2) @@ -449,7 +498,7 @@ class MessageTest(basetest.TestCase): self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) message.repeated_int32.sort(key=abs, reverse=True) self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) - if sys.version_info.major < 3: # No cmp sorting in PY3. + if sys.version_info < (3,): # No cmp sorting in PY3. abs_cmp = lambda a, b: cmp(abs(a), abs(b)) message.repeated_int32.sort(sort_function=abs_cmp) self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) @@ -463,16 +512,16 @@ class MessageTest(basetest.TestCase): self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) message.repeated_string.sort(key=len, reverse=True) self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) - if sys.version_info.major < 3: # No cmp sorting in PY3. + if sys.version_info < (3,): # No cmp sorting in PY3. len_cmp = lambda a, b: cmp(len(a), len(b)) message.repeated_string.sort(sort_function=len_cmp) self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) message.repeated_string.sort(cmp=len_cmp, reverse=True) self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) - def testRepeatedFieldsComparable(self): - m1 = unittest_pb2.TestAllTypes() - m2 = unittest_pb2.TestAllTypes() + def testRepeatedFieldsComparable(self, message_module): + m1 = message_module.TestAllTypes() + m2 = message_module.TestAllTypes() m1.repeated_int32.append(0) m1.repeated_int32.append(1) m1.repeated_int32.append(2) @@ -486,7 +535,7 @@ class MessageTest(basetest.TestCase): m2.repeated_nested_message.add().bb = 2 m2.repeated_nested_message.add().bb = 3 - if sys.version_info.major >= 3: return # No cmp() in PY3. + if sys.version_info >= (3,): return # No cmp() in PY3. # These comparisons should not raise errors. _ = m1 < m2 @@ -505,54 +554,11 @@ class MessageTest(basetest.TestCase): # TODO(anuraag): Implement extensiondict comparison in C++ and then add test - def testParsingMerge(self): - """Check the merge behavior when a required or optional field appears - multiple times in the input.""" - messages = [ - unittest_pb2.TestAllTypes(), - unittest_pb2.TestAllTypes(), - unittest_pb2.TestAllTypes() ] - messages[0].optional_int32 = 1 - messages[1].optional_int64 = 2 - messages[2].optional_int32 = 3 - messages[2].optional_string = 'hello' - - merged_message = unittest_pb2.TestAllTypes() - merged_message.optional_int32 = 3 - merged_message.optional_int64 = 2 - merged_message.optional_string = 'hello' - - generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() - generator.field1.extend(messages) - generator.field2.extend(messages) - generator.field3.extend(messages) - generator.ext1.extend(messages) - generator.ext2.extend(messages) - generator.group1.add().field1.MergeFrom(messages[0]) - generator.group1.add().field1.MergeFrom(messages[1]) - generator.group1.add().field1.MergeFrom(messages[2]) - generator.group2.add().field1.MergeFrom(messages[0]) - generator.group2.add().field1.MergeFrom(messages[1]) - generator.group2.add().field1.MergeFrom(messages[2]) - - data = generator.SerializeToString() - parsing_merge = unittest_pb2.TestParsingMerge() - parsing_merge.ParseFromString(data) - - # Required and optional fields should be merged. - self.assertEqual(parsing_merge.required_all_types, merged_message) - self.assertEqual(parsing_merge.optional_all_types, merged_message) - self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, - merged_message) - self.assertEqual(parsing_merge.Extensions[ - unittest_pb2.TestParsingMerge.optional_ext], - merged_message) - - # Repeated fields should not be merged. - self.assertEqual(len(parsing_merge.repeated_all_types), 3) - self.assertEqual(len(parsing_merge.repeatedgroup), 3) - self.assertEqual(len(parsing_merge.Extensions[ - unittest_pb2.TestParsingMerge.repeated_ext]), 3) + def testRepeatedFieldsAreSequences(self, message_module): + m = message_module.TestAllTypes() + self.assertIsInstance(m.repeated_int32, collections.MutableSequence) + self.assertIsInstance(m.repeated_nested_message, + collections.MutableSequence) def ensureNestedMessageExists(self, msg, attribute): """Make sure that a nested message object exists. @@ -563,12 +569,28 @@ class MessageTest(basetest.TestCase): getattr(msg, attribute) self.assertFalse(msg.HasField(attribute)) - def testOneofGetCaseNonexistingField(self): - m = unittest_pb2.TestAllTypes() + def testOneofGetCaseNonexistingField(self, message_module): + m = message_module.TestAllTypes() self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') - def testOneofSemantics(self): - m = unittest_pb2.TestAllTypes() + def testOneofDefaultValues(self, message_module): + m = message_module.TestAllTypes() + self.assertIs(None, m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + + # Oneof is set even when setting it to a default value. + m.oneof_uint32 = 0 + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_uint32')) + self.assertFalse(m.HasField('oneof_string')) + + m.oneof_string = "" + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_string')) + self.assertFalse(m.HasField('oneof_uint32')) + + def testOneofSemantics(self, message_module): + m = message_module.TestAllTypes() self.assertIs(None, m.WhichOneof('oneof_field')) m.oneof_uint32 = 11 @@ -580,6 +602,18 @@ class MessageTest(basetest.TestCase): self.assertFalse(m.HasField('oneof_uint32')) self.assertTrue(m.HasField('oneof_string')) + # Read nested message accessor without accessing submessage. + m.oneof_nested_message + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_string')) + self.assertFalse(m.HasField('oneof_nested_message')) + + # Read accessor of nested message without accessing submessage. + m.oneof_nested_message.bb + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_string')) + self.assertFalse(m.HasField('oneof_nested_message')) + m.oneof_nested_message.bb = 11 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) self.assertFalse(m.HasField('oneof_string')) @@ -590,72 +624,1099 @@ class MessageTest(basetest.TestCase): self.assertFalse(m.HasField('oneof_nested_message')) self.assertTrue(m.HasField('oneof_bytes')) - def testOneofCompositeFieldReadAccess(self): - m = unittest_pb2.TestAllTypes() + def testOneofCompositeFieldReadAccess(self, message_module): + m = message_module.TestAllTypes() m.oneof_uint32 = 11 self.ensureNestedMessageExists(m, 'oneof_nested_message') self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) self.assertEqual(11, m.oneof_uint32) - def testOneofHasField(self): - m = unittest_pb2.TestAllTypes() - self.assertFalse(m.HasField('oneof_field')) + def testOneofWhichOneof(self, message_module): + m = message_module.TestAllTypes() + self.assertIs(None, m.WhichOneof('oneof_field')) + if message_module is unittest_pb2: + self.assertFalse(m.HasField('oneof_field')) + m.oneof_uint32 = 11 - self.assertTrue(m.HasField('oneof_field')) + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + if message_module is unittest_pb2: + self.assertTrue(m.HasField('oneof_field')) + m.oneof_bytes = b'bb' - self.assertTrue(m.HasField('oneof_field')) + self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) + m.ClearField('oneof_bytes') - self.assertFalse(m.HasField('oneof_field')) + self.assertIs(None, m.WhichOneof('oneof_field')) + if message_module is unittest_pb2: + self.assertFalse(m.HasField('oneof_field')) - def testOneofClearField(self): - m = unittest_pb2.TestAllTypes() + def testOneofClearField(self, message_module): + m = message_module.TestAllTypes() m.oneof_uint32 = 11 m.ClearField('oneof_field') - self.assertFalse(m.HasField('oneof_field')) + if message_module is unittest_pb2: + self.assertFalse(m.HasField('oneof_field')) self.assertFalse(m.HasField('oneof_uint32')) self.assertIs(None, m.WhichOneof('oneof_field')) - def testOneofClearSetField(self): - m = unittest_pb2.TestAllTypes() + def testOneofClearSetField(self, message_module): + m = message_module.TestAllTypes() m.oneof_uint32 = 11 m.ClearField('oneof_uint32') - self.assertFalse(m.HasField('oneof_field')) + if message_module is unittest_pb2: + self.assertFalse(m.HasField('oneof_field')) self.assertFalse(m.HasField('oneof_uint32')) self.assertIs(None, m.WhichOneof('oneof_field')) - def testOneofClearUnsetField(self): - m = unittest_pb2.TestAllTypes() + def testOneofClearUnsetField(self, message_module): + m = message_module.TestAllTypes() m.oneof_uint32 = 11 self.ensureNestedMessageExists(m, 'oneof_nested_message') m.ClearField('oneof_nested_message') self.assertEqual(11, m.oneof_uint32) - self.assertTrue(m.HasField('oneof_field')) + if message_module is unittest_pb2: + self.assertTrue(m.HasField('oneof_field')) self.assertTrue(m.HasField('oneof_uint32')) self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) - def testOneofDeserialize(self): - m = unittest_pb2.TestAllTypes() + def testOneofDeserialize(self, message_module): + m = message_module.TestAllTypes() m.oneof_uint32 = 11 - m2 = unittest_pb2.TestAllTypes() + m2 = message_module.TestAllTypes() m2.ParseFromString(m.SerializeToString()) self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) - def testSortEmptyRepeatedCompositeContainer(self): + def testOneofCopyFrom(self, message_module): + m = message_module.TestAllTypes() + m.oneof_uint32 = 11 + m2 = message_module.TestAllTypes() + m2.CopyFrom(m) + self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) + + def testOneofNestedMergeFrom(self, message_module): + m = message_module.NestedTestAllTypes() + m.payload.oneof_uint32 = 11 + m2 = message_module.NestedTestAllTypes() + m2.payload.oneof_bytes = b'bb' + m2.child.payload.oneof_bytes = b'bb' + m2.MergeFrom(m) + self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field')) + self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field')) + + def testOneofMessageMergeFrom(self, message_module): + m = message_module.NestedTestAllTypes() + m.payload.oneof_nested_message.bb = 11 + m.child.payload.oneof_nested_message.bb = 12 + m2 = message_module.NestedTestAllTypes() + m2.payload.oneof_uint32 = 13 + m2.MergeFrom(m) + self.assertEqual('oneof_nested_message', + m2.payload.WhichOneof('oneof_field')) + self.assertEqual('oneof_nested_message', + m2.child.payload.WhichOneof('oneof_field')) + + def testOneofNestedMessageInit(self, message_module): + m = message_module.TestAllTypes( + oneof_nested_message=message_module.TestAllTypes.NestedMessage()) + self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) + + def testOneofClear(self, message_module): + m = message_module.TestAllTypes() + m.oneof_uint32 = 11 + m.Clear() + self.assertIsNone(m.WhichOneof('oneof_field')) + m.oneof_bytes = b'bb' + self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) + + def testAssignByteStringToUnicodeField(self, message_module): + """Assigning a byte string to a string field should result + in the value being converted to a Unicode string.""" + m = message_module.TestAllTypes() + m.optional_string = str('') + self.assertIsInstance(m.optional_string, six.text_type) + + def testLongValuedSlice(self, message_module): + """It should be possible to use long-valued indicies in slices + + This didn't used to work in the v2 C++ implementation. + """ + m = message_module.TestAllTypes() + + # Repeated scalar + m.repeated_int32.append(1) + sl = m.repeated_int32[long(0):long(len(m.repeated_int32))] + self.assertEqual(len(m.repeated_int32), len(sl)) + + # Repeated composite + m.repeated_nested_message.add().bb = 3 + sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))] + self.assertEqual(len(m.repeated_nested_message), len(sl)) + + def testExtendShouldNotSwallowExceptions(self, message_module): + """This didn't use to work in the v2 C++ implementation.""" + m = message_module.TestAllTypes() + with self.assertRaises(NameError) as _: + m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable + with self.assertRaises(NameError) as _: + m.repeated_nested_enum.extend( + a for i in range(10)) # pylint: disable=undefined-variable + + FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()] + + def testExtendInt32WithNothing(self, message_module): + """Test no-ops extending repeated int32 fields.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_int32) + + # TODO(ptucker): Deprecate this behavior. b/18413862 + for falsy_value in MessageTest.FALSY_VALUES: + m.repeated_int32.extend(falsy_value) + self.assertSequenceEqual([], m.repeated_int32) + + m.repeated_int32.extend([]) + self.assertSequenceEqual([], m.repeated_int32) + + def testExtendFloatWithNothing(self, message_module): + """Test no-ops extending repeated float fields.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_float) + + # TODO(ptucker): Deprecate this behavior. b/18413862 + for falsy_value in MessageTest.FALSY_VALUES: + m.repeated_float.extend(falsy_value) + self.assertSequenceEqual([], m.repeated_float) + + m.repeated_float.extend([]) + self.assertSequenceEqual([], m.repeated_float) + + def testExtendStringWithNothing(self, message_module): + """Test no-ops extending repeated string fields.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_string) + + # TODO(ptucker): Deprecate this behavior. b/18413862 + for falsy_value in MessageTest.FALSY_VALUES: + m.repeated_string.extend(falsy_value) + self.assertSequenceEqual([], m.repeated_string) + + m.repeated_string.extend([]) + self.assertSequenceEqual([], m.repeated_string) + + def testExtendInt32WithPythonList(self, message_module): + """Test extending repeated int32 fields with python lists.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_int32) + m.repeated_int32.extend([0]) + self.assertSequenceEqual([0], m.repeated_int32) + m.repeated_int32.extend([1, 2]) + self.assertSequenceEqual([0, 1, 2], m.repeated_int32) + m.repeated_int32.extend([3, 4]) + self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) + + def testExtendFloatWithPythonList(self, message_module): + """Test extending repeated float fields with python lists.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_float) + m.repeated_float.extend([0.0]) + self.assertSequenceEqual([0.0], m.repeated_float) + m.repeated_float.extend([1.0, 2.0]) + self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) + m.repeated_float.extend([3.0, 4.0]) + self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) + + def testExtendStringWithPythonList(self, message_module): + """Test extending repeated string fields with python lists.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_string) + m.repeated_string.extend(['']) + self.assertSequenceEqual([''], m.repeated_string) + m.repeated_string.extend(['11', '22']) + self.assertSequenceEqual(['', '11', '22'], m.repeated_string) + m.repeated_string.extend(['33', '44']) + self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string) + + def testExtendStringWithString(self, message_module): + """Test extending repeated string fields with characters from a string.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_string) + m.repeated_string.extend('abc') + self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string) + + class TestIterable(object): + """This iterable object mimics the behavior of numpy.array. + + __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1. + + """ + + def __init__(self, values=None): + self._list = values or [] + + def __nonzero__(self): + size = len(self._list) + if size == 0: + return False + if size == 1: + return bool(self._list[0]) + raise ValueError('Truth value is ambiguous.') + + def __len__(self): + return len(self._list) + + def __iter__(self): + return self._list.__iter__() + + def testExtendInt32WithIterable(self, message_module): + """Test extending repeated int32 fields with iterable.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_int32) + m.repeated_int32.extend(MessageTest.TestIterable([])) + self.assertSequenceEqual([], m.repeated_int32) + m.repeated_int32.extend(MessageTest.TestIterable([0])) + self.assertSequenceEqual([0], m.repeated_int32) + m.repeated_int32.extend(MessageTest.TestIterable([1, 2])) + self.assertSequenceEqual([0, 1, 2], m.repeated_int32) + m.repeated_int32.extend(MessageTest.TestIterable([3, 4])) + self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) + + def testExtendFloatWithIterable(self, message_module): + """Test extending repeated float fields with iterable.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_float) + m.repeated_float.extend(MessageTest.TestIterable([])) + self.assertSequenceEqual([], m.repeated_float) + m.repeated_float.extend(MessageTest.TestIterable([0.0])) + self.assertSequenceEqual([0.0], m.repeated_float) + m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0])) + self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) + m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0])) + self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) + + def testExtendStringWithIterable(self, message_module): + """Test extending repeated string fields with iterable.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_string) + m.repeated_string.extend(MessageTest.TestIterable([])) + self.assertSequenceEqual([], m.repeated_string) + m.repeated_string.extend(MessageTest.TestIterable([''])) + self.assertSequenceEqual([''], m.repeated_string) + m.repeated_string.extend(MessageTest.TestIterable(['1', '2'])) + self.assertSequenceEqual(['', '1', '2'], m.repeated_string) + m.repeated_string.extend(MessageTest.TestIterable(['3', '4'])) + self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string) + + def testPickleRepeatedScalarContainer(self, message_module): + # TODO(tibell): The pure-Python implementation support pickling of + # scalar containers in *some* cases. For now the cpp2 version + # throws an exception to avoid a segfault. Investigate if we + # want to support pickling of these fields. + # + # For more information see: https://b2.corp.google.com/u/0/issues/18677897 + if (api_implementation.Type() != 'cpp' or + api_implementation.Version() == 2): + return + m = message_module.TestAllTypes() + with self.assertRaises(pickle.PickleError) as _: + pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL) + + def testSortEmptyRepeatedCompositeContainer(self, message_module): """Exercise a scenario that has led to segfaults in the past. """ - m = unittest_pb2.TestAllTypes() + m = message_module.TestAllTypes() m.repeated_nested_message.sort() - def testHasFieldOnRepeatedField(self): + def testHasFieldOnRepeatedField(self, message_module): """Using HasField on a repeated field should raise an exception. """ - m = unittest_pb2.TestAllTypes() + m = message_module.TestAllTypes() with self.assertRaises(ValueError) as _: m.HasField('repeated_int32') + def testRepeatedScalarFieldPop(self, message_module): + m = message_module.TestAllTypes() + with self.assertRaises(IndexError) as _: + m.repeated_int32.pop() + m.repeated_int32.extend(range(5)) + self.assertEqual(4, m.repeated_int32.pop()) + self.assertEqual(0, m.repeated_int32.pop(0)) + self.assertEqual(2, m.repeated_int32.pop(1)) + self.assertEqual([1, 3], m.repeated_int32) + + def testRepeatedCompositeFieldPop(self, message_module): + m = message_module.TestAllTypes() + with self.assertRaises(IndexError) as _: + m.repeated_nested_message.pop() + for i in range(5): + n = m.repeated_nested_message.add() + n.bb = i + self.assertEqual(4, m.repeated_nested_message.pop().bb) + self.assertEqual(0, m.repeated_nested_message.pop(0).bb) + self.assertEqual(2, m.repeated_nested_message.pop(1).bb) + self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message]) + + +# Class to test proto2-only features (required, extensions, etc.) +class Proto2Test(unittest.TestCase): + + def testFieldPresence(self): + message = unittest_pb2.TestAllTypes() + + self.assertFalse(message.HasField("optional_int32")) + self.assertFalse(message.HasField("optional_bool")) + self.assertFalse(message.HasField("optional_nested_message")) + + with self.assertRaises(ValueError): + message.HasField("field_doesnt_exist") + + with self.assertRaises(ValueError): + message.HasField("repeated_int32") + with self.assertRaises(ValueError): + message.HasField("repeated_nested_message") + + self.assertEqual(0, message.optional_int32) + self.assertEqual(False, message.optional_bool) + self.assertEqual(0, message.optional_nested_message.bb) + + # Fields are set even when setting the values to default values. + message.optional_int32 = 0 + message.optional_bool = False + message.optional_nested_message.bb = 0 + self.assertTrue(message.HasField("optional_int32")) + self.assertTrue(message.HasField("optional_bool")) + self.assertTrue(message.HasField("optional_nested_message")) + + # Set the fields to non-default values. + message.optional_int32 = 5 + message.optional_bool = True + message.optional_nested_message.bb = 15 + + self.assertTrue(message.HasField("optional_int32")) + self.assertTrue(message.HasField("optional_bool")) + self.assertTrue(message.HasField("optional_nested_message")) + + # Clearing the fields unsets them and resets their value to default. + message.ClearField("optional_int32") + message.ClearField("optional_bool") + message.ClearField("optional_nested_message") + + self.assertFalse(message.HasField("optional_int32")) + self.assertFalse(message.HasField("optional_bool")) + self.assertFalse(message.HasField("optional_nested_message")) + self.assertEqual(0, message.optional_int32) + self.assertEqual(False, message.optional_bool) + self.assertEqual(0, message.optional_nested_message.bb) + + # TODO(tibell): The C++ implementations actually allows assignment + # of unknown enum values to *scalar* fields (but not repeated + # fields). Once checked enum fields becomes the default in the + # Python implementation, the C++ implementation should follow suit. + def testAssignInvalidEnum(self): + """It should not be possible to assign an invalid enum number to an + enum field.""" + m = unittest_pb2.TestAllTypes() + + with self.assertRaises(ValueError) as _: + m.optional_nested_enum = 1234567 + self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567) + + def testGoldenExtensions(self): + golden_data = test_util.GoldenFileData('golden_message') + golden_message = unittest_pb2.TestAllExtensions() + golden_message.ParseFromString(golden_data) + all_set = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(all_set) + self.assertEqual(all_set, golden_message) + self.assertEqual(golden_data, golden_message.SerializeToString()) + golden_copy = copy.deepcopy(golden_message) + self.assertEqual(golden_data, golden_copy.SerializeToString()) + + def testGoldenPackedExtensions(self): + golden_data = test_util.GoldenFileData('golden_packed_fields_message') + golden_message = unittest_pb2.TestPackedExtensions() + golden_message.ParseFromString(golden_data) + all_set = unittest_pb2.TestPackedExtensions() + test_util.SetAllPackedExtensions(all_set) + self.assertEqual(all_set, golden_message) + self.assertEqual(golden_data, all_set.SerializeToString()) + golden_copy = copy.deepcopy(golden_message) + self.assertEqual(golden_data, golden_copy.SerializeToString()) + + def testPickleIncompleteProto(self): + golden_message = unittest_pb2.TestRequired(a=1) + pickled_message = pickle.dumps(golden_message) + + unpickled_message = pickle.loads(pickled_message) + self.assertEqual(unpickled_message, golden_message) + self.assertEqual(unpickled_message.a, 1) + # This is still an incomplete proto - so serializing should fail + self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) + + + # TODO(haberman): this isn't really a proto2-specific test except that this + # message has a required field in it. Should probably be factored out so + # that we can test the other parts with proto3. + def testParsingMerge(self): + """Check the merge behavior when a required or optional field appears + multiple times in the input.""" + messages = [ + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes() ] + messages[0].optional_int32 = 1 + messages[1].optional_int64 = 2 + messages[2].optional_int32 = 3 + messages[2].optional_string = 'hello' + + merged_message = unittest_pb2.TestAllTypes() + merged_message.optional_int32 = 3 + merged_message.optional_int64 = 2 + merged_message.optional_string = 'hello' + + generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() + generator.field1.extend(messages) + generator.field2.extend(messages) + generator.field3.extend(messages) + generator.ext1.extend(messages) + generator.ext2.extend(messages) + generator.group1.add().field1.MergeFrom(messages[0]) + generator.group1.add().field1.MergeFrom(messages[1]) + generator.group1.add().field1.MergeFrom(messages[2]) + generator.group2.add().field1.MergeFrom(messages[0]) + generator.group2.add().field1.MergeFrom(messages[1]) + generator.group2.add().field1.MergeFrom(messages[2]) + + data = generator.SerializeToString() + parsing_merge = unittest_pb2.TestParsingMerge() + parsing_merge.ParseFromString(data) + + # Required and optional fields should be merged. + self.assertEqual(parsing_merge.required_all_types, merged_message) + self.assertEqual(parsing_merge.optional_all_types, merged_message) + self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, + merged_message) + self.assertEqual(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.optional_ext], + merged_message) + + # Repeated fields should not be merged. + self.assertEqual(len(parsing_merge.repeated_all_types), 3) + self.assertEqual(len(parsing_merge.repeatedgroup), 3) + self.assertEqual(len(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.repeated_ext]), 3) + + def testPythonicInit(self): + message = unittest_pb2.TestAllTypes( + optional_int32=100, + optional_fixed32=200, + optional_float=300.5, + optional_bytes=b'x', + optionalgroup={'a': 400}, + optional_nested_message={'bb': 500}, + optional_nested_enum='BAZ', + repeatedgroup=[{'a': 600}, + {'a': 700}], + repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR], + default_int32=800, + oneof_string='y') + self.assertIsInstance(message, unittest_pb2.TestAllTypes) + self.assertEqual(100, message.optional_int32) + self.assertEqual(200, message.optional_fixed32) + self.assertEqual(300.5, message.optional_float) + self.assertEqual(b'x', message.optional_bytes) + self.assertEqual(400, message.optionalgroup.a) + self.assertIsInstance(message.optional_nested_message, unittest_pb2.TestAllTypes.NestedMessage) + self.assertEqual(500, message.optional_nested_message.bb) + self.assertEqual(unittest_pb2.TestAllTypes.BAZ, + message.optional_nested_enum) + self.assertEqual(2, len(message.repeatedgroup)) + self.assertEqual(600, message.repeatedgroup[0].a) + self.assertEqual(700, message.repeatedgroup[1].a) + self.assertEqual(2, len(message.repeated_nested_enum)) + self.assertEqual(unittest_pb2.TestAllTypes.FOO, + message.repeated_nested_enum[0]) + self.assertEqual(unittest_pb2.TestAllTypes.BAR, + message.repeated_nested_enum[1]) + self.assertEqual(800, message.default_int32) + self.assertEqual('y', message.oneof_string) + self.assertFalse(message.HasField('optional_int64')) + self.assertEqual(0, len(message.repeated_float)) + self.assertEqual(42, message.default_int64) + + message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ') + self.assertEqual(unittest_pb2.TestAllTypes.BAZ, + message.optional_nested_enum) + + with self.assertRaises(ValueError): + unittest_pb2.TestAllTypes( + optional_nested_message={'INVALID_NESTED_FIELD': 17}) + + with self.assertRaises(TypeError): + unittest_pb2.TestAllTypes( + optional_nested_message={'bb': 'INVALID_VALUE_TYPE'}) + + with self.assertRaises(ValueError): + unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL') + + with self.assertRaises(ValueError): + unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') + + + +# Class to test proto3-only features/behavior (updated field presence & enums) +class Proto3Test(unittest.TestCase): + + # Utility method for comparing equality with a map. + def assertMapIterEquals(self, map_iter, dict_value): + # Avoid mutating caller's copy. + dict_value = dict(dict_value) + + for k, v in map_iter: + self.assertEqual(v, dict_value[k]) + del dict_value[k] + + self.assertEqual({}, dict_value) + + def testFieldPresence(self): + message = unittest_proto3_arena_pb2.TestAllTypes() + + # We can't test presence of non-repeated, non-submessage fields. + with self.assertRaises(ValueError): + message.HasField('optional_int32') + with self.assertRaises(ValueError): + message.HasField('optional_float') + with self.assertRaises(ValueError): + message.HasField('optional_string') + with self.assertRaises(ValueError): + message.HasField('optional_bool') + + # But we can still test presence of submessage fields. + self.assertFalse(message.HasField('optional_nested_message')) + + # As with proto2, we can't test presence of fields that don't exist, or + # repeated fields. + with self.assertRaises(ValueError): + message.HasField('field_doesnt_exist') + + with self.assertRaises(ValueError): + message.HasField('repeated_int32') + with self.assertRaises(ValueError): + message.HasField('repeated_nested_message') + + # Fields should default to their type-specific default. + self.assertEqual(0, message.optional_int32) + self.assertEqual(0, message.optional_float) + self.assertEqual('', message.optional_string) + self.assertEqual(False, message.optional_bool) + self.assertEqual(0, message.optional_nested_message.bb) + + # Setting a submessage should still return proper presence information. + message.optional_nested_message.bb = 0 + self.assertTrue(message.HasField('optional_nested_message')) + + # Set the fields to non-default values. + message.optional_int32 = 5 + message.optional_float = 1.1 + message.optional_string = 'abc' + message.optional_bool = True + message.optional_nested_message.bb = 15 + + # Clearing the fields unsets them and resets their value to default. + message.ClearField('optional_int32') + message.ClearField('optional_float') + message.ClearField('optional_string') + message.ClearField('optional_bool') + message.ClearField('optional_nested_message') + + self.assertEqual(0, message.optional_int32) + self.assertEqual(0, message.optional_float) + self.assertEqual('', message.optional_string) + self.assertEqual(False, message.optional_bool) + self.assertEqual(0, message.optional_nested_message.bb) + + def testAssignUnknownEnum(self): + """Assigning an unknown enum value is allowed and preserves the value.""" + m = unittest_proto3_arena_pb2.TestAllTypes() + + m.optional_nested_enum = 1234567 + self.assertEqual(1234567, m.optional_nested_enum) + m.repeated_nested_enum.append(22334455) + self.assertEqual(22334455, m.repeated_nested_enum[0]) + # Assignment is a different code path than append for the C++ impl. + m.repeated_nested_enum[0] = 7654321 + self.assertEqual(7654321, m.repeated_nested_enum[0]) + serialized = m.SerializeToString() + + m2 = unittest_proto3_arena_pb2.TestAllTypes() + m2.ParseFromString(serialized) + self.assertEqual(1234567, m2.optional_nested_enum) + self.assertEqual(7654321, m2.repeated_nested_enum[0]) + + # Map isn't really a proto3-only feature. But there is no proto2 equivalent + # of google/protobuf/map_unittest.proto right now, so it's not easy to + # test both with the same test like we do for the other proto2/proto3 tests. + # (google/protobuf/map_protobuf_unittest.proto is very different in the set + # of messages and fields it contains). + def testScalarMapDefaults(self): + msg = map_unittest_pb2.TestMap() + + # Scalars start out unset. + self.assertFalse(-123 in msg.map_int32_int32) + self.assertFalse(-2**33 in msg.map_int64_int64) + self.assertFalse(123 in msg.map_uint32_uint32) + self.assertFalse(2**33 in msg.map_uint64_uint64) + self.assertFalse(123 in msg.map_int32_double) + self.assertFalse(False in msg.map_bool_bool) + self.assertFalse('abc' in msg.map_string_string) + self.assertFalse(111 in msg.map_int32_bytes) + self.assertFalse(888 in msg.map_int32_enum) + + # Accessing an unset key returns the default. + self.assertEqual(0, msg.map_int32_int32[-123]) + self.assertEqual(0, msg.map_int64_int64[-2**33]) + self.assertEqual(0, msg.map_uint32_uint32[123]) + self.assertEqual(0, msg.map_uint64_uint64[2**33]) + self.assertEqual(0.0, msg.map_int32_double[123]) + self.assertTrue(isinstance(msg.map_int32_double[123], float)) + self.assertEqual(False, msg.map_bool_bool[False]) + self.assertTrue(isinstance(msg.map_bool_bool[False], bool)) + self.assertEqual('', msg.map_string_string['abc']) + self.assertEqual(b'', msg.map_int32_bytes[111]) + self.assertEqual(0, msg.map_int32_enum[888]) + + # It also sets the value in the map + self.assertTrue(-123 in msg.map_int32_int32) + self.assertTrue(-2**33 in msg.map_int64_int64) + self.assertTrue(123 in msg.map_uint32_uint32) + self.assertTrue(2**33 in msg.map_uint64_uint64) + self.assertTrue(123 in msg.map_int32_double) + self.assertTrue(False in msg.map_bool_bool) + self.assertTrue('abc' in msg.map_string_string) + self.assertTrue(111 in msg.map_int32_bytes) + self.assertTrue(888 in msg.map_int32_enum) + + self.assertIsInstance(msg.map_string_string['abc'], six.text_type) + + # Accessing an unset key still throws TypeError if the type of the key + # is incorrect. + with self.assertRaises(TypeError): + msg.map_string_string[123] + + with self.assertRaises(TypeError): + 123 in msg.map_string_string + + def testMapGet(self): + # Need to test that get() properly returns the default, even though the dict + # has defaultdict-like semantics. + msg = map_unittest_pb2.TestMap() + + self.assertIsNone(msg.map_int32_int32.get(5)) + self.assertEqual(10, msg.map_int32_int32.get(5, 10)) + self.assertIsNone(msg.map_int32_int32.get(5)) + + msg.map_int32_int32[5] = 15 + self.assertEqual(15, msg.map_int32_int32.get(5)) + + self.assertIsNone(msg.map_int32_foreign_message.get(5)) + self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10)) + + submsg = msg.map_int32_foreign_message[5] + self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) + + def testScalarMap(self): + msg = map_unittest_pb2.TestMap() + + self.assertEqual(0, len(msg.map_int32_int32)) + self.assertFalse(5 in msg.map_int32_int32) -class ValidTypeNamesTest(basetest.TestCase): + msg.map_int32_int32[-123] = -456 + msg.map_int64_int64[-2**33] = -2**34 + msg.map_uint32_uint32[123] = 456 + msg.map_uint64_uint64[2**33] = 2**34 + msg.map_string_string['abc'] = '123' + msg.map_int32_enum[888] = 2 + + self.assertEqual([], msg.FindInitializationErrors()) + + self.assertEqual(1, len(msg.map_string_string)) + + # Bad key. + with self.assertRaises(TypeError): + msg.map_string_string[123] = '123' + + # Verify that trying to assign a bad key doesn't actually add a member to + # the map. + self.assertEqual(1, len(msg.map_string_string)) + + # Bad value. + with self.assertRaises(TypeError): + msg.map_string_string['123'] = 123 + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + + # Bad key. + with self.assertRaises(TypeError): + msg2.map_string_string[123] = '123' + + # Bad value. + with self.assertRaises(TypeError): + msg2.map_string_string['123'] = 123 + + self.assertEqual(-456, msg2.map_int32_int32[-123]) + self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) + self.assertEqual(456, msg2.map_uint32_uint32[123]) + self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) + self.assertEqual('123', msg2.map_string_string['abc']) + self.assertEqual(2, msg2.map_int32_enum[888]) + + def testStringUnicodeConversionInMap(self): + msg = map_unittest_pb2.TestMap() + + unicode_obj = u'\u1234' + bytes_obj = unicode_obj.encode('utf8') + + msg.map_string_string[bytes_obj] = bytes_obj + + (key, value) = list(msg.map_string_string.items())[0] + + self.assertEqual(key, unicode_obj) + self.assertEqual(value, unicode_obj) + + self.assertIsInstance(key, six.text_type) + self.assertIsInstance(value, six.text_type) + + def testMessageMap(self): + msg = map_unittest_pb2.TestMap() + + self.assertEqual(0, len(msg.map_int32_foreign_message)) + self.assertFalse(5 in msg.map_int32_foreign_message) + + msg.map_int32_foreign_message[123] + # get_or_create() is an alias for getitem. + msg.map_int32_foreign_message.get_or_create(-456) + + self.assertEqual(2, len(msg.map_int32_foreign_message)) + self.assertIn(123, msg.map_int32_foreign_message) + self.assertIn(-456, msg.map_int32_foreign_message) + self.assertEqual(2, len(msg.map_int32_foreign_message)) + + # Bad key. + with self.assertRaises(TypeError): + msg.map_int32_foreign_message['123'] + + # Can't assign directly to submessage. + with self.assertRaises(ValueError): + msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123] + + # Verify that trying to assign a bad key doesn't actually add a member to + # the map. + self.assertEqual(2, len(msg.map_int32_foreign_message)) + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + + self.assertEqual(2, len(msg2.map_int32_foreign_message)) + self.assertIn(123, msg2.map_int32_foreign_message) + self.assertIn(-456, msg2.map_int32_foreign_message) + self.assertEqual(2, len(msg2.map_int32_foreign_message)) + + def testMergeFrom(self): + msg = map_unittest_pb2.TestMap() + msg.map_int32_int32[12] = 34 + msg.map_int32_int32[56] = 78 + msg.map_int64_int64[22] = 33 + msg.map_int32_foreign_message[111].c = 5 + msg.map_int32_foreign_message[222].c = 10 + + msg2 = map_unittest_pb2.TestMap() + msg2.map_int32_int32[12] = 55 + msg2.map_int64_int64[88] = 99 + msg2.map_int32_foreign_message[222].c = 15 + + msg2.MergeFrom(msg) + + self.assertEqual(34, msg2.map_int32_int32[12]) + self.assertEqual(78, msg2.map_int32_int32[56]) + self.assertEqual(33, msg2.map_int64_int64[22]) + self.assertEqual(99, msg2.map_int64_int64[88]) + self.assertEqual(5, msg2.map_int32_foreign_message[111].c) + self.assertEqual(10, msg2.map_int32_foreign_message[222].c) + + # Verify that there is only one entry per key, even though the MergeFrom + # may have internally created multiple entries for a single key in the + # list representation. + as_dict = {} + for key in msg2.map_int32_foreign_message: + self.assertFalse(key in as_dict) + as_dict[key] = msg2.map_int32_foreign_message[key].c + + self.assertEqual({111: 5, 222: 10}, as_dict) + + # Special case: test that delete of item really removes the item, even if + # there might have physically been duplicate keys due to the previous merge. + # This is only a special case for the C++ implementation which stores the + # map as an array. + del msg2.map_int32_int32[12] + self.assertFalse(12 in msg2.map_int32_int32) + + del msg2.map_int32_foreign_message[222] + self.assertFalse(222 in msg2.map_int32_foreign_message) + + def testMergeFromBadType(self): + msg = map_unittest_pb2.TestMap() + with self.assertRaisesRegexp( + TypeError, + r'Parameter to MergeFrom\(\) must be instance of same class: expected ' + r'.*TestMap got int\.'): + msg.MergeFrom(1) + + def testCopyFromBadType(self): + msg = map_unittest_pb2.TestMap() + with self.assertRaisesRegexp( + TypeError, + r'Parameter to [A-Za-z]*From\(\) must be instance of same class: ' + r'expected .*TestMap got int\.'): + msg.CopyFrom(1) + + def testIntegerMapWithLongs(self): + msg = map_unittest_pb2.TestMap() + msg.map_int32_int32[long(-123)] = long(-456) + msg.map_int64_int64[long(-2**33)] = long(-2**34) + msg.map_uint32_uint32[long(123)] = long(456) + msg.map_uint64_uint64[long(2**33)] = long(2**34) + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + + self.assertEqual(-456, msg2.map_int32_int32[-123]) + self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) + self.assertEqual(456, msg2.map_uint32_uint32[123]) + self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) + + def testMapAssignmentCausesPresence(self): + msg = map_unittest_pb2.TestMapSubmessage() + msg.test_map.map_int32_int32[123] = 456 + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMapSubmessage() + msg2.ParseFromString(serialized) + + self.assertEqual(msg, msg2) + + # Now test that various mutations of the map properly invalidate the + # cached size of the submessage. + msg.test_map.map_int32_int32[888] = 999 + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + msg.test_map.map_int32_int32.clear() + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + def testMapAssignmentCausesPresenceForSubmessages(self): + msg = map_unittest_pb2.TestMapSubmessage() + msg.test_map.map_int32_foreign_message[123].c = 5 + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMapSubmessage() + msg2.ParseFromString(serialized) + + self.assertEqual(msg, msg2) + + # Now test that various mutations of the map properly invalidate the + # cached size of the submessage. + msg.test_map.map_int32_foreign_message[888].c = 7 + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + msg.test_map.map_int32_foreign_message[888].MergeFrom( + msg.test_map.map_int32_foreign_message[123]) + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + msg.test_map.map_int32_foreign_message.clear() + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + def testModifyMapWhileIterating(self): + msg = map_unittest_pb2.TestMap() + + string_string_iter = iter(msg.map_string_string) + int32_foreign_iter = iter(msg.map_int32_foreign_message) + + msg.map_string_string['abc'] = '123' + msg.map_int32_foreign_message[5].c = 5 + + with self.assertRaises(RuntimeError): + for key in string_string_iter: + pass + + with self.assertRaises(RuntimeError): + for key in int32_foreign_iter: + pass + + def testSubmessageMap(self): + msg = map_unittest_pb2.TestMap() + + submsg = msg.map_int32_foreign_message[111] + self.assertIs(submsg, msg.map_int32_foreign_message[111]) + self.assertIsInstance(submsg, unittest_pb2.ForeignMessage) + + submsg.c = 5 + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + + self.assertEqual(5, msg2.map_int32_foreign_message[111].c) + + # Doesn't allow direct submessage assignment. + with self.assertRaises(ValueError): + msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage() + + def testMapIteration(self): + msg = map_unittest_pb2.TestMap() + + for k, v in msg.map_int32_int32.items(): + # Should not be reached. + self.assertTrue(False) + + msg.map_int32_int32[2] = 4 + msg.map_int32_int32[3] = 6 + msg.map_int32_int32[4] = 8 + self.assertEqual(3, len(msg.map_int32_int32)) + + matching_dict = {2: 4, 3: 6, 4: 8} + self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict) + + def testMapItems(self): + # Map items used to have strange behaviors when use c extension. Because + # [] may reorder the map and invalidate any exsting iterators. + # TODO(jieluo): Check if [] reordering the map is a bug or intended + # behavior. + msg = map_unittest_pb2.TestMap() + msg.map_string_string['local_init_op'] = '' + msg.map_string_string['trainable_variables'] = '' + msg.map_string_string['variables'] = '' + msg.map_string_string['init_op'] = '' + msg.map_string_string['summaries'] = '' + items1 = msg.map_string_string.items() + items2 = msg.map_string_string.items() + self.assertEqual(items1, items2) + + def testMapIterationClearMessage(self): + # Iterator needs to work even if message and map are deleted. + msg = map_unittest_pb2.TestMap() + + msg.map_int32_int32[2] = 4 + msg.map_int32_int32[3] = 6 + msg.map_int32_int32[4] = 8 + + it = msg.map_int32_int32.items() + del msg + + matching_dict = {2: 4, 3: 6, 4: 8} + self.assertMapIterEquals(it, matching_dict) + + def testMapConstruction(self): + msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4}) + self.assertEqual(2, msg.map_int32_int32[1]) + self.assertEqual(4, msg.map_int32_int32[3]) + + msg = map_unittest_pb2.TestMap( + map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)}) + self.assertEqual(5, msg.map_int32_foreign_message[3].c) + + def testMapValidAfterFieldCleared(self): + # Map needs to work even if field is cleared. + # For the C++ implementation this tests the correctness of + # ScalarMapContainer::Release() + msg = map_unittest_pb2.TestMap() + int32_map = msg.map_int32_int32 + + int32_map[2] = 4 + int32_map[3] = 6 + int32_map[4] = 8 + + msg.ClearField('map_int32_int32') + self.assertEqual(b'', msg.SerializeToString()) + matching_dict = {2: 4, 3: 6, 4: 8} + self.assertMapIterEquals(int32_map.items(), matching_dict) + + def testMessageMapValidAfterFieldCleared(self): + # Map needs to work even if field is cleared. + # For the C++ implementation this tests the correctness of + # ScalarMapContainer::Release() + msg = map_unittest_pb2.TestMap() + int32_foreign_message = msg.map_int32_foreign_message + + int32_foreign_message[2].c = 5 + + msg.ClearField('map_int32_foreign_message') + self.assertEqual(b'', msg.SerializeToString()) + self.assertTrue(2 in int32_foreign_message.keys()) + + def testMapIterInvalidatedByClearField(self): + # Map iterator is invalidated when field is cleared. + # But this case does need to not crash the interpreter. + # For the C++ implementation this tests the correctness of + # ScalarMapContainer::Release() + msg = map_unittest_pb2.TestMap() + + it = iter(msg.map_int32_int32) + + msg.ClearField('map_int32_int32') + with self.assertRaises(RuntimeError): + for _ in it: + pass + + it = iter(msg.map_int32_foreign_message) + msg.ClearField('map_int32_foreign_message') + with self.assertRaises(RuntimeError): + for _ in it: + pass + + def testMapDelete(self): + msg = map_unittest_pb2.TestMap() + + self.assertEqual(0, len(msg.map_int32_int32)) + + msg.map_int32_int32[4] = 6 + self.assertEqual(1, len(msg.map_int32_int32)) + + with self.assertRaises(KeyError): + del msg.map_int32_int32[88] + + del msg.map_int32_int32[4] + self.assertEqual(0, len(msg.map_int32_int32)) + + def testMapsAreMapping(self): + msg = map_unittest_pb2.TestMap() + self.assertIsInstance(msg.map_int32_int32, collections.Mapping) + self.assertIsInstance(msg.map_int32_int32, collections.MutableMapping) + self.assertIsInstance(msg.map_int32_foreign_message, collections.Mapping) + self.assertIsInstance(msg.map_int32_foreign_message, + collections.MutableMapping) + + def testMapFindInitializationErrorsSmokeTest(self): + msg = map_unittest_pb2.TestMap() + msg.map_string_string['abc'] = '123' + msg.map_int32_int32[35] = 64 + msg.map_string_foreign_message['foo'].c = 5 + self.assertEqual(0, len(msg.FindInitializationErrors())) + + + +class ValidTypeNamesTest(unittest.TestCase): def assertImportFromName(self, msg, base_name): # Parse <type 'module.class_name'> to extra 'some.name' as a string. @@ -676,6 +1737,116 @@ class ValidTypeNamesTest(basetest.TestCase): self.assertImportFromName(pb.repeated_int32, 'Scalar') self.assertImportFromName(pb.repeated_nested_message, 'Composite') +class PackedFieldTest(unittest.TestCase): + + def setMessage(self, message): + message.repeated_int32.append(1) + message.repeated_int64.append(1) + message.repeated_uint32.append(1) + message.repeated_uint64.append(1) + message.repeated_sint32.append(1) + message.repeated_sint64.append(1) + message.repeated_fixed32.append(1) + message.repeated_fixed64.append(1) + message.repeated_sfixed32.append(1) + message.repeated_sfixed64.append(1) + message.repeated_float.append(1.0) + message.repeated_double.append(1.0) + message.repeated_bool.append(True) + message.repeated_nested_enum.append(1) + + def testPackedFields(self): + message = packed_field_test_pb2.TestPackedTypes() + self.setMessage(message) + golden_data = (b'\x0A\x01\x01' + b'\x12\x01\x01' + b'\x1A\x01\x01' + b'\x22\x01\x01' + b'\x2A\x01\x02' + b'\x32\x01\x02' + b'\x3A\x04\x01\x00\x00\x00' + b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00' + b'\x4A\x04\x01\x00\x00\x00' + b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00' + b'\x5A\x04\x00\x00\x80\x3f' + b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f' + b'\x6A\x01\x01' + b'\x72\x01\x01') + self.assertEqual(golden_data, message.SerializeToString()) + + def testUnpackedFields(self): + message = packed_field_test_pb2.TestUnpackedTypes() + self.setMessage(message) + golden_data = (b'\x08\x01' + b'\x10\x01' + b'\x18\x01' + b'\x20\x01' + b'\x28\x02' + b'\x30\x02' + b'\x3D\x01\x00\x00\x00' + b'\x41\x01\x00\x00\x00\x00\x00\x00\x00' + b'\x4D\x01\x00\x00\x00' + b'\x51\x01\x00\x00\x00\x00\x00\x00\x00' + b'\x5D\x00\x00\x80\x3f' + b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f' + b'\x68\x01' + b'\x70\x01') + self.assertEqual(golden_data, message.SerializeToString()) + + +@unittest.skipIf(api_implementation.Type() != 'cpp', + 'explicit tests of the C++ implementation') +class OversizeProtosTest(unittest.TestCase): + + def setUp(self): + self.file_desc = """ + name: "f/f.msg2" + package: "f" + message_type { + name: "msg1" + field { + name: "payload" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + message_type { + name: "msg2" + field { + name: "field" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: "msg1" + } + } + """ + pool = descriptor_pool.DescriptorPool() + desc = descriptor_pb2.FileDescriptorProto() + text_format.Parse(self.file_desc, desc) + pool.Add(desc) + self.proto_cls = message_factory.MessageFactory(pool).GetPrototype( + pool.FindMessageTypeByName('f.msg2')) + self.p = self.proto_cls() + self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1) + self.p_serialized = self.p.SerializeToString() + + def testAssertOversizeProto(self): + from google.protobuf.pyext._message import SetAllowOversizeProtos + SetAllowOversizeProtos(False) + q = self.proto_cls() + try: + q.ParseFromString(self.p_serialized) + except message.DecodeError as e: + self.assertEqual(str(e), 'Error parsing message') + + def testSucceedOversizeProto(self): + from google.protobuf.pyext._message import SetAllowOversizeProtos + SetAllowOversizeProtos(True) + q = self.proto_cls() + q.ParseFromString(self.p_serialized) + self.assertEqual(self.p.field.payload, q.field.payload) if __name__ == '__main__': - basetest.main() + unittest.main() diff --git a/python/google/protobuf/internal/missing_enum_values.proto b/python/google/protobuf/internal/missing_enum_values.proto index e90f0cd3e..1850be5bb 100644 --- a/python/google/protobuf/internal/missing_enum_values.proto +++ b/python/google/protobuf/internal/missing_enum_values.proto @@ -28,6 +28,8 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +syntax = "proto2"; + package google.protobuf.python.internal; message TestEnumValues { @@ -48,3 +50,7 @@ message TestMissingEnumValues { repeated NestedEnum repeated_nested_enum = 2; repeated NestedEnum packed_nested_enum = 3 [packed = true]; } + +message JustString { + required string dummy = 1; +} diff --git a/python/google/protobuf/internal/more_extensions.proto b/python/google/protobuf/internal/more_extensions.proto index c04e597f9..78f146736 100644 --- a/python/google/protobuf/internal/more_extensions.proto +++ b/python/google/protobuf/internal/more_extensions.proto @@ -30,6 +30,7 @@ // Author: robinson@google.com (Will Robinson) +syntax = "proto2"; package google.protobuf.internal; diff --git a/python/google/protobuf/internal/more_extensions_dynamic.proto b/python/google/protobuf/internal/more_extensions_dynamic.proto index 88bd9c1b2..11f85ef60 100644 --- a/python/google/protobuf/internal/more_extensions_dynamic.proto +++ b/python/google/protobuf/internal/more_extensions_dynamic.proto @@ -34,6 +34,7 @@ // generated C++ type is available for the extendee, but the extension is // defined in a file whose C++ type is not in the binary. +syntax = "proto2"; import "google/protobuf/internal/more_extensions.proto"; diff --git a/python/google/protobuf/internal/more_messages.proto b/python/google/protobuf/internal/more_messages.proto index 61db66c56..2c6ab9efd 100644 --- a/python/google/protobuf/internal/more_messages.proto +++ b/python/google/protobuf/internal/more_messages.proto @@ -30,6 +30,7 @@ // Author: robinson@google.com (Will Robinson) +syntax = "proto2"; package google.protobuf.internal; diff --git a/python/google/protobuf/internal/packed_field_test.proto b/python/google/protobuf/internal/packed_field_test.proto new file mode 100644 index 000000000..0dfdc10a8 --- /dev/null +++ b/python/google/protobuf/internal/packed_field_test.proto @@ -0,0 +1,73 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +syntax = "proto3"; + +package google.protobuf.python.internal; + +message TestPackedTypes { + enum NestedEnum { + FOO = 0; + BAR = 1; + BAZ = 2; + } + + repeated int32 repeated_int32 = 1; + repeated int64 repeated_int64 = 2; + repeated uint32 repeated_uint32 = 3; + repeated uint64 repeated_uint64 = 4; + repeated sint32 repeated_sint32 = 5; + repeated sint64 repeated_sint64 = 6; + repeated fixed32 repeated_fixed32 = 7; + repeated fixed64 repeated_fixed64 = 8; + repeated sfixed32 repeated_sfixed32 = 9; + repeated sfixed64 repeated_sfixed64 = 10; + repeated float repeated_float = 11; + repeated double repeated_double = 12; + repeated bool repeated_bool = 13; + repeated NestedEnum repeated_nested_enum = 14; +} + +message TestUnpackedTypes { + repeated int32 repeated_int32 = 1 [packed = false]; + repeated int64 repeated_int64 = 2 [packed = false]; + repeated uint32 repeated_uint32 = 3 [packed = false]; + repeated uint64 repeated_uint64 = 4 [packed = false]; + repeated sint32 repeated_sint32 = 5 [packed = false]; + repeated sint64 repeated_sint64 = 6 [packed = false]; + repeated fixed32 repeated_fixed32 = 7 [packed = false]; + repeated fixed64 repeated_fixed64 = 8 [packed = false]; + repeated sfixed32 repeated_sfixed32 = 9 [packed = false]; + repeated sfixed64 repeated_sfixed64 = 10 [packed = false]; + repeated float repeated_float = 11 [packed = false]; + repeated double repeated_double = 12 [packed = false]; + repeated bool repeated_bool = 13 [packed = false]; + repeated TestPackedTypes.NestedEnum repeated_nested_enum = 14 [packed = false]; +} diff --git a/python/google/protobuf/internal/proto_builder_test.py b/python/google/protobuf/internal/proto_builder_test.py new file mode 100644 index 000000000..36dfbfded --- /dev/null +++ b/python/google/protobuf/internal/proto_builder_test.py @@ -0,0 +1,96 @@ +#! /usr/bin/env python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.proto_builder.""" + +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict #PY26 +try: + import unittest2 as unittest +except ImportError: + import unittest + +from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor_pool +from google.protobuf import proto_builder +from google.protobuf import text_format + + +class ProtoBuilderTest(unittest.TestCase): + + def setUp(self): + self.ordered_fields = OrderedDict([ + ('foo', descriptor_pb2.FieldDescriptorProto.TYPE_INT64), + ('bar', descriptor_pb2.FieldDescriptorProto.TYPE_STRING), + ]) + self._fields = dict(self.ordered_fields) + + def testMakeSimpleProtoClass(self): + """Test that we can create a proto class.""" + proto_cls = proto_builder.MakeSimpleProtoClass( + self._fields, + full_name='net.proto2.python.public.proto_builder_test.Test') + proto = proto_cls() + proto.foo = 12345 + proto.bar = 'asdf' + self.assertMultiLineEqual( + 'bar: "asdf"\nfoo: 12345\n', text_format.MessageToString(proto)) + + def testOrderedFields(self): + """Test that the field order is maintained when given an OrderedDict.""" + proto_cls = proto_builder.MakeSimpleProtoClass( + self.ordered_fields, + full_name='net.proto2.python.public.proto_builder_test.OrderedTest') + proto = proto_cls() + proto.foo = 12345 + proto.bar = 'asdf' + self.assertMultiLineEqual( + 'foo: 12345\nbar: "asdf"\n', text_format.MessageToString(proto)) + + def testMakeSameProtoClassTwice(self): + """Test that the DescriptorPool is used.""" + pool = descriptor_pool.DescriptorPool() + proto_cls1 = proto_builder.MakeSimpleProtoClass( + self._fields, + full_name='net.proto2.python.public.proto_builder_test.Test', + pool=pool) + proto_cls2 = proto_builder.MakeSimpleProtoClass( + self._fields, + full_name='net.proto2.python.public.proto_builder_test.Test', + pool=pool) + self.assertIs(proto_cls1.DESCRIPTOR, proto_cls2.DESCRIPTOR) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index a5c26f452..f8f73dd20 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -28,10 +28,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# Keep it Python2.5 compatible for GAE. -# -# Copyright 2007 Google Inc. All Rights Reserved. -# # This code is meant to work on Python 2.4 and above only. # # TODO(robinson): Helpers for verbose, common checks like seeing if a @@ -54,19 +50,21 @@ this file*. __author__ = 'robinson@google.com (Will Robinson)' +from io import BytesIO import sys -if sys.version_info[0] < 3: - try: - from cStringIO import StringIO as BytesIO - except ImportError: - from StringIO import StringIO as BytesIO - import copy_reg as copyreg -else: - from io import BytesIO - import copyreg import struct import weakref +import six +try: + import six.moves.copyreg as copyreg +except ImportError: + # On some platforms, for example gMac, we run native Python because there is + # nothing like hermetic Python. This means lesser control on the system and + # the six.moves package may be missing (is missing on 20150321 on gMac). Be + # extra conservative and try to load the old replacement if it fails. + import copy_reg as copyreg + # We use "as" to avoid name collisions with variables. from google.protobuf.internal import containers from google.protobuf.internal import decoder @@ -74,41 +72,121 @@ from google.protobuf.internal import encoder from google.protobuf.internal import enum_type_wrapper from google.protobuf.internal import message_listener as message_listener_mod from google.protobuf.internal import type_checkers +from google.protobuf.internal import well_known_types from google.protobuf.internal import wire_format from google.protobuf import descriptor as descriptor_mod from google.protobuf import message as message_mod +from google.protobuf import symbol_database from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor +_AnyFullTypeName = 'google.protobuf.Any' -def NewMessage(bases, descriptor, dictionary): - _AddClassAttributesForNestedExtensions(descriptor, dictionary) - _AddSlots(descriptor, dictionary) - return bases +class GeneratedProtocolMessageType(type): + """Metaclass for protocol message classes created at runtime from Descriptors. -def InitMessage(descriptor, cls): - cls._decoders_by_tag = {} - cls._extensions_by_name = {} - cls._extensions_by_number = {} - if (descriptor.has_options and - descriptor.GetOptions().message_set_wire_format): - cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( - decoder.MessageSetItemDecoder(cls._extensions_by_number), None) + We add implementations for all methods described in the Message class. We + also create properties to allow getting/setting all fields in the protocol + message. Finally, we create slots to prevent users from accidentally + "setting" nonexistent fields in the protocol message, which then wouldn't get + serialized / deserialized properly. - # Attach stuff to each FieldDescriptor for quick lookup later on. - for field in descriptor.fields: - _AttachFieldHelpers(cls, field) + The protocol compiler currently uses this metaclass to create protocol + message classes at runtime. Clients can also manually create their own + classes at runtime, as in this example: - _AddEnumValues(descriptor, cls) - _AddInitMethod(descriptor, cls) - _AddPropertiesForFields(descriptor, cls) - _AddPropertiesForExtensions(descriptor, cls) - _AddStaticMethods(cls) - _AddMessageMethods(descriptor, cls) - _AddPrivateHelperMethods(descriptor, cls) - copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + mydescriptor = Descriptor(.....) + class MyProtoClass(Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = mydescriptor + myproto_instance = MyProtoClass() + myproto.foo_field = 23 + ... + + The above example will not work for nested types. If you wish to include them, + use reflection.MakeClass() instead of manually instantiating the class in + order to create the appropriate class structure. + """ + + # Must be consistent with the protocol-compiler code in + # proto2/compiler/internal/generator.*. + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __new__(cls, name, bases, dictionary): + """Custom allocation for runtime-generated class types. + + We override __new__ because this is apparently the only place + where we can meaningfully set __slots__ on the class we're creating(?). + (The interplay between metaclasses and slots is not very well-documented). + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + + Returns: + Newly-allocated class. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + if descriptor.full_name in well_known_types.WKTBASES: + bases += (well_known_types.WKTBASES[descriptor.full_name],) + _AddClassAttributesForNestedExtensions(descriptor, dictionary) + _AddSlots(descriptor, dictionary) + + superclass = super(GeneratedProtocolMessageType, cls) + new_class = superclass.__new__(cls, name, bases, dictionary) + return new_class + + def __init__(cls, name, bases, dictionary): + """Here we perform the majority of our work on the class. + We add enum getters, an __init__ method, implementations + of all Message methods, and properties for all fields + in the protocol type. + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + cls._decoders_by_tag = {} + cls._extensions_by_name = {} + cls._extensions_by_number = {} + if (descriptor.has_options and + descriptor.GetOptions().message_set_wire_format): + cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( + decoder.MessageSetItemDecoder(cls._extensions_by_number), None) + + # Attach stuff to each FieldDescriptor for quick lookup later on. + for field in descriptor.fields: + _AttachFieldHelpers(cls, field) + + descriptor._concrete_class = cls # pylint: disable=protected-access + _AddEnumValues(descriptor, cls) + _AddInitMethod(descriptor, cls) + _AddPropertiesForFields(descriptor, cls) + _AddPropertiesForExtensions(descriptor, cls) + _AddStaticMethods(cls) + _AddMessageMethods(descriptor, cls) + _AddPrivateHelperMethods(descriptor, cls) + copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + + superclass = super(GeneratedProtocolMessageType, cls) + superclass.__init__(name, bases, dictionary) # Stateless helpers for GeneratedProtocolMessageType below. @@ -194,16 +272,40 @@ def _IsMessageSetExtension(field): field.containing_type.has_options and field.containing_type.GetOptions().message_set_wire_format and field.type == _FieldDescriptor.TYPE_MESSAGE and - field.message_type == field.extension_scope and field.label == _FieldDescriptor.LABEL_OPTIONAL) +def _IsMapField(field): + return (field.type == _FieldDescriptor.TYPE_MESSAGE and + field.message_type.has_options and + field.message_type.GetOptions().map_entry) + + +def _IsMessageMapField(field): + value_type = field.message_type.fields_by_name["value"] + return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE + + def _AttachFieldHelpers(cls, field_descriptor): is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) - is_packed = (field_descriptor.has_options and - field_descriptor.GetOptions().packed) - - if _IsMessageSetExtension(field_descriptor): + is_packable = (is_repeated and + wire_format.IsTypePackable(field_descriptor.type)) + if not is_packable: + is_packed = False + elif field_descriptor.containing_type.syntax == "proto2": + is_packed = (field_descriptor.has_options and + field_descriptor.GetOptions().packed) + else: + has_packed_false = (field_descriptor.has_options and + field_descriptor.GetOptions().HasField("packed") and + field_descriptor.GetOptions().packed == False) + is_packed = not has_packed_false + is_map_entry = _IsMapField(field_descriptor) + + if is_map_entry: + field_encoder = encoder.MapEncoder(field_descriptor) + sizer = encoder.MapSizer(field_descriptor) + elif _IsMessageSetExtension(field_descriptor): field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) sizer = encoder.MessageSetItemSizer(field_descriptor.number) else: @@ -219,12 +321,27 @@ def _AttachFieldHelpers(cls, field_descriptor): def AddDecoder(wiretype, is_packed): tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) - cls._decoders_by_tag[tag_bytes] = ( - type_checkers.TYPE_TO_DECODER[field_descriptor.type]( - field_descriptor.number, is_repeated, is_packed, - field_descriptor, field_descriptor._default_constructor), - field_descriptor if field_descriptor.containing_oneof is not None - else None) + decode_type = field_descriptor.type + if (decode_type == _FieldDescriptor.TYPE_ENUM and + type_checkers.SupportsOpenEnums(field_descriptor)): + decode_type = _FieldDescriptor.TYPE_INT32 + + oneof_descriptor = None + if field_descriptor.containing_oneof is not None: + oneof_descriptor = field_descriptor + + if is_map_entry: + is_message_map = _IsMessageMapField(field_descriptor) + + field_decoder = decoder.MapDecoder( + field_descriptor, _GetInitializeDefaultForMap(field_descriptor), + is_message_map) + else: + field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor) + + cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False) @@ -237,7 +354,7 @@ def _AttachFieldHelpers(cls, field_descriptor): def _AddClassAttributesForNestedExtensions(descriptor, dictionary): extension_dict = descriptor.extensions_by_name - for extension_name, extension_field in extension_dict.iteritems(): + for extension_name, extension_field in extension_dict.items(): assert extension_name not in dictionary dictionary[extension_name] = extension_field @@ -257,6 +374,26 @@ def _AddEnumValues(descriptor, cls): setattr(cls, enum_value.name, enum_value.number) +def _GetInitializeDefaultForMap(field): + if field.label != _FieldDescriptor.LABEL_REPEATED: + raise ValueError('map_entry set on non-repeated field %s' % ( + field.name)) + fields_by_name = field.message_type.fields_by_name + key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) + + value_field = fields_by_name['value'] + if _IsMessageMapField(field): + def MakeMessageMapDefault(message): + return containers.MessageMap( + message._listener_for_children, value_field.message_type, key_checker) + return MakeMessageMapDefault + else: + value_checker = type_checkers.GetTypeChecker(value_field) + def MakePrimitiveMapDefault(message): + return containers.ScalarMap( + message._listener_for_children, key_checker, value_checker) + return MakePrimitiveMapDefault + def _DefaultValueConstructorForField(field): """Returns a function which returns a default value for a field. @@ -271,6 +408,9 @@ def _DefaultValueConstructorForField(field): value may refer back to |message| via a weak reference. """ + if _IsMapField(field): + return _GetInitializeDefaultForMap(field) + if field.label == _FieldDescriptor.LABEL_REPEATED: if field.has_default_value and field.default_value != []: raise ValueError('Repeated field default value not empty list: %s' % ( @@ -295,7 +435,10 @@ def _DefaultValueConstructorForField(field): message_type = field.message_type def MakeSubMessageDefault(message): result = message_type._concrete_class() - result._SetListener(message._listener_for_children) + result._SetListener( + _OneofListener(message, field) + if field.containing_oneof is not None + else message._listener_for_children) return result return MakeSubMessageDefault @@ -306,9 +449,35 @@ def _DefaultValueConstructorForField(field): return MakeScalarDefault +def _ReraiseTypeErrorWithFieldName(message_name, field_name): + """Re-raise the currently-handled TypeError with the field name added.""" + exc = sys.exc_info()[1] + if len(exc.args) == 1 and type(exc) is TypeError: + # simple TypeError; add field name to exception message + exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) + + # re-raise possibly-amended exception with original traceback: + six.reraise(type(exc), exc, sys.exc_info()[2]) + + def _AddInitMethod(message_descriptor, cls): """Adds an __init__ method to cls.""" - fields = message_descriptor.fields + + def _GetIntegerEnumValue(enum_type, value): + """Convert a string or integer enum value to an integer. + + If the value is a string, it is converted to the enum value in + enum_type with the same name. If the value is not a string, it's + returned as-is. (No conversion or bounds-checking is done.) + """ + if isinstance(value, six.string_types): + try: + return enum_type.values_by_name[value].number + except KeyError: + raise ValueError('Enum type %s: unknown label "%s"' % ( + enum_type.full_name, value)) + return value + def init(self, **kwargs): self._cached_byte_size = 0 self._cached_byte_size_dirty = len(kwargs) > 0 @@ -323,25 +492,52 @@ def _AddInitMethod(message_descriptor, cls): self._is_present_in_parent = False self._listener = message_listener_mod.NullMessageListener() self._listener_for_children = _Listener(self) - for field_name, field_value in kwargs.iteritems(): + for field_name, field_value in kwargs.items(): field = _GetFieldByName(message_descriptor, field_name) if field is None: raise TypeError("%s() got an unexpected keyword argument '%s'" % (message_descriptor.name, field_name)) + if field_value is None: + # field=None is the same as no field at all. + continue if field.label == _FieldDescriptor.LABEL_REPEATED: copy = field._default_constructor(self) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite - for val in field_value: - copy.add().MergeFrom(val) + if _IsMapField(field): + if _IsMessageMapField(field): + for key in field_value: + copy[key].MergeFrom(field_value[key]) + else: + copy.update(field_value) + else: + for val in field_value: + if isinstance(val, dict): + copy.add(**val) + else: + copy.add().MergeFrom(val) else: # Scalar + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + field_value = [_GetIntegerEnumValue(field.enum_type, val) + for val in field_value] copy.extend(field_value) self._fields[field] = copy elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: copy = field._default_constructor(self) - copy.MergeFrom(field_value) + new_val = field_value + if isinstance(field_value, dict): + new_val = field.message_type._concrete_class(**field_value) + try: + copy.MergeFrom(new_val) + except TypeError: + _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) self._fields[field] = copy else: - setattr(self, field_name, field_value) + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + field_value = _GetIntegerEnumValue(field.enum_type, field_value) + try: + setattr(self, field_name, field_value) + except TypeError: + _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) init.__module__ = None init.__doc__ = None @@ -360,7 +556,8 @@ def _GetFieldByName(message_descriptor, field_name): try: return message_descriptor.fields_by_name[field_name] except KeyError: - raise ValueError('Protocol message has no "%s" field.' % field_name) + raise ValueError('Protocol message %s has no "%s" field.' % + (message_descriptor.name, field_name)) def _AddPropertiesForFields(descriptor, cls): @@ -459,6 +656,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): type_checker = type_checkers.GetTypeChecker(field) default_value = field.default_value valid_values = set() + is_proto3 = field.containing_type.syntax == "proto3" def getter(self): # TODO(protobuf-team): This may be broken since there may not be @@ -466,15 +664,24 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): return self._fields.get(field, default_value) getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name + + clear_when_set_to_default = is_proto3 and not field.containing_oneof + def field_setter(self, new_value): # pylint: disable=protected-access - self._fields[field] = type_checker.CheckValue(new_value) + # Testing the value for truthiness captures all of the proto3 defaults + # (0, 0.0, enum 0, and False). + new_value = type_checker.CheckValue(new_value) + if clear_when_set_to_default and not new_value: + self._fields.pop(field, None) + else: + self._fields[field] = new_value # Check _cached_byte_size_dirty inline to improve performance, since scalar # setters are called frequently. if not self._cached_byte_size_dirty: self._Modified() - if field.containing_oneof is not None: + if field.containing_oneof: def setter(self, new_value): field_setter(self, new_value) self._UpdateOneofState(field) @@ -505,21 +712,11 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): proto_field_name = field.name property_name = _PropertyName(proto_field_name) - # TODO(komarek): Can anyone explain to me why we cache the message_type this - # way, instead of referring to field.message_type inside of getter(self)? - # What if someone sets message_type later on (which makes for simpler - # dyanmic proto descriptor and class creation code). - message_type = field.message_type - def getter(self): field_value = self._fields.get(field) if field_value is None: # Construct a new object to represent this field. - field_value = message_type._concrete_class() # use field.message_type? - field_value._SetListener( - _OneofListener(self, field) - if field.containing_oneof is not None - else self._listener_for_children) + field_value = field._default_constructor(self) # Atomically check if another thread has preempted us and, if not, swap # in the new object we just created. If someone has preempted us, we @@ -546,7 +743,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): def _AddPropertiesForExtensions(descriptor, cls): """Adds properties for all fields in this protocol message type.""" extension_dict = descriptor.extensions_by_name - for extension_name, extension_field in extension_dict.iteritems(): + for extension_name, extension_field in extension_dict.items(): constant_name = extension_name.upper() + "_FIELD_NUMBER" setattr(cls, constant_name, extension_field.number) @@ -601,30 +798,41 @@ def _AddListFieldsMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def ListFields(self): - all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] + all_fields = [item for item in self._fields.items() if _IsPresent(item)] all_fields.sort(key = lambda item: item[0].number) return all_fields cls.ListFields = ListFields +_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"' +_Proto2HasError = 'Protocol message has no non-repeated field "%s"' def _AddHasFieldMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - singular_fields = {} + is_proto3 = (message_descriptor.syntax == "proto3") + error_msg = _Proto3HasError if is_proto3 else _Proto2HasError + + hassable_fields = {} for field in message_descriptor.fields: - if field.label != _FieldDescriptor.LABEL_REPEATED: - singular_fields[field.name] = field - # Fields inside oneofs are never repeated (enforced by the compiler). - for field in message_descriptor.oneofs: - singular_fields[field.name] = field + if field.label == _FieldDescriptor.LABEL_REPEATED: + continue + # For proto3, only submessages and fields inside a oneof have presence. + if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and + not field.containing_oneof): + continue + hassable_fields[field.name] = field + + if not is_proto3: + # Fields inside oneofs are never repeated (enforced by the compiler). + for oneof in message_descriptor.oneofs: + hassable_fields[oneof.name] = oneof def HasField(self, field_name): try: - field = singular_fields[field_name] + field = hassable_fields[field_name] except KeyError: - raise ValueError( - 'Protocol message has no singular "%s" field.' % field_name) + raise ValueError(error_msg % field_name) if isinstance(field, descriptor_mod.OneofDescriptor): try: @@ -654,9 +862,15 @@ def _AddClearFieldMethod(message_descriptor, cls): else: return except KeyError: - raise ValueError('Protocol message has no "%s" field.' % field_name) + raise ValueError('Protocol message %s() has no "%s" field.' % + (message_descriptor.name, field_name)) if field in self._fields: + # To match the C++ implementation, we need to invalidate iterators + # for map fields when ClearField() happens. + if hasattr(self._fields[field], 'InvalidateIterators'): + self._fields[field].InvalidateIterators() + # Note: If the field is a sub-message, its listener will still point # at us. That's fine, because the worst than can happen is that it # will call _Modified() and invalidate our byte size. Big deal. @@ -685,16 +899,6 @@ def _AddClearExtensionMethod(cls): cls.ClearExtension = ClearExtension -def _AddClearMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - def Clear(self): - # Clear fields. - self._fields = {} - self._unknown_fields = () - self._Modified() - cls.Clear = Clear - - def _AddHasExtensionMethod(cls): """Helper for _AddMessageMethods().""" def HasExtension(self, extension_handle): @@ -709,6 +913,38 @@ def _AddHasExtensionMethod(cls): return extension_handle in self._fields cls.HasExtension = HasExtension +def _InternalUnpackAny(msg): + """Unpacks Any message and returns the unpacked message. + + This internal method is differnt from public Any Unpack method which takes + the target message as argument. _InternalUnpackAny method does not have + target message type and need to find the message type in descriptor pool. + + Args: + msg: An Any message to be unpacked. + + Returns: + The unpacked message. + """ + type_url = msg.type_url + db = symbol_database.Default() + + if not type_url: + return None + + # TODO(haberman): For now we just strip the hostname. Better logic will be + # required. + type_name = type_url.split("/")[-1] + descriptor = db.pool.FindMessageTypeByName(type_name) + + if descriptor is None: + return None + + message_class = db.GetPrototype(descriptor) + message = message_class() + + message.ParseFromString(msg.value) + return message def _AddEqualsMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" @@ -720,6 +956,12 @@ def _AddEqualsMethod(message_descriptor, cls): if self is other: return True + if self.DESCRIPTOR.full_name == _AnyFullTypeName: + any_a = _InternalUnpackAny(self) + any_b = _InternalUnpackAny(other) + if any_a and any_b: + return any_a == any_b + if not self.ListFields() == other.ListFields(): return False @@ -741,6 +983,13 @@ def _AddStrMethod(message_descriptor, cls): cls.__str__ = __str__ +def _AddReprMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __repr__(self): + return text_format.MessageToString(self) + cls.__repr__ = __repr__ + + def _AddUnicodeMethod(unused_message_descriptor, cls): """Helper for _AddMessageMethods().""" @@ -749,16 +998,6 @@ def _AddUnicodeMethod(unused_message_descriptor, cls): cls.__unicode__ = __unicode__ -def _AddSetListenerMethod(cls): - """Helper for _AddMessageMethods().""" - def SetListener(self, listener): - if listener is None: - self._listener = message_listener_mod.NullMessageListener() - else: - self._listener = listener - cls._SetListener = SetListener - - def _BytesForNonRepeatedElement(value, field_number, field_type): """Returns the number of bytes needed to serialize a non-repeated element. The returned byte count includes space for tag information and any @@ -845,7 +1084,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls): except (IndexError, TypeError): # Now ord(buf[p:p+1]) == ord('') gets TypeError. raise message_mod.DecodeError('Truncated message.') - except struct.error, e: + except struct.error as e: raise message_mod.DecodeError(e) return length # Return this for legacy reasons. cls.MergeFromString = MergeFromString @@ -853,6 +1092,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls): local_ReadTag = decoder.ReadTag local_SkipField = decoder.SkipField decoders_by_tag = cls._decoders_by_tag + is_proto3 = message_descriptor.syntax == "proto3" def InternalParse(self, buffer, pos, end): self._Modified() @@ -866,9 +1106,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls): new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) if new_pos == -1: return pos - if not unknown_field_list: - unknown_field_list = self._unknown_fields = [] - unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) + if not is_proto3: + if not unknown_field_list: + unknown_field_list = self._unknown_fields = [] + unknown_field_list.append( + (tag_bytes, buffer[value_start_pos:new_pos])) pos = new_pos else: pos = field_decoder(buffer, new_pos, end, self, field_dict) @@ -909,6 +1151,9 @@ def _AddIsInitializedMethod(message_descriptor, cls): for field, value in list(self._fields.items()): # dict can change size! if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.label == _FieldDescriptor.LABEL_REPEATED: + if (field.message_type.has_options and + field.message_type.GetOptions().map_entry): + continue for element in value: if not element.IsInitialized(): if errors is not None: @@ -944,16 +1189,26 @@ def _AddIsInitializedMethod(message_descriptor, cls): else: name = field.name - if field.label == _FieldDescriptor.LABEL_REPEATED: - for i in xrange(len(value)): + if _IsMapField(field): + if _IsMessageMapField(field): + for key in value: + element = value[key] + prefix = "%s[%s]." % (name, key) + sub_errors = element.FindInitializationErrors() + errors += [prefix + error for error in sub_errors] + else: + # ScalarMaps can't have any initialization errors. + pass + elif field.label == _FieldDescriptor.LABEL_REPEATED: + for i in range(len(value)): element = value[i] prefix = "%s[%d]." % (name, i) sub_errors = element.FindInitializationErrors() - errors += [ prefix + error for error in sub_errors ] + errors += [prefix + error for error in sub_errors] else: prefix = name + "." sub_errors = value.FindInitializationErrors() - errors += [ prefix + error for error in sub_errors ] + errors += [prefix + error for error in sub_errors] return errors @@ -975,7 +1230,7 @@ def _AddMergeFromMethod(cls): fields = self._fields - for field, value in msg._fields.iteritems(): + for field, value in msg._fields.items(): if field.label == LABEL_REPEATED: field_value = fields.get(field) if field_value is None: @@ -993,6 +1248,8 @@ def _AddMergeFromMethod(cls): field_value.MergeFrom(value) else: self._fields[field] = value + if field.containing_oneof: + self._UpdateOneofState(field) if msg._unknown_fields: if not self._unknown_fields: @@ -1020,6 +1277,32 @@ def _AddWhichOneofMethod(message_descriptor, cls): cls.WhichOneof = WhichOneof +def _Clear(self): + # Clear fields. + self._fields = {} + self._unknown_fields = () + self._oneofs = {} + self._Modified() + + +def _DiscardUnknownFields(self): + self._unknown_fields = [] + for field, value in self.ListFields(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.label == _FieldDescriptor.LABEL_REPEATED: + for sub_message in value: + sub_message.DiscardUnknownFields() + else: + value.DiscardUnknownFields() + + +def _SetListener(self, listener): + if listener is None: + self._listener = message_listener_mod.NullMessageListener() + else: + self._listener = listener + + def _AddMessageMethods(message_descriptor, cls): """Adds implementations of all Message methods to cls.""" _AddListFieldsMethod(message_descriptor, cls) @@ -1028,11 +1311,10 @@ def _AddMessageMethods(message_descriptor, cls): if message_descriptor.is_extendable: _AddClearExtensionMethod(cls) _AddHasExtensionMethod(cls) - _AddClearMethod(message_descriptor, cls) _AddEqualsMethod(message_descriptor, cls) _AddStrMethod(message_descriptor, cls) + _AddReprMethod(message_descriptor, cls) _AddUnicodeMethod(message_descriptor, cls) - _AddSetListenerMethod(cls) _AddByteSizeMethod(message_descriptor, cls) _AddSerializeToStringMethod(message_descriptor, cls) _AddSerializePartialToStringMethod(message_descriptor, cls) @@ -1040,6 +1322,11 @@ def _AddMessageMethods(message_descriptor, cls): _AddIsInitializedMethod(message_descriptor, cls) _AddMergeFromMethod(cls) _AddWhichOneofMethod(message_descriptor, cls) + # Adds methods which do not depend on cls. + cls.Clear = _Clear + cls.DiscardUnknownFields = _DiscardUnknownFields + cls._SetListener = _SetListener + def _AddPrivateHelperMethods(message_descriptor, cls): """Adds implementation of private helper methods to cls.""" @@ -1232,11 +1519,10 @@ class _ExtensionDict(object): # It's slightly wasteful to lookup the type checker each time, # but we expect this to be a vanishingly uncommon case anyway. - type_checker = type_checkers.GetTypeChecker( - extension_handle) + type_checker = type_checkers.GetTypeChecker(extension_handle) # pylint: disable=protected-access self._extended_message._fields[extension_handle] = ( - type_checker.CheckValue(value)) + type_checker.CheckValue(value)) self._extended_message._Modified() def _FindExtensionByName(self, name): @@ -1249,3 +1535,14 @@ class _ExtensionDict(object): Extension field descriptor. """ return self._extended_message._extensions_by_name.get(name, None) + + def _FindExtensionByNumber(self, number): + """Tries to find a known extension with the field number. + + Args: + number: Extension field number. + + Returns: + Extension field descriptor. + """ + return self._extended_message._extensions_by_number.get(number, None) diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index d59815d00..6dc2fffe2 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # -*- coding: utf-8 -*- # # Protocol Buffers - Google's data interchange format @@ -35,14 +35,17 @@ pure-Python protocol compiler. """ -__author__ = 'robinson@google.com (Will Robinson)' - import copy import gc import operator +import six import struct -from google.apputils import basetest +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 @@ -54,6 +57,7 @@ from google.protobuf import text_format from google.protobuf.internal import api_implementation from google.protobuf.internal import more_extensions_pb2 from google.protobuf.internal import more_messages_pb2 +from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf.internal import wire_format from google.protobuf.internal import test_util from google.protobuf.internal import decoder @@ -104,7 +108,7 @@ class _MiniDecoder(object): return self._pos == len(self._bytes) -class ReflectionTest(basetest.TestCase): +class ReflectionTest(unittest.TestCase): def assertListsEqual(self, values, others): self.assertEqual(len(values), len(others)) @@ -116,11 +120,13 @@ class ReflectionTest(basetest.TestCase): proto = unittest_pb2.TestAllTypes( optional_int32=24, optional_double=54.321, - optional_string='optional_string') + optional_string='optional_string', + optional_float=None) self.assertEqual(24, proto.optional_int32) self.assertEqual(54.321, proto.optional_double) self.assertEqual('optional_string', proto.optional_string) + self.assertFalse(proto.HasField("optional_float")) def testRepeatedScalarConstructor(self): # Constructor with only repeated scalar types should succeed. @@ -128,12 +134,14 @@ class ReflectionTest(basetest.TestCase): repeated_int32=[1, 2, 3, 4], repeated_double=[1.23, 54.321], repeated_bool=[True, False, False], - repeated_string=["optional_string"]) + repeated_string=["optional_string"], + repeated_float=None) - self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32)) - self.assertEquals([1.23, 54.321], list(proto.repeated_double)) - self.assertEquals([True, False, False], list(proto.repeated_bool)) - self.assertEquals(["optional_string"], list(proto.repeated_string)) + self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32)) + self.assertEqual([1.23, 54.321], list(proto.repeated_double)) + self.assertEqual([True, False, False], list(proto.repeated_bool)) + self.assertEqual(["optional_string"], list(proto.repeated_string)) + self.assertEqual([], list(proto.repeated_float)) def testRepeatedCompositeConstructor(self): # Constructor with only repeated composite types should succeed. @@ -152,18 +160,18 @@ class ReflectionTest(basetest.TestCase): unittest_pb2.TestAllTypes.RepeatedGroup(a=1), unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) - self.assertEquals( + self.assertEqual( [unittest_pb2.TestAllTypes.NestedMessage( bb=unittest_pb2.TestAllTypes.FOO), unittest_pb2.TestAllTypes.NestedMessage( bb=unittest_pb2.TestAllTypes.BAR)], list(proto.repeated_nested_message)) - self.assertEquals( + self.assertEqual( [unittest_pb2.ForeignMessage(c=-43), unittest_pb2.ForeignMessage(c=45324), unittest_pb2.ForeignMessage(c=12)], list(proto.repeated_foreign_message)) - self.assertEquals( + self.assertEqual( [unittest_pb2.TestAllTypes.RepeatedGroup(), unittest_pb2.TestAllTypes.RepeatedGroup(a=1), unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], @@ -184,23 +192,25 @@ class ReflectionTest(basetest.TestCase): repeated_foreign_message=[ unittest_pb2.ForeignMessage(c=-43), unittest_pb2.ForeignMessage(c=45324), - unittest_pb2.ForeignMessage(c=12)]) + unittest_pb2.ForeignMessage(c=12)], + optional_nested_message=None) self.assertEqual(24, proto.optional_int32) self.assertEqual('optional_string', proto.optional_string) - self.assertEquals([1.23, 54.321], list(proto.repeated_double)) - self.assertEquals([True, False, False], list(proto.repeated_bool)) - self.assertEquals( + self.assertEqual([1.23, 54.321], list(proto.repeated_double)) + self.assertEqual([True, False, False], list(proto.repeated_bool)) + self.assertEqual( [unittest_pb2.TestAllTypes.NestedMessage( bb=unittest_pb2.TestAllTypes.FOO), unittest_pb2.TestAllTypes.NestedMessage( bb=unittest_pb2.TestAllTypes.BAR)], list(proto.repeated_nested_message)) - self.assertEquals( + self.assertEqual( [unittest_pb2.ForeignMessage(c=-43), unittest_pb2.ForeignMessage(c=45324), unittest_pb2.ForeignMessage(c=12)], list(proto.repeated_foreign_message)) + self.assertFalse(proto.HasField("optional_nested_message")) def testConstructorTypeError(self): self.assertRaises( @@ -224,18 +234,18 @@ class ReflectionTest(basetest.TestCase): def testConstructorInvalidatesCachedByteSize(self): message = unittest_pb2.TestAllTypes(optional_int32 = 12) - self.assertEquals(2, message.ByteSize()) + self.assertEqual(2, message.ByteSize()) message = unittest_pb2.TestAllTypes( optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage()) - self.assertEquals(3, message.ByteSize()) + self.assertEqual(3, message.ByteSize()) message = unittest_pb2.TestAllTypes(repeated_int32 = [12]) - self.assertEquals(3, message.ByteSize()) + self.assertEqual(3, message.ByteSize()) message = unittest_pb2.TestAllTypes( repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()]) - self.assertEquals(3, message.ByteSize()) + self.assertEqual(3, message.ByteSize()) def testSimpleHasBits(self): # Test a scalar. @@ -469,7 +479,7 @@ class ReflectionTest(basetest.TestCase): proto.repeated_string.extend(['foo', 'bar']) proto.repeated_string.extend([]) proto.repeated_string.append('baz') - proto.repeated_string.extend(str(x) for x in xrange(2)) + proto.repeated_string.extend(str(x) for x in range(2)) proto.optional_int32 = 21 proto.repeated_bool # Access but don't set anything; should not be listed. self.assertEqual( @@ -611,14 +621,18 @@ class ReflectionTest(basetest.TestCase): def TestGetAndDeserialize(field_name, value, expected_type): proto = unittest_pb2.TestAllTypes() setattr(proto, field_name, value) - self.assertTrue(isinstance(getattr(proto, field_name), expected_type)) + self.assertIsInstance(getattr(proto, field_name), expected_type) proto2 = unittest_pb2.TestAllTypes() proto2.ParseFromString(proto.SerializeToString()) - self.assertTrue(isinstance(getattr(proto2, field_name), expected_type)) + self.assertIsInstance(getattr(proto2, field_name), expected_type) TestGetAndDeserialize('optional_int32', 1, int) TestGetAndDeserialize('optional_int32', 1 << 30, int) TestGetAndDeserialize('optional_uint32', 1 << 30, int) + try: + integer_64 = long + except NameError: # Python3 + integer_64 = int if struct.calcsize('L') == 4: # Python only has signed ints, so 32-bit python can't fit an uint32 # in an int. @@ -626,10 +640,10 @@ class ReflectionTest(basetest.TestCase): else: # 64-bit python can fit uint32 inside an int TestGetAndDeserialize('optional_uint32', 1 << 31, int) - TestGetAndDeserialize('optional_int64', 1 << 30, long) - TestGetAndDeserialize('optional_int64', 1 << 60, long) - TestGetAndDeserialize('optional_uint64', 1 << 30, long) - TestGetAndDeserialize('optional_uint64', 1 << 60, long) + TestGetAndDeserialize('optional_int64', 1 << 30, integer_64) + TestGetAndDeserialize('optional_int64', 1 << 60, integer_64) + TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64) + TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64) def testSingleScalarBoundsChecking(self): def TestMinAndMaxIntegers(field_name, expected_min, expected_max): @@ -755,18 +769,18 @@ class ReflectionTest(basetest.TestCase): def testEnum_KeysAndValues(self): self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'], - unittest_pb2.ForeignEnum.keys()) + list(unittest_pb2.ForeignEnum.keys())) self.assertEqual([4, 5, 6], - unittest_pb2.ForeignEnum.values()) + list(unittest_pb2.ForeignEnum.values())) self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), ('FOREIGN_BAZ', 6)], - unittest_pb2.ForeignEnum.items()) + list(unittest_pb2.ForeignEnum.items())) proto = unittest_pb2.TestAllTypes() - self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys()) - self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values()) + self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], list(proto.NestedEnum.keys())) + self.assertEqual([1, 2, 3, -1], list(proto.NestedEnum.values())) self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], - proto.NestedEnum.items()) + list(proto.NestedEnum.items())) def testRepeatedScalars(self): proto = unittest_pb2.TestAllTypes() @@ -805,7 +819,7 @@ class ReflectionTest(basetest.TestCase): self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) # Test slice assignment with an iterator - proto.repeated_int32[1:4] = (i for i in xrange(3)) + proto.repeated_int32[1:4] = (i for i in range(3)) self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) # Test slice assignment. @@ -896,7 +910,7 @@ class ReflectionTest(basetest.TestCase): self.assertTrue(proto.repeated_nested_message) self.assertEqual(2, len(proto.repeated_nested_message)) self.assertListsEqual([m0, m1], proto.repeated_nested_message) - self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage)) + self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage) # Test out-of-bounds indices. self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, @@ -1008,9 +1022,8 @@ class ReflectionTest(basetest.TestCase): containing_type=None, nested_types=[], enum_types=[], fields=[foo_field_descriptor], extensions=[], options=descriptor_pb2.MessageOptions()) - class MyProtoClass(message.Message): + class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): DESCRIPTOR = mydescriptor - __metaclass__ = reflection.GeneratedProtocolMessageType myproto_instance = MyProtoClass() self.assertEqual(0, myproto_instance.foo_field) self.assertTrue(not myproto_instance.HasField('foo_field')) @@ -1050,14 +1063,13 @@ class ReflectionTest(basetest.TestCase): new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED desc = descriptor.MakeDescriptor(desc_proto) - self.assertTrue(desc.fields_by_name.has_key('name')) - self.assertTrue(desc.fields_by_name.has_key('year')) - self.assertTrue(desc.fields_by_name.has_key('automatic')) - self.assertTrue(desc.fields_by_name.has_key('price')) - self.assertTrue(desc.fields_by_name.has_key('owners')) - - class CarMessage(message.Message): - __metaclass__ = reflection.GeneratedProtocolMessageType + self.assertTrue('name' in desc.fields_by_name) + self.assertTrue('year' in desc.fields_by_name) + self.assertTrue('automatic' in desc.fields_by_name) + self.assertTrue('price' in desc.fields_by_name) + self.assertTrue('owners' in desc.fields_by_name) + + class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): DESCRIPTOR = desc prius = CarMessage() @@ -1175,7 +1187,7 @@ class ReflectionTest(basetest.TestCase): self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) # Make sure extensions haven't been registered into types that shouldn't # have any. - self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) + self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) # If message A directly contains message B, and # a.HasField('b') is currently False, then mutating any @@ -1252,15 +1264,18 @@ class ReflectionTest(basetest.TestCase): # Try something that *is* an extension handle, just not for # this message... - unknown_handle = more_extensions_pb2.optional_int_extension - self.assertRaises(KeyError, extendee_proto.HasExtension, - unknown_handle) - self.assertRaises(KeyError, extendee_proto.ClearExtension, - unknown_handle) - self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, - unknown_handle) - self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, - unknown_handle, 5) + for unknown_handle in (more_extensions_pb2.optional_int_extension, + more_extensions_pb2.optional_message_extension, + more_extensions_pb2.repeated_int_extension, + more_extensions_pb2.repeated_message_extension): + self.assertRaises(KeyError, extendee_proto.HasExtension, + unknown_handle) + self.assertRaises(KeyError, extendee_proto.ClearExtension, + unknown_handle) + self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, + unknown_handle) + self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, + unknown_handle, 5) # Try call HasExtension() with a valid handle, but for a # *repeated* field. (Just as with non-extension repeated @@ -1496,18 +1511,18 @@ class ReflectionTest(basetest.TestCase): test_util.SetAllNonLazyFields(proto) # Clear the message. proto.Clear() - self.assertEquals(proto.ByteSize(), 0) + self.assertEqual(proto.ByteSize(), 0) empty_proto = unittest_pb2.TestAllTypes() - self.assertEquals(proto, empty_proto) + self.assertEqual(proto, empty_proto) # Test if extensions which were set are cleared. proto = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(proto) # Clear the message. proto.Clear() - self.assertEquals(proto.ByteSize(), 0) + self.assertEqual(proto.ByteSize(), 0) empty_proto = unittest_pb2.TestAllExtensions() - self.assertEquals(proto, empty_proto) + self.assertEqual(proto, empty_proto) def testDisconnectingBeforeClear(self): proto = unittest_pb2.TestAllTypes() @@ -1618,7 +1633,7 @@ class ReflectionTest(basetest.TestCase): self.assertFalse(proto.IsInitialized(errors)) self.assertEqual(errors, ['a', 'b', 'c']) - @basetest.unittest.skipIf( + @unittest.skipIf( api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, 'Errors are only available from the most recent C++ implementation.') def testFileDescriptorErrors(self): @@ -1660,30 +1675,29 @@ class ReflectionTest(basetest.TestCase): setattr, proto, 'optional_bytes', u'unicode object') # Check that the default value is of python's 'unicode' type. - self.assertEqual(type(proto.optional_string), unicode) + self.assertEqual(type(proto.optional_string), six.text_type) - proto.optional_string = unicode('Testing') + proto.optional_string = six.text_type('Testing') self.assertEqual(proto.optional_string, str('Testing')) # Assign a value of type 'str' which can be encoded in UTF-8. proto.optional_string = str('Testing') - self.assertEqual(proto.optional_string, unicode('Testing')) + self.assertEqual(proto.optional_string, six.text_type('Testing')) - # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII. + # Try to assign a 'bytes' object which contains non-UTF-8. self.assertRaises(ValueError, setattr, proto, 'optional_string', b'a\x80a') - if str is bytes: # PY2 - # Assign a 'str' object which contains a UTF-8 encoded string. - self.assertRaises(ValueError, - setattr, proto, 'optional_string', 'Тест') - else: - proto.optional_string = 'Тест' - # No exception thrown. + # No exception: Assign already encoded UTF-8 bytes to a string field. + utf8_bytes = u'Тест'.encode('utf-8') + proto.optional_string = utf8_bytes + # No exception: Assign the a non-ascii unicode object. + proto.optional_string = u'Тест' + # No exception thrown (normal str assignment containing ASCII). proto.optional_string = 'abc' def testStringUTF8Serialization(self): - proto = unittest_mset_pb2.TestMessageSet() - extension_message = unittest_mset_pb2.TestMessageSetExtension2 + proto = message_set_extensions_pb2.TestMessageSet() + extension_message = message_set_extensions_pb2.TestMessageSetExtension2 extension = extension_message.message_set_extension test_utf8 = u'Тест' @@ -1703,19 +1717,18 @@ class ReflectionTest(basetest.TestCase): bytes_read = raw.MergeFromString(serialized) self.assertEqual(len(serialized), bytes_read) - message2 = unittest_mset_pb2.TestMessageSetExtension2() + message2 = message_set_extensions_pb2.TestMessageSetExtension2() self.assertEqual(1, len(raw.item)) # Check that the type_id is the same as the tag ID in the .proto file. - self.assertEqual(raw.item[0].type_id, 1547769) + self.assertEqual(raw.item[0].type_id, 98418634) # Check the actual bytes on the wire. - self.assertTrue( - raw.item[0].message.endswith(test_utf8_bytes)) + self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes)) bytes_read = message2.MergeFromString(raw.item[0].message) self.assertEqual(len(raw.item[0].message), bytes_read) - self.assertEqual(type(message2.str), unicode) + self.assertEqual(type(message2.str), six.text_type) self.assertEqual(message2.str, test_utf8) # The pure Python API throws an exception on MergeFromString(), @@ -1739,7 +1752,7 @@ class ReflectionTest(basetest.TestCase): def testBytesInTextFormat(self): proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', - unicode(proto)) + six.text_type(proto)) def testEmptyNestedMessage(self): proto = unittest_pb2.TestAllTypes() @@ -1774,12 +1787,29 @@ class ReflectionTest(basetest.TestCase): proto.optionalgroup.SetInParent() self.assertTrue(proto.HasField('optionalgroup')) + def testPackageInitializationImport(self): + """Test that we can import nested messages from their __init__.py. + + Such setup is not trivial since at the time of processing of __init__.py one + can't refer to its submodules by name in code, so expressions like + google.protobuf.internal.import_test_package.inner_pb2 + don't work. They do work in imports, so we have assign an alias at import + and then use that alias in generated code. + """ + # We import here since it's the import that used to fail, and we want + # the failure to have the right context. + # pylint: disable=g-import-not-at-top + from google.protobuf.internal import import_test_package + # pylint: enable=g-import-not-at-top + msg = import_test_package.myproto.Outer() + # Just check the default value. + self.assertEqual(57, msg.inner.value) # Since we had so many tests for protocol buffer equality, we broke these out # into separate TestCase classes. -class TestAllTypesEqualityTest(basetest.TestCase): +class TestAllTypesEqualityTest(unittest.TestCase): def setUp(self): self.first_proto = unittest_pb2.TestAllTypes() @@ -1795,7 +1825,7 @@ class TestAllTypesEqualityTest(basetest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class FullProtosEqualityTest(basetest.TestCase): +class FullProtosEqualityTest(unittest.TestCase): """Equality tests using completely-full protos as a starting point.""" @@ -1881,7 +1911,7 @@ class FullProtosEqualityTest(basetest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class ExtensionEqualityTest(basetest.TestCase): +class ExtensionEqualityTest(unittest.TestCase): def testExtensionEquality(self): first_proto = unittest_pb2.TestAllExtensions() @@ -1914,7 +1944,7 @@ class ExtensionEqualityTest(basetest.TestCase): self.assertEqual(first_proto, second_proto) -class MutualRecursionEqualityTest(basetest.TestCase): +class MutualRecursionEqualityTest(unittest.TestCase): def testEqualityWithMutualRecursion(self): first_proto = unittest_pb2.TestMutualRecursionA() @@ -1926,7 +1956,7 @@ class MutualRecursionEqualityTest(basetest.TestCase): self.assertEqual(first_proto, second_proto) -class ByteSizeTest(basetest.TestCase): +class ByteSizeTest(unittest.TestCase): def setUp(self): self.proto = unittest_pb2.TestAllTypes() @@ -2222,7 +2252,7 @@ class ByteSizeTest(basetest.TestCase): # * Handling of empty submessages (with and without "has" # bits set). -class SerializationTest(basetest.TestCase): +class SerializationTest(unittest.TestCase): def testSerializeEmtpyMessage(self): first_proto = unittest_pb2.TestAllTypes() @@ -2289,7 +2319,7 @@ class SerializationTest(basetest.TestCase): test_util.SetAllFields(first_proto) serialized = first_proto.SerializeToString() - for truncation_point in xrange(len(serialized) + 1): + for truncation_point in range(len(serialized) + 1): try: second_proto = unittest_pb2.TestAllTypes() unknown_fields = unittest_pb2.TestEmptyMessage() @@ -2366,13 +2396,15 @@ class SerializationTest(basetest.TestCase): self.assertEqual(42, second_proto.optional_nested_message.bb) def testMessageSetWireFormat(self): - proto = unittest_mset_pb2.TestMessageSet() - extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 - extension_message2 = unittest_mset_pb2.TestMessageSetExtension2 + proto = message_set_extensions_pb2.TestMessageSet() + extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 + extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2 extension1 = extension_message1.message_set_extension extension2 = extension_message2.message_set_extension + extension3 = message_set_extensions_pb2.message_set_extension3 proto.Extensions[extension1].i = 123 proto.Extensions[extension2].str = 'foo' + proto.Extensions[extension3].text = 'bar' # Serialize using the MessageSet wire format (this is specified in the # .proto file). @@ -2384,27 +2416,34 @@ class SerializationTest(basetest.TestCase): self.assertEqual( len(serialized), raw.MergeFromString(serialized)) - self.assertEqual(2, len(raw.item)) + self.assertEqual(3, len(raw.item)) - message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1 = message_set_extensions_pb2.TestMessageSetExtension1() self.assertEqual( len(raw.item[0].message), message1.MergeFromString(raw.item[0].message)) self.assertEqual(123, message1.i) - message2 = unittest_mset_pb2.TestMessageSetExtension2() + message2 = message_set_extensions_pb2.TestMessageSetExtension2() self.assertEqual( len(raw.item[1].message), message2.MergeFromString(raw.item[1].message)) self.assertEqual('foo', message2.str) + message3 = message_set_extensions_pb2.TestMessageSetExtension3() + self.assertEqual( + len(raw.item[2].message), + message3.MergeFromString(raw.item[2].message)) + self.assertEqual('bar', message3.text) + # Deserialize using the MessageSet wire format. - proto2 = unittest_mset_pb2.TestMessageSet() + proto2 = message_set_extensions_pb2.TestMessageSet() self.assertEqual( len(serialized), proto2.MergeFromString(serialized)) self.assertEqual(123, proto2.Extensions[extension1].i) self.assertEqual('foo', proto2.Extensions[extension2].str) + self.assertEqual('bar', proto2.Extensions[extension3].text) # Check byte size. self.assertEqual(proto2.ByteSize(), len(serialized)) @@ -2417,39 +2456,39 @@ class SerializationTest(basetest.TestCase): # Add an item. item = raw.item.add() - item.type_id = 1545008 - extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 - message1 = unittest_mset_pb2.TestMessageSetExtension1() + item.type_id = 98418603 + extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 + message1 = message_set_extensions_pb2.TestMessageSetExtension1() message1.i = 12345 item.message = message1.SerializeToString() # Add a second, unknown extension. item = raw.item.add() - item.type_id = 1545009 - extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 - message1 = unittest_mset_pb2.TestMessageSetExtension1() + item.type_id = 98418604 + extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 + message1 = message_set_extensions_pb2.TestMessageSetExtension1() message1.i = 12346 item.message = message1.SerializeToString() # Add another unknown extension. item = raw.item.add() - item.type_id = 1545010 - message1 = unittest_mset_pb2.TestMessageSetExtension2() + item.type_id = 98418605 + message1 = message_set_extensions_pb2.TestMessageSetExtension2() message1.str = 'foo' item.message = message1.SerializeToString() serialized = raw.SerializeToString() # Parse message using the message set wire format. - proto = unittest_mset_pb2.TestMessageSet() + proto = message_set_extensions_pb2.TestMessageSet() self.assertEqual( len(serialized), proto.MergeFromString(serialized)) # Check that the message parsed well. - extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 extension1 = extension_message1.message_set_extension - self.assertEquals(12345, proto.Extensions[extension1].i) + self.assertEqual(12345, proto.Extensions[extension1].i) def testUnknownFields(self): proto = unittest_pb2.TestAllTypes() @@ -2734,9 +2773,10 @@ class SerializationTest(basetest.TestCase): def testInitArgsUnknownFieldName(self): def InitalizeEmptyMessageWithExtraKeywordArg(): unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') - self._CheckRaises(ValueError, - InitalizeEmptyMessageWithExtraKeywordArg, - 'Protocol message has no "unknown" field.') + self._CheckRaises( + ValueError, + InitalizeEmptyMessageWithExtraKeywordArg, + 'Protocol message TestEmptyMessage has no "unknown" field.') def testInitRequiredKwargs(self): proto = unittest_pb2.TestRequired(a=1, b=1, c=1) @@ -2773,10 +2813,10 @@ class SerializationTest(basetest.TestCase): self.assertEqual(3, proto.repeated_int32[2]) -class OptionsTest(basetest.TestCase): +class OptionsTest(unittest.TestCase): def testMessageOptions(self): - proto = unittest_mset_pb2.TestMessageSet() + proto = message_set_extensions_pb2.TestMessageSet() self.assertEqual(True, proto.DESCRIPTOR.GetOptions().message_set_wire_format) proto = unittest_pb2.TestAllTypes() @@ -2795,13 +2835,16 @@ class OptionsTest(basetest.TestCase): proto.packed_double.append(3.0) for field_descriptor, _ in proto.ListFields(): self.assertEqual(True, field_descriptor.GetOptions().packed) - self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED, + self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED, field_descriptor.label) -class ClassAPITest(basetest.TestCase): +class ClassAPITest(unittest.TestCase): + @unittest.skipIf( + api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, + 'C++ implementation requires a call to MakeDescriptor()') def testMakeClassWithNestedDescriptor(self): leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', containing_type=None, fields=[], @@ -2882,13 +2925,19 @@ class ClassAPITest(basetest.TestCase): def testParsingFlatClassWithExplicitClassDeclaration(self): """Test that the generated class can parse a flat message.""" + # TODO(xiaofeng): This test fails with cpp implemetnation in the call + # of six.with_metaclass(). The other two callsites of with_metaclass + # in this file are both excluded from cpp test, so it might be expected + # to fail. Need someone more familiar with the python code to take a + # look at this. + if api_implementation.Type() != 'python': + return file_descriptor = descriptor_pb2.FileDescriptorProto() file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) msg_descriptor = descriptor.MakeDescriptor( file_descriptor.message_type[0]) - class MessageClass(message.Message): - __metaclass__ = reflection.GeneratedProtocolMessageType + class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): DESCRIPTOR = msg_descriptor msg = MessageClass() msg_str = ( @@ -2931,4 +2980,4 @@ class ClassAPITest(basetest.TestCase): self.assertEqual(msg.bar.baz.deep, 4) if __name__ == '__main__': - basetest.main() + unittest.main() diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py index 07dcf4453..62900b1d1 100755 --- a/python/google/protobuf/internal/service_reflection_test.py +++ b/python/google/protobuf/internal/service_reflection_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -34,13 +34,18 @@ __author__ = 'petar@google.com (Petar Petrov)' -from google.apputils import basetest + +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + from google.protobuf import unittest_pb2 from google.protobuf import service_reflection from google.protobuf import service -class FooUnitTest(basetest.TestCase): +class FooUnitTest(unittest.TestCase): def testService(self): class MockRpcChannel(service.RpcChannel): @@ -80,7 +85,7 @@ class FooUnitTest(basetest.TestCase): self.assertEqual('Method Bar not implemented.', rpc_controller.failure_message) self.assertEqual(None, self.callback_response) - + class MyServiceImpl(unittest_pb2.TestService): def Foo(self, rpc_controller, request, done): self.foo_called = True @@ -118,19 +123,18 @@ class FooUnitTest(basetest.TestCase): rpc_controller = 'controller' request = 'request' - # GetDescriptor now static, still works as instance method for compatability + # GetDescriptor now static, still works as instance method for compatibility self.assertEqual(unittest_pb2.TestService_Stub.GetDescriptor(), stub.GetDescriptor()) # Invoke method. stub.Foo(rpc_controller, request, MyCallback) - self.assertTrue(isinstance(self.callback_response, - unittest_pb2.FooResponse)) + self.assertIsInstance(self.callback_response, unittest_pb2.FooResponse) self.assertEqual(request, channel.request) self.assertEqual(rpc_controller, channel.controller) self.assertEqual(stub.GetDescriptor().methods[0], channel.method) if __name__ == '__main__': - basetest.main() + unittest.main() diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py index 47572d58c..c99b426dc 100644 --- a/python/google/protobuf/internal/symbol_database_test.py +++ b/python/google/protobuf/internal/symbol_database_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -32,24 +32,33 @@ """Tests for google.protobuf.symbol_database.""" -from google.apputils import basetest +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + from google.protobuf import unittest_pb2 +from google.protobuf import descriptor from google.protobuf import symbol_database - -class SymbolDatabaseTest(basetest.TestCase): +class SymbolDatabaseTest(unittest.TestCase): def _Database(self): - db = symbol_database.SymbolDatabase() - # Register representative types from unittest_pb2. - db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR) - db.RegisterMessage(unittest_pb2.TestAllTypes) - db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage) - db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup) - db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup) - db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) - db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) - return db + # TODO(b/17734095): Remove this difference when the C++ implementation + # supports multiple databases. + if descriptor._USE_C_DESCRIPTORS: + return symbol_database.Default() + else: + db = symbol_database.SymbolDatabase() + # Register representative types from unittest_pb2. + db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR) + db.RegisterMessage(unittest_pb2.TestAllTypes) + db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage) + db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup) + db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup) + db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) + db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) + return db def testGetPrototype(self): instance = self._Database().GetPrototype( @@ -64,57 +73,57 @@ class SymbolDatabaseTest(basetest.TestCase): messages['protobuf_unittest.TestAllTypes']) def testGetSymbol(self): - self.assertEquals( + self.assertEqual( unittest_pb2.TestAllTypes, self._Database().GetSymbol( 'protobuf_unittest.TestAllTypes')) - self.assertEquals( + self.assertEqual( unittest_pb2.TestAllTypes.NestedMessage, self._Database().GetSymbol( 'protobuf_unittest.TestAllTypes.NestedMessage')) - self.assertEquals( + self.assertEqual( unittest_pb2.TestAllTypes.OptionalGroup, self._Database().GetSymbol( 'protobuf_unittest.TestAllTypes.OptionalGroup')) - self.assertEquals( + self.assertEqual( unittest_pb2.TestAllTypes.RepeatedGroup, self._Database().GetSymbol( 'protobuf_unittest.TestAllTypes.RepeatedGroup')) def testEnums(self): # Check registration of types in the pool. - self.assertEquals( + self.assertEqual( 'protobuf_unittest.ForeignEnum', self._Database().pool.FindEnumTypeByName( 'protobuf_unittest.ForeignEnum').full_name) - self.assertEquals( + self.assertEqual( 'protobuf_unittest.TestAllTypes.NestedEnum', self._Database().pool.FindEnumTypeByName( 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) def testFindMessageTypeByName(self): - self.assertEquals( + self.assertEqual( 'protobuf_unittest.TestAllTypes', self._Database().pool.FindMessageTypeByName( 'protobuf_unittest.TestAllTypes').full_name) - self.assertEquals( + self.assertEqual( 'protobuf_unittest.TestAllTypes.NestedMessage', self._Database().pool.FindMessageTypeByName( 'protobuf_unittest.TestAllTypes.NestedMessage').full_name) def testFindFindContainingSymbol(self): # Lookup based on either enum or message. - self.assertEquals( + self.assertEqual( 'google/protobuf/unittest.proto', self._Database().pool.FindFileContainingSymbol( 'protobuf_unittest.TestAllTypes.NestedEnum').name) - self.assertEquals( + self.assertEqual( 'google/protobuf/unittest.proto', self._Database().pool.FindFileContainingSymbol( 'protobuf_unittest.TestAllTypes').name) def testFindFileByName(self): - self.assertEquals( + self.assertEqual( 'google/protobuf/unittest.proto', self._Database().pool.FindFileByName( 'google/protobuf/unittest.proto').name) if __name__ == '__main__': - basetest.main() + unittest.main() diff --git a/python/google/protobuf/internal/test_bad_identifiers.proto b/python/google/protobuf/internal/test_bad_identifiers.proto index 9eb18cb09..c4860ea88 100644 --- a/python/google/protobuf/internal/test_bad_identifiers.proto +++ b/python/google/protobuf/internal/test_bad_identifiers.proto @@ -30,6 +30,7 @@ // Author: kenton@google.com (Kenton Varda) +syntax = "proto2"; package protobuf_unittest; diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index 787f46505..2c805599b 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -38,15 +38,23 @@ __author__ = 'robinson@google.com (Will Robinson)' import os.path +import sys + from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 +from google.protobuf import descriptor_pb2 +# Tests whether the given TestAllTypes message is proto2 or not. +# This is used to gate several fields/features that only exist +# for the proto2 version of the message. +def IsProto2(message): + return message.DESCRIPTOR.syntax == "proto2" def SetAllNonLazyFields(message): """Sets every non-lazy field in the message to a unique value. Args: - message: A unittest_pb2.TestAllTypes instance. + message: A TestAllTypes instance. """ # @@ -69,7 +77,8 @@ def SetAllNonLazyFields(message): message.optional_string = u'115' message.optional_bytes = b'116' - message.optionalgroup.a = 117 + if IsProto2(message): + message.optionalgroup.a = 117 message.optional_nested_message.bb = 118 message.optional_foreign_message.c = 119 message.optional_import_message.d = 120 @@ -77,7 +86,8 @@ def SetAllNonLazyFields(message): message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ - message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ + if IsProto2(message): + message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ message.optional_string_piece = u'124' message.optional_cord = u'125' @@ -102,7 +112,8 @@ def SetAllNonLazyFields(message): message.repeated_string.append(u'215') message.repeated_bytes.append(b'216') - message.repeatedgroup.add().a = 217 + if IsProto2(message): + message.repeatedgroup.add().a = 217 message.repeated_nested_message.add().bb = 218 message.repeated_foreign_message.add().c = 219 message.repeated_import_message.add().d = 220 @@ -110,7 +121,8 @@ def SetAllNonLazyFields(message): message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR) message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) - message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR) + if IsProto2(message): + message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR) message.repeated_string_piece.append(u'224') message.repeated_cord.append(u'225') @@ -132,7 +144,8 @@ def SetAllNonLazyFields(message): message.repeated_string.append(u'315') message.repeated_bytes.append(b'316') - message.repeatedgroup.add().a = 317 + if IsProto2(message): + message.repeatedgroup.add().a = 317 message.repeated_nested_message.add().bb = 318 message.repeated_foreign_message.add().c = 319 message.repeated_import_message.add().d = 320 @@ -140,7 +153,8 @@ def SetAllNonLazyFields(message): message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ) message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) - message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ) + if IsProto2(message): + message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ) message.repeated_string_piece.append(u'324') message.repeated_cord.append(u'325') @@ -149,28 +163,29 @@ def SetAllNonLazyFields(message): # Fields that have defaults. # - message.default_int32 = 401 - message.default_int64 = 402 - message.default_uint32 = 403 - message.default_uint64 = 404 - message.default_sint32 = 405 - message.default_sint64 = 406 - message.default_fixed32 = 407 - message.default_fixed64 = 408 - message.default_sfixed32 = 409 - message.default_sfixed64 = 410 - message.default_float = 411 - message.default_double = 412 - message.default_bool = False - message.default_string = '415' - message.default_bytes = b'416' - - message.default_nested_enum = unittest_pb2.TestAllTypes.FOO - message.default_foreign_enum = unittest_pb2.FOREIGN_FOO - message.default_import_enum = unittest_import_pb2.IMPORT_FOO - - message.default_string_piece = '424' - message.default_cord = '425' + if IsProto2(message): + message.default_int32 = 401 + message.default_int64 = 402 + message.default_uint32 = 403 + message.default_uint64 = 404 + message.default_sint32 = 405 + message.default_sint64 = 406 + message.default_fixed32 = 407 + message.default_fixed64 = 408 + message.default_sfixed32 = 409 + message.default_sfixed64 = 410 + message.default_float = 411 + message.default_double = 412 + message.default_bool = False + message.default_string = '415' + message.default_bytes = b'416' + + message.default_nested_enum = unittest_pb2.TestAllTypes.FOO + message.default_foreign_enum = unittest_pb2.FOREIGN_FOO + message.default_import_enum = unittest_import_pb2.IMPORT_FOO + + message.default_string_piece = '424' + message.default_cord = '425' message.oneof_uint32 = 601 message.oneof_nested_message.bb = 602 @@ -386,7 +401,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertTrue(message.HasField('optional_string')) test_case.assertTrue(message.HasField('optional_bytes')) - test_case.assertTrue(message.HasField('optionalgroup')) + if IsProto2(message): + test_case.assertTrue(message.HasField('optionalgroup')) test_case.assertTrue(message.HasField('optional_nested_message')) test_case.assertTrue(message.HasField('optional_foreign_message')) test_case.assertTrue(message.HasField('optional_import_message')) @@ -398,7 +414,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertTrue(message.HasField('optional_nested_enum')) test_case.assertTrue(message.HasField('optional_foreign_enum')) - test_case.assertTrue(message.HasField('optional_import_enum')) + if IsProto2(message): + test_case.assertTrue(message.HasField('optional_import_enum')) test_case.assertTrue(message.HasField('optional_string_piece')) test_case.assertTrue(message.HasField('optional_cord')) @@ -419,7 +436,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual('115', message.optional_string) test_case.assertEqual(b'116', message.optional_bytes) - test_case.assertEqual(117, message.optionalgroup.a) + if IsProto2(message): + test_case.assertEqual(117, message.optionalgroup.a) test_case.assertEqual(118, message.optional_nested_message.bb) test_case.assertEqual(119, message.optional_foreign_message.c) test_case.assertEqual(120, message.optional_import_message.d) @@ -430,8 +448,9 @@ def ExpectAllFieldsSet(test_case, message): message.optional_nested_enum) test_case.assertEqual(unittest_pb2.FOREIGN_BAZ, message.optional_foreign_enum) - test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ, - message.optional_import_enum) + if IsProto2(message): + test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ, + message.optional_import_enum) # ----------------------------------------------------------------- @@ -451,13 +470,15 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(2, len(message.repeated_string)) test_case.assertEqual(2, len(message.repeated_bytes)) - test_case.assertEqual(2, len(message.repeatedgroup)) + if IsProto2(message): + test_case.assertEqual(2, len(message.repeatedgroup)) test_case.assertEqual(2, len(message.repeated_nested_message)) test_case.assertEqual(2, len(message.repeated_foreign_message)) test_case.assertEqual(2, len(message.repeated_import_message)) test_case.assertEqual(2, len(message.repeated_nested_enum)) test_case.assertEqual(2, len(message.repeated_foreign_enum)) - test_case.assertEqual(2, len(message.repeated_import_enum)) + if IsProto2(message): + test_case.assertEqual(2, len(message.repeated_import_enum)) test_case.assertEqual(2, len(message.repeated_string_piece)) test_case.assertEqual(2, len(message.repeated_cord)) @@ -478,7 +499,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual('215', message.repeated_string[0]) test_case.assertEqual(b'216', message.repeated_bytes[0]) - test_case.assertEqual(217, message.repeatedgroup[0].a) + if IsProto2(message): + test_case.assertEqual(217, message.repeatedgroup[0].a) test_case.assertEqual(218, message.repeated_nested_message[0].bb) test_case.assertEqual(219, message.repeated_foreign_message[0].c) test_case.assertEqual(220, message.repeated_import_message[0].d) @@ -488,8 +510,9 @@ def ExpectAllFieldsSet(test_case, message): message.repeated_nested_enum[0]) test_case.assertEqual(unittest_pb2.FOREIGN_BAR, message.repeated_foreign_enum[0]) - test_case.assertEqual(unittest_import_pb2.IMPORT_BAR, - message.repeated_import_enum[0]) + if IsProto2(message): + test_case.assertEqual(unittest_import_pb2.IMPORT_BAR, + message.repeated_import_enum[0]) test_case.assertEqual(301, message.repeated_int32[1]) test_case.assertEqual(302, message.repeated_int64[1]) @@ -507,7 +530,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual('315', message.repeated_string[1]) test_case.assertEqual(b'316', message.repeated_bytes[1]) - test_case.assertEqual(317, message.repeatedgroup[1].a) + if IsProto2(message): + test_case.assertEqual(317, message.repeatedgroup[1].a) test_case.assertEqual(318, message.repeated_nested_message[1].bb) test_case.assertEqual(319, message.repeated_foreign_message[1].c) test_case.assertEqual(320, message.repeated_import_message[1].d) @@ -517,53 +541,55 @@ def ExpectAllFieldsSet(test_case, message): message.repeated_nested_enum[1]) test_case.assertEqual(unittest_pb2.FOREIGN_BAZ, message.repeated_foreign_enum[1]) - test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ, - message.repeated_import_enum[1]) + if IsProto2(message): + test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ, + message.repeated_import_enum[1]) # ----------------------------------------------------------------- - test_case.assertTrue(message.HasField('default_int32')) - test_case.assertTrue(message.HasField('default_int64')) - test_case.assertTrue(message.HasField('default_uint32')) - test_case.assertTrue(message.HasField('default_uint64')) - test_case.assertTrue(message.HasField('default_sint32')) - test_case.assertTrue(message.HasField('default_sint64')) - test_case.assertTrue(message.HasField('default_fixed32')) - test_case.assertTrue(message.HasField('default_fixed64')) - test_case.assertTrue(message.HasField('default_sfixed32')) - test_case.assertTrue(message.HasField('default_sfixed64')) - test_case.assertTrue(message.HasField('default_float')) - test_case.assertTrue(message.HasField('default_double')) - test_case.assertTrue(message.HasField('default_bool')) - test_case.assertTrue(message.HasField('default_string')) - test_case.assertTrue(message.HasField('default_bytes')) - - test_case.assertTrue(message.HasField('default_nested_enum')) - test_case.assertTrue(message.HasField('default_foreign_enum')) - test_case.assertTrue(message.HasField('default_import_enum')) - - test_case.assertEqual(401, message.default_int32) - test_case.assertEqual(402, message.default_int64) - test_case.assertEqual(403, message.default_uint32) - test_case.assertEqual(404, message.default_uint64) - test_case.assertEqual(405, message.default_sint32) - test_case.assertEqual(406, message.default_sint64) - test_case.assertEqual(407, message.default_fixed32) - test_case.assertEqual(408, message.default_fixed64) - test_case.assertEqual(409, message.default_sfixed32) - test_case.assertEqual(410, message.default_sfixed64) - test_case.assertEqual(411, message.default_float) - test_case.assertEqual(412, message.default_double) - test_case.assertEqual(False, message.default_bool) - test_case.assertEqual('415', message.default_string) - test_case.assertEqual(b'416', message.default_bytes) - - test_case.assertEqual(unittest_pb2.TestAllTypes.FOO, - message.default_nested_enum) - test_case.assertEqual(unittest_pb2.FOREIGN_FOO, - message.default_foreign_enum) - test_case.assertEqual(unittest_import_pb2.IMPORT_FOO, - message.default_import_enum) + if IsProto2(message): + test_case.assertTrue(message.HasField('default_int32')) + test_case.assertTrue(message.HasField('default_int64')) + test_case.assertTrue(message.HasField('default_uint32')) + test_case.assertTrue(message.HasField('default_uint64')) + test_case.assertTrue(message.HasField('default_sint32')) + test_case.assertTrue(message.HasField('default_sint64')) + test_case.assertTrue(message.HasField('default_fixed32')) + test_case.assertTrue(message.HasField('default_fixed64')) + test_case.assertTrue(message.HasField('default_sfixed32')) + test_case.assertTrue(message.HasField('default_sfixed64')) + test_case.assertTrue(message.HasField('default_float')) + test_case.assertTrue(message.HasField('default_double')) + test_case.assertTrue(message.HasField('default_bool')) + test_case.assertTrue(message.HasField('default_string')) + test_case.assertTrue(message.HasField('default_bytes')) + + test_case.assertTrue(message.HasField('default_nested_enum')) + test_case.assertTrue(message.HasField('default_foreign_enum')) + test_case.assertTrue(message.HasField('default_import_enum')) + + test_case.assertEqual(401, message.default_int32) + test_case.assertEqual(402, message.default_int64) + test_case.assertEqual(403, message.default_uint32) + test_case.assertEqual(404, message.default_uint64) + test_case.assertEqual(405, message.default_sint32) + test_case.assertEqual(406, message.default_sint64) + test_case.assertEqual(407, message.default_fixed32) + test_case.assertEqual(408, message.default_fixed64) + test_case.assertEqual(409, message.default_sfixed32) + test_case.assertEqual(410, message.default_sfixed64) + test_case.assertEqual(411, message.default_float) + test_case.assertEqual(412, message.default_double) + test_case.assertEqual(False, message.default_bool) + test_case.assertEqual('415', message.default_string) + test_case.assertEqual(b'416', message.default_bytes) + + test_case.assertEqual(unittest_pb2.TestAllTypes.FOO, + message.default_nested_enum) + test_case.assertEqual(unittest_pb2.FOREIGN_FOO, + message.default_foreign_enum) + test_case.assertEqual(unittest_import_pb2.IMPORT_FOO, + message.default_import_enum) def GoldenFile(filename): @@ -578,6 +604,14 @@ def GoldenFile(filename): return open(full_path, 'rb') path = os.path.join(path, '..') + # Search internally. + path = '.' + full_path = os.path.join(path, 'third_party/py/google/protobuf/testdata', + filename) + if os.path.exists(full_path): + # Found it. Load the golden file from the testdata directory. + return open(full_path, 'rb') + raise RuntimeError( 'Could not find golden files. This test must be run from within the ' 'protobuf source package so that it can read test data files from the ' @@ -594,7 +628,7 @@ def SetAllPackedFields(message): """Sets every field in the message to a unique value. Args: - message: A unittest_pb2.TestPackedTypes instance. + message: A TestPackedTypes instance. """ message.packed_int32.extend([601, 701]) message.packed_int64.extend([602, 702]) diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py index db0222bd3..c7d182c44 100755 --- a/python/google/protobuf/internal/text_encoding_test.py +++ b/python/google/protobuf/internal/text_encoding_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -32,7 +32,11 @@ """Tests for google.protobuf.text_encoding.""" -from google.apputils import basetest +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + from google.protobuf import text_encoding TEST_VALUES = [ @@ -50,19 +54,19 @@ TEST_VALUES = [ b"\010\011\012\013\014\015")] -class TextEncodingTestCase(basetest.TestCase): +class TextEncodingTestCase(unittest.TestCase): def testCEscape(self): for escaped, escaped_utf8, unescaped in TEST_VALUES: - self.assertEquals(escaped, + self.assertEqual(escaped, text_encoding.CEscape(unescaped, as_utf8=False)) - self.assertEquals(escaped_utf8, + self.assertEqual(escaped_utf8, text_encoding.CEscape(unescaped, as_utf8=True)) def testCUnescape(self): for escaped, escaped_utf8, unescaped in TEST_VALUES: - self.assertEquals(unescaped, text_encoding.CUnescape(escaped)) - self.assertEquals(unescaped, text_encoding.CUnescape(escaped_utf8)) + self.assertEqual(unescaped, text_encoding.CUnescape(escaped)) + self.assertEqual(unescaped, text_encoding.CUnescape(escaped_utf8)) if __name__ == "__main__": - basetest.main() + unittest.main() diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index b0a3a5f72..ab2bf05b8 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -34,16 +34,42 @@ __author__ = 'kenton@google.com (Kenton Varda)' + import re +import six +import string -from google.apputils import basetest -from google.protobuf import text_format +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + +from google.protobuf.internal import _parameterized + +from google.protobuf import map_unittest_pb2 +from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util -from google.protobuf import unittest_pb2 -from google.protobuf import unittest_mset_pb2 +from google.protobuf.internal import message_set_extensions_pb2 +from google.protobuf import text_format -class TextFormatTest(basetest.TestCase): + +# Low-level nuts-n-bolts tests. +class SimpleTextFormatTests(unittest.TestCase): + + # The members of _QUOTES are formatted into a regexp template that + # expects single characters. Therefore it's an error (in addition to being + # non-sensical in the first place) to try to specify a "quote mark" that is + # more than one character. + def testQuoteMarksAreSingleChars(self): + for quote in text_format._QUOTES: + self.assertEqual(1, len(quote)) + + +# Base class with some common functionality. +class TextFormatBase(unittest.TestCase): def ReadGolden(self, golden_filename): with test_util.GoldenFile(golden_filename) as f: @@ -55,70 +81,26 @@ class TextFormatTest(basetest.TestCase): self.assertMultiLineEqual(text, ''.join(golden_lines)) def CompareToGoldenText(self, text, golden_text): - self.assertMultiLineEqual(text, golden_text) + self.assertEqual(text, golden_text) - def testPrintAllFields(self): - message = unittest_pb2.TestAllTypes() - test_util.SetAllFields(message) - self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'text_format_unittest_data_oneof_implemented.txt') - - def testPrintInIndexOrder(self): - message = unittest_pb2.TestFieldOrderings() - message.my_string = '115' - message.my_int = 101 - message.my_float = 111 - self.CompareToGoldenText( - self.RemoveRedundantZeros(text_format.MessageToString( - message, use_index_order=True)), - 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n') - self.CompareToGoldenText( - self.RemoveRedundantZeros(text_format.MessageToString( - message)), 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n') - - def testPrintAllExtensions(self): - message = unittest_pb2.TestAllExtensions() - test_util.SetAllExtensions(message) - self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'text_format_unittest_extensions_data.txt') + def RemoveRedundantZeros(self, text): + # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove + # these zeros in order to match the golden file. + text = text.replace('e+0','e+').replace('e+0','e+') \ + .replace('e-0','e-').replace('e-0','e-') + # Floating point fields are printed with .0 suffix even if they are + # actualy integer numbers. + text = re.compile('\.0$', re.MULTILINE).sub('', text) + return text - def testPrintAllFieldsPointy(self): - message = unittest_pb2.TestAllTypes() - test_util.SetAllFields(message) - self.CompareToGoldenFile( - self.RemoveRedundantZeros( - text_format.MessageToString(message, pointy_brackets=True)), - 'text_format_unittest_data_pointy_oneof.txt') - def testPrintAllExtensionsPointy(self): - message = unittest_pb2.TestAllExtensions() - test_util.SetAllExtensions(message) - self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString( - message, pointy_brackets=True)), - 'text_format_unittest_extensions_data_pointy.txt') +@_parameterized.Parameters( + (unittest_pb2), + (unittest_proto3_arena_pb2)) +class TextFormatTest(TextFormatBase): - def testPrintMessageSet(self): - message = unittest_mset_pb2.TestMessageSetContainer() - ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension - ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension - message.message_set.Extensions[ext1].i = 23 - message.message_set.Extensions[ext2].str = 'foo' - self.CompareToGoldenText( - text_format.MessageToString(message), - 'message_set {\n' - ' [protobuf_unittest.TestMessageSetExtension1] {\n' - ' i: 23\n' - ' }\n' - ' [protobuf_unittest.TestMessageSetExtension2] {\n' - ' str: \"foo\"\n' - ' }\n' - '}\n') - - def testPrintExotic(self): - message = unittest_pb2.TestAllTypes() + def testPrintExotic(self, message_module): + message = message_module.TestAllTypes() message.repeated_int64.append(-9223372036854775808) message.repeated_uint64.append(18446744073709551615) message.repeated_double.append(123.456) @@ -137,61 +119,44 @@ class TextFormatTest(basetest.TestCase): ' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n' 'repeated_string: "\\303\\274\\352\\234\\237"\n') - def testPrintExoticUnicodeSubclass(self): - class UnicodeSub(unicode): + def testPrintExoticUnicodeSubclass(self, message_module): + class UnicodeSub(six.text_type): pass - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f')) self.CompareToGoldenText( text_format.MessageToString(message), 'repeated_string: "\\303\\274\\352\\234\\237"\n') - def testPrintNestedMessageAsOneLine(self): - message = unittest_pb2.TestAllTypes() + def testPrintNestedMessageAsOneLine(self, message_module): + message = message_module.TestAllTypes() msg = message.repeated_nested_message.add() msg.bb = 42 self.CompareToGoldenText( text_format.MessageToString(message, as_one_line=True), 'repeated_nested_message { bb: 42 }') - def testPrintRepeatedFieldsAsOneLine(self): - message = unittest_pb2.TestAllTypes() + def testPrintRepeatedFieldsAsOneLine(self, message_module): + message = message_module.TestAllTypes() message.repeated_int32.append(1) message.repeated_int32.append(1) message.repeated_int32.append(3) - message.repeated_string.append("Google") - message.repeated_string.append("Zurich") + message.repeated_string.append('Google') + message.repeated_string.append('Zurich') self.CompareToGoldenText( text_format.MessageToString(message, as_one_line=True), 'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 ' 'repeated_string: "Google" repeated_string: "Zurich"') - def testPrintNestedNewLineInStringAsOneLine(self): - message = unittest_pb2.TestAllTypes() - message.optional_string = "a\nnew\nline" + def testPrintNestedNewLineInStringAsOneLine(self, message_module): + message = message_module.TestAllTypes() + message.optional_string = 'a\nnew\nline' self.CompareToGoldenText( text_format.MessageToString(message, as_one_line=True), 'optional_string: "a\\nnew\\nline"') - def testPrintMessageSetAsOneLine(self): - message = unittest_mset_pb2.TestMessageSetContainer() - ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension - ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension - message.message_set.Extensions[ext1].i = 23 - message.message_set.Extensions[ext2].str = 'foo' - self.CompareToGoldenText( - text_format.MessageToString(message, as_one_line=True), - 'message_set {' - ' [protobuf_unittest.TestMessageSetExtension1] {' - ' i: 23' - ' }' - ' [protobuf_unittest.TestMessageSetExtension2] {' - ' str: \"foo\"' - ' }' - ' }') - - def testPrintExoticAsOneLine(self): - message = unittest_pb2.TestAllTypes() + def testPrintExoticAsOneLine(self, message_module): + message = message_module.TestAllTypes() message.repeated_int64.append(-9223372036854775808) message.repeated_uint64.append(18446744073709551615) message.repeated_double.append(123.456) @@ -211,8 +176,8 @@ class TextFormatTest(basetest.TestCase): '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""' ' repeated_string: "\\303\\274\\352\\234\\237"') - def testRoundTripExoticAsOneLine(self): - message = unittest_pb2.TestAllTypes() + def testRoundTripExoticAsOneLine(self, message_module): + message = message_module.TestAllTypes() message.repeated_int64.append(-9223372036854775808) message.repeated_uint64.append(18446744073709551615) message.repeated_double.append(123.456) @@ -224,33 +189,33 @@ class TextFormatTest(basetest.TestCase): # Test as_utf8 = False. wire_text = text_format.MessageToString( message, as_one_line=True, as_utf8=False) - parsed_message = unittest_pb2.TestAllTypes() + parsed_message = message_module.TestAllTypes() r = text_format.Parse(wire_text, parsed_message) self.assertIs(r, parsed_message) - self.assertEquals(message, parsed_message) + self.assertEqual(message, parsed_message) # Test as_utf8 = True. wire_text = text_format.MessageToString( message, as_one_line=True, as_utf8=True) - parsed_message = unittest_pb2.TestAllTypes() + parsed_message = message_module.TestAllTypes() r = text_format.Parse(wire_text, parsed_message) self.assertIs(r, parsed_message) - self.assertEquals(message, parsed_message, + self.assertEqual(message, parsed_message, '\n%s != %s' % (message, parsed_message)) - def testPrintRawUtf8String(self): - message = unittest_pb2.TestAllTypes() + def testPrintRawUtf8String(self, message_module): + message = message_module.TestAllTypes() message.repeated_string.append(u'\u00fc\ua71f') text = text_format.MessageToString(message, as_utf8=True) self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n') - parsed_message = unittest_pb2.TestAllTypes() + parsed_message = message_module.TestAllTypes() text_format.Parse(text, parsed_message) - self.assertEquals(message, parsed_message, + self.assertEqual(message, parsed_message, '\n%s != %s' % (message, parsed_message)) - def testPrintFloatFormat(self): + def testPrintFloatFormat(self, message_module): # Check that float_format argument is passed to sub-message formatting. - message = unittest_pb2.NestedTestAllTypes() + message = message_module.NestedTestAllTypes() # We use 1.25 as it is a round number in binary. The proto 32-bit float # will not gain additional imprecise digits as a 64-bit Python float and # show up in its str. 32-bit 1.2 is noisy when extended to 64-bit: @@ -272,93 +237,62 @@ class TextFormatTest(basetest.TestCase): text_message = text_format.MessageToString(message, float_format='.15g') self.CompareToGoldenText( self.RemoveRedundantZeros(text_message), - 'payload {{\n {}\n {}\n {}\n {}\n}}\n'.format(*formatted_fields)) + 'payload {{\n {0}\n {1}\n {2}\n {3}\n}}\n'.format(*formatted_fields)) # as_one_line=True is a separate code branch where float_format is passed. text_message = text_format.MessageToString(message, as_one_line=True, float_format='.15g') self.CompareToGoldenText( self.RemoveRedundantZeros(text_message), - 'payload {{ {} {} {} {} }}'.format(*formatted_fields)) + 'payload {{ {0} {1} {2} {3} }}'.format(*formatted_fields)) - def testMessageToString(self): - message = unittest_pb2.ForeignMessage() + def testMessageToString(self, message_module): + message = message_module.ForeignMessage() message.c = 123 self.assertEqual('c: 123\n', str(message)) - def RemoveRedundantZeros(self, text): - # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove - # these zeros in order to match the golden file. - text = text.replace('e+0','e+').replace('e+0','e+') \ - .replace('e-0','e-').replace('e-0','e-') - # Floating point fields are printed with .0 suffix even if they are - # actualy integer numbers. - text = re.compile('\.0$', re.MULTILINE).sub('', text) - return text - - def testParseGolden(self): - golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt')) - parsed_message = unittest_pb2.TestAllTypes() - r = text_format.Parse(golden_text, parsed_message) - self.assertIs(r, parsed_message) - - message = unittest_pb2.TestAllTypes() - test_util.SetAllFields(message) - self.assertEquals(message, parsed_message) - - def testParseGoldenExtensions(self): - golden_text = '\n'.join(self.ReadGolden( - 'text_format_unittest_extensions_data.txt')) - parsed_message = unittest_pb2.TestAllExtensions() - text_format.Parse(golden_text, parsed_message) - - message = unittest_pb2.TestAllExtensions() - test_util.SetAllExtensions(message) - self.assertEquals(message, parsed_message) - - def testParseAllFields(self): - message = unittest_pb2.TestAllTypes() + def testPrintField(self, message_module): + message = message_module.TestAllTypes() + field = message.DESCRIPTOR.fields_by_name['optional_float'] + value = message.optional_float + out = text_format.TextWriter(False) + text_format.PrintField(field, value, out) + self.assertEqual('optional_float: 0.0\n', out.getvalue()) + out.close() + # Test Printer + out = text_format.TextWriter(False) + printer = text_format._Printer(out) + printer.PrintField(field, value) + self.assertEqual('optional_float: 0.0\n', out.getvalue()) + out.close() + + def testPrintFieldValue(self, message_module): + message = message_module.TestAllTypes() + field = message.DESCRIPTOR.fields_by_name['optional_float'] + value = message.optional_float + out = text_format.TextWriter(False) + text_format.PrintFieldValue(field, value, out) + self.assertEqual('0.0', out.getvalue()) + out.close() + # Test Printer + out = text_format.TextWriter(False) + printer = text_format._Printer(out) + printer.PrintFieldValue(field, value) + self.assertEqual('0.0', out.getvalue()) + out.close() + + def testParseAllFields(self, message_module): + message = message_module.TestAllTypes() test_util.SetAllFields(message) ascii_text = text_format.MessageToString(message) - parsed_message = unittest_pb2.TestAllTypes() + parsed_message = message_module.TestAllTypes() text_format.Parse(ascii_text, parsed_message) self.assertEqual(message, parsed_message) - test_util.ExpectAllFieldsSet(self, message) + if message_module is unittest_pb2: + test_util.ExpectAllFieldsSet(self, message) - def testParseAllExtensions(self): - message = unittest_pb2.TestAllExtensions() - test_util.SetAllExtensions(message) - ascii_text = text_format.MessageToString(message) - - parsed_message = unittest_pb2.TestAllExtensions() - text_format.Parse(ascii_text, parsed_message) - self.assertEqual(message, parsed_message) - - def testParseMessageSet(self): - message = unittest_pb2.TestAllTypes() - text = ('repeated_uint64: 1\n' - 'repeated_uint64: 2\n') - text_format.Parse(text, message) - self.assertEqual(1, message.repeated_uint64[0]) - self.assertEqual(2, message.repeated_uint64[1]) - - message = unittest_mset_pb2.TestMessageSetContainer() - text = ('message_set {\n' - ' [protobuf_unittest.TestMessageSetExtension1] {\n' - ' i: 23\n' - ' }\n' - ' [protobuf_unittest.TestMessageSetExtension2] {\n' - ' str: \"foo\"\n' - ' }\n' - '}\n') - text_format.Parse(text, message) - ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension - ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension - self.assertEquals(23, message.message_set.Extensions[ext1].i) - self.assertEquals('foo', message.message_set.Extensions[ext2].str) - - def testParseExotic(self): - message = unittest_pb2.TestAllTypes() + def testParseExotic(self, message_module): + message = message_module.TestAllTypes() text = ('repeated_int64: -9223372036854775808\n' 'repeated_uint64: 18446744073709551615\n' 'repeated_double: 123.456\n' @@ -383,8 +317,8 @@ class TextFormatTest(basetest.TestCase): self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2]) self.assertEqual(u'\u00fc', message.repeated_string[3]) - def testParseTrailingCommas(self): - message = unittest_pb2.TestAllTypes() + def testParseTrailingCommas(self, message_module): + message = message_module.TestAllTypes() text = ('repeated_int64: 100;\n' 'repeated_int64: 200;\n' 'repeated_int64: 300,\n' @@ -398,101 +332,87 @@ class TextFormatTest(basetest.TestCase): self.assertEqual(u'one', message.repeated_string[0]) self.assertEqual(u'two', message.repeated_string[1]) - def testParseEmptyText(self): - message = unittest_pb2.TestAllTypes() + def testParseRepeatedScalarShortFormat(self, message_module): + message = message_module.TestAllTypes() + text = ('repeated_int64: [100, 200];\n' + 'repeated_int64: 300,\n' + 'repeated_string: ["one", "two"];\n') + text_format.Parse(text, message) + + self.assertEqual(100, message.repeated_int64[0]) + self.assertEqual(200, message.repeated_int64[1]) + self.assertEqual(300, message.repeated_int64[2]) + self.assertEqual(u'one', message.repeated_string[0]) + self.assertEqual(u'two', message.repeated_string[1]) + + def testParseRepeatedMessageShortFormat(self, message_module): + message = message_module.TestAllTypes() + text = ('repeated_nested_message: [{bb: 100}, {bb: 200}],\n' + 'repeated_nested_message: {bb: 300}\n' + 'repeated_nested_message [{bb: 400}];\n') + text_format.Parse(text, message) + + self.assertEqual(100, message.repeated_nested_message[0].bb) + self.assertEqual(200, message.repeated_nested_message[1].bb) + self.assertEqual(300, message.repeated_nested_message[2].bb) + self.assertEqual(400, message.repeated_nested_message[3].bb) + + def testParseEmptyText(self, message_module): + message = message_module.TestAllTypes() text = '' text_format.Parse(text, message) - self.assertEquals(unittest_pb2.TestAllTypes(), message) + self.assertEqual(message_module.TestAllTypes(), message) - def testParseInvalidUtf8(self): - message = unittest_pb2.TestAllTypes() + def testParseInvalidUtf8(self, message_module): + message = message_module.TestAllTypes() text = 'repeated_string: "\\xc3\\xc3"' self.assertRaises(text_format.ParseError, text_format.Parse, text, message) - def testParseSingleWord(self): - message = unittest_pb2.TestAllTypes() + def testParseSingleWord(self, message_module): + message = message_module.TestAllTypes() text = 'foo' - self.assertRaisesWithLiteralMatch( + six.assertRaisesRegex(self, text_format.ParseError, - ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' - '"foo".'), + (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"foo".'), text_format.Parse, text, message) - def testParseUnknownField(self): - message = unittest_pb2.TestAllTypes() + def testParseUnknownField(self, message_module): + message = message_module.TestAllTypes() text = 'unknown_field: 8\n' - self.assertRaisesWithLiteralMatch( + six.assertRaisesRegex(self, text_format.ParseError, - ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' - '"unknown_field".'), + (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"unknown_field".'), text_format.Parse, text, message) - def testParseBadExtension(self): - message = unittest_pb2.TestAllExtensions() - text = '[unknown_extension]: 8\n' - self.assertRaisesWithLiteralMatch( - text_format.ParseError, - '1:2 : Extension "unknown_extension" not registered.', - text_format.Parse, text, message) - message = unittest_pb2.TestAllTypes() - self.assertRaisesWithLiteralMatch( - text_format.ParseError, - ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' - 'extensions.'), - text_format.Parse, text, message) - - def testParseGroupNotClosed(self): - message = unittest_pb2.TestAllTypes() - text = 'RepeatedGroup: <' - self.assertRaisesWithLiteralMatch( - text_format.ParseError, '1:16 : Expected ">".', - text_format.Parse, text, message) - - text = 'RepeatedGroup: {' - self.assertRaisesWithLiteralMatch( - text_format.ParseError, '1:16 : Expected "}".', - text_format.Parse, text, message) - - def testParseEmptyGroup(self): - message = unittest_pb2.TestAllTypes() - text = 'OptionalGroup: {}' - text_format.Parse(text, message) - self.assertTrue(message.HasField('optionalgroup')) - - message.Clear() - - message = unittest_pb2.TestAllTypes() - text = 'OptionalGroup: <>' - text_format.Parse(text, message) - self.assertTrue(message.HasField('optionalgroup')) - - def testParseBadEnumValue(self): - message = unittest_pb2.TestAllTypes() + def testParseBadEnumValue(self, message_module): + message = message_module.TestAllTypes() text = 'optional_nested_enum: BARR' - self.assertRaisesWithLiteralMatch( + six.assertRaisesRegex(self, text_format.ParseError, - ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' - 'has no value named BARR.'), + (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' + r'has no value named BARR.'), text_format.Parse, text, message) - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() text = 'optional_nested_enum: 100' - self.assertRaisesWithLiteralMatch( + six.assertRaisesRegex(self, text_format.ParseError, - ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' - 'has no value with number 100.'), + (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' + r'has no value with number 100.'), text_format.Parse, text, message) - def testParseBadIntValue(self): - message = unittest_pb2.TestAllTypes() + def testParseBadIntValue(self, message_module): + message = message_module.TestAllTypes() text = 'optional_int32: bork' - self.assertRaisesWithLiteralMatch( + six.assertRaisesRegex(self, text_format.ParseError, ('1:17 : Couldn\'t parse integer: bork'), text_format.Parse, text, message) - def testParseStringFieldUnescape(self): - message = unittest_pb2.TestAllTypes() + def testParseStringFieldUnescape(self, message_module): + message = message_module.TestAllTypes() text = r'''repeated_string: "\xf\x62" repeated_string: "\\xf\\x62" repeated_string: "\\\xf\\\x62" @@ -511,43 +431,483 @@ class TextFormatTest(basetest.TestCase): message.repeated_string[4]) self.assertEqual(SLASH + 'x20', message.repeated_string[5]) - def testMergeRepeatedScalars(self): - message = unittest_pb2.TestAllTypes() + def testMergeDuplicateScalars(self, message_module): + message = message_module.TestAllTypes() text = ('optional_int32: 42 ' 'optional_int32: 67') r = text_format.Merge(text, message) self.assertIs(r, message) self.assertEqual(67, message.optional_int32) - def testParseRepeatedScalars(self): - message = unittest_pb2.TestAllTypes() - text = ('optional_int32: 42 ' - 'optional_int32: 67') - self.assertRaisesWithLiteralMatch( - text_format.ParseError, - ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not ' - 'have multiple "optional_int32" fields.'), - text_format.Parse, text, message) - - def testMergeRepeatedNestedMessageScalars(self): - message = unittest_pb2.TestAllTypes() + def testMergeDuplicateNestedMessageScalars(self, message_module): + message = message_module.TestAllTypes() text = ('optional_nested_message { bb: 1 } ' 'optional_nested_message { bb: 2 }') r = text_format.Merge(text, message) self.assertTrue(r is message) self.assertEqual(2, message.optional_nested_message.bb) - def testParseRepeatedNestedMessageScalars(self): + def testParseOneof(self, message_module): + m = message_module.TestAllTypes() + m.oneof_uint32 = 11 + m2 = message_module.TestAllTypes() + text_format.Parse(text_format.MessageToString(m), m2) + self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) + + def testParseMultipleOneof(self, message_module): + m_string = '\n'.join([ + 'oneof_uint32: 11', + 'oneof_string: "foo"']) + m2 = message_module.TestAllTypes() + if message_module is unittest_pb2: + with self.assertRaisesRegexp( + text_format.ParseError, ' is specified along with field '): + text_format.Parse(m_string, m2) + else: + text_format.Parse(m_string, m2) + self.assertEqual('oneof_string', m2.WhichOneof('oneof_field')) + + +# These are tests that aren't fundamentally specific to proto2, but are at +# the moment because of differences between the proto2 and proto3 test schemas. +# Ideally the schemas would be made more similar so these tests could pass. +class OnlyWorksWithProto2RightNowTests(TextFormatBase): + + def testPrintAllFieldsPointy(self): message = unittest_pb2.TestAllTypes() - text = ('optional_nested_message { bb: 1 } ' - 'optional_nested_message { bb: 2 }') - self.assertRaisesWithLiteralMatch( + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros( + text_format.MessageToString(message, pointy_brackets=True)), + 'text_format_unittest_data_pointy_oneof.txt') + + def testParseGolden(self): + golden_text = '\n'.join(self.ReadGolden( + 'text_format_unittest_data_oneof_implemented.txt')) + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.Parse(golden_text, parsed_message) + self.assertIs(r, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEqual(message, parsed_message) + + def testPrintAllFields(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_data_oneof_implemented.txt') + + def testPrintAllFieldsPointy(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros( + text_format.MessageToString(message, pointy_brackets=True)), + 'text_format_unittest_data_pointy_oneof.txt') + + def testPrintInIndexOrder(self): + message = unittest_pb2.TestFieldOrderings() + message.my_string = '115' + message.my_int = 101 + message.my_float = 111 + message.optional_nested_message.oo = 0 + message.optional_nested_message.bb = 1 + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString( + message, use_index_order=True)), + 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n' + 'optional_nested_message {\n oo: 0\n bb: 1\n}\n') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString( + message)), + 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n' + 'optional_nested_message {\n bb: 1\n oo: 0\n}\n') + + def testMergeLinesGolden(self): + opened = self.ReadGolden('text_format_unittest_data_oneof_implemented.txt') + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.MergeLines(opened, parsed_message) + self.assertIs(r, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEqual(message, parsed_message) + + def testParseLinesGolden(self): + opened = self.ReadGolden('text_format_unittest_data_oneof_implemented.txt') + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.ParseLines(opened, parsed_message) + self.assertIs(r, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEqual(message, parsed_message) + + def testPrintMap(self): + message = map_unittest_pb2.TestMap() + + message.map_int32_int32[-123] = -456 + message.map_int64_int64[-2**33] = -2**34 + message.map_uint32_uint32[123] = 456 + message.map_uint64_uint64[2**33] = 2**34 + message.map_string_string["abc"] = "123" + message.map_int32_foreign_message[111].c = 5 + + # Maps are serialized to text format using their underlying repeated + # representation. + self.CompareToGoldenText( + text_format.MessageToString(message), + 'map_int32_int32 {\n' + ' key: -123\n' + ' value: -456\n' + '}\n' + 'map_int64_int64 {\n' + ' key: -8589934592\n' + ' value: -17179869184\n' + '}\n' + 'map_uint32_uint32 {\n' + ' key: 123\n' + ' value: 456\n' + '}\n' + 'map_uint64_uint64 {\n' + ' key: 8589934592\n' + ' value: 17179869184\n' + '}\n' + 'map_string_string {\n' + ' key: "abc"\n' + ' value: "123"\n' + '}\n' + 'map_int32_foreign_message {\n' + ' key: 111\n' + ' value {\n' + ' c: 5\n' + ' }\n' + '}\n') + + def testMapOrderEnforcement(self): + message = map_unittest_pb2.TestMap() + for letter in string.ascii_uppercase[13:26]: + message.map_string_string[letter] = 'dummy' + for letter in reversed(string.ascii_uppercase[0:13]): + message.map_string_string[letter] = 'dummy' + golden = ''.join(( + 'map_string_string {\n key: "%c"\n value: "dummy"\n}\n' % (letter,) + for letter in string.ascii_uppercase)) + self.CompareToGoldenText(text_format.MessageToString(message), golden) + + def testMapOrderSemantics(self): + golden_lines = self.ReadGolden('map_test_data.txt') + # The C++ implementation emits defaulted-value fields, while the Python + # implementation does not. Adjusting for this is awkward, but it is + # valuable to test against a common golden file. + line_blacklist = (' key: 0\n', + ' value: 0\n', + ' key: false\n', + ' value: false\n') + golden_lines = [line for line in golden_lines if line not in line_blacklist] + + message = map_unittest_pb2.TestMap() + text_format.ParseLines(golden_lines, message) + candidate = text_format.MessageToString(message) + # The Python implementation emits "1.0" for the double value that the C++ + # implementation emits as "1". + candidate = candidate.replace('1.0', '1', 2) + self.assertMultiLineEqual(candidate, ''.join(golden_lines)) + + +# Tests of proto2-only features (MessageSet, extensions, etc.). +class Proto2Tests(TextFormatBase): + + def testPrintMessageSet(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + self.CompareToGoldenText( + text_format.MessageToString(message), + 'message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') + + message = message_set_extensions_pb2.TestMessageSet() + ext = message_set_extensions_pb2.message_set_extension3 + message.Extensions[ext].text = 'bar' + self.CompareToGoldenText( + text_format.MessageToString(message), + '[google.protobuf.internal.TestMessageSetExtension3] {\n' + ' text: \"bar\"\n' + '}\n') + + def testPrintMessageSetByFieldNumber(self): + out = text_format.TextWriter(False) + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + text_format.PrintMessage(message, out, use_field_number=True) + self.CompareToGoldenText( + out.getvalue(), + '1 {\n' + ' 1545008 {\n' + ' 15: 23\n' + ' }\n' + ' 1547769 {\n' + ' 25: \"foo\"\n' + ' }\n' + '}\n') + out.close() + + def testPrintMessageSetAsOneLine(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + self.CompareToGoldenText( + text_format.MessageToString(message, as_one_line=True), + 'message_set {' + ' [protobuf_unittest.TestMessageSetExtension1] {' + ' i: 23' + ' }' + ' [protobuf_unittest.TestMessageSetExtension2] {' + ' str: \"foo\"' + ' }' + ' }') + + def testParseMessageSet(self): + message = unittest_pb2.TestAllTypes() + text = ('repeated_uint64: 1\n' + 'repeated_uint64: 2\n') + text_format.Parse(text, message) + self.assertEqual(1, message.repeated_uint64[0]) + self.assertEqual(2, message.repeated_uint64[1]) + + message = unittest_mset_pb2.TestMessageSetContainer() + text = ('message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') + text_format.Parse(text, message) + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + self.assertEqual(23, message.message_set.Extensions[ext1].i) + self.assertEqual('foo', message.message_set.Extensions[ext2].str) + + def testParseMessageByFieldNumber(self): + message = unittest_pb2.TestAllTypes() + text = ('34: 1\n' + 'repeated_uint64: 2\n') + text_format.Parse(text, message, allow_field_number=True) + self.assertEqual(1, message.repeated_uint64[0]) + self.assertEqual(2, message.repeated_uint64[1]) + + message = unittest_mset_pb2.TestMessageSetContainer() + text = ('1 {\n' + ' 1545008 {\n' + ' 15: 23\n' + ' }\n' + ' 1547769 {\n' + ' 25: \"foo\"\n' + ' }\n' + '}\n') + text_format.Parse(text, message, allow_field_number=True) + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + self.assertEqual(23, message.message_set.Extensions[ext1].i) + self.assertEqual('foo', message.message_set.Extensions[ext2].str) + + # Can't parse field number without set allow_field_number=True. + message = unittest_pb2.TestAllTypes() + text = '34:1\n' + six.assertRaisesRegex( + self, text_format.ParseError, - ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" ' - 'should not have multiple "bb" fields.'), + (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"34".'), text_format.Parse, text, message) - def testMergeRepeatedExtensionScalars(self): + # Can't parse if field number is not found. + text = '1234:1\n' + six.assertRaisesRegex( + self, + text_format.ParseError, + (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"1234".'), + text_format.Parse, text, message, allow_field_number=True) + + def testPrintAllExtensions(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_extensions_data.txt') + + def testPrintAllExtensionsPointy(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString( + message, pointy_brackets=True)), + 'text_format_unittest_extensions_data_pointy.txt') + + def testParseGoldenExtensions(self): + golden_text = '\n'.join(self.ReadGolden( + 'text_format_unittest_extensions_data.txt')) + parsed_message = unittest_pb2.TestAllExtensions() + text_format.Parse(golden_text, parsed_message) + + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.assertEqual(message, parsed_message) + + def testParseAllExtensions(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + ascii_text = text_format.MessageToString(message) + + parsed_message = unittest_pb2.TestAllExtensions() + text_format.Parse(ascii_text, parsed_message) + self.assertEqual(message, parsed_message) + + def testParseAllowedUnknownExtension(self): + # Skip over unknown extension correctly. + message = unittest_mset_pb2.TestMessageSetContainer() + text = ('message_set {\n' + ' [unknown_extension] {\n' + ' i: 23\n' + ' bin: "\xe0"' + ' [nested_unknown_ext]: {\n' + ' i: 23\n' + ' test: "test_string"\n' + ' floaty_float: -0.315\n' + ' num: -inf\n' + ' multiline_str: "abc"\n' + ' "def"\n' + ' "xyz."\n' + ' [nested_unknown_ext]: <\n' + ' i: 23\n' + ' i: 24\n' + ' pointfloat: .3\n' + ' test: "test_string"\n' + ' floaty_float: -0.315\n' + ' num: -inf\n' + ' long_string: "test" "test2" \n' + ' >\n' + ' }\n' + ' }\n' + ' [unknown_extension]: 5\n' + '}\n') + text_format.Parse(text, message, allow_unknown_extension=True) + golden = 'message_set {\n}\n' + self.CompareToGoldenText(text_format.MessageToString(message), golden) + + # Catch parse errors in unknown extension. + message = unittest_mset_pb2.TestMessageSetContainer() + malformed = ('message_set {\n' + ' [unknown_extension] {\n' + ' i:\n' # Missing value. + ' }\n' + '}\n') + six.assertRaisesRegex(self, + text_format.ParseError, + 'Invalid field value: }', + text_format.Parse, malformed, message, + allow_unknown_extension=True) + + message = unittest_mset_pb2.TestMessageSetContainer() + malformed = ('message_set {\n' + ' [unknown_extension] {\n' + ' str: "malformed string\n' # Missing closing quote. + ' }\n' + '}\n') + six.assertRaisesRegex(self, + text_format.ParseError, + 'Invalid field value: "', + text_format.Parse, malformed, message, + allow_unknown_extension=True) + + message = unittest_mset_pb2.TestMessageSetContainer() + malformed = ('message_set {\n' + ' [unknown_extension] {\n' + ' str: "malformed\n multiline\n string\n' + ' }\n' + '}\n') + six.assertRaisesRegex(self, + text_format.ParseError, + 'Invalid field value: "', + text_format.Parse, malformed, message, + allow_unknown_extension=True) + + message = unittest_mset_pb2.TestMessageSetContainer() + malformed = ('message_set {\n' + ' [malformed_extension] <\n' + ' i: -5\n' + ' \n' # Missing '>' here. + '}\n') + six.assertRaisesRegex(self, + text_format.ParseError, + '5:1 : Expected ">".', + text_format.Parse, malformed, message, + allow_unknown_extension=True) + + # Don't allow unknown fields with allow_unknown_extension=True. + message = unittest_mset_pb2.TestMessageSetContainer() + malformed = ('message_set {\n' + ' unknown_field: true\n' + ' \n' # Missing '>' here. + '}\n') + six.assertRaisesRegex(self, + text_format.ParseError, + ('2:3 : Message type ' + '"proto2_wireformat_unittest.TestMessageSet" has no' + ' field named "unknown_field".'), + text_format.Parse, malformed, message, + allow_unknown_extension=True) + + # Parse known extension correcty. + message = unittest_mset_pb2.TestMessageSetContainer() + text = ('message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') + text_format.Parse(text, message, allow_unknown_extension=True) + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + self.assertEqual(23, message.message_set.Extensions[ext1].i) + self.assertEqual('foo', message.message_set.Extensions[ext2].str) + + def testParseBadExtension(self): + message = unittest_pb2.TestAllExtensions() + text = '[unknown_extension]: 8\n' + six.assertRaisesRegex(self, + text_format.ParseError, + '1:2 : Extension "unknown_extension" not registered.', + text_format.Parse, text, message) + message = unittest_pb2.TestAllTypes() + six.assertRaisesRegex(self, + text_format.ParseError, + ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' + 'extensions.'), + text_format.Parse, text, message) + + def testMergeDuplicateExtensionScalars(self): message = unittest_pb2.TestAllExtensions() text = ('[protobuf_unittest.optional_int32_extension]: 42 ' '[protobuf_unittest.optional_int32_extension]: 67') @@ -556,46 +916,102 @@ class TextFormatTest(basetest.TestCase): 67, message.Extensions[unittest_pb2.optional_int32_extension]) - def testParseRepeatedExtensionScalars(self): + def testParseDuplicateExtensionScalars(self): message = unittest_pb2.TestAllExtensions() text = ('[protobuf_unittest.optional_int32_extension]: 42 ' '[protobuf_unittest.optional_int32_extension]: 67') - self.assertRaisesWithLiteralMatch( + six.assertRaisesRegex(self, text_format.ParseError, ('1:96 : Message type "protobuf_unittest.TestAllExtensions" ' 'should not have multiple ' '"protobuf_unittest.optional_int32_extension" extensions.'), text_format.Parse, text, message) - def testParseLinesGolden(self): - opened = self.ReadGolden('text_format_unittest_data.txt') - parsed_message = unittest_pb2.TestAllTypes() - r = text_format.ParseLines(opened, parsed_message) - self.assertIs(r, parsed_message) + def testParseDuplicateNestedMessageScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_nested_message { bb: 1 } ' + 'optional_nested_message { bb: 2 }') + six.assertRaisesRegex(self, + text_format.ParseError, + ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" ' + 'should not have multiple "bb" fields.'), + text_format.Parse, text, message) + def testParseDuplicateScalars(self): message = unittest_pb2.TestAllTypes() - test_util.SetAllFields(message) - self.assertEquals(message, parsed_message) + text = ('optional_int32: 42 ' + 'optional_int32: 67') + six.assertRaisesRegex(self, + text_format.ParseError, + ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not ' + 'have multiple "optional_int32" fields.'), + text_format.Parse, text, message) - def testMergeLinesGolden(self): - opened = self.ReadGolden('text_format_unittest_data.txt') - parsed_message = unittest_pb2.TestAllTypes() - r = text_format.MergeLines(opened, parsed_message) - self.assertIs(r, parsed_message) + def testParseGroupNotClosed(self): + message = unittest_pb2.TestAllTypes() + text = 'RepeatedGroup: <' + six.assertRaisesRegex(self, + text_format.ParseError, '1:16 : Expected ">".', + text_format.Parse, text, message) + text = 'RepeatedGroup: {' + six.assertRaisesRegex(self, + text_format.ParseError, '1:16 : Expected "}".', + text_format.Parse, text, message) + def testParseEmptyGroup(self): message = unittest_pb2.TestAllTypes() - test_util.SetAllFields(message) - self.assertEqual(message, parsed_message) + text = 'OptionalGroup: {}' + text_format.Parse(text, message) + self.assertTrue(message.HasField('optionalgroup')) - def testParseOneof(self): - m = unittest_pb2.TestAllTypes() - m.oneof_uint32 = 11 - m2 = unittest_pb2.TestAllTypes() - text_format.Parse(text_format.MessageToString(m), m2) - self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) + message.Clear() + + message = unittest_pb2.TestAllTypes() + text = 'OptionalGroup: <>' + text_format.Parse(text, message) + self.assertTrue(message.HasField('optionalgroup')) + + # Maps aren't really proto2-only, but our test schema only has maps for + # proto2. + def testParseMap(self): + text = ('map_int32_int32 {\n' + ' key: -123\n' + ' value: -456\n' + '}\n' + 'map_int64_int64 {\n' + ' key: -8589934592\n' + ' value: -17179869184\n' + '}\n' + 'map_uint32_uint32 {\n' + ' key: 123\n' + ' value: 456\n' + '}\n' + 'map_uint64_uint64 {\n' + ' key: 8589934592\n' + ' value: 17179869184\n' + '}\n' + 'map_string_string {\n' + ' key: "abc"\n' + ' value: "123"\n' + '}\n' + 'map_int32_foreign_message {\n' + ' key: 111\n' + ' value {\n' + ' c: 5\n' + ' }\n' + '}\n') + message = map_unittest_pb2.TestMap() + text_format.Parse(text, message) + + self.assertEqual(-456, message.map_int32_int32[-123]) + self.assertEqual(-2**34, message.map_int64_int64[-2**33]) + self.assertEqual(456, message.map_uint32_uint32[123]) + self.assertEqual(2**34, message.map_uint64_uint64[2**33]) + self.assertEqual("123", message.map_string_string["abc"]) + self.assertEqual(5, message.map_int32_foreign_message[111].c) -class TokenizerTest(basetest.TestCase): +class TokenizerTest(unittest.TestCase): def testSimpleTokenCases(self): text = ('identifier1:"string1"\n \n\n' @@ -740,4 +1156,4 @@ class TokenizerTest(basetest.TestCase): if __name__ == '__main__': - basetest.main() + unittest.main() diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index 56d264604..1be3ad9a3 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -28,10 +28,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#PY25 compatible for GAE. -# -# Copyright 2008 Google Inc. All Rights Reserved. - """Provides type checking routines. This module defines type checking utilities in the forms of dictionaries: @@ -49,8 +45,11 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization __author__ = 'robinson@google.com (Will Robinson)' -import sys ##PY25 -if sys.version < '2.6': bytes = str ##PY25 +import six + +if six.PY3: + long = int + from google.protobuf.internal import api_implementation from google.protobuf.internal import decoder from google.protobuf.internal import encoder @@ -59,6 +58,8 @@ from google.protobuf import descriptor _FieldDescriptor = descriptor.FieldDescriptor +def SupportsOpenEnums(field_descriptor): + return field_descriptor.containing_type.syntax == "proto3" def GetTypeChecker(field): """Returns a type checker for a message field of the specified types. @@ -74,7 +75,11 @@ def GetTypeChecker(field): field.type == _FieldDescriptor.TYPE_STRING): return UnicodeValueChecker() if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: - return EnumValueChecker(field.enum_type) + if SupportsOpenEnums(field): + # When open enums are supported, any int32 can be assigned. + return _VALUE_CHECKERS[_FieldDescriptor.CPPTYPE_INT32] + else: + return EnumValueChecker(field.enum_type) return _VALUE_CHECKERS[field.cpp_type] @@ -104,6 +109,16 @@ class TypeChecker(object): return proposed_value +class TypeCheckerWithDefault(TypeChecker): + + def __init__(self, default_value, *acceptable_types): + TypeChecker.__init__(self, acceptable_types) + self._default_value = default_value + + def DefaultValue(self): + return self._default_value + + # IntValueChecker and its subclasses perform integer type-checks # and bounds-checks. class IntValueChecker(object): @@ -111,9 +126,9 @@ class IntValueChecker(object): """Checker used for integer fields. Performs type-check and range check.""" def CheckValue(self, proposed_value): - if not isinstance(proposed_value, (int, long)): + if not isinstance(proposed_value, six.integer_types): message = ('%.1024r has type %s, but expected one of: %s' % - (proposed_value, type(proposed_value), (int, long))) + (proposed_value, type(proposed_value), six.integer_types)) raise TypeError(message) if not self._MIN <= proposed_value <= self._MAX: raise ValueError('Value out of range: %d' % proposed_value) @@ -123,6 +138,9 @@ class IntValueChecker(object): proposed_value = self._TYPE(proposed_value) return proposed_value + def DefaultValue(self): + return 0 + class EnumValueChecker(object): @@ -132,14 +150,17 @@ class EnumValueChecker(object): self._enum_type = enum_type def CheckValue(self, proposed_value): - if not isinstance(proposed_value, (int, long)): + if not isinstance(proposed_value, six.integer_types): message = ('%.1024r has type %s, but expected one of: %s' % - (proposed_value, type(proposed_value), (int, long))) + (proposed_value, type(proposed_value), six.integer_types)) raise TypeError(message) if proposed_value not in self._enum_type.values_by_number: raise ValueError('Unknown enum value: %d' % proposed_value) return proposed_value + def DefaultValue(self): + return self._enum_type.values[0].number + class UnicodeValueChecker(object): @@ -149,23 +170,25 @@ class UnicodeValueChecker(object): """ def CheckValue(self, proposed_value): - if not isinstance(proposed_value, (bytes, unicode)): + if not isinstance(proposed_value, (bytes, six.text_type)): message = ('%.1024r has type %s, but expected one of: %s' % - (proposed_value, type(proposed_value), (bytes, unicode))) + (proposed_value, type(proposed_value), (bytes, six.text_type))) raise TypeError(message) - # If the value is of type 'bytes' make sure that it is in 7-bit ASCII - # encoding. + # If the value is of type 'bytes' make sure that it is valid UTF-8 data. if isinstance(proposed_value, bytes): try: - proposed_value = proposed_value.decode('ascii') + proposed_value = proposed_value.decode('utf-8') except UnicodeDecodeError: - raise ValueError('%.1024r has type bytes, but isn\'t in 7-bit ASCII ' - 'encoding. Non-ASCII strings must be converted to ' + raise ValueError('%.1024r has type bytes, but isn\'t valid UTF-8 ' + 'encoding. Non-UTF-8 strings must be converted to ' 'unicode objects before being added.' % (proposed_value)) return proposed_value + def DefaultValue(self): + return u"" + class Int32ValueChecker(IntValueChecker): # We're sure to use ints instead of longs here since comparison may be more @@ -199,12 +222,13 @@ _VALUE_CHECKERS = { _FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(), _FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(), _FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(), - _FieldDescriptor.CPPTYPE_DOUBLE: TypeChecker( - float, int, long), - _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker( - float, int, long), - _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int), - _FieldDescriptor.CPPTYPE_STRING: TypeChecker(bytes), + _FieldDescriptor.CPPTYPE_DOUBLE: TypeCheckerWithDefault( + 0.0, float, int, long), + _FieldDescriptor.CPPTYPE_FLOAT: TypeCheckerWithDefault( + 0.0, float, int, long), + _FieldDescriptor.CPPTYPE_BOOL: TypeCheckerWithDefault( + False, bool, int), + _FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes), } diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 71775609d..84073f1c5 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # -*- coding: utf-8 -*- # # Protocol Buffers - Google's data interchange format @@ -35,16 +35,28 @@ __author__ = 'bohdank@google.com (Bohdan Koval)' -from google.apputils import basetest +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 +from google.protobuf import unittest_proto3_arena_pb2 +from google.protobuf.internal import api_implementation from google.protobuf.internal import encoder +from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf.internal import missing_enum_values_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import type_checkers -class UnknownFieldsTest(basetest.TestCase): +def SkipIfCppImplementation(func): + return unittest.skipIf( + api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, + 'C++ implementation does not expose unknown fields to Python')(func) + + +class UnknownFieldsTest(unittest.TestCase): def setUp(self): self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR @@ -53,7 +65,98 @@ class UnknownFieldsTest(basetest.TestCase): self.all_fields_data = self.all_fields.SerializeToString() self.empty_message = unittest_pb2.TestEmptyMessage() self.empty_message.ParseFromString(self.all_fields_data) - self.unknown_fields = self.empty_message._unknown_fields + + def testSerialize(self): + data = self.empty_message.SerializeToString() + + # Don't use assertEqual because we don't want to dump raw binary data to + # stdout. + self.assertTrue(data == self.all_fields_data) + + def testSerializeProto3(self): + # Verify that proto3 doesn't preserve unknown fields. + message = unittest_proto3_arena_pb2.TestEmptyMessage() + message.ParseFromString(self.all_fields_data) + self.assertEqual(0, len(message.SerializeToString())) + + def testByteSize(self): + self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) + + def testListFields(self): + # Make sure ListFields doesn't return unknown fields. + self.assertEqual(0, len(self.empty_message.ListFields())) + + def testSerializeMessageSetWireFormatUnknownExtension(self): + # Create a message using the message set wire format with an unknown + # message. + raw = unittest_mset_pb2.RawMessageSet() + + # Add an unknown extension. + item = raw.item.add() + item.type_id = 98418603 + message1 = message_set_extensions_pb2.TestMessageSetExtension1() + message1.i = 12345 + item.message = message1.SerializeToString() + + serialized = raw.SerializeToString() + + # Parse message using the message set wire format. + proto = message_set_extensions_pb2.TestMessageSet() + proto.MergeFromString(serialized) + + # Verify that the unknown extension is serialized unchanged + reserialized = proto.SerializeToString() + new_raw = unittest_mset_pb2.RawMessageSet() + new_raw.MergeFromString(reserialized) + self.assertEqual(raw, new_raw) + + def testEquals(self): + message = unittest_pb2.TestEmptyMessage() + message.ParseFromString(self.all_fields_data) + self.assertEqual(self.empty_message, message) + + self.all_fields.ClearField('optional_string') + message.ParseFromString(self.all_fields.SerializeToString()) + self.assertNotEqual(self.empty_message, message) + + def testDiscardUnknownFields(self): + self.empty_message.DiscardUnknownFields() + self.assertEqual(b'', self.empty_message.SerializeToString()) + # Test message field and repeated message field. + message = unittest_pb2.TestAllTypes() + other_message = unittest_pb2.TestAllTypes() + other_message.optional_string = 'discard' + message.optional_nested_message.ParseFromString( + other_message.SerializeToString()) + message.repeated_nested_message.add().ParseFromString( + other_message.SerializeToString()) + self.assertNotEqual( + b'', message.optional_nested_message.SerializeToString()) + self.assertNotEqual( + b'', message.repeated_nested_message[0].SerializeToString()) + message.DiscardUnknownFields() + self.assertEqual(b'', message.optional_nested_message.SerializeToString()) + self.assertEqual( + b'', message.repeated_nested_message[0].SerializeToString()) + + +class UnknownFieldsAccessorsTest(unittest.TestCase): + + def setUp(self): + self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + self.all_fields = unittest_pb2.TestAllTypes() + test_util.SetAllFields(self.all_fields) + self.all_fields_data = self.all_fields.SerializeToString() + self.empty_message = unittest_pb2.TestEmptyMessage() + self.empty_message.ParseFromString(self.all_fields_data) + if api_implementation.Type() != 'cpp': + # _unknown_fields is an implementation detail. + self.unknown_fields = self.empty_message._unknown_fields + + # All the tests that use GetField() check an implementation detail of the + # Python implementation, which stores unknown fields as serialized strings. + # These tests are skipped by the C++ implementation: it's enough to check that + # the message is correctly serialized. def GetField(self, name): field_descriptor = self.descriptor.fields_by_name[name] @@ -66,45 +169,45 @@ class UnknownFieldsTest(basetest.TestCase): decoder(value, 0, len(value), self.all_fields, result_dict) return result_dict[field_descriptor] + @SkipIfCppImplementation def testEnum(self): value = self.GetField('optional_nested_enum') self.assertEqual(self.all_fields.optional_nested_enum, value) + @SkipIfCppImplementation def testRepeatedEnum(self): value = self.GetField('repeated_nested_enum') self.assertEqual(self.all_fields.repeated_nested_enum, value) + @SkipIfCppImplementation def testVarint(self): value = self.GetField('optional_int32') self.assertEqual(self.all_fields.optional_int32, value) + @SkipIfCppImplementation def testFixed32(self): value = self.GetField('optional_fixed32') self.assertEqual(self.all_fields.optional_fixed32, value) + @SkipIfCppImplementation def testFixed64(self): value = self.GetField('optional_fixed64') self.assertEqual(self.all_fields.optional_fixed64, value) + @SkipIfCppImplementation def testLengthDelimited(self): value = self.GetField('optional_string') self.assertEqual(self.all_fields.optional_string, value) + @SkipIfCppImplementation def testGroup(self): value = self.GetField('optionalgroup') self.assertEqual(self.all_fields.optionalgroup, value) - def testSerialize(self): - data = self.empty_message.SerializeToString() - - # Don't use assertEqual because we don't want to dump raw binary data to - # stdout. - self.assertTrue(data == self.all_fields_data) - def testCopyFrom(self): message = unittest_pb2.TestEmptyMessage() message.CopyFrom(self.empty_message) - self.assertEqual(self.unknown_fields, message._unknown_fields) + self.assertEqual(message.SerializeToString(), self.all_fields_data) def testMergeFrom(self): message = unittest_pb2.TestAllTypes() @@ -118,64 +221,27 @@ class UnknownFieldsTest(basetest.TestCase): message.optional_uint32 = 4 destination = unittest_pb2.TestEmptyMessage() destination.ParseFromString(message.SerializeToString()) - unknown_fields = destination._unknown_fields[:] destination.MergeFrom(source) - self.assertEqual(unknown_fields + source._unknown_fields, - destination._unknown_fields) + # Check that the fields where correctly merged, even stored in the unknown + # fields set. + message.ParseFromString(destination.SerializeToString()) + self.assertEqual(message.optional_int32, 1) + self.assertEqual(message.optional_uint32, 2) + self.assertEqual(message.optional_int64, 3) def testClear(self): self.empty_message.Clear() - self.assertEqual(0, len(self.empty_message._unknown_fields)) - - def testByteSize(self): - self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) + # All cleared, even unknown fields. + self.assertEqual(self.empty_message.SerializeToString(), b'') def testUnknownExtensions(self): message = unittest_pb2.TestEmptyMessageWithExtensions() message.ParseFromString(self.all_fields_data) - self.assertEqual(self.empty_message._unknown_fields, - message._unknown_fields) + self.assertEqual(message.SerializeToString(), self.all_fields_data) - def testListFields(self): - # Make sure ListFields doesn't return unknown fields. - self.assertEqual(0, len(self.empty_message.ListFields())) - def testSerializeMessageSetWireFormatUnknownExtension(self): - # Create a message using the message set wire format with an unknown - # message. - raw = unittest_mset_pb2.RawMessageSet() - - # Add an unknown extension. - item = raw.item.add() - item.type_id = 1545009 - message1 = unittest_mset_pb2.TestMessageSetExtension1() - message1.i = 12345 - item.message = message1.SerializeToString() - - serialized = raw.SerializeToString() - - # Parse message using the message set wire format. - proto = unittest_mset_pb2.TestMessageSet() - proto.MergeFromString(serialized) - - # Verify that the unknown extension is serialized unchanged - reserialized = proto.SerializeToString() - new_raw = unittest_mset_pb2.RawMessageSet() - new_raw.MergeFromString(reserialized) - self.assertEqual(raw, new_raw) - - def testEquals(self): - message = unittest_pb2.TestEmptyMessage() - message.ParseFromString(self.all_fields_data) - self.assertEqual(self.empty_message, message) - - self.all_fields.ClearField('optional_string') - message.ParseFromString(self.all_fields.SerializeToString()) - self.assertNotEqual(self.empty_message, message) - - -class UnknownFieldsTest(basetest.TestCase): +class UnknownEnumValuesTest(unittest.TestCase): def setUp(self): self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR @@ -194,7 +260,14 @@ class UnknownFieldsTest(basetest.TestCase): self.message_data = self.message.SerializeToString() self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() self.missing_message.ParseFromString(self.message_data) - self.unknown_fields = self.missing_message._unknown_fields + if api_implementation.Type() != 'cpp': + # _unknown_fields is an implementation detail. + self.unknown_fields = self.missing_message._unknown_fields + + # All the tests that use GetField() check an implementation detail of the + # Python implementation, which stores unknown fields as serialized strings. + # These tests are skipped by the C++ implementation: it's enough to check that + # the message is correctly serialized. def GetField(self, name): field_descriptor = self.descriptor.fields_by_name[name] @@ -208,15 +281,31 @@ class UnknownFieldsTest(basetest.TestCase): decoder(value, 0, len(value), self.message, result_dict) return result_dict[field_descriptor] + def testUnknownParseMismatchEnumValue(self): + just_string = missing_enum_values_pb2.JustString() + just_string.dummy = 'blah' + + missing = missing_enum_values_pb2.TestEnumValues() + # The parse is invalid, storing the string proto into the set of + # unknown fields. + missing.ParseFromString(just_string.SerializeToString()) + + # Fetching the enum field shouldn't crash, instead returning the + # default value. + self.assertEqual(missing.optional_nested_enum, 0) + + @SkipIfCppImplementation def testUnknownEnumValue(self): self.assertFalse(self.missing_message.HasField('optional_nested_enum')) value = self.GetField('optional_nested_enum') self.assertEqual(self.message.optional_nested_enum, value) + @SkipIfCppImplementation def testUnknownRepeatedEnumValue(self): value = self.GetField('repeated_nested_enum') self.assertEqual(self.message.repeated_nested_enum, value) + @SkipIfCppImplementation def testUnknownPackedEnumValue(self): value = self.GetField('packed_nested_enum') self.assertEqual(self.message.packed_nested_enum, value) @@ -228,4 +317,4 @@ class UnknownFieldsTest(basetest.TestCase): if __name__ == '__main__': - basetest.main() + unittest.main() diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py new file mode 100644 index 000000000..7c5dffd0f --- /dev/null +++ b/python/google/protobuf/internal/well_known_types.py @@ -0,0 +1,724 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains well known classes. + +This files defines well known classes which need extra maintenance including: + - Any + - Duration + - FieldMask + - Struct + - Timestamp +""" + +__author__ = 'jieluo@google.com (Jie Luo)' + +from datetime import datetime +from datetime import timedelta +import six + +from google.protobuf.descriptor import FieldDescriptor + +_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S' +_NANOS_PER_SECOND = 1000000000 +_NANOS_PER_MILLISECOND = 1000000 +_NANOS_PER_MICROSECOND = 1000 +_MILLIS_PER_SECOND = 1000 +_MICROS_PER_SECOND = 1000000 +_SECONDS_PER_DAY = 24 * 3600 + + +class Error(Exception): + """Top-level module error.""" + + +class ParseError(Error): + """Thrown in case of parsing error.""" + + +class Any(object): + """Class for Any Message type.""" + + def Pack(self, msg, type_url_prefix='type.googleapis.com/'): + """Packs the specified message into current Any message.""" + if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/': + self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) + else: + self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) + self.value = msg.SerializeToString() + + def Unpack(self, msg): + """Unpacks the current Any message into specified message.""" + descriptor = msg.DESCRIPTOR + if not self.Is(descriptor): + return False + msg.ParseFromString(self.value) + return True + + def TypeName(self): + """Returns the protobuf type name of the inner message.""" + # Only last part is to be used: b/25630112 + return self.type_url.split('/')[-1] + + def Is(self, descriptor): + """Checks if this Any represents the given protobuf type.""" + return self.TypeName() == descriptor.full_name + + +class Timestamp(object): + """Class for Timestamp message type.""" + + def ToJsonString(self): + """Converts Timestamp to RFC 3339 date string format. + + Returns: + A string converted from timestamp. The string is always Z-normalized + and uses 3, 6 or 9 fractional digits as required to represent the + exact time. Example of the return format: '1972-01-01T10:00:20.021Z' + """ + nanos = self.nanos % _NANOS_PER_SECOND + total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND + seconds = total_sec % _SECONDS_PER_DAY + days = (total_sec - seconds) // _SECONDS_PER_DAY + dt = datetime(1970, 1, 1) + timedelta(days, seconds) + + result = dt.isoformat() + if (nanos % 1e9) == 0: + # If there are 0 fractional digits, the fractional + # point '.' should be omitted when serializing. + return result + 'Z' + if (nanos % 1e6) == 0: + # Serialize 3 fractional digits. + return result + '.%03dZ' % (nanos / 1e6) + if (nanos % 1e3) == 0: + # Serialize 6 fractional digits. + return result + '.%06dZ' % (nanos / 1e3) + # Serialize 9 fractional digits. + return result + '.%09dZ' % nanos + + def FromJsonString(self, value): + """Parse a RFC 3339 date string format to Timestamp. + + Args: + value: A date string. Any fractional digits (or none) and any offset are + accepted as long as they fit into nano-seconds precision. + Example of accepted format: '1972-01-01T10:00:20.021-05:00' + + Raises: + ParseError: On parsing problems. + """ + timezone_offset = value.find('Z') + if timezone_offset == -1: + timezone_offset = value.find('+') + if timezone_offset == -1: + timezone_offset = value.rfind('-') + if timezone_offset == -1: + raise ParseError( + 'Failed to parse timestamp: missing valid timezone offset.') + time_value = value[0:timezone_offset] + # Parse datetime and nanos. + point_position = time_value.find('.') + if point_position == -1: + second_value = time_value + nano_value = '' + else: + second_value = time_value[:point_position] + nano_value = time_value[point_position + 1:] + date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT) + td = date_object - datetime(1970, 1, 1) + seconds = td.seconds + td.days * _SECONDS_PER_DAY + if len(nano_value) > 9: + raise ParseError( + 'Failed to parse Timestamp: nanos {0} more than ' + '9 fractional digits.'.format(nano_value)) + if nano_value: + nanos = round(float('0.' + nano_value) * 1e9) + else: + nanos = 0 + # Parse timezone offsets. + if value[timezone_offset] == 'Z': + if len(value) != timezone_offset + 1: + raise ParseError('Failed to parse timestamp: invalid trailing' + ' data {0}.'.format(value)) + else: + timezone = value[timezone_offset:] + pos = timezone.find(':') + if pos == -1: + raise ParseError( + 'Invalid timezone offset value: {0}.'.format(timezone)) + if timezone[0] == '+': + seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 + else: + seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 + # Set seconds and nanos + self.seconds = int(seconds) + self.nanos = int(nanos) + + def GetCurrentTime(self): + """Get the current UTC into Timestamp.""" + self.FromDatetime(datetime.utcnow()) + + def ToNanoseconds(self): + """Converts Timestamp to nanoseconds since epoch.""" + return self.seconds * _NANOS_PER_SECOND + self.nanos + + def ToMicroseconds(self): + """Converts Timestamp to microseconds since epoch.""" + return (self.seconds * _MICROS_PER_SECOND + + self.nanos // _NANOS_PER_MICROSECOND) + + def ToMilliseconds(self): + """Converts Timestamp to milliseconds since epoch.""" + return (self.seconds * _MILLIS_PER_SECOND + + self.nanos // _NANOS_PER_MILLISECOND) + + def ToSeconds(self): + """Converts Timestamp to seconds since epoch.""" + return self.seconds + + def FromNanoseconds(self, nanos): + """Converts nanoseconds since epoch to Timestamp.""" + self.seconds = nanos // _NANOS_PER_SECOND + self.nanos = nanos % _NANOS_PER_SECOND + + def FromMicroseconds(self, micros): + """Converts microseconds since epoch to Timestamp.""" + self.seconds = micros // _MICROS_PER_SECOND + self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND + + def FromMilliseconds(self, millis): + """Converts milliseconds since epoch to Timestamp.""" + self.seconds = millis // _MILLIS_PER_SECOND + self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND + + def FromSeconds(self, seconds): + """Converts seconds since epoch to Timestamp.""" + self.seconds = seconds + self.nanos = 0 + + def ToDatetime(self): + """Converts Timestamp to datetime.""" + return datetime.utcfromtimestamp( + self.seconds + self.nanos / float(_NANOS_PER_SECOND)) + + def FromDatetime(self, dt): + """Converts datetime to Timestamp.""" + td = dt - datetime(1970, 1, 1) + self.seconds = td.seconds + td.days * _SECONDS_PER_DAY + self.nanos = td.microseconds * _NANOS_PER_MICROSECOND + + +class Duration(object): + """Class for Duration message type.""" + + def ToJsonString(self): + """Converts Duration to string format. + + Returns: + A string converted from self. The string format will contains + 3, 6, or 9 fractional digits depending on the precision required to + represent the exact Duration value. For example: "1s", "1.010s", + "1.000000100s", "-3.100s" + """ + if self.seconds < 0 or self.nanos < 0: + result = '-' + seconds = - self.seconds + int((0 - self.nanos) // 1e9) + nanos = (0 - self.nanos) % 1e9 + else: + result = '' + seconds = self.seconds + int(self.nanos // 1e9) + nanos = self.nanos % 1e9 + result += '%d' % seconds + if (nanos % 1e9) == 0: + # If there are 0 fractional digits, the fractional + # point '.' should be omitted when serializing. + return result + 's' + if (nanos % 1e6) == 0: + # Serialize 3 fractional digits. + return result + '.%03ds' % (nanos / 1e6) + if (nanos % 1e3) == 0: + # Serialize 6 fractional digits. + return result + '.%06ds' % (nanos / 1e3) + # Serialize 9 fractional digits. + return result + '.%09ds' % nanos + + def FromJsonString(self, value): + """Converts a string to Duration. + + Args: + value: A string to be converted. The string must end with 's'. Any + fractional digits (or none) are accepted as long as they fit into + precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s + + Raises: + ParseError: On parsing problems. + """ + if len(value) < 1 or value[-1] != 's': + raise ParseError( + 'Duration must end with letter "s": {0}.'.format(value)) + try: + pos = value.find('.') + if pos == -1: + self.seconds = int(value[:-1]) + self.nanos = 0 + else: + self.seconds = int(value[:pos]) + if value[0] == '-': + self.nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9)) + else: + self.nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9)) + except ValueError: + raise ParseError( + 'Couldn\'t parse duration: {0}.'.format(value)) + + def ToNanoseconds(self): + """Converts a Duration to nanoseconds.""" + return self.seconds * _NANOS_PER_SECOND + self.nanos + + def ToMicroseconds(self): + """Converts a Duration to microseconds.""" + micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND) + return self.seconds * _MICROS_PER_SECOND + micros + + def ToMilliseconds(self): + """Converts a Duration to milliseconds.""" + millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND) + return self.seconds * _MILLIS_PER_SECOND + millis + + def ToSeconds(self): + """Converts a Duration to seconds.""" + return self.seconds + + def FromNanoseconds(self, nanos): + """Converts nanoseconds to Duration.""" + self._NormalizeDuration(nanos // _NANOS_PER_SECOND, + nanos % _NANOS_PER_SECOND) + + def FromMicroseconds(self, micros): + """Converts microseconds to Duration.""" + self._NormalizeDuration( + micros // _MICROS_PER_SECOND, + (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND) + + def FromMilliseconds(self, millis): + """Converts milliseconds to Duration.""" + self._NormalizeDuration( + millis // _MILLIS_PER_SECOND, + (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND) + + def FromSeconds(self, seconds): + """Converts seconds to Duration.""" + self.seconds = seconds + self.nanos = 0 + + def ToTimedelta(self): + """Converts Duration to timedelta.""" + return timedelta( + seconds=self.seconds, microseconds=_RoundTowardZero( + self.nanos, _NANOS_PER_MICROSECOND)) + + def FromTimedelta(self, td): + """Convertd timedelta to Duration.""" + self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY, + td.microseconds * _NANOS_PER_MICROSECOND) + + def _NormalizeDuration(self, seconds, nanos): + """Set Duration by seconds and nonas.""" + # Force nanos to be negative if the duration is negative. + if seconds < 0 and nanos > 0: + seconds += 1 + nanos -= _NANOS_PER_SECOND + self.seconds = seconds + self.nanos = nanos + + +def _RoundTowardZero(value, divider): + """Truncates the remainder part after division.""" + # For some languanges, the sign of the remainder is implementation + # dependent if any of the operands is negative. Here we enforce + # "rounded toward zero" semantics. For example, for (-5) / 2 an + # implementation may give -3 as the result with the remainder being + # 1. This function ensures we always return -2 (closer to zero). + result = value // divider + remainder = value % divider + if result < 0 and remainder > 0: + return result + 1 + else: + return result + + +class FieldMask(object): + """Class for FieldMask message type.""" + + def ToJsonString(self): + """Converts FieldMask to string according to proto3 JSON spec.""" + return ','.join(self.paths) + + def FromJsonString(self, value): + """Converts string to FieldMask according to proto3 JSON spec.""" + self.Clear() + for path in value.split(','): + self.paths.append(path) + + def IsValidForDescriptor(self, message_descriptor): + """Checks whether the FieldMask is valid for Message Descriptor.""" + for path in self.paths: + if not _IsValidPath(message_descriptor, path): + return False + return True + + def AllFieldsFromDescriptor(self, message_descriptor): + """Gets all direct fields of Message Descriptor to FieldMask.""" + self.Clear() + for field in message_descriptor.fields: + self.paths.append(field.name) + + def CanonicalFormFromMask(self, mask): + """Converts a FieldMask to the canonical form. + + Removes paths that are covered by another path. For example, + "foo.bar" is covered by "foo" and will be removed if "foo" + is also in the FieldMask. Then sorts all paths in alphabetical order. + + Args: + mask: The original FieldMask to be converted. + """ + tree = _FieldMaskTree(mask) + tree.ToFieldMask(self) + + def Union(self, mask1, mask2): + """Merges mask1 and mask2 into this FieldMask.""" + _CheckFieldMaskMessage(mask1) + _CheckFieldMaskMessage(mask2) + tree = _FieldMaskTree(mask1) + tree.MergeFromFieldMask(mask2) + tree.ToFieldMask(self) + + def Intersect(self, mask1, mask2): + """Intersects mask1 and mask2 into this FieldMask.""" + _CheckFieldMaskMessage(mask1) + _CheckFieldMaskMessage(mask2) + tree = _FieldMaskTree(mask1) + intersection = _FieldMaskTree() + for path in mask2.paths: + tree.IntersectPath(path, intersection) + intersection.ToFieldMask(self) + + def MergeMessage( + self, source, destination, + replace_message_field=False, replace_repeated_field=False): + """Merges fields specified in FieldMask from source to destination. + + Args: + source: Source message. + destination: The destination message to be merged into. + replace_message_field: Replace message field if True. Merge message + field if False. + replace_repeated_field: Replace repeated field if True. Append + elements of repeated field if False. + """ + tree = _FieldMaskTree(self) + tree.MergeMessage( + source, destination, replace_message_field, replace_repeated_field) + + +def _IsValidPath(message_descriptor, path): + """Checks whether the path is valid for Message Descriptor.""" + parts = path.split('.') + last = parts.pop() + for name in parts: + field = message_descriptor.fields_by_name[name] + if (field is None or + field.label == FieldDescriptor.LABEL_REPEATED or + field.type != FieldDescriptor.TYPE_MESSAGE): + return False + message_descriptor = field.message_type + return last in message_descriptor.fields_by_name + + +def _CheckFieldMaskMessage(message): + """Raises ValueError if message is not a FieldMask.""" + message_descriptor = message.DESCRIPTOR + if (message_descriptor.name != 'FieldMask' or + message_descriptor.file.name != 'google/protobuf/field_mask.proto'): + raise ValueError('Message {0} is not a FieldMask.'.format( + message_descriptor.full_name)) + + +class _FieldMaskTree(object): + """Represents a FieldMask in a tree structure. + + For example, given a FieldMask "foo.bar,foo.baz,bar.baz", + the FieldMaskTree will be: + [_root] -+- foo -+- bar + | | + | +- baz + | + +- bar --- baz + In the tree, each leaf node represents a field path. + """ + + def __init__(self, field_mask=None): + """Initializes the tree by FieldMask.""" + self._root = {} + if field_mask: + self.MergeFromFieldMask(field_mask) + + def MergeFromFieldMask(self, field_mask): + """Merges a FieldMask to the tree.""" + for path in field_mask.paths: + self.AddPath(path) + + def AddPath(self, path): + """Adds a field path into the tree. + + If the field path to add is a sub-path of an existing field path + in the tree (i.e., a leaf node), it means the tree already matches + the given path so nothing will be added to the tree. If the path + matches an existing non-leaf node in the tree, that non-leaf node + will be turned into a leaf node with all its children removed because + the path matches all the node's children. Otherwise, a new path will + be added. + + Args: + path: The field path to add. + """ + node = self._root + for name in path.split('.'): + if name not in node: + node[name] = {} + elif not node[name]: + # Pre-existing empty node implies we already have this entire tree. + return + node = node[name] + # Remove any sub-trees we might have had. + node.clear() + + def ToFieldMask(self, field_mask): + """Converts the tree to a FieldMask.""" + field_mask.Clear() + _AddFieldPaths(self._root, '', field_mask) + + def IntersectPath(self, path, intersection): + """Calculates the intersection part of a field path with this tree. + + Args: + path: The field path to calculates. + intersection: The out tree to record the intersection part. + """ + node = self._root + for name in path.split('.'): + if name not in node: + return + elif not node[name]: + intersection.AddPath(path) + return + node = node[name] + intersection.AddLeafNodes(path, node) + + def AddLeafNodes(self, prefix, node): + """Adds leaf nodes begin with prefix to this tree.""" + if not node: + self.AddPath(prefix) + for name in node: + child_path = prefix + '.' + name + self.AddLeafNodes(child_path, node[name]) + + def MergeMessage( + self, source, destination, + replace_message, replace_repeated): + """Merge all fields specified by this tree from source to destination.""" + _MergeMessage( + self._root, source, destination, replace_message, replace_repeated) + + +def _StrConvert(value): + """Converts value to str if it is not.""" + # This file is imported by c extension and some methods like ClearField + # requires string for the field name. py2/py3 has different text + # type and may use unicode. + if not isinstance(value, str): + return value.encode('utf-8') + return value + + +def _MergeMessage( + node, source, destination, replace_message, replace_repeated): + """Merge all fields specified by a sub-tree from source to destination.""" + source_descriptor = source.DESCRIPTOR + for name in node: + child = node[name] + field = source_descriptor.fields_by_name[name] + if field is None: + raise ValueError('Error: Can\'t find field {0} in message {1}.'.format( + name, source_descriptor.full_name)) + if child: + # Sub-paths are only allowed for singular message fields. + if (field.label == FieldDescriptor.LABEL_REPEATED or + field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE): + raise ValueError('Error: Field {0} in message {1} is not a singular ' + 'message field and cannot have sub-fields.'.format( + name, source_descriptor.full_name)) + _MergeMessage( + child, getattr(source, name), getattr(destination, name), + replace_message, replace_repeated) + continue + if field.label == FieldDescriptor.LABEL_REPEATED: + if replace_repeated: + destination.ClearField(_StrConvert(name)) + repeated_source = getattr(source, name) + repeated_destination = getattr(destination, name) + if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + for item in repeated_source: + repeated_destination.add().MergeFrom(item) + else: + repeated_destination.extend(repeated_source) + else: + if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + if replace_message: + destination.ClearField(_StrConvert(name)) + if source.HasField(name): + getattr(destination, name).MergeFrom(getattr(source, name)) + else: + setattr(destination, name, getattr(source, name)) + + +def _AddFieldPaths(node, prefix, field_mask): + """Adds the field paths descended from node to field_mask.""" + if not node: + field_mask.paths.append(prefix) + return + for name in sorted(node): + if prefix: + child_path = prefix + '.' + name + else: + child_path = name + _AddFieldPaths(node[name], child_path, field_mask) + + +_INT_OR_FLOAT = six.integer_types + (float,) + + +def _SetStructValue(struct_value, value): + if value is None: + struct_value.null_value = 0 + elif isinstance(value, bool): + # Note: this check must come before the number check because in Python + # True and False are also considered numbers. + struct_value.bool_value = value + elif isinstance(value, six.string_types): + struct_value.string_value = value + elif isinstance(value, _INT_OR_FLOAT): + struct_value.number_value = value + else: + raise ValueError('Unexpected type') + + +def _GetStructValue(struct_value): + which = struct_value.WhichOneof('kind') + if which == 'struct_value': + return struct_value.struct_value + elif which == 'null_value': + return None + elif which == 'number_value': + return struct_value.number_value + elif which == 'string_value': + return struct_value.string_value + elif which == 'bool_value': + return struct_value.bool_value + elif which == 'list_value': + return struct_value.list_value + elif which is None: + raise ValueError('Value not set') + + +class Struct(object): + """Class for Struct message type.""" + + __slots__ = [] + + def __getitem__(self, key): + return _GetStructValue(self.fields[key]) + + def __setitem__(self, key, value): + _SetStructValue(self.fields[key], value) + + def get_or_create_list(self, key): + """Returns a list for this key, creating if it didn't exist already.""" + return self.fields[key].list_value + + def get_or_create_struct(self, key): + """Returns a struct for this key, creating if it didn't exist already.""" + return self.fields[key].struct_value + + # TODO(haberman): allow constructing/merging from dict. + + +class ListValue(object): + """Class for ListValue message type.""" + + def __len__(self): + return len(self.values) + + def append(self, value): + _SetStructValue(self.values.add(), value) + + def extend(self, elem_seq): + for value in elem_seq: + self.append(value) + + def __getitem__(self, index): + """Retrieves item by the specified index.""" + return _GetStructValue(self.values.__getitem__(index)) + + def __setitem__(self, index, value): + _SetStructValue(self.values.__getitem__(index), value) + + def items(self): + for i in range(len(self)): + yield self[i] + + def add_struct(self): + """Appends and returns a struct value as the next value in the list.""" + return self.values.add().struct_value + + def add_list(self): + """Appends and returns a list value as the next value in the list.""" + return self.values.add().list_value + + +WKTBASES = { + 'google.protobuf.Any': Any, + 'google.protobuf.Duration': Duration, + 'google.protobuf.FieldMask': FieldMask, + 'google.protobuf.ListValue': ListValue, + 'google.protobuf.Struct': Struct, + 'google.protobuf.Timestamp': Timestamp, +} diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py new file mode 100644 index 000000000..2f32ac994 --- /dev/null +++ b/python/google/protobuf/internal/well_known_types_test.py @@ -0,0 +1,644 @@ +#! /usr/bin/env python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for google.protobuf.internal.well_known_types.""" + +__author__ = 'jieluo@google.com (Jie Luo)' + +from datetime import datetime + +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + +from google.protobuf import any_pb2 +from google.protobuf import duration_pb2 +from google.protobuf import field_mask_pb2 +from google.protobuf import struct_pb2 +from google.protobuf import timestamp_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf.internal import any_test_pb2 +from google.protobuf.internal import test_util +from google.protobuf.internal import well_known_types +from google.protobuf import descriptor +from google.protobuf import text_format + + +class TimeUtilTestBase(unittest.TestCase): + + def CheckTimestampConversion(self, message, text): + self.assertEqual(text, message.ToJsonString()) + parsed_message = timestamp_pb2.Timestamp() + parsed_message.FromJsonString(text) + self.assertEqual(message, parsed_message) + + def CheckDurationConversion(self, message, text): + self.assertEqual(text, message.ToJsonString()) + parsed_message = duration_pb2.Duration() + parsed_message.FromJsonString(text) + self.assertEqual(message, parsed_message) + + +class TimeUtilTest(TimeUtilTestBase): + + def testTimestampSerializeAndParse(self): + message = timestamp_pb2.Timestamp() + # Generated output should contain 3, 6, or 9 fractional digits. + message.seconds = 0 + message.nanos = 0 + self.CheckTimestampConversion(message, '1970-01-01T00:00:00Z') + message.nanos = 10000000 + self.CheckTimestampConversion(message, '1970-01-01T00:00:00.010Z') + message.nanos = 10000 + self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000010Z') + message.nanos = 10 + self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000000010Z') + # Test min timestamps. + message.seconds = -62135596800 + message.nanos = 0 + self.CheckTimestampConversion(message, '0001-01-01T00:00:00Z') + # Test max timestamps. + message.seconds = 253402300799 + message.nanos = 999999999 + self.CheckTimestampConversion(message, '9999-12-31T23:59:59.999999999Z') + # Test negative timestamps. + message.seconds = -1 + self.CheckTimestampConversion(message, '1969-12-31T23:59:59.999999999Z') + + # Parsing accepts an fractional digits as long as they fit into nano + # precision. + message.FromJsonString('1970-01-01T00:00:00.1Z') + self.assertEqual(0, message.seconds) + self.assertEqual(100000000, message.nanos) + # Parsing accpets offsets. + message.FromJsonString('1970-01-01T00:00:00-08:00') + self.assertEqual(8 * 3600, message.seconds) + self.assertEqual(0, message.nanos) + + def testDurationSerializeAndParse(self): + message = duration_pb2.Duration() + # Generated output should contain 3, 6, or 9 fractional digits. + message.seconds = 0 + message.nanos = 0 + self.CheckDurationConversion(message, '0s') + message.nanos = 10000000 + self.CheckDurationConversion(message, '0.010s') + message.nanos = 10000 + self.CheckDurationConversion(message, '0.000010s') + message.nanos = 10 + self.CheckDurationConversion(message, '0.000000010s') + + # Test min and max + message.seconds = 315576000000 + message.nanos = 999999999 + self.CheckDurationConversion(message, '315576000000.999999999s') + message.seconds = -315576000000 + message.nanos = -999999999 + self.CheckDurationConversion(message, '-315576000000.999999999s') + + # Parsing accepts an fractional digits as long as they fit into nano + # precision. + message.FromJsonString('0.1s') + self.assertEqual(100000000, message.nanos) + message.FromJsonString('0.0000001s') + self.assertEqual(100, message.nanos) + + def testTimestampIntegerConversion(self): + message = timestamp_pb2.Timestamp() + message.FromNanoseconds(1) + self.assertEqual('1970-01-01T00:00:00.000000001Z', + message.ToJsonString()) + self.assertEqual(1, message.ToNanoseconds()) + + message.FromNanoseconds(-1) + self.assertEqual('1969-12-31T23:59:59.999999999Z', + message.ToJsonString()) + self.assertEqual(-1, message.ToNanoseconds()) + + message.FromMicroseconds(1) + self.assertEqual('1970-01-01T00:00:00.000001Z', + message.ToJsonString()) + self.assertEqual(1, message.ToMicroseconds()) + + message.FromMicroseconds(-1) + self.assertEqual('1969-12-31T23:59:59.999999Z', + message.ToJsonString()) + self.assertEqual(-1, message.ToMicroseconds()) + + message.FromMilliseconds(1) + self.assertEqual('1970-01-01T00:00:00.001Z', + message.ToJsonString()) + self.assertEqual(1, message.ToMilliseconds()) + + message.FromMilliseconds(-1) + self.assertEqual('1969-12-31T23:59:59.999Z', + message.ToJsonString()) + self.assertEqual(-1, message.ToMilliseconds()) + + message.FromSeconds(1) + self.assertEqual('1970-01-01T00:00:01Z', + message.ToJsonString()) + self.assertEqual(1, message.ToSeconds()) + + message.FromSeconds(-1) + self.assertEqual('1969-12-31T23:59:59Z', + message.ToJsonString()) + self.assertEqual(-1, message.ToSeconds()) + + message.FromNanoseconds(1999) + self.assertEqual(1, message.ToMicroseconds()) + # For negative values, Timestamp will be rounded down. + # For example, "1969-12-31T23:59:59.5Z" (i.e., -0.5s) rounded to seconds + # will be "1969-12-31T23:59:59Z" (i.e., -1s) rather than + # "1970-01-01T00:00:00Z" (i.e., 0s). + message.FromNanoseconds(-1999) + self.assertEqual(-2, message.ToMicroseconds()) + + def testDurationIntegerConversion(self): + message = duration_pb2.Duration() + message.FromNanoseconds(1) + self.assertEqual('0.000000001s', + message.ToJsonString()) + self.assertEqual(1, message.ToNanoseconds()) + + message.FromNanoseconds(-1) + self.assertEqual('-0.000000001s', + message.ToJsonString()) + self.assertEqual(-1, message.ToNanoseconds()) + + message.FromMicroseconds(1) + self.assertEqual('0.000001s', + message.ToJsonString()) + self.assertEqual(1, message.ToMicroseconds()) + + message.FromMicroseconds(-1) + self.assertEqual('-0.000001s', + message.ToJsonString()) + self.assertEqual(-1, message.ToMicroseconds()) + + message.FromMilliseconds(1) + self.assertEqual('0.001s', + message.ToJsonString()) + self.assertEqual(1, message.ToMilliseconds()) + + message.FromMilliseconds(-1) + self.assertEqual('-0.001s', + message.ToJsonString()) + self.assertEqual(-1, message.ToMilliseconds()) + + message.FromSeconds(1) + self.assertEqual('1s', message.ToJsonString()) + self.assertEqual(1, message.ToSeconds()) + + message.FromSeconds(-1) + self.assertEqual('-1s', + message.ToJsonString()) + self.assertEqual(-1, message.ToSeconds()) + + # Test truncation behavior. + message.FromNanoseconds(1999) + self.assertEqual(1, message.ToMicroseconds()) + + # For negative values, Duration will be rounded towards 0. + message.FromNanoseconds(-1999) + self.assertEqual(-1, message.ToMicroseconds()) + + def testDatetimeConverison(self): + message = timestamp_pb2.Timestamp() + dt = datetime(1970, 1, 1) + message.FromDatetime(dt) + self.assertEqual(dt, message.ToDatetime()) + + message.FromMilliseconds(1999) + self.assertEqual(datetime(1970, 1, 1, 0, 0, 1, 999000), + message.ToDatetime()) + + def testTimedeltaConversion(self): + message = duration_pb2.Duration() + message.FromNanoseconds(1999999999) + td = message.ToTimedelta() + self.assertEqual(1, td.seconds) + self.assertEqual(999999, td.microseconds) + + message.FromNanoseconds(-1999999999) + td = message.ToTimedelta() + self.assertEqual(-1, td.days) + self.assertEqual(86398, td.seconds) + self.assertEqual(1, td.microseconds) + + message.FromMicroseconds(-1) + td = message.ToTimedelta() + self.assertEqual(-1, td.days) + self.assertEqual(86399, td.seconds) + self.assertEqual(999999, td.microseconds) + converted_message = duration_pb2.Duration() + converted_message.FromTimedelta(td) + self.assertEqual(message, converted_message) + + def testInvalidTimestamp(self): + message = timestamp_pb2.Timestamp() + self.assertRaisesRegexp( + ValueError, + 'time data \'10000-01-01T00:00:00\' does not match' + ' format \'%Y-%m-%dT%H:%M:%S\'', + message.FromJsonString, '10000-01-01T00:00:00.00Z') + self.assertRaisesRegexp( + well_known_types.ParseError, + 'nanos 0123456789012 more than 9 fractional digits.', + message.FromJsonString, + '1970-01-01T00:00:00.0123456789012Z') + self.assertRaisesRegexp( + well_known_types.ParseError, + (r'Invalid timezone offset value: \+08.'), + message.FromJsonString, + '1972-01-01T01:00:00.01+08',) + self.assertRaisesRegexp( + ValueError, + 'year is out of range', + message.FromJsonString, + '0000-01-01T00:00:00Z') + message.seconds = 253402300800 + self.assertRaisesRegexp( + OverflowError, + 'date value out of range', + message.ToJsonString) + + def testInvalidDuration(self): + message = duration_pb2.Duration() + self.assertRaisesRegexp( + well_known_types.ParseError, + 'Duration must end with letter "s": 1.', + message.FromJsonString, '1') + self.assertRaisesRegexp( + well_known_types.ParseError, + 'Couldn\'t parse duration: 1...2s.', + message.FromJsonString, '1...2s') + + +class FieldMaskTest(unittest.TestCase): + + def testStringFormat(self): + mask = field_mask_pb2.FieldMask() + self.assertEqual('', mask.ToJsonString()) + mask.paths.append('foo') + self.assertEqual('foo', mask.ToJsonString()) + mask.paths.append('bar') + self.assertEqual('foo,bar', mask.ToJsonString()) + + mask.FromJsonString('') + self.assertEqual('', mask.ToJsonString()) + mask.FromJsonString('foo') + self.assertEqual(['foo'], mask.paths) + mask.FromJsonString('foo,bar') + self.assertEqual(['foo', 'bar'], mask.paths) + + def testDescriptorToFieldMask(self): + mask = field_mask_pb2.FieldMask() + msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + mask.AllFieldsFromDescriptor(msg_descriptor) + self.assertEqual(75, len(mask.paths)) + self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) + for field in msg_descriptor.fields: + self.assertTrue(field.name in mask.paths) + mask.paths.append('optional_nested_message.bb') + self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) + mask.paths.append('repeated_nested_message.bb') + self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) + + def testCanonicalFrom(self): + mask = field_mask_pb2.FieldMask() + out_mask = field_mask_pb2.FieldMask() + # Paths will be sorted. + mask.FromJsonString('baz.quz,bar,foo') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('bar,baz.quz,foo', out_mask.ToJsonString()) + # Duplicated paths will be removed. + mask.FromJsonString('foo,bar,foo') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('bar,foo', out_mask.ToJsonString()) + # Sub-paths of other paths will be removed. + mask.FromJsonString('foo.b1,bar.b1,foo.b2,bar') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('bar,foo.b1,foo.b2', out_mask.ToJsonString()) + + # Test more deeply nested cases. + mask.FromJsonString( + 'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('foo.bar.baz1,foo.bar.baz2', + out_mask.ToJsonString()) + mask.FromJsonString( + 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('foo.bar.baz1,foo.bar.baz2', + out_mask.ToJsonString()) + mask.FromJsonString( + 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('foo.bar', out_mask.ToJsonString()) + mask.FromJsonString( + 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('foo', out_mask.ToJsonString()) + + def testUnion(self): + mask1 = field_mask_pb2.FieldMask() + mask2 = field_mask_pb2.FieldMask() + out_mask = field_mask_pb2.FieldMask() + mask1.FromJsonString('foo,baz') + mask2.FromJsonString('bar,quz') + out_mask.Union(mask1, mask2) + self.assertEqual('bar,baz,foo,quz', out_mask.ToJsonString()) + # Overlap with duplicated paths. + mask1.FromJsonString('foo,baz.bb') + mask2.FromJsonString('baz.bb,quz') + out_mask.Union(mask1, mask2) + self.assertEqual('baz.bb,foo,quz', out_mask.ToJsonString()) + # Overlap with paths covering some other paths. + mask1.FromJsonString('foo.bar.baz,quz') + mask2.FromJsonString('foo.bar,bar') + out_mask.Union(mask1, mask2) + self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString()) + + def testIntersect(self): + mask1 = field_mask_pb2.FieldMask() + mask2 = field_mask_pb2.FieldMask() + out_mask = field_mask_pb2.FieldMask() + # Test cases without overlapping. + mask1.FromJsonString('foo,baz') + mask2.FromJsonString('bar,quz') + out_mask.Intersect(mask1, mask2) + self.assertEqual('', out_mask.ToJsonString()) + # Overlap with duplicated paths. + mask1.FromJsonString('foo,baz.bb') + mask2.FromJsonString('baz.bb,quz') + out_mask.Intersect(mask1, mask2) + self.assertEqual('baz.bb', out_mask.ToJsonString()) + # Overlap with paths covering some other paths. + mask1.FromJsonString('foo.bar.baz,quz') + mask2.FromJsonString('foo.bar,bar') + out_mask.Intersect(mask1, mask2) + self.assertEqual('foo.bar.baz', out_mask.ToJsonString()) + mask1.FromJsonString('foo.bar,bar') + mask2.FromJsonString('foo.bar.baz,quz') + out_mask.Intersect(mask1, mask2) + self.assertEqual('foo.bar.baz', out_mask.ToJsonString()) + + def testMergeMessage(self): + # Test merge one field. + src = unittest_pb2.TestAllTypes() + test_util.SetAllFields(src) + for field in src.DESCRIPTOR.fields: + if field.containing_oneof: + continue + field_name = field.name + dst = unittest_pb2.TestAllTypes() + # Only set one path to mask. + mask = field_mask_pb2.FieldMask() + mask.paths.append(field_name) + mask.MergeMessage(src, dst) + # The expected result message. + msg = unittest_pb2.TestAllTypes() + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + repeated_src = getattr(src, field_name) + repeated_msg = getattr(msg, field_name) + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + for item in repeated_src: + repeated_msg.add().CopyFrom(item) + else: + repeated_msg.extend(repeated_src) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + getattr(msg, field_name).CopyFrom(getattr(src, field_name)) + else: + setattr(msg, field_name, getattr(src, field_name)) + # Only field specified in mask is merged. + self.assertEqual(msg, dst) + + # Test merge nested fields. + nested_src = unittest_pb2.NestedTestAllTypes() + nested_dst = unittest_pb2.NestedTestAllTypes() + nested_src.child.payload.optional_int32 = 1234 + nested_src.child.child.payload.optional_int32 = 5678 + mask = field_mask_pb2.FieldMask() + mask.FromJsonString('child.payload') + mask.MergeMessage(nested_src, nested_dst) + self.assertEqual(1234, nested_dst.child.payload.optional_int32) + self.assertEqual(0, nested_dst.child.child.payload.optional_int32) + + mask.FromJsonString('child.child.payload') + mask.MergeMessage(nested_src, nested_dst) + self.assertEqual(1234, nested_dst.child.payload.optional_int32) + self.assertEqual(5678, nested_dst.child.child.payload.optional_int32) + + nested_dst.Clear() + mask.FromJsonString('child.child.payload') + mask.MergeMessage(nested_src, nested_dst) + self.assertEqual(0, nested_dst.child.payload.optional_int32) + self.assertEqual(5678, nested_dst.child.child.payload.optional_int32) + + nested_dst.Clear() + mask.FromJsonString('child') + mask.MergeMessage(nested_src, nested_dst) + self.assertEqual(1234, nested_dst.child.payload.optional_int32) + self.assertEqual(5678, nested_dst.child.child.payload.optional_int32) + + # Test MergeOptions. + nested_dst.Clear() + nested_dst.child.payload.optional_int64 = 4321 + # Message fields will be merged by default. + mask.FromJsonString('child.payload') + mask.MergeMessage(nested_src, nested_dst) + self.assertEqual(1234, nested_dst.child.payload.optional_int32) + self.assertEqual(4321, nested_dst.child.payload.optional_int64) + # Change the behavior to replace message fields. + mask.FromJsonString('child.payload') + mask.MergeMessage(nested_src, nested_dst, True, False) + self.assertEqual(1234, nested_dst.child.payload.optional_int32) + self.assertEqual(0, nested_dst.child.payload.optional_int64) + + # By default, fields missing in source are not cleared in destination. + nested_dst.payload.optional_int32 = 1234 + self.assertTrue(nested_dst.HasField('payload')) + mask.FromJsonString('payload') + mask.MergeMessage(nested_src, nested_dst) + self.assertTrue(nested_dst.HasField('payload')) + # But they are cleared when replacing message fields. + nested_dst.Clear() + nested_dst.payload.optional_int32 = 1234 + mask.FromJsonString('payload') + mask.MergeMessage(nested_src, nested_dst, True, False) + self.assertFalse(nested_dst.HasField('payload')) + + nested_src.payload.repeated_int32.append(1234) + nested_dst.payload.repeated_int32.append(5678) + # Repeated fields will be appended by default. + mask.FromJsonString('payload.repeated_int32') + mask.MergeMessage(nested_src, nested_dst) + self.assertEqual(2, len(nested_dst.payload.repeated_int32)) + self.assertEqual(5678, nested_dst.payload.repeated_int32[0]) + self.assertEqual(1234, nested_dst.payload.repeated_int32[1]) + # Change the behavior to replace repeated fields. + mask.FromJsonString('payload.repeated_int32') + mask.MergeMessage(nested_src, nested_dst, False, True) + self.assertEqual(1, len(nested_dst.payload.repeated_int32)) + self.assertEqual(1234, nested_dst.payload.repeated_int32[0]) + + +class StructTest(unittest.TestCase): + + def testStruct(self): + struct = struct_pb2.Struct() + struct_class = struct.__class__ + + struct['key1'] = 5 + struct['key2'] = 'abc' + struct['key3'] = True + struct.get_or_create_struct('key4')['subkey'] = 11.0 + struct_list = struct.get_or_create_list('key5') + struct_list.extend([6, 'seven', True, False, None]) + struct_list.add_struct()['subkey2'] = 9 + + self.assertTrue(isinstance(struct, well_known_types.Struct)) + self.assertEquals(5, struct['key1']) + self.assertEquals('abc', struct['key2']) + self.assertIs(True, struct['key3']) + self.assertEquals(11, struct['key4']['subkey']) + inner_struct = struct_class() + inner_struct['subkey2'] = 9 + self.assertEquals([6, 'seven', True, False, None, inner_struct], + list(struct['key5'].items())) + + serialized = struct.SerializeToString() + + struct2 = struct_pb2.Struct() + struct2.ParseFromString(serialized) + + self.assertEquals(struct, struct2) + + self.assertTrue(isinstance(struct2, well_known_types.Struct)) + self.assertEquals(5, struct2['key1']) + self.assertEquals('abc', struct2['key2']) + self.assertIs(True, struct2['key3']) + self.assertEquals(11, struct2['key4']['subkey']) + self.assertEquals([6, 'seven', True, False, None, inner_struct], + list(struct2['key5'].items())) + + struct_list = struct2['key5'] + self.assertEquals(6, struct_list[0]) + self.assertEquals('seven', struct_list[1]) + self.assertEquals(True, struct_list[2]) + self.assertEquals(False, struct_list[3]) + self.assertEquals(None, struct_list[4]) + self.assertEquals(inner_struct, struct_list[5]) + + struct_list[1] = 7 + self.assertEquals(7, struct_list[1]) + + struct_list.add_list().extend([1, 'two', True, False, None]) + self.assertEquals([1, 'two', True, False, None], + list(struct_list[6].items())) + + text_serialized = str(struct) + struct3 = struct_pb2.Struct() + text_format.Merge(text_serialized, struct3) + self.assertEquals(struct, struct3) + + struct.get_or_create_struct('key3')['replace'] = 12 + self.assertEquals(12, struct['key3']['replace']) + + +class AnyTest(unittest.TestCase): + + def testAnyMessage(self): + # Creates and sets message. + msg = any_test_pb2.TestAny() + msg_descriptor = msg.DESCRIPTOR + all_types = unittest_pb2.TestAllTypes() + all_descriptor = all_types.DESCRIPTOR + all_types.repeated_string.append(u'\u00fc\ua71f') + # Packs to Any. + msg.value.Pack(all_types) + self.assertEqual(msg.value.type_url, + 'type.googleapis.com/%s' % all_descriptor.full_name) + self.assertEqual(msg.value.value, + all_types.SerializeToString()) + # Tests Is() method. + self.assertTrue(msg.value.Is(all_descriptor)) + self.assertFalse(msg.value.Is(msg_descriptor)) + # Unpacks Any. + unpacked_message = unittest_pb2.TestAllTypes() + self.assertTrue(msg.value.Unpack(unpacked_message)) + self.assertEqual(all_types, unpacked_message) + # Unpacks to different type. + self.assertFalse(msg.value.Unpack(msg)) + # Only Any messages have Pack method. + try: + msg.Pack(all_types) + except AttributeError: + pass + else: + raise AttributeError('%s should not have Pack method.' % + msg_descriptor.full_name) + + def testMessageName(self): + # Creates and sets message. + submessage = any_test_pb2.TestAny() + submessage.int_value = 12345 + msg = any_pb2.Any() + msg.Pack(submessage) + self.assertEqual(msg.TypeName(), 'google.protobuf.internal.TestAny') + + def testPackWithCustomTypeUrl(self): + submessage = any_test_pb2.TestAny() + submessage.int_value = 12345 + msg = any_pb2.Any() + # Pack with a custom type URL prefix. + msg.Pack(submessage, 'type.myservice.com') + self.assertEqual(msg.type_url, + 'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name) + # Pack with a custom type URL prefix ending with '/'. + msg.Pack(submessage, 'type.myservice.com/') + self.assertEqual(msg.type_url, + 'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name) + # Pack with an empty type URL prefix. + msg.Pack(submessage, '') + self.assertEqual(msg.type_url, + '/%s' % submessage.DESCRIPTOR.full_name) + # Test unpacking the type. + unpacked_message = any_test_pb2.TestAny() + self.assertTrue(msg.Unpack(unpacked_message)) + self.assertEqual(submessage, unpacked_message) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py index f39035cae..da120f336 100755 --- a/python/google/protobuf/internal/wire_format_test.py +++ b/python/google/protobuf/internal/wire_format_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -34,12 +34,16 @@ __author__ = 'robinson@google.com (Will Robinson)' -from google.apputils import basetest +try: + import unittest2 as unittest #PY26 +except ImportError: + import unittest + from google.protobuf import message from google.protobuf.internal import wire_format -class WireFormatTest(basetest.TestCase): +class WireFormatTest(unittest.TestCase): def testPackTag(self): field_number = 0xabc @@ -250,4 +254,4 @@ class WireFormatTest(basetest.TestCase): if __name__ == '__main__': - basetest.main() + unittest.main() diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py new file mode 100644 index 000000000..7921556e6 --- /dev/null +++ b/python/google/protobuf/json_format.py @@ -0,0 +1,645 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains routines for printing protocol messages in JSON format. + +Simple usage example: + + # Create a proto object and serialize it to a json format string. + message = my_proto_pb2.MyMessage(foo='bar') + json_string = json_format.MessageToJson(message) + + # Parse a json format string to proto object. + message = json_format.Parse(json_string, my_proto_pb2.MyMessage()) +""" + +__author__ = 'jieluo@google.com (Jie Luo)' + +import base64 +import json +import math +import six +import sys + +from google.protobuf import descriptor +from google.protobuf import symbol_database + +_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S' +_INT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT32, + descriptor.FieldDescriptor.CPPTYPE_UINT32, + descriptor.FieldDescriptor.CPPTYPE_INT64, + descriptor.FieldDescriptor.CPPTYPE_UINT64]) +_INT64_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT64, + descriptor.FieldDescriptor.CPPTYPE_UINT64]) +_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT, + descriptor.FieldDescriptor.CPPTYPE_DOUBLE]) +_INFINITY = 'Infinity' +_NEG_INFINITY = '-Infinity' +_NAN = 'NaN' + + +class Error(Exception): + """Top-level module error for json_format.""" + + +class SerializeToJsonError(Error): + """Thrown if serialization to JSON fails.""" + + +class ParseError(Error): + """Thrown in case of parsing error.""" + + +def MessageToJson(message, including_default_value_fields=False): + """Converts protobuf message to JSON format. + + Args: + message: The protocol buffers message instance to serialize. + including_default_value_fields: If True, singular primitive fields, + repeated fields, and map fields will always be serialized. If + False, only serialize non-empty fields. Singular message fields + and oneof fields are not affected by this option. + + Returns: + A string containing the JSON formatted protocol buffer message. + """ + js = _MessageToJsonObject(message, including_default_value_fields) + return json.dumps(js, indent=2) + + +def _MessageToJsonObject(message, including_default_value_fields): + """Converts message to an object according to Proto3 JSON Specification.""" + message_descriptor = message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + return _WrapperMessageToJsonObject(message) + if full_name in _WKTJSONMETHODS: + return _WKTJSONMETHODS[full_name][0]( + message, including_default_value_fields) + js = {} + return _RegularMessageToJsonObject( + message, js, including_default_value_fields) + + +def _IsMapEntry(field): + return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.message_type.has_options and + field.message_type.GetOptions().map_entry) + + +def _RegularMessageToJsonObject(message, js, including_default_value_fields): + """Converts normal message according to Proto3 JSON Specification.""" + fields = message.ListFields() + include_default = including_default_value_fields + + try: + for field, value in fields: + name = field.camelcase_name + if _IsMapEntry(field): + # Convert a map field. + v_field = field.message_type.fields_by_name['value'] + js_map = {} + for key in value: + if isinstance(key, bool): + if key: + recorded_key = 'true' + else: + recorded_key = 'false' + else: + recorded_key = key + js_map[recorded_key] = _FieldToJsonObject( + v_field, value[key], including_default_value_fields) + js[name] = js_map + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + # Convert a repeated field. + js[name] = [_FieldToJsonObject(field, k, include_default) + for k in value] + else: + js[name] = _FieldToJsonObject(field, value, include_default) + + # Serialize default value if including_default_value_fields is True. + if including_default_value_fields: + message_descriptor = message.DESCRIPTOR + for field in message_descriptor.fields: + # Singular message fields and oneof fields will not be affected. + if ((field.label != descriptor.FieldDescriptor.LABEL_REPEATED and + field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or + field.containing_oneof): + continue + name = field.camelcase_name + if name in js: + # Skip the field which has been serailized already. + continue + if _IsMapEntry(field): + js[name] = {} + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + js[name] = [] + else: + js[name] = _FieldToJsonObject(field, field.default_value) + + except ValueError as e: + raise SerializeToJsonError( + 'Failed to serialize {0} field: {1}.'.format(field.name, e)) + + return js + + +def _FieldToJsonObject( + field, value, including_default_value_fields=False): + """Converts field value according to Proto3 JSON Specification.""" + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + return _MessageToJsonObject(value, including_default_value_fields) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + enum_value = field.enum_type.values_by_number.get(value, None) + if enum_value is not None: + return enum_value.name + else: + raise SerializeToJsonError('Enum field contains an integer value ' + 'which can not mapped to an enum value.') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + # Use base64 Data encoding for bytes + return base64.b64encode(value).decode('utf-8') + else: + return value + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + return bool(value) + elif field.cpp_type in _INT64_TYPES: + return str(value) + elif field.cpp_type in _FLOAT_TYPES: + if math.isinf(value): + if value < 0.0: + return _NEG_INFINITY + else: + return _INFINITY + if math.isnan(value): + return _NAN + return value + + +def _AnyMessageToJsonObject(message, including_default): + """Converts Any message according to Proto3 JSON Specification.""" + if not message.ListFields(): + return {} + js = {} + type_url = message.type_url + js['@type'] = type_url + sub_message = _CreateMessageFromTypeUrl(type_url) + sub_message.ParseFromString(message.value) + message_descriptor = sub_message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + js['value'] = _WrapperMessageToJsonObject(sub_message) + return js + if full_name in _WKTJSONMETHODS: + js['value'] = _WKTJSONMETHODS[full_name][0](sub_message, including_default) + return js + return _RegularMessageToJsonObject(sub_message, js, including_default) + + +def _CreateMessageFromTypeUrl(type_url): + # TODO(jieluo): Should add a way that users can register the type resolver + # instead of the default one. + db = symbol_database.Default() + type_name = type_url.split('/')[-1] + try: + message_descriptor = db.pool.FindMessageTypeByName(type_name) + except KeyError: + raise TypeError( + 'Can not find message descriptor by type_url: {0}.'.format(type_url)) + message_class = db.GetPrototype(message_descriptor) + return message_class() + + +def _GenericMessageToJsonObject(message, unused_including_default): + """Converts message by ToJsonString according to Proto3 JSON Specification.""" + # Duration, Timestamp and FieldMask have ToJsonString method to do the + # convert. Users can also call the method directly. + return message.ToJsonString() + + +def _ValueMessageToJsonObject(message, unused_including_default=False): + """Converts Value message according to Proto3 JSON Specification.""" + which = message.WhichOneof('kind') + # If the Value message is not set treat as null_value when serialize + # to JSON. The parse back result will be different from original message. + if which is None or which == 'null_value': + return None + if which == 'list_value': + return _ListValueMessageToJsonObject(message.list_value) + if which == 'struct_value': + value = message.struct_value + else: + value = getattr(message, which) + oneof_descriptor = message.DESCRIPTOR.fields_by_name[which] + return _FieldToJsonObject(oneof_descriptor, value) + + +def _ListValueMessageToJsonObject(message, unused_including_default=False): + """Converts ListValue message according to Proto3 JSON Specification.""" + return [_ValueMessageToJsonObject(value) + for value in message.values] + + +def _StructMessageToJsonObject(message, unused_including_default=False): + """Converts Struct message according to Proto3 JSON Specification.""" + fields = message.fields + ret = {} + for key in fields: + ret[key] = _ValueMessageToJsonObject(fields[key]) + return ret + + +def _IsWrapperMessage(message_descriptor): + return message_descriptor.file.name == 'google/protobuf/wrappers.proto' + + +def _WrapperMessageToJsonObject(message): + return _FieldToJsonObject( + message.DESCRIPTOR.fields_by_name['value'], message.value) + + +def _DuplicateChecker(js): + result = {} + for name, value in js: + if name in result: + raise ParseError('Failed to load JSON: duplicate key {0}.'.format(name)) + result[name] = value + return result + + +def Parse(text, message): + """Parses a JSON representation of a protocol message into a message. + + Args: + text: Message JSON representation. + message: A protocol beffer message to merge into. + + Returns: + The same message passed as argument. + + Raises:: + ParseError: On JSON parsing problems. + """ + if not isinstance(text, six.text_type): text = text.decode('utf-8') + try: + if sys.version_info < (2, 7): + # object_pair_hook is not supported before python2.7 + js = json.loads(text) + else: + js = json.loads(text, object_pairs_hook=_DuplicateChecker) + except ValueError as e: + raise ParseError('Failed to load JSON: {0}.'.format(str(e))) + _ConvertMessage(js, message) + return message + + +def _ConvertFieldValuePair(js, message): + """Convert field value pairs into regular message. + + Args: + js: A JSON object to convert the field value pairs. + message: A regular protocol message to record the data. + + Raises: + ParseError: In case of problems converting. + """ + names = [] + message_descriptor = message.DESCRIPTOR + for name in js: + try: + field = message_descriptor.fields_by_camelcase_name.get(name, None) + if not field: + raise ParseError( + 'Message type "{0}" has no field named "{1}".'.format( + message_descriptor.full_name, name)) + if name in names: + raise ParseError( + 'Message type "{0}" should not have multiple "{1}" fields.'.format( + message.DESCRIPTOR.full_name, name)) + names.append(name) + # Check no other oneof field is parsed. + if field.containing_oneof is not None: + oneof_name = field.containing_oneof.name + if oneof_name in names: + raise ParseError('Message type "{0}" should not have multiple "{1}" ' + 'oneof fields.'.format( + message.DESCRIPTOR.full_name, oneof_name)) + names.append(oneof_name) + + value = js[name] + if value is None: + message.ClearField(field.name) + continue + + # Parse field value. + if _IsMapEntry(field): + message.ClearField(field.name) + _ConvertMapFieldValue(value, message, field) + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + message.ClearField(field.name) + if not isinstance(value, list): + raise ParseError('repeated field {0} must be in [] which is ' + '{1}.'.format(name, value)) + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + # Repeated message field. + for item in value: + sub_message = getattr(message, field.name).add() + # None is a null_value in Value. + if (item is None and + sub_message.DESCRIPTOR.full_name != 'google.protobuf.Value'): + raise ParseError('null is not allowed to be used as an element' + ' in a repeated field.') + _ConvertMessage(item, sub_message) + else: + # Repeated scalar field. + for item in value: + if item is None: + raise ParseError('null is not allowed to be used as an element' + ' in a repeated field.') + getattr(message, field.name).append( + _ConvertScalarFieldValue(item, field)) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + sub_message = getattr(message, field.name) + _ConvertMessage(value, sub_message) + else: + setattr(message, field.name, _ConvertScalarFieldValue(value, field)) + except ParseError as e: + if field and field.containing_oneof is None: + raise ParseError('Failed to parse {0} field: {1}'.format(name, e)) + else: + raise ParseError(str(e)) + except ValueError as e: + raise ParseError('Failed to parse {0} field: {1}.'.format(name, e)) + except TypeError as e: + raise ParseError('Failed to parse {0} field: {1}.'.format(name, e)) + + +def _ConvertMessage(value, message): + """Convert a JSON object into a message. + + Args: + value: A JSON object. + message: A WKT or regular protocol message to record the data. + + Raises: + ParseError: In case of convert problems. + """ + message_descriptor = message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + _ConvertWrapperMessage(value, message) + elif full_name in _WKTJSONMETHODS: + _WKTJSONMETHODS[full_name][1](value, message) + else: + _ConvertFieldValuePair(value, message) + + +def _ConvertAnyMessage(value, message): + """Convert a JSON representation into Any message.""" + if isinstance(value, dict) and not value: + return + try: + type_url = value['@type'] + except KeyError: + raise ParseError('@type is missing when parsing any message.') + + sub_message = _CreateMessageFromTypeUrl(type_url) + message_descriptor = sub_message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + _ConvertWrapperMessage(value['value'], sub_message) + elif full_name in _WKTJSONMETHODS: + _WKTJSONMETHODS[full_name][1](value['value'], sub_message) + else: + del value['@type'] + _ConvertFieldValuePair(value, sub_message) + # Sets Any message + message.value = sub_message.SerializeToString() + message.type_url = type_url + + +def _ConvertGenericMessage(value, message): + """Convert a JSON representation into message with FromJsonString.""" + # Durantion, Timestamp, FieldMask have FromJsonString method to do the + # convert. Users can also call the method directly. + message.FromJsonString(value) + + +_INT_OR_FLOAT = six.integer_types + (float,) + + +def _ConvertValueMessage(value, message): + """Convert a JSON representation into Value message.""" + if isinstance(value, dict): + _ConvertStructMessage(value, message.struct_value) + elif isinstance(value, list): + _ConvertListValueMessage(value, message.list_value) + elif value is None: + message.null_value = 0 + elif isinstance(value, bool): + message.bool_value = value + elif isinstance(value, six.string_types): + message.string_value = value + elif isinstance(value, _INT_OR_FLOAT): + message.number_value = value + else: + raise ParseError('Unexpected type for Value message.') + + +def _ConvertListValueMessage(value, message): + """Convert a JSON representation into ListValue message.""" + if not isinstance(value, list): + raise ParseError( + 'ListValue must be in [] which is {0}.'.format(value)) + message.ClearField('values') + for item in value: + _ConvertValueMessage(item, message.values.add()) + + +def _ConvertStructMessage(value, message): + """Convert a JSON representation into Struct message.""" + if not isinstance(value, dict): + raise ParseError( + 'Struct must be in a dict which is {0}.'.format(value)) + for key in value: + _ConvertValueMessage(value[key], message.fields[key]) + return + + +def _ConvertWrapperMessage(value, message): + """Convert a JSON representation into Wrapper message.""" + field = message.DESCRIPTOR.fields_by_name['value'] + setattr(message, 'value', _ConvertScalarFieldValue(value, field)) + + +def _ConvertMapFieldValue(value, message, field): + """Convert map field value for a message map field. + + Args: + value: A JSON object to convert the map field value. + message: A protocol message to record the converted data. + field: The descriptor of the map field to be converted. + + Raises: + ParseError: In case of convert problems. + """ + if not isinstance(value, dict): + raise ParseError( + 'Map field {0} must be in a dict which is {1}.'.format( + field.name, value)) + key_field = field.message_type.fields_by_name['key'] + value_field = field.message_type.fields_by_name['value'] + for key in value: + key_value = _ConvertScalarFieldValue(key, key_field, True) + if value_field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + _ConvertMessage(value[key], getattr(message, field.name)[key_value]) + else: + getattr(message, field.name)[key_value] = _ConvertScalarFieldValue( + value[key], value_field) + + +def _ConvertScalarFieldValue(value, field, require_str=False): + """Convert a single scalar field value. + + Args: + value: A scalar value to convert the scalar field value. + field: The descriptor of the field to convert. + require_str: If True, the field value must be a str. + + Returns: + The converted scalar field value + + Raises: + ParseError: In case of convert problems. + """ + if field.cpp_type in _INT_TYPES: + return _ConvertInteger(value) + elif field.cpp_type in _FLOAT_TYPES: + return _ConvertFloat(value) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + return _ConvertBool(value, require_str) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + return base64.b64decode(value) + else: + return value + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + # Convert an enum value. + enum_value = field.enum_type.values_by_name.get(value, None) + if enum_value is None: + raise ParseError( + 'Enum value must be a string literal with double quotes. ' + 'Type "{0}" has no value named {1}.'.format( + field.enum_type.full_name, value)) + return enum_value.number + + +def _ConvertInteger(value): + """Convert an integer. + + Args: + value: A scalar value to convert. + + Returns: + The integer value. + + Raises: + ParseError: If an integer couldn't be consumed. + """ + if isinstance(value, float): + raise ParseError('Couldn\'t parse integer: {0}.'.format(value)) + + if isinstance(value, six.text_type) and value.find(' ') != -1: + raise ParseError('Couldn\'t parse integer: "{0}".'.format(value)) + + return int(value) + + +def _ConvertFloat(value): + """Convert an floating point number.""" + if value == 'nan': + raise ParseError('Couldn\'t parse float "nan", use "NaN" instead.') + try: + # Assume Python compatible syntax. + return float(value) + except ValueError: + # Check alternative spellings. + if value == _NEG_INFINITY: + return float('-inf') + elif value == _INFINITY: + return float('inf') + elif value == _NAN: + return float('nan') + else: + raise ParseError('Couldn\'t parse float: {0}.'.format(value)) + + +def _ConvertBool(value, require_str): + """Convert a boolean value. + + Args: + value: A scalar value to convert. + require_str: If True, value must be a str. + + Returns: + The bool parsed. + + Raises: + ParseError: If a boolean value couldn't be consumed. + """ + if require_str: + if value == 'true': + return True + elif value == 'false': + return False + else: + raise ParseError('Expected "true" or "false", not {0}.'.format(value)) + + if not isinstance(value, bool): + raise ParseError('Expected true or false without quotes.') + return value + +_WKTJSONMETHODS = { + 'google.protobuf.Any': [_AnyMessageToJsonObject, + _ConvertAnyMessage], + 'google.protobuf.Duration': [_GenericMessageToJsonObject, + _ConvertGenericMessage], + 'google.protobuf.FieldMask': [_GenericMessageToJsonObject, + _ConvertGenericMessage], + 'google.protobuf.ListValue': [_ListValueMessageToJsonObject, + _ConvertListValueMessage], + 'google.protobuf.Struct': [_StructMessageToJsonObject, + _ConvertStructMessage], + 'google.protobuf.Timestamp': [_GenericMessageToJsonObject, + _ConvertGenericMessage], + 'google.protobuf.Value': [_ValueMessageToJsonObject, + _ConvertValueMessage] +} diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py index c186452a7..606f735f3 100755 --- a/python/google/protobuf/message.py +++ b/python/google/protobuf/message.py @@ -36,7 +36,6 @@ __author__ = 'robinson@google.com (Will Robinson)' - class Error(Exception): pass class DecodeError(Error): pass class EncodeError(Error): pass @@ -233,12 +232,21 @@ class Message(object): raise NotImplementedError def HasField(self, field_name): - """Checks if a certain field is set for the message. Note if the - field_name is not defined in the message descriptor, ValueError will be - raised.""" + """Checks if a certain field is set for the message, or if any field inside + a oneof group is set. Note that if the field_name is not defined in the + message descriptor, ValueError will be raised.""" raise NotImplementedError def ClearField(self, field_name): + """Clears the contents of a given field, or the field set inside a oneof + group. If the name neither refers to a defined field or oneof group, + ValueError is raised.""" + raise NotImplementedError + + def WhichOneof(self, oneof_group): + """Returns the name of the field that is set inside a oneof group, or + None if no field is set. If no group with the given name exists, ValueError + will be raised.""" raise NotImplementedError def HasExtension(self, extension_handle): @@ -247,6 +255,9 @@ class Message(object): def ClearExtension(self, extension_handle): raise NotImplementedError + def DiscardUnknownFields(self): + raise NotImplementedError + def ByteSize(self): """Returns the serialized size of this message. Recursively calls ByteSize() on all contained messages. diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py index 7fd7bec0b..1b059d130 100644 --- a/python/google/protobuf/message_factory.py +++ b/python/google/protobuf/message_factory.py @@ -28,10 +28,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#PY25 compatible for GAE. -# -# Copyright 2012 Google Inc. All Rights Reserved. - """Provides a factory class for generating dynamic messages. The easiest way to use this class is if you have access to the FileDescriptor @@ -43,8 +39,6 @@ my_proto_instance = message_classes['some.proto.package.MessageName']() __author__ = 'matthewtoia@google.com (Matt Toia)' -import sys ##PY25 -from google.protobuf import descriptor_database from google.protobuf import descriptor_pool from google.protobuf import message from google.protobuf import reflection @@ -55,8 +49,7 @@ class MessageFactory(object): def __init__(self, pool=None): """Initializes a new factory.""" - self.pool = (pool or descriptor_pool.DescriptorPool( - descriptor_database.DescriptorDatabase())) + self.pool = pool or descriptor_pool.DescriptorPool() # local cache of all classes built from protobuf descriptors self._classes = {} @@ -75,8 +68,7 @@ class MessageFactory(object): """ if descriptor.full_name not in self._classes: descriptor_name = descriptor.name - if sys.version_info[0] < 3: ##PY25 -##!PY25 if str is bytes: # PY2 + if str is bytes: # PY2 descriptor_name = descriptor.name.encode('ascii', 'ignore') result_class = reflection.GeneratedProtocolMessageType( descriptor_name, @@ -111,7 +103,7 @@ class MessageFactory(object): result = {} for file_name in files: file_desc = self.pool.FindFileByName(file_name) - for name, msg in file_desc.message_types_by_name.iteritems(): + for name, msg in file_desc.message_types_by_name.items(): if file_desc.package: full_name = '.'.join([file_desc.package, name]) else: @@ -128,7 +120,7 @@ class MessageFactory(object): # ignore the registration if the original was the same, or raise # an error if they were different. - for name, extension in file_desc.extensions_by_name.iteritems(): + for name, extension in file_desc.extensions_by_name.items(): if extension.containing_type.full_name not in self._classes: self.GetPrototype(extension.containing_type) extended_class = self._classes[extension.containing_type.full_name] diff --git a/python/google/protobuf/proto_builder.py b/python/google/protobuf/proto_builder.py new file mode 100644 index 000000000..736caed38 --- /dev/null +++ b/python/google/protobuf/proto_builder.py @@ -0,0 +1,130 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Dynamic Protobuf class creator.""" + +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict #PY26 +import hashlib +import os + +from google.protobuf import descriptor_pb2 +from google.protobuf import message_factory + + +def _GetMessageFromFactory(factory, full_name): + """Get a proto class from the MessageFactory by name. + + Args: + factory: a MessageFactory instance. + full_name: str, the fully qualified name of the proto type. + Returns: + A class, for the type identified by full_name. + Raises: + KeyError, if the proto is not found in the factory's descriptor pool. + """ + proto_descriptor = factory.pool.FindMessageTypeByName(full_name) + proto_cls = factory.GetPrototype(proto_descriptor) + return proto_cls + + +def MakeSimpleProtoClass(fields, full_name=None, pool=None): + """Create a Protobuf class whose fields are basic types. + + Note: this doesn't validate field names! + + Args: + fields: dict of {name: field_type} mappings for each field in the proto. If + this is an OrderedDict the order will be maintained, otherwise the + fields will be sorted by name. + full_name: optional str, the fully-qualified name of the proto type. + pool: optional DescriptorPool instance. + Returns: + a class, the new protobuf class with a FileDescriptor. + """ + factory = message_factory.MessageFactory(pool=pool) + + if full_name is not None: + try: + proto_cls = _GetMessageFromFactory(factory, full_name) + return proto_cls + except KeyError: + # The factory's DescriptorPool doesn't know about this class yet. + pass + + # Get a list of (name, field_type) tuples from the fields dict. If fields was + # an OrderedDict we keep the order, but otherwise we sort the field to ensure + # consistent ordering. + field_items = fields.items() + if not isinstance(fields, OrderedDict): + field_items = sorted(field_items) + + # Use a consistent file name that is unlikely to conflict with any imported + # proto files. + fields_hash = hashlib.sha1() + for f_name, f_type in field_items: + fields_hash.update(f_name.encode('utf-8')) + fields_hash.update(str(f_type).encode('utf-8')) + proto_file_name = fields_hash.hexdigest() + '.proto' + + # If the proto is anonymous, use the same hash to name it. + if full_name is None: + full_name = ('net.proto2.python.public.proto_builder.AnonymousProto_' + + fields_hash.hexdigest()) + try: + proto_cls = _GetMessageFromFactory(factory, full_name) + return proto_cls + except KeyError: + # The factory's DescriptorPool doesn't know about this class yet. + pass + + # This is the first time we see this proto: add a new descriptor to the pool. + factory.pool.Add( + _MakeFileDescriptorProto(proto_file_name, full_name, field_items)) + return _GetMessageFromFactory(factory, full_name) + + +def _MakeFileDescriptorProto(proto_file_name, full_name, field_items): + """Populate FileDescriptorProto for MessageFactory's DescriptorPool.""" + package, name = full_name.rsplit('.', 1) + file_proto = descriptor_pb2.FileDescriptorProto() + file_proto.name = os.path.join(package.replace('.', '/'), proto_file_name) + file_proto.package = package + desc_proto = file_proto.message_type.add() + desc_proto.name = name + for f_number, (f_name, f_type) in enumerate(field_items, 1): + field_proto = desc_proto.field.add() + field_proto.name = f_name + field_proto.number = f_number + field_proto.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + field_proto.type = f_type + return file_proto diff --git a/python/google/protobuf/pyext/__init__.py b/python/google/protobuf/pyext/__init__.py index e69de29bb..558561412 100644 --- a/python/google/protobuf/pyext/__init__.py +++ b/python/google/protobuf/pyext/__init__.py @@ -0,0 +1,4 @@ +try: + __import__('pkg_resources').declare_namespace(__name__) +except ImportError: + __path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/python/google/protobuf/pyext/cpp_message.py b/python/google/protobuf/pyext/cpp_message.py index dcf34a023..b215211ee 100644 --- a/python/google/protobuf/pyext/cpp_message.py +++ b/python/google/protobuf/pyext/cpp_message.py @@ -37,25 +37,29 @@ Descriptor objects at runtime backed by the protocol buffer C++ API. __author__ = 'tibell@google.com (Johan Tibell)' from google.protobuf.pyext import _message -from google.protobuf import message -def NewMessage(bases, message_descriptor, dictionary): - """Creates a new protocol message *class*.""" - new_bases = [] - for base in bases: - if base is message.Message: - # _message.Message must come before message.Message as it - # overrides methods in that class. - new_bases.append(_message.Message) - new_bases.append(base) - return tuple(new_bases) +class GeneratedProtocolMessageType(_message.MessageMeta): + """Metaclass for protocol message classes created at runtime from Descriptors. -def InitMessage(message_descriptor, cls): - """Constructs a new message instance (called before instance's __init__).""" + The protocol compiler currently uses this metaclass to create protocol + message classes at runtime. Clients can also manually create their own + classes at runtime, as in this example: - def SubInit(self, **kwargs): - super(cls, self).__init__(message_descriptor, **kwargs) - cls.__init__ = SubInit - cls.AddDescriptors(message_descriptor) + mydescriptor = Descriptor(.....) + class MyProtoClass(Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = mydescriptor + myproto_instance = MyProtoClass() + myproto.foo_field = 23 + ... + + The above example will not work for nested types. If you wish to include them, + use reflection.MakeClass() instead of manually instantiating the class in + order to create the appropriate class structure. + """ + + # Must be consistent with the protocol-compiler code in + # proto2/compiler/internal/generator.*. + _DESCRIPTOR_KEY = 'DESCRIPTOR' diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index 3f7be73cb..235575389 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -31,29 +31,119 @@ // Author: petar@google.com (Petar Petrov) #include <Python.h> +#include <frameobject.h> #include <string> +#include <google/protobuf/io/coded_stream.h> #include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/dynamic_message.h> #include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/descriptor_containers.h> +#include <google/protobuf/pyext/descriptor_pool.h> +#include <google/protobuf/pyext/message.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> -#define C(str) const_cast<char*>(str) - #if PY_MAJOR_VERSION >= 3 #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #define PyString_Check PyUnicode_Check + #define PyString_InternFromString PyUnicode_InternFromString #define PyInt_FromLong PyLong_FromLong + #define PyInt_FromSize_t PyLong_FromSize_t #if PY_VERSION_HEX < 0x03030000 #error "Python 3.0 - 3.2 are not supported." - #else - #define PyString_AsString(ob) \ - (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob)) #endif + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) #endif namespace google { namespace protobuf { namespace python { +// Store interned descriptors, so that the same C++ descriptor yields the same +// Python object. Objects are not immortal: this map does not own the +// references, and items are deleted when the last reference to the object is +// released. +// This is enough to support the "is" operator on live objects. +// All descriptors are stored here. +hash_map<const void*, PyObject*> interned_descriptors; + +PyObject* PyString_FromCppString(const string& str) { + return PyString_FromStringAndSize(str.c_str(), str.size()); +} + +// Check that the calling Python code is the global scope of a _pb2.py module. +// This function is used to support the current code generated by the proto +// compiler, which creates descriptors, then update some properties. +// For example: +// message_descriptor = Descriptor( +// name='Message', +// fields = [FieldDescriptor(name='field')] +// message_descriptor.fields[0].containing_type = message_descriptor +// +// This code is still executed, but the descriptors now have no other storage +// than the (const) C++ pointer, and are immutable. +// So we let this code pass, by simply ignoring the new value. +// +// From user code, descriptors still look immutable. +// +// TODO(amauryfa): Change the proto2 compiler to remove the assignments, and +// remove this hack. +bool _CalledFromGeneratedFile(int stacklevel) { +#ifndef PYPY_VERSION + // This check is not critical and is somewhat difficult to implement correctly + // in PyPy. + PyFrameObject* frame = PyEval_GetFrame(); + if (frame == NULL) { + return false; + } + while (stacklevel-- > 0) { + frame = frame->f_back; + if (frame == NULL) { + return false; + } + } + if (frame->f_globals != frame->f_locals) { + // Not at global module scope + return false; + } + + if (frame->f_code->co_filename == NULL) { + return false; + } + char* filename; + Py_ssize_t filename_size; + if (PyString_AsStringAndSize(frame->f_code->co_filename, + &filename, &filename_size) < 0) { + // filename is not a string. + PyErr_Clear(); + return false; + } + if (filename_size < 7) { + // filename is too short. + return false; + } + if (strcmp(&filename[filename_size - 7], "_pb2.py") != 0) { + // Filename is not ending with _pb2. + return false; + } +#endif + return true; +} + +// If the calling code is not a _pb2.py file, raise AttributeError. +// To be used in attribute setters. +static int CheckCalledFromGeneratedFile(const char* attr_name) { + if (_CalledFromGeneratedFile(0)) { + return 0; + } + PyErr_Format(PyExc_AttributeError, + "attribute is not writable: %s", attr_name); + return -1; +} + #ifndef PyVarObject_HEAD_INIT #define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, @@ -63,58 +153,484 @@ namespace python { #endif -static google::protobuf::DescriptorPool* g_descriptor_pool = NULL; +// Helper functions for descriptor objects. + +// A set of templates to retrieve the C++ FileDescriptor of any descriptor. +template<class DescriptorClass> +const FileDescriptor* GetFileDescriptor(const DescriptorClass* descriptor) { + return descriptor->file(); +} +template<> +const FileDescriptor* GetFileDescriptor(const FileDescriptor* descriptor) { + return descriptor; +} +template<> +const FileDescriptor* GetFileDescriptor(const EnumValueDescriptor* descriptor) { + return descriptor->type()->file(); +} +template<> +const FileDescriptor* GetFileDescriptor(const OneofDescriptor* descriptor) { + return descriptor->containing_type()->file(); +} + +// Converts options into a Python protobuf, and cache the result. +// +// This is a bit tricky because options can contain extension fields defined in +// the same proto file. In this case the options parsed from the serialized_pb +// have unkown fields, and we need to parse them again. +// +// Always returns a new reference. +template<class DescriptorClass> +static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { + // Options (and their extensions) are completely resolved in the proto file + // containing the descriptor. + PyDescriptorPool* pool = GetDescriptorPool_FromPool( + GetFileDescriptor(descriptor)->pool()); + + hash_map<const void*, PyObject*>* descriptor_options = + pool->descriptor_options; + // First search in the cache. + if (descriptor_options->find(descriptor) != descriptor_options->end()) { + PyObject *value = (*descriptor_options)[descriptor]; + Py_INCREF(value); + return value; + } + + // Build the Options object: get its Python class, and make a copy of the C++ + // read-only instance. + const Message& options(descriptor->options()); + const Descriptor *message_type = options.GetDescriptor(); + CMessageClass* message_class( + cdescriptor_pool::GetMessageClass(pool, message_type)); + if (message_class == NULL) { + // The Options message was not found in the current DescriptorPool. + // In this case, there cannot be extensions to these options, and we can + // try to use the basic pool instead. + PyErr_Clear(); + message_class = cdescriptor_pool::GetMessageClass( + GetDefaultDescriptorPool(), message_type); + } + if (message_class == NULL) { + PyErr_Format(PyExc_TypeError, "Could not retrieve class for Options: %s", + message_type->full_name().c_str()); + return NULL; + } + ScopedPyObjectPtr value( + PyEval_CallObject(message_class->AsPyObject(), NULL)); + if (value == NULL) { + return NULL; + } + if (!PyObject_TypeCheck(value.get(), &CMessage_Type)) { + PyErr_Format(PyExc_TypeError, "Invalid class for %s: %s", + message_type->full_name().c_str(), + Py_TYPE(value.get())->tp_name); + return NULL; + } + CMessage* cmsg = reinterpret_cast<CMessage*>(value.get()); + + const Reflection* reflection = options.GetReflection(); + const UnknownFieldSet& unknown_fields(reflection->GetUnknownFields(options)); + if (unknown_fields.empty()) { + cmsg->message->CopyFrom(options); + } else { + // Reparse options string! XXX call cmessage::MergeFromString + string serialized; + options.SerializeToString(&serialized); + io::CodedInputStream input( + reinterpret_cast<const uint8*>(serialized.c_str()), serialized.size()); + input.SetExtensionRegistry(pool->pool, pool->message_factory); + bool success = cmsg->message->MergePartialFromCodedStream(&input); + if (!success) { + PyErr_Format(PyExc_ValueError, "Error parsing Options message"); + return NULL; + } + } + + // Cache the result. + Py_INCREF(value.get()); + (*pool->descriptor_options)[descriptor] = value.get(); + + return value.release(); +} + +// Copy the C++ descriptor to a Python message. +// The Python message is an instance of descriptor_pb2.DescriptorProto +// or similar. +template<class DescriptorProtoClass, class DescriptorClass> +static PyObject* CopyToPythonProto(const DescriptorClass *descriptor, + PyObject *target) { + const Descriptor* self_descriptor = + DescriptorProtoClass::default_instance().GetDescriptor(); + CMessage* message = reinterpret_cast<CMessage*>(target); + if (!PyObject_TypeCheck(target, &CMessage_Type) || + message->message->GetDescriptor() != self_descriptor) { + PyErr_Format(PyExc_TypeError, "Not a %s message", + self_descriptor->full_name().c_str()); + return NULL; + } + cmessage::AssureWritable(message); + DescriptorProtoClass* descriptor_message = + static_cast<DescriptorProtoClass*>(message->message); + descriptor->CopyTo(descriptor_message); + Py_RETURN_NONE; +} + +// All Descriptors classes share the same memory layout. +typedef struct PyBaseDescriptor { + PyObject_HEAD + + // Pointer to the C++ proto2 descriptor. + // Like all descriptors, it is owned by the global DescriptorPool. + const void* descriptor; + + // Owned reference to the DescriptorPool, to ensure it is kept alive. + PyDescriptorPool* pool; +} PyBaseDescriptor; + + +// FileDescriptor structure "inherits" from the base descriptor. +typedef struct PyFileDescriptor { + PyBaseDescriptor base; + + // The cached version of serialized pb. Either NULL, or a Bytes string. + // We own the reference. + PyObject *serialized_pb; +} PyFileDescriptor; -namespace cfield_descriptor { -static void Dealloc(CFieldDescriptor* self) { - Py_CLEAR(self->descriptor_field); +namespace descriptor { + +// Creates or retrieve a Python descriptor of the specified type. +// Objects are interned: the same descriptor will return the same object if it +// was kept alive. +// 'was_created' is an optional pointer to a bool, and is set to true if a new +// object was allocated. +// Always return a new reference. +template<class DescriptorClass> +PyObject* NewInternedDescriptor(PyTypeObject* type, + const DescriptorClass* descriptor, + bool* was_created) { + if (was_created) { + *was_created = false; + } + if (descriptor == NULL) { + PyErr_BadInternalCall(); + return NULL; + } + + // See if the object is in the map of interned descriptors + hash_map<const void*, PyObject*>::iterator it = + interned_descriptors.find(descriptor); + if (it != interned_descriptors.end()) { + GOOGLE_DCHECK(Py_TYPE(it->second) == type); + Py_INCREF(it->second); + return it->second; + } + // Create a new descriptor object + PyBaseDescriptor* py_descriptor = PyObject_New( + PyBaseDescriptor, type); + if (py_descriptor == NULL) { + return NULL; + } + py_descriptor->descriptor = descriptor; + + // and cache it. + interned_descriptors.insert( + std::make_pair(descriptor, reinterpret_cast<PyObject*>(py_descriptor))); + + // Ensures that the DescriptorPool stays alive. + PyDescriptorPool* pool = GetDescriptorPool_FromPool( + GetFileDescriptor(descriptor)->pool()); + if (pool == NULL) { + // Don't DECREF, the object is not fully initialized. + PyObject_Del(py_descriptor); + return NULL; + } + Py_INCREF(pool); + py_descriptor->pool = pool; + + if (was_created) { + *was_created = true; + } + return reinterpret_cast<PyObject*>(py_descriptor); +} + +static void Dealloc(PyBaseDescriptor* self) { + // Remove from interned dictionary + interned_descriptors.erase(self->descriptor); + Py_CLEAR(self->pool); Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); } -static PyObject* GetFullName(CFieldDescriptor* self, void *closure) { - return PyString_FromStringAndSize( - self->descriptor->full_name().c_str(), - self->descriptor->full_name().size()); +static PyGetSetDef Getters[] = { + {NULL} +}; + +PyTypeObject PyBaseDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".DescriptorBase", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + (destructor)Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "Descriptors base class", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + Getters, // tp_getset +}; + +} // namespace descriptor + +const void* PyDescriptor_AsVoidPtr(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &descriptor::PyBaseDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a BaseDescriptor"); + return NULL; + } + return reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor; } -static PyObject* GetName(CFieldDescriptor *self, void *closure) { - return PyString_FromStringAndSize( - self->descriptor->name().c_str(), - self->descriptor->name().size()); +namespace message_descriptor { + +// Unchecked accessor to the C++ pointer. +static const Descriptor* _GetDescriptor(PyBaseDescriptor* self) { + return reinterpret_cast<const Descriptor*>(self->descriptor); } -static PyObject* GetCppType(CFieldDescriptor *self, void *closure) { - return PyInt_FromLong(self->descriptor->cpp_type()); +static PyObject* GetName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); } -static PyObject* GetLabel(CFieldDescriptor *self, void *closure) { - return PyInt_FromLong(self->descriptor->label()); +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); } -static PyObject* GetID(CFieldDescriptor *self, void *closure) { - return PyLong_FromVoidPtr(self); +static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { + return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file()); +} + +static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) { + // Retuns the canonical class for the given descriptor. + // This is the class that was registered with the primary descriptor pool + // which contains this descriptor. + // This might not be the one you expect! For example the returned object does + // not know about extensions defined in a custom pool. + CMessageClass* concrete_class(cdescriptor_pool::GetMessageClass( + GetDescriptorPool_FromPool(_GetDescriptor(self)->file()->pool()), + _GetDescriptor(self))); + Py_XINCREF(concrete_class); + return concrete_class->AsPyObject(); +} + +static PyObject* GetFieldsByName(PyBaseDescriptor* self, void *closure) { + return NewMessageFieldsByName(_GetDescriptor(self)); +} + +static PyObject* GetFieldsByCamelcaseName(PyBaseDescriptor* self, + void *closure) { + return NewMessageFieldsByCamelcaseName(_GetDescriptor(self)); +} + +static PyObject* GetFieldsByNumber(PyBaseDescriptor* self, void *closure) { + return NewMessageFieldsByNumber(_GetDescriptor(self)); +} + +static PyObject* GetFieldsSeq(PyBaseDescriptor* self, void *closure) { + return NewMessageFieldsSeq(_GetDescriptor(self)); +} + +static PyObject* GetNestedTypesByName(PyBaseDescriptor* self, void *closure) { + return NewMessageNestedTypesByName(_GetDescriptor(self)); +} + +static PyObject* GetNestedTypesSeq(PyBaseDescriptor* self, void *closure) { + return NewMessageNestedTypesSeq(_GetDescriptor(self)); +} + +static PyObject* GetExtensionsByName(PyBaseDescriptor* self, void *closure) { + return NewMessageExtensionsByName(_GetDescriptor(self)); +} + +static PyObject* GetExtensions(PyBaseDescriptor* self, void *closure) { + return NewMessageExtensionsSeq(_GetDescriptor(self)); +} + +static PyObject* GetEnumsSeq(PyBaseDescriptor* self, void *closure) { + return NewMessageEnumsSeq(_GetDescriptor(self)); +} + +static PyObject* GetEnumTypesByName(PyBaseDescriptor* self, void *closure) { + return NewMessageEnumsByName(_GetDescriptor(self)); +} + +static PyObject* GetEnumValuesByName(PyBaseDescriptor* self, void *closure) { + return NewMessageEnumValuesByName(_GetDescriptor(self)); +} + +static PyObject* GetOneofsByName(PyBaseDescriptor* self, void *closure) { + return NewMessageOneofsByName(_GetDescriptor(self)); +} + +static PyObject* GetOneofsSeq(PyBaseDescriptor* self, void *closure) { + return NewMessageOneofsSeq(_GetDescriptor(self)); +} + +static PyObject* IsExtendable(PyBaseDescriptor *self, void *closure) { + if (_GetDescriptor(self)->extension_range_count() > 0) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +static PyObject* GetExtensionRanges(PyBaseDescriptor *self, void *closure) { + const Descriptor* descriptor = _GetDescriptor(self); + PyObject* range_list = PyList_New(descriptor->extension_range_count()); + + for (int i = 0; i < descriptor->extension_range_count(); i++) { + const Descriptor::ExtensionRange* range = descriptor->extension_range(i); + PyObject* start = PyInt_FromLong(range->start); + PyObject* end = PyInt_FromLong(range->end); + PyList_SetItem(range_list, i, PyTuple_Pack(2, start, end)); + } + + return range_list; +} + +static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { + const Descriptor* containing_type = + _GetDescriptor(self)->containing_type(); + if (containing_type) { + return PyMessageDescriptor_FromDescriptor(containing_type); + } else { + Py_RETURN_NONE; + } +} + +static int SetContainingType(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("containing_type"); +} + +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const MessageOptions& options(_GetDescriptor(self)->options()); + if (&options != &MessageOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} + +static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) { + return CopyToPythonProto<DescriptorProto>(_GetDescriptor(self), target); +} + +static PyObject* EnumValueName(PyBaseDescriptor *self, PyObject *args) { + const char *enum_name; + int number; + if (!PyArg_ParseTuple(args, "si", &enum_name, &number)) + return NULL; + const EnumDescriptor *enum_type = + _GetDescriptor(self)->FindEnumTypeByName(enum_name); + if (enum_type == NULL) { + PyErr_SetString(PyExc_KeyError, enum_name); + return NULL; + } + const EnumValueDescriptor *enum_value = + enum_type->FindValueByNumber(number); + if (enum_value == NULL) { + PyErr_Format(PyExc_KeyError, "%d", number); + return NULL; + } + return PyString_FromCppString(enum_value->name()); +} + +static PyObject* GetSyntax(PyBaseDescriptor *self, void *closure) { + return PyString_InternFromString( + FileDescriptor::SyntaxName(_GetDescriptor(self)->file()->syntax())); } static PyGetSetDef Getters[] = { - { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL}, - { C("name"), (getter)GetName, NULL, "last name", NULL}, - { C("cpp_type"), (getter)GetCppType, NULL, "C++ Type", NULL}, - { C("label"), (getter)GetLabel, NULL, "Label", NULL}, - { C("id"), (getter)GetID, NULL, "ID", NULL}, + { "name", (getter)GetName, NULL, "Last name"}, + { "full_name", (getter)GetFullName, NULL, "Full name"}, + { "_concrete_class", (getter)GetConcreteClass, NULL, "concrete class"}, + { "file", (getter)GetFile, NULL, "File descriptor"}, + + { "fields", (getter)GetFieldsSeq, NULL, "Fields sequence"}, + { "fields_by_name", (getter)GetFieldsByName, NULL, "Fields by name"}, + { "fields_by_camelcase_name", (getter)GetFieldsByCamelcaseName, NULL, + "Fields by camelCase name"}, + { "fields_by_number", (getter)GetFieldsByNumber, NULL, "Fields by number"}, + { "nested_types", (getter)GetNestedTypesSeq, NULL, "Nested types sequence"}, + { "nested_types_by_name", (getter)GetNestedTypesByName, NULL, + "Nested types by name"}, + { "extensions", (getter)GetExtensions, NULL, "Extensions Sequence"}, + { "extensions_by_name", (getter)GetExtensionsByName, NULL, + "Extensions by name"}, + { "extension_ranges", (getter)GetExtensionRanges, NULL, "Extension ranges"}, + { "enum_types", (getter)GetEnumsSeq, NULL, "Enum sequence"}, + { "enum_types_by_name", (getter)GetEnumTypesByName, NULL, + "Enum types by name"}, + { "enum_values_by_name", (getter)GetEnumValuesByName, NULL, + "Enum values by name"}, + { "oneofs_by_name", (getter)GetOneofsByName, NULL, "Oneofs by name"}, + { "oneofs", (getter)GetOneofsSeq, NULL, "Oneofs by name"}, + { "containing_type", (getter)GetContainingType, (setter)SetContainingType, + "Containing type"}, + { "is_extendable", (getter)IsExtendable, (setter)NULL}, + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + { "syntax", (getter)GetSyntax, (setter)NULL, "Syntax"}, {NULL} }; -} // namespace cfield_descriptor +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, + { "EnumValueName", (PyCFunction)EnumValueName, METH_VARARGS, }, + {NULL} +}; + +} // namespace message_descriptor -PyTypeObject CFieldDescriptor_Type = { +PyTypeObject PyMessageDescriptor_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - C("google.protobuf.internal." - "_net_proto2___python." - "CFieldDescriptor"), // tp_name - sizeof(CFieldDescriptor), // tp_basicsize + FULL_MODULE_NAME ".MessageDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize 0, // tp_itemsize - (destructor)cfield_descriptor::Dealloc, // tp_dealloc + 0, // tp_dealloc 0, // tp_print 0, // tp_getattr 0, // tp_setattr @@ -130,223 +646,934 @@ PyTypeObject CFieldDescriptor_Type = { 0, // tp_setattro 0, // tp_as_buffer Py_TPFLAGS_DEFAULT, // tp_flags - C("A Field Descriptor"), // tp_doc + "A Message Descriptor", // tp_doc 0, // tp_traverse 0, // tp_clear 0, // tp_richcompare 0, // tp_weaklistoffset 0, // tp_iter 0, // tp_iternext - 0, // tp_methods + message_descriptor::Methods, // tp_methods 0, // tp_members - cfield_descriptor::Getters, // tp_getset - 0, // tp_base - 0, // tp_dict - 0, // tp_descr_get - 0, // tp_descr_set - 0, // tp_dictoffset - 0, // tp_init - PyType_GenericAlloc, // tp_alloc - PyType_GenericNew, // tp_new - PyObject_Del, // tp_free + message_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base }; -namespace cdescriptor_pool { - -static void Dealloc(CDescriptorPool* self) { - Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +PyObject* PyMessageDescriptor_FromDescriptor( + const Descriptor* message_descriptor) { + return descriptor::NewInternedDescriptor( + &PyMessageDescriptor_Type, message_descriptor, NULL); } -static PyObject* NewCDescriptor( - const google::protobuf::FieldDescriptor* field_descriptor) { - CFieldDescriptor* cfield_descriptor = PyObject_New( - CFieldDescriptor, &CFieldDescriptor_Type); - if (cfield_descriptor == NULL) { +const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyMessageDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a MessageDescriptor"); return NULL; } - cfield_descriptor->descriptor = field_descriptor; - cfield_descriptor->descriptor_field = NULL; + return reinterpret_cast<const Descriptor*>( + reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor); +} + +namespace field_descriptor { - return reinterpret_cast<PyObject*>(cfield_descriptor); +// Unchecked accessor to the C++ pointer. +static const FieldDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast<const FieldDescriptor*>(self->descriptor); } -PyObject* FindFieldByName(CDescriptorPool* self, PyObject* name) { - const char* full_field_name = PyString_AsString(name); - if (full_field_name == NULL) { - return NULL; +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetName(PyBaseDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetCamelcaseName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->camelcase_name()); +} + +static PyObject* GetType(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->type()); +} + +static PyObject* GetCppType(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->cpp_type()); +} + +static PyObject* GetLabel(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->label()); +} + +static PyObject* GetNumber(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->number()); +} + +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); +} + +static PyObject* GetID(PyBaseDescriptor *self, void *closure) { + return PyLong_FromVoidPtr(self); +} + +static PyObject* IsExtension(PyBaseDescriptor *self, void *closure) { + return PyBool_FromLong(_GetDescriptor(self)->is_extension()); +} + +static PyObject* HasDefaultValue(PyBaseDescriptor *self, void *closure) { + return PyBool_FromLong(_GetDescriptor(self)->has_default_value()); +} + +static PyObject* GetDefaultValue(PyBaseDescriptor *self, void *closure) { + PyObject *result; + + switch (_GetDescriptor(self)->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + int32 value = _GetDescriptor(self)->default_value_int32(); + result = PyInt_FromLong(value); + break; + } + case FieldDescriptor::CPPTYPE_INT64: { + int64 value = _GetDescriptor(self)->default_value_int64(); + result = PyLong_FromLongLong(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT32: { + uint32 value = _GetDescriptor(self)->default_value_uint32(); + result = PyInt_FromSize_t(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT64: { + uint64 value = _GetDescriptor(self)->default_value_uint64(); + result = PyLong_FromUnsignedLongLong(value); + break; + } + case FieldDescriptor::CPPTYPE_FLOAT: { + float value = _GetDescriptor(self)->default_value_float(); + result = PyFloat_FromDouble(value); + break; + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + double value = _GetDescriptor(self)->default_value_double(); + result = PyFloat_FromDouble(value); + break; + } + case FieldDescriptor::CPPTYPE_BOOL: { + bool value = _GetDescriptor(self)->default_value_bool(); + result = PyBool_FromLong(value); + break; + } + case FieldDescriptor::CPPTYPE_STRING: { + string value = _GetDescriptor(self)->default_value_string(); + result = ToStringObject(_GetDescriptor(self), value); + break; + } + case FieldDescriptor::CPPTYPE_ENUM: { + const EnumValueDescriptor* value = + _GetDescriptor(self)->default_value_enum(); + result = PyInt_FromLong(value->number()); + break; + } + default: + PyErr_Format(PyExc_NotImplementedError, "default value for %s", + _GetDescriptor(self)->full_name().c_str()); + return NULL; + } + return result; +} + +static PyObject* GetCDescriptor(PyObject *self, void *closure) { + Py_INCREF(self); + return self; +} + +static PyObject *GetEnumType(PyBaseDescriptor *self, void *closure) { + const EnumDescriptor* enum_type = _GetDescriptor(self)->enum_type(); + if (enum_type) { + return PyEnumDescriptor_FromDescriptor(enum_type); + } else { + Py_RETURN_NONE; } +} - const google::protobuf::FieldDescriptor* field_descriptor = NULL; +static int SetEnumType(PyBaseDescriptor *self, PyObject *value, void *closure) { + return CheckCalledFromGeneratedFile("enum_type"); +} - field_descriptor = self->pool->FindFieldByName(full_field_name); +static PyObject *GetMessageType(PyBaseDescriptor *self, void *closure) { + const Descriptor* message_type = _GetDescriptor(self)->message_type(); + if (message_type) { + return PyMessageDescriptor_FromDescriptor(message_type); + } else { + Py_RETURN_NONE; + } +} - if (field_descriptor == NULL) { - PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", - full_field_name); - return NULL; +static int SetMessageType(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("message_type"); +} + +static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { + const Descriptor* containing_type = + _GetDescriptor(self)->containing_type(); + if (containing_type) { + return PyMessageDescriptor_FromDescriptor(containing_type); + } else { + Py_RETURN_NONE; } +} - return NewCDescriptor(field_descriptor); +static int SetContainingType(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("containing_type"); } -PyObject* FindExtensionByName(CDescriptorPool* self, PyObject* arg) { - const char* full_field_name = PyString_AsString(arg); - if (full_field_name == NULL) { - return NULL; +static PyObject* GetExtensionScope(PyBaseDescriptor *self, void *closure) { + const Descriptor* extension_scope = + _GetDescriptor(self)->extension_scope(); + if (extension_scope) { + return PyMessageDescriptor_FromDescriptor(extension_scope); + } else { + Py_RETURN_NONE; } +} + +static PyObject* GetContainingOneof(PyBaseDescriptor *self, void *closure) { + const OneofDescriptor* containing_oneof = + _GetDescriptor(self)->containing_oneof(); + if (containing_oneof) { + return PyOneofDescriptor_FromDescriptor(containing_oneof); + } else { + Py_RETURN_NONE; + } +} + +static int SetContainingOneof(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("containing_oneof"); +} + +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const FieldOptions& options(_GetDescriptor(self)->options()); + if (&options != &FieldOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} + + +static PyGetSetDef Getters[] = { + { "full_name", (getter)GetFullName, NULL, "Full name"}, + { "name", (getter)GetName, NULL, "Unqualified name"}, + { "camelcase_name", (getter)GetCamelcaseName, NULL, "Camelcase name"}, + { "type", (getter)GetType, NULL, "C++ Type"}, + { "cpp_type", (getter)GetCppType, NULL, "C++ Type"}, + { "label", (getter)GetLabel, NULL, "Label"}, + { "number", (getter)GetNumber, NULL, "Number"}, + { "index", (getter)GetIndex, NULL, "Index"}, + { "default_value", (getter)GetDefaultValue, NULL, "Default Value"}, + { "has_default_value", (getter)HasDefaultValue}, + { "is_extension", (getter)IsExtension, NULL, "ID"}, + { "id", (getter)GetID, NULL, "ID"}, + { "_cdescriptor", (getter)GetCDescriptor, NULL, "HAACK REMOVE ME"}, + + { "message_type", (getter)GetMessageType, (setter)SetMessageType, + "Message type"}, + { "enum_type", (getter)GetEnumType, (setter)SetEnumType, "Enum type"}, + { "containing_type", (getter)GetContainingType, (setter)SetContainingType, + "Containing type"}, + { "extension_scope", (getter)GetExtensionScope, (setter)NULL, + "Extension scope"}, + { "containing_oneof", (getter)GetContainingOneof, (setter)SetContainingOneof, + "Containing oneof"}, + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + {NULL} +}; + +} // namespace field_descriptor + +PyTypeObject PyFieldDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".FieldDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Field Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + field_descriptor::Methods, // tp_methods + 0, // tp_members + field_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyFieldDescriptor_FromDescriptor( + const FieldDescriptor* field_descriptor) { + return descriptor::NewInternedDescriptor( + &PyFieldDescriptor_Type, field_descriptor, NULL); +} - const google::protobuf::FieldDescriptor* field_descriptor = - self->pool->FindExtensionByName(full_field_name); - if (field_descriptor == NULL) { - PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", - full_field_name); +const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyFieldDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a FieldDescriptor"); return NULL; } + return reinterpret_cast<const FieldDescriptor*>( + reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor); +} + +namespace enum_descriptor { + +// Unchecked accessor to the C++ pointer. +static const EnumDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast<const EnumDescriptor*>(self->descriptor); +} + +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetName(PyBaseDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { + return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file()); +} + +static PyObject* GetEnumvaluesByName(PyBaseDescriptor* self, void *closure) { + return NewEnumValuesByName(_GetDescriptor(self)); +} + +static PyObject* GetEnumvaluesByNumber(PyBaseDescriptor* self, void *closure) { + return NewEnumValuesByNumber(_GetDescriptor(self)); +} + +static PyObject* GetEnumvaluesSeq(PyBaseDescriptor* self, void *closure) { + return NewEnumValuesSeq(_GetDescriptor(self)); +} + +static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { + const Descriptor* containing_type = + _GetDescriptor(self)->containing_type(); + if (containing_type) { + return PyMessageDescriptor_FromDescriptor(containing_type); + } else { + Py_RETURN_NONE; + } +} + +static int SetContainingType(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("containing_type"); +} + + +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const EnumOptions& options(_GetDescriptor(self)->options()); + if (&options != &EnumOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} - return NewCDescriptor(field_descriptor); +static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) { + return CopyToPythonProto<EnumDescriptorProto>(_GetDescriptor(self), target); } static PyMethodDef Methods[] = { - { C("FindFieldByName"), - (PyCFunction)FindFieldByName, - METH_O, - C("Searches for a field descriptor by full name.") }, - { C("FindExtensionByName"), - (PyCFunction)FindExtensionByName, - METH_O, - C("Searches for extension descriptor by full name.") }, + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, {NULL} }; -} // namespace cdescriptor_pool +static PyGetSetDef Getters[] = { + { "full_name", (getter)GetFullName, NULL, "Full name"}, + { "name", (getter)GetName, NULL, "last name"}, + { "file", (getter)GetFile, NULL, "File descriptor"}, + { "values", (getter)GetEnumvaluesSeq, NULL, "values"}, + { "values_by_name", (getter)GetEnumvaluesByName, NULL, + "Enum values by name"}, + { "values_by_number", (getter)GetEnumvaluesByNumber, NULL, + "Enum values by number"}, + + { "containing_type", (getter)GetContainingType, (setter)SetContainingType, + "Containing type"}, + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + {NULL} +}; + +} // namespace enum_descriptor -PyTypeObject CDescriptorPool_Type = { +PyTypeObject PyEnumDescriptor_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - C("google.protobuf.internal." - "_net_proto2___python." - "CFieldDescriptor"), // tp_name - sizeof(CDescriptorPool), // tp_basicsize - 0, // tp_itemsize - (destructor)cdescriptor_pool::Dealloc, // tp_dealloc - 0, // tp_print - 0, // tp_getattr - 0, // tp_setattr - 0, // tp_compare - 0, // tp_repr - 0, // tp_as_number - 0, // tp_as_sequence - 0, // tp_as_mapping - 0, // tp_hash - 0, // tp_call - 0, // tp_str - 0, // tp_getattro - 0, // tp_setattro - 0, // tp_as_buffer - Py_TPFLAGS_DEFAULT, // tp_flags - C("A Descriptor Pool"), // tp_doc - 0, // tp_traverse - 0, // tp_clear - 0, // tp_richcompare - 0, // tp_weaklistoffset - 0, // tp_iter - 0, // tp_iternext - cdescriptor_pool::Methods, // tp_methods - 0, // tp_members - 0, // tp_getset - 0, // tp_base - 0, // tp_dict - 0, // tp_descr_get - 0, // tp_descr_set - 0, // tp_dictoffset - 0, // tp_init - PyType_GenericAlloc, // tp_alloc - PyType_GenericNew, // tp_new - PyObject_Del, // tp_free + FULL_MODULE_NAME ".EnumDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Enum Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + enum_descriptor::Methods, // tp_getset + 0, // tp_members + enum_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base }; -google::protobuf::DescriptorPool* GetDescriptorPool() { - if (g_descriptor_pool == NULL) { - g_descriptor_pool = new google::protobuf::DescriptorPool( - google::protobuf::DescriptorPool::generated_pool()); +PyObject* PyEnumDescriptor_FromDescriptor( + const EnumDescriptor* enum_descriptor) { + return descriptor::NewInternedDescriptor( + &PyEnumDescriptor_Type, enum_descriptor, NULL); +} + +const EnumDescriptor* PyEnumDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyEnumDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not an EnumDescriptor"); + return NULL; + } + return reinterpret_cast<const EnumDescriptor*>( + reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor); +} + +namespace enumvalue_descriptor { + +// Unchecked accessor to the C++ pointer. +static const EnumValueDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast<const EnumValueDescriptor*>(self->descriptor); +} + +static PyObject* GetName(PyBaseDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetNumber(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->number()); +} + +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); +} + +static PyObject* GetType(PyBaseDescriptor *self, void *closure) { + return PyEnumDescriptor_FromDescriptor(_GetDescriptor(self)->type()); +} + +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const EnumValueOptions& options(_GetDescriptor(self)->options()); + if (&options != &EnumValueOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; } - return g_descriptor_pool; +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} + + +static PyGetSetDef Getters[] = { + { "name", (getter)GetName, NULL, "name"}, + { "number", (getter)GetNumber, NULL, "number"}, + { "index", (getter)GetIndex, NULL, "index"}, + { "type", (getter)GetType, NULL, "index"}, + + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + {NULL} +}; + +} // namespace enumvalue_descriptor + +PyTypeObject PyEnumValueDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".EnumValueDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A EnumValue Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + enumvalue_descriptor::Methods, // tp_methods + 0, // tp_members + enumvalue_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyEnumValueDescriptor_FromDescriptor( + const EnumValueDescriptor* enumvalue_descriptor) { + return descriptor::NewInternedDescriptor( + &PyEnumValueDescriptor_Type, enumvalue_descriptor, NULL); +} + +namespace file_descriptor { + +// Unchecked accessor to the C++ pointer. +static const FileDescriptor* _GetDescriptor(PyFileDescriptor *self) { + return reinterpret_cast<const FileDescriptor*>(self->base.descriptor); +} + +static void Dealloc(PyFileDescriptor* self) { + Py_XDECREF(self->serialized_pb); + descriptor::Dealloc(&self->base); +} + +static PyObject* GetPool(PyFileDescriptor *self, void *closure) { + PyObject* pool = reinterpret_cast<PyObject*>( + GetDescriptorPool_FromPool(_GetDescriptor(self)->pool())); + Py_XINCREF(pool); + return pool; +} + +static PyObject* GetName(PyFileDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); } -PyObject* Python_NewCDescriptorPool(PyObject* ignored, PyObject* args) { - CDescriptorPool* cdescriptor_pool = PyObject_New( - CDescriptorPool, &CDescriptorPool_Type); - if (cdescriptor_pool == NULL) { +static PyObject* GetPackage(PyFileDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->package()); +} + +static PyObject* GetSerializedPb(PyFileDescriptor *self, void *closure) { + PyObject *serialized_pb = self->serialized_pb; + if (serialized_pb != NULL) { + Py_INCREF(serialized_pb); + return serialized_pb; + } + FileDescriptorProto file_proto; + _GetDescriptor(self)->CopyTo(&file_proto); + string contents; + file_proto.SerializePartialToString(&contents); + self->serialized_pb = PyBytes_FromStringAndSize( + contents.c_str(), contents.size()); + if (self->serialized_pb == NULL) { return NULL; } - cdescriptor_pool->pool = GetDescriptorPool(); - return reinterpret_cast<PyObject*>(cdescriptor_pool); + Py_INCREF(self->serialized_pb); + return self->serialized_pb; } +static PyObject* GetMessageTypesByName(PyFileDescriptor* self, void *closure) { + return NewFileMessageTypesByName(_GetDescriptor(self)); +} -// Collects errors that occur during proto file building to allow them to be -// propagated in the python exception instead of only living in ERROR logs. -class BuildFileErrorCollector : public google::protobuf::DescriptorPool::ErrorCollector { - public: - BuildFileErrorCollector() : error_message(""), had_errors(false) {} +static PyObject* GetEnumTypesByName(PyFileDescriptor* self, void *closure) { + return NewFileEnumTypesByName(_GetDescriptor(self)); +} - void AddError(const string& filename, const string& element_name, - const Message* descriptor, ErrorLocation location, - const string& message) { - // Replicates the logging behavior that happens in the C++ implementation - // when an error collector is not passed in. - if (!had_errors) { - error_message += - ("Invalid proto descriptor for file \"" + filename + "\":\n"); - } - // As this only happens on failure and will result in the program not - // running at all, no effort is made to optimize this string manipulation. - error_message += (" " + element_name + ": " + message + "\n"); +static PyObject* GetExtensionsByName(PyFileDescriptor* self, void *closure) { + return NewFileExtensionsByName(_GetDescriptor(self)); +} + +static PyObject* GetDependencies(PyFileDescriptor* self, void *closure) { + return NewFileDependencies(_GetDescriptor(self)); +} + +static PyObject* GetPublicDependencies(PyFileDescriptor* self, void *closure) { + return NewFilePublicDependencies(_GetDescriptor(self)); +} + +static PyObject* GetHasOptions(PyFileDescriptor *self, void *closure) { + const FileOptions& options(_GetDescriptor(self)->options()); + if (&options != &FileOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; } +} +static int SetHasOptions(PyFileDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyFileDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyFileDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} - string error_message; - bool had_errors; +static PyObject* GetSyntax(PyFileDescriptor *self, void *closure) { + return PyString_InternFromString( + FileDescriptor::SyntaxName(_GetDescriptor(self)->syntax())); +} + +static PyObject* CopyToProto(PyFileDescriptor *self, PyObject *target) { + return CopyToPythonProto<FileDescriptorProto>(_GetDescriptor(self), target); +} + +static PyGetSetDef Getters[] = { + { "pool", (getter)GetPool, NULL, "pool"}, + { "name", (getter)GetName, NULL, "name"}, + { "package", (getter)GetPackage, NULL, "package"}, + { "serialized_pb", (getter)GetSerializedPb}, + { "message_types_by_name", (getter)GetMessageTypesByName, NULL, + "Messages by name"}, + { "enum_types_by_name", (getter)GetEnumTypesByName, NULL, "Enums by name"}, + { "extensions_by_name", (getter)GetExtensionsByName, NULL, + "Extensions by name"}, + { "dependencies", (getter)GetDependencies, NULL, "Dependencies"}, + { "public_dependencies", (getter)GetPublicDependencies, NULL, "Dependencies"}, + + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + { "syntax", (getter)GetSyntax, (setter)NULL, "Syntax"}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, + {NULL} }; -PyObject* Python_BuildFile(PyObject* ignored, PyObject* arg) { - char* message_type; - Py_ssize_t message_len; +} // namespace file_descriptor - if (PyBytes_AsStringAndSize(arg, &message_type, &message_len) < 0) { +PyTypeObject PyFileDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".FileDescriptor", // tp_name + sizeof(PyFileDescriptor), // tp_basicsize + 0, // tp_itemsize + (destructor)file_descriptor::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A File Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + file_descriptor::Methods, // tp_methods + 0, // tp_members + file_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + PyObject_Del, // tp_free +}; + +PyObject* PyFileDescriptor_FromDescriptor( + const FileDescriptor* file_descriptor) { + return PyFileDescriptor_FromDescriptorWithSerializedPb(file_descriptor, + NULL); +} + +PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb( + const FileDescriptor* file_descriptor, PyObject *serialized_pb) { + bool was_created; + PyObject* py_descriptor = descriptor::NewInternedDescriptor( + &PyFileDescriptor_Type, file_descriptor, &was_created); + if (py_descriptor == NULL) { return NULL; } + if (was_created) { + PyFileDescriptor* cfile_descriptor = + reinterpret_cast<PyFileDescriptor*>(py_descriptor); + Py_XINCREF(serialized_pb); + cfile_descriptor->serialized_pb = serialized_pb; + } + // TODO(amauryfa): In the case of a cached object, check that serialized_pb + // is the same as before. + + return py_descriptor; +} - google::protobuf::FileDescriptorProto file_proto; - if (!file_proto.ParseFromArray(message_type, message_len)) { - PyErr_SetString(PyExc_TypeError, "Couldn't parse file content!"); +const FileDescriptor* PyFileDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyFileDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a FileDescriptor"); return NULL; } + return reinterpret_cast<const FileDescriptor*>( + reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor); +} + +namespace oneof_descriptor { + +// Unchecked accessor to the C++ pointer. +static const OneofDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast<const OneofDescriptor*>(self->descriptor); +} + +static PyObject* GetName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); +} - if (google::protobuf::DescriptorPool::generated_pool()->FindFileByName( - file_proto.name()) != NULL) { +static PyObject* GetFields(PyBaseDescriptor* self, void *closure) { + return NewOneofFieldsSeq(_GetDescriptor(self)); +} + +static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { + const Descriptor* containing_type = + _GetDescriptor(self)->containing_type(); + if (containing_type) { + return PyMessageDescriptor_FromDescriptor(containing_type); + } else { Py_RETURN_NONE; } +} - BuildFileErrorCollector error_collector; - const google::protobuf::FileDescriptor* descriptor = - GetDescriptorPool()->BuildFileCollectingErrors(file_proto, - &error_collector); - if (descriptor == NULL) { - PyErr_Format(PyExc_TypeError, - "Couldn't build proto file into descriptor pool!\n%s", - error_collector.error_message.c_str()); - return NULL; +static PyGetSetDef Getters[] = { + { "name", (getter)GetName, NULL, "Name"}, + { "full_name", (getter)GetFullName, NULL, "Full name"}, + { "index", (getter)GetIndex, NULL, "Index"}, + + { "containing_type", (getter)GetContainingType, NULL, "Containing type"}, + { "fields", (getter)GetFields, NULL, "Fields"}, + {NULL} +}; + +} // namespace oneof_descriptor + +PyTypeObject PyOneofDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".OneofDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Oneof Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + oneof_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyOneofDescriptor_FromDescriptor( + const OneofDescriptor* oneof_descriptor) { + return descriptor::NewInternedDescriptor( + &PyOneofDescriptor_Type, oneof_descriptor, NULL); +} + +// Add a enum values to a type dictionary. +static bool AddEnumValues(PyTypeObject *type, + const EnumDescriptor* enum_descriptor) { + for (int i = 0; i < enum_descriptor->value_count(); ++i) { + const EnumValueDescriptor* value = enum_descriptor->value(i); + ScopedPyObjectPtr obj(PyInt_FromLong(value->number())); + if (obj == NULL) { + return false; + } + if (PyDict_SetItemString(type->tp_dict, value->name().c_str(), obj.get()) < + 0) { + return false; + } } + return true; +} - Py_RETURN_NONE; +static bool AddIntConstant(PyTypeObject *type, const char* name, int value) { + ScopedPyObjectPtr obj(PyInt_FromLong(value)); + if (PyDict_SetItemString(type->tp_dict, name, obj.get()) < 0) { + return false; + } + return true; } + bool InitDescriptor() { - CFieldDescriptor_Type.tp_new = PyType_GenericNew; - if (PyType_Ready(&CFieldDescriptor_Type) < 0) + if (PyType_Ready(&PyMessageDescriptor_Type) < 0) + return false; + + if (PyType_Ready(&PyFieldDescriptor_Type) < 0) + return false; + + if (!AddEnumValues(&PyFieldDescriptor_Type, + FieldDescriptorProto::Label_descriptor())) { + return false; + } + if (!AddEnumValues(&PyFieldDescriptor_Type, + FieldDescriptorProto::Type_descriptor())) { + return false; + } +#define ADD_FIELDDESC_CONSTANT(NAME) AddIntConstant( \ + &PyFieldDescriptor_Type, #NAME, FieldDescriptor::NAME) + if (!ADD_FIELDDESC_CONSTANT(CPPTYPE_INT32) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_INT64) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_UINT32) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_UINT64) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_DOUBLE) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_FLOAT) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_BOOL) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_ENUM) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_STRING) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_MESSAGE)) { + return false; + } +#undef ADD_FIELDDESC_CONSTANT + + if (PyType_Ready(&PyEnumDescriptor_Type) < 0) + return false; + + if (PyType_Ready(&PyEnumValueDescriptor_Type) < 0) + return false; + + if (PyType_Ready(&PyFileDescriptor_Type) < 0) + return false; + + if (PyType_Ready(&PyOneofDescriptor_Type) < 0) return false; - CDescriptorPool_Type.tp_new = PyType_GenericNew; - if (PyType_Ready(&CDescriptorPool_Type) < 0) + if (!InitDescriptorMappingTypes()) return false; return true; diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h index ae7a1b9c6..eb99df182 100644 --- a/python/google/protobuf/pyext/descriptor.h +++ b/python/google/protobuf/pyext/descriptor.h @@ -34,60 +34,61 @@ #define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__ #include <Python.h> -#include <structmember.h> #include <google/protobuf/descriptor.h> -#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) -typedef int Py_ssize_t; -#define PY_SSIZE_T_MAX INT_MAX -#define PY_SSIZE_T_MIN INT_MIN -#endif - namespace google { namespace protobuf { namespace python { -typedef struct CFieldDescriptor { - PyObject_HEAD - - // The proto2 descriptor that this object represents. - const google::protobuf::FieldDescriptor* descriptor; - - // Reference to the original field object in the Python DESCRIPTOR. - PyObject* descriptor_field; -} CFieldDescriptor; - -typedef struct { - PyObject_HEAD - - const google::protobuf::DescriptorPool* pool; -} CDescriptorPool; - -extern PyTypeObject CFieldDescriptor_Type; - -extern PyTypeObject CDescriptorPool_Type; +extern PyTypeObject PyMessageDescriptor_Type; +extern PyTypeObject PyFieldDescriptor_Type; +extern PyTypeObject PyEnumDescriptor_Type; +extern PyTypeObject PyEnumValueDescriptor_Type; +extern PyTypeObject PyFileDescriptor_Type; +extern PyTypeObject PyOneofDescriptor_Type; -namespace cdescriptor_pool { - -// Looks up a field by name. Returns a CDescriptor corresponding to -// the field on success, or NULL on failure. -// +// Wraps a Descriptor in a Python object. +// The C++ pointer is usually borrowed from the global DescriptorPool. +// In any case, it must stay alive as long as the Python object. // Returns a new reference. -PyObject* FindFieldByName(CDescriptorPool* self, PyObject* name); - -// Looks up an extension by name. Returns a CDescriptor corresponding -// to the field on success, or NULL on failure. -// +PyObject* PyMessageDescriptor_FromDescriptor(const Descriptor* descriptor); +PyObject* PyFieldDescriptor_FromDescriptor(const FieldDescriptor* descriptor); +PyObject* PyEnumDescriptor_FromDescriptor(const EnumDescriptor* descriptor); +PyObject* PyEnumValueDescriptor_FromDescriptor( + const EnumValueDescriptor* descriptor); +PyObject* PyOneofDescriptor_FromDescriptor(const OneofDescriptor* descriptor); +PyObject* PyFileDescriptor_FromDescriptor( + const FileDescriptor* file_descriptor); + +// Alternate constructor of PyFileDescriptor, used when we already have a +// serialized FileDescriptorProto that can be cached. // Returns a new reference. -PyObject* FindExtensionByName(CDescriptorPool* self, PyObject* arg); - -} // namespace cdescriptor_pool +PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb( + const FileDescriptor* file_descriptor, PyObject* serialized_pb); + +// Return the C++ descriptor pointer. +// This function checks the parameter type; on error, return NULL with a Python +// exception set. +const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj); +const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj); +const EnumDescriptor* PyEnumDescriptor_AsDescriptor(PyObject* obj); +const FileDescriptor* PyFileDescriptor_AsDescriptor(PyObject* obj); + +// Returns the raw C++ pointer. +const void* PyDescriptor_AsVoidPtr(PyObject* obj); + +// Check that the calling Python code is the global scope of a _pb2.py module. +// This function is used to support the current code generated by the proto +// compiler, which insists on modifying descriptors after they have been +// created. +// +// stacklevel indicates which Python frame should be the _pb2.py module. +// +// Don't use this function outside descriptor classes. +bool _CalledFromGeneratedFile(int stacklevel); -PyObject* Python_NewCDescriptorPool(PyObject* ignored, PyObject* args); -PyObject* Python_BuildFile(PyObject* ignored, PyObject* args); bool InitDescriptor(); -google::protobuf::DescriptorPool* GetDescriptorPool(); } // namespace python } // namespace protobuf diff --git a/python/google/protobuf/pyext/descriptor_containers.cc b/python/google/protobuf/pyext/descriptor_containers.cc new file mode 100644 index 000000000..e505d8122 --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_containers.cc @@ -0,0 +1,1652 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Mappings and Sequences of descriptors. +// Used by Descriptor.fields_by_name, EnumDescriptor.values... +// +// They avoid the allocation of a full dictionary or a full list: they simply +// store a pointer to the parent descriptor, use the C++ Descriptor methods (see +// google/protobuf/descriptor.h) to retrieve other descriptors, and create +// Python objects on the fly. +// +// The containers fully conform to abc.Mapping and abc.Sequence, and behave just +// like read-only dictionaries and lists. +// +// Because the interface of C++ Descriptors is quite regular, this file actually +// defines only three types, the exact behavior of a container is controlled by +// a DescriptorContainerDef structure, which contains functions that uses the +// public Descriptor API. +// +// Note: This DescriptorContainerDef is similar to the "virtual methods table" +// that a C++ compiler generates for a class. We have to make it explicit +// because the Python API is based on C, and does not play well with C++ +// inheritance. + +#include <Python.h> + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/pyext/descriptor_containers.h> +#include <google/protobuf/pyext/descriptor_pool.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #define PyString_FromFormat PyUnicode_FromFormat + #define PyInt_FromLong PyLong_FromLong + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + +namespace google { +namespace protobuf { +namespace python { + +struct PyContainer; + +typedef int (*CountMethod)(PyContainer* self); +typedef const void* (*GetByIndexMethod)(PyContainer* self, int index); +typedef const void* (*GetByNameMethod)(PyContainer* self, const string& name); +typedef const void* (*GetByCamelcaseNameMethod)(PyContainer* self, + const string& name); +typedef const void* (*GetByNumberMethod)(PyContainer* self, int index); +typedef PyObject* (*NewObjectFromItemMethod)(const void* descriptor); +typedef const string& (*GetItemNameMethod)(const void* descriptor); +typedef const string& (*GetItemCamelcaseNameMethod)(const void* descriptor); +typedef int (*GetItemNumberMethod)(const void* descriptor); +typedef int (*GetItemIndexMethod)(const void* descriptor); + +struct DescriptorContainerDef { + const char* mapping_name; + // Returns the number of items in the container. + CountMethod count_fn; + // Retrieve item by index (usually the order of declaration in the proto file) + // Used by sequences, but also iterators. 0 <= index < Count(). + GetByIndexMethod get_by_index_fn; + // Retrieve item by name (usually a call to some 'FindByName' method). + // Used by "by_name" mappings. + GetByNameMethod get_by_name_fn; + // Retrieve item by camelcase name (usually a call to some + // 'FindByCamelcaseName' method). Used by "by_camelcase_name" mappings. + GetByCamelcaseNameMethod get_by_camelcase_name_fn; + // Retrieve item by declared number (field tag, or enum value). + // Used by "by_number" mappings. + GetByNumberMethod get_by_number_fn; + // Converts a item C++ descriptor to a Python object. Returns a new reference. + NewObjectFromItemMethod new_object_from_item_fn; + // Retrieve the name of an item. Used by iterators on "by_name" mappings. + GetItemNameMethod get_item_name_fn; + // Retrieve the camelcase name of an item. Used by iterators on + // "by_camelcase_name" mappings. + GetItemCamelcaseNameMethod get_item_camelcase_name_fn; + // Retrieve the number of an item. Used by iterators on "by_number" mappings. + GetItemNumberMethod get_item_number_fn; + // Retrieve the index of an item for the container type. + // Used by "__contains__". + // If not set, "x in sequence" will do a linear search. + GetItemIndexMethod get_item_index_fn; +}; + +struct PyContainer { + PyObject_HEAD + + // The proto2 descriptor this container belongs to the global DescriptorPool. + const void* descriptor; + + // A pointer to a static structure with function pointers that control the + // behavior of the container. Very similar to the table of virtual functions + // of a C++ class. + const DescriptorContainerDef* container_def; + + // The kind of container: list, or dict by name or value. + enum ContainerKind { + KIND_SEQUENCE, + KIND_BYNAME, + KIND_BYCAMELCASENAME, + KIND_BYNUMBER, + } kind; +}; + +struct PyContainerIterator { + PyObject_HEAD + + // The container we are iterating over. Own a reference. + PyContainer* container; + + // The current index in the iterator. + int index; + + // The kind of container: list, or dict by name or value. + enum IterKind { + KIND_ITERKEY, + KIND_ITERVALUE, + KIND_ITERITEM, + KIND_ITERVALUE_REVERSED, // For sequences + } kind; +}; + +namespace descriptor { + +// Returns the C++ item descriptor for a given Python key. +// When the descriptor is found, return true and set *item. +// When the descriptor is not found, return true, but set *item to NULL. +// On error, returns false with an exception set. +static bool _GetItemByKey(PyContainer* self, PyObject* key, const void** item) { + switch (self->kind) { + case PyContainer::KIND_BYNAME: + { + char* name; + Py_ssize_t name_size; + if (PyString_AsStringAndSize(key, &name, &name_size) < 0) { + if (PyErr_ExceptionMatches(PyExc_TypeError)) { + // Not a string, cannot be in the container. + PyErr_Clear(); + *item = NULL; + return true; + } + return false; + } + *item = self->container_def->get_by_name_fn( + self, string(name, name_size)); + return true; + } + case PyContainer::KIND_BYCAMELCASENAME: + { + char* camelcase_name; + Py_ssize_t name_size; + if (PyString_AsStringAndSize(key, &camelcase_name, &name_size) < 0) { + if (PyErr_ExceptionMatches(PyExc_TypeError)) { + // Not a string, cannot be in the container. + PyErr_Clear(); + *item = NULL; + return true; + } + return false; + } + *item = self->container_def->get_by_camelcase_name_fn( + self, string(camelcase_name, name_size)); + return true; + } + case PyContainer::KIND_BYNUMBER: + { + Py_ssize_t number = PyNumber_AsSsize_t(key, NULL); + if (number == -1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_TypeError)) { + // Not a number, cannot be in the container. + PyErr_Clear(); + *item = NULL; + return true; + } + return false; + } + *item = self->container_def->get_by_number_fn(self, number); + return true; + } + default: + PyErr_SetNone(PyExc_NotImplementedError); + return false; + } +} + +// Returns the key of the object at the given index. +// Used when iterating over mappings. +static PyObject* _NewKey_ByIndex(PyContainer* self, Py_ssize_t index) { + const void* item = self->container_def->get_by_index_fn(self, index); + switch (self->kind) { + case PyContainer::KIND_BYNAME: + { + const string& name(self->container_def->get_item_name_fn(item)); + return PyString_FromStringAndSize(name.c_str(), name.size()); + } + case PyContainer::KIND_BYCAMELCASENAME: + { + const string& name( + self->container_def->get_item_camelcase_name_fn(item)); + return PyString_FromStringAndSize(name.c_str(), name.size()); + } + case PyContainer::KIND_BYNUMBER: + { + int value = self->container_def->get_item_number_fn(item); + return PyInt_FromLong(value); + } + default: + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; + } +} + +// Returns the object at the given index. +// Also used when iterating over mappings. +static PyObject* _NewObj_ByIndex(PyContainer* self, Py_ssize_t index) { + return self->container_def->new_object_from_item_fn( + self->container_def->get_by_index_fn(self, index)); +} + +static Py_ssize_t Length(PyContainer* self) { + return self->container_def->count_fn(self); +} + +// The DescriptorMapping type. + +static PyObject* Subscript(PyContainer* self, PyObject* key) { + const void* item = NULL; + if (!_GetItemByKey(self, key, &item)) { + return NULL; + } + if (!item) { + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + } + return self->container_def->new_object_from_item_fn(item); +} + +static int AssSubscript(PyContainer* self, PyObject* key, PyObject* value) { + if (_CalledFromGeneratedFile(0)) { + return 0; + } + PyErr_Format(PyExc_TypeError, + "'%.200s' object does not support item assignment", + Py_TYPE(self)->tp_name); + return -1; +} + +static PyMappingMethods MappingMappingMethods = { + (lenfunc)Length, // mp_length + (binaryfunc)Subscript, // mp_subscript + (objobjargproc)AssSubscript, // mp_ass_subscript +}; + +static int Contains(PyContainer* self, PyObject* key) { + const void* item = NULL; + if (!_GetItemByKey(self, key, &item)) { + return -1; + } + if (item) { + return 1; + } else { + return 0; + } +} + +static PyObject* ContainerRepr(PyContainer* self) { + const char* kind = ""; + switch (self->kind) { + case PyContainer::KIND_SEQUENCE: + kind = "sequence"; + break; + case PyContainer::KIND_BYNAME: + kind = "mapping by name"; + break; + case PyContainer::KIND_BYCAMELCASENAME: + kind = "mapping by camelCase name"; + break; + case PyContainer::KIND_BYNUMBER: + kind = "mapping by number"; + break; + } + return PyString_FromFormat( + "<%s %s>", self->container_def->mapping_name, kind); +} + +extern PyTypeObject DescriptorMapping_Type; +extern PyTypeObject DescriptorSequence_Type; + +// A sequence container can only be equal to another sequence container, or (for +// backward compatibility) to a list containing the same items. +// Returns 1 if equal, 0 if unequal, -1 on error. +static int DescriptorSequence_Equal(PyContainer* self, PyObject* other) { + // Check the identity of C++ pointers. + if (PyObject_TypeCheck(other, &DescriptorSequence_Type)) { + PyContainer* other_container = reinterpret_cast<PyContainer*>(other); + if (self->descriptor == other_container->descriptor && + self->container_def == other_container->container_def && + self->kind == other_container->kind) { + return 1; + } else { + return 0; + } + } + + // If other is a list + if (PyList_Check(other)) { + // return list(self) == other + int size = Length(self); + if (size != PyList_Size(other)) { + return false; + } + for (int index = 0; index < size; index++) { + ScopedPyObjectPtr value1(_NewObj_ByIndex(self, index)); + if (value1 == NULL) { + return -1; + } + PyObject* value2 = PyList_GetItem(other, index); + if (value2 == NULL) { + return -1; + } + int cmp = PyObject_RichCompareBool(value1.get(), value2, Py_EQ); + if (cmp != 1) // error or not equal + return cmp; + } + // All items were found and equal + return 1; + } + + // Any other object is different. + return 0; +} + +// A mapping container can only be equal to another mapping container, or (for +// backward compatibility) to a dict containing the same items. +// Returns 1 if equal, 0 if unequal, -1 on error. +static int DescriptorMapping_Equal(PyContainer* self, PyObject* other) { + // Check the identity of C++ pointers. + if (PyObject_TypeCheck(other, &DescriptorMapping_Type)) { + PyContainer* other_container = reinterpret_cast<PyContainer*>(other); + if (self->descriptor == other_container->descriptor && + self->container_def == other_container->container_def && + self->kind == other_container->kind) { + return 1; + } else { + return 0; + } + } + + // If other is a dict + if (PyDict_Check(other)) { + // equivalent to dict(self.items()) == other + int size = Length(self); + if (size != PyDict_Size(other)) { + return false; + } + for (int index = 0; index < size; index++) { + ScopedPyObjectPtr key(_NewKey_ByIndex(self, index)); + if (key == NULL) { + return -1; + } + ScopedPyObjectPtr value1(_NewObj_ByIndex(self, index)); + if (value1 == NULL) { + return -1; + } + PyObject* value2 = PyDict_GetItem(other, key.get()); + if (value2 == NULL) { + // Not found in the other dictionary + return 0; + } + int cmp = PyObject_RichCompareBool(value1.get(), value2, Py_EQ); + if (cmp != 1) // error or not equal + return cmp; + } + // All items were found and equal + return 1; + } + + // Any other object is different. + return 0; +} + +static PyObject* RichCompare(PyContainer* self, PyObject* other, int opid) { + if (opid != Py_EQ && opid != Py_NE) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + + int result; + + if (self->kind == PyContainer::KIND_SEQUENCE) { + result = DescriptorSequence_Equal(self, other); + } else { + result = DescriptorMapping_Equal(self, other); + } + if (result < 0) { + return NULL; + } + if (result ^ (opid == Py_NE)) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +static PySequenceMethods MappingSequenceMethods = { + 0, // sq_length + 0, // sq_concat + 0, // sq_repeat + 0, // sq_item + 0, // sq_slice + 0, // sq_ass_item + 0, // sq_ass_slice + (objobjproc)Contains, // sq_contains +}; + +static PyObject* Get(PyContainer* self, PyObject* args) { + PyObject* key; + PyObject* default_value = Py_None; + if (!PyArg_UnpackTuple(args, "get", 1, 2, &key, &default_value)) { + return NULL; + } + + const void* item; + if (!_GetItemByKey(self, key, &item)) { + return NULL; + } + if (item == NULL) { + Py_INCREF(default_value); + return default_value; + } + return self->container_def->new_object_from_item_fn(item); +} + +static PyObject* Keys(PyContainer* self, PyObject* args) { + Py_ssize_t count = Length(self); + ScopedPyObjectPtr list(PyList_New(count)); + if (list == NULL) { + return NULL; + } + for (Py_ssize_t index = 0; index < count; ++index) { + PyObject* key = _NewKey_ByIndex(self, index); + if (key == NULL) { + return NULL; + } + PyList_SET_ITEM(list.get(), index, key); + } + return list.release(); +} + +static PyObject* Values(PyContainer* self, PyObject* args) { + Py_ssize_t count = Length(self); + ScopedPyObjectPtr list(PyList_New(count)); + if (list == NULL) { + return NULL; + } + for (Py_ssize_t index = 0; index < count; ++index) { + PyObject* value = _NewObj_ByIndex(self, index); + if (value == NULL) { + return NULL; + } + PyList_SET_ITEM(list.get(), index, value); + } + return list.release(); +} + +static PyObject* Items(PyContainer* self, PyObject* args) { + Py_ssize_t count = Length(self); + ScopedPyObjectPtr list(PyList_New(count)); + if (list == NULL) { + return NULL; + } + for (Py_ssize_t index = 0; index < count; ++index) { + ScopedPyObjectPtr obj(PyTuple_New(2)); + if (obj == NULL) { + return NULL; + } + PyObject* key = _NewKey_ByIndex(self, index); + if (key == NULL) { + return NULL; + } + PyTuple_SET_ITEM(obj.get(), 0, key); + PyObject* value = _NewObj_ByIndex(self, index); + if (value == NULL) { + return NULL; + } + PyTuple_SET_ITEM(obj.get(), 1, value); + PyList_SET_ITEM(list.get(), index, obj.release()); + } + return list.release(); +} + +static PyObject* NewContainerIterator(PyContainer* mapping, + PyContainerIterator::IterKind kind); + +static PyObject* Iter(PyContainer* self) { + return NewContainerIterator(self, PyContainerIterator::KIND_ITERKEY); +} +static PyObject* IterKeys(PyContainer* self, PyObject* args) { + return NewContainerIterator(self, PyContainerIterator::KIND_ITERKEY); +} +static PyObject* IterValues(PyContainer* self, PyObject* args) { + return NewContainerIterator(self, PyContainerIterator::KIND_ITERVALUE); +} +static PyObject* IterItems(PyContainer* self, PyObject* args) { + return NewContainerIterator(self, PyContainerIterator::KIND_ITERITEM); +} + +static PyMethodDef MappingMethods[] = { + { "get", (PyCFunction)Get, METH_VARARGS, }, + { "keys", (PyCFunction)Keys, METH_NOARGS, }, + { "values", (PyCFunction)Values, METH_NOARGS, }, + { "items", (PyCFunction)Items, METH_NOARGS, }, + { "iterkeys", (PyCFunction)IterKeys, METH_NOARGS, }, + { "itervalues", (PyCFunction)IterValues, METH_NOARGS, }, + { "iteritems", (PyCFunction)IterItems, METH_NOARGS, }, + {NULL} +}; + +PyTypeObject DescriptorMapping_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "DescriptorMapping", // tp_name + sizeof(PyContainer), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + (reprfunc)ContainerRepr, // tp_repr + 0, // tp_as_number + &MappingSequenceMethods, // tp_as_sequence + &MappingMappingMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + 0, // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + (getiterfunc)Iter, // tp_iter + 0, // tp_iternext + MappingMethods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + 0, // tp_free +}; + +// The DescriptorSequence type. + +static PyObject* GetItem(PyContainer* self, Py_ssize_t index) { + if (index < 0) { + index += Length(self); + } + if (index < 0 || index >= Length(self)) { + PyErr_SetString(PyExc_IndexError, "index out of range"); + return NULL; + } + return _NewObj_ByIndex(self, index); +} + +// Returns the position of the item in the sequence, of -1 if not found. +// This function never fails. +int Find(PyContainer* self, PyObject* item) { + // The item can only be in one position: item.index. + // Check that self[item.index] == item, it's faster than a linear search. + // + // This assumes that sequences are only defined by syntax of the .proto file: + // a specific item belongs to only one sequence, depending on its position in + // the .proto file definition. + const void* descriptor_ptr = PyDescriptor_AsVoidPtr(item); + if (descriptor_ptr == NULL) { + // Not a descriptor, it cannot be in the list. + return -1; + } + if (self->container_def->get_item_index_fn) { + int index = self->container_def->get_item_index_fn(descriptor_ptr); + if (index < 0 || index >= Length(self)) { + // This index is not from this collection. + return -1; + } + if (self->container_def->get_by_index_fn(self, index) != descriptor_ptr) { + // The descriptor at this index is not the same. + return -1; + } + // self[item.index] == item, so return the index. + return index; + } else { + // Fall back to linear search. + int length = Length(self); + for (int index=0; index < length; index++) { + if (self->container_def->get_by_index_fn(self, index) == descriptor_ptr) { + return index; + } + } + // Not found + return -1; + } +} + +// Implements list.index(): the position of the item is in the sequence. +static PyObject* Index(PyContainer* self, PyObject* item) { + int position = Find(self, item); + if (position < 0) { + // Not found + PyErr_SetNone(PyExc_ValueError); + return NULL; + } else { + return PyInt_FromLong(position); + } +} +// Implements "list.__contains__()": is the object in the sequence. +static int SeqContains(PyContainer* self, PyObject* item) { + int position = Find(self, item); + if (position < 0) { + return 0; + } else { + return 1; + } +} + +// Implements list.count(): number of occurrences of the item in the sequence. +// An item can only appear once in a sequence. If it exists, return 1. +static PyObject* Count(PyContainer* self, PyObject* item) { + int position = Find(self, item); + if (position < 0) { + return PyInt_FromLong(0); + } else { + return PyInt_FromLong(1); + } +} + +static PyObject* Append(PyContainer* self, PyObject* args) { + if (_CalledFromGeneratedFile(0)) { + Py_RETURN_NONE; + } + PyErr_Format(PyExc_TypeError, + "'%.200s' object is not a mutable sequence", + Py_TYPE(self)->tp_name); + return NULL; +} + +static PyObject* Reversed(PyContainer* self, PyObject* args) { + return NewContainerIterator(self, + PyContainerIterator::KIND_ITERVALUE_REVERSED); +} + +static PyMethodDef SeqMethods[] = { + { "index", (PyCFunction)Index, METH_O, }, + { "count", (PyCFunction)Count, METH_O, }, + { "append", (PyCFunction)Append, METH_O, }, + { "__reversed__", (PyCFunction)Reversed, METH_NOARGS, }, + {NULL} +}; + +static PySequenceMethods SeqSequenceMethods = { + (lenfunc)Length, // sq_length + 0, // sq_concat + 0, // sq_repeat + (ssizeargfunc)GetItem, // sq_item + 0, // sq_slice + 0, // sq_ass_item + 0, // sq_ass_slice + (objobjproc)SeqContains, // sq_contains +}; + +PyTypeObject DescriptorSequence_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "DescriptorSequence", // tp_name + sizeof(PyContainer), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + (reprfunc)ContainerRepr, // tp_repr + 0, // tp_as_number + &SeqSequenceMethods, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + 0, // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + SeqMethods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + 0, // tp_free +}; + +static PyObject* NewMappingByName( + DescriptorContainerDef* container_def, const void* descriptor) { + PyContainer* self = PyObject_New(PyContainer, &DescriptorMapping_Type); + if (self == NULL) { + return NULL; + } + self->descriptor = descriptor; + self->container_def = container_def; + self->kind = PyContainer::KIND_BYNAME; + return reinterpret_cast<PyObject*>(self); +} + +static PyObject* NewMappingByCamelcaseName( + DescriptorContainerDef* container_def, const void* descriptor) { + PyContainer* self = PyObject_New(PyContainer, &DescriptorMapping_Type); + if (self == NULL) { + return NULL; + } + self->descriptor = descriptor; + self->container_def = container_def; + self->kind = PyContainer::KIND_BYCAMELCASENAME; + return reinterpret_cast<PyObject*>(self); +} + +static PyObject* NewMappingByNumber( + DescriptorContainerDef* container_def, const void* descriptor) { + if (container_def->get_by_number_fn == NULL || + container_def->get_item_number_fn == NULL) { + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; + } + PyContainer* self = PyObject_New(PyContainer, &DescriptorMapping_Type); + if (self == NULL) { + return NULL; + } + self->descriptor = descriptor; + self->container_def = container_def; + self->kind = PyContainer::KIND_BYNUMBER; + return reinterpret_cast<PyObject*>(self); +} + +static PyObject* NewSequence( + DescriptorContainerDef* container_def, const void* descriptor) { + PyContainer* self = PyObject_New(PyContainer, &DescriptorSequence_Type); + if (self == NULL) { + return NULL; + } + self->descriptor = descriptor; + self->container_def = container_def; + self->kind = PyContainer::KIND_SEQUENCE; + return reinterpret_cast<PyObject*>(self); +} + +// Implement iterators over PyContainers. + +static void Iterator_Dealloc(PyContainerIterator* self) { + Py_CLEAR(self->container); + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +static PyObject* Iterator_Next(PyContainerIterator* self) { + int count = self->container->container_def->count_fn(self->container); + if (self->index >= count) { + // Return NULL with no exception to indicate the end. + return NULL; + } + int index = self->index; + self->index += 1; + switch (self->kind) { + case PyContainerIterator::KIND_ITERKEY: + return _NewKey_ByIndex(self->container, index); + case PyContainerIterator::KIND_ITERVALUE: + return _NewObj_ByIndex(self->container, index); + case PyContainerIterator::KIND_ITERVALUE_REVERSED: + return _NewObj_ByIndex(self->container, count - index - 1); + case PyContainerIterator::KIND_ITERITEM: + { + PyObject* obj = PyTuple_New(2); + if (obj == NULL) { + return NULL; + } + PyObject* key = _NewKey_ByIndex(self->container, index); + if (key == NULL) { + Py_DECREF(obj); + return NULL; + } + PyTuple_SET_ITEM(obj, 0, key); + PyObject* value = _NewObj_ByIndex(self->container, index); + if (value == NULL) { + Py_DECREF(obj); + return NULL; + } + PyTuple_SET_ITEM(obj, 1, value); + return obj; + } + default: + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; + } +} + +static PyTypeObject ContainerIterator_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "DescriptorContainerIterator", // tp_name + sizeof(PyContainerIterator), // tp_basicsize + 0, // tp_itemsize + (destructor)Iterator_Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + 0, // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + PyObject_SelfIter, // tp_iter + (iternextfunc)Iterator_Next, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + 0, // tp_free +}; + +static PyObject* NewContainerIterator(PyContainer* container, + PyContainerIterator::IterKind kind) { + PyContainerIterator* self = PyObject_New(PyContainerIterator, + &ContainerIterator_Type); + if (self == NULL) { + return NULL; + } + Py_INCREF(container); + self->container = container; + self->kind = kind; + self->index = 0; + + return reinterpret_cast<PyObject*>(self); +} + +} // namespace descriptor + +// Now define the real collections! + +namespace message_descriptor { + +typedef const Descriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast<ParentDescriptor>(self->descriptor); +} + +namespace fields { + +typedef const FieldDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->field_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindFieldByName(name); +} + +static ItemDescriptor GetByCamelcaseName(PyContainer* self, + const string& name) { + return GetDescriptor(self)->FindFieldByCamelcaseName(name); +} + +static ItemDescriptor GetByNumber(PyContainer* self, int number) { + return GetDescriptor(self)->FindFieldByNumber(number); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->field(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyFieldDescriptor_FromDescriptor(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static const string& GetItemCamelcaseName(ItemDescriptor item) { + return item->camelcase_name(); +} + +static int GetItemNumber(ItemDescriptor item) { + return item->number(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageFields", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)GetByCamelcaseName, + (GetByNumberMethod)GetByNumber, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)GetItemCamelcaseName, + (GetItemNumberMethod)GetItemNumber, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace fields + +PyObject* NewMessageFieldsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&fields::ContainerDef, descriptor); +} + +PyObject* NewMessageFieldsByCamelcaseName(ParentDescriptor descriptor) { + return descriptor::NewMappingByCamelcaseName(&fields::ContainerDef, + descriptor); +} + +PyObject* NewMessageFieldsByNumber(ParentDescriptor descriptor) { + return descriptor::NewMappingByNumber(&fields::ContainerDef, descriptor); +} + +PyObject* NewMessageFieldsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&fields::ContainerDef, descriptor); +} + +namespace nested_types { + +typedef const Descriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->nested_type_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindNestedTypeByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->nested_type(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyMessageDescriptor_FromDescriptor(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageNestedTypes", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace nested_types + +PyObject* NewMessageNestedTypesSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&nested_types::ContainerDef, descriptor); +} + +PyObject* NewMessageNestedTypesByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&nested_types::ContainerDef, descriptor); +} + +namespace enums { + +typedef const EnumDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->enum_type_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindEnumTypeByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->enum_type(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyEnumDescriptor_FromDescriptor(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageNestedEnums", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace enums + +PyObject* NewMessageEnumsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&enums::ContainerDef, descriptor); +} + +PyObject* NewMessageEnumsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&enums::ContainerDef, descriptor); +} + +namespace enumvalues { + +// This is the "enum_values_by_name" mapping, which collects values from all +// enum types in a message. +// +// Note that the behavior of the C++ descriptor is different: it will search and +// return the first value that matches the name, whereas the Python +// implementation retrieves the last one. + +typedef const EnumValueDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + int count = 0; + for (int i = 0; i < GetDescriptor(self)->enum_type_count(); ++i) { + count += GetDescriptor(self)->enum_type(i)->value_count(); + } + return count; +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindEnumValueByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + // This is not optimal, but the number of enums *types* in a given message + // is small. This function is only used when iterating over the mapping. + const EnumDescriptor* enum_type = NULL; + int enum_type_count = GetDescriptor(self)->enum_type_count(); + for (int i = 0; i < enum_type_count; ++i) { + enum_type = GetDescriptor(self)->enum_type(i); + int enum_value_count = enum_type->value_count(); + if (index < enum_value_count) { + // Found it! + break; + } + index -= enum_value_count; + } + // The next statement cannot overflow, because this function is only called by + // internal iterators which ensure that 0 <= index < Count(). + return enum_type->value(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyEnumValueDescriptor_FromDescriptor(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageEnumValues", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)NULL, +}; + +} // namespace enumvalues + +PyObject* NewMessageEnumValuesByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&enumvalues::ContainerDef, descriptor); +} + +namespace extensions { + +typedef const FieldDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->extension_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindExtensionByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->extension(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyFieldDescriptor_FromDescriptor(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageExtensions", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace extensions + +PyObject* NewMessageExtensionsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&extensions::ContainerDef, descriptor); +} + +PyObject* NewMessageExtensionsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&extensions::ContainerDef, descriptor); +} + +namespace oneofs { + +typedef const OneofDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->oneof_decl_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindOneofByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->oneof_decl(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyOneofDescriptor_FromDescriptor(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageOneofs", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace oneofs + +PyObject* NewMessageOneofsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&oneofs::ContainerDef, descriptor); +} + +PyObject* NewMessageOneofsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&oneofs::ContainerDef, descriptor); +} + +} // namespace message_descriptor + +namespace enum_descriptor { + +typedef const EnumDescriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast<ParentDescriptor>(self->descriptor); +} + +namespace enumvalues { + +typedef const EnumValueDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->value_count(); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->value(index); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindValueByName(name); +} + +static ItemDescriptor GetByNumber(PyContainer* self, int number) { + return GetDescriptor(self)->FindValueByNumber(number); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyEnumValueDescriptor_FromDescriptor(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemNumber(ItemDescriptor item) { + return item->number(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "EnumValues", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)GetByNumber, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)GetItemNumber, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace enumvalues + +PyObject* NewEnumValuesByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&enumvalues::ContainerDef, descriptor); +} + +PyObject* NewEnumValuesByNumber(ParentDescriptor descriptor) { + return descriptor::NewMappingByNumber(&enumvalues::ContainerDef, descriptor); +} + +PyObject* NewEnumValuesSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&enumvalues::ContainerDef, descriptor); +} + +} // namespace enum_descriptor + +namespace oneof_descriptor { + +typedef const OneofDescriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast<ParentDescriptor>(self->descriptor); +} + +namespace fields { + +typedef const FieldDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->field_count(); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->field(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyFieldDescriptor_FromDescriptor(item); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index_in_oneof(); +} + +static DescriptorContainerDef ContainerDef = { + "OneofFields", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)NULL, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)NULL, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace fields + +PyObject* NewOneofFieldsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&fields::ContainerDef, descriptor); +} + +} // namespace oneof_descriptor + +namespace file_descriptor { + +typedef const FileDescriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast<ParentDescriptor>(self->descriptor); +} + +namespace messages { + +typedef const Descriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->message_type_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindMessageTypeByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->message_type(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyMessageDescriptor_FromDescriptor(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "FileMessages", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace messages + +PyObject* NewFileMessageTypesByName(const FileDescriptor* descriptor) { + return descriptor::NewMappingByName(&messages::ContainerDef, descriptor); +} + +namespace enums { + +typedef const EnumDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->enum_type_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindEnumTypeByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->enum_type(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyEnumDescriptor_FromDescriptor(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "FileEnums", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace enums + +PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor) { + return descriptor::NewMappingByName(&enums::ContainerDef, descriptor); +} + +namespace extensions { + +typedef const FieldDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->extension_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindExtensionByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->extension(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyFieldDescriptor_FromDescriptor(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "FileExtensions", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace extensions + +PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor) { + return descriptor::NewMappingByName(&extensions::ContainerDef, descriptor); +} + +namespace dependencies { + +typedef const FileDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->dependency_count(); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->dependency(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyFileDescriptor_FromDescriptor(item); +} + +static DescriptorContainerDef ContainerDef = { + "FileDependencies", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)NULL, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)NULL, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)NULL, +}; + +} // namespace dependencies + +PyObject* NewFileDependencies(const FileDescriptor* descriptor) { + return descriptor::NewSequence(&dependencies::ContainerDef, descriptor); +} + +namespace public_dependencies { + +typedef const FileDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->public_dependency_count(); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->public_dependency(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyFileDescriptor_FromDescriptor(item); +} + +static DescriptorContainerDef ContainerDef = { + "FilePublicDependencies", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)NULL, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)NULL, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)NULL, +}; + +} // namespace public_dependencies + +PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor) { + return descriptor::NewSequence(&public_dependencies::ContainerDef, + descriptor); +} + +} // namespace file_descriptor + + +// Register all implementations + +bool InitDescriptorMappingTypes() { + if (PyType_Ready(&descriptor::DescriptorMapping_Type) < 0) + return false; + if (PyType_Ready(&descriptor::DescriptorSequence_Type) < 0) + return false; + if (PyType_Ready(&descriptor::ContainerIterator_Type) < 0) + return false; + return true; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/descriptor_containers.h b/python/google/protobuf/pyext/descriptor_containers.h new file mode 100644 index 000000000..ce40747d5 --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_containers.h @@ -0,0 +1,101 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__ + +// Mappings and Sequences of descriptors. +// They implement containers like fields_by_name, EnumDescriptor.values... +// See descriptor_containers.cc for more description. +#include <Python.h> + +namespace google { +namespace protobuf { + +class Descriptor; +class FileDescriptor; +class EnumDescriptor; +class OneofDescriptor; + +namespace python { + +// Initialize the various types and objects. +bool InitDescriptorMappingTypes(); + +// Each function below returns a Mapping, or a Sequence of descriptors. +// They all return a new reference. + +namespace message_descriptor { +PyObject* NewMessageFieldsByName(const Descriptor* descriptor); +PyObject* NewMessageFieldsByCamelcaseName(const Descriptor* descriptor); +PyObject* NewMessageFieldsByNumber(const Descriptor* descriptor); +PyObject* NewMessageFieldsSeq(const Descriptor* descriptor); + +PyObject* NewMessageNestedTypesSeq(const Descriptor* descriptor); +PyObject* NewMessageNestedTypesByName(const Descriptor* descriptor); + +PyObject* NewMessageEnumsByName(const Descriptor* descriptor); +PyObject* NewMessageEnumsSeq(const Descriptor* descriptor); +PyObject* NewMessageEnumValuesByName(const Descriptor* descriptor); + +PyObject* NewMessageExtensionsByName(const Descriptor* descriptor); +PyObject* NewMessageExtensionsSeq(const Descriptor* descriptor); + +PyObject* NewMessageOneofsByName(const Descriptor* descriptor); +PyObject* NewMessageOneofsSeq(const Descriptor* descriptor); +} // namespace message_descriptor + +namespace enum_descriptor { +PyObject* NewEnumValuesByName(const EnumDescriptor* descriptor); +PyObject* NewEnumValuesByNumber(const EnumDescriptor* descriptor); +PyObject* NewEnumValuesSeq(const EnumDescriptor* descriptor); +} // namespace enum_descriptor + +namespace oneof_descriptor { +PyObject* NewOneofFieldsSeq(const OneofDescriptor* descriptor); +} // namespace oneof_descriptor + +namespace file_descriptor { +PyObject* NewFileMessageTypesByName(const FileDescriptor* descriptor); + +PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor); + +PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor); + +PyObject* NewFileDependencies(const FileDescriptor* descriptor); +PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor); +} // namespace file_descriptor + + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__ diff --git a/python/google/protobuf/pyext/descriptor_cpp2_test.py b/python/google/protobuf/pyext/descriptor_cpp2_test.py deleted file mode 100644 index 3cf45a226..000000000 --- a/python/google/protobuf/pyext/descriptor_cpp2_test.py +++ /dev/null @@ -1,58 +0,0 @@ -#! /usr/bin/python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Tests for google.protobuf.pyext behavior.""" - -__author__ = 'anuraag@google.com (Anuraag Agrawal)' - -import os -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' - -# We must set the implementation version above before the google3 imports. -# pylint: disable=g-import-not-at-top -from google.apputils import basetest -from google.protobuf.internal import api_implementation -# Run all tests from the original module by putting them in our namespace. -# pylint: disable=wildcard-import -from google.protobuf.internal.descriptor_test import * - - -class ConfirmCppApi2Test(basetest.TestCase): - - def testImplementationSetting(self): - self.assertEqual('cpp', api_implementation.Type()) - self.assertEqual(2, api_implementation.Version()) - - -if __name__ == '__main__': - basetest.main() diff --git a/python/google/protobuf/pyext/descriptor_database.cc b/python/google/protobuf/pyext/descriptor_database.cc new file mode 100644 index 000000000..daa40cc72 --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_database.cc @@ -0,0 +1,148 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file defines a C++ DescriptorDatabase, which wraps a Python Database +// and delegate all its operations to Python methods. + +#include <google/protobuf/pyext/descriptor_database.h> + +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +namespace google { +namespace protobuf { +namespace python { + +PyDescriptorDatabase::PyDescriptorDatabase(PyObject* py_database) + : py_database_(py_database) { + Py_INCREF(py_database_); +} + +PyDescriptorDatabase::~PyDescriptorDatabase() { Py_DECREF(py_database_); } + +// Convert a Python object to a FileDescriptorProto pointer. +// Handles all kinds of Python errors, which are simply logged. +static bool GetFileDescriptorProto(PyObject* py_descriptor, + FileDescriptorProto* output) { + if (py_descriptor == NULL) { + if (PyErr_ExceptionMatches(PyExc_KeyError)) { + // Expected error: item was simply not found. + PyErr_Clear(); + } else { + GOOGLE_LOG(ERROR) << "DescriptorDatabase method raised an error"; + PyErr_Print(); + } + return false; + } + if (py_descriptor == Py_None) { + return false; + } + const Descriptor* filedescriptor_descriptor = + FileDescriptorProto::default_instance().GetDescriptor(); + CMessage* message = reinterpret_cast<CMessage*>(py_descriptor); + if (PyObject_TypeCheck(py_descriptor, &CMessage_Type) && + message->message->GetDescriptor() == filedescriptor_descriptor) { + // Fast path: Just use the pointer. + FileDescriptorProto* file_proto = + static_cast<FileDescriptorProto*>(message->message); + *output = *file_proto; + return true; + } else { + // Slow path: serialize the message. This allows to use databases which + // use a different implementation of FileDescriptorProto. + ScopedPyObjectPtr serialized_pb( + PyObject_CallMethod(py_descriptor, "SerializeToString", NULL)); + if (serialized_pb == NULL) { + GOOGLE_LOG(ERROR) + << "DescriptorDatabase method did not return a FileDescriptorProto"; + PyErr_Print(); + return false; + } + char* str; + Py_ssize_t len; + if (PyBytes_AsStringAndSize(serialized_pb.get(), &str, &len) < 0) { + GOOGLE_LOG(ERROR) + << "DescriptorDatabase method did not return a FileDescriptorProto"; + PyErr_Print(); + return false; + } + FileDescriptorProto file_proto; + if (!file_proto.ParseFromArray(str, len)) { + GOOGLE_LOG(ERROR) + << "DescriptorDatabase method did not return a FileDescriptorProto"; + return false; + } + *output = file_proto; + return true; + } +} + +// Find a file by file name. +bool PyDescriptorDatabase::FindFileByName(const string& filename, + FileDescriptorProto* output) { + ScopedPyObjectPtr py_descriptor(PyObject_CallMethod( + py_database_, "FindFileByName", "s#", filename.c_str(), filename.size())); + return GetFileDescriptorProto(py_descriptor.get(), output); +} + +// Find the file that declares the given fully-qualified symbol name. +bool PyDescriptorDatabase::FindFileContainingSymbol( + const string& symbol_name, FileDescriptorProto* output) { + ScopedPyObjectPtr py_descriptor( + PyObject_CallMethod(py_database_, "FindFileContainingSymbol", "s#", + symbol_name.c_str(), symbol_name.size())); + return GetFileDescriptorProto(py_descriptor.get(), output); +} + +// Find the file which defines an extension extending the given message type +// with the given field number. +// Python DescriptorDatabases are not required to implement this method. +bool PyDescriptorDatabase::FindFileContainingExtension( + const string& containing_type, int field_number, + FileDescriptorProto* output) { + ScopedPyObjectPtr py_method( + PyObject_GetAttrString(py_database_, "FindFileContainingExtension")); + if (py_method == NULL) { + // This method is not implemented, returns without error. + PyErr_Clear(); + return false; + } + ScopedPyObjectPtr py_descriptor( + PyObject_CallFunction(py_method.get(), "s#i", containing_type.c_str(), + containing_type.size(), field_number)); + return GetFileDescriptorProto(py_descriptor.get(), output); +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/descriptor_database.h b/python/google/protobuf/pyext/descriptor_database.h new file mode 100644 index 000000000..fc71c4bcb --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_database.h @@ -0,0 +1,75 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__ + +#include <Python.h> + +#include <google/protobuf/descriptor_database.h> + +namespace google { +namespace protobuf { +namespace python { + +class PyDescriptorDatabase : public DescriptorDatabase { + public: + explicit PyDescriptorDatabase(PyObject* py_database); + ~PyDescriptorDatabase(); + + // Implement the abstract interface. All these functions fill the output + // with a copy of FileDescriptorProto. + + // Find a file by file name. + bool FindFileByName(const string& filename, + FileDescriptorProto* output); + + // Find the file that declares the given fully-qualified symbol name. + bool FindFileContainingSymbol(const string& symbol_name, + FileDescriptorProto* output); + + // Find the file which defines an extension extending the given message type + // with the given field number. + // Containing_type must be a fully-qualified type name. + // Python objects are not required to implement this method. + bool FindFileContainingExtension(const string& containing_type, + int field_number, + FileDescriptorProto* output); + + private: + // The python object that implements the database. The reference is owned. + PyObject* py_database_; +}; + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__ diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc new file mode 100644 index 000000000..1faff96bc --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -0,0 +1,593 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Implements the DescriptorPool, which collects all descriptors. + +#include <Python.h> + +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/dynamic_message.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/descriptor_database.h> +#include <google/protobuf/pyext/descriptor_pool.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + +namespace google { +namespace protobuf { +namespace python { + +// A map to cache Python Pools per C++ pointer. +// Pointers are not owned here, and belong to the PyDescriptorPool. +static hash_map<const DescriptorPool*, PyDescriptorPool*> descriptor_pool_map; + +namespace cdescriptor_pool { + +// Create a Python DescriptorPool object, but does not fill the "pool" +// attribute. +static PyDescriptorPool* _CreateDescriptorPool() { + PyDescriptorPool* cpool = PyObject_New( + PyDescriptorPool, &PyDescriptorPool_Type); + if (cpool == NULL) { + return NULL; + } + + cpool->underlay = NULL; + cpool->database = NULL; + + DynamicMessageFactory* message_factory = new DynamicMessageFactory(); + // This option might be the default some day. + message_factory->SetDelegateToGeneratedFactory(true); + cpool->message_factory = message_factory; + + // TODO(amauryfa): Rewrite the SymbolDatabase in C so that it uses the same + // storage. + cpool->classes_by_descriptor = + new PyDescriptorPool::ClassesByMessageMap(); + cpool->descriptor_options = + new hash_map<const void*, PyObject *>(); + + return cpool; +} + +// Create a Python DescriptorPool, using the given pool as an underlay: +// new messages will be added to a custom pool, not to the underlay. +// +// Ownership of the underlay is not transferred, its pointer should +// stay alive. +static PyDescriptorPool* PyDescriptorPool_NewWithUnderlay( + const DescriptorPool* underlay) { + PyDescriptorPool* cpool = _CreateDescriptorPool(); + if (cpool == NULL) { + return NULL; + } + cpool->pool = new DescriptorPool(underlay); + cpool->underlay = underlay; + + if (!descriptor_pool_map.insert( + std::make_pair(cpool->pool, cpool)).second) { + // Should never happen -- would indicate an internal error / bug. + PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered"); + return NULL; + } + + return cpool; +} + +static PyDescriptorPool* PyDescriptorPool_NewWithDatabase( + DescriptorDatabase* database) { + PyDescriptorPool* cpool = _CreateDescriptorPool(); + if (cpool == NULL) { + return NULL; + } + if (database != NULL) { + cpool->pool = new DescriptorPool(database); + cpool->database = database; + } else { + cpool->pool = new DescriptorPool(); + } + + if (!descriptor_pool_map.insert(std::make_pair(cpool->pool, cpool)).second) { + // Should never happen -- would indicate an internal error / bug. + PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered"); + return NULL; + } + + return cpool; +} + +// The public DescriptorPool constructor. +static PyObject* New(PyTypeObject* type, + PyObject* args, PyObject* kwargs) { + static char* kwlist[] = {"descriptor_db", 0}; + PyObject* py_database = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &py_database)) { + return NULL; + } + DescriptorDatabase* database = NULL; + if (py_database && py_database != Py_None) { + database = new PyDescriptorDatabase(py_database); + } + return reinterpret_cast<PyObject*>( + PyDescriptorPool_NewWithDatabase(database)); +} + +static void Dealloc(PyDescriptorPool* self) { + typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; + descriptor_pool_map.erase(self->pool); + for (iterator it = self->classes_by_descriptor->begin(); + it != self->classes_by_descriptor->end(); ++it) { + Py_DECREF(it->second); + } + delete self->classes_by_descriptor; + for (hash_map<const void*, PyObject*>::iterator it = + self->descriptor_options->begin(); + it != self->descriptor_options->end(); ++it) { + Py_DECREF(it->second); + } + delete self->descriptor_options; + delete self->message_factory; + delete self->database; + delete self->pool; + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const Descriptor* message_descriptor = + self->pool->FindMessageTypeByName(string(name, name_size)); + + if (message_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find message %.200s", name); + return NULL; + } + + return PyMessageDescriptor_FromDescriptor(message_descriptor); +} + +// Add a message class to our database. +int RegisterMessageClass(PyDescriptorPool* self, + const Descriptor* message_descriptor, + CMessageClass* message_class) { + Py_INCREF(message_class); + typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; + std::pair<iterator, bool> ret = self->classes_by_descriptor->insert( + std::make_pair(message_descriptor, message_class)); + if (!ret.second) { + // Update case: DECREF the previous value. + Py_DECREF(ret.first->second); + ret.first->second = message_class; + } + return 0; +} + +// Retrieve the message class added to our database. +CMessageClass* GetMessageClass(PyDescriptorPool* self, + const Descriptor* message_descriptor) { + typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; + iterator ret = self->classes_by_descriptor->find(message_descriptor); + if (ret == self->classes_by_descriptor->end()) { + PyErr_Format(PyExc_TypeError, "No message class registered for '%s'", + message_descriptor->full_name().c_str()); + return NULL; + } else { + return ret->second; + } +} + +PyObject* FindFileByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const FileDescriptor* file_descriptor = + self->pool->FindFileByName(string(name, name_size)); + if (file_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find file %.200s", + name); + return NULL; + } + + return PyFileDescriptor_FromDescriptor(file_descriptor); +} + +PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const FieldDescriptor* field_descriptor = + self->pool->FindFieldByName(string(name, name_size)); + if (field_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find field %.200s", + name); + return NULL; + } + + return PyFieldDescriptor_FromDescriptor(field_descriptor); +} + +PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const FieldDescriptor* field_descriptor = + self->pool->FindExtensionByName(string(name, name_size)); + if (field_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find extension field %.200s", name); + return NULL; + } + + return PyFieldDescriptor_FromDescriptor(field_descriptor); +} + +PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const EnumDescriptor* enum_descriptor = + self->pool->FindEnumTypeByName(string(name, name_size)); + if (enum_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find enum %.200s", name); + return NULL; + } + + return PyEnumDescriptor_FromDescriptor(enum_descriptor); +} + +PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const OneofDescriptor* oneof_descriptor = + self->pool->FindOneofByName(string(name, name_size)); + if (oneof_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find oneof %.200s", name); + return NULL; + } + + return PyOneofDescriptor_FromDescriptor(oneof_descriptor); +} + +PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const FileDescriptor* file_descriptor = + self->pool->FindFileContainingSymbol(string(name, name_size)); + if (file_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find symbol %.200s", name); + return NULL; + } + + return PyFileDescriptor_FromDescriptor(file_descriptor); +} + +// These functions should not exist -- the only valid way to create +// descriptors is to call Add() or AddSerializedFile(). +// But these AddDescriptor() functions were created in Python and some people +// call them, so we support them for now for compatibility. +// However we do check that the existing descriptor already exists in the pool, +// which appears to always be true for existing calls -- but then why do people +// call a function that will just be a no-op? +// TODO(amauryfa): Need to investigate further. + +PyObject* AddFileDescriptor(PyDescriptorPool* self, PyObject* descriptor) { + const FileDescriptor* file_descriptor = + PyFileDescriptor_AsDescriptor(descriptor); + if (!file_descriptor) { + return NULL; + } + if (file_descriptor != + self->pool->FindFileByName(file_descriptor->name())) { + PyErr_Format(PyExc_ValueError, + "The file descriptor %s does not belong to this pool", + file_descriptor->name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + +PyObject* AddDescriptor(PyDescriptorPool* self, PyObject* descriptor) { + const Descriptor* message_descriptor = + PyMessageDescriptor_AsDescriptor(descriptor); + if (!message_descriptor) { + return NULL; + } + if (message_descriptor != + self->pool->FindMessageTypeByName(message_descriptor->full_name())) { + PyErr_Format(PyExc_ValueError, + "The message descriptor %s does not belong to this pool", + message_descriptor->full_name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + +PyObject* AddEnumDescriptor(PyDescriptorPool* self, PyObject* descriptor) { + const EnumDescriptor* enum_descriptor = + PyEnumDescriptor_AsDescriptor(descriptor); + if (!enum_descriptor) { + return NULL; + } + if (enum_descriptor != + self->pool->FindEnumTypeByName(enum_descriptor->full_name())) { + PyErr_Format(PyExc_ValueError, + "The enum descriptor %s does not belong to this pool", + enum_descriptor->full_name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + +// The code below loads new Descriptors from a serialized FileDescriptorProto. + + +// Collects errors that occur during proto file building to allow them to be +// propagated in the python exception instead of only living in ERROR logs. +class BuildFileErrorCollector : public DescriptorPool::ErrorCollector { + public: + BuildFileErrorCollector() : error_message(""), had_errors(false) {} + + void AddError(const string& filename, const string& element_name, + const Message* descriptor, ErrorLocation location, + const string& message) { + // Replicates the logging behavior that happens in the C++ implementation + // when an error collector is not passed in. + if (!had_errors) { + error_message += + ("Invalid proto descriptor for file \"" + filename + "\":\n"); + had_errors = true; + } + // As this only happens on failure and will result in the program not + // running at all, no effort is made to optimize this string manipulation. + error_message += (" " + element_name + ": " + message + "\n"); + } + + string error_message; + bool had_errors; +}; + +PyObject* AddSerializedFile(PyDescriptorPool* self, PyObject* serialized_pb) { + char* message_type; + Py_ssize_t message_len; + + if (self->database != NULL) { + PyErr_SetString( + PyExc_ValueError, + "Cannot call Add on a DescriptorPool that uses a DescriptorDatabase. " + "Add your file to the underlying database."); + return NULL; + } + + if (PyBytes_AsStringAndSize(serialized_pb, &message_type, &message_len) < 0) { + return NULL; + } + + FileDescriptorProto file_proto; + if (!file_proto.ParseFromArray(message_type, message_len)) { + PyErr_SetString(PyExc_TypeError, "Couldn't parse file content!"); + return NULL; + } + + // If the file was already part of a C++ library, all its descriptors are in + // the underlying pool. No need to do anything else. + const FileDescriptor* generated_file = NULL; + if (self->underlay) { + generated_file = self->underlay->FindFileByName(file_proto.name()); + } + if (generated_file != NULL) { + return PyFileDescriptor_FromDescriptorWithSerializedPb( + generated_file, serialized_pb); + } + + BuildFileErrorCollector error_collector; + const FileDescriptor* descriptor = + self->pool->BuildFileCollectingErrors(file_proto, + &error_collector); + if (descriptor == NULL) { + PyErr_Format(PyExc_TypeError, + "Couldn't build proto file into descriptor pool!\n%s", + error_collector.error_message.c_str()); + return NULL; + } + + return PyFileDescriptor_FromDescriptorWithSerializedPb( + descriptor, serialized_pb); +} + +PyObject* Add(PyDescriptorPool* self, PyObject* file_descriptor_proto) { + ScopedPyObjectPtr serialized_pb( + PyObject_CallMethod(file_descriptor_proto, "SerializeToString", NULL)); + if (serialized_pb == NULL) { + return NULL; + } + return AddSerializedFile(self, serialized_pb.get()); +} + +static PyMethodDef Methods[] = { + { "Add", (PyCFunction)Add, METH_O, + "Adds the FileDescriptorProto and its types to this pool." }, + { "AddSerializedFile", (PyCFunction)AddSerializedFile, METH_O, + "Adds a serialized FileDescriptorProto to this pool." }, + + // TODO(amauryfa): Understand why the Python implementation differs from + // this one, ask users to use another API and deprecate these functions. + { "AddFileDescriptor", (PyCFunction)AddFileDescriptor, METH_O, + "No-op. Add() must have been called before." }, + { "AddDescriptor", (PyCFunction)AddDescriptor, METH_O, + "No-op. Add() must have been called before." }, + { "AddEnumDescriptor", (PyCFunction)AddEnumDescriptor, METH_O, + "No-op. Add() must have been called before." }, + + { "FindFileByName", (PyCFunction)FindFileByName, METH_O, + "Searches for a file descriptor by its .proto name." }, + { "FindMessageTypeByName", (PyCFunction)FindMessageByName, METH_O, + "Searches for a message descriptor by full name." }, + { "FindFieldByName", (PyCFunction)FindFieldByName, METH_O, + "Searches for a field descriptor by full name." }, + { "FindExtensionByName", (PyCFunction)FindExtensionByName, METH_O, + "Searches for extension descriptor by full name." }, + { "FindEnumTypeByName", (PyCFunction)FindEnumTypeByName, METH_O, + "Searches for enum type descriptor by full name." }, + { "FindOneofByName", (PyCFunction)FindOneofByName, METH_O, + "Searches for oneof descriptor by full name." }, + + { "FindFileContainingSymbol", (PyCFunction)FindFileContainingSymbol, METH_O, + "Gets the FileDescriptor containing the specified symbol." }, + {NULL} +}; + +} // namespace cdescriptor_pool + +PyTypeObject PyDescriptorPool_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".DescriptorPool", // tp_name + sizeof(PyDescriptorPool), // tp_basicsize + 0, // tp_itemsize + (destructor)cdescriptor_pool::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Descriptor Pool", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + cdescriptor_pool::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + cdescriptor_pool::New, // tp_new + PyObject_Del, // tp_free +}; + +// This is the DescriptorPool which contains all the definitions from the +// generated _pb2.py modules. +static PyDescriptorPool* python_generated_pool = NULL; + +bool InitDescriptorPool() { + if (PyType_Ready(&PyDescriptorPool_Type) < 0) + return false; + + // The Pool of messages declared in Python libraries. + // generated_pool() contains all messages already linked in C++ libraries, and + // is used as underlay. + python_generated_pool = cdescriptor_pool::PyDescriptorPool_NewWithUnderlay( + DescriptorPool::generated_pool()); + if (python_generated_pool == NULL) { + return false; + } + // Register this pool to be found for C++-generated descriptors. + descriptor_pool_map.insert( + std::make_pair(DescriptorPool::generated_pool(), + python_generated_pool)); + + return true; +} + +// The default DescriptorPool used everywhere in this module. +// Today it's the python_generated_pool. +// TODO(amauryfa): Remove all usages of this function: the pool should be +// derived from the context. +PyDescriptorPool* GetDefaultDescriptorPool() { + return python_generated_pool; +} + +PyDescriptorPool* GetDescriptorPool_FromPool(const DescriptorPool* pool) { + // Fast path for standard descriptors. + if (pool == python_generated_pool->pool || + pool == DescriptorPool::generated_pool()) { + return python_generated_pool; + } + hash_map<const DescriptorPool*, PyDescriptorPool*>::iterator it = + descriptor_pool_map.find(pool); + if (it == descriptor_pool_map.end()) { + PyErr_SetString(PyExc_KeyError, "Unknown descriptor pool"); + return NULL; + } + return it->second; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h new file mode 100644 index 000000000..2a42c1126 --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_pool.h @@ -0,0 +1,167 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__ + +#include <Python.h> + +#include <google/protobuf/stubs/hash.h> +#include <google/protobuf/descriptor.h> + +namespace google { +namespace protobuf { +class MessageFactory; + +namespace python { + +// The (meta) type of all Messages classes. +struct CMessageClass; + +// Wraps operations to the global DescriptorPool which contains information +// about all messages and fields. +// +// There is normally one pool per process. We make it a Python object only +// because it contains many Python references. +// TODO(amauryfa): See whether such objects can appear in reference cycles, and +// consider adding support for the cyclic GC. +// +// "Methods" that interacts with this DescriptorPool are in the cdescriptor_pool +// namespace. +typedef struct PyDescriptorPool { + PyObject_HEAD + + // The C++ pool containing Descriptors. + DescriptorPool* pool; + + // The C++ pool acting as an underlay. Can be NULL. + // This pointer is not owned and must stay alive. + const DescriptorPool* underlay; + + // The C++ descriptor database used to fetch unknown protos. Can be NULL. + // This pointer is owned. + const DescriptorDatabase* database; + + // DynamicMessageFactory used to create C++ instances of messages. + // This object cache the descriptors that were used, so the DescriptorPool + // needs to get rid of it before it can delete itself. + // + // Note: A C++ MessageFactory is different from the Python MessageFactory. + // The C++ one creates messages, when the Python one creates classes. + MessageFactory* message_factory; + + // Make our own mapping to retrieve Python classes from C++ descriptors. + // + // Descriptor pointers stored here are owned by the DescriptorPool above. + // Python references to classes are owned by this PyDescriptorPool. + typedef hash_map<const Descriptor*, CMessageClass*> ClassesByMessageMap; + ClassesByMessageMap* classes_by_descriptor; + + // Cache the options for any kind of descriptor. + // Descriptor pointers are owned by the DescriptorPool above. + // Python objects are owned by the map. + hash_map<const void*, PyObject*>* descriptor_options; +} PyDescriptorPool; + + +extern PyTypeObject PyDescriptorPool_Type; + +namespace cdescriptor_pool { + +// Looks up a message by name. +// Returns a message Descriptor, or NULL if not found. +const Descriptor* FindMessageTypeByName(PyDescriptorPool* self, + const string& name); + +// Registers a new Python class for the given message descriptor. +// On error, returns -1 with a Python exception set. +int RegisterMessageClass(PyDescriptorPool* self, + const Descriptor* message_descriptor, + CMessageClass* message_class); + +// Retrieves the Python class registered with the given message descriptor. +// +// Returns a *borrowed* reference if found, otherwise returns NULL with an +// exception set. +CMessageClass* GetMessageClass(PyDescriptorPool* self, + const Descriptor* message_descriptor); + +// The functions below are also exposed as methods of the DescriptorPool type. + +// Looks up a message by name. Returns a PyMessageDescriptor corresponding to +// the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* name); + +// Looks up a field by name. Returns a PyFieldDescriptor corresponding to +// the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* name); + +// Looks up an extension by name. Returns a PyFieldDescriptor corresponding +// to the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg); + +// Looks up an enum type by name. Returns a PyEnumDescriptor corresponding +// to the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg); + +// Looks up a oneof by name. Returns a COneofDescriptor corresponding +// to the oneof on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg); + +} // namespace cdescriptor_pool + +// Retrieve the global descriptor pool owned by the _message module. +// This is the one used by pb2.py generated modules. +// Returns a *borrowed* reference. +// "Default" pool used to register messages from _pb2.py modules. +PyDescriptorPool* GetDefaultDescriptorPool(); + +// Retrieve the python descriptor pool owning a C++ descriptor pool. +// Returns a *borrowed* reference. +PyDescriptorPool* GetDescriptorPool_FromPool(const DescriptorPool* pool); + +// Initialize objects used by this module. +bool InitDescriptorPool(); + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__ diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index 3861c7943..21bbb8c2b 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -33,11 +33,13 @@ #include <google/protobuf/pyext/extension_dict.h> +#include <google/protobuf/stubs/logging.h> #include <google/protobuf/stubs/common.h> #include <google/protobuf/descriptor.h> #include <google/protobuf/dynamic_message.h> #include <google/protobuf/message.h> #include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/message.h> #include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/repeated_scalar_container.h> @@ -48,36 +50,8 @@ namespace google { namespace protobuf { namespace python { -extern google::protobuf::DynamicMessageFactory* global_message_factory; - namespace extension_dict { -// TODO(tibell): Always use self->message for clarity, just like in -// RepeatedCompositeContainer. -static google::protobuf::Message* GetMessage(ExtensionDict* self) { - if (self->parent != NULL) { - return self->parent->message; - } else { - return self->message; - } -} - -CFieldDescriptor* InternalGetCDescriptorFromExtension(PyObject* extension) { - PyObject* cdescriptor = PyObject_GetAttrString(extension, "_cdescriptor"); - if (cdescriptor == NULL) { - PyErr_SetString(PyExc_KeyError, "Unregistered extension."); - return NULL; - } - if (!PyObject_TypeCheck(cdescriptor, &CFieldDescriptor_Type)) { - PyErr_SetString(PyExc_TypeError, "Not a CFieldDescriptor"); - Py_DECREF(cdescriptor); - return NULL; - } - CFieldDescriptor* descriptor = - reinterpret_cast<CFieldDescriptor*>(cdescriptor); - return descriptor; -} - PyObject* len(ExtensionDict* self) { #if PY_MAJOR_VERSION >= 3 return PyLong_FromLong(PyDict_Size(self->values)); @@ -89,10 +63,9 @@ PyObject* len(ExtensionDict* self) { // TODO(tibell): Use VisitCompositeField. int ReleaseExtension(ExtensionDict* self, PyObject* extension, - const google::protobuf::FieldDescriptor* descriptor) { - if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { - if (descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + const FieldDescriptor* descriptor) { + if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { if (repeated_composite_container::Release( reinterpret_cast<RepeatedCompositeContainer*>( extension)) < 0) { @@ -105,10 +78,9 @@ int ReleaseExtension(ExtensionDict* self, return -1; } } - } else if (descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { if (cmessage::ReleaseSubMessage( - GetMessage(self), descriptor, + self->parent, descriptor, reinterpret_cast<CMessage*>(extension)) < 0) { return -1; } @@ -118,19 +90,17 @@ int ReleaseExtension(ExtensionDict* self, } PyObject* subscript(ExtensionDict* self, PyObject* key) { - CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension( - key); - if (cdescriptor == NULL) { + const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); + if (descriptor == NULL) { return NULL; } - ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor)); - const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor; - if (descriptor == NULL) { + if (!CheckFieldBelongsToMessage(descriptor, self->message)) { return NULL; } + if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { - return cmessage::InternalGetScalar(self->parent, descriptor); + return cmessage::InternalGetScalar(self->message, descriptor); } PyObject* value = PyDict_GetItem(self->values, key); @@ -139,10 +109,18 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { return value; } + if (self->parent == NULL) { + // We are in "detached" state. Don't allow further modifications. + // TODO(amauryfa): Support adding non-scalars to a detached extension dict. + // This probably requires to store the type of the main message. + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + } + if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyObject* sub_message = cmessage::InternalGetSubMessage( - self->parent, cdescriptor); + self->parent, descriptor); if (sub_message == NULL) { return NULL; } @@ -152,33 +130,22 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - // COPIED - PyObject* py_container = PyObject_CallObject( - reinterpret_cast<PyObject*>(&RepeatedCompositeContainer_Type), - NULL); + CMessageClass* message_class = cdescriptor_pool::GetMessageClass( + cmessage::GetDescriptorPoolForMessage(self->parent), + descriptor->message_type()); + if (message_class == NULL) { + return NULL; + } + PyObject* py_container = repeated_composite_container::NewContainer( + self->parent, descriptor, message_class); if (py_container == NULL) { return NULL; } - RepeatedCompositeContainer* container = - reinterpret_cast<RepeatedCompositeContainer*>(py_container); - PyObject* field = cdescriptor->descriptor_field; - PyObject* message_type = PyObject_GetAttrString(field, "message_type"); - PyObject* concrete_class = PyObject_GetAttrString(message_type, - "_concrete_class"); - container->owner = self->owner; - container->parent = self->parent; - container->message = self->parent->message; - container->parent_field = cdescriptor; - container->subclass_init = concrete_class; - Py_DECREF(message_type); PyDict_SetItem(self->values, key, py_container); return py_container; } else { - // COPIED - ScopedPyObjectPtr init_args(PyTuple_Pack(2, self->parent, cdescriptor)); - PyObject* py_container = PyObject_CallObject( - reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type), - init_args); + PyObject* py_container = repeated_scalar_container::NewContainer( + self->parent, descriptor); if (py_container == NULL) { return NULL; } @@ -191,22 +158,25 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { } int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { - CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension( - key); - if (cdescriptor == NULL) { + const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); + if (descriptor == NULL) { + return -1; + } + if (!CheckFieldBelongsToMessage(descriptor, self->message)) { return -1; } - ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor)); - const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor; + if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL || descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite " "type"); return -1; } - cmessage::AssureWritable(self->parent); - if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { - return -1; + if (self->parent) { + cmessage::AssureWritable(self->parent); + if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { + return -1; + } } // TODO(tibell): We shouldn't write scalars to the cache. PyDict_SetItem(self->values, key, value); @@ -214,22 +184,23 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { } PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { - CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension( - extension); - if (cdescriptor == NULL) { + const FieldDescriptor* descriptor = + cmessage::GetExtensionDescriptor(extension); + if (descriptor == NULL) { return NULL; } - ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor)); PyObject* value = PyDict_GetItem(self->values, extension); - if (value != NULL) { - if (ReleaseExtension(self, value, cdescriptor->descriptor) < 0) { + if (self->parent) { + if (value != NULL) { + if (ReleaseExtension(self, value, descriptor) < 0) { + return NULL; + } + } + if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor( + self->parent, descriptor)) == NULL) { return NULL; } } - if (cmessage::ClearFieldByDescriptor(self->parent, - cdescriptor->descriptor) == NULL) { - return NULL; - } if (PyDict_DelItem(self->values, extension) < 0) { PyErr_Clear(); } @@ -237,15 +208,20 @@ PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { } PyObject* HasExtension(ExtensionDict* self, PyObject* extension) { - CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension( - extension); - if (cdescriptor == NULL) { + const FieldDescriptor* descriptor = + cmessage::GetExtensionDescriptor(extension); + if (descriptor == NULL) { return NULL; } - ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor)); - PyObject* result = cmessage::HasFieldByDescriptor( - self->parent, cdescriptor->descriptor); - return result; + if (self->parent) { + return cmessage::HasFieldByDescriptor(self->parent, descriptor); + } else { + int exists = PyDict_Contains(self->values, extension); + if (exists < 0) { + return NULL; + } + return PyBool_FromLong(exists); + } } PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { @@ -254,7 +230,7 @@ PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { if (extensions_by_name == NULL) { return NULL; } - PyObject* result = PyDict_GetItem(extensions_by_name, name); + PyObject* result = PyDict_GetItem(extensions_by_name.get(), name); if (result == NULL) { Py_RETURN_NONE; } else { @@ -263,11 +239,33 @@ PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { } } -int init(ExtensionDict* self, PyObject* args, PyObject* kwargs) { - self->parent = NULL; - self->message = NULL; +PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number) { + ScopedPyObjectPtr extensions_by_number(PyObject_GetAttrString( + reinterpret_cast<PyObject*>(self->parent), "_extensions_by_number")); + if (extensions_by_number == NULL) { + return NULL; + } + PyObject* result = PyDict_GetItem(extensions_by_number.get(), number); + if (result == NULL) { + Py_RETURN_NONE; + } else { + Py_INCREF(result); + return result; + } +} + +ExtensionDict* NewExtensionDict(CMessage *parent) { + ExtensionDict* self = reinterpret_cast<ExtensionDict*>( + PyType_GenericAlloc(&ExtensionDict_Type, 0)); + if (self == NULL) { + return NULL; + } + + self->parent = parent; // Store a borrowed reference. + self->message = parent->message; + self->owner = parent->owner; self->values = PyDict_New(); - return 0; + return self; } void dealloc(ExtensionDict* self) { @@ -288,6 +286,8 @@ static PyMethodDef Methods[] = { EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."), EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."), + EDMETHOD(_FindExtensionByNumber, METH_O, + "Finds an extension by field number."), { NULL, NULL } }; @@ -295,8 +295,7 @@ static PyMethodDef Methods[] = { PyTypeObject ExtensionDict_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - "google.protobuf.internal." - "cpp._message.ExtensionDict", // tp_name + FULL_MODULE_NAME ".ExtensionDict", // tp_name sizeof(ExtensionDict), // tp_basicsize 0, // tp_itemsize (destructor)extension_dict::dealloc, // tp_dealloc @@ -308,7 +307,7 @@ PyTypeObject ExtensionDict_Type = { 0, // tp_as_number 0, // tp_as_sequence &extension_dict::MpMethods, // tp_as_mapping - 0, // tp_hash + PyObject_HashNotImplemented, // tp_hash 0, // tp_call 0, // tp_str 0, // tp_getattro @@ -330,7 +329,7 @@ PyTypeObject ExtensionDict_Type = { 0, // tp_descr_get 0, // tp_descr_set 0, // tp_dictoffset - (initproc)extension_dict::init, // tp_init + 0, // tp_init }; } // namespace python diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h index 13c874a49..2456eda1e 100644 --- a/python/google/protobuf/pyext/extension_dict.h +++ b/python/google/protobuf/pyext/extension_dict.h @@ -41,25 +41,41 @@ #include <google/protobuf/stubs/shared_ptr.h> #endif - namespace google { namespace protobuf { class Message; class FieldDescriptor; +#ifdef _SHARED_PTR_H +using std::shared_ptr; +#else using internal::shared_ptr; +#endif namespace python { struct CMessage; -struct CFieldDescriptor; typedef struct ExtensionDict { PyObject_HEAD; + + // This is the top-level C++ Message object that owns the whole + // proto tree. Every Python container class holds a + // reference to it in order to keep it alive as long as there's a + // Python object that references any part of the tree. shared_ptr<Message> owner; + + // Weak reference to parent message. Used to make sure + // the parent is writable when an extension field is modified. CMessage* parent; + + // Pointer to the C++ Message that this ExtensionDict extends. + // Not owned by us. Message* message; + + // A dict of child messages, indexed by Extension descriptors. + // Similar to CMessage::composite_fields. PyObject* values; } ExtensionDict; @@ -67,11 +83,8 @@ extern PyTypeObject ExtensionDict_Type; namespace extension_dict { -// Gets the _cdescriptor reference to a CFieldDescriptor object given a -// python descriptor object. -// -// Returns a new reference. -CFieldDescriptor* InternalGetCDescriptorFromExtension(PyObject* extension); +// Builds an Extensions dict for a specific message. +ExtensionDict* NewExtensionDict(CMessage *parent); // Gets the number of extension values in this ExtensionDict as a python object. // @@ -84,7 +97,7 @@ PyObject* len(ExtensionDict* self); // Returns 0 on success, -1 on failure. int ReleaseExtension(ExtensionDict* self, PyObject* extension, - const google::protobuf::FieldDescriptor* descriptor); + const FieldDescriptor* descriptor); // Gets an extension from the dict for the given extension descriptor. // @@ -104,17 +117,18 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value); PyObject* ClearExtension(ExtensionDict* self, PyObject* extension); -// Checks if the dict has an extension. -// -// Returns a new python boolean reference. -PyObject* HasExtension(ExtensionDict* self, PyObject* extension); - // Gets an extension from the dict given the extension name as opposed to // descriptor. // // Returns a new reference. PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name); +// Gets an extension from the dict given the extension field number as +// opposed to descriptor. +// +// Returns a new reference. +PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number); + } // namespace extension_dict } // namespace python } // namespace protobuf diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc new file mode 100644 index 000000000..e022406d1 --- /dev/null +++ b/python/google/protobuf/pyext/map_container.cc @@ -0,0 +1,970 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: haberman@google.com (Josh Haberman) + +#include <google/protobuf/pyext/map_container.h> + +#include <memory> +#ifndef _SHARED_PTR_H +#include <google/protobuf/stubs/shared_ptr.h> +#endif + +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/stubs/scoped_ptr.h> +#include <google/protobuf/map_field.h> +#include <google/protobuf/map.h> +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyInt_FromLong PyLong_FromLong + #define PyInt_FromSize_t PyLong_FromSize_t +#endif + +namespace google { +namespace protobuf { +namespace python { + +// Functions that need access to map reflection functionality. +// They need to be contained in this class because it is friended. +class MapReflectionFriend { + public: + // Methods that are in common between the map types. + static PyObject* Contains(PyObject* _self, PyObject* key); + static Py_ssize_t Length(PyObject* _self); + static PyObject* GetIterator(PyObject *_self); + static PyObject* IterNext(PyObject* _self); + + // Methods that differ between the map types. + static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key); + static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key); + static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v); + static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v); +}; + +struct MapIterator { + PyObject_HEAD; + + google::protobuf::scoped_ptr< ::google::protobuf::MapIterator> iter; + + // A pointer back to the container, so we can notice changes to the version. + // We own a ref on this. + MapContainer* container; + + // We need to keep a ref on the Message* too, because + // MapIterator::~MapIterator() accesses it. Normally this would be ok because + // the ref on container (above) would guarantee outlive semantics. However in + // the case of ClearField(), InitializeAndCopyToParentContainer() resets the + // message pointer (and the owner) to a different message, a copy of the + // original. But our iterator still points to the original, which could now + // get deleted before us. + // + // To prevent this, we ensure that the Message will always stay alive as long + // as this iterator does. This is solely for the benefit of the MapIterator + // destructor -- we should never actually access the iterator in this state + // except to delete it. + shared_ptr<Message> owner; + + // The version of the map when we took the iterator to it. + // + // We store this so that if the map is modified during iteration we can throw + // an error. + uint64 version; + + // True if the container is empty. We signal this separately to avoid calling + // any of the iteration methods, which are non-const. + bool empty; +}; + +Message* MapContainer::GetMutableMessage() { + cmessage::AssureWritable(parent); + return const_cast<Message*>(message); +} + +// Consumes a reference on the Python string object. +static bool PyStringToSTL(PyObject* py_string, string* stl_string) { + char *value; + Py_ssize_t value_len; + + if (!py_string) { + return false; + } + if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) { + Py_DECREF(py_string); + return false; + } else { + stl_string->assign(value, value_len); + Py_DECREF(py_string); + return true; + } +} + +static bool PythonToMapKey(PyObject* obj, + const FieldDescriptor* field_descriptor, + MapKey* key) { + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(obj, value, false); + key->SetInt32Value(value); + break; + } + case FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(obj, value, false); + key->SetInt64Value(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(obj, value, false); + key->SetUInt32Value(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(obj, value, false); + key->SetUInt64Value(value); + break; + } + case FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(obj, value, false); + key->SetBoolValue(value); + break; + } + case FieldDescriptor::CPPTYPE_STRING: { + string str; + if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) { + return false; + } + key->SetStringValue(str); + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Type %d cannot be a map key", + field_descriptor->cpp_type()); + return false; + } + return true; +} + +static PyObject* MapKeyToPython(const FieldDescriptor* field_descriptor, + const MapKey& key) { + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return PyInt_FromLong(key.GetInt32Value()); + case FieldDescriptor::CPPTYPE_INT64: + return PyLong_FromLongLong(key.GetInt64Value()); + case FieldDescriptor::CPPTYPE_UINT32: + return PyInt_FromSize_t(key.GetUInt32Value()); + case FieldDescriptor::CPPTYPE_UINT64: + return PyLong_FromUnsignedLongLong(key.GetUInt64Value()); + case FieldDescriptor::CPPTYPE_BOOL: + return PyBool_FromLong(key.GetBoolValue()); + case FieldDescriptor::CPPTYPE_STRING: + return ToStringObject(field_descriptor, key.GetStringValue()); + default: + PyErr_Format( + PyExc_SystemError, "Couldn't convert type %d to value", + field_descriptor->cpp_type()); + return NULL; + } +} + +// This is only used for ScalarMap, so we don't need to handle the +// CPPTYPE_MESSAGE case. +PyObject* MapValueRefToPython(const FieldDescriptor* field_descriptor, + MapValueRef* value) { + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return PyInt_FromLong(value->GetInt32Value()); + case FieldDescriptor::CPPTYPE_INT64: + return PyLong_FromLongLong(value->GetInt64Value()); + case FieldDescriptor::CPPTYPE_UINT32: + return PyInt_FromSize_t(value->GetUInt32Value()); + case FieldDescriptor::CPPTYPE_UINT64: + return PyLong_FromUnsignedLongLong(value->GetUInt64Value()); + case FieldDescriptor::CPPTYPE_FLOAT: + return PyFloat_FromDouble(value->GetFloatValue()); + case FieldDescriptor::CPPTYPE_DOUBLE: + return PyFloat_FromDouble(value->GetDoubleValue()); + case FieldDescriptor::CPPTYPE_BOOL: + return PyBool_FromLong(value->GetBoolValue()); + case FieldDescriptor::CPPTYPE_STRING: + return ToStringObject(field_descriptor, value->GetStringValue()); + case FieldDescriptor::CPPTYPE_ENUM: + return PyInt_FromLong(value->GetEnumValue()); + default: + PyErr_Format( + PyExc_SystemError, "Couldn't convert type %d to value", + field_descriptor->cpp_type()); + return NULL; + } +} + +// This is only used for ScalarMap, so we don't need to handle the +// CPPTYPE_MESSAGE case. +static bool PythonToMapValueRef(PyObject* obj, + const FieldDescriptor* field_descriptor, + bool allow_unknown_enum_values, + MapValueRef* value_ref) { + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(obj, value, false); + value_ref->SetInt32Value(value); + return true; + } + case FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(obj, value, false); + value_ref->SetInt64Value(value); + return true; + } + case FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(obj, value, false); + value_ref->SetUInt32Value(value); + return true; + } + case FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(obj, value, false); + value_ref->SetUInt64Value(value); + return true; + } + case FieldDescriptor::CPPTYPE_FLOAT: { + GOOGLE_CHECK_GET_FLOAT(obj, value, false); + value_ref->SetFloatValue(value); + return true; + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + GOOGLE_CHECK_GET_DOUBLE(obj, value, false); + value_ref->SetDoubleValue(value); + return true; + } + case FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(obj, value, false); + value_ref->SetBoolValue(value); + return true;; + } + case FieldDescriptor::CPPTYPE_STRING: { + string str; + if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) { + return false; + } + value_ref->SetStringValue(str); + return true; + } + case FieldDescriptor::CPPTYPE_ENUM: { + GOOGLE_CHECK_GET_INT32(obj, value, false); + if (allow_unknown_enum_values) { + value_ref->SetEnumValue(value); + return true; + } else { + const EnumDescriptor* enum_descriptor = field_descriptor->enum_type(); + const EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + value_ref->SetEnumValue(value); + return true; + } else { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value); + return false; + } + } + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Setting value to a field of unknown type %d", + field_descriptor->cpp_type()); + return false; + } +} + +// Map methods common to ScalarMap and MessageMap ////////////////////////////// + +static MapContainer* GetMap(PyObject* obj) { + return reinterpret_cast<MapContainer*>(obj); +} + +Py_ssize_t MapReflectionFriend::Length(PyObject* _self) { + MapContainer* self = GetMap(_self); + const google::protobuf::Message* message = self->message; + return message->GetReflection()->MapSize(*message, + self->parent_field_descriptor); +} + +PyObject* Clear(PyObject* _self) { + MapContainer* self = GetMap(_self); + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + + reflection->ClearField(message, self->parent_field_descriptor); + + Py_RETURN_NONE; +} + +PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) { + MapContainer* self = GetMap(_self); + + const Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + + if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) { + return NULL; + } + + if (reflection->ContainsMapKey(*message, self->parent_field_descriptor, + map_key)) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +// Initializes the underlying Message object of "to" so it becomes a new parent +// repeated scalar, and copies all the values from "from" to it. A child scalar +// container can be released by passing it as both from and to (e.g. making it +// the recipient of the new parent message and copying the values from itself). +static int InitializeAndCopyToParentContainer(MapContainer* from, + MapContainer* to) { + // For now we require from == to, re-evaluate if we want to support deep copy + // as in repeated_scalar_container.cc. + GOOGLE_DCHECK(from == to); + Message* new_message = from->message->New(); + + if (MapReflectionFriend::Length(reinterpret_cast<PyObject*>(from)) > 0) { + // A somewhat roundabout way of copying just one field from old_message to + // new_message. This is the best we can do with what Reflection gives us. + Message* mutable_old = from->GetMutableMessage(); + vector<const FieldDescriptor*> fields; + fields.push_back(from->parent_field_descriptor); + + // Move the map field into the new message. + mutable_old->GetReflection()->SwapFields(mutable_old, new_message, fields); + + // If/when we support from != to, this will be required also to copy the + // map field back into the existing message: + // mutable_old->MergeFrom(*new_message); + } + + // If from == to this could delete old_message. + to->owner.reset(new_message); + + to->parent = NULL; + to->parent_field_descriptor = from->parent_field_descriptor; + to->message = new_message; + + // Invalidate iterators, since they point to the old copy of the field. + to->version++; + + return 0; +} + +int MapContainer::Release() { + return InitializeAndCopyToParentContainer(this, this); +} + + +// ScalarMap /////////////////////////////////////////////////////////////////// + +PyObject *NewScalarMapContainer( + CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) { + if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { + return NULL; + } + +#if PY_MAJOR_VERSION >= 3 + ScopedPyObjectPtr obj(PyType_GenericAlloc( + reinterpret_cast<PyTypeObject *>(ScalarMapContainer_Type), 0)); +#else + ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0)); +#endif + if (obj.get() == NULL) { + return PyErr_Format(PyExc_RuntimeError, + "Could not allocate new container."); + } + + MapContainer* self = GetMap(obj.get()); + + self->message = parent->message; + self->parent = parent; + self->parent_field_descriptor = parent_field_descriptor; + self->owner = parent->owner; + self->version = 0; + + self->key_field_descriptor = + parent_field_descriptor->message_type()->FindFieldByName("key"); + self->value_field_descriptor = + parent_field_descriptor->message_type()->FindFieldByName("value"); + + if (self->key_field_descriptor == NULL || + self->value_field_descriptor == NULL) { + return PyErr_Format(PyExc_KeyError, + "Map entry descriptor did not have key/value fields"); + } + + return obj.release(); +} + +PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self, + PyObject* key) { + MapContainer* self = GetMap(_self); + + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + MapValueRef value; + + if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) { + return NULL; + } + + if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, + map_key, &value)) { + self->version++; + } + + return MapValueRefToPython(self->value_field_descriptor, &value); +} + +int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key, + PyObject* v) { + MapContainer* self = GetMap(_self); + + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + MapValueRef value; + + if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) { + return -1; + } + + self->version++; + + if (v) { + // Set item to v. + reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, + map_key, &value); + + return PythonToMapValueRef(v, self->value_field_descriptor, + reflection->SupportsUnknownEnumValues(), &value) + ? 0 + : -1; + } else { + // Delete key from map. + if (reflection->DeleteMapValue(message, self->parent_field_descriptor, + map_key)) { + return 0; + } else { + PyErr_Format(PyExc_KeyError, "Key not present in map"); + return -1; + } + } +} + +static PyObject* ScalarMapGet(PyObject* self, PyObject* args) { + PyObject* key; + PyObject* default_value = NULL; + if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) { + return NULL; + } + + ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key)); + if (is_present.get() == NULL) { + return NULL; + } + + if (PyObject_IsTrue(is_present.get())) { + return MapReflectionFriend::ScalarMapGetItem(self, key); + } else { + if (default_value != NULL) { + Py_INCREF(default_value); + return default_value; + } else { + Py_RETURN_NONE; + } + } +} + +static void ScalarMapDealloc(PyObject* _self) { + MapContainer* self = GetMap(_self); + self->owner.reset(); + Py_TYPE(_self)->tp_free(_self); +} + +static PyMethodDef ScalarMapMethods[] = { + { "__contains__", MapReflectionFriend::Contains, METH_O, + "Tests whether a key is a member of the map." }, + { "clear", (PyCFunction)Clear, METH_NOARGS, + "Removes all elements from the map." }, + { "get", ScalarMapGet, METH_VARARGS, + "Gets the value for the given key if present, or otherwise a default" }, + /* + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, + "Outputs picklable representation of the repeated field." }, + */ + {NULL, NULL}, +}; + +#if PY_MAJOR_VERSION >= 3 + static PyType_Slot ScalarMapContainer_Type_slots[] = { + {Py_tp_dealloc, (void *)ScalarMapDealloc}, + {Py_mp_length, (void *)MapReflectionFriend::Length}, + {Py_mp_subscript, (void *)MapReflectionFriend::ScalarMapGetItem}, + {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem}, + {Py_tp_methods, (void *)ScalarMapMethods}, + {Py_tp_iter, (void *)MapReflectionFriend::GetIterator}, + {0, 0}, + }; + + PyType_Spec ScalarMapContainer_Type_spec = { + FULL_MODULE_NAME ".ScalarMapContainer", + sizeof(MapContainer), + 0, + Py_TPFLAGS_DEFAULT, + ScalarMapContainer_Type_slots + }; + PyObject *ScalarMapContainer_Type; +#else + static PyMappingMethods ScalarMapMappingMethods = { + MapReflectionFriend::Length, // mp_length + MapReflectionFriend::ScalarMapGetItem, // mp_subscript + MapReflectionFriend::ScalarMapSetItem, // mp_ass_subscript + }; + + PyTypeObject ScalarMapContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".ScalarMapContainer", // tp_name + sizeof(MapContainer), // tp_basicsize + 0, // tp_itemsize + ScalarMapDealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + &ScalarMapMappingMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A scalar map container", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + MapReflectionFriend::GetIterator, // tp_iter + 0, // tp_iternext + ScalarMapMethods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + }; +#endif + + +// MessageMap ////////////////////////////////////////////////////////////////// + +static MessageMapContainer* GetMessageMap(PyObject* obj) { + return reinterpret_cast<MessageMapContainer*>(obj); +} + +static PyObject* GetCMessage(MessageMapContainer* self, Message* message) { + // Get or create the CMessage object corresponding to this message. + ScopedPyObjectPtr key(PyLong_FromVoidPtr(message)); + PyObject* ret = PyDict_GetItem(self->message_dict, key.get()); + + if (ret == NULL) { + CMessage* cmsg = cmessage::NewEmptyMessage(self->message_class); + ret = reinterpret_cast<PyObject*>(cmsg); + + if (cmsg == NULL) { + return NULL; + } + cmsg->owner = self->owner; + cmsg->message = message; + cmsg->parent = self->parent; + + if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) { + Py_DECREF(ret); + return NULL; + } + } else { + Py_INCREF(ret); + } + + return ret; +} + +PyObject* NewMessageMapContainer( + CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor, + CMessageClass* message_class) { + if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { + return NULL; + } + +#if PY_MAJOR_VERSION >= 3 + PyObject* obj = PyType_GenericAlloc( + reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0); +#else + PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0); +#endif + if (obj == NULL) { + return PyErr_Format(PyExc_RuntimeError, + "Could not allocate new container."); + } + + MessageMapContainer* self = GetMessageMap(obj); + + self->message = parent->message; + self->parent = parent; + self->parent_field_descriptor = parent_field_descriptor; + self->owner = parent->owner; + self->version = 0; + + self->key_field_descriptor = + parent_field_descriptor->message_type()->FindFieldByName("key"); + self->value_field_descriptor = + parent_field_descriptor->message_type()->FindFieldByName("value"); + + self->message_dict = PyDict_New(); + if (self->message_dict == NULL) { + return PyErr_Format(PyExc_RuntimeError, + "Could not allocate message dict."); + } + + Py_INCREF(message_class); + self->message_class = message_class; + + if (self->key_field_descriptor == NULL || + self->value_field_descriptor == NULL) { + Py_DECREF(obj); + return PyErr_Format(PyExc_KeyError, + "Map entry descriptor did not have key/value fields"); + } + + return obj; +} + +int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key, + PyObject* v) { + if (v) { + PyErr_Format(PyExc_ValueError, + "Direct assignment of submessage not allowed"); + return -1; + } + + // Now we know that this is a delete, not a set. + + MessageMapContainer* self = GetMessageMap(_self); + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + MapValueRef value; + + self->version++; + + if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) { + return -1; + } + + // Delete key from map. + if (reflection->DeleteMapValue(message, self->parent_field_descriptor, + map_key)) { + return 0; + } else { + PyErr_Format(PyExc_KeyError, "Key not present in map"); + return -1; + } +} + +PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self, + PyObject* key) { + MessageMapContainer* self = GetMessageMap(_self); + + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + MapValueRef value; + + if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) { + return NULL; + } + + if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, + map_key, &value)) { + self->version++; + } + + return GetCMessage(self, value.MutableMessageValue()); +} + +PyObject* MessageMapGet(PyObject* self, PyObject* args) { + PyObject* key; + PyObject* default_value = NULL; + if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) { + return NULL; + } + + ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key)); + if (is_present.get() == NULL) { + return NULL; + } + + if (PyObject_IsTrue(is_present.get())) { + return MapReflectionFriend::MessageMapGetItem(self, key); + } else { + if (default_value != NULL) { + Py_INCREF(default_value); + return default_value; + } else { + Py_RETURN_NONE; + } + } +} + +static void MessageMapDealloc(PyObject* _self) { + MessageMapContainer* self = GetMessageMap(_self); + self->owner.reset(); + Py_DECREF(self->message_dict); + Py_DECREF(self->message_class); + Py_TYPE(_self)->tp_free(_self); +} + +static PyMethodDef MessageMapMethods[] = { + { "__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O, + "Tests whether the map contains this element."}, + { "clear", (PyCFunction)Clear, METH_NOARGS, + "Removes all elements from the map."}, + { "get", MessageMapGet, METH_VARARGS, + "Gets the value for the given key if present, or otherwise a default" }, + { "get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O, + "Alias for getitem, useful to make explicit that the map is mutated." }, + /* + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, + "Outputs picklable representation of the repeated field." }, + */ + {NULL, NULL}, +}; + +#if PY_MAJOR_VERSION >= 3 + static PyType_Slot MessageMapContainer_Type_slots[] = { + {Py_tp_dealloc, (void *)MessageMapDealloc}, + {Py_mp_length, (void *)MapReflectionFriend::Length}, + {Py_mp_subscript, (void *)MapReflectionFriend::MessageMapGetItem}, + {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem}, + {Py_tp_methods, (void *)MessageMapMethods}, + {Py_tp_iter, (void *)MapReflectionFriend::GetIterator}, + {0, 0} + }; + + PyType_Spec MessageMapContainer_Type_spec = { + FULL_MODULE_NAME ".MessageMapContainer", + sizeof(MessageMapContainer), + 0, + Py_TPFLAGS_DEFAULT, + MessageMapContainer_Type_slots + }; + + PyObject *MessageMapContainer_Type; +#else + static PyMappingMethods MessageMapMappingMethods = { + MapReflectionFriend::Length, // mp_length + MapReflectionFriend::MessageMapGetItem, // mp_subscript + MapReflectionFriend::MessageMapSetItem, // mp_ass_subscript + }; + + PyTypeObject MessageMapContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MessageMapContainer", // tp_name + sizeof(MessageMapContainer), // tp_basicsize + 0, // tp_itemsize + MessageMapDealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + &MessageMapMappingMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A map container for message", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + MapReflectionFriend::GetIterator, // tp_iter + 0, // tp_iternext + MessageMapMethods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + }; +#endif + +// MapIterator ///////////////////////////////////////////////////////////////// + +static MapIterator* GetIter(PyObject* obj) { + return reinterpret_cast<MapIterator*>(obj); +} + +PyObject* MapReflectionFriend::GetIterator(PyObject *_self) { + MapContainer* self = GetMap(_self); + + ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0)); + if (obj == NULL) { + return PyErr_Format(PyExc_KeyError, "Could not allocate iterator"); + } + + MapIterator* iter = GetIter(obj.get()); + + Py_INCREF(self); + iter->container = self; + iter->version = self->version; + iter->owner = self->owner; + + if (MapReflectionFriend::Length(_self) > 0) { + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + + iter->iter.reset(new ::google::protobuf::MapIterator( + reflection->MapBegin(message, self->parent_field_descriptor))); + } + + return obj.release(); +} + +PyObject* MapReflectionFriend::IterNext(PyObject* _self) { + MapIterator* self = GetIter(_self); + + // This won't catch mutations to the map performed by MergeFrom(); no easy way + // to address that. + if (self->version != self->container->version) { + return PyErr_Format(PyExc_RuntimeError, + "Map modified during iteration."); + } + + if (self->iter.get() == NULL) { + return NULL; + } + + Message* message = self->container->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + + if (*self->iter == + reflection->MapEnd(message, self->container->parent_field_descriptor)) { + return NULL; + } + + PyObject* ret = MapKeyToPython(self->container->key_field_descriptor, + self->iter->GetKey()); + + ++(*self->iter); + + return ret; +} + +static void DeallocMapIterator(PyObject* _self) { + MapIterator* self = GetIter(_self); + self->iter.reset(); + self->owner.reset(); + Py_XDECREF(self->container); + Py_TYPE(_self)->tp_free(_self); +} + +PyTypeObject MapIterator_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MapIterator", // tp_name + sizeof(MapIterator), // tp_basicsize + 0, // tp_itemsize + DeallocMapIterator, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A scalar map iterator", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + PyObject_SelfIter, // tp_iter + MapReflectionFriend::IterNext, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/map_container.h b/python/google/protobuf/pyext/map_container.h new file mode 100644 index 000000000..fbd6713f7 --- /dev/null +++ b/python/google/protobuf/pyext/map_container.h @@ -0,0 +1,142 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__ + +#include <Python.h> + +#include <memory> +#ifndef _SHARED_PTR_H +#include <google/protobuf/stubs/shared_ptr.h> +#endif + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/message.h> + +namespace google { +namespace protobuf { + +class Message; + +#ifdef _SHARED_PTR_H +using std::shared_ptr; +#else +using internal::shared_ptr; +#endif + +namespace python { + +struct CMessage; +struct CMessageClass; + +// This struct is used directly for ScalarMap, and is the base class of +// MessageMapContainer, which is used for MessageMap. +struct MapContainer { + PyObject_HEAD; + + // This is the top-level C++ Message object that owns the whole + // proto tree. Every Python MapContainer holds a + // reference to it in order to keep it alive as long as there's a + // Python object that references any part of the tree. + shared_ptr<Message> owner; + + // Pointer to the C++ Message that contains this container. The + // MapContainer does not own this pointer. + const Message* message; + + // Use to get a mutable message when necessary. + Message* GetMutableMessage(); + + // Weak reference to a parent CMessage object (i.e. may be NULL.) + // + // Used to make sure all ancestors are also mutable when first + // modifying the container. + CMessage* parent; + + // Pointer to the parent's descriptor that describes this + // field. Used together with the parent's message when making a + // default message instance mutable. + // The pointer is owned by the global DescriptorPool. + const FieldDescriptor* parent_field_descriptor; + const FieldDescriptor* key_field_descriptor; + const FieldDescriptor* value_field_descriptor; + + // We bump this whenever we perform a mutation, to invalidate existing + // iterators. + uint64 version; + + // Releases the messages in the container to a new message. + // + // Returns 0 on success, -1 on failure. + int Release(); + + // Set the owner field of self and any children of self. + void SetOwner(const shared_ptr<Message>& new_owner) { + owner = new_owner; + } +}; + +struct MessageMapContainer : public MapContainer { + // The type used to create new child messages. + CMessageClass* message_class; + + // A dict mapping Message* -> CMessage. + PyObject* message_dict; +}; + +#if PY_MAJOR_VERSION >= 3 + extern PyObject *MessageMapContainer_Type; + extern PyType_Spec MessageMapContainer_Type_spec; + extern PyObject *ScalarMapContainer_Type; + extern PyType_Spec ScalarMapContainer_Type_spec; +#else + extern PyTypeObject MessageMapContainer_Type; + extern PyTypeObject ScalarMapContainer_Type; +#endif + +extern PyTypeObject MapIterator_Type; // Both map types use the same iterator. + +// Builds a MapContainer object, from a parent message and a +// field descriptor. +extern PyObject* NewScalarMapContainer( + CMessage* parent, const FieldDescriptor* parent_field_descriptor); + +// Builds a MessageMap object, from a parent message and a +// field descriptor. +extern PyObject* NewMessageMapContainer( + CMessage* parent, const FieldDescriptor* parent_field_descriptor, + CMessageClass* message_class); + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__ diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 9fb7083f8..83c151ff6 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -33,12 +33,14 @@ #include <google/protobuf/pyext/message.h> +#include <map> #include <memory> #ifndef _SHARED_PTR_H #include <google/protobuf/stubs/shared_ptr.h> #endif #include <string> #include <vector> +#include <structmember.h> // A Python header file. #ifndef PyVarObject_HEAD_INIT #define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, @@ -48,16 +50,21 @@ #endif #include <google/protobuf/descriptor.pb.h> #include <google/protobuf/stubs/common.h> +#include <google/protobuf/stubs/logging.h> #include <google/protobuf/io/coded_stream.h> +#include <google/protobuf/util/message_differencer.h> #include <google/protobuf/descriptor.h> -#include <google/protobuf/dynamic_message.h> #include <google/protobuf/message.h> #include <google/protobuf/text_format.h> +#include <google/protobuf/unknown_field_set.h> #include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/extension_dict.h> #include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/repeated_scalar_container.h> +#include <google/protobuf/pyext/map_container.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> +#include <google/protobuf/stubs/strutil.h> #if PY_MAJOR_VERSION >= 3 #define PyInt_Check PyLong_Check @@ -71,7 +78,11 @@ #error "Python 3.0 - 3.2 are not supported." #else #define PyString_AsString(ob) \ - (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob)) + (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob)) + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) #endif #endif @@ -79,14 +90,328 @@ namespace google { namespace protobuf { namespace python { +static PyObject* kDESCRIPTOR; +static PyObject* k_extensions_by_name; +static PyObject* k_extensions_by_number; +PyObject* EnumTypeWrapper_class; +static PyObject* PythonMessage_class; +static PyObject* kEmptyWeakref; +static PyObject* WKT_classes = NULL; + +namespace message_meta { + +static int InsertEmptyWeakref(PyTypeObject* base); + +// Add the number of a field descriptor to the containing message class. +// Equivalent to: +// _cls.<field>_FIELD_NUMBER = <number> +static bool AddFieldNumberToClass( + PyObject* cls, const FieldDescriptor* field_descriptor) { + string constant_name = field_descriptor->name() + "_FIELD_NUMBER"; + UpperString(&constant_name); + ScopedPyObjectPtr attr_name(PyString_FromStringAndSize( + constant_name.c_str(), constant_name.size())); + if (attr_name == NULL) { + return false; + } + ScopedPyObjectPtr number(PyInt_FromLong(field_descriptor->number())); + if (number == NULL) { + return false; + } + if (PyObject_SetAttr(cls, attr_name.get(), number.get()) == -1) { + return false; + } + return true; +} + + +// Finalize the creation of the Message class. +static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { + // If there are extension_ranges, the message is "extendable", and extension + // classes will register themselves in this class. + if (descriptor->extension_range_count() > 0) { + ScopedPyObjectPtr by_name(PyDict_New()); + if (PyObject_SetAttr(cls, k_extensions_by_name, by_name.get()) < 0) { + return -1; + } + ScopedPyObjectPtr by_number(PyDict_New()); + if (PyObject_SetAttr(cls, k_extensions_by_number, by_number.get()) < 0) { + return -1; + } + } + + // For each field set: cls.<field>_FIELD_NUMBER = <number> + for (int i = 0; i < descriptor->field_count(); ++i) { + if (!AddFieldNumberToClass(cls, descriptor->field(i))) { + return -1; + } + } + + // For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>). + for (int i = 0; i < descriptor->enum_type_count(); ++i) { + const EnumDescriptor* enum_descriptor = descriptor->enum_type(i); + ScopedPyObjectPtr enum_type( + PyEnumDescriptor_FromDescriptor(enum_descriptor)); + if (enum_type == NULL) { + return -1; + } + // Add wrapped enum type to message class. + ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs( + EnumTypeWrapper_class, enum_type.get(), NULL)); + if (wrapped == NULL) { + return -1; + } + if (PyObject_SetAttrString( + cls, enum_descriptor->name().c_str(), wrapped.get()) == -1) { + return -1; + } + + // For each enum value add cls.<name> = <number> + for (int j = 0; j < enum_descriptor->value_count(); ++j) { + const EnumValueDescriptor* enum_value_descriptor = + enum_descriptor->value(j); + ScopedPyObjectPtr value_number(PyInt_FromLong( + enum_value_descriptor->number())); + if (value_number == NULL) { + return -1; + } + if (PyObject_SetAttrString(cls, enum_value_descriptor->name().c_str(), + value_number.get()) == -1) { + return -1; + } + } + } + + // For each extension set cls.<extension name> = <extension descriptor>. + // + // Extension descriptors come from + // <message descriptor>.extensions_by_name[name] + // which was defined previously. + for (int i = 0; i < descriptor->extension_count(); ++i) { + const google::protobuf::FieldDescriptor* field = descriptor->extension(i); + ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field)); + if (extension_field == NULL) { + return -1; + } + + // Add the extension field to the message class. + if (PyObject_SetAttrString( + cls, field->name().c_str(), extension_field.get()) == -1) { + return -1; + } + + // For each extension set cls.<extension name>_FIELD_NUMBER = <number>. + if (!AddFieldNumberToClass(cls, field)) { + return -1; + } + } + + return 0; +} + +static PyObject* New(PyTypeObject* type, + PyObject* args, PyObject* kwargs) { + static char *kwlist[] = {"name", "bases", "dict", 0}; + PyObject *bases, *dict; + const char* name; + + // Check arguments: (name, bases, dict) + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "sO!O!:type", kwlist, + &name, + &PyTuple_Type, &bases, + &PyDict_Type, &dict)) { + return NULL; + } + + // Check bases: only (), or (message.Message,) are allowed + if (!(PyTuple_GET_SIZE(bases) == 0 || + (PyTuple_GET_SIZE(bases) == 1 && + PyTuple_GET_ITEM(bases, 0) == PythonMessage_class))) { + PyErr_SetString(PyExc_TypeError, + "A Message class can only inherit from Message"); + return NULL; + } + + // Check dict['DESCRIPTOR'] + PyObject* py_descriptor = PyDict_GetItem(dict, kDESCRIPTOR); + if (py_descriptor == NULL) { + PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR"); + return NULL; + } + if (!PyObject_TypeCheck(py_descriptor, &PyMessageDescriptor_Type)) { + PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s", + py_descriptor->ob_type->tp_name); + return NULL; + } + + // Build the arguments to the base metaclass. + // We change the __bases__ classes. + ScopedPyObjectPtr new_args; + const Descriptor* message_descriptor = + PyMessageDescriptor_AsDescriptor(py_descriptor); + if (message_descriptor == NULL) { + return NULL; + } + + if (WKT_classes == NULL) { + ScopedPyObjectPtr well_known_types(PyImport_ImportModule( + "google.protobuf.internal.well_known_types")); + GOOGLE_DCHECK(well_known_types != NULL); + + WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES"); + GOOGLE_DCHECK(WKT_classes != NULL); + } + + PyObject* well_known_class = PyDict_GetItemString( + WKT_classes, message_descriptor->full_name().c_str()); + if (well_known_class == NULL) { + new_args.reset(Py_BuildValue("s(OO)O", name, &CMessage_Type, + PythonMessage_class, dict)); + } else { + new_args.reset(Py_BuildValue("s(OOO)O", name, &CMessage_Type, + PythonMessage_class, well_known_class, dict)); + } + + if (new_args == NULL) { + return NULL; + } + // Call the base metaclass. + ScopedPyObjectPtr result(PyType_Type.tp_new(type, new_args.get(), NULL)); + if (result == NULL) { + return NULL; + } + CMessageClass* newtype = reinterpret_cast<CMessageClass*>(result.get()); + + // Insert the empty weakref into the base classes. + if (InsertEmptyWeakref( + reinterpret_cast<PyTypeObject*>(PythonMessage_class)) < 0 || + InsertEmptyWeakref(&CMessage_Type) < 0) { + return NULL; + } + + // Cache the descriptor, both as Python object and as C++ pointer. + const Descriptor* descriptor = + PyMessageDescriptor_AsDescriptor(py_descriptor); + if (descriptor == NULL) { + return NULL; + } + Py_INCREF(py_descriptor); + newtype->py_message_descriptor = py_descriptor; + newtype->message_descriptor = descriptor; + // TODO(amauryfa): Don't always use the canonical pool of the descriptor, + // use the MessageFactory optionally passed in the class dict. + newtype->py_descriptor_pool = GetDescriptorPool_FromPool( + descriptor->file()->pool()); + if (newtype->py_descriptor_pool == NULL) { + return NULL; + } + Py_INCREF(newtype->py_descriptor_pool); + + // Add the message to the DescriptorPool. + if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool, + descriptor, newtype) < 0) { + return NULL; + } + + // Continue with type initialization: add other descriptors, enum values... + if (AddDescriptors(result.get(), descriptor) < 0) { + return NULL; + } + return result.release(); +} + +static void Dealloc(CMessageClass *self) { + Py_DECREF(self->py_message_descriptor); + Py_DECREF(self->py_descriptor_pool); + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + + +// This function inserts and empty weakref at the end of the list of +// subclasses for the main protocol buffer Message class. +// +// This eliminates a O(n^2) behaviour in the internal add_subclass +// routine. +static int InsertEmptyWeakref(PyTypeObject *base_type) { +#if PY_MAJOR_VERSION >= 3 + // Python 3.4 has already included the fix for the issue that this + // hack addresses. For further background and the fix please see + // https://bugs.python.org/issue17936. + return 0; +#else + PyObject *subclasses = base_type->tp_subclasses; + if (subclasses && PyList_CheckExact(subclasses)) { + return PyList_Append(subclasses, kEmptyWeakref); + } + return 0; +#endif // PY_MAJOR_VERSION >= 3 +} + +} // namespace message_meta + +PyTypeObject CMessageClass_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MessageMeta", // tp_name + sizeof(CMessageClass), // tp_basicsize + 0, // tp_itemsize + (destructor)message_meta::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags + "The metaclass of ProtocolMessages", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + message_meta::New, // tp_new +}; + +static CMessageClass* CheckMessageClass(PyTypeObject* cls) { + if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { + PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name); + return NULL; + } + return reinterpret_cast<CMessageClass*>(cls); +} + +static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { + CMessageClass* type = CheckMessageClass(cls); + if (type == NULL) { + return NULL; + } + return type->message_descriptor; +} + // Forward declarations namespace cmessage { -static PyObject* GetDescriptor(CMessage* self, PyObject* name); -static string GetMessageName(CMessage* self); int InternalReleaseFieldByDescriptor( - const google::protobuf::FieldDescriptor* field_descriptor, - PyObject* composite_field, - google::protobuf::Message* parent_message); + CMessage* self, + const FieldDescriptor* field_descriptor, + PyObject* composite_field); } // namespace cmessage // --------------------------------------------------------------------- @@ -105,7 +430,7 @@ struct ChildVisitor { // Returns 0 on success, -1 on failure. int VisitCMessage(CMessage* cmessage, - const google::protobuf::FieldDescriptor* field_descriptor) { + const FieldDescriptor* field_descriptor) { return 0; } }; @@ -116,20 +441,26 @@ template<class Visitor> static int VisitCompositeField(const FieldDescriptor* descriptor, PyObject* child, Visitor visitor) { - if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { - if (descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { - RepeatedCompositeContainer* container = - reinterpret_cast<RepeatedCompositeContainer*>(child); - if (visitor.VisitRepeatedCompositeContainer(container) == -1) - return -1; + if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + if (descriptor->is_map()) { + MapContainer* container = reinterpret_cast<MapContainer*>(child); + if (visitor.VisitMapContainer(container) == -1) { + return -1; + } + } else { + RepeatedCompositeContainer* container = + reinterpret_cast<RepeatedCompositeContainer*>(child); + if (visitor.VisitRepeatedCompositeContainer(container) == -1) + return -1; + } } else { RepeatedScalarContainer* container = reinterpret_cast<RepeatedScalarContainer*>(child); if (visitor.VisitRepeatedScalarContainer(container) == -1) return -1; } - } else if (descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { CMessage* cmsg = reinterpret_cast<CMessage*>(child); if (visitor.VisitCMessage(cmsg, descriptor) == -1) return -1; @@ -148,24 +479,33 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { PyObject* field; // Visit normal fields. - while (PyDict_Next(self->composite_fields, &pos, &key, &field)) { - PyObject* cdescriptor = cmessage::GetDescriptor(self, key); - if (cdescriptor != NULL) { - const google::protobuf::FieldDescriptor* descriptor = - reinterpret_cast<CFieldDescriptor*>(cdescriptor)->descriptor; - if (VisitCompositeField(descriptor, field, visitor) == -1) + if (self->composite_fields) { + // Never use self->message in this function, it may be already freed. + const Descriptor* message_descriptor = + GetMessageDescriptor(Py_TYPE(self)); + while (PyDict_Next(self->composite_fields, &pos, &key, &field)) { + Py_ssize_t key_str_size; + char *key_str_data; + if (PyString_AsStringAndSize(key, &key_str_data, &key_str_size) != 0) return -1; + const string key_str(key_str_data, key_str_size); + const FieldDescriptor* descriptor = + message_descriptor->FindFieldByName(key_str); + if (descriptor != NULL) { + if (VisitCompositeField(descriptor, field, visitor) == -1) + return -1; + } } } // Visit extension fields. if (self->extensions != NULL) { + pos = 0; while (PyDict_Next(self->extensions->values, &pos, &key, &field)) { - CFieldDescriptor* cdescriptor = - extension_dict::InternalGetCDescriptorFromExtension(key); - if (cdescriptor == NULL) + const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); + if (descriptor == NULL) return -1; - if (VisitCompositeField(cdescriptor->descriptor, field, visitor) == -1) + if (VisitCompositeField(descriptor, field, visitor) == -1) return -1; } } @@ -184,25 +524,13 @@ PyObject* kint64min_py; PyObject* kint64max_py; PyObject* kuint64max_py; -PyObject* EnumTypeWrapper_class; PyObject* EncodeError_class; PyObject* DecodeError_class; PyObject* PickleError_class; // Constant PyString values used for GetAttr/GetItem. -static PyObject* kDESCRIPTOR; -static PyObject* k__descriptors; +static PyObject* k_cdescriptor; static PyObject* kfull_name; -static PyObject* kname; -static PyObject* kmessage_type; -static PyObject* kis_extendable; -static PyObject* kextensions_by_name; -static PyObject* k_extensions_by_name; -static PyObject* k_extensions_by_number; -static PyObject* k_concrete_class; -static PyObject* kfields_by_name; - -static CDescriptorPool* descriptor_pool; /* Is 64bit */ void FormatTypeError(PyObject* arg, char* expected_types) { @@ -235,12 +563,14 @@ bool CheckAndGetInteger( if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 || PyObject_RichCompareBool(max, arg, Py_GE) != 1) { #endif - PyObject *s = PyObject_Str(arg); - if (s) { - PyErr_Format(PyExc_ValueError, - "Value out of range: %s", - PyString_AsString(s)); - Py_DECREF(s); + if (!PyErr_Occurred()) { + PyObject *s = PyObject_Str(arg); + if (s) { + PyErr_Format(PyExc_ValueError, + "Value out of range: %s", + PyString_AsString(s)); + Py_DECREF(s); + } } return false; } @@ -298,49 +628,59 @@ bool CheckAndGetBool(PyObject* arg, bool* value) { return true; } -bool CheckAndSetString( - PyObject* arg, google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* descriptor, - const google::protobuf::Reflection* reflection, - bool append, - int index) { - GOOGLE_DCHECK(descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING || - descriptor->type() == google::protobuf::FieldDescriptor::TYPE_BYTES); - if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { +// Checks whether the given object (which must be "bytes" or "unicode") contains +// valid UTF-8. +bool IsValidUTF8(PyObject* obj) { + if (PyBytes_Check(obj)) { + PyObject* unicode = PyUnicode_FromEncodedObject(obj, "utf-8", NULL); + + // Clear the error indicator; we report our own error when desired. + PyErr_Clear(); + + if (unicode) { + Py_DECREF(unicode); + return true; + } else { + return false; + } + } else { + // Unicode object, known to be valid UTF-8. + return true; + } +} + +bool AllowInvalidUTF8(const FieldDescriptor* field) { return false; } + +PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) { + GOOGLE_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING || + descriptor->type() == FieldDescriptor::TYPE_BYTES); + if (descriptor->type() == FieldDescriptor::TYPE_STRING) { if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) { FormatTypeError(arg, "bytes, unicode"); - return false; + return NULL; } - if (PyBytes_Check(arg)) { - PyObject* unicode = PyUnicode_FromEncodedObject(arg, "ascii", NULL); - if (unicode == NULL) { - PyObject* repr = PyObject_Repr(arg); - PyErr_Format(PyExc_ValueError, - "%s has type str, but isn't in 7-bit ASCII " - "encoding. Non-ASCII strings must be converted to " - "unicode objects before being added.", - PyString_AsString(repr)); - Py_DECREF(repr); - return false; - } else { - Py_DECREF(unicode); - } + if (!IsValidUTF8(arg) && !AllowInvalidUTF8(descriptor)) { + PyObject* repr = PyObject_Repr(arg); + PyErr_Format(PyExc_ValueError, + "%s has type str, but isn't valid UTF-8 " + "encoding. Non-UTF-8 strings must be converted to " + "unicode objects before being added.", + PyString_AsString(repr)); + Py_DECREF(repr); + return NULL; } } else if (!PyBytes_Check(arg)) { FormatTypeError(arg, "bytes"); - return false; + return NULL; } PyObject* encoded_string = NULL; - if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { + if (descriptor->type() == FieldDescriptor::TYPE_STRING) { if (PyBytes_Check(arg)) { -#if PY_MAJOR_VERSION < 3 - encoded_string = PyString_AsEncodedObject(arg, "utf-8", NULL); -#else + // The bytes were already validated as correctly encoded UTF-8 above. encoded_string = arg; // Already encoded. Py_INCREF(encoded_string); -#endif } else { encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL); } @@ -350,14 +690,24 @@ bool CheckAndSetString( Py_INCREF(encoded_string); } - if (encoded_string == NULL) { + return encoded_string; +} + +bool CheckAndSetString( + PyObject* arg, Message* message, + const FieldDescriptor* descriptor, + const Reflection* reflection, + bool append, + int index) { + ScopedPyObjectPtr encoded_string(CheckString(arg, descriptor)); + + if (encoded_string.get() == NULL) { return false; } char* value; Py_ssize_t value_len; - if (PyBytes_AsStringAndSize(encoded_string, &value, &value_len) < 0) { - Py_DECREF(encoded_string); + if (PyBytes_AsStringAndSize(encoded_string.get(), &value, &value_len) < 0) { return false; } @@ -369,13 +719,11 @@ bool CheckAndSetString( } else { reflection->SetRepeatedString(message, descriptor, index, value_string); } - Py_DECREF(encoded_string); return true; } -PyObject* ToStringObject( - const google::protobuf::FieldDescriptor* descriptor, string value) { - if (descriptor->type() != google::protobuf::FieldDescriptor::TYPE_STRING) { +PyObject* ToStringObject(const FieldDescriptor* descriptor, string value) { + if (descriptor->type() != FieldDescriptor::TYPE_STRING) { return PyBytes_FromStringAndSize(value.c_str(), value.length()); } @@ -391,16 +739,36 @@ PyObject* ToStringObject( return result; } -google::protobuf::DynamicMessageFactory* global_message_factory; +bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, + const Message* message) { + if (message->GetDescriptor() == field_descriptor->containing_type()) { + return true; + } + PyErr_Format(PyExc_KeyError, "Field '%s' does not belong to message '%s'", + field_descriptor->full_name().c_str(), + message->GetDescriptor()->full_name().c_str()); + return false; +} namespace cmessage { +PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) { + // No need to check the type: the type of instances of CMessage is always + // an instance of CMessageClass. Let's prove it with a debug-only check. + GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type)); + return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_descriptor_pool; +} + +MessageFactory* GetFactoryForMessage(CMessage* message) { + return GetDescriptorPoolForMessage(message)->message_factory; +} + static int MaybeReleaseOverlappingOneofField( CMessage* cmessage, - const google::protobuf::FieldDescriptor* field) { + const FieldDescriptor* field) { #ifdef GOOGLE_PROTOBUF_HAS_ONEOF - google::protobuf::Message* message = cmessage->message; - const google::protobuf::Reflection* reflection = message->GetReflection(); + Message* message = cmessage->message; + const Reflection* reflection = message->GetReflection(); if (!field->containing_oneof() || !reflection->HasOneof(*message, field->containing_oneof()) || reflection->HasField(*message, field)) { @@ -411,20 +779,20 @@ static int MaybeReleaseOverlappingOneofField( const OneofDescriptor* oneof = field->containing_oneof(); const FieldDescriptor* existing_field = reflection->GetOneofFieldDescriptor(*message, oneof); - if (existing_field->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (existing_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { // Non-message fields don't need to be released. return 0; } const char* field_name = existing_field->name().c_str(); - PyObject* child_message = PyDict_GetItemString( - cmessage->composite_fields, field_name); + PyObject* child_message = cmessage->composite_fields ? + PyDict_GetItemString(cmessage->composite_fields, field_name) : NULL; if (child_message == NULL) { // No python reference to this field so no need to release. return 0; } if (InternalReleaseFieldByDescriptor( - existing_field, child_message, message) < 0) { + cmessage, existing_field, child_message) < 0) { return -1; } return PyDict_DelItemString(cmessage->composite_fields, field_name); @@ -436,21 +804,21 @@ static int MaybeReleaseOverlappingOneofField( // --------------------------------------------------------------------- // Making a message writable -static google::protobuf::Message* GetMutableMessage( +static Message* GetMutableMessage( CMessage* parent, - const google::protobuf::FieldDescriptor* parent_field) { - google::protobuf::Message* parent_message = parent->message; - const google::protobuf::Reflection* reflection = parent_message->GetReflection(); + const FieldDescriptor* parent_field) { + Message* parent_message = parent->message; + const Reflection* reflection = parent_message->GetReflection(); if (MaybeReleaseOverlappingOneofField(parent, parent_field) < 0) { return NULL; } return reflection->MutableMessage( - parent_message, parent_field, global_message_factory); + parent_message, parent_field, GetFactoryForMessage(parent)); } struct FixupMessageReference : public ChildVisitor { // message must outlive this object. - explicit FixupMessageReference(google::protobuf::Message* message) : + explicit FixupMessageReference(Message* message) : message_(message) {} int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { @@ -463,8 +831,13 @@ struct FixupMessageReference : public ChildVisitor { return 0; } + int VisitMapContainer(MapContainer* container) { + container->message = message_; + return 0; + } + private: - google::protobuf::Message* message_; + Message* message_; }; int AssureWritable(CMessage* self) { @@ -476,20 +849,20 @@ int AssureWritable(CMessage* self) { // If parent is NULL but we are trying to modify a read-only message, this // is a reference to a constant default instance that needs to be replaced // with a mutable top-level message. - const Message* prototype = global_message_factory->GetPrototype( - self->message->GetDescriptor()); - self->message = prototype->New(); + self->message = self->message->New(); self->owner.reset(self->message); + // Cascade the new owner to eventual children: even if this message is + // empty, some submessages or repeated containers might exist already. + SetOwner(self, self->owner); } else { // Otherwise, we need a mutable child message. if (AssureWritable(self->parent) == -1) return -1; // Make self->message writable. - google::protobuf::Message* parent_message = self->parent->message; - google::protobuf::Message* mutable_message = GetMutableMessage( + Message* mutable_message = GetMutableMessage( self->parent, - self->parent_field->descriptor); + self->parent_field_descriptor); if (mutable_message == NULL) { return -1; } @@ -500,8 +873,8 @@ int AssureWritable(CMessage* self) { // When a CMessage is made writable its Message pointer is updated // to point to a new mutable Message. When that happens we need to // update any references to the old, read-only CMessage. There are - // three places such references occur: RepeatedScalarContainer, - // RepeatedCompositeContainer, and ExtensionDict. + // four places such references occur: RepeatedScalarContainer, + // RepeatedCompositeContainer, MapContainer, and ExtensionDict. if (self->extensions != NULL) self->extensions->message = self->message; if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1) @@ -512,26 +885,65 @@ int AssureWritable(CMessage* self) { // --- Globals: -static PyObject* GetDescriptor(CMessage* self, PyObject* name) { - PyObject* descriptors = - PyDict_GetItem(Py_TYPE(self)->tp_dict, k__descriptors); - if (descriptors == NULL) { - PyErr_SetString(PyExc_TypeError, "No __descriptors"); +// Retrieve a C++ FieldDescriptor for a message attribute. +// The C++ message must be valid. +// TODO(amauryfa): This function should stay internal, because exception +// handling is not consistent. +static const FieldDescriptor* GetFieldDescriptor( + CMessage* self, PyObject* name) { + const Descriptor *message_descriptor = self->message->GetDescriptor(); + char* field_name; + Py_ssize_t size; + if (PyString_AsStringAndSize(name, &field_name, &size) < 0) { return NULL; } - - return PyDict_GetItem(descriptors, name); + const FieldDescriptor *field_descriptor = + message_descriptor->FindFieldByName(string(field_name, size)); + if (field_descriptor == NULL) { + // Note: No exception is set! + return NULL; + } + return field_descriptor; } -static const google::protobuf::Message* CreateMessage(const char* message_type) { - string message_name(message_type); - const google::protobuf::Descriptor* descriptor = - GetDescriptorPool()->FindMessageTypeByName(message_name); - if (descriptor == NULL) { - PyErr_SetString(PyExc_TypeError, message_type); +// Retrieve a C++ FieldDescriptor for an extension handle. +const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) { + ScopedPyObjectPtr cdescriptor; + if (!PyObject_TypeCheck(extension, &PyFieldDescriptor_Type)) { + // Most callers consider extensions as a plain dictionary. We should + // allow input which is not a field descriptor, and simply pretend it does + // not exist. + PyErr_SetObject(PyExc_KeyError, extension); return NULL; } - return global_message_factory->GetPrototype(descriptor); + return PyFieldDescriptor_AsDescriptor(extension); +} + +// If value is a string, convert it into an enum value based on the labels in +// descriptor, otherwise simply return value. Always returns a new reference. +static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor, + PyObject* value) { + if (PyString_Check(value) || PyUnicode_Check(value)) { + const EnumDescriptor* enum_descriptor = descriptor.enum_type(); + if (enum_descriptor == NULL) { + PyErr_SetString(PyExc_TypeError, "not an enum field"); + return NULL; + } + char* enum_label; + Py_ssize_t size; + if (PyString_AsStringAndSize(value, &enum_label, &size) < 0) { + return NULL; + } + const EnumValueDescriptor* enum_value_descriptor = + enum_descriptor->FindValueByName(string(enum_label, size)); + if (enum_value_descriptor == NULL) { + PyErr_SetString(PyExc_ValueError, "unknown enum label"); + return NULL; + } + return PyInt_FromLong(enum_value_descriptor->number()); + } + Py_INCREF(value); + return value; } // If cmessage_list is not NULL, this function releases values into the @@ -539,12 +951,13 @@ static const google::protobuf::Message* CreateMessage(const char* message_type) // needs to do this to make sure CMessages stay alive if they're still // referenced after deletion. Repeated scalar container doesn't need to worry. int InternalDeleteRepeatedField( - google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* field_descriptor, + CMessage* self, + const FieldDescriptor* field_descriptor, PyObject* slice, PyObject* cmessage_list) { + Message* message = self->message; Py_ssize_t length, from, to, step, slice_length; - const google::protobuf::Reflection* reflection = message->GetReflection(); + const Reflection* reflection = message->GetReflection(); int min, max; length = reflection->FieldSize(*message, field_descriptor); @@ -616,7 +1029,7 @@ int InternalDeleteRepeatedField( CMessage* last_cmessage = reinterpret_cast<CMessage*>( PyList_GET_ITEM(cmessage_list, PyList_GET_SIZE(cmessage_list) - 1)); repeated_composite_container::ReleaseLastTo( - field_descriptor, message, last_cmessage); + self, field_descriptor, last_cmessage); if (PySequence_DelItem(cmessage_list, -1) < 0) { return -1; } @@ -627,39 +1040,8 @@ int InternalDeleteRepeatedField( return 0; } -int InitAttributes(CMessage* self, PyObject* arg, PyObject* kwargs) { - ScopedPyObjectPtr descriptor; - if (arg == NULL) { - descriptor.reset( - PyObject_GetAttr(reinterpret_cast<PyObject*>(self), kDESCRIPTOR)); - if (descriptor == NULL) { - return NULL; - } - } else { - descriptor.reset(arg); - descriptor.inc(); - } - ScopedPyObjectPtr is_extendable(PyObject_GetAttr(descriptor, kis_extendable)); - if (is_extendable == NULL) { - return NULL; - } - int retcode = PyObject_IsTrue(is_extendable); - if (retcode == -1) { - return NULL; - } - if (retcode) { - PyObject* py_extension_dict = PyObject_CallObject( - reinterpret_cast<PyObject*>(&ExtensionDict_Type), NULL); - if (py_extension_dict == NULL) { - return NULL; - } - ExtensionDict* extension_dict = reinterpret_cast<ExtensionDict*>( - py_extension_dict); - extension_dict->parent = self; - extension_dict->message = self->message; - self->extensions = extension_dict; - } - +// Initializes fields of a message. Used in constructors. +int InitAttributes(CMessage* self, PyObject* kwargs) { if (kwargs == NULL) { return 0; } @@ -672,46 +1054,138 @@ int InitAttributes(CMessage* self, PyObject* arg, PyObject* kwargs) { PyErr_SetString(PyExc_ValueError, "Field name must be a string"); return -1; } - PyObject* py_cdescriptor = GetDescriptor(self, name); - if (py_cdescriptor == NULL) { - PyErr_Format(PyExc_ValueError, "Protocol message has no \"%s\" field.", + const FieldDescriptor* descriptor = GetFieldDescriptor(self, name); + if (descriptor == NULL) { + PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.", + self->message->GetDescriptor()->name().c_str(), PyString_AsString(name)); return -1; } - const google::protobuf::FieldDescriptor* descriptor = - reinterpret_cast<CFieldDescriptor*>(py_cdescriptor)->descriptor; - if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + if (value == Py_None) { + // field=None is the same as no field at all. + continue; + } + if (descriptor->is_map()) { + ScopedPyObjectPtr map(GetAttr(self, name)); + const FieldDescriptor* value_descriptor = + descriptor->message_type()->FindFieldByName("value"); + if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + Py_ssize_t map_pos = 0; + PyObject* map_key; + PyObject* map_value; + while (PyDict_Next(value, &map_pos, &map_key, &map_value)) { + ScopedPyObjectPtr function_return; + function_return.reset(PyObject_GetItem(map.get(), map_key)); + if (function_return.get() == NULL) { + return -1; + } + ScopedPyObjectPtr ok(PyObject_CallMethod( + function_return.get(), "MergeFrom", "O", map_value)); + if (ok.get() == NULL) { + return -1; + } + } + } else { + ScopedPyObjectPtr function_return; + function_return.reset( + PyObject_CallMethod(map.get(), "update", "O", value)); + if (function_return.get() == NULL) { + return -1; + } + } + } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { ScopedPyObjectPtr container(GetAttr(self, name)); if (container == NULL) { return -1; } - if (descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { - if (repeated_composite_container::Extend( - reinterpret_cast<RepeatedCompositeContainer*>(container.get()), - value) - == NULL) { + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + RepeatedCompositeContainer* rc_container = + reinterpret_cast<RepeatedCompositeContainer*>(container.get()); + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == NULL) { + PyErr_SetString(PyExc_TypeError, "Value must be iterable"); + return -1; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter.get()))) != NULL) { + PyObject* kwargs = (PyDict_Check(next.get()) ? next.get() : NULL); + ScopedPyObjectPtr new_msg( + repeated_composite_container::Add(rc_container, NULL, kwargs)); + if (new_msg == NULL) { + return -1; + } + if (kwargs == NULL) { + // next was not a dict, it's a message we need to merge + ScopedPyObjectPtr merged(MergeFrom( + reinterpret_cast<CMessage*>(new_msg.get()), next.get())); + if (merged.get() == NULL) { + return -1; + } + } + } + if (PyErr_Occurred()) { + // Check to see how PyIter_Next() exited. + return -1; + } + } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { + RepeatedScalarContainer* rs_container = + reinterpret_cast<RepeatedScalarContainer*>(container.get()); + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == NULL) { + PyErr_SetString(PyExc_TypeError, "Value must be iterable"); + return -1; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter.get()))) != NULL) { + ScopedPyObjectPtr enum_value( + GetIntegerEnumValue(*descriptor, next.get())); + if (enum_value == NULL) { + return -1; + } + ScopedPyObjectPtr new_msg(repeated_scalar_container::Append( + rs_container, enum_value.get())); + if (new_msg == NULL) { + return -1; + } + } + if (PyErr_Occurred()) { + // Check to see how PyIter_Next() exited. return -1; } } else { - if (repeated_scalar_container::Extend( + if (ScopedPyObjectPtr(repeated_scalar_container::Extend( reinterpret_cast<RepeatedScalarContainer*>(container.get()), - value) == + value)) == NULL) { return -1; } } - } else if (descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { ScopedPyObjectPtr message(GetAttr(self, name)); if (message == NULL) { return -1; } - if (MergeFrom(reinterpret_cast<CMessage*>(message.get()), - value) == NULL) { - return -1; + CMessage* cmessage = reinterpret_cast<CMessage*>(message.get()); + if (PyDict_Check(value)) { + if (InitAttributes(cmessage, value) < 0) { + return -1; + } + } else { + ScopedPyObjectPtr merged(MergeFrom(cmessage, value)); + if (merged == NULL) { + return -1; + } } } else { - if (SetAttr(self, name, value) < 0) { + ScopedPyObjectPtr new_val; + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { + new_val.reset(GetIntegerEnumValue(*descriptor, value)); + if (new_val == NULL) { + return -1; + } + } + if (SetAttr(self, name, (new_val.get() == NULL) ? value : new_val.get()) < + 0) { return -1; } } @@ -719,59 +1193,64 @@ int InitAttributes(CMessage* self, PyObject* arg, PyObject* kwargs) { return 0; } -static PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) { - CMessage* self = reinterpret_cast<CMessage*>(type->tp_alloc(type, 0)); +// Allocates an incomplete Python Message: the caller must fill self->message, +// self->owner and eventually self->parent. +CMessage* NewEmptyMessage(CMessageClass* type) { + CMessage* self = reinterpret_cast<CMessage*>( + PyType_GenericAlloc(&type->super.ht_type, 0)); if (self == NULL) { return NULL; } self->message = NULL; self->parent = NULL; - self->parent_field = NULL; + self->parent_field_descriptor = NULL; self->read_only = false; self->extensions = NULL; - self->composite_fields = PyDict_New(); - if (self->composite_fields == NULL) { - return NULL; - } - return reinterpret_cast<PyObject*>(self); -} + self->composite_fields = NULL; -PyObject* NewEmpty(PyObject* type) { - return New(reinterpret_cast<PyTypeObject*>(type), NULL, NULL); + return self; } -static int Init(CMessage* self, PyObject* args, PyObject* kwargs) { - if (kwargs == NULL) { - // TODO(anuraag): Set error - return -1; +// The __new__ method of Message classes. +// Creates a new C++ message and takes ownership. +static PyObject* New(PyTypeObject* cls, + PyObject* unused_args, PyObject* unused_kwargs) { + CMessageClass* type = CheckMessageClass(cls); + if (type == NULL) { + return NULL; } - - PyObject* descriptor = PyTuple_GetItem(args, 0); - if (descriptor == NULL || PyTuple_Size(args) != 1) { - PyErr_SetString(PyExc_ValueError, "args must contain one arg: descriptor"); - return -1; + // Retrieve the message descriptor and the default instance (=prototype). + const Descriptor* message_descriptor = type->message_descriptor; + if (message_descriptor == NULL) { + return NULL; } - - ScopedPyObjectPtr py_message_type(PyObject_GetAttr(descriptor, kfull_name)); - if (py_message_type == NULL) { - return -1; + const Message* default_message = type->py_descriptor_pool->message_factory + ->GetPrototype(message_descriptor); + if (default_message == NULL) { + PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str()); + return NULL; } - const char* message_type = PyString_AsString(py_message_type.get()); - const google::protobuf::Message* message = CreateMessage(message_type); - if (message == NULL) { - return -1; + CMessage* self = NewEmptyMessage(type); + if (self == NULL) { + return NULL; } - - self->message = message->New(); + self->message = default_message->New(); self->owner.reset(self->message); + return reinterpret_cast<PyObject*>(self); +} - if (InitAttributes(self, descriptor, kwargs) < 0) { +// The __init__ method of Message classes. +// It initializes fields from keywords passed to the constructor. +static int Init(CMessage* self, PyObject* args, PyObject* kwargs) { + if (PyTuple_Size(args) != 0) { + PyErr_SetString(PyExc_TypeError, "No positional arguments allowed"); return -1; } - return 0; + + return InitAttributes(self, kwargs); } // --------------------------------------------------------------------- @@ -800,8 +1279,13 @@ struct ClearWeakReferences : public ChildVisitor { return 0; } + int VisitMapContainer(MapContainer* container) { + container->parent = NULL; + return 0; + } + int VisitCMessage(CMessage* cmessage, - const google::protobuf::FieldDescriptor* field_descriptor) { + const FieldDescriptor* field_descriptor) { cmessage->parent = NULL; return 0; } @@ -810,6 +1294,9 @@ struct ClearWeakReferences : public ChildVisitor { static void Dealloc(CMessage* self) { // Null out all weak references from children to this message. GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences())); + if (self->extensions) { + self->extensions->parent = NULL; + } Py_CLEAR(self->extensions); Py_CLEAR(self->composite_fields); @@ -851,14 +1338,12 @@ PyObject* IsInitialized(CMessage* self, PyObject* args) { } PyObject* HasFieldByDescriptor( - CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor) { - google::protobuf::Message* message = self->message; - if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { - PyErr_SetString(PyExc_KeyError, - "Field does not belong to message!"); + CMessage* self, const FieldDescriptor* field_descriptor) { + Message* message = self->message; + if (!CheckFieldBelongsToMessage(field_descriptor, message)) { return NULL; } - if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { PyErr_SetString(PyExc_KeyError, "Field is repeated. A singular method is required."); return NULL; @@ -868,42 +1353,78 @@ PyObject* HasFieldByDescriptor( return PyBool_FromLong(has_field ? 1 : 0); } -const google::protobuf::FieldDescriptor* FindFieldWithOneofs( - const google::protobuf::Message* message, const char* field_name, bool* in_oneof) { - const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); - const google::protobuf::FieldDescriptor* field_descriptor = +const FieldDescriptor* FindFieldWithOneofs( + const Message* message, const string& field_name, bool* in_oneof) { + *in_oneof = false; + const Descriptor* descriptor = message->GetDescriptor(); + const FieldDescriptor* field_descriptor = descriptor->FindFieldByName(field_name); - if (field_descriptor == NULL) { - const google::protobuf::OneofDescriptor* oneof_desc = - message->GetDescriptor()->FindOneofByName(field_name); - if (oneof_desc == NULL) { - *in_oneof = false; - return NULL; - } else { - *in_oneof = true; - return message->GetReflection()->GetOneofFieldDescriptor( - *message, oneof_desc); + if (field_descriptor != NULL) { + return field_descriptor; + } + const OneofDescriptor* oneof_desc = + descriptor->FindOneofByName(field_name); + if (oneof_desc != NULL) { + *in_oneof = true; + return message->GetReflection()->GetOneofFieldDescriptor(*message, + oneof_desc); + } + return NULL; +} + +bool CheckHasPresence(const FieldDescriptor* field_descriptor, bool in_oneof) { + if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + PyErr_Format(PyExc_ValueError, + "Protocol message has no singular \"%s\" field.", + field_descriptor->name().c_str()); + return false; + } + + if (field_descriptor->file()->syntax() == FileDescriptor::SYNTAX_PROTO3) { + // HasField() for a oneof *itself* isn't supported. + if (in_oneof) { + PyErr_Format(PyExc_ValueError, + "Can't test oneof field \"%s\" for presence in proto3, use " + "WhichOneof instead.", + field_descriptor->containing_oneof()->name().c_str()); + return false; + } + + // ...but HasField() for fields *in* a oneof is supported. + if (field_descriptor->containing_oneof() != NULL) { + return true; + } + + if (field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { + PyErr_Format( + PyExc_ValueError, + "Can't test non-submessage field \"%s\" for presence in proto3.", + field_descriptor->name().c_str()); + return false; } } - return field_descriptor; + + return true; } PyObject* HasField(CMessage* self, PyObject* arg) { -#if PY_MAJOR_VERSION < 3 char* field_name; - if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) { + Py_ssize_t size; +#if PY_MAJOR_VERSION < 3 + if (PyString_AsStringAndSize(arg, &field_name, &size) < 0) { + return NULL; + } #else - char* field_name = PyUnicode_AsUTF8(arg); + field_name = PyUnicode_AsUTF8AndSize(arg, &size); if (!field_name) { -#endif return NULL; } +#endif - google::protobuf::Message* message = self->message; - const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); + Message* message = self->message; bool is_in_oneof; - const google::protobuf::FieldDescriptor* field_descriptor = - FindFieldWithOneofs(message, field_name, &is_in_oneof); + const FieldDescriptor* field_descriptor = + FindFieldWithOneofs(message, string(field_name, size), &is_in_oneof); if (field_descriptor == NULL) { if (!is_in_oneof) { PyErr_Format(PyExc_ValueError, "Unknown field %s.", field_name); @@ -913,44 +1434,51 @@ PyObject* HasField(CMessage* self, PyObject* arg) { } } - if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { - PyErr_Format(PyExc_ValueError, - "Protocol message has no singular \"%s\" field.", field_name); + if (!CheckHasPresence(field_descriptor, is_in_oneof)) { return NULL; } - bool has_field = - message->GetReflection()->HasField(*message, field_descriptor); - if (!has_field && field_descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_ENUM) { - // We may have an invalid enum value stored in the UnknownFieldSet and need - // to check presence in there as well. - const google::protobuf::UnknownFieldSet& unknown_field_set = + if (message->GetReflection()->HasField(*message, field_descriptor)) { + Py_RETURN_TRUE; + } + if (!message->GetReflection()->SupportsUnknownEnumValues() && + field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { + // Special case: Python HasField() differs in semantics from C++ + // slightly: we return HasField('enum_field') == true if there is + // an unknown enum value present. To implement this we have to + // look in the UnknownFieldSet. + const UnknownFieldSet& unknown_field_set = message->GetReflection()->GetUnknownFields(*message); for (int i = 0; i < unknown_field_set.field_count(); ++i) { if (unknown_field_set.field(i).number() == field_descriptor->number()) { Py_RETURN_TRUE; } } - Py_RETURN_FALSE; } - return PyBool_FromLong(has_field ? 1 : 0); + Py_RETURN_FALSE; } -PyObject* ClearExtension(CMessage* self, PyObject* arg) { +PyObject* ClearExtension(CMessage* self, PyObject* extension) { if (self->extensions != NULL) { - return extension_dict::ClearExtension(self->extensions, arg); + return extension_dict::ClearExtension(self->extensions, extension); + } else { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; + } + if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) { + return NULL; + } } - PyErr_SetString(PyExc_TypeError, "Message is not extendable"); - return NULL; + Py_RETURN_NONE; } -PyObject* HasExtension(CMessage* self, PyObject* arg) { - if (self->extensions != NULL) { - return extension_dict::HasExtension(self->extensions, arg); +PyObject* HasExtension(CMessage* self, PyObject* extension) { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; } - PyErr_SetString(PyExc_TypeError, "Message is not extendable"); - return NULL; + return HasFieldByDescriptor(self, descriptor); } // --------------------------------------------------------------------- @@ -1000,8 +1528,13 @@ struct SetOwnerVisitor : public ChildVisitor { return 0; } + int VisitMapContainer(MapContainer* container) { + container->SetOwner(new_owner_); + return 0; + } + int VisitCMessage(CMessage* cmessage, - const google::protobuf::FieldDescriptor* field_descriptor) { + const FieldDescriptor* field_descriptor) { return SetOwner(cmessage, new_owner_); } @@ -1020,18 +1553,18 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) { // Releases the message specified by 'field' and returns the // pointer. If the field does not exist a new message is created using // 'descriptor'. The caller takes ownership of the returned pointer. -Message* ReleaseMessage(google::protobuf::Message* message, - const google::protobuf::Descriptor* descriptor, - const google::protobuf::FieldDescriptor* field_descriptor) { - Message* released_message = message->GetReflection()->ReleaseMessage( - message, field_descriptor, global_message_factory); +Message* ReleaseMessage(CMessage* self, + const Descriptor* descriptor, + const FieldDescriptor* field_descriptor) { + MessageFactory* message_factory = GetFactoryForMessage(self); + Message* released_message = self->message->GetReflection()->ReleaseMessage( + self->message, field_descriptor, message_factory); // ReleaseMessage will return NULL which differs from // child_cmessage->message, if the field does not exist. In this case, // the latter points to the default instance via a const_cast<>, so we // have to reset it to a new mutable object since we are taking ownership. if (released_message == NULL) { - const Message* prototype = global_message_factory->GetPrototype( - descriptor); + const Message* prototype = message_factory->GetPrototype(descriptor); GOOGLE_DCHECK(prototype != NULL); released_message = prototype->New(); } @@ -1039,16 +1572,16 @@ Message* ReleaseMessage(google::protobuf::Message* message, return released_message; } -int ReleaseSubMessage(google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* field_descriptor, +int ReleaseSubMessage(CMessage* self, + const FieldDescriptor* field_descriptor, CMessage* child_cmessage) { // Release the Message shared_ptr<Message> released_message(ReleaseMessage( - message, child_cmessage->message->GetDescriptor(), field_descriptor)); + self, child_cmessage->message->GetDescriptor(), field_descriptor)); child_cmessage->message = released_message.get(); child_cmessage->owner.swap(released_message); child_cmessage->parent = NULL; - child_cmessage->parent_field = NULL; + child_cmessage->parent_field_descriptor = NULL; child_cmessage->read_only = false; return ForEachCompositeField(child_cmessage, SetOwnerVisitor(child_cmessage->owner)); @@ -1056,8 +1589,8 @@ int ReleaseSubMessage(google::protobuf::Message* message, struct ReleaseChild : public ChildVisitor { // message must outlive this object. - explicit ReleaseChild(google::protobuf::Message* parent_message) : - parent_message_(parent_message) {} + explicit ReleaseChild(CMessage* parent) : + parent_(parent) {} int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { return repeated_composite_container::Release( @@ -1069,44 +1602,33 @@ struct ReleaseChild : public ChildVisitor { reinterpret_cast<RepeatedScalarContainer*>(container)); } + int VisitMapContainer(MapContainer* container) { + return reinterpret_cast<MapContainer*>(container)->Release(); + } + int VisitCMessage(CMessage* cmessage, - const google::protobuf::FieldDescriptor* field_descriptor) { - return ReleaseSubMessage(parent_message_, field_descriptor, + const FieldDescriptor* field_descriptor) { + return ReleaseSubMessage(parent_, field_descriptor, reinterpret_cast<CMessage*>(cmessage)); } - google::protobuf::Message* parent_message_; + CMessage* parent_; }; int InternalReleaseFieldByDescriptor( - const google::protobuf::FieldDescriptor* field_descriptor, - PyObject* composite_field, - google::protobuf::Message* parent_message) { + CMessage* self, + const FieldDescriptor* field_descriptor, + PyObject* composite_field) { return VisitCompositeField( field_descriptor, composite_field, - ReleaseChild(parent_message)); -} - -int InternalReleaseField(CMessage* self, PyObject* composite_field, - PyObject* name) { - PyObject* cdescriptor = GetDescriptor(self, name); - if (cdescriptor != NULL) { - const google::protobuf::FieldDescriptor* descriptor = - reinterpret_cast<CFieldDescriptor*>(cdescriptor)->descriptor; - return InternalReleaseFieldByDescriptor( - descriptor, composite_field, self->message); - } - - return 0; + ReleaseChild(self)); } PyObject* ClearFieldByDescriptor( CMessage* self, - const google::protobuf::FieldDescriptor* descriptor) { - if (!FIELD_BELONGS_TO_MESSAGE(descriptor, self->message)) { - PyErr_SetString(PyExc_KeyError, - "Field does not belong to message!"); + const FieldDescriptor* descriptor) { + if (!CheckFieldBelongsToMessage(descriptor, self->message)) { return NULL; } AssureWritable(self); @@ -1115,25 +1637,23 @@ PyObject* ClearFieldByDescriptor( } PyObject* ClearField(CMessage* self, PyObject* arg) { - char* field_name; if (!PyString_Check(arg)) { PyErr_SetString(PyExc_TypeError, "field name must be a string"); return NULL; } #if PY_MAJOR_VERSION < 3 - if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) { - return NULL; - } + const char* field_name = PyString_AS_STRING(arg); + Py_ssize_t size = PyString_GET_SIZE(arg); #else - field_name = PyUnicode_AsUTF8(arg); + Py_ssize_t size; + const char* field_name = PyUnicode_AsUTF8AndSize(arg, &size); #endif AssureWritable(self); - google::protobuf::Message* message = self->message; - const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); + Message* message = self->message; ScopedPyObjectPtr arg_in_oneof; bool is_in_oneof; - const google::protobuf::FieldDescriptor* field_descriptor = - FindFieldWithOneofs(message, field_name, &is_in_oneof); + const FieldDescriptor* field_descriptor = + FindFieldWithOneofs(message, string(field_name, size), &is_in_oneof); if (field_descriptor == NULL) { if (!is_in_oneof) { PyErr_Format(PyExc_ValueError, @@ -1143,24 +1663,27 @@ PyObject* ClearField(CMessage* self, PyObject* arg) { Py_RETURN_NONE; } } else if (is_in_oneof) { - arg_in_oneof.reset(PyString_FromString(field_descriptor->name().c_str())); + const string& name = field_descriptor->name(); + arg_in_oneof.reset(PyString_FromStringAndSize(name.c_str(), name.size())); arg = arg_in_oneof.get(); } - PyObject* composite_field = PyDict_GetItem(self->composite_fields, - arg); + PyObject* composite_field = self->composite_fields ? + PyDict_GetItem(self->composite_fields, arg) : NULL; // Only release the field if there's a possibility that there are // references to it. if (composite_field != NULL) { - if (InternalReleaseField(self, composite_field, arg) < 0) { + if (InternalReleaseFieldByDescriptor(self, field_descriptor, + composite_field) < 0) { return NULL; } PyDict_DelItem(self->composite_fields, arg); } message->GetReflection()->ClearField(message, field_descriptor); - if (field_descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_ENUM) { - google::protobuf::UnknownFieldSet* unknown_field_set = + if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM && + !message->GetReflection()->SupportsUnknownEnumValues()) { + UnknownFieldSet* unknown_field_set = message->GetReflection()->MutableUnknownFields(message); unknown_field_set->DeleteByNumber(field_descriptor->number()); } @@ -1170,25 +1693,12 @@ PyObject* ClearField(CMessage* self, PyObject* arg) { PyObject* Clear(CMessage* self) { AssureWritable(self); - if (ForEachCompositeField(self, ReleaseChild(self->message)) == -1) + if (ForEachCompositeField(self, ReleaseChild(self)) == -1) return NULL; - - // The old ExtensionDict still aliases this CMessage, but all its - // fields have been released. - if (self->extensions != NULL) { - Py_CLEAR(self->extensions); - PyObject* py_extension_dict = PyObject_CallObject( - reinterpret_cast<PyObject*>(&ExtensionDict_Type), NULL); - if (py_extension_dict == NULL) { - return NULL; - } - ExtensionDict* extension_dict = reinterpret_cast<ExtensionDict*>( - py_extension_dict); - extension_dict->parent = self; - extension_dict->message = self->message; - self->extensions = extension_dict; + Py_CLEAR(self->extensions); + if (self->composite_fields) { + PyDict_Clear(self->composite_fields); } - PyDict_Clear(self->composite_fields); self->message->Clear(); Py_RETURN_NONE; } @@ -1196,8 +1706,8 @@ PyObject* Clear(CMessage* self) { // --------------------------------------------------------------------- static string GetMessageName(CMessage* self) { - if (self->parent_field != NULL) { - return self->parent_field->descriptor->full_name(); + if (self->parent_field_descriptor != NULL) { + return self->parent_field_descriptor->full_name(); } else { return self->message->GetDescriptor()->full_name(); } @@ -1218,7 +1728,27 @@ static PyObject* SerializeToString(CMessage* self, PyObject* args) { if (joined == NULL) { return NULL; } - PyErr_Format(EncodeError_class, "Message %s is missing required fields: %s", + + // TODO(haberman): this is a (hopefully temporary) hack. The unit testing + // infrastructure reloads all pure-Python modules for every test, but not + // C++ modules (because that's generally impossible: + // http://bugs.python.org/issue1144263). But if we cache EncodeError, we'll + // return the EncodeError from a previous load of the module, which won't + // match a user's attempt to catch EncodeError. So we have to look it up + // again every time. + ScopedPyObjectPtr message_module(PyImport_ImportModule( + "google.protobuf.message")); + if (message_module.get() == NULL) { + return NULL; + } + + ScopedPyObjectPtr encode_error( + PyObject_GetAttrString(message_module.get(), "EncodeError")); + if (encode_error.get() == NULL) { + return NULL; + } + PyErr_Format(encode_error.get(), + "Message %s is missing required fields: %s", GetMessageName(self).c_str(), PyString_AsString(joined.get())); return NULL; } @@ -1244,40 +1774,44 @@ static PyObject* SerializePartialToString(CMessage* self) { // Formats proto fields for ascii dumps using python formatting functions where // appropriate. -class PythonFieldValuePrinter : public google::protobuf::TextFormat::FieldValuePrinter { +class PythonFieldValuePrinter : public TextFormat::FieldValuePrinter { public: - PythonFieldValuePrinter() : float_holder_(PyFloat_FromDouble(0)) {} - // Python has some differences from C++ when printing floating point numbers. // // 1) Trailing .0 is always printed. - // 2) Outputted is rounded to 12 digits. + // 2) (Python2) Output is rounded to 12 digits. + // 3) (Python3) The full precision of the double is preserved (and Python uses + // David M. Gay's dtoa(), when the C++ code uses SimpleDtoa. There are some + // differences, but they rarely happen) // // We override floating point printing with the C-API function for printing // Python floats to ensure consistency. string PrintFloat(float value) const { return PrintDouble(value); } string PrintDouble(double value) const { - reinterpret_cast<PyFloatObject*>(float_holder_.get())->ob_fval = value; - ScopedPyObjectPtr s(PyObject_Str(float_holder_.get())); - if (s == NULL) return string(); -#if PY_MAJOR_VERSION < 3 - char *cstr = PyBytes_AS_STRING(static_cast<PyObject*>(s)); -#else - char *cstr = PyUnicode_AsUTF8(s); -#endif - return string(cstr); - } + // This implementation is not highly optimized (it allocates two temporary + // Python objects) but it is simple and portable. If this is shown to be a + // performance bottleneck, we can optimize it, but the results will likely + // be more complicated to accommodate the differing behavior of double + // formatting between Python 2 and Python 3. + // + // (Though a valid question is: do we really want to make out output + // dependent on the Python version?) + ScopedPyObjectPtr py_value(PyFloat_FromDouble(value)); + if (!py_value.get()) { + return string(); + } - private: - // Holder for a python float object which we use to allow us to use - // the Python API for printing doubles. We initialize once and then - // directly modify it for every float printed to save on allocations - // and refcounting. - ScopedPyObjectPtr float_holder_; + ScopedPyObjectPtr py_str(PyObject_Str(py_value.get())); + if (!py_str.get()) { + return string(); + } + + return string(PyString_AsString(py_str.get())); + } }; static PyObject* ToStr(CMessage* self) { - google::protobuf::TextFormat::Printer printer; + TextFormat::Printer printer; // Passes ownership printer.SetDefaultFieldValuePrinter(new PythonFieldValuePrinter()); printer.SetHideUnknownFields(true); @@ -1291,8 +1825,12 @@ static PyObject* ToStr(CMessage* self) { PyObject* MergeFrom(CMessage* self, PyObject* arg) { CMessage* other_message; - if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) { - PyErr_SetString(PyExc_TypeError, "Must be a message"); + if (!PyObject_TypeCheck(arg, &CMessage_Type)) { + PyErr_Format(PyExc_TypeError, + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s.", + self->message->GetDescriptor()->full_name().c_str(), + Py_TYPE(arg)->tp_name); return NULL; } @@ -1300,8 +1838,8 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) { if (other_message->message->GetDescriptor() != self->message->GetDescriptor()) { PyErr_Format(PyExc_TypeError, - "Tried to merge from a message with a different type. " - "to: %s, from: %s", + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s.", self->message->GetDescriptor()->full_name().c_str(), other_message->message->GetDescriptor()->full_name().c_str()); return NULL; @@ -1319,8 +1857,12 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) { static PyObject* CopyFrom(CMessage* self, PyObject* arg) { CMessage* other_message; - if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) { - PyErr_SetString(PyExc_TypeError, "Must be a message"); + if (!PyObject_TypeCheck(arg, &CMessage_Type)) { + PyErr_Format(PyExc_TypeError, + "Parameter to CopyFrom() must be instance of same class: " + "expected %s got %s.", + self->message->GetDescriptor()->full_name().c_str(), + Py_TYPE(arg)->tp_name); return NULL; } @@ -1333,8 +1875,8 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) { if (other_message->message->GetDescriptor() != self->message->GetDescriptor()) { PyErr_Format(PyExc_TypeError, - "Tried to copy from a message with a different type. " - "to: %s, from: %s", + "Parameter to CopyFrom() must be instance of same class: " + "expected %s got %s.", self->message->GetDescriptor()->full_name().c_str(), other_message->message->GetDescriptor()->full_name().c_str()); return NULL; @@ -1344,13 +1886,37 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) { // CopyFrom on the message will not clean up self->composite_fields, // which can leave us in an inconsistent state, so clear it out here. - Clear(self); + (void)ScopedPyObjectPtr(Clear(self)); self->message->CopyFrom(*other_message->message); Py_RETURN_NONE; } +// Protobuf has a 64MB limit built in, this variable will override this. Please +// do not enable this unless you fully understand the implications: protobufs +// must all be kept in memory at the same time, so if they grow too big you may +// get OOM errors. The protobuf APIs do not provide any tools for processing +// protobufs in chunks. If you have protos this big you should break them up if +// it is at all convenient to do so. +static bool allow_oversize_protos = false; + +// Provide a method in the module to set allow_oversize_protos to a boolean +// value. This method returns the newly value of allow_oversize_protos. +static PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) { + if (!arg || !PyBool_Check(arg)) { + PyErr_SetString(PyExc_TypeError, + "Argument to SetAllowOversizeProtos must be boolean"); + return NULL; + } + allow_oversize_protos = PyObject_IsTrue(arg); + if (allow_oversize_protos) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + static PyObject* MergeFromString(CMessage* self, PyObject* arg) { const void* data; Py_ssize_t data_length; @@ -1359,9 +1925,13 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { } AssureWritable(self); - google::protobuf::io::CodedInputStream input( + io::CodedInputStream input( reinterpret_cast<const uint8*>(data), data_length); - input.SetExtensionRegistry(GetDescriptorPool(), global_message_factory); + if (allow_oversize_protos) { + input.SetTotalBytesLimit(INT_MAX, INT_MAX); + } + PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); + input.SetExtensionRegistry(pool->pool, pool->message_factory); bool success = self->message->MergePartialFromCodedStream(&input); if (success) { return PyInt_FromLong(input.CurrentPosition()); @@ -1372,7 +1942,7 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { } static PyObject* ParseFromString(CMessage* self, PyObject* arg) { - if (Clear(self) == NULL) { + if (ScopedPyObjectPtr(Clear(self)) == NULL) { return NULL; } return MergeFromString(self, arg); @@ -1384,14 +1954,12 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) { static PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) { - ScopedPyObjectPtr message_descriptor(PyObject_GetAttr(cls, kDESCRIPTOR)); - if (message_descriptor == NULL) { - return NULL; - } - if (PyObject_SetAttrString(extension_handle, "containing_type", - message_descriptor) < 0) { + const FieldDescriptor* descriptor = + GetExtensionDescriptor(extension_handle); + if (descriptor == NULL) { return NULL; } + ScopedPyObjectPtr extensions_by_name( PyObject_GetAttr(cls, k_extensions_by_name)); if (extensions_by_name == NULL) { @@ -1402,7 +1970,23 @@ static PyObject* RegisterExtension(PyObject* cls, if (full_name == NULL) { return NULL; } - if (PyDict_SetItem(extensions_by_name, full_name, extension_handle) < 0) { + + // If the extension was already registered, check that it is the same. + PyObject* existing_extension = + PyDict_GetItem(extensions_by_name.get(), full_name.get()); + if (existing_extension != NULL) { + const FieldDescriptor* existing_extension_descriptor = + GetExtensionDescriptor(existing_extension); + if (existing_extension_descriptor != descriptor) { + PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); + return NULL; + } + // Nothing else to do. + Py_RETURN_NONE; + } + + if (PyDict_SetItem(extensions_by_name.get(), full_name.get(), + extension_handle) < 0) { return NULL; } @@ -1413,36 +1997,52 @@ static PyObject* RegisterExtension(PyObject* cls, PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class"); return NULL; } + ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number")); if (number == NULL) { return NULL; } - if (PyDict_SetItem(extensions_by_number, number, extension_handle) < 0) { - return NULL; - } - CFieldDescriptor* cdescriptor = - extension_dict::InternalGetCDescriptorFromExtension(extension_handle); - ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor)); - if (cdescriptor == NULL) { + // If the extension was already registered by number, check that it is the + // same. + existing_extension = PyDict_GetItem(extensions_by_number.get(), number.get()); + if (existing_extension != NULL) { + const FieldDescriptor* existing_extension_descriptor = + GetExtensionDescriptor(existing_extension); + if (existing_extension_descriptor != descriptor) { + const Descriptor* msg_desc = GetMessageDescriptor( + reinterpret_cast<PyTypeObject*>(cls)); + PyErr_Format( + PyExc_ValueError, + "Extensions \"%s\" and \"%s\" both try to extend message type " + "\"%s\" with field number %ld.", + existing_extension_descriptor->full_name().c_str(), + descriptor->full_name().c_str(), + msg_desc->full_name().c_str(), + PyInt_AsLong(number.get())); + return NULL; + } + // Nothing else to do. + Py_RETURN_NONE; + } + if (PyDict_SetItem(extensions_by_number.get(), number.get(), + extension_handle) < 0) { return NULL; } - Py_INCREF(extension_handle); - cdescriptor->descriptor_field = extension_handle; - const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor; + // Check if it's a message set if (descriptor->is_extension() && descriptor->containing_type()->options().message_set_wire_format() && - descriptor->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE && - descriptor->message_type() == descriptor->extension_scope() && - descriptor->label() == google::protobuf::FieldDescriptor::LABEL_OPTIONAL) { + descriptor->type() == FieldDescriptor::TYPE_MESSAGE && + descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) { ScopedPyObjectPtr message_name(PyString_FromStringAndSize( descriptor->message_type()->full_name().c_str(), descriptor->message_type()->full_name().size())); if (message_name == NULL) { return NULL; } - PyDict_SetItem(extensions_by_name, message_name, extension_handle); + PyDict_SetItem(extensions_by_name.get(), message_name.get(), + extension_handle); } Py_RETURN_NONE; @@ -1454,53 +2054,38 @@ static PyObject* SetInParent(CMessage* self, PyObject* args) { } static PyObject* WhichOneof(CMessage* self, PyObject* arg) { - char* oneof_name; - if (!PyString_Check(arg)) { - PyErr_SetString(PyExc_TypeError, "field name must be a string"); + Py_ssize_t name_size; + char *name_data; + if (PyString_AsStringAndSize(arg, &name_data, &name_size) < 0) return NULL; - } - oneof_name = PyString_AsString(arg); - if (oneof_name == NULL) { - return NULL; - } - const google::protobuf::OneofDescriptor* oneof_desc = + string oneof_name = string(name_data, name_size); + const OneofDescriptor* oneof_desc = self->message->GetDescriptor()->FindOneofByName(oneof_name); if (oneof_desc == NULL) { PyErr_Format(PyExc_ValueError, - "Protocol message has no oneof \"%s\" field.", oneof_name); + "Protocol message has no oneof \"%s\" field.", + oneof_name.c_str()); return NULL; } - const google::protobuf::FieldDescriptor* field_in_oneof = + const FieldDescriptor* field_in_oneof = self->message->GetReflection()->GetOneofFieldDescriptor( *self->message, oneof_desc); if (field_in_oneof == NULL) { Py_RETURN_NONE; } else { - return PyString_FromString(field_in_oneof->name().c_str()); + const string& name = field_in_oneof->name(); + return PyString_FromStringAndSize(name.c_str(), name.size()); } } +static PyObject* GetExtensionDict(CMessage* self, void *closure); + static PyObject* ListFields(CMessage* self) { - vector<const google::protobuf::FieldDescriptor*> fields; + vector<const FieldDescriptor*> fields; self->message->GetReflection()->ListFields(*self->message, &fields); - PyObject* descriptor = PyDict_GetItem(Py_TYPE(self)->tp_dict, kDESCRIPTOR); - if (descriptor == NULL) { - return NULL; - } - ScopedPyObjectPtr fields_by_name( - PyObject_GetAttr(descriptor, kfields_by_name)); - if (fields_by_name == NULL) { - return NULL; - } - ScopedPyObjectPtr extensions_by_name(PyObject_GetAttr( - reinterpret_cast<PyObject*>(Py_TYPE(self)), k_extensions_by_name)); - if (extensions_by_name == NULL) { - PyErr_SetString(PyExc_ValueError, "no extensionsbyname"); - return NULL; - } // Normally, the list will be exactly the size of the fields. - PyObject* all_fields = PyList_New(fields.size()); + ScopedPyObjectPtr all_fields(PyList_New(fields.size())); if (all_fields == NULL) { return NULL; } @@ -1509,75 +2094,84 @@ static PyObject* ListFields(CMessage* self) { // the field information. Thus the actual size of the py list will be // smaller than the size of fields. Set the actual size at the end. Py_ssize_t actual_size = 0; - for (Py_ssize_t i = 0; i < fields.size(); ++i) { + for (size_t i = 0; i < fields.size(); ++i) { ScopedPyObjectPtr t(PyTuple_New(2)); if (t == NULL) { - Py_DECREF(all_fields); return NULL; } if (fields[i]->is_extension()) { - const string& field_name = fields[i]->full_name(); - PyObject* extension_field = PyDict_GetItemString(extensions_by_name, - field_name.c_str()); + ScopedPyObjectPtr extension_field( + PyFieldDescriptor_FromDescriptor(fields[i])); if (extension_field == NULL) { - // If we couldn't fetch extension_field, it means the module that - // defines this extension has not been explicitly imported in Python - // code, and the extension hasn't been registered. There's nothing much - // we can do about this, so just skip it in the output to match the - // behavior of the python implementation. + return NULL; + } + // With C++ descriptors, the field can always be retrieved, but for + // unknown extensions which have not been imported in Python code, there + // is no message class and we cannot retrieve the value. + // TODO(amauryfa): consider building the class on the fly! + if (fields[i]->message_type() != NULL && + cdescriptor_pool::GetMessageClass( + GetDescriptorPoolForMessage(self), + fields[i]->message_type()) == NULL) { + PyErr_Clear(); continue; } - PyObject* extensions = reinterpret_cast<PyObject*>(self->extensions); + ScopedPyObjectPtr extensions(GetExtensionDict(self, NULL)); if (extensions == NULL) { - Py_DECREF(all_fields); return NULL; } // 'extension' reference later stolen by PyTuple_SET_ITEM. - PyObject* extension = PyObject_GetItem(extensions, extension_field); + PyObject* extension = PyObject_GetItem( + extensions.get(), extension_field.get()); if (extension == NULL) { - Py_DECREF(all_fields); return NULL; } - Py_INCREF(extension_field); - PyTuple_SET_ITEM(t.get(), 0, extension_field); + PyTuple_SET_ITEM(t.get(), 0, extension_field.release()); // Steals reference to 'extension' PyTuple_SET_ITEM(t.get(), 1, extension); } else { + // Normal field const string& field_name = fields[i]->name(); ScopedPyObjectPtr py_field_name(PyString_FromStringAndSize( field_name.c_str(), field_name.length())); if (py_field_name == NULL) { PyErr_SetString(PyExc_ValueError, "bad string"); - Py_DECREF(all_fields); return NULL; } - PyObject* field_descriptor = - PyDict_GetItem(fields_by_name, py_field_name); + ScopedPyObjectPtr field_descriptor( + PyFieldDescriptor_FromDescriptor(fields[i])); if (field_descriptor == NULL) { - Py_DECREF(all_fields); return NULL; } - PyObject* field_value = GetAttr(self, py_field_name); + PyObject* field_value = GetAttr(self, py_field_name.get()); if (field_value == NULL) { - PyErr_SetObject(PyExc_ValueError, py_field_name); - Py_DECREF(all_fields); + PyErr_SetObject(PyExc_ValueError, py_field_name.get()); return NULL; } - Py_INCREF(field_descriptor); - PyTuple_SET_ITEM(t.get(), 0, field_descriptor); + PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release()); PyTuple_SET_ITEM(t.get(), 1, field_value); } - PyList_SET_ITEM(all_fields, actual_size, t.release()); + PyList_SET_ITEM(all_fields.get(), actual_size, t.release()); ++actual_size; } - Py_SIZE(all_fields) = actual_size; - return all_fields; + if (static_cast<size_t>(actual_size) != fields.size() && + (PyList_SetSlice(all_fields.get(), actual_size, fields.size(), NULL) < + 0)) { + return NULL; + } + return all_fields.release(); +} + +static PyObject* DiscardUnknownFields(CMessage* self) { + AssureWritable(self); + self->message->DiscardUnknownFields(); + Py_RETURN_NONE; } PyObject* FindInitializationErrors(CMessage* self) { - google::protobuf::Message* message = self->message; + Message* message = self->message; vector<string> errors; message->FindInitializationErrors(&errors); @@ -1585,7 +2179,7 @@ PyObject* FindInitializationErrors(CMessage* self) { if (error_list == NULL) { return NULL; } - for (Py_ssize_t i = 0; i < errors.size(); ++i) { + for (size_t i = 0; i < errors.size(); ++i) { const string& error = errors[i]; PyObject* error_string = PyString_FromStringAndSize( error.c_str(), error.length()); @@ -1599,94 +2193,105 @@ PyObject* FindInitializationErrors(CMessage* self) { } static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { - if (!PyObject_TypeCheck(other, &CMessage_Type)) { - if (opid == Py_EQ) { - Py_RETURN_FALSE; - } else if (opid == Py_NE) { - Py_RETURN_TRUE; - } - } - if (opid == Py_EQ || opid == Py_NE) { - ScopedPyObjectPtr self_fields(ListFields(self)); - ScopedPyObjectPtr other_fields(ListFields( - reinterpret_cast<CMessage*>(other))); - return PyObject_RichCompare(self_fields, other_fields, opid); - } else { + // Only equality comparisons are implemented. + if (opid != Py_EQ && opid != Py_NE) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } + bool equals = true; + // If other is not a message, it cannot be equal. + if (!PyObject_TypeCheck(other, &CMessage_Type)) { + equals = false; + } + const google::protobuf::Message* other_message = + reinterpret_cast<CMessage*>(other)->message; + // If messages don't have the same descriptors, they are not equal. + if (equals && + self->message->GetDescriptor() != other_message->GetDescriptor()) { + equals = false; + } + // Check the message contents. + if (equals && !google::protobuf::util::MessageDifferencer::Equals( + *self->message, + *reinterpret_cast<CMessage*>(other)->message)) { + equals = false; + } + if (equals ^ (opid == Py_EQ)) { + Py_RETURN_FALSE; + } else { + Py_RETURN_TRUE; + } } -PyObject* InternalGetScalar( - CMessage* self, - const google::protobuf::FieldDescriptor* field_descriptor) { - google::protobuf::Message* message = self->message; - const google::protobuf::Reflection* reflection = message->GetReflection(); +PyObject* InternalGetScalar(const Message* message, + const FieldDescriptor* field_descriptor) { + const Reflection* reflection = message->GetReflection(); - if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { - PyErr_SetString( - PyExc_KeyError, "Field does not belong to message!"); + if (!CheckFieldBelongsToMessage(field_descriptor, message)) { return NULL; } PyObject* result = NULL; switch (field_descriptor->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + case FieldDescriptor::CPPTYPE_INT32: { int32 value = reflection->GetInt32(*message, field_descriptor); result = PyInt_FromLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + case FieldDescriptor::CPPTYPE_INT64: { int64 value = reflection->GetInt64(*message, field_descriptor); result = PyLong_FromLongLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + case FieldDescriptor::CPPTYPE_UINT32: { uint32 value = reflection->GetUInt32(*message, field_descriptor); result = PyInt_FromSize_t(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + case FieldDescriptor::CPPTYPE_UINT64: { uint64 value = reflection->GetUInt64(*message, field_descriptor); result = PyLong_FromUnsignedLongLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + case FieldDescriptor::CPPTYPE_FLOAT: { float value = reflection->GetFloat(*message, field_descriptor); result = PyFloat_FromDouble(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + case FieldDescriptor::CPPTYPE_DOUBLE: { double value = reflection->GetDouble(*message, field_descriptor); result = PyFloat_FromDouble(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + case FieldDescriptor::CPPTYPE_BOOL: { bool value = reflection->GetBool(*message, field_descriptor); result = PyBool_FromLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + case FieldDescriptor::CPPTYPE_STRING: { string value = reflection->GetString(*message, field_descriptor); result = ToStringObject(field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { - if (!message->GetReflection()->HasField(*message, field_descriptor)) { + case FieldDescriptor::CPPTYPE_ENUM: { + if (!message->GetReflection()->SupportsUnknownEnumValues() && + !message->GetReflection()->HasField(*message, field_descriptor)) { // Look for the value in the unknown fields. - google::protobuf::UnknownFieldSet* unknown_field_set = - message->GetReflection()->MutableUnknownFields(message); - for (int i = 0; i < unknown_field_set->field_count(); ++i) { - if (unknown_field_set->field(i).number() == - field_descriptor->number()) { - result = PyInt_FromLong(unknown_field_set->field(i).varint()); + const UnknownFieldSet& unknown_field_set = + message->GetReflection()->GetUnknownFields(*message); + for (int i = 0; i < unknown_field_set.field_count(); ++i) { + if (unknown_field_set.field(i).number() == + field_descriptor->number() && + unknown_field_set.field(i).type() == + google::protobuf::UnknownField::TYPE_VARINT) { + result = PyInt_FromLong(unknown_field_set.field(i).varint()); break; } } } if (result == NULL) { - const google::protobuf::EnumValueDescriptor* enum_value = + const EnumValueDescriptor* enum_value = message->GetReflection()->GetEnum(*message, field_descriptor); result = PyInt_FromLong(enum_value->number()); } @@ -1701,116 +2306,100 @@ PyObject* InternalGetScalar( return result; } -PyObject* InternalGetSubMessage(CMessage* self, - CFieldDescriptor* cfield_descriptor) { - PyObject* field = cfield_descriptor->descriptor_field; - ScopedPyObjectPtr message_type(PyObject_GetAttr(field, kmessage_type)); - if (message_type == NULL) { - return NULL; - } - ScopedPyObjectPtr concrete_class( - PyObject_GetAttr(message_type, k_concrete_class)); - if (concrete_class == NULL) { +PyObject* InternalGetSubMessage( + CMessage* self, const FieldDescriptor* field_descriptor) { + const Reflection* reflection = self->message->GetReflection(); + PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); + const Message& sub_message = reflection->GetMessage( + *self->message, field_descriptor, pool->message_factory); + + CMessageClass* message_class = cdescriptor_pool::GetMessageClass( + pool, field_descriptor->message_type()); + if (message_class == NULL) { return NULL; } - PyObject* py_cmsg = cmessage::NewEmpty(concrete_class); - if (py_cmsg == NULL) { + + CMessage* cmsg = cmessage::NewEmptyMessage(message_class); + if (cmsg == NULL) { return NULL; } - if (!PyObject_TypeCheck(py_cmsg, &CMessage_Type)) { - PyErr_SetString(PyExc_TypeError, "Not a CMessage!"); - } - CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg); - const google::protobuf::FieldDescriptor* field_descriptor = - cfield_descriptor->descriptor; - const google::protobuf::Reflection* reflection = self->message->GetReflection(); - const google::protobuf::Message& sub_message = reflection->GetMessage( - *self->message, field_descriptor, global_message_factory); cmsg->owner = self->owner; cmsg->parent = self; - cmsg->parent_field = cfield_descriptor; + cmsg->parent_field_descriptor = field_descriptor; cmsg->read_only = !reflection->HasField(*self->message, field_descriptor); - cmsg->message = const_cast<google::protobuf::Message*>(&sub_message); + cmsg->message = const_cast<Message*>(&sub_message); - if (InitAttributes(cmsg, NULL, NULL) < 0) { - Py_DECREF(py_cmsg); - return NULL; - } - return py_cmsg; + return reinterpret_cast<PyObject*>(cmsg); } -int InternalSetScalar( - CMessage* self, - const google::protobuf::FieldDescriptor* field_descriptor, +int InternalSetNonOneofScalar( + Message* message, + const FieldDescriptor* field_descriptor, PyObject* arg) { - google::protobuf::Message* message = self->message; - const google::protobuf::Reflection* reflection = message->GetReflection(); - - if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { - PyErr_SetString( - PyExc_KeyError, "Field does not belong to message!"); - return -1; - } + const Reflection* reflection = message->GetReflection(); - if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) { + if (!CheckFieldBelongsToMessage(field_descriptor, message)) { return -1; } switch (field_descriptor->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + case FieldDescriptor::CPPTYPE_INT32: { GOOGLE_CHECK_GET_INT32(arg, value, -1); reflection->SetInt32(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + case FieldDescriptor::CPPTYPE_INT64: { GOOGLE_CHECK_GET_INT64(arg, value, -1); reflection->SetInt64(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + case FieldDescriptor::CPPTYPE_UINT32: { GOOGLE_CHECK_GET_UINT32(arg, value, -1); reflection->SetUInt32(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + case FieldDescriptor::CPPTYPE_UINT64: { GOOGLE_CHECK_GET_UINT64(arg, value, -1); reflection->SetUInt64(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + case FieldDescriptor::CPPTYPE_FLOAT: { GOOGLE_CHECK_GET_FLOAT(arg, value, -1); reflection->SetFloat(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + case FieldDescriptor::CPPTYPE_DOUBLE: { GOOGLE_CHECK_GET_DOUBLE(arg, value, -1); reflection->SetDouble(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + case FieldDescriptor::CPPTYPE_BOOL: { GOOGLE_CHECK_GET_BOOL(arg, value, -1); reflection->SetBool(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + case FieldDescriptor::CPPTYPE_STRING: { if (!CheckAndSetString( arg, message, field_descriptor, reflection, false, -1)) { return -1; } break; } - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + case FieldDescriptor::CPPTYPE_ENUM: { GOOGLE_CHECK_GET_INT32(arg, value, -1); - const google::protobuf::EnumDescriptor* enum_descriptor = - field_descriptor->enum_type(); - const google::protobuf::EnumValueDescriptor* enum_value = - enum_descriptor->FindValueByNumber(value); - if (enum_value != NULL) { - reflection->SetEnum(message, field_descriptor, enum_value); + if (reflection->SupportsUnknownEnumValues()) { + reflection->SetEnumValue(message, field_descriptor, value); } else { - PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value); - return -1; + const EnumDescriptor* enum_descriptor = field_descriptor->enum_type(); + const EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->SetEnum(message, field_descriptor, enum_value); + } else { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value); + return -1; + } } break; } @@ -1824,6 +2413,21 @@ int InternalSetScalar( return 0; } +int InternalSetScalar( + CMessage* self, + const FieldDescriptor* field_descriptor, + PyObject* arg) { + if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) { + return -1; + } + + if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) { + return -1; + } + + return InternalSetNonOneofScalar(self->message, field_descriptor, arg); +} + PyObject* FromString(PyTypeObject* cls, PyObject* serialized) { PyObject* py_cmsg = PyObject_CallObject( reinterpret_cast<PyObject*>(cls), NULL); @@ -1838,197 +2442,9 @@ PyObject* FromString(PyTypeObject* cls, PyObject* serialized) { return NULL; } - if (InitAttributes(cmsg, NULL, NULL) < 0) { - Py_DECREF(py_cmsg); - return NULL; - } return py_cmsg; } -static PyObject* AddDescriptors(PyTypeObject* cls, - PyObject* descriptor) { - if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), - k_extensions_by_name, PyDict_New()) < 0) { - return NULL; - } - if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), - k_extensions_by_number, PyDict_New()) < 0) { - return NULL; - } - - ScopedPyObjectPtr field_descriptors(PyDict_New()); - - ScopedPyObjectPtr fields(PyObject_GetAttrString(descriptor, "fields")); - if (fields == NULL) { - return NULL; - } - - ScopedPyObjectPtr _NUMBER_string(PyString_FromString("_FIELD_NUMBER")); - if (_NUMBER_string == NULL) { - return NULL; - } - - const Py_ssize_t fields_size = PyList_GET_SIZE(fields.get()); - for (int i = 0; i < fields_size; ++i) { - PyObject* field = PyList_GET_ITEM(fields.get(), i); - ScopedPyObjectPtr field_name(PyObject_GetAttr(field, kname)); - ScopedPyObjectPtr full_field_name(PyObject_GetAttr(field, kfull_name)); - if (field_name == NULL || full_field_name == NULL) { - PyErr_SetString(PyExc_TypeError, "Name is null"); - return NULL; - } - - PyObject* field_descriptor = - cdescriptor_pool::FindFieldByName(descriptor_pool, full_field_name); - if (field_descriptor == NULL) { - PyErr_SetString(PyExc_TypeError, "Couldn't find field"); - return NULL; - } - Py_INCREF(field); - CFieldDescriptor* cfield_descriptor = reinterpret_cast<CFieldDescriptor*>( - field_descriptor); - cfield_descriptor->descriptor_field = field; - if (PyDict_SetItem(field_descriptors, field_name, field_descriptor) < 0) { - return NULL; - } - - // The FieldDescriptor's name field might either be of type bytes or - // of type unicode, depending on whether the FieldDescriptor was - // parsed from a serialized message or read from the - // <message>_pb2.py module. - ScopedPyObjectPtr field_name_upcased( - PyObject_CallMethod(field_name, "upper", NULL)); - if (field_name_upcased == NULL) { - return NULL; - } - - ScopedPyObjectPtr field_number_name(PyObject_CallMethod( - field_name_upcased, "__add__", "(O)", _NUMBER_string.get())); - if (field_number_name == NULL) { - return NULL; - } - - ScopedPyObjectPtr number(PyInt_FromLong( - cfield_descriptor->descriptor->number())); - if (number == NULL) { - return NULL; - } - if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), - field_number_name, number) == -1) { - return NULL; - } - } - - PyDict_SetItem(cls->tp_dict, k__descriptors, field_descriptors); - - // Enum Values - ScopedPyObjectPtr enum_types(PyObject_GetAttrString(descriptor, - "enum_types")); - if (enum_types == NULL) { - return NULL; - } - ScopedPyObjectPtr type_iter(PyObject_GetIter(enum_types)); - if (type_iter == NULL) { - return NULL; - } - ScopedPyObjectPtr enum_type; - while ((enum_type.reset(PyIter_Next(type_iter))) != NULL) { - ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs( - EnumTypeWrapper_class, enum_type.get(), NULL)); - if (wrapped == NULL) { - return NULL; - } - ScopedPyObjectPtr enum_name(PyObject_GetAttr(enum_type, kname)); - if (enum_name == NULL) { - return NULL; - } - if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), - enum_name, wrapped) == -1) { - return NULL; - } - - ScopedPyObjectPtr enum_values(PyObject_GetAttrString(enum_type, "values")); - if (enum_values == NULL) { - return NULL; - } - ScopedPyObjectPtr values_iter(PyObject_GetIter(enum_values)); - if (values_iter == NULL) { - return NULL; - } - ScopedPyObjectPtr enum_value; - while ((enum_value.reset(PyIter_Next(values_iter))) != NULL) { - ScopedPyObjectPtr value_name(PyObject_GetAttr(enum_value, kname)); - if (value_name == NULL) { - return NULL; - } - ScopedPyObjectPtr value_number(PyObject_GetAttrString(enum_value, - "number")); - if (value_number == NULL) { - return NULL; - } - if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), - value_name, value_number) == -1) { - return NULL; - } - } - if (PyErr_Occurred()) { // If PyIter_Next failed - return NULL; - } - } - if (PyErr_Occurred()) { // If PyIter_Next failed - return NULL; - } - - ScopedPyObjectPtr extension_dict( - PyObject_GetAttr(descriptor, kextensions_by_name)); - if (extension_dict == NULL || !PyDict_Check(extension_dict)) { - PyErr_SetString(PyExc_TypeError, "extensions_by_name not a dict"); - return NULL; - } - Py_ssize_t pos = 0; - PyObject* extension_name; - PyObject* extension_field; - - while (PyDict_Next(extension_dict, &pos, &extension_name, &extension_field)) { - if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), - extension_name, extension_field) == -1) { - return NULL; - } - ScopedPyObjectPtr py_cfield_descriptor( - PyObject_GetAttrString(extension_field, "_cdescriptor")); - if (py_cfield_descriptor == NULL) { - return NULL; - } - CFieldDescriptor* cfield_descriptor = - reinterpret_cast<CFieldDescriptor*>(py_cfield_descriptor.get()); - Py_INCREF(extension_field); - cfield_descriptor->descriptor_field = extension_field; - - ScopedPyObjectPtr field_name_upcased( - PyObject_CallMethod(extension_name, "upper", NULL)); - if (field_name_upcased == NULL) { - return NULL; - } - ScopedPyObjectPtr field_number_name(PyObject_CallMethod( - field_name_upcased, "__add__", "(O)", _NUMBER_string.get())); - if (field_number_name == NULL) { - return NULL; - } - ScopedPyObjectPtr number(PyInt_FromLong( - cfield_descriptor->descriptor->number())); - if (number == NULL) { - return NULL; - } - if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), - field_number_name, PyInt_FromLong( - cfield_descriptor->descriptor->number())) == -1) { - return NULL; - } - } - - Py_RETURN_NONE; -} - PyObject* DeepCopy(CMessage* self, PyObject* arg) { PyObject* clone = PyObject_CallObject( reinterpret_cast<PyObject*>(Py_TYPE(self)), NULL); @@ -2039,12 +2455,9 @@ PyObject* DeepCopy(CMessage* self, PyObject* arg) { Py_DECREF(clone); return NULL; } - if (InitAttributes(reinterpret_cast<CMessage*>(clone), NULL, NULL) < 0) { - Py_DECREF(clone); - return NULL; - } - if (MergeFrom(reinterpret_cast<CMessage*>(clone), - reinterpret_cast<PyObject*>(self)) == NULL) { + if (ScopedPyObjectPtr(MergeFrom( + reinterpret_cast<CMessage*>(clone), + reinterpret_cast<PyObject*>(self))) == NULL) { Py_DECREF(clone); return NULL; } @@ -2063,16 +2476,16 @@ PyObject* ToUnicode(CMessage* self) { return NULL; } Py_INCREF(Py_True); - ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs(text_format, method_name, - self, Py_True, NULL)); + ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs( + text_format.get(), method_name.get(), self, Py_True, NULL)); Py_DECREF(Py_True); if (encoded == NULL) { return NULL; } #if PY_MAJOR_VERSION < 3 - PyObject* decoded = PyString_AsDecodedObject(encoded, "utf-8", NULL); + PyObject* decoded = PyString_AsDecodedObject(encoded.get(), "utf-8", NULL); #else - PyObject* decoded = PyUnicode_FromEncodedObject(encoded, "utf-8", NULL); + PyObject* decoded = PyUnicode_FromEncodedObject(encoded.get(), "utf-8", NULL); #endif if (decoded == NULL) { return NULL; @@ -2095,7 +2508,7 @@ PyObject* Reduce(CMessage* self) { if (serialized == NULL) { return NULL; } - if (PyDict_SetItemString(state, "serialized", serialized) < 0) { + if (PyDict_SetItemString(state.get(), "serialized", serialized.get()) < 0) { return NULL; } return Py_BuildValue("OOO", constructor.get(), args.get(), state.get()); @@ -2110,24 +2523,49 @@ PyObject* SetState(CMessage* self, PyObject* state) { if (serialized == NULL) { return NULL; } - if (ParseFromString(self, serialized) == NULL) { + if (ScopedPyObjectPtr(ParseFromString(self, serialized)) == NULL) { return NULL; } Py_RETURN_NONE; } // CMessage static methods: -PyObject* _GetFieldDescriptor(PyObject* unused, PyObject* arg) { - return cdescriptor_pool::FindFieldByName(descriptor_pool, arg); +PyObject* _CheckCalledFromGeneratedFile(PyObject* unused, + PyObject* unused_arg) { + if (!_CalledFromGeneratedFile(1)) { + PyErr_SetString(PyExc_TypeError, + "Descriptors should not be created directly, " + "but only retrieved from their parent."); + return NULL; + } + Py_RETURN_NONE; } -PyObject* _GetExtensionDescriptor(PyObject* unused, PyObject* arg) { - return cdescriptor_pool::FindExtensionByName(descriptor_pool, arg); +static PyObject* GetExtensionDict(CMessage* self, void *closure) { + if (self->extensions) { + Py_INCREF(self->extensions); + return reinterpret_cast<PyObject*>(self->extensions); + } + + // If there are extension_ranges, the message is "extendable". Allocate a + // dictionary to store the extension fields. + const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self)); + if (descriptor->extension_range_count() > 0) { + ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); + if (extension_dict == NULL) { + return NULL; + } + self->extensions = extension_dict; + Py_INCREF(self->extensions); + return reinterpret_cast<PyObject*>(self->extensions); + } + + PyErr_SetNone(PyExc_AttributeError); + return NULL; } -static PyMemberDef Members[] = { - {"Extensions", T_OBJECT_EX, offsetof(CMessage, extensions), 0, - "Extension dict"}, +static PyGetSetDef Getters[] = { + {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"}, {NULL} }; @@ -2140,8 +2578,6 @@ static PyMethodDef Methods[] = { "Inputs picklable representation of the message." }, { "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS, "Outputs a unicode representation of the message." }, - { "AddDescriptors", (PyCFunction)AddDescriptors, METH_O | METH_CLASS, - "Adds field descriptors to the class" }, { "ByteSize", (PyCFunction)ByteSize, METH_NOARGS, "Returns the size of the message in bytes." }, { "Clear", (PyCFunction)Clear, METH_NOARGS, @@ -2152,6 +2588,8 @@ static PyMethodDef Methods[] = { "Clears a message field." }, { "CopyFrom", (PyCFunction)CopyFrom, METH_O, "Copies a protocol message into the current message." }, + { "DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS, + "Discards the unknown fields." }, { "FindInitializationErrors", (PyCFunction)FindInitializationErrors, METH_NOARGS, "Finds unset required fields." }, @@ -2185,113 +2623,117 @@ static PyMethodDef Methods[] = { "or None if no field is set." }, // Static Methods. - { "_BuildFile", (PyCFunction)Python_BuildFile, METH_O | METH_STATIC, - "Registers a new protocol buffer file in the global C++ descriptor pool." }, - { "_GetFieldDescriptor", (PyCFunction)_GetFieldDescriptor, - METH_O | METH_STATIC, "Finds a field descriptor in the message pool." }, - { "_GetExtensionDescriptor", (PyCFunction)_GetExtensionDescriptor, - METH_O | METH_STATIC, - "Finds a extension descriptor in the message pool." }, + { "_CheckCalledFromGeneratedFile", (PyCFunction)_CheckCalledFromGeneratedFile, + METH_NOARGS | METH_STATIC, + "Raises TypeError if the caller is not in a _pb2.py file."}, { NULL, NULL} }; +static bool SetCompositeField( + CMessage* self, PyObject* name, PyObject* value) { + if (self->composite_fields == NULL) { + self->composite_fields = PyDict_New(); + if (self->composite_fields == NULL) { + return false; + } + } + return PyDict_SetItem(self->composite_fields, name, value) == 0; +} + PyObject* GetAttr(CMessage* self, PyObject* name) { - PyObject* value = PyDict_GetItem(self->composite_fields, name); + PyObject* value = self->composite_fields ? + PyDict_GetItem(self->composite_fields, name) : NULL; if (value != NULL) { Py_INCREF(value); return value; } - PyObject* descriptor = GetDescriptor(self, name); - if (descriptor != NULL) { - CFieldDescriptor* cdescriptor = - reinterpret_cast<CFieldDescriptor*>(descriptor); - const google::protobuf::FieldDescriptor* field_descriptor = cdescriptor->descriptor; - if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { - if (field_descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject* py_container = PyObject_CallObject( - reinterpret_cast<PyObject*>(&RepeatedCompositeContainer_Type), - NULL); - if (py_container == NULL) { - return NULL; - } - RepeatedCompositeContainer* container = - reinterpret_cast<RepeatedCompositeContainer*>(py_container); - PyObject* field = cdescriptor->descriptor_field; - PyObject* message_type = PyObject_GetAttr(field, kmessage_type); - if (message_type == NULL) { - return NULL; - } - PyObject* concrete_class = - PyObject_GetAttr(message_type, k_concrete_class); - if (concrete_class == NULL) { - return NULL; - } - container->parent = self; - container->parent_field = cdescriptor; - container->message = self->message; - container->owner = self->owner; - container->subclass_init = concrete_class; - Py_DECREF(message_type); - if (PyDict_SetItem(self->composite_fields, name, py_container) < 0) { - Py_DECREF(py_container); - return NULL; - } - return py_container; - } else { - ScopedPyObjectPtr init_args(PyTuple_Pack(2, self, cdescriptor)); - PyObject* py_container = PyObject_CallObject( - reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type), - init_args); - if (py_container == NULL) { - return NULL; - } - if (PyDict_SetItem(self->composite_fields, name, py_container) < 0) { - Py_DECREF(py_container); - return NULL; - } - return py_container; + const FieldDescriptor* field_descriptor = GetFieldDescriptor(self, name); + if (field_descriptor == NULL) { + return CMessage_Type.tp_base->tp_getattro( + reinterpret_cast<PyObject*>(self), name); + } + + if (field_descriptor->is_map()) { + PyObject* py_container = NULL; + const Descriptor* entry_type = field_descriptor->message_type(); + const FieldDescriptor* value_type = entry_type->FindFieldByName("value"); + if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + CMessageClass* value_class = cdescriptor_pool::GetMessageClass( + GetDescriptorPoolForMessage(self), value_type->message_type()); + if (value_class == NULL) { + return NULL; } + py_container = + NewMessageMapContainer(self, field_descriptor, value_class); } else { - if (field_descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject* sub_message = InternalGetSubMessage(self, cdescriptor); - if (PyDict_SetItem(self->composite_fields, name, sub_message) < 0) { - Py_DECREF(sub_message); - return NULL; - } - return sub_message; - } else { - return InternalGetScalar(self, field_descriptor); + py_container = NewScalarMapContainer(self, field_descriptor); + } + if (py_container == NULL) { + return NULL; + } + if (!SetCompositeField(self, name, py_container)) { + Py_DECREF(py_container); + return NULL; + } + return py_container; + } + + if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + PyObject* py_container = NULL; + if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + CMessageClass* message_class = cdescriptor_pool::GetMessageClass( + GetDescriptorPoolForMessage(self), field_descriptor->message_type()); + if (message_class == NULL) { + return NULL; } + py_container = repeated_composite_container::NewContainer( + self, field_descriptor, message_class); + } else { + py_container = repeated_scalar_container::NewContainer( + self, field_descriptor); + } + if (py_container == NULL) { + return NULL; + } + if (!SetCompositeField(self, name, py_container)) { + Py_DECREF(py_container); + return NULL; + } + return py_container; + } + + if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + PyObject* sub_message = InternalGetSubMessage(self, field_descriptor); + if (sub_message == NULL) { + return NULL; + } + if (!SetCompositeField(self, name, sub_message)) { + Py_DECREF(sub_message); + return NULL; } + return sub_message; } - return CMessage_Type.tp_base->tp_getattro(reinterpret_cast<PyObject*>(self), - name); + return InternalGetScalar(self->message, field_descriptor); } int SetAttr(CMessage* self, PyObject* name, PyObject* value) { - if (PyDict_Contains(self->composite_fields, name)) { + if (self->composite_fields && PyDict_Contains(self->composite_fields, name)) { PyErr_SetString(PyExc_TypeError, "Can't set composite field"); return -1; } - PyObject* descriptor = GetDescriptor(self, name); - if (descriptor != NULL) { + const FieldDescriptor* field_descriptor = GetFieldDescriptor(self, name); + if (field_descriptor != NULL) { AssureWritable(self); - CFieldDescriptor* cdescriptor = - reinterpret_cast<CFieldDescriptor*>(descriptor); - const google::protobuf::FieldDescriptor* field_descriptor = cdescriptor->descriptor; - if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { PyErr_Format(PyExc_AttributeError, "Assignment not allowed to repeated " "field \"%s\" in protocol message object.", field_descriptor->name().c_str()); return -1; } else { - if (field_descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyErr_Format(PyExc_AttributeError, "Assignment not allowed to " "field \"%s\" in protocol message object.", field_descriptor->name().c_str()); @@ -2302,16 +2744,18 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) { } } - PyErr_Format(PyExc_AttributeError, "Assignment not allowed"); + PyErr_Format(PyExc_AttributeError, + "Assignment not allowed " + "(no field \"%s\" in protocol message object).", + PyString_AsString(name)); return -1; } } // namespace cmessage PyTypeObject CMessage_Type = { - PyVarObject_HEAD_INIT(&PyType_Type, 0) - "google.protobuf.internal." - "cpp._message.CMessage", // tp_name + PyVarObject_HEAD_INIT(&CMessageClass_Type, 0) + FULL_MODULE_NAME ".CMessage", // tp_name sizeof(CMessage), // tp_basicsize 0, // tp_itemsize (destructor)cmessage::Dealloc, // tp_dealloc @@ -2319,11 +2763,11 @@ PyTypeObject CMessage_Type = { 0, // tp_getattr 0, // tp_setattr 0, // tp_compare - 0, // tp_repr + (reprfunc)cmessage::ToStr, // tp_repr 0, // tp_as_number 0, // tp_as_sequence 0, // tp_as_mapping - 0, // tp_hash + PyObject_HashNotImplemented, // tp_hash 0, // tp_call (reprfunc)cmessage::ToStr, // tp_str (getattrofunc)cmessage::GetAttr, // tp_getattro @@ -2338,8 +2782,8 @@ PyTypeObject CMessage_Type = { 0, // tp_iter 0, // tp_iternext cmessage::Methods, // tp_methods - cmessage::Members, // tp_members - 0, // tp_getset + 0, // tp_members + cmessage::Getters, // tp_getset 0, // tp_base 0, // tp_dict 0, // tp_descr_get @@ -2355,7 +2799,7 @@ PyTypeObject CMessage_Type = { const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg); Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg); -static const google::protobuf::Message* GetCProtoInsidePyProtoImpl(PyObject* msg) { +static const Message* GetCProtoInsidePyProtoImpl(PyObject* msg) { if (!PyObject_TypeCheck(msg, &CMessage_Type)) { return NULL; } @@ -2363,12 +2807,12 @@ static const google::protobuf::Message* GetCProtoInsidePyProtoImpl(PyObject* msg return cmsg->message; } -static google::protobuf::Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { +static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { if (!PyObject_TypeCheck(msg, &CMessage_Type)) { return NULL; } CMessage* cmsg = reinterpret_cast<CMessage*>(msg); - if (PyDict_Size(cmsg->composite_fields) != 0 || + if ((cmsg->composite_fields && PyDict_Size(cmsg->composite_fields) != 0) || (cmsg->extensions != NULL && PyDict_Size(cmsg->extensions->values) != 0)) { // There is currently no way of accurately syncing arbitrary changes to @@ -2401,87 +2845,212 @@ void InitGlobals() { kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max); kDESCRIPTOR = PyString_FromString("DESCRIPTOR"); - k__descriptors = PyString_FromString("__descriptors"); + k_cdescriptor = PyString_FromString("_cdescriptor"); kfull_name = PyString_FromString("full_name"); - kis_extendable = PyString_FromString("is_extendable"); - kextensions_by_name = PyString_FromString("extensions_by_name"); k_extensions_by_name = PyString_FromString("_extensions_by_name"); k_extensions_by_number = PyString_FromString("_extensions_by_number"); - k_concrete_class = PyString_FromString("_concrete_class"); - kmessage_type = PyString_FromString("message_type"); - kname = PyString_FromString("name"); - kfields_by_name = PyString_FromString("fields_by_name"); - - global_message_factory = new DynamicMessageFactory(GetDescriptorPool()); - global_message_factory->SetDelegateToGeneratedFactory(true); - descriptor_pool = reinterpret_cast<google::protobuf::python::CDescriptorPool*>( - Python_NewCDescriptorPool(NULL, NULL)); + PyObject *dummy_obj = PySet_New(NULL); + kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL); + Py_DECREF(dummy_obj); } bool InitProto2MessageModule(PyObject *m) { - InitGlobals(); - - google::protobuf::python::CMessage_Type.tp_hash = PyObject_HashNotImplemented; - if (PyType_Ready(&google::protobuf::python::CMessage_Type) < 0) { + // Initialize types and globals in descriptor.cc + if (!InitDescriptor()) { return false; } - // All three of these are actually set elsewhere, directly onto the child - // protocol buffer message class, but set them here as well to document that - // subclasses need to set these. - PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); - PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, - k_extensions_by_name, Py_None); - PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, - k_extensions_by_number, Py_None); + // Initialize types and globals in descriptor_pool.cc + if (!InitDescriptorPool()) { + return false; + } - PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>( - &google::protobuf::python::CMessage_Type)); + // Initialize constants defined in this file. + InitGlobals(); - google::protobuf::python::RepeatedScalarContainer_Type.tp_new = PyType_GenericNew; - google::protobuf::python::RepeatedScalarContainer_Type.tp_hash = - PyObject_HashNotImplemented; - if (PyType_Ready(&google::protobuf::python::RepeatedScalarContainer_Type) < 0) { + CMessageClass_Type.tp_base = &PyType_Type; + if (PyType_Ready(&CMessageClass_Type) < 0) { return false; } + PyModule_AddObject(m, "MessageMeta", + reinterpret_cast<PyObject*>(&CMessageClass_Type)); - PyModule_AddObject(m, "RepeatedScalarContainer", - reinterpret_cast<PyObject*>( - &google::protobuf::python::RepeatedScalarContainer_Type)); + if (PyType_Ready(&CMessage_Type) < 0) { + return false; + } - google::protobuf::python::RepeatedCompositeContainer_Type.tp_new = PyType_GenericNew; - google::protobuf::python::RepeatedCompositeContainer_Type.tp_hash = - PyObject_HashNotImplemented; - if (PyType_Ready(&google::protobuf::python::RepeatedCompositeContainer_Type) < 0) { + // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set + // it here as well to document that subclasses need to set it. + PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); + // Subclasses with message extensions will override _extensions_by_name and + // _extensions_by_number with fresh mutable dictionaries in AddDescriptors. + // All other classes can share this same immutable mapping. + ScopedPyObjectPtr empty_dict(PyDict_New()); + if (empty_dict == NULL) { + return false; + } + ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict.get())); + if (immutable_dict == NULL) { + return false; + } + if (PyDict_SetItem(CMessage_Type.tp_dict, + k_extensions_by_name, immutable_dict.get()) < 0) { + return false; + } + if (PyDict_SetItem(CMessage_Type.tp_dict, + k_extensions_by_number, immutable_dict.get()) < 0) { return false; } - PyModule_AddObject( - m, "RepeatedCompositeContainer", - reinterpret_cast<PyObject*>( - &google::protobuf::python::RepeatedCompositeContainer_Type)); + PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type)); - google::protobuf::python::ExtensionDict_Type.tp_new = PyType_GenericNew; - google::protobuf::python::ExtensionDict_Type.tp_hash = PyObject_HashNotImplemented; - if (PyType_Ready(&google::protobuf::python::ExtensionDict_Type) < 0) { - return false; + // Initialize Repeated container types. + { + if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "RepeatedScalarContainer", + reinterpret_cast<PyObject*>( + &RepeatedScalarContainer_Type)); + + if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) { + return false; + } + + PyModule_AddObject( + m, "RepeatedCompositeContainer", + reinterpret_cast<PyObject*>( + &RepeatedCompositeContainer_Type)); + + // Register them as collections.Sequence + ScopedPyObjectPtr collections(PyImport_ImportModule("collections")); + if (collections == NULL) { + return false; + } + ScopedPyObjectPtr mutable_sequence( + PyObject_GetAttrString(collections.get(), "MutableSequence")); + if (mutable_sequence == NULL) { + return false; + } + if (ScopedPyObjectPtr( + PyObject_CallMethod(mutable_sequence.get(), "register", "O", + &RepeatedScalarContainer_Type)) == NULL) { + return false; + } + if (ScopedPyObjectPtr( + PyObject_CallMethod(mutable_sequence.get(), "register", "O", + &RepeatedCompositeContainer_Type)) == NULL) { + return false; + } } - PyModule_AddObject( - m, "ExtensionDict", - reinterpret_cast<PyObject*>(&google::protobuf::python::ExtensionDict_Type)); + // Initialize Map container types. + { + // ScalarMapContainer_Type derives from our MutableMapping type. + ScopedPyObjectPtr containers(PyImport_ImportModule( + "google.protobuf.internal.containers")); + if (containers == NULL) { + return false; + } + + ScopedPyObjectPtr mutable_mapping( + PyObject_GetAttrString(containers.get(), "MutableMapping")); + if (mutable_mapping == NULL) { + return false; + } + + if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) { + return false; + } + + Py_INCREF(mutable_mapping.get()); +#if PY_MAJOR_VERSION >= 3 + PyObject* bases = PyTuple_New(1); + PyTuple_SET_ITEM(bases, 0, mutable_mapping.get()); + + ScalarMapContainer_Type = + PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases); + PyModule_AddObject(m, "ScalarMapContainer", ScalarMapContainer_Type); +#else + ScalarMapContainer_Type.tp_base = + reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); + + if (PyType_Ready(&ScalarMapContainer_Type) < 0) { + return false; + } - if (!google::protobuf::python::InitDescriptor()) { + PyModule_AddObject(m, "ScalarMapContainer", + reinterpret_cast<PyObject*>(&ScalarMapContainer_Type)); +#endif + + if (PyType_Ready(&MapIterator_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "MapIterator", + reinterpret_cast<PyObject*>(&MapIterator_Type)); + + +#if PY_MAJOR_VERSION >= 3 + MessageMapContainer_Type = + PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases); + PyModule_AddObject(m, "MessageMapContainer", MessageMapContainer_Type); +#else + Py_INCREF(mutable_mapping.get()); + MessageMapContainer_Type.tp_base = + reinterpret_cast<PyTypeObject*>(mutable_mapping.get()); + + if (PyType_Ready(&MessageMapContainer_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "MessageMapContainer", + reinterpret_cast<PyObject*>(&MessageMapContainer_Type)); +#endif + } + + if (PyType_Ready(&ExtensionDict_Type) < 0) { return false; } + PyModule_AddObject( + m, "ExtensionDict", + reinterpret_cast<PyObject*>(&ExtensionDict_Type)); + + // Expose the DescriptorPool used to hold all descriptors added from generated + // pb2.py files. + // PyModule_AddObject steals a reference. + Py_INCREF(GetDefaultDescriptorPool()); + PyModule_AddObject(m, "default_pool", + reinterpret_cast<PyObject*>(GetDefaultDescriptorPool())); + + PyModule_AddObject(m, "DescriptorPool", reinterpret_cast<PyObject*>( + &PyDescriptorPool_Type)); + + // This implementation provides full Descriptor types, we advertise it so that + // descriptor.py can use them in replacement of the Python classes. + PyModule_AddIntConstant(m, "_USE_C_DESCRIPTORS", 1); + + PyModule_AddObject(m, "Descriptor", reinterpret_cast<PyObject*>( + &PyMessageDescriptor_Type)); + PyModule_AddObject(m, "FieldDescriptor", reinterpret_cast<PyObject*>( + &PyFieldDescriptor_Type)); + PyModule_AddObject(m, "EnumDescriptor", reinterpret_cast<PyObject*>( + &PyEnumDescriptor_Type)); + PyModule_AddObject(m, "EnumValueDescriptor", reinterpret_cast<PyObject*>( + &PyEnumValueDescriptor_Type)); + PyModule_AddObject(m, "FileDescriptor", reinterpret_cast<PyObject*>( + &PyFileDescriptor_Type)); + PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast<PyObject*>( + &PyOneofDescriptor_Type)); PyObject* enum_type_wrapper = PyImport_ImportModule( "google.protobuf.internal.enum_type_wrapper"); if (enum_type_wrapper == NULL) { return false; } - google::protobuf::python::EnumTypeWrapper_class = + EnumTypeWrapper_class = PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper"); Py_DECREF(enum_type_wrapper); @@ -2490,25 +3059,21 @@ bool InitProto2MessageModule(PyObject *m) { if (message_module == NULL) { return false; } - google::protobuf::python::EncodeError_class = PyObject_GetAttrString(message_module, - "EncodeError"); - google::protobuf::python::DecodeError_class = PyObject_GetAttrString(message_module, - "DecodeError"); + EncodeError_class = PyObject_GetAttrString(message_module, "EncodeError"); + DecodeError_class = PyObject_GetAttrString(message_module, "DecodeError"); + PythonMessage_class = PyObject_GetAttrString(message_module, "Message"); Py_DECREF(message_module); PyObject* pickle_module = PyImport_ImportModule("pickle"); if (pickle_module == NULL) { return false; } - google::protobuf::python::PickleError_class = PyObject_GetAttrString(pickle_module, - "PickleError"); + PickleError_class = PyObject_GetAttrString(pickle_module, "PickleError"); Py_DECREF(pickle_module); // Override {Get,Mutable}CProtoInsidePyProto. - google::protobuf::python::GetCProtoInsidePyProtoPtr = - google::protobuf::python::GetCProtoInsidePyProtoImpl; - google::protobuf::python::MutableCProtoInsidePyProtoPtr = - google::protobuf::python::MutableCProtoInsidePyProtoImpl; + GetCProtoInsidePyProtoPtr = GetCProtoInsidePyProtoImpl; + MutableCProtoInsidePyProtoPtr = MutableCProtoInsidePyProtoImpl; return true; } @@ -2516,6 +3081,12 @@ bool InitProto2MessageModule(PyObject *m) { } // namespace python } // namespace protobuf +static PyMethodDef ModuleMethods[] = { + {"SetAllowOversizeProtos", + (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos, + METH_O, "Enable/disable oversize proto parsing."}, + { NULL, NULL} +}; #if PY_MAJOR_VERSION >= 3 static struct PyModuleDef _module = { @@ -2523,7 +3094,7 @@ static struct PyModuleDef _module = { "_message", google::protobuf::python::module_docstring, -1, - NULL, + ModuleMethods, /* m_methods */ NULL, NULL, NULL, @@ -2542,7 +3113,8 @@ extern "C" { #if PY_MAJOR_VERSION >= 3 m = PyModule_Create(&_module); #else - m = Py_InitModule3("_message", NULL, google::protobuf::python::module_docstring); + m = Py_InitModule3("_message", ModuleMethods, + google::protobuf::python::module_docstring); #endif if (m == NULL) { return INITFUNC_ERRORVAL; diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index 9f4978f44..3a4bec81c 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -42,20 +42,27 @@ #endif #include <string> - namespace google { namespace protobuf { class Message; class Reflection; class FieldDescriptor; - +class Descriptor; +class DescriptorPool; +class MessageFactory; + +#ifdef _SHARED_PTR_H +using std::shared_ptr; +using ::std::string; +#else using internal::shared_ptr; +#endif namespace python { -struct CFieldDescriptor; struct ExtensionDict; +struct PyDescriptorPool; typedef struct CMessage { PyObject_HEAD; @@ -79,13 +86,11 @@ typedef struct CMessage { // to use this pointer will result in a crash. struct CMessage* parent; - // Weak reference to the parent's descriptor that describes this submessage. + // Pointer to the parent's descriptor that describes this submessage. // Used together with the parent's message when making a default message // instance mutable. - // TODO(anuraag): With a bit of work on the Python/C++ layer, it should be - // possible to make this a direct pointer to a C++ FieldDescriptor, this would - // be easier if this implementation replaces upstream. - CFieldDescriptor* parent_field; + // The pointer is owned by the global DescriptorPool. + const FieldDescriptor* parent_field_descriptor; // Pointer to the C++ Message object for this CMessage. The // CMessage does not own this pointer. @@ -111,49 +116,91 @@ typedef struct CMessage { extern PyTypeObject CMessage_Type; + +// The (meta) type of all Messages classes. +// It allows us to cache some C++ pointers in the class object itself, they are +// faster to extract than from the type's dictionary. + +struct CMessageClass { + // This is how CPython subclasses C structures: the base structure must be + // the first member of the object. + PyHeapTypeObject super; + + // C++ descriptor of this message. + const Descriptor* message_descriptor; + + // Owned reference, used to keep the pointer above alive. + PyObject* py_message_descriptor; + + // The Python DescriptorPool used to create the class. It is needed to resolve + // fields descriptors, including extensions fields; its C++ MessageFactory is + // used to instantiate submessages. + // This can be different from DESCRIPTOR.file.pool, in the case of a custom + // DescriptorPool which defines new extensions. + // We own the reference, because it's important to keep the descriptors and + // factory alive. + PyDescriptorPool* py_descriptor_pool; + + PyObject* AsPyObject() { + return reinterpret_cast<PyObject*>(this); + } +}; + + namespace cmessage { -// Create a new empty message that can be populated by the parent. -PyObject* NewEmpty(PyObject* type); +// Internal function to create a new empty Message Python object, but with empty +// pointers to the C++ objects. +// The caller must fill self->message, self->owner and eventually self->parent. +CMessage* NewEmptyMessage(CMessageClass* type); // Release a submessage from its proto tree, making it a new top-level messgae. // A new message will be created if this is a read-only default instance. // // Corresponds to reflection api method ReleaseMessage. -int ReleaseSubMessage(google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* field_descriptor, +int ReleaseSubMessage(CMessage* self, + const FieldDescriptor* field_descriptor, CMessage* child_cmessage); +// Retrieves the C++ descriptor of a Python Extension descriptor. +// On error, return NULL with an exception set. +const FieldDescriptor* GetExtensionDescriptor(PyObject* extension); + // Initializes a new CMessage instance for a submessage. Only called once per // submessage as the result is cached in composite_fields. // // Corresponds to reflection api method GetMessage. -PyObject* InternalGetSubMessage(CMessage* self, - CFieldDescriptor* cfield_descriptor); +PyObject* InternalGetSubMessage( + CMessage* self, const FieldDescriptor* field_descriptor); // Deletes a range of C++ submessages in a repeated field (following a // removal in a RepeatedCompositeContainer). // // Releases messages to the provided cmessage_list if it is not NULL rather // than just removing them from the underlying proto. This cmessage_list must -// have a CMessage for each underlying submessage. The CMessages refered to +// have a CMessage for each underlying submessage. The CMessages referred to // by slice will be removed from cmessage_list by this function. // // Corresponds to reflection api method RemoveLast. -int InternalDeleteRepeatedField(google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* field_descriptor, +int InternalDeleteRepeatedField(CMessage* self, + const FieldDescriptor* field_descriptor, PyObject* slice, PyObject* cmessage_list); // Sets the specified scalar value to the message. int InternalSetScalar(CMessage* self, - const google::protobuf::FieldDescriptor* field_descriptor, + const FieldDescriptor* field_descriptor, PyObject* value); +// Sets the specified scalar value to the message. Requires it is not a Oneof. +int InternalSetNonOneofScalar(Message* message, + const FieldDescriptor* field_descriptor, + PyObject* arg); + // Retrieves the specified scalar value from the message. // // Returns a new python reference. -PyObject* InternalGetScalar(CMessage* self, - const google::protobuf::FieldDescriptor* field_descriptor); +PyObject* InternalGetScalar(const Message* message, + const FieldDescriptor* field_descriptor); // Clears the message, removing all contained data. Extension dictionary and // submessages are released first if there are remaining external references. @@ -169,8 +216,7 @@ PyObject* Clear(CMessage* self); // // Corresponds to reflection api method ClearField. PyObject* ClearFieldByDescriptor( - CMessage* self, - const google::protobuf::FieldDescriptor* descriptor); + CMessage* self, const FieldDescriptor* descriptor); // Clears the data for the given field name. The message is released if there // are any external references. @@ -183,17 +229,15 @@ PyObject* ClearField(CMessage* self, PyObject* arg); // // Corresponds to reflection api method HasField PyObject* HasFieldByDescriptor( - CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor); + CMessage* self, const FieldDescriptor* field_descriptor); // Checks if the message has the named field. // // Corresponds to reflection api method HasField. PyObject* HasField(CMessage* self, PyObject* arg); -// Initializes constants/enum values on a message. This is called by -// RepeatedCompositeContainer and ExtensionDict after calling the constructor. -// TODO(anuraag): Make it always called from within the constructor since it can -int InitAttributes(CMessage* self, PyObject* descriptor, PyObject* kwargs); +// Initializes values of fields on a newly constructed message. +int InitAttributes(CMessage* self, PyObject* kwargs); PyObject* MergeFrom(CMessage* self, PyObject* arg); @@ -216,16 +260,23 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner); int AssureWritable(CMessage* self); +// Returns the "best" DescriptorPool for the given message. +// This is often equivalent to message.DESCRIPTOR.pool, but not always, when +// the message class was created from a MessageFactory using a custom pool which +// uses the generated pool as an underlay. +// +// The returned pool is suitable for finding fields and building submessages, +// even in the case of extensions. +PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message); + } // namespace cmessage + /* Is 64bit */ #define IS_64BIT (SIZEOF_LONG == 8) -#define FIELD_BELONGS_TO_MESSAGE(field_descriptor, message) \ - ((message)->GetDescriptor() == (field_descriptor)->containing_type()) - #define FIELD_IS_REPEATED(field_descriptor) \ - ((field_descriptor)->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) + ((field_descriptor)->label() == FieldDescriptor::LABEL_REPEATED) #define GOOGLE_CHECK_GET_INT32(arg, value, err) \ int32 value; \ @@ -278,7 +329,7 @@ extern PyObject* kint64min_py; extern PyObject* kint64max_py; extern PyObject* kuint64max_py; -#define C(str) const_cast<char*>(str) +#define FULL_MODULE_NAME "google.protobuf.pyext._message" void FormatTypeError(PyObject* arg, char* expected_types); template<class T> @@ -287,14 +338,19 @@ bool CheckAndGetInteger( bool CheckAndGetDouble(PyObject* arg, double* value); bool CheckAndGetFloat(PyObject* arg, float* value); bool CheckAndGetBool(PyObject* arg, bool* value); +PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor); bool CheckAndSetString( - PyObject* arg, google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* descriptor, - const google::protobuf::Reflection* reflection, + PyObject* arg, Message* message, + const FieldDescriptor* descriptor, + const Reflection* reflection, bool append, int index); -PyObject* ToStringObject( - const google::protobuf::FieldDescriptor* descriptor, string value); +PyObject* ToStringObject(const FieldDescriptor* descriptor, string value); + +// Check if the passed field descriptor belongs to the given message. +// If not, return false and set a Python exception (a KeyError) +bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, + const Message* message); extern PyObject* PickleError_class; diff --git a/python/google/protobuf/pyext/message_factory_cpp2_test.py b/python/google/protobuf/pyext/message_factory_cpp2_test.py deleted file mode 100644 index 32ab4f852..000000000 --- a/python/google/protobuf/pyext/message_factory_cpp2_test.py +++ /dev/null @@ -1,56 +0,0 @@ -#! /usr/bin/python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Tests for google.protobuf.message_factory.""" - -import os -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' - -# We must set the implementation version above before the google3 imports. -# pylint: disable=g-import-not-at-top -from google.apputils import basetest -from google.protobuf.internal import api_implementation -# Run all tests from the original module by putting them in our namespace. -# pylint: disable=wildcard-import -from google.protobuf.internal.message_factory_test import * - - -class ConfirmCppApi2Test(basetest.TestCase): - - def testImplementationSetting(self): - self.assertEqual('cpp', api_implementation.Type()) - self.assertEqual(2, api_implementation.Version()) - - -if __name__ == '__main__': - basetest.main() diff --git a/python/google/protobuf/pyext/proto2_api_test.proto b/python/google/protobuf/pyext/proto2_api_test.proto index 72c31b1cd..18aecfb7d 100644 --- a/python/google/protobuf/pyext/proto2_api_test.proto +++ b/python/google/protobuf/pyext/proto2_api_test.proto @@ -28,6 +28,8 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +syntax = "proto2"; + import "google/protobuf/internal/cpp/proto1_api_test.proto"; package google.protobuf.python.internal; diff --git a/python/google/protobuf/pyext/python.proto b/python/google/protobuf/pyext/python.proto index d47d402c6..cce645d71 100644 --- a/python/google/protobuf/pyext/python.proto +++ b/python/google/protobuf/pyext/python.proto @@ -33,6 +33,7 @@ // These message definitions are used to exercises known corner cases // in the C++ implementation of the Python API. +syntax = "proto2"; package google.protobuf.python.internal; @@ -63,4 +64,5 @@ message TestAllExtensions { extend TestAllExtensions { optional TestAllTypes.NestedMessage optional_nested_message_extension = 1; + repeated TestAllTypes.NestedMessage repeated_nested_message_extension = 2; } diff --git a/python/google/protobuf/pyext/reflection_cpp2_generated_test.py b/python/google/protobuf/pyext/reflection_cpp2_generated_test.py deleted file mode 100755 index 552efd482..000000000 --- a/python/google/protobuf/pyext/reflection_cpp2_generated_test.py +++ /dev/null @@ -1,94 +0,0 @@ -#! /usr/bin/python -# -*- coding: utf-8 -*- -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Unittest for reflection.py, which tests the generated C++ implementation.""" - -__author__ = 'jasonh@google.com (Jason Hsueh)' - -import os -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' - -from google.apputils import basetest -from google.protobuf.internal import api_implementation -from google.protobuf.internal import more_extensions_dynamic_pb2 -from google.protobuf.internal import more_extensions_pb2 -from google.protobuf.internal.reflection_test import * - - -class ReflectionCppTest(basetest.TestCase): - def testImplementationSetting(self): - self.assertEqual('cpp', api_implementation.Type()) - self.assertEqual(2, api_implementation.Version()) - - def testExtensionOfGeneratedTypeInDynamicFile(self): - """Tests that a file built dynamically can extend a generated C++ type. - - The C++ implementation uses a DescriptorPool that has the generated - DescriptorPool as an underlay. Typically, a type can only find - extensions in its own pool. With the python C-extension, the generated C++ - extendee may be available, but not the extension. This tests that the - C-extension implements the correct special handling to make such extensions - available. - """ - pb1 = more_extensions_pb2.ExtendedMessage() - # Test that basic accessors work. - self.assertFalse( - pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) - self.assertFalse( - pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) - pb1.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension] = 17 - pb1.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a = 24 - self.assertTrue( - pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) - self.assertTrue( - pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) - - # Now serialize the data and parse to a new message. - pb2 = more_extensions_pb2.ExtendedMessage() - pb2.MergeFromString(pb1.SerializeToString()) - - self.assertTrue( - pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) - self.assertTrue( - pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) - self.assertEqual( - 17, pb2.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension]) - self.assertEqual( - 24, - pb2.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a) - - - -if __name__ == '__main__': - basetest.main() diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc index 5c05b3d89..4f339e772 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.cc +++ b/python/google/protobuf/pyext/repeated_composite_container.cc @@ -38,11 +38,13 @@ #include <google/protobuf/stubs/shared_ptr.h> #endif +#include <google/protobuf/stubs/logging.h> #include <google/protobuf/stubs/common.h> #include <google/protobuf/descriptor.h> #include <google/protobuf/dynamic_message.h> #include <google/protobuf/message.h> #include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/message.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> @@ -56,8 +58,6 @@ namespace google { namespace protobuf { namespace python { -extern google::protobuf::DynamicMessageFactory* global_message_factory; - namespace repeated_composite_container { // TODO(tibell): We might also want to check: @@ -65,144 +65,25 @@ namespace repeated_composite_container { #define GOOGLE_CHECK_ATTACHED(self) \ do { \ GOOGLE_CHECK_NOTNULL((self)->message); \ - GOOGLE_CHECK_NOTNULL((self)->parent_field); \ + GOOGLE_CHECK_NOTNULL((self)->parent_field_descriptor); \ } while (0); #define GOOGLE_CHECK_RELEASED(self) \ do { \ GOOGLE_CHECK((self)->owner.get() == NULL); \ GOOGLE_CHECK((self)->message == NULL); \ - GOOGLE_CHECK((self)->parent_field == NULL); \ + GOOGLE_CHECK((self)->parent_field_descriptor == NULL); \ GOOGLE_CHECK((self)->parent == NULL); \ } while (0); -// Returns a new reference. -static PyObject* GetKey(PyObject* x) { - // Just the identity function. - Py_INCREF(x); - return x; -} - -#define GET_KEY(keyfunc, value) \ - ((keyfunc) == NULL ? \ - GetKey((value)) : \ - PyObject_CallFunctionObjArgs((keyfunc), (value), NULL)) - -// Converts a comparison function that returns -1, 0, or 1 into a -// less-than predicate. -// -// Returns -1 on error, 1 if x < y, 0 if x >= y. -static int islt(PyObject *x, PyObject *y, PyObject *compare) { - if (compare == NULL) - return PyObject_RichCompareBool(x, y, Py_LT); - - ScopedPyObjectPtr res(PyObject_CallFunctionObjArgs(compare, x, y, NULL)); - if (res == NULL) - return -1; - if (!PyInt_Check(res)) { - PyErr_Format(PyExc_TypeError, - "comparison function must return int, not %.200s", - Py_TYPE(res)->tp_name); - return -1; - } - return PyInt_AsLong(res) < 0; -} - -// Copied from uarrsort.c but swaps memcpy swaps with protobuf/python swaps -// TODO(anuraag): Is there a better way to do this then reinventing the wheel? -static int InternalQuickSort(RepeatedCompositeContainer* self, - Py_ssize_t start, - Py_ssize_t limit, - PyObject* cmp, - PyObject* keyfunc) { - if (limit - start <= 1) - return 0; // Nothing to sort. - - GOOGLE_CHECK_ATTACHED(self); - - google::protobuf::Message* message = self->message; - const google::protobuf::Reflection* reflection = message->GetReflection(); - const google::protobuf::FieldDescriptor* descriptor = self->parent_field->descriptor; - Py_ssize_t left; - Py_ssize_t right; - - PyObject* children = self->child_messages; - - do { - left = start; - right = limit; - ScopedPyObjectPtr mid( - GET_KEY(keyfunc, PyList_GET_ITEM(children, (start + limit) / 2))); - do { - ScopedPyObjectPtr key(GET_KEY(keyfunc, PyList_GET_ITEM(children, left))); - int is_lt = islt(key, mid, cmp); - if (is_lt == -1) - return -1; - /* array[left]<x */ - while (is_lt) { - ++left; - ScopedPyObjectPtr key(GET_KEY(keyfunc, - PyList_GET_ITEM(children, left))); - is_lt = islt(key, mid, cmp); - if (is_lt == -1) - return -1; - } - key.reset(GET_KEY(keyfunc, PyList_GET_ITEM(children, right - 1))); - is_lt = islt(mid, key, cmp); - if (is_lt == -1) - return -1; - while (is_lt) { - --right; - ScopedPyObjectPtr key(GET_KEY(keyfunc, - PyList_GET_ITEM(children, right - 1))); - is_lt = islt(mid, key, cmp); - if (is_lt == -1) - return -1; - } - if (left < right) { - --right; - if (left < right) { - reflection->SwapElements(message, descriptor, left, right); - PyObject* tmp = PyList_GET_ITEM(children, left); - PyList_SET_ITEM(children, left, PyList_GET_ITEM(children, right)); - PyList_SET_ITEM(children, right, tmp); - } - ++left; - } - } while (left < right); - - if ((right - start) < (limit - left)) { - /* sort [start..right[ */ - if (start < (right - 1)) { - InternalQuickSort(self, start, right, cmp, keyfunc); - } - - /* sort [left..limit[ */ - start = left; - } else { - /* sort [left..limit[ */ - if (left < (limit - 1)) { - InternalQuickSort(self, left, limit, cmp, keyfunc); - } - - /* sort [start..right[ */ - limit = right; - } - } while (start < (limit - 1)); - - return 0; -} - -#undef GET_KEY - // --------------------------------------------------------------------- // len() static Py_ssize_t Length(RepeatedCompositeContainer* self) { - google::protobuf::Message* message = self->message; + Message* message = self->message; if (message != NULL) { return message->GetReflection()->FieldSize(*message, - self->parent_field->descriptor); + self->parent_field_descriptor); } else { // The container has been released (i.e. by a call to Clear() or // ClearField() on the parent) and thus there's no message. @@ -221,23 +102,22 @@ static int UpdateChildMessages(RepeatedCompositeContainer* self) { // be removed in such a way so there's no need to worry about that. Py_ssize_t message_length = Length(self); Py_ssize_t child_length = PyList_GET_SIZE(self->child_messages); - google::protobuf::Message* message = self->message; - const google::protobuf::Reflection* reflection = message->GetReflection(); + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); for (Py_ssize_t i = child_length; i < message_length; ++i) { const Message& sub_message = reflection->GetRepeatedMessage( - *(self->message), self->parent_field->descriptor, i); - ScopedPyObjectPtr py_cmsg(cmessage::NewEmpty(self->subclass_init)); - if (py_cmsg == NULL) { + *(self->message), self->parent_field_descriptor, i); + CMessage* cmsg = cmessage::NewEmptyMessage(self->child_message_class); + ScopedPyObjectPtr py_cmsg(reinterpret_cast<PyObject*>(cmsg)); + if (cmsg == NULL) { return -1; } - CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg.get()); cmsg->owner = self->owner; - cmsg->message = const_cast<google::protobuf::Message*>(&sub_message); + cmsg->message = const_cast<Message*>(&sub_message); cmsg->parent = self->parent; - if (cmessage::InitAttributes(cmsg, NULL, NULL) < 0) { + if (PyList_Append(self->child_messages, py_cmsg.get()) < 0) { return -1; } - PyList_Append(self->child_messages, py_cmsg); } return 0; } @@ -255,26 +135,27 @@ static PyObject* AddToAttached(RepeatedCompositeContainer* self, } if (cmessage::AssureWritable(self->parent) == -1) return NULL; - google::protobuf::Message* message = self->message; - google::protobuf::Message* sub_message = + Message* message = self->message; + Message* sub_message = message->GetReflection()->AddMessage(message, - self->parent_field->descriptor); - PyObject* py_cmsg = cmessage::NewEmpty(self->subclass_init); - if (py_cmsg == NULL) { + self->parent_field_descriptor); + CMessage* cmsg = cmessage::NewEmptyMessage(self->child_message_class); + if (cmsg == NULL) return NULL; - } - CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg); cmsg->owner = self->owner; cmsg->message = sub_message; cmsg->parent = self->parent; - // cmessage::InitAttributes must be called after cmsg->message has - // been set. - if (cmessage::InitAttributes(cmsg, NULL, kwargs) < 0) { + if (cmessage::InitAttributes(cmsg, kwargs) < 0) { + Py_DECREF(cmsg); + return NULL; + } + + PyObject* py_cmsg = reinterpret_cast<PyObject*>(cmsg); + if (PyList_Append(self->child_messages, py_cmsg) < 0) { Py_DECREF(py_cmsg); return NULL; } - PyList_Append(self->child_messages, py_cmsg); return py_cmsg; } @@ -283,20 +164,16 @@ static PyObject* AddToReleased(RepeatedCompositeContainer* self, PyObject* kwargs) { GOOGLE_CHECK_RELEASED(self); - // Create the CMessage - PyObject* py_cmsg = PyObject_CallObject(self->subclass_init, NULL); + // Create a new Message detached from the rest. + PyObject* py_cmsg = PyEval_CallObjectWithKeywords( + self->child_message_class->AsPyObject(), NULL, kwargs); if (py_cmsg == NULL) return NULL; - CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg); - if (cmessage::InitAttributes(cmsg, NULL, kwargs) < 0) { + + if (PyList_Append(self->child_messages, py_cmsg) < 0) { Py_DECREF(py_cmsg); return NULL; } - - // The Message got created by the call to subclass_init above and - // it set self->owner to the newly allocated message. - - PyList_Append(self->child_messages, py_cmsg); return py_cmsg; } @@ -323,8 +200,8 @@ PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) { return NULL; } ScopedPyObjectPtr next; - while ((next.reset(PyIter_Next(iter))) != NULL) { - if (!PyObject_TypeCheck(next, &CMessage_Type)) { + while ((next.reset(PyIter_Next(iter.get()))) != NULL) { + if (!PyObject_TypeCheck(next.get(), &CMessage_Type)) { PyErr_SetString(PyExc_TypeError, "Not a cmessage"); return NULL; } @@ -333,7 +210,8 @@ PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) { return NULL; } CMessage* new_cmessage = reinterpret_cast<CMessage*>(new_message.get()); - if (cmessage::MergeFrom(new_cmessage, next) == NULL) { + if (ScopedPyObjectPtr(cmessage::MergeFrom(new_cmessage, next.get())) == + NULL) { return NULL; } } @@ -354,35 +232,9 @@ PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* slice) { if (UpdateChildMessages(self) < 0) { return NULL; } - Py_ssize_t from; - Py_ssize_t to; - Py_ssize_t step; - Py_ssize_t length = Length(self); - Py_ssize_t slicelength; - if (PySlice_Check(slice)) { -#if PY_MAJOR_VERSION >= 3 - if (PySlice_GetIndicesEx(slice, -#else - if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice), -#endif - length, &from, &to, &step, &slicelength) == -1) { - return NULL; - } - return PyList_GetSlice(self->child_messages, from, to); - } else if (PyInt_Check(slice) || PyLong_Check(slice)) { - from = to = PyLong_AsLong(slice); - if (from < 0) { - from = to = length + from; - } - PyObject* result = PyList_GetItem(self->child_messages, from); - if (result == NULL) { - return NULL; - } - Py_INCREF(result); - return result; - } - PyErr_SetString(PyExc_TypeError, "index must be an integer or slice"); - return NULL; + // Just forward the call to the subscript-handling function of the + // list containing the child messages. + return PyObject_GetItem(self->child_messages, slice); } int AssignSubscript(RepeatedCompositeContainer* self, @@ -397,9 +249,9 @@ int AssignSubscript(RepeatedCompositeContainer* self, } // Delete from the underlying Message, if any. - if (self->message != NULL) { - if (cmessage::InternalDeleteRepeatedField(self->message, - self->parent_field->descriptor, + if (self->parent != NULL) { + if (cmessage::InternalDeleteRepeatedField(self->parent, + self->parent_field_descriptor, slice, self->child_messages) < 0) { return -1; @@ -441,7 +293,7 @@ static PyObject* Remove(RepeatedCompositeContainer* self, PyObject* value) { return NULL; } ScopedPyObjectPtr py_index(PyLong_FromLong(index)); - if (AssignSubscript(self, py_index, NULL) < 0) { + if (AssignSubscript(self, py_index.get(), NULL) < 0) { return NULL; } Py_RETURN_NONE; @@ -465,17 +317,17 @@ static PyObject* RichCompare(RepeatedCompositeContainer* self, if (full_slice == NULL) { return NULL; } - ScopedPyObjectPtr list(Subscript(self, full_slice)); + ScopedPyObjectPtr list(Subscript(self, full_slice.get())); if (list == NULL) { return NULL; } ScopedPyObjectPtr other_list( - Subscript( - reinterpret_cast<RepeatedCompositeContainer*>(other), full_slice)); + Subscript(reinterpret_cast<RepeatedCompositeContainer*>(other), + full_slice.get())); if (other_list == NULL) { return NULL; } - return PyObject_RichCompare(list, other_list, opid); + return PyObject_RichCompare(list.get(), other_list.get(), opid); } else { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; @@ -485,58 +337,39 @@ static PyObject* RichCompare(RepeatedCompositeContainer* self, // --------------------------------------------------------------------- // sort() -static PyObject* SortAttached(RepeatedCompositeContainer* self, - PyObject* args, - PyObject* kwds) { - // Sort the underlying Message array. - PyObject *compare = NULL; - int reverse = 0; - PyObject *keyfunc = NULL; - static char *kwlist[] = {"cmp", "key", "reverse", 0}; - - if (args != NULL) { - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOi:sort", - kwlist, &compare, &keyfunc, &reverse)) - return NULL; - } - if (compare == Py_None) - compare = NULL; - if (keyfunc == Py_None) - keyfunc = NULL; - +static void ReorderAttached(RepeatedCompositeContainer* self) { + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + const FieldDescriptor* descriptor = self->parent_field_descriptor; const Py_ssize_t length = Length(self); - if (InternalQuickSort(self, 0, length, compare, keyfunc) < 0) - return NULL; - - // Finally reverse the result if requested. - if (reverse) { - google::protobuf::Message* message = self->message; - const google::protobuf::Reflection* reflection = message->GetReflection(); - const google::protobuf::FieldDescriptor* descriptor = self->parent_field->descriptor; - // Reverse the Message array. - for (int i = 0; i < length / 2; ++i) - reflection->SwapElements(message, descriptor, i, length - i - 1); + // Since Python protobuf objects are never arena-allocated, adding and + // removing message pointers to the underlying array is just updating + // pointers. + for (Py_ssize_t i = 0; i < length; ++i) + reflection->ReleaseLast(message, descriptor); - // Reverse the Python list. - ScopedPyObjectPtr res(PyObject_CallMethod(self->child_messages, - "reverse", NULL)); - if (res == NULL) - return NULL; + for (Py_ssize_t i = 0; i < length; ++i) { + CMessage* py_cmsg = reinterpret_cast<CMessage*>( + PyList_GET_ITEM(self->child_messages, i)); + reflection->AddAllocatedMessage(message, descriptor, py_cmsg->message); } - - Py_RETURN_NONE; } -static PyObject* SortReleased(RepeatedCompositeContainer* self, - PyObject* args, - PyObject* kwds) { +// Returns 0 if successful; returns -1 and sets an exception if +// unsuccessful. +static int SortPythonMessages(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwds) { ScopedPyObjectPtr m(PyObject_GetAttrString(self->child_messages, "sort")); if (m == NULL) - return NULL; - if (PyObject_Call(m, args, kwds) == NULL) - return NULL; - Py_RETURN_NONE; + return -1; + if (PyObject_Call(m.get(), args, kwds) == NULL) + return -1; + if (self->message != NULL) { + ReorderAttached(self); + } + return 0; } static PyObject* Sort(RepeatedCompositeContainer* self, @@ -554,13 +387,13 @@ static PyObject* Sort(RepeatedCompositeContainer* self, } } - if (UpdateChildMessages(self) < 0) + if (UpdateChildMessages(self) < 0) { return NULL; - if (self->message == NULL) { - return SortReleased(self, args, kwds); - } else { - return SortAttached(self, args, kwds); } + if (SortPythonMessages(self, args, kwds) < 0) { + return NULL; + } + Py_RETURN_NONE; } // --------------------------------------------------------------------- @@ -581,46 +414,43 @@ static PyObject* Item(RepeatedCompositeContainer* self, Py_ssize_t index) { return item; } -// The caller takes ownership of the returned Message. -Message* ReleaseLast(const FieldDescriptor* field, - const Descriptor* type, - Message* message) { - GOOGLE_CHECK_NOTNULL(field); - GOOGLE_CHECK_NOTNULL(type); - GOOGLE_CHECK_NOTNULL(message); - - Message* released_message = message->GetReflection()->ReleaseLast( - message, field); - // TODO(tibell): Deal with proto1. - - // ReleaseMessage will return NULL which differs from - // child_cmessage->message, if the field does not exist. In this case, - // the latter points to the default instance via a const_cast<>, so we - // have to reset it to a new mutable object since we are taking ownership. - if (released_message == NULL) { - const Message* prototype = global_message_factory->GetPrototype(type); - GOOGLE_CHECK_NOTNULL(prototype); - return prototype->New(); - } else { - return released_message; +static PyObject* Pop(RepeatedCompositeContainer* self, + PyObject* args) { + Py_ssize_t index = -1; + if (!PyArg_ParseTuple(args, "|n", &index)) { + return NULL; } + PyObject* item = Item(self, index); + if (item == NULL) { + PyErr_Format(PyExc_IndexError, + "list index (%zd) out of range", + index); + return NULL; + } + ScopedPyObjectPtr py_index(PyLong_FromSsize_t(index)); + if (AssignSubscript(self, py_index.get(), NULL) < 0) { + return NULL; + } + return item; } -// Release field of message and transfer the ownership to cmessage. -void ReleaseLastTo(const FieldDescriptor* field, - Message* message, - CMessage* cmessage) { +// Release field of parent message and transfer the ownership to target. +void ReleaseLastTo(CMessage* parent, + const FieldDescriptor* field, + CMessage* target) { + GOOGLE_CHECK_NOTNULL(parent); GOOGLE_CHECK_NOTNULL(field); - GOOGLE_CHECK_NOTNULL(message); - GOOGLE_CHECK_NOTNULL(cmessage); + GOOGLE_CHECK_NOTNULL(target); shared_ptr<Message> released_message( - ReleaseLast(field, cmessage->message->GetDescriptor(), message)); - cmessage->parent = NULL; - cmessage->parent_field = NULL; - cmessage->message = released_message.get(); - cmessage->read_only = false; - cmessage::SetOwner(cmessage, released_message); + parent->message->GetReflection()->ReleaseLast(parent->message, field)); + // TODO(tibell): Deal with proto1. + + target->parent = NULL; + target->parent_field_descriptor = NULL; + target->message = released_message.get(); + target->read_only = false; + cmessage::SetOwner(target, released_message); } // Called to release a container using @@ -633,7 +463,7 @@ int Release(RepeatedCompositeContainer* self) { } Message* message = self->message; - const FieldDescriptor* field = self->parent_field->descriptor; + const FieldDescriptor* field = self->parent_field_descriptor; // The reflection API only lets us release the last message in a // repeated field. Therefore we iterate through the children @@ -643,12 +473,12 @@ int Release(RepeatedCompositeContainer* self) { for (Py_ssize_t i = size - 1; i >= 0; --i) { CMessage* child_cmessage = reinterpret_cast<CMessage*>( PyList_GET_ITEM(self->child_messages, i)); - ReleaseLastTo(field, message, child_cmessage); + ReleaseLastTo(self->parent, field, child_cmessage); } // Detach from containing message. self->parent = NULL; - self->parent_field = NULL; + self->parent_field_descriptor = NULL; self->message = NULL; self->owner.reset(); @@ -670,22 +500,40 @@ int SetOwner(RepeatedCompositeContainer* self, return 0; } -static int Init(RepeatedCompositeContainer* self, - PyObject* args, - PyObject* kwargs) { - self->message = NULL; - self->parent = NULL; - self->parent_field = NULL; - self->subclass_init = NULL; +// The private constructor of RepeatedCompositeContainer objects. +PyObject *NewContainer( + CMessage* parent, + const FieldDescriptor* parent_field_descriptor, + CMessageClass* concrete_class) { + if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { + return NULL; + } + + RepeatedCompositeContainer* self = + reinterpret_cast<RepeatedCompositeContainer*>( + PyType_GenericAlloc(&RepeatedCompositeContainer_Type, 0)); + if (self == NULL) { + return NULL; + } + + self->message = parent->message; + self->parent = parent; + self->parent_field_descriptor = parent_field_descriptor; + self->owner = parent->owner; + Py_INCREF(concrete_class); + self->child_message_class = concrete_class; self->child_messages = PyList_New(0); - return 0; + + return reinterpret_cast<PyObject*>(self); } static void Dealloc(RepeatedCompositeContainer* self) { Py_CLEAR(self->child_messages); + Py_CLEAR(self->child_message_class); // TODO(tibell): Do we need to call delete on these objects to make // sure their destructors are called? self->owner.reset(); + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); } @@ -707,6 +555,8 @@ static PyMethodDef Methods[] = { "Adds an object to the repeated container." }, { "extend", (PyCFunction) Extend, METH_O, "Adds objects to the repeated container." }, + { "pop", (PyCFunction)Pop, METH_VARARGS, + "Removes an object from the repeated container and returns it." }, { "remove", (PyCFunction) Remove, METH_O, "Removes an object from the repeated container." }, { "sort", (PyCFunction) Sort, METH_VARARGS | METH_KEYWORDS, @@ -720,9 +570,8 @@ static PyMethodDef Methods[] = { PyTypeObject RepeatedCompositeContainer_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - "google.protobuf.internal." - "cpp._message.RepeatedCompositeContainer", // tp_name - sizeof(RepeatedCompositeContainer), // tp_basicsize + FULL_MODULE_NAME ".RepeatedCompositeContainer", // tp_name + sizeof(RepeatedCompositeContainer), // tp_basicsize 0, // tp_itemsize (destructor)repeated_composite_container::Dealloc, // tp_dealloc 0, // tp_print @@ -733,7 +582,7 @@ PyTypeObject RepeatedCompositeContainer_Type = { 0, // tp_as_number &repeated_composite_container::SqMethods, // tp_as_sequence &repeated_composite_container::MpMethods, // tp_as_mapping - 0, // tp_hash + PyObject_HashNotImplemented, // tp_hash 0, // tp_call 0, // tp_str 0, // tp_getattro @@ -755,7 +604,7 @@ PyTypeObject RepeatedCompositeContainer_Type = { 0, // tp_descr_get 0, // tp_descr_set 0, // tp_dictoffset - (initproc)repeated_composite_container::Init, // tp_init + 0, // tp_init }; } // namespace python diff --git a/python/google/protobuf/pyext/repeated_composite_container.h b/python/google/protobuf/pyext/repeated_composite_container.h index 898ef5a71..a7b56b61b 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.h +++ b/python/google/protobuf/pyext/repeated_composite_container.h @@ -43,30 +43,33 @@ #include <string> #include <vector> - namespace google { namespace protobuf { class FieldDescriptor; class Message; +#ifdef _SHARED_PTR_H +using std::shared_ptr; +#else using internal::shared_ptr; +#endif namespace python { struct CMessage; -struct CFieldDescriptor; +struct CMessageClass; // A RepeatedCompositeContainer can be in one of two states: attached // or released. // // When in the attached state all modifications to the container are // done both on the 'message' and on the 'child_messages' -// list. In this state all Messages refered to by the children in +// list. In this state all Messages referred to by the children in // 'child_messages' are owner by the 'owner'. // // When in the released state 'message', 'owner', 'parent', and -// 'parent_field' are NULL. +// 'parent_field_descriptor' are NULL. typedef struct RepeatedCompositeContainer { PyObject_HEAD; @@ -82,7 +85,8 @@ typedef struct RepeatedCompositeContainer { CMessage* parent; // A descriptor used to modify the underlying 'message'. - CFieldDescriptor* parent_field; + // The pointer is owned by the global DescriptorPool. + const FieldDescriptor* parent_field_descriptor; // Pointer to the C++ Message that contains this container. The // RepeatedCompositeContainer does not own this pointer. @@ -91,8 +95,8 @@ typedef struct RepeatedCompositeContainer { // calling Clear() or ClearField() on the parent. Message* message; - // A callable that is used to create new child messages. - PyObject* subclass_init; + // The type used to create new child messages. + CMessageClass* child_message_class; // A list of child messages. PyObject* child_messages; @@ -102,8 +106,12 @@ extern PyTypeObject RepeatedCompositeContainer_Type; namespace repeated_composite_container { -// Returns the number of items in this repeated composite container. -static Py_ssize_t Length(RepeatedCompositeContainer* self); +// Builds a RepeatedCompositeContainer object, from a parent message and a +// field descriptor. +PyObject *NewContainer( + CMessage* parent, + const FieldDescriptor* parent_field_descriptor, + CMessageClass *child_message_class); // Appends a new CMessage to the container and returns it. The // CMessage is initialized using the content of kwargs. @@ -143,8 +151,7 @@ int AssignSubscript(RepeatedCompositeContainer* self, // Releases the messages in the container to the given message. // // Returns 0 on success, -1 on failure. -int ReleaseToMessage(RepeatedCompositeContainer* self, - google::protobuf::Message* new_message); +int ReleaseToMessage(RepeatedCompositeContainer* self, Message* new_message); // Releases the messages in the container to a new message. // @@ -156,13 +163,13 @@ int SetOwner(RepeatedCompositeContainer* self, const shared_ptr<Message>& new_owner); // Removes the last element of the repeated message field 'field' on -// the Message 'message', and transfers the ownership of the released -// Message to 'cmessage'. +// the Message 'parent', and transfers the ownership of the released +// Message to 'target'. // // Corresponds to reflection api method ReleaseMessage. -void ReleaseLastTo(const FieldDescriptor* field, - Message* message, - CMessage* cmessage); +void ReleaseLastTo(CMessage* parent, + const FieldDescriptor* field, + CMessage* target); } // namespace repeated_composite_container } // namespace python diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc index e627d37dc..95da85f87 100644 --- a/python/google/protobuf/pyext/repeated_scalar_container.cc +++ b/python/google/protobuf/pyext/repeated_scalar_container.cc @@ -39,10 +39,12 @@ #endif #include <google/protobuf/stubs/common.h> +#include <google/protobuf/stubs/logging.h> #include <google/protobuf/descriptor.h> #include <google/protobuf/dynamic_message.h> #include <google/protobuf/message.h> #include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/message.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> @@ -52,7 +54,7 @@ #error "Python 3.0 - 3.2 are not supported." #else #define PyString_AsString(ob) \ - (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob)) + (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob)) #endif #endif @@ -60,17 +62,15 @@ namespace google { namespace protobuf { namespace python { -extern google::protobuf::DynamicMessageFactory* global_message_factory; - namespace repeated_scalar_container { static int InternalAssignRepeatedField( RepeatedScalarContainer* self, PyObject* list) { self->message->GetReflection()->ClearField(self->message, - self->parent_field->descriptor); + self->parent_field_descriptor); for (Py_ssize_t i = 0; i < PyList_GET_SIZE(list); ++i) { PyObject* value = PyList_GET_ITEM(list, i); - if (Append(self, value) == NULL) { + if (ScopedPyObjectPtr(Append(self, value)) == NULL) { return -1; } } @@ -78,25 +78,19 @@ static int InternalAssignRepeatedField( } static Py_ssize_t Len(RepeatedScalarContainer* self) { - google::protobuf::Message* message = self->message; + Message* message = self->message; return message->GetReflection()->FieldSize(*message, - self->parent_field->descriptor); + self->parent_field_descriptor); } static int AssignItem(RepeatedScalarContainer* self, Py_ssize_t index, PyObject* arg) { cmessage::AssureWritable(self->parent); - google::protobuf::Message* message = self->message; - const google::protobuf::FieldDescriptor* field_descriptor = - self->parent_field->descriptor; - if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { - PyErr_SetString( - PyExc_KeyError, "Field does not belong to message!"); - return -1; - } + Message* message = self->message; + const FieldDescriptor* field_descriptor = self->parent_field_descriptor; - const google::protobuf::Reflection* reflection = message->GetReflection(); + const Reflection* reflection = message->GetReflection(); int field_size = reflection->FieldSize(*message, field_descriptor); if (index < 0) { index = field_size + index; @@ -110,8 +104,8 @@ static int AssignItem(RepeatedScalarContainer* self, if (arg == NULL) { ScopedPyObjectPtr py_index(PyLong_FromLong(index)); - return cmessage::InternalDeleteRepeatedField(message, field_descriptor, - py_index, NULL); + return cmessage::InternalDeleteRepeatedField(self->parent, field_descriptor, + py_index.get(), NULL); } if (PySequence_Check(arg) && !(PyBytes_Check(arg) || PyUnicode_Check(arg))) { @@ -120,64 +114,68 @@ static int AssignItem(RepeatedScalarContainer* self, } switch (field_descriptor->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + case FieldDescriptor::CPPTYPE_INT32: { GOOGLE_CHECK_GET_INT32(arg, value, -1); reflection->SetRepeatedInt32(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + case FieldDescriptor::CPPTYPE_INT64: { GOOGLE_CHECK_GET_INT64(arg, value, -1); reflection->SetRepeatedInt64(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + case FieldDescriptor::CPPTYPE_UINT32: { GOOGLE_CHECK_GET_UINT32(arg, value, -1); reflection->SetRepeatedUInt32(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + case FieldDescriptor::CPPTYPE_UINT64: { GOOGLE_CHECK_GET_UINT64(arg, value, -1); reflection->SetRepeatedUInt64(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + case FieldDescriptor::CPPTYPE_FLOAT: { GOOGLE_CHECK_GET_FLOAT(arg, value, -1); reflection->SetRepeatedFloat(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + case FieldDescriptor::CPPTYPE_DOUBLE: { GOOGLE_CHECK_GET_DOUBLE(arg, value, -1); reflection->SetRepeatedDouble(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + case FieldDescriptor::CPPTYPE_BOOL: { GOOGLE_CHECK_GET_BOOL(arg, value, -1); reflection->SetRepeatedBool(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + case FieldDescriptor::CPPTYPE_STRING: { if (!CheckAndSetString( arg, message, field_descriptor, reflection, false, index)) { return -1; } break; } - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + case FieldDescriptor::CPPTYPE_ENUM: { GOOGLE_CHECK_GET_INT32(arg, value, -1); - const google::protobuf::EnumDescriptor* enum_descriptor = - field_descriptor->enum_type(); - const google::protobuf::EnumValueDescriptor* enum_value = - enum_descriptor->FindValueByNumber(value); - if (enum_value != NULL) { - reflection->SetRepeatedEnum(message, field_descriptor, index, - enum_value); + if (reflection->SupportsUnknownEnumValues()) { + reflection->SetRepeatedEnumValue(message, field_descriptor, index, + value); } else { - ScopedPyObjectPtr s(PyObject_Str(arg)); - if (s != NULL) { - PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", - PyString_AsString(s.get())); + const EnumDescriptor* enum_descriptor = field_descriptor->enum_type(); + const EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->SetRepeatedEnum(message, field_descriptor, index, + enum_value); + } else { + ScopedPyObjectPtr s(PyObject_Str(arg)); + if (s != NULL) { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", + PyString_AsString(s.get())); + } + return -1; } - return -1; } break; } @@ -191,10 +189,9 @@ static int AssignItem(RepeatedScalarContainer* self, } static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) { - google::protobuf::Message* message = self->message; - const google::protobuf::FieldDescriptor* field_descriptor = - self->parent_field->descriptor; - const google::protobuf::Reflection* reflection = message->GetReflection(); + Message* message = self->message; + const FieldDescriptor* field_descriptor = self->parent_field_descriptor; + const Reflection* reflection = message->GetReflection(); int field_size = reflection->FieldSize(*message, field_descriptor); if (index < 0) { @@ -202,80 +199,80 @@ static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) { } if (index < 0 || index >= field_size) { PyErr_Format(PyExc_IndexError, - "list assignment index (%d) out of range", - static_cast<int>(index)); + "list index (%zd) out of range", + index); return NULL; } PyObject* result = NULL; switch (field_descriptor->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + case FieldDescriptor::CPPTYPE_INT32: { int32 value = reflection->GetRepeatedInt32( *message, field_descriptor, index); result = PyInt_FromLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + case FieldDescriptor::CPPTYPE_INT64: { int64 value = reflection->GetRepeatedInt64( *message, field_descriptor, index); result = PyLong_FromLongLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + case FieldDescriptor::CPPTYPE_UINT32: { uint32 value = reflection->GetRepeatedUInt32( *message, field_descriptor, index); result = PyLong_FromLongLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + case FieldDescriptor::CPPTYPE_UINT64: { uint64 value = reflection->GetRepeatedUInt64( *message, field_descriptor, index); result = PyLong_FromUnsignedLongLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + case FieldDescriptor::CPPTYPE_FLOAT: { float value = reflection->GetRepeatedFloat( *message, field_descriptor, index); result = PyFloat_FromDouble(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + case FieldDescriptor::CPPTYPE_DOUBLE: { double value = reflection->GetRepeatedDouble( *message, field_descriptor, index); result = PyFloat_FromDouble(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + case FieldDescriptor::CPPTYPE_BOOL: { bool value = reflection->GetRepeatedBool( *message, field_descriptor, index); result = PyBool_FromLong(value ? 1 : 0); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { - const google::protobuf::EnumValueDescriptor* enum_value = + case FieldDescriptor::CPPTYPE_ENUM: { + const EnumValueDescriptor* enum_value = message->GetReflection()->GetRepeatedEnum( *message, field_descriptor, index); result = PyInt_FromLong(enum_value->number()); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + case FieldDescriptor::CPPTYPE_STRING: { string value = reflection->GetRepeatedString( *message, field_descriptor, index); result = ToStringObject(field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + case FieldDescriptor::CPPTYPE_MESSAGE: { PyObject* py_cmsg = PyObject_CallObject(reinterpret_cast<PyObject*>( &CMessage_Type), NULL); if (py_cmsg == NULL) { return NULL; } CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg); - const google::protobuf::Message& msg = reflection->GetRepeatedMessage( + const Message& msg = reflection->GetRepeatedMessage( *message, field_descriptor, index); cmsg->owner = self->owner; cmsg->parent = self->parent; - cmsg->message = const_cast<google::protobuf::Message*>(&msg); + cmsg->message = const_cast<Message*>(&msg); cmsg->read_only = false; result = reinterpret_cast<PyObject*>(py_cmsg); break; @@ -337,7 +334,7 @@ static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) { break; } ScopedPyObjectPtr s(Item(self, index)); - PyList_Append(list, s); + PyList_Append(list, s.get()); } } else { if (step > 0) { @@ -348,7 +345,7 @@ static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) { break; } ScopedPyObjectPtr s(Item(self, index)); - PyList_Append(list, s); + PyList_Append(list, s.get()); } } return list; @@ -356,75 +353,71 @@ static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) { PyObject* Append(RepeatedScalarContainer* self, PyObject* item) { cmessage::AssureWritable(self->parent); - google::protobuf::Message* message = self->message; - const google::protobuf::FieldDescriptor* field_descriptor = - self->parent_field->descriptor; - - if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { - PyErr_SetString( - PyExc_KeyError, "Field does not belong to message!"); - return NULL; - } + Message* message = self->message; + const FieldDescriptor* field_descriptor = self->parent_field_descriptor; - const google::protobuf::Reflection* reflection = message->GetReflection(); + const Reflection* reflection = message->GetReflection(); switch (field_descriptor->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + case FieldDescriptor::CPPTYPE_INT32: { GOOGLE_CHECK_GET_INT32(item, value, NULL); reflection->AddInt32(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + case FieldDescriptor::CPPTYPE_INT64: { GOOGLE_CHECK_GET_INT64(item, value, NULL); reflection->AddInt64(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + case FieldDescriptor::CPPTYPE_UINT32: { GOOGLE_CHECK_GET_UINT32(item, value, NULL); reflection->AddUInt32(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + case FieldDescriptor::CPPTYPE_UINT64: { GOOGLE_CHECK_GET_UINT64(item, value, NULL); reflection->AddUInt64(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + case FieldDescriptor::CPPTYPE_FLOAT: { GOOGLE_CHECK_GET_FLOAT(item, value, NULL); reflection->AddFloat(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + case FieldDescriptor::CPPTYPE_DOUBLE: { GOOGLE_CHECK_GET_DOUBLE(item, value, NULL); reflection->AddDouble(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + case FieldDescriptor::CPPTYPE_BOOL: { GOOGLE_CHECK_GET_BOOL(item, value, NULL); reflection->AddBool(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + case FieldDescriptor::CPPTYPE_STRING: { if (!CheckAndSetString( item, message, field_descriptor, reflection, true, -1)) { return NULL; } break; } - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + case FieldDescriptor::CPPTYPE_ENUM: { GOOGLE_CHECK_GET_INT32(item, value, NULL); - const google::protobuf::EnumDescriptor* enum_descriptor = - field_descriptor->enum_type(); - const google::protobuf::EnumValueDescriptor* enum_value = - enum_descriptor->FindValueByNumber(value); - if (enum_value != NULL) { - reflection->AddEnum(message, field_descriptor, enum_value); + if (reflection->SupportsUnknownEnumValues()) { + reflection->AddEnumValue(message, field_descriptor, value); } else { - ScopedPyObjectPtr s(PyObject_Str(item)); - if (s != NULL) { - PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", - PyString_AsString(s.get())); + const EnumDescriptor* enum_descriptor = field_descriptor->enum_type(); + const EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->AddEnum(message, field_descriptor, enum_value); + } else { + ScopedPyObjectPtr s(PyObject_Str(item)); + if (s != NULL) { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", + PyString_AsString(s.get())); + } + return NULL; } - return NULL; } break; } @@ -449,9 +442,9 @@ static int AssSubscript(RepeatedScalarContainer* self, bool create_list = false; cmessage::AssureWritable(self->parent); - google::protobuf::Message* message = self->message; - const google::protobuf::FieldDescriptor* field_descriptor = - self->parent_field->descriptor; + Message* message = self->message; + const FieldDescriptor* field_descriptor = + self->parent_field_descriptor; #if PY_MAJOR_VERSION < 3 if (PyInt_Check(slice)) { @@ -461,7 +454,7 @@ static int AssSubscript(RepeatedScalarContainer* self, if (PyLong_Check(slice)) { from = to = PyLong_AsLong(slice); } else if (PySlice_Check(slice)) { - const google::protobuf::Reflection* reflection = message->GetReflection(); + const Reflection* reflection = message->GetReflection(); length = reflection->FieldSize(*message, field_descriptor); #if PY_MAJOR_VERSION >= 3 if (PySlice_GetIndicesEx(slice, @@ -479,7 +472,7 @@ static int AssSubscript(RepeatedScalarContainer* self, if (value == NULL) { return cmessage::InternalDeleteRepeatedField( - message, field_descriptor, slice, NULL); + self->parent, field_descriptor, slice, NULL); } if (!create_list) { @@ -490,30 +483,36 @@ static int AssSubscript(RepeatedScalarContainer* self, if (full_slice == NULL) { return -1; } - ScopedPyObjectPtr new_list(Subscript(self, full_slice)); + ScopedPyObjectPtr new_list(Subscript(self, full_slice.get())); if (new_list == NULL) { return -1; } - if (PySequence_SetSlice(new_list, from, to, value) < 0) { + if (PySequence_SetSlice(new_list.get(), from, to, value) < 0) { return -1; } - return InternalAssignRepeatedField(self, new_list); + return InternalAssignRepeatedField(self, new_list.get()); } PyObject* Extend(RepeatedScalarContainer* self, PyObject* value) { cmessage::AssureWritable(self->parent); - if (PyObject_Not(value)) { + + // TODO(ptucker): Deprecate this behavior. b/18413862 + if (value == Py_None) { + Py_RETURN_NONE; + } + if ((Py_TYPE(value)->tp_as_sequence == NULL) && PyObject_Not(value)) { Py_RETURN_NONE; } + ScopedPyObjectPtr iter(PyObject_GetIter(value)); if (iter == NULL) { PyErr_SetString(PyExc_TypeError, "Value must be iterable"); return NULL; } ScopedPyObjectPtr next; - while ((next.reset(PyIter_Next(iter))) != NULL) { - if (Append(self, next) == NULL) { + while ((next.reset(PyIter_Next(iter.get()))) != NULL) { + if (ScopedPyObjectPtr(Append(self, next.get())) == NULL) { return NULL; } } @@ -530,11 +529,11 @@ static PyObject* Insert(RepeatedScalarContainer* self, PyObject* args) { return NULL; } ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); - ScopedPyObjectPtr new_list(Subscript(self, full_slice)); - if (PyList_Insert(new_list, index, value) < 0) { + ScopedPyObjectPtr new_list(Subscript(self, full_slice.get())); + if (PyList_Insert(new_list.get(), index, value) < 0) { return NULL; } - int ret = InternalAssignRepeatedField(self, new_list); + int ret = InternalAssignRepeatedField(self, new_list.get()); if (ret < 0) { return NULL; } @@ -545,7 +544,7 @@ static PyObject* Remove(RepeatedScalarContainer* self, PyObject* value) { Py_ssize_t match_index = -1; for (Py_ssize_t i = 0; i < Len(self); ++i) { ScopedPyObjectPtr elem(Item(self, i)); - if (PyObject_RichCompareBool(elem, value, Py_EQ)) { + if (PyObject_RichCompareBool(elem.get(), value, Py_EQ)) { match_index = i; break; } @@ -580,15 +579,15 @@ static PyObject* RichCompare(RepeatedScalarContainer* self, ScopedPyObjectPtr other_list_deleter; if (PyObject_TypeCheck(other, &RepeatedScalarContainer_Type)) { other_list_deleter.reset(Subscript( - reinterpret_cast<RepeatedScalarContainer*>(other), full_slice)); + reinterpret_cast<RepeatedScalarContainer*>(other), full_slice.get())); other = other_list_deleter.get(); } - ScopedPyObjectPtr list(Subscript(self, full_slice)); + ScopedPyObjectPtr list(Subscript(self, full_slice.get())); if (list == NULL) { return NULL; } - return PyObject_RichCompare(list, other, opid); + return PyObject_RichCompare(list.get(), other, opid); } PyObject* Reduce(RepeatedScalarContainer* unused_self) { @@ -619,66 +618,63 @@ static PyObject* Sort(RepeatedScalarContainer* self, if (full_slice == NULL) { return NULL; } - ScopedPyObjectPtr list(Subscript(self, full_slice)); + ScopedPyObjectPtr list(Subscript(self, full_slice.get())); if (list == NULL) { return NULL; } - ScopedPyObjectPtr m(PyObject_GetAttrString(list, "sort")); + ScopedPyObjectPtr m(PyObject_GetAttrString(list.get(), "sort")); if (m == NULL) { return NULL; } - ScopedPyObjectPtr res(PyObject_Call(m, args, kwds)); + ScopedPyObjectPtr res(PyObject_Call(m.get(), args, kwds)); if (res == NULL) { return NULL; } - int ret = InternalAssignRepeatedField(self, list); + int ret = InternalAssignRepeatedField(self, list.get()); if (ret < 0) { return NULL; } Py_RETURN_NONE; } -static int Init(RepeatedScalarContainer* self, - PyObject* args, - PyObject* kwargs) { - PyObject* py_parent; - PyObject* py_parent_field; - if (!PyArg_UnpackTuple(args, "__init__()", 2, 2, &py_parent, - &py_parent_field)) { - return -1; +static PyObject* Pop(RepeatedScalarContainer* self, + PyObject* args) { + Py_ssize_t index = -1; + if (!PyArg_ParseTuple(args, "|n", &index)) { + return NULL; } - - if (!PyObject_TypeCheck(py_parent, &CMessage_Type)) { - PyErr_Format(PyExc_TypeError, - "expect %s, but got %s", - CMessage_Type.tp_name, - Py_TYPE(py_parent)->tp_name); - return -1; + PyObject* item = Item(self, index); + if (item == NULL) { + PyErr_Format(PyExc_IndexError, + "list index (%zd) out of range", + index); + return NULL; } - - if (!PyObject_TypeCheck(py_parent_field, &CFieldDescriptor_Type)) { - PyErr_Format(PyExc_TypeError, - "expect %s, but got %s", - CFieldDescriptor_Type.tp_name, - Py_TYPE(py_parent_field)->tp_name); - return -1; + if (AssignItem(self, index, NULL) < 0) { + return NULL; } + return item; +} - CMessage* cmessage = reinterpret_cast<CMessage*>(py_parent); - CFieldDescriptor* cdescriptor = reinterpret_cast<CFieldDescriptor*>( - py_parent_field); +// The private constructor of RepeatedScalarContainer objects. +PyObject *NewContainer( + CMessage* parent, const FieldDescriptor* parent_field_descriptor) { + if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { + return NULL; + } - if (!FIELD_BELONGS_TO_MESSAGE(cdescriptor->descriptor, cmessage->message)) { - PyErr_SetString( - PyExc_KeyError, "Field does not belong to message!"); - return -1; + RepeatedScalarContainer* self = reinterpret_cast<RepeatedScalarContainer*>( + PyType_GenericAlloc(&RepeatedScalarContainer_Type, 0)); + if (self == NULL) { + return NULL; } - self->message = cmessage->message; - self->parent = cmessage; - self->parent_field = cdescriptor; - self->owner = cmessage->owner; - return 0; + self->message = parent->message; + self->parent = parent; + self->parent_field_descriptor = parent_field_descriptor; + self->owner = parent->owner; + + return reinterpret_cast<PyObject*>(self); } // Initializes the underlying Message object of "to" so it becomes a new parent @@ -692,20 +688,16 @@ static int InitializeAndCopyToParentContainer( if (full_slice == NULL) { return -1; } - ScopedPyObjectPtr values(Subscript(from, full_slice)); + ScopedPyObjectPtr values(Subscript(from, full_slice.get())); if (values == NULL) { return -1; } - google::protobuf::Message* new_message = global_message_factory->GetPrototype( - from->message->GetDescriptor())->New(); + Message* new_message = from->message->New(); to->parent = NULL; - // TODO(anuraag): Document why it's OK to hang on to parent_field, - // even though it's a weak reference. It ought to be enough to - // hold on to the FieldDescriptor only. - to->parent_field = from->parent_field; + to->parent_field_descriptor = from->parent_field_descriptor; to->message = new_message; to->owner.reset(new_message); - if (InternalAssignRepeatedField(to, values) < 0) { + if (InternalAssignRepeatedField(to, values.get()) < 0) { return -1; } return 0; @@ -716,23 +708,17 @@ int Release(RepeatedScalarContainer* self) { } PyObject* DeepCopy(RepeatedScalarContainer* self, PyObject* arg) { - ScopedPyObjectPtr init_args( - PyTuple_Pack(2, self->parent, self->parent_field)); - PyObject* clone = PyObject_CallObject( - reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type), init_args); + RepeatedScalarContainer* clone = reinterpret_cast<RepeatedScalarContainer*>( + PyType_GenericAlloc(&RepeatedScalarContainer_Type, 0)); if (clone == NULL) { return NULL; } - if (!PyObject_TypeCheck(clone, &RepeatedScalarContainer_Type)) { - Py_DECREF(clone); - return NULL; - } - if (InitializeAndCopyToParentContainer( - self, reinterpret_cast<RepeatedScalarContainer*>(clone)) < 0) { + + if (InitializeAndCopyToParentContainer(self, clone) < 0) { Py_DECREF(clone); return NULL; } - return clone; + return reinterpret_cast<PyObject*>(clone); } static void Dealloc(RepeatedScalarContainer* self) { @@ -771,6 +757,8 @@ static PyMethodDef Methods[] = { "Appends objects to the repeated container." }, { "insert", (PyCFunction)Insert, METH_VARARGS, "Appends objects to the repeated container." }, + { "pop", (PyCFunction)Pop, METH_VARARGS, + "Removes an object from the repeated container and returns it." }, { "remove", (PyCFunction)Remove, METH_O, "Removes an object from the repeated container." }, { "sort", (PyCFunction)Sort, METH_VARARGS | METH_KEYWORDS, @@ -782,8 +770,7 @@ static PyMethodDef Methods[] = { PyTypeObject RepeatedScalarContainer_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - "google.protobuf.internal." - "cpp._message.RepeatedScalarContainer", // tp_name + FULL_MODULE_NAME ".RepeatedScalarContainer", // tp_name sizeof(RepeatedScalarContainer), // tp_basicsize 0, // tp_itemsize (destructor)repeated_scalar_container::Dealloc, // tp_dealloc @@ -795,7 +782,7 @@ PyTypeObject RepeatedScalarContainer_Type = { 0, // tp_as_number &repeated_scalar_container::SqMethods, // tp_as_sequence &repeated_scalar_container::MpMethods, // tp_as_mapping - 0, // tp_hash + PyObject_HashNotImplemented, // tp_hash 0, // tp_call 0, // tp_str 0, // tp_getattro @@ -817,7 +804,7 @@ PyTypeObject RepeatedScalarContainer_Type = { 0, // tp_descr_get 0, // tp_descr_set 0, // tp_dictoffset - (initproc)repeated_scalar_container::Init, // tp_init + 0, // tp_init }; } // namespace python diff --git a/python/google/protobuf/pyext/repeated_scalar_container.h b/python/google/protobuf/pyext/repeated_scalar_container.h index 69d15d5cd..555e621c9 100644 --- a/python/google/protobuf/pyext/repeated_scalar_container.h +++ b/python/google/protobuf/pyext/repeated_scalar_container.h @@ -41,17 +41,21 @@ #include <google/protobuf/stubs/shared_ptr.h> #endif +#include <google/protobuf/descriptor.h> namespace google { namespace protobuf { class Message; +#ifdef _SHARED_PTR_H +using std::shared_ptr; +#else using internal::shared_ptr; +#endif namespace python { -struct CFieldDescriptor; struct CMessage; typedef struct RepeatedScalarContainer { @@ -73,16 +77,22 @@ typedef struct RepeatedScalarContainer { // modifying the container. CMessage* parent; - // Weak reference to the parent's descriptor that describes this + // Pointer to the parent's descriptor that describes this // field. Used together with the parent's message when making a // default message instance mutable. - CFieldDescriptor* parent_field; + // The pointer is owned by the global DescriptorPool. + const FieldDescriptor* parent_field_descriptor; } RepeatedScalarContainer; extern PyTypeObject RepeatedScalarContainer_Type; namespace repeated_scalar_container { +// Builds a RepeatedScalarContainer object, from a parent message and a +// field descriptor. +extern PyObject *NewContainer( + CMessage* parent, const FieldDescriptor* parent_field_descriptor); + // Appends the scalar 'item' to the end of the container 'self'. // // Returns None if successful; returns NULL and sets an exception if diff --git a/python/google/protobuf/pyext/scoped_pyobject_ptr.h b/python/google/protobuf/pyext/scoped_pyobject_ptr.h index 9f337c3c8..a128cd4c6 100644 --- a/python/google/protobuf/pyext/scoped_pyobject_ptr.h +++ b/python/google/protobuf/pyext/scoped_pyobject_ptr.h @@ -33,12 +33,14 @@ #ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__ #define GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__ +#include <google/protobuf/stubs/common.h> + #include <Python.h> namespace google { class ScopedPyObjectPtr { public: - // Constructor. Defaults to intializing with NULL. + // Constructor. Defaults to initializing with NULL. // There is no way to create an uninitialized ScopedPyObjectPtr. explicit ScopedPyObjectPtr(PyObject* p = NULL) : ptr_(p) { } @@ -49,24 +51,23 @@ class ScopedPyObjectPtr { // Reset. Deletes the current owned object, if any. // Then takes ownership of a new object, if given. - // this->reset(this->get()) works. + // This function must be called with a reference that you own. + // this->reset(this->get()) is wrong! + // this->reset(this->release()) is OK. PyObject* reset(PyObject* p = NULL) { - if (p != ptr_) { - Py_XDECREF(ptr_); - ptr_ = p; - } + Py_XDECREF(ptr_); + ptr_ = p; return ptr_; } // Releases ownership of the object. + // The caller now owns the returned reference. PyObject* release() { PyObject* p = ptr_; ptr_ = NULL; return p; } - operator PyObject*() { return ptr_; } - PyObject* operator->() const { assert(ptr_ != NULL); return ptr_; diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index 1fc704a23..0c757264f 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -49,108 +49,23 @@ __author__ = 'robinson@google.com (Will Robinson)' from google.protobuf.internal import api_implementation -from google.protobuf import descriptor as descriptor_mod from google.protobuf import message -_FieldDescriptor = descriptor_mod.FieldDescriptor - if api_implementation.Type() == 'cpp': - if api_implementation.Version() == 2: - from google.protobuf.pyext import cpp_message - _NewMessage = cpp_message.NewMessage - _InitMessage = cpp_message.InitMessage - else: - from google.protobuf.internal import cpp_message - _NewMessage = cpp_message.NewMessage - _InitMessage = cpp_message.InitMessage + from google.protobuf.pyext import cpp_message as message_impl else: - from google.protobuf.internal import python_message - _NewMessage = python_message.NewMessage - _InitMessage = python_message.InitMessage - - -class GeneratedProtocolMessageType(type): - - """Metaclass for protocol message classes created at runtime from Descriptors. - - We add implementations for all methods described in the Message class. We - also create properties to allow getting/setting all fields in the protocol - message. Finally, we create slots to prevent users from accidentally - "setting" nonexistent fields in the protocol message, which then wouldn't get - serialized / deserialized properly. + from google.protobuf.internal import python_message as message_impl - The protocol compiler currently uses this metaclass to create protocol - message classes at runtime. Clients can also manually create their own - classes at runtime, as in this example: - - mydescriptor = Descriptor(.....) - class MyProtoClass(Message): - __metaclass__ = GeneratedProtocolMessageType - DESCRIPTOR = mydescriptor - myproto_instance = MyProtoClass() - myproto.foo_field = 23 - ... - - The above example will not work for nested types. If you wish to include them, - use reflection.MakeClass() instead of manually instantiating the class in - order to create the appropriate class structure. - """ - - # Must be consistent with the protocol-compiler code in - # proto2/compiler/internal/generator.*. - _DESCRIPTOR_KEY = 'DESCRIPTOR' - - def __new__(cls, name, bases, dictionary): - """Custom allocation for runtime-generated class types. - - We override __new__ because this is apparently the only place - where we can meaningfully set __slots__ on the class we're creating(?). - (The interplay between metaclasses and slots is not very well-documented). - - Args: - name: Name of the class (ignored, but required by the - metaclass protocol). - bases: Base classes of the class we're constructing. - (Should be message.Message). We ignore this field, but - it's required by the metaclass protocol - dictionary: The class dictionary of the class we're - constructing. dictionary[_DESCRIPTOR_KEY] must contain - a Descriptor object describing this protocol message - type. - - Returns: - Newly-allocated class. - """ - descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] - bases = _NewMessage(bases, descriptor, dictionary) - superclass = super(GeneratedProtocolMessageType, cls) - - new_class = superclass.__new__(cls, name, bases, dictionary) - setattr(descriptor, '_concrete_class', new_class) - return new_class - - def __init__(cls, name, bases, dictionary): - """Here we perform the majority of our work on the class. - We add enum getters, an __init__ method, implementations - of all Message methods, and properties for all fields - in the protocol type. - - Args: - name: Name of the class (ignored, but required by the - metaclass protocol). - bases: Base classes of the class we're constructing. - (Should be message.Message). We ignore this field, but - it's required by the metaclass protocol - dictionary: The class dictionary of the class we're - constructing. dictionary[_DESCRIPTOR_KEY] must contain - a Descriptor object describing this protocol message - type. - """ - descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] - _InitMessage(descriptor, cls) - superclass = super(GeneratedProtocolMessageType, cls) - superclass.__init__(name, bases, dictionary) +# The type of all Message classes. +# Part of the public interface. +# +# Used by generated files, but clients can also use it at runtime: +# mydescriptor = pool.FindDescriptor(.....) +# class MyProtoClass(Message): +# __metaclass__ = GeneratedProtocolMessageType +# DESCRIPTOR = mydescriptor +GeneratedProtocolMessageType = message_impl.GeneratedProtocolMessageType def ParseMessage(descriptor, byte_str): diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py index 4c70b3938..87760f263 100644 --- a/python/google/protobuf/symbol_database.py +++ b/python/google/protobuf/symbol_database.py @@ -72,12 +72,12 @@ class SymbolDatabase(object): buffer types used within a program. """ - def __init__(self): + def __init__(self, pool=None): """Constructor.""" self._symbols = {} self._symbols_by_file = {} - self.pool = descriptor_pool.DescriptorPool() + self.pool = pool or descriptor_pool.Default() def RegisterMessage(self, message): """Registers the given message type in the local database. @@ -177,7 +177,7 @@ class SymbolDatabase(object): result.update(self._symbols_by_file[f]) return result -_DEFAULT = SymbolDatabase() +_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default()) def Default(): diff --git a/python/google/protobuf/text_encoding.py b/python/google/protobuf/text_encoding.py index 2d86a67ca..989956382 100644 --- a/python/google/protobuf/text_encoding.py +++ b/python/google/protobuf/text_encoding.py @@ -28,15 +28,13 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#PY25 compatible for GAE. -# """Encoding related utilities.""" - import re -import sys ##PY25 + +import six # Lookup table for utf8 -_cescape_utf8_to_str = [chr(i) for i in xrange(0, 256)] +_cescape_utf8_to_str = [chr(i) for i in range(0, 256)] _cescape_utf8_to_str[9] = r'\t' # optional escape _cescape_utf8_to_str[10] = r'\n' # optional escape _cescape_utf8_to_str[13] = r'\r' # optional escape @@ -46,9 +44,9 @@ _cescape_utf8_to_str[34] = r'\"' # necessary escape _cescape_utf8_to_str[92] = r'\\' # necessary escape # Lookup table for non-utf8, with necessary escapes at (o >= 127 or o < 32) -_cescape_byte_to_str = ([r'\%03o' % i for i in xrange(0, 32)] + - [chr(i) for i in xrange(32, 127)] + - [r'\%03o' % i for i in xrange(127, 256)]) +_cescape_byte_to_str = ([r'\%03o' % i for i in range(0, 32)] + + [chr(i) for i in range(32, 127)] + + [r'\%03o' % i for i in range(127, 256)]) _cescape_byte_to_str[9] = r'\t' # optional escape _cescape_byte_to_str[10] = r'\n' # optional escape _cescape_byte_to_str[13] = r'\r' # optional escape @@ -75,7 +73,7 @@ def CEscape(text, as_utf8): """ # PY3 hack: make Ord work for str and bytes: # //platforms/networking/data uses unicode here, hence basestring. - Ord = ord if isinstance(text, basestring) else lambda x: x + Ord = ord if isinstance(text, six.string_types) else lambda x: x if as_utf8: return ''.join(_cescape_utf8_to_str[Ord(c)] for c in text) return ''.join(_cescape_byte_to_str[Ord(c)] for c in text) @@ -100,8 +98,7 @@ def CUnescape(text): # allow single-digit hex escapes (like '\xf'). result = _CUNESCAPE_HEX.sub(ReplaceHex, text) - if sys.version_info[0] < 3: ##PY25 -##!PY25 if str is bytes: # PY2 + if str is bytes: # PY2 return result.decode('string_escape') result = ''.join(_cescape_highbit_to_str[ord(c)] for c in result) return (result.encode('ascii') # Make it bytes to allow decode. diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 2429fa59f..6f1e3c8b7 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -28,17 +28,28 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#PY25 compatible for GAE. -# -# Copyright 2007 Google Inc. All Rights Reserved. +"""Contains routines for printing protocol messages in text format. + +Simple usage example: + + # Create a proto object and serialize it to a text proto string. + message = my_proto_pb2.MyMessage(foo='bar') + text_proto = text_format.MessageToString(message) -"""Contains routines for printing protocol messages in text format.""" + # Parse a text proto string. + message = text_format.Parse(text_proto, my_proto_pb2.MyMessage()) +""" __author__ = 'kenton@google.com (Kenton Varda)' -import cStringIO +import io import re +import six + +if six.PY3: + long = int + from google.protobuf.internal import type_checkers from google.protobuf import descriptor from google.protobuf import text_encoding @@ -55,6 +66,7 @@ _FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE) _FLOAT_NAN = re.compile('nanf?', re.IGNORECASE) _FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT, descriptor.FieldDescriptor.CPPTYPE_DOUBLE]) +_QUOTES = frozenset(("'", '"')) class Error(Exception): @@ -62,17 +74,38 @@ class Error(Exception): class ParseError(Error): - """Thrown in case of ASCII parsing error.""" + """Thrown in case of text parsing error.""" + + +class TextWriter(object): + def __init__(self, as_utf8): + if six.PY2: + self._writer = io.BytesIO() + else: + self._writer = io.StringIO() + + def write(self, val): + if six.PY2: + if isinstance(val, six.text_type): + val = val.encode('utf-8') + return self._writer.write(val) + + def close(self): + return self._writer.close() + + def getvalue(self): + return self._writer.getvalue() def MessageToString(message, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, - float_format=None): + float_format=None, use_field_number=False): """Convert protobuf message to text format. Floating point values can be formatted compactly with 15 digits of precision (which is the most that IEEE 754 "double" can guarantee) - using float_format='.15g'. + using float_format='.15g'. To ensure that converting to text and back to a + proto will result in an identical value, float_format='.17g' should be used. Args: message: The protocol buffers message. @@ -85,15 +118,16 @@ def MessageToString(message, as_utf8=False, as_one_line=False, field number order. float_format: If set, use this to specify floating point number formatting (per the "Format Specification Mini-Language"); otherwise, str() is used. + use_field_number: If True, print field numbers instead of names. Returns: A string of the text formatted protocol buffer message. """ - out = cStringIO.StringIO() - PrintMessage(message, out, as_utf8=as_utf8, as_one_line=as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) + out = TextWriter(as_utf8) + printer = _Printer(out, 0, as_utf8, as_one_line, + pointy_brackets, use_index_order, float_format, + use_field_number) + printer.PrintMessage(message) result = out.getvalue() out.close() if as_one_line: @@ -101,262 +135,446 @@ def MessageToString(message, as_utf8=False, as_one_line=False, return result +def _IsMapEntry(field): + return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.message_type.has_options and + field.message_type.GetOptions().map_entry) + + def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, - float_format=None): - fields = message.ListFields() - if use_index_order: - fields.sort(key=lambda x: x[0].index) - for field, value in fields: - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - for element in value: - PrintField(field, element, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - float_format=float_format) - else: - PrintField(field, value, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - float_format=float_format) + float_format=None, use_field_number=False): + printer = _Printer(out, indent, as_utf8, as_one_line, + pointy_brackets, use_index_order, float_format, + use_field_number) + printer.PrintMessage(message) def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False, - pointy_brackets=False, float_format=None): - """Print a single field name/value pair. For repeated fields, the value - should be a single element.""" - - out.write(' ' * indent) - if field.is_extension: - out.write('[') - if (field.containing_type.GetOptions().message_set_wire_format and - field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and - field.message_type == field.extension_scope and - field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): - out.write(field.message_type.full_name) - else: - out.write(field.full_name) - out.write(']') - elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: - # For groups, use the capitalized name. - out.write(field.message_type.name) - else: - out.write(field.name) - - if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - # The colon is optional in this case, but our cross-language golden files - # don't include it. - out.write(': ') - - PrintFieldValue(field, value, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - float_format=float_format) - if as_one_line: - out.write(' ') - else: - out.write('\n') + pointy_brackets=False, use_index_order=False, float_format=None): + """Print a single field name/value pair.""" + printer = _Printer(out, indent, as_utf8, as_one_line, + pointy_brackets, use_index_order, float_format) + printer.PrintField(field, value) def PrintFieldValue(field, value, out, indent=0, as_utf8=False, as_one_line=False, pointy_brackets=False, + use_index_order=False, float_format=None): - """Print a single field value (not including name). For repeated fields, - the value should be a single element.""" + """Print a single field value (not including name).""" + printer = _Printer(out, indent, as_utf8, as_one_line, + pointy_brackets, use_index_order, float_format) + printer.PrintFieldValue(field, value) - if pointy_brackets: - openb = '<' - closeb = '>' - else: - openb = '{' - closeb = '}' - - if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - if as_one_line: - out.write(' %s ' % openb) - PrintMessage(value, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - float_format=float_format) - out.write(closeb) - else: - out.write(' %s\n' % openb) - PrintMessage(value, out, indent + 2, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - float_format=float_format) - out.write(' ' * indent + closeb) - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: - enum_value = field.enum_type.values_by_number.get(value, None) - if enum_value is not None: - out.write(enum_value.name) - else: - out.write(str(value)) - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: - out.write('\"') - if isinstance(value, unicode): - out_value = value.encode('utf-8') - else: - out_value = value - if field.type == descriptor.FieldDescriptor.TYPE_BYTES: - # We need to escape non-UTF8 chars in TYPE_BYTES field. - out_as_utf8 = False - else: - out_as_utf8 = as_utf8 - out.write(text_encoding.CEscape(out_value, out_as_utf8)) - out.write('\"') - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: - if value: - out.write('true') + +class _Printer(object): + """Text format printer for protocol message.""" + + def __init__(self, out, indent=0, as_utf8=False, as_one_line=False, + pointy_brackets=False, use_index_order=False, float_format=None, + use_field_number=False): + """Initialize the Printer. + + Floating point values can be formatted compactly with 15 digits of + precision (which is the most that IEEE 754 "double" can guarantee) + using float_format='.15g'. To ensure that converting to text and back to a + proto will result in an identical value, float_format='.17g' should be used. + + Args: + out: To record the text format result. + indent: The indent level for pretty print. + as_utf8: Produce text output in UTF8 format. + as_one_line: Don't introduce newlines between fields. + pointy_brackets: If True, use angle brackets instead of curly braces for + nesting. + use_index_order: If True, print fields of a proto message using the order + defined in source code instead of the field number. By default, use the + field number order. + float_format: If set, use this to specify floating point number formatting + (per the "Format Specification Mini-Language"); otherwise, str() is + used. + use_field_number: If True, print field numbers instead of names. + """ + self.out = out + self.indent = indent + self.as_utf8 = as_utf8 + self.as_one_line = as_one_line + self.pointy_brackets = pointy_brackets + self.use_index_order = use_index_order + self.float_format = float_format + self.use_field_number = use_field_number + + def PrintMessage(self, message): + """Convert protobuf message to text format. + + Args: + message: The protocol buffers message. + """ + fields = message.ListFields() + if self.use_index_order: + fields.sort(key=lambda x: x[0].index) + for field, value in fields: + if _IsMapEntry(field): + for key in sorted(value): + # This is slow for maps with submessage entires because it copies the + # entire tree. Unfortunately this would take significant refactoring + # of this file to work around. + # + # TODO(haberman): refactor and optimize if this becomes an issue. + entry_submsg = field.message_type._concrete_class( + key=key, value=value[key]) + self.PrintField(field, entry_submsg) + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + for element in value: + self.PrintField(field, element) + else: + self.PrintField(field, value) + + def PrintField(self, field, value): + """Print a single field name/value pair.""" + out = self.out + out.write(' ' * self.indent) + if self.use_field_number: + out.write(str(field.number)) else: - out.write('false') - elif field.cpp_type in _FLOAT_TYPES and float_format is not None: - out.write('{1:{0}}'.format(float_format, value)) - else: - out.write(str(value)) + if field.is_extension: + out.write('[') + if (field.containing_type.GetOptions().message_set_wire_format and + field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): + out.write(field.message_type.full_name) + else: + out.write(field.full_name) + out.write(']') + elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: + # For groups, use the capitalized name. + out.write(field.message_type.name) + else: + out.write(field.name) + if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + # The colon is optional in this case, but our cross-language golden files + # don't include it. + out.write(': ') -def _ParseOrMerge(lines, message, allow_multiple_scalars): - """Converts an ASCII representation of a protocol message into a message. + self.PrintFieldValue(field, value) + if self.as_one_line: + out.write(' ') + else: + out.write('\n') - Args: - lines: Lines of a message's ASCII representation. - message: A protocol buffer message to merge into. - allow_multiple_scalars: Determines if repeated values for a non-repeated - field are permitted, e.g., the string "foo: 1 foo: 2" for a - required/optional field named "foo". + def PrintFieldValue(self, field, value): + """Print a single field value (not including name). - Raises: - ParseError: On ASCII parsing problems. - """ - tokenizer = _Tokenizer(lines) - while not tokenizer.AtEnd(): - _MergeField(tokenizer, message, allow_multiple_scalars) + For repeated fields, the value should be a single element. + + Args: + field: The descriptor of the field to be printed. + value: The value of the field. + """ + out = self.out + if self.pointy_brackets: + openb = '<' + closeb = '>' + else: + openb = '{' + closeb = '}' + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + if self.as_one_line: + out.write(' %s ' % openb) + self.PrintMessage(value) + out.write(closeb) + else: + out.write(' %s\n' % openb) + self.indent += 2 + self.PrintMessage(value) + self.indent -= 2 + out.write(' ' * self.indent + closeb) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + enum_value = field.enum_type.values_by_number.get(value, None) + if enum_value is not None: + out.write(enum_value.name) + else: + out.write(str(value)) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + out.write('\"') + if isinstance(value, six.text_type): + out_value = value.encode('utf-8') + else: + out_value = value + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + # We need to escape non-UTF8 chars in TYPE_BYTES field. + out_as_utf8 = False + else: + out_as_utf8 = self.as_utf8 + out.write(text_encoding.CEscape(out_value, out_as_utf8)) + out.write('\"') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + if value: + out.write('true') + else: + out.write('false') + elif field.cpp_type in _FLOAT_TYPES and self.float_format is not None: + out.write('{1:{0}}'.format(self.float_format, value)) + else: + out.write(str(value)) -def Parse(text, message): - """Parses an ASCII representation of a protocol message into a message. +def Parse(text, message, + allow_unknown_extension=False, allow_field_number=False): + """Parses an text representation of a protocol message into a message. Args: - text: Message ASCII representation. + text: Message text representation. message: A protocol buffer message to merge into. + allow_unknown_extension: if True, skip over missing extensions and keep + parsing + allow_field_number: if True, both field number and field name are allowed. Returns: The same message passed as argument. Raises: - ParseError: On ASCII parsing problems. + ParseError: On text parsing problems. """ - if not isinstance(text, str): text = text.decode('utf-8') - return ParseLines(text.split('\n'), message) + if not isinstance(text, str): + text = text.decode('utf-8') + return ParseLines(text.split('\n'), message, allow_unknown_extension, + allow_field_number) -def Merge(text, message): - """Parses an ASCII representation of a protocol message into a message. +def Merge(text, message, allow_unknown_extension=False, + allow_field_number=False): + """Parses an text representation of a protocol message into a message. Like Parse(), but allows repeated values for a non-repeated field, and uses the last one. Args: - text: Message ASCII representation. + text: Message text representation. message: A protocol buffer message to merge into. + allow_unknown_extension: if True, skip over missing extensions and keep + parsing + allow_field_number: if True, both field number and field name are allowed. Returns: The same message passed as argument. Raises: - ParseError: On ASCII parsing problems. + ParseError: On text parsing problems. """ - return MergeLines(text.split('\n'), message) + return MergeLines(text.split('\n'), message, allow_unknown_extension, + allow_field_number) -def ParseLines(lines, message): - """Parses an ASCII representation of a protocol message into a message. +def ParseLines(lines, message, allow_unknown_extension=False, + allow_field_number=False): + """Parses an text representation of a protocol message into a message. Args: - lines: An iterable of lines of a message's ASCII representation. + lines: An iterable of lines of a message's text representation. message: A protocol buffer message to merge into. + allow_unknown_extension: if True, skip over missing extensions and keep + parsing + allow_field_number: if True, both field number and field name are allowed. Returns: The same message passed as argument. Raises: - ParseError: On ASCII parsing problems. + ParseError: On text parsing problems. """ - _ParseOrMerge(lines, message, False) - return message + parser = _Parser(allow_unknown_extension, allow_field_number) + return parser.ParseLines(lines, message) -def MergeLines(lines, message): - """Parses an ASCII representation of a protocol message into a message. +def MergeLines(lines, message, allow_unknown_extension=False, + allow_field_number=False): + """Parses an text representation of a protocol message into a message. Args: - lines: An iterable of lines of a message's ASCII representation. + lines: An iterable of lines of a message's text representation. message: A protocol buffer message to merge into. + allow_unknown_extension: if True, skip over missing extensions and keep + parsing + allow_field_number: if True, both field number and field name are allowed. Returns: The same message passed as argument. Raises: - ParseError: On ASCII parsing problems. + ParseError: On text parsing problems. """ - _ParseOrMerge(lines, message, True) - return message + parser = _Parser(allow_unknown_extension, allow_field_number) + return parser.MergeLines(lines, message) -def _MergeField(tokenizer, message, allow_multiple_scalars): - """Merges a single protocol message field into a message. +class _Parser(object): + """Text format parser for protocol message.""" - Args: - tokenizer: A tokenizer to parse the field name and values. - message: A protocol message to record the data. - allow_multiple_scalars: Determines if repeated values for a non-repeated - field are permitted, e.g., the string "foo: 1 foo: 2" for a - required/optional field named "foo". + def __init__(self, allow_unknown_extension=False, allow_field_number=False): + self.allow_unknown_extension = allow_unknown_extension + self.allow_field_number = allow_field_number - Raises: - ParseError: In case of ASCII parsing problems. - """ - message_descriptor = message.DESCRIPTOR - if tokenizer.TryConsume('['): - name = [tokenizer.ConsumeIdentifier()] - while tokenizer.TryConsume('.'): - name.append(tokenizer.ConsumeIdentifier()) - name = '.'.join(name) - - if not message_descriptor.is_extendable: - raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" does not have extensions.' % - message_descriptor.full_name) - # pylint: disable=protected-access - field = message.Extensions._FindExtensionByName(name) - # pylint: enable=protected-access - if not field: - raise tokenizer.ParseErrorPreviousToken( - 'Extension "%s" not registered.' % name) - elif message_descriptor != field.containing_type: - raise tokenizer.ParseErrorPreviousToken( - 'Extension "%s" does not extend message type "%s".' % ( - name, message_descriptor.full_name)) - tokenizer.Consume(']') - else: - name = tokenizer.ConsumeIdentifier() - field = message_descriptor.fields_by_name.get(name, None) + def ParseFromString(self, text, message): + """Parses an text representation of a protocol message into a message.""" + if not isinstance(text, str): + text = text.decode('utf-8') + return self.ParseLines(text.split('\n'), message) + + def ParseLines(self, lines, message): + """Parses an text representation of a protocol message into a message.""" + self._allow_multiple_scalars = False + self._ParseOrMerge(lines, message) + return message + + def MergeFromString(self, text, message): + """Merges an text representation of a protocol message into a message.""" + return self._MergeLines(text.split('\n'), message) - # Group names are expected to be capitalized as they appear in the - # .proto file, which actually matches their type names, not their field - # names. - if not field: - field = message_descriptor.fields_by_name.get(name.lower(), None) - if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP: - field = None + def MergeLines(self, lines, message): + """Merges an text representation of a protocol message into a message.""" + self._allow_multiple_scalars = True + self._ParseOrMerge(lines, message) + return message - if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and - field.message_type.name != name): - field = None + def _ParseOrMerge(self, lines, message): + """Converts an text representation of a protocol message into a message. - if not field: - raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" has no field named "%s".' % ( - message_descriptor.full_name, name)) + Args: + lines: Lines of a message's text representation. + message: A protocol buffer message to merge into. - if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - tokenizer.TryConsume(':') + Raises: + ParseError: On text parsing problems. + """ + tokenizer = _Tokenizer(lines) + while not tokenizer.AtEnd(): + self._MergeField(tokenizer, message) + + def _MergeField(self, tokenizer, message): + """Merges a single protocol message field into a message. + + Args: + tokenizer: A tokenizer to parse the field name and values. + message: A protocol message to record the data. + + Raises: + ParseError: In case of text parsing problems. + """ + message_descriptor = message.DESCRIPTOR + if (hasattr(message_descriptor, 'syntax') and + message_descriptor.syntax == 'proto3'): + # Proto3 doesn't represent presence so we can't test if multiple + # scalars have occurred. We have to allow them. + self._allow_multiple_scalars = True + if tokenizer.TryConsume('['): + name = [tokenizer.ConsumeIdentifier()] + while tokenizer.TryConsume('.'): + name.append(tokenizer.ConsumeIdentifier()) + name = '.'.join(name) + + if not message_descriptor.is_extendable: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" does not have extensions.' % + message_descriptor.full_name) + # pylint: disable=protected-access + field = message.Extensions._FindExtensionByName(name) + # pylint: enable=protected-access + if not field: + if self.allow_unknown_extension: + field = None + else: + raise tokenizer.ParseErrorPreviousToken( + 'Extension "%s" not registered.' % name) + elif message_descriptor != field.containing_type: + raise tokenizer.ParseErrorPreviousToken( + 'Extension "%s" does not extend message type "%s".' % ( + name, message_descriptor.full_name)) + + tokenizer.Consume(']') + + else: + name = tokenizer.ConsumeIdentifier() + if self.allow_field_number and name.isdigit(): + number = ParseInteger(name, True, True) + field = message_descriptor.fields_by_number.get(number, None) + if not field and message_descriptor.is_extendable: + field = message.Extensions._FindExtensionByNumber(number) + else: + field = message_descriptor.fields_by_name.get(name, None) + + # Group names are expected to be capitalized as they appear in the + # .proto file, which actually matches their type names, not their field + # names. + if not field: + field = message_descriptor.fields_by_name.get(name.lower(), None) + if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP: + field = None + + if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and + field.message_type.name != name): + field = None + + if not field: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" has no field named "%s".' % ( + message_descriptor.full_name, name)) + + if field: + if not self._allow_multiple_scalars and field.containing_oneof: + # Check if there's a different field set in this oneof. + # Note that we ignore the case if the same field was set before, and we + # apply _allow_multiple_scalars to non-scalar fields as well. + which_oneof = message.WhichOneof(field.containing_oneof.name) + if which_oneof is not None and which_oneof != field.name: + raise tokenizer.ParseErrorPreviousToken( + 'Field "%s" is specified along with field "%s", another member ' + 'of oneof "%s" for message type "%s".' % ( + field.name, which_oneof, field.containing_oneof.name, + message_descriptor.full_name)) + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + tokenizer.TryConsume(':') + merger = self._MergeMessageField + else: + tokenizer.Consume(':') + merger = self._MergeScalarField + + if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED + and tokenizer.TryConsume('[')): + # Short repeated format, e.g. "foo: [1, 2, 3]" + while True: + merger(tokenizer, message, field) + if tokenizer.TryConsume(']'): break + tokenizer.Consume(',') + + else: + merger(tokenizer, message, field) + + else: # Proto field is unknown. + assert self.allow_unknown_extension + _SkipFieldContents(tokenizer) + + # For historical reasons, fields may optionally be separated by commas or + # semicolons. + if not tokenizer.TryConsume(','): + tokenizer.TryConsume(';') + + def _MergeMessageField(self, tokenizer, message, field): + """Merges a single scalar field into a message. + + Args: + tokenizer: A tokenizer to parse the field value. + message: The message of which field is a member. + field: The descriptor of the field to be merged. + + Raises: + ParseError: In case of text parsing problems. + """ + is_map_entry = _IsMapEntry(field) if tokenizer.TryConsume('<'): end_token = '>' @@ -367,6 +585,9 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: if field.is_extension: sub_message = message.Extensions[field].add() + elif is_map_entry: + # pylint: disable=protected-access + sub_message = field.message_type._concrete_class() else: sub_message = getattr(message, field.name).add() else: @@ -378,10 +599,117 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): while not tokenizer.TryConsume(end_token): if tokenizer.AtEnd(): - raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token)) - _MergeField(tokenizer, sub_message, allow_multiple_scalars) + raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token,)) + self._MergeField(tokenizer, sub_message) + + if is_map_entry: + value_cpptype = field.message_type.fields_by_name['value'].cpp_type + if value_cpptype == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + value = getattr(message, field.name)[sub_message.key] + value.MergeFrom(sub_message.value) + else: + getattr(message, field.name)[sub_message.key] = sub_message.value + + def _MergeScalarField(self, tokenizer, message, field): + """Merges a single scalar field into a message. + + Args: + tokenizer: A tokenizer to parse the field value. + message: A protocol message to record the data. + field: The descriptor of the field to be merged. + + Raises: + ParseError: In case of text parsing problems. + RuntimeError: On runtime errors. + """ + _ = self.allow_unknown_extension + value = None + + if field.type in (descriptor.FieldDescriptor.TYPE_INT32, + descriptor.FieldDescriptor.TYPE_SINT32, + descriptor.FieldDescriptor.TYPE_SFIXED32): + value = tokenizer.ConsumeInt32() + elif field.type in (descriptor.FieldDescriptor.TYPE_INT64, + descriptor.FieldDescriptor.TYPE_SINT64, + descriptor.FieldDescriptor.TYPE_SFIXED64): + value = tokenizer.ConsumeInt64() + elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32, + descriptor.FieldDescriptor.TYPE_FIXED32): + value = tokenizer.ConsumeUint32() + elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64, + descriptor.FieldDescriptor.TYPE_FIXED64): + value = tokenizer.ConsumeUint64() + elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT, + descriptor.FieldDescriptor.TYPE_DOUBLE): + value = tokenizer.ConsumeFloat() + elif field.type == descriptor.FieldDescriptor.TYPE_BOOL: + value = tokenizer.ConsumeBool() + elif field.type == descriptor.FieldDescriptor.TYPE_STRING: + value = tokenizer.ConsumeString() + elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: + value = tokenizer.ConsumeByteString() + elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: + value = tokenizer.ConsumeEnum(field) + else: + raise RuntimeError('Unknown field type %d' % field.type) + + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_extension: + message.Extensions[field].append(value) + else: + getattr(message, field.name).append(value) + else: + if field.is_extension: + if not self._allow_multiple_scalars and message.HasExtension(field): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" extensions.' % + (message.DESCRIPTOR.full_name, field.full_name)) + else: + message.Extensions[field] = value + else: + if not self._allow_multiple_scalars and message.HasField(field.name): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" fields.' % + (message.DESCRIPTOR.full_name, field.name)) + else: + setattr(message, field.name, value) + + +def _SkipFieldContents(tokenizer): + """Skips over contents (value or message) of a field. + + Args: + tokenizer: A tokenizer to parse the field name and values. + """ + # Try to guess the type of this field. + # If this field is not a message, there should be a ":" between the + # field name and the field value and also the field value should not + # start with "{" or "<" which indicates the beginning of a message body. + # If there is no ":" or there is a "{" or "<" after ":", this field has + # to be a message or the input is ill-formed. + if tokenizer.TryConsume(':') and not tokenizer.LookingAt( + '{') and not tokenizer.LookingAt('<'): + _SkipFieldValue(tokenizer) + else: + _SkipFieldMessage(tokenizer) + + +def _SkipField(tokenizer): + """Skips over a complete field (name and value/message). + + Args: + tokenizer: A tokenizer to parse the field name and values. + """ + if tokenizer.TryConsume('['): + # Consume extension name. + tokenizer.ConsumeIdentifier() + while tokenizer.TryConsume('.'): + tokenizer.ConsumeIdentifier() + tokenizer.Consume(']') else: - _MergeScalarField(tokenizer, message, field, allow_multiple_scalars) + tokenizer.ConsumeIdentifier() + + _SkipFieldContents(tokenizer) # For historical reasons, fields may optionally be separated by commas or # semicolons. @@ -389,76 +717,50 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): tokenizer.TryConsume(';') -def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars): - """Merges a single protocol message scalar field into a message. +def _SkipFieldMessage(tokenizer): + """Skips over a field message. Args: - tokenizer: A tokenizer to parse the field value. - message: A protocol message to record the data. - field: The descriptor of the field to be merged. - allow_multiple_scalars: Determines if repeated values for a non-repeated - field are permitted, e.g., the string "foo: 1 foo: 2" for a - required/optional field named "foo". + tokenizer: A tokenizer to parse the field name and values. + """ + + if tokenizer.TryConsume('<'): + delimiter = '>' + else: + tokenizer.Consume('{') + delimiter = '}' + + while not tokenizer.LookingAt('>') and not tokenizer.LookingAt('}'): + _SkipField(tokenizer) + + tokenizer.Consume(delimiter) + + +def _SkipFieldValue(tokenizer): + """Skips over a field value. + + Args: + tokenizer: A tokenizer to parse the field name and values. Raises: - ParseError: In case of ASCII parsing problems. - RuntimeError: On runtime errors. + ParseError: In case an invalid field value is found. """ - tokenizer.Consume(':') - value = None - - if field.type in (descriptor.FieldDescriptor.TYPE_INT32, - descriptor.FieldDescriptor.TYPE_SINT32, - descriptor.FieldDescriptor.TYPE_SFIXED32): - value = tokenizer.ConsumeInt32() - elif field.type in (descriptor.FieldDescriptor.TYPE_INT64, - descriptor.FieldDescriptor.TYPE_SINT64, - descriptor.FieldDescriptor.TYPE_SFIXED64): - value = tokenizer.ConsumeInt64() - elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32, - descriptor.FieldDescriptor.TYPE_FIXED32): - value = tokenizer.ConsumeUint32() - elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64, - descriptor.FieldDescriptor.TYPE_FIXED64): - value = tokenizer.ConsumeUint64() - elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT, - descriptor.FieldDescriptor.TYPE_DOUBLE): - value = tokenizer.ConsumeFloat() - elif field.type == descriptor.FieldDescriptor.TYPE_BOOL: - value = tokenizer.ConsumeBool() - elif field.type == descriptor.FieldDescriptor.TYPE_STRING: - value = tokenizer.ConsumeString() - elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: - value = tokenizer.ConsumeByteString() - elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: - value = tokenizer.ConsumeEnum(field) - else: - raise RuntimeError('Unknown field type %d' % field.type) + # String/bytes tokens can come in multiple adjacent string literals. + # If we can consume one, consume as many as we can. + if tokenizer.TryConsumeByteString(): + while tokenizer.TryConsumeByteString(): + pass + return - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - if field.is_extension: - message.Extensions[field].append(value) - else: - getattr(message, field.name).append(value) - else: - if field.is_extension: - if not allow_multiple_scalars and message.HasExtension(field): - raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" should not have multiple "%s" extensions.' % - (message.DESCRIPTOR.full_name, field.full_name)) - else: - message.Extensions[field] = value - else: - if not allow_multiple_scalars and message.HasField(field.name): - raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" should not have multiple "%s" fields.' % - (message.DESCRIPTOR.full_name, field.name)) - else: - setattr(message, field.name, value) + if (not tokenizer.TryConsumeIdentifier() and + not tokenizer.TryConsumeInt64() and + not tokenizer.TryConsumeUint64() and + not tokenizer.TryConsumeFloat()): + raise ParseError('Invalid field value: ' + tokenizer.token) class _Tokenizer(object): - """Protocol buffer ASCII representation tokenizer. + """Protocol buffer text representation tokenizer. This class handles the lower level string parsing by splitting it into meaningful tokens. @@ -467,11 +769,13 @@ class _Tokenizer(object): """ _WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE) - _TOKEN = re.compile( - '[a-zA-Z_][0-9a-zA-Z_+-]*|' # an identifier - '[0-9+-][0-9a-zA-Z_.+-]*|' # a number - '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string - '\'([^\'\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string + _TOKEN = re.compile('|'.join([ + r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier + r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number + ] + [ # quoted str for each quote mark + r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES + ])) + _IDENTIFIER = re.compile(r'\w+') def __init__(self, lines): @@ -488,6 +792,9 @@ class _Tokenizer(object): self._SkipWhitespace() self.NextToken() + def LookingAt(self, token): + return self.token == token + def AtEnd(self): """Checks the end of the text was reached. @@ -499,7 +806,7 @@ class _Tokenizer(object): def _PopLine(self): while len(self._current_line) <= self._column: try: - self._current_line = self._lines.next() + self._current_line = next(self._lines) except StopIteration: self._current_line = '' self._more_lines = False @@ -543,6 +850,13 @@ class _Tokenizer(object): if not self.TryConsume(token): raise self._ParseError('Expected "%s".' % token) + def TryConsumeIdentifier(self): + try: + self.ConsumeIdentifier() + return True + except ParseError: + return False + def ConsumeIdentifier(self): """Consumes protocol message field identifier. @@ -569,7 +883,7 @@ class _Tokenizer(object): """ try: result = ParseInteger(self.token, is_signed=True, is_long=False) - except ValueError, e: + except ValueError as e: raise self._ParseError(str(e)) self.NextToken() return result @@ -585,11 +899,18 @@ class _Tokenizer(object): """ try: result = ParseInteger(self.token, is_signed=False, is_long=False) - except ValueError, e: + except ValueError as e: raise self._ParseError(str(e)) self.NextToken() return result + def TryConsumeInt64(self): + try: + self.ConsumeInt64() + return True + except ParseError: + return False + def ConsumeInt64(self): """Consumes a signed 64bit integer number. @@ -601,11 +922,18 @@ class _Tokenizer(object): """ try: result = ParseInteger(self.token, is_signed=True, is_long=True) - except ValueError, e: + except ValueError as e: raise self._ParseError(str(e)) self.NextToken() return result + def TryConsumeUint64(self): + try: + self.ConsumeUint64() + return True + except ParseError: + return False + def ConsumeUint64(self): """Consumes an unsigned 64bit integer number. @@ -617,11 +945,18 @@ class _Tokenizer(object): """ try: result = ParseInteger(self.token, is_signed=False, is_long=True) - except ValueError, e: + except ValueError as e: raise self._ParseError(str(e)) self.NextToken() return result + def TryConsumeFloat(self): + try: + self.ConsumeFloat() + return True + except ParseError: + return False + def ConsumeFloat(self): """Consumes an floating point number. @@ -633,7 +968,7 @@ class _Tokenizer(object): """ try: result = ParseFloat(self.token) - except ValueError, e: + except ValueError as e: raise self._ParseError(str(e)) self.NextToken() return result @@ -649,11 +984,18 @@ class _Tokenizer(object): """ try: result = ParseBool(self.token) - except ValueError, e: + except ValueError as e: raise self._ParseError(str(e)) self.NextToken() return result + def TryConsumeByteString(self): + try: + self.ConsumeByteString() + return True + except ParseError: + return False + def ConsumeString(self): """Consumes a string value. @@ -665,8 +1007,8 @@ class _Tokenizer(object): """ the_bytes = self.ConsumeByteString() try: - return unicode(the_bytes, 'utf-8') - except UnicodeDecodeError, e: + return six.text_type(the_bytes, 'utf-8') + except UnicodeDecodeError as e: raise self._StringParseError(e) def ConsumeByteString(self): @@ -679,10 +1021,9 @@ class _Tokenizer(object): ParseError: If a byte array value couldn't be consumed. """ the_list = [self._ConsumeSingleByteString()] - while self.token and self.token[0] in ('\'', '"'): + while self.token and self.token[0] in _QUOTES: the_list.append(self._ConsumeSingleByteString()) - return ''.encode('latin1').join(the_list) ##PY25 -##!PY25 return b''.join(the_list) + return b''.join(the_list) def _ConsumeSingleByteString(self): """Consume one token of a string literal. @@ -690,17 +1031,22 @@ class _Tokenizer(object): String literals (whether bytes or text) can come in multiple adjacent tokens which are automatically concatenated, like in C or Python. This method only consumes one token. + + Returns: + The token parsed. + Raises: + ParseError: When the wrong format data is found. """ text = self.token - if len(text) < 1 or text[0] not in ('\'', '"'): - raise self._ParseError('Expected string.') + if len(text) < 1 or text[0] not in _QUOTES: + raise self._ParseError('Expected string but found: %r' % (text,)) if len(text) < 2 or text[-1] != text[0]: - raise self._ParseError('String missing ending quote.') + raise self._ParseError('String missing ending quote: %r' % (text,)) try: result = text_encoding.CUnescape(text[1:-1]) - except ValueError, e: + except ValueError as e: raise self._ParseError(str(e)) self.NextToken() return result @@ -708,7 +1054,7 @@ class _Tokenizer(object): def ConsumeEnum(self, field): try: result = ParseEnum(field, self.token) - except ValueError, e: + except ValueError as e: raise self._ParseError(str(e)) self.NextToken() return result diff --git a/python/mox.py b/python/mox.py index ce80ba505..257468e52 100755 --- a/python/mox.py +++ b/python/mox.py @@ -31,7 +31,7 @@ If an unexpected method (or an expected method with unexpected parameters) is called, then an exception will be raised. Once you are done interacting with the mock, you need to verify that -all the expected interactions occured. (Maybe your code exited +all the expected interactions occurred. (Maybe your code exited prematurely without calling some cleanup method!) The verify phase ensures that every expected method was called; otherwise, an exception will be raised. diff --git a/python/setup.py b/python/setup.py index 2450a7743..0f4b53c4a 100755 --- a/python/setup.py +++ b/python/setup.py @@ -1,29 +1,24 @@ -#! /usr/bin/python +#! /usr/bin/env python # # See README for usage instructions. -import sys +import glob import os import subprocess +import sys # We must use setuptools, not distutils, because we need to use the # namespace_packages option for the "google" package. -try: - from setuptools import setup, Extension -except ImportError: - try: - from ez_setup import use_setuptools - use_setuptools() - from setuptools import setup, Extension - except ImportError: - sys.stderr.write( - "Could not import setuptools; make sure you have setuptools or " - "ez_setup installed.\n") - raise +from setuptools import setup, Extension, find_packages + from distutils.command.clean import clean as _clean -from distutils.command.build_py import build_py as _build_py -from distutils.spawn import find_executable -maintainer_email = "protobuf@googlegroups.com" +if sys.version_info[0] == 3: + # Python 3 + from distutils.command.build_py import build_py_2to3 as _build_py +else: + # Python 2 + from distutils.command.build_py import build_py as _build_py +from distutils.spawn import find_executable # Find the Protocol Compiler. if 'PROTOC' in os.environ and os.path.exists(os.environ['PROTOC']): @@ -39,23 +34,38 @@ elif os.path.exists("../vsprojects/Release/protoc.exe"): else: protoc = find_executable("protoc") -def generate_proto(source): + +def GetVersion(): + """Gets the version from google/protobuf/__init__.py + + Do not import google.protobuf.__init__ directly, because an installed + protobuf library may be loaded instead.""" + + with open(os.path.join('google', 'protobuf', '__init__.py')) as version_file: + exec(version_file.read(), globals()) + return __version__ + + +def generate_proto(source, require = True): """Invokes the Protocol Compiler to generate a _pb2.py from the given .proto file. Does nothing if the output already exists and is newer than the input.""" + if not require and not os.path.exists(source): + return + output = source.replace(".proto", "_pb2.py").replace("../src/", "") if (not os.path.exists(output) or (os.path.exists(source) and os.path.getmtime(source) > os.path.getmtime(output))): - print ("Generating %s..." % output) + print("Generating %s..." % output) if not os.path.exists(source): sys.stderr.write("Can't find required file: %s\n" % source) sys.exit(-1) - if protoc == None: + if protoc is None: sys.stderr.write( "protoc is not installed nor found in ../src. Please compile it " "or install the binary package.\n") @@ -66,39 +76,35 @@ def generate_proto(source): sys.exit(-1) def GenerateUnittestProtos(): - generate_proto("../src/google/protobuf/unittest.proto") - generate_proto("../src/google/protobuf/unittest_custom_options.proto") - generate_proto("../src/google/protobuf/unittest_import.proto") - generate_proto("../src/google/protobuf/unittest_import_public.proto") - generate_proto("../src/google/protobuf/unittest_mset.proto") - generate_proto("../src/google/protobuf/unittest_no_generic_services.proto") - generate_proto("google/protobuf/internal/descriptor_pool_test1.proto") - generate_proto("google/protobuf/internal/descriptor_pool_test2.proto") - generate_proto("google/protobuf/internal/test_bad_identifiers.proto") - generate_proto("google/protobuf/internal/missing_enum_values.proto") - generate_proto("google/protobuf/internal/more_extensions.proto") - generate_proto("google/protobuf/internal/more_extensions_dynamic.proto") - generate_proto("google/protobuf/internal/more_messages.proto") - generate_proto("google/protobuf/internal/factory_test1.proto") - generate_proto("google/protobuf/internal/factory_test2.proto") - generate_proto("google/protobuf/pyext/python.proto") - -def MakeTestSuite(): - # Test C++ implementation - import unittest - import google.protobuf.pyext.descriptor_cpp2_test as descriptor_cpp2_test - import google.protobuf.pyext.message_factory_cpp2_test \ - as message_factory_cpp2_test - import google.protobuf.pyext.reflection_cpp2_generated_test \ - as reflection_cpp2_generated_test - - loader = unittest.defaultTestLoader - suite = unittest.TestSuite() - for test in [ descriptor_cpp2_test, - message_factory_cpp2_test, - reflection_cpp2_generated_test]: - suite.addTest(loader.loadTestsFromModule(test)) - return suite + generate_proto("../src/google/protobuf/map_unittest.proto", False) + generate_proto("../src/google/protobuf/unittest_arena.proto", False) + generate_proto("../src/google/protobuf/unittest_no_arena.proto", False) + generate_proto("../src/google/protobuf/unittest_no_arena_import.proto", False) + generate_proto("../src/google/protobuf/unittest.proto", False) + generate_proto("../src/google/protobuf/unittest_custom_options.proto", False) + generate_proto("../src/google/protobuf/unittest_import.proto", False) + generate_proto("../src/google/protobuf/unittest_import_public.proto", False) + generate_proto("../src/google/protobuf/unittest_mset.proto", False) + generate_proto("../src/google/protobuf/unittest_mset_wire_format.proto", False) + generate_proto("../src/google/protobuf/unittest_no_generic_services.proto", False) + generate_proto("../src/google/protobuf/unittest_proto3_arena.proto", False) + generate_proto("../src/google/protobuf/util/json_format_proto3.proto", False) + generate_proto("google/protobuf/internal/any_test.proto", False) + generate_proto("google/protobuf/internal/descriptor_pool_test1.proto", False) + generate_proto("google/protobuf/internal/descriptor_pool_test2.proto", False) + generate_proto("google/protobuf/internal/factory_test1.proto", False) + generate_proto("google/protobuf/internal/factory_test2.proto", False) + generate_proto("google/protobuf/internal/import_test_package/inner.proto", False) + generate_proto("google/protobuf/internal/import_test_package/outer.proto", False) + generate_proto("google/protobuf/internal/missing_enum_values.proto", False) + generate_proto("google/protobuf/internal/message_set_extensions.proto", False) + generate_proto("google/protobuf/internal/more_extensions.proto", False) + generate_proto("google/protobuf/internal/more_extensions_dynamic.proto", False) + generate_proto("google/protobuf/internal/more_messages.proto", False) + generate_proto("google/protobuf/internal/packed_field_test.proto", False) + generate_proto("google/protobuf/internal/test_bad_identifiers.proto", False) + generate_proto("google/protobuf/pyext/python.proto", False) + class clean(_clean): def run(self): @@ -108,7 +114,8 @@ class clean(_clean): filepath = os.path.join(dirpath, filename) if filepath.endswith("_pb2.py") or filepath.endswith(".pyc") or \ filepath.endswith(".so") or filepath.endswith(".o") or \ - filepath.endswith('google/protobuf/compiler/__init__.py'): + filepath.endswith('google/protobuf/compiler/__init__.py') or \ + filepath.endswith('google/protobuf/util/__init__.py'): os.remove(filepath) # _clean is an old-style class, so super() doesn't work. _clean.run(self) @@ -118,84 +125,126 @@ class build_py(_build_py): # Generate necessary .proto file if it doesn't exist. generate_proto("../src/google/protobuf/descriptor.proto") generate_proto("../src/google/protobuf/compiler/plugin.proto") + generate_proto("../src/google/protobuf/any.proto") + generate_proto("../src/google/protobuf/api.proto") + generate_proto("../src/google/protobuf/duration.proto") + generate_proto("../src/google/protobuf/empty.proto") + generate_proto("../src/google/protobuf/field_mask.proto") + generate_proto("../src/google/protobuf/source_context.proto") + generate_proto("../src/google/protobuf/struct.proto") + generate_proto("../src/google/protobuf/timestamp.proto") + generate_proto("../src/google/protobuf/type.proto") + generate_proto("../src/google/protobuf/wrappers.proto") GenerateUnittestProtos() # Make sure google.protobuf/** are valid packages. - for path in ['', 'internal/', 'compiler/', 'pyext/']: + for path in ['', 'internal/', 'compiler/', 'pyext/', 'util/']: try: open('google/protobuf/%s__init__.py' % path, 'a').close() except EnvironmentError: pass # _build_py is an old-style class, so super() doesn't work. _build_py.run(self) - # TODO(mrovner): Subclass to run 2to3 on some files only. - # Tracing what https://wiki.python.org/moin/PortingPythonToPy3k's "Approach 2" - # section on how to get 2to3 to run on source files during install under - # Python 3. This class seems like a good place to put logic that calls - # python3's distutils.util.run_2to3 on the subset of the files we have in our - # release that are subject to conversion. - # See code reference in previous code review. + +class test_conformance(_build_py): + target = 'test_python' + def run(self): + if sys.version_info >= (2, 7): + # Python 2.6 dodges these extra failures. + os.environ["CONFORMANCE_PYTHON_EXTRA_FAILURES"] = ( + "--failure_list failure_list_python-post26.txt") + cmd = 'cd ../conformance && make %s' % (test_conformance.target) + status = subprocess.check_call(cmd, shell=True) + + +def get_option_from_sys_argv(option_str): + if option_str in sys.argv: + sys.argv.remove(option_str) + return True + return False + if __name__ == '__main__': ext_module_list = [] - cpp_impl = '--cpp_implementation' - if cpp_impl in sys.argv: - sys.argv.remove(cpp_impl) + warnings_as_errors = '--warnings_as_errors' + if get_option_from_sys_argv('--cpp_implementation'): + # Link libprotobuf.a and libprotobuf-lite.a statically with the + # extension. Note that those libraries have to be compiled with + # -fPIC for this to work. + compile_static_ext = get_option_from_sys_argv('--compile_static_extension') + extra_compile_args = ['-Wno-write-strings', + '-Wno-invalid-offsetof', + '-Wno-sign-compare'] + libraries = ['protobuf'] + extra_objects = None + if compile_static_ext: + libraries = None + extra_objects = ['../src/.libs/libprotobuf.a', + '../src/.libs/libprotobuf-lite.a'] + test_conformance.target = 'test_python_cpp' + + if "clang" in os.popen('$CC --version 2> /dev/null').read(): + extra_compile_args.append('-Wno-shorten-64-to-32') + + if warnings_as_errors in sys.argv: + extra_compile_args.append('-Werror') + sys.argv.remove(warnings_as_errors) + # C++ implementation extension - ext_module_list.append(Extension( - "google.protobuf.pyext._message", - [ "google/protobuf/pyext/descriptor.cc", - "google/protobuf/pyext/message.cc", - "google/protobuf/pyext/extension_dict.cc", - "google/protobuf/pyext/repeated_scalar_container.cc", - "google/protobuf/pyext/repeated_composite_container.cc" ], - define_macros=[('GOOGLE_PROTOBUF_HAS_ONEOF', '1')], - include_dirs = [ ".", "../src"], - libraries = [ "protobuf" ], - library_dirs = [ '../src/.libs' ], - )) - - setup(name = 'protobuf', - version = '2.6.1', - packages = [ 'google' ], - namespace_packages = [ 'google' ], - test_suite = 'setup.MakeTestSuite', - google_test_dir = "google/protobuf/internal", - # Must list modules explicitly so that we don't install tests. - py_modules = [ - 'google.protobuf.internal.api_implementation', - 'google.protobuf.internal.containers', - 'google.protobuf.internal.cpp_message', - 'google.protobuf.internal.decoder', - 'google.protobuf.internal.encoder', - 'google.protobuf.internal.enum_type_wrapper', - 'google.protobuf.internal.message_listener', - 'google.protobuf.internal.python_message', - 'google.protobuf.internal.type_checkers', - 'google.protobuf.internal.wire_format', - 'google.protobuf.descriptor', - 'google.protobuf.descriptor_pb2', - 'google.protobuf.compiler.plugin_pb2', - 'google.protobuf.message', - 'google.protobuf.descriptor_database', - 'google.protobuf.descriptor_pool', - 'google.protobuf.message_factory', - 'google.protobuf.pyext.cpp_message', - 'google.protobuf.reflection', - 'google.protobuf.service', - 'google.protobuf.service_reflection', - 'google.protobuf.symbol_database', - 'google.protobuf.text_encoding', - 'google.protobuf.text_format'], - cmdclass = { 'clean': clean, 'build_py': build_py }, - install_requires = ['setuptools'], - setup_requires = ['google-apputils'], - ext_modules = ext_module_list, - url = 'https://developers.google.com/protocol-buffers/', - maintainer = maintainer_email, - maintainer_email = 'protobuf@googlegroups.com', - license = 'New BSD License', - description = 'Protocol Buffers', - long_description = - "Protocol Buffers are Google's data interchange format.", - ) + ext_module_list.extend([ + Extension( + "google.protobuf.pyext._message", + glob.glob('google/protobuf/pyext/*.cc'), + include_dirs=[".", "../src"], + libraries=libraries, + extra_objects=extra_objects, + library_dirs=['../src/.libs'], + extra_compile_args=extra_compile_args, + ), + Extension( + "google.protobuf.internal._api_implementation", + glob.glob('google/protobuf/internal/api_implementation.cc'), + extra_compile_args=['-DPYTHON_PROTO2_CPP_IMPL_V2'], + ), + ]) + os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' + + # Keep this list of dependencies in sync with tox.ini. + install_requires = ['six>=1.9', 'setuptools'] + if sys.version_info <= (2,7): + install_requires.append('ordereddict') + install_requires.append('unittest2') + + setup( + name='protobuf', + version=GetVersion(), + description='Protocol Buffers', + long_description="Protocol Buffers are Google's data interchange format", + url='https://developers.google.com/protocol-buffers/', + maintainer='protobuf@googlegroups.com', + maintainer_email='protobuf@googlegroups.com', + license='New BSD License', + classifiers=[ + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.6", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.3", + "Programming Language :: Python :: 3.4", + ], + namespace_packages=['google'], + packages=find_packages( + exclude=[ + 'import_test_package', + ], + ), + test_suite='google.protobuf.internal', + cmdclass={ + 'clean': clean, + 'build_py': build_py, + 'test_conformance': test_conformance, + }, + install_requires=install_requires, + ext_modules=ext_module_list, + ) diff --git a/python/tox.ini b/python/tox.ini new file mode 100644 index 000000000..cf8d54016 --- /dev/null +++ b/python/tox.ini @@ -0,0 +1,24 @@ +[tox] +envlist = + py{26,27,33,34}-{cpp,python} + +[testenv] +usedevelop=true +passenv = CC +setenv = + cpp: LD_LIBRARY_PATH={toxinidir}/../src/.libs + cpp: DYLD_LIBRARY_PATH={toxinidir}/../src/.libs + cpp: PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +commands = + python setup.py -q build_py + python: python setup.py -q build + cpp: python setup.py -q build --cpp_implementation --warnings_as_errors + python: python setup.py -q test -q + cpp: python setup.py -q test -q --cpp_implementation + python: python setup.py -q test_conformance + cpp: python setup.py -q test_conformance --cpp_implementation +deps = + # Keep this list of dependencies in sync with setup.py. + six>=1.9 + py26: ordereddict + py26: unittest2 |
