aboutsummaryrefslogtreecommitdiffstats
path: root/python
diff options
context:
space:
mode:
authorTamas Berghammer <tberghammer@google.com>2016-06-03 17:53:47 -0700
committerTamas Berghammer <tberghammer@google.com>2016-06-06 16:38:21 -0700
commitb0575e93e4c39dec69365b850088a1eb7f82c5b3 (patch)
treeaf0dd50df9603fdc849ad26d3d6b50384035d51e /python
parent9f183e211f11e0e5c0a1fa519107dbfbb1a8d4ca (diff)
downloadplatform_external_protobuf-b0575e93e4c39dec69365b850088a1eb7f82c5b3.tar.gz
platform_external_protobuf-b0575e93e4c39dec69365b850088a1eb7f82c5b3.tar.bz2
platform_external_protobuf-b0575e93e4c39dec69365b850088a1eb7f82c5b3.zip
Update from protobuf v2.6.1 to protobuf 3.0.0-beta-3
This change just copies the upstream code into the repository without fixing the Android.mk or fixing the possible cmpile errors. All of those will be fixed with foloowup CLs. Bug: b/28974522 Change-Id: I79fb3966dbef85915965692fa6ab14dc611ed9ea
Diffstat (limited to 'python')
-rw-r--r--python/MANIFEST.in14
-rw-r--r--python/README.md (renamed from python/README.txt)74
-rwxr-xr-xpython/ez_setup.py284
-rwxr-xr-xpython/google/__init__.py5
-rwxr-xr-xpython/google/protobuf/__init__.py39
-rwxr-xr-xpython/google/protobuf/descriptor.py222
-rw-r--r--python/google/protobuf/descriptor_database.py8
-rw-r--r--python/google/protobuf/descriptor_pool.py230
-rwxr-xr-xpython/google/protobuf/internal/_parameterized.py443
-rw-r--r--python/google/protobuf/internal/any_test.proto42
-rw-r--r--python/google/protobuf/internal/api_implementation.cc14
-rwxr-xr-xpython/google/protobuf/internal/api_implementation.py48
-rw-r--r--python/google/protobuf/internal/api_implementation_default_test.py63
-rwxr-xr-xpython/google/protobuf/internal/containers.py358
-rwxr-xr-xpython/google/protobuf/internal/cpp_message.py663
-rwxr-xr-xpython/google/protobuf/internal/decoder.py99
-rw-r--r--python/google/protobuf/internal/descriptor_database_test.py24
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py312
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test1.proto2
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test2.proto3
-rw-r--r--python/google/protobuf/internal/descriptor_python_test.py54
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py247
-rwxr-xr-xpython/google/protobuf/internal/encoder.py109
-rw-r--r--python/google/protobuf/internal/factory_test1.proto1
-rw-r--r--python/google/protobuf/internal/factory_test2.proto7
-rwxr-xr-xpython/google/protobuf/internal/generator_test.py43
-rw-r--r--python/google/protobuf/internal/import_test_package/__init__.py (renamed from python/google/protobuf/internal/message_python_test.py)25
-rw-r--r--python/google/protobuf/internal/import_test_package/inner.proto37
-rw-r--r--python/google/protobuf/internal/import_test_package/outer.proto39
-rw-r--r--python/google/protobuf/internal/json_format_test.py769
-rw-r--r--python/google/protobuf/internal/message_factory_python_test.py54
-rw-r--r--python/google/protobuf/internal/message_factory_test.py101
-rw-r--r--python/google/protobuf/internal/message_set_extensions.proto74
-rwxr-xr-xpython/google/protobuf/internal/message_test.py1533
-rw-r--r--python/google/protobuf/internal/missing_enum_values.proto6
-rw-r--r--python/google/protobuf/internal/more_extensions.proto1
-rw-r--r--python/google/protobuf/internal/more_extensions_dynamic.proto1
-rw-r--r--python/google/protobuf/internal/more_messages.proto1
-rw-r--r--python/google/protobuf/internal/packed_field_test.proto73
-rw-r--r--python/google/protobuf/internal/proto_builder_test.py96
-rwxr-xr-xpython/google/protobuf/internal/python_message.py527
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py283
-rwxr-xr-xpython/google/protobuf/internal/service_reflection_test.py20
-rw-r--r--python/google/protobuf/internal/symbol_database_test.py61
-rw-r--r--python/google/protobuf/internal/test_bad_identifiers.proto1
-rwxr-xr-xpython/google/protobuf/internal/test_util.py204
-rwxr-xr-xpython/google/protobuf/internal/text_encoding_test.py20
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py1008
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py72
-rwxr-xr-xpython/google/protobuf/internal/unknown_fields_test.py211
-rw-r--r--python/google/protobuf/internal/well_known_types.py724
-rw-r--r--python/google/protobuf/internal/well_known_types_test.py644
-rwxr-xr-xpython/google/protobuf/internal/wire_format_test.py12
-rw-r--r--python/google/protobuf/json_format.py645
-rwxr-xr-xpython/google/protobuf/message.py19
-rw-r--r--python/google/protobuf/message_factory.py16
-rw-r--r--python/google/protobuf/proto_builder.py130
-rw-r--r--python/google/protobuf/pyext/__init__.py4
-rw-r--r--python/google/protobuf/pyext/cpp_message.py38
-rw-r--r--python/google/protobuf/pyext/descriptor.cc1593
-rw-r--r--python/google/protobuf/pyext/descriptor.h85
-rw-r--r--python/google/protobuf/pyext/descriptor_containers.cc1652
-rw-r--r--python/google/protobuf/pyext/descriptor_containers.h101
-rw-r--r--python/google/protobuf/pyext/descriptor_cpp2_test.py58
-rw-r--r--python/google/protobuf/pyext/descriptor_database.cc148
-rw-r--r--python/google/protobuf/pyext/descriptor_database.h75
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.cc593
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.h167
-rw-r--r--python/google/protobuf/pyext/extension_dict.cc195
-rw-r--r--python/google/protobuf/pyext/extension_dict.h40
-rw-r--r--python/google/protobuf/pyext/map_container.cc970
-rw-r--r--python/google/protobuf/pyext/map_container.h142
-rw-r--r--python/google/protobuf/pyext/message.cc2372
-rw-r--r--python/google/protobuf/pyext/message.h130
-rw-r--r--python/google/protobuf/pyext/message_factory_cpp2_test.py56
-rw-r--r--python/google/protobuf/pyext/proto2_api_test.proto2
-rw-r--r--python/google/protobuf/pyext/python.proto2
-rwxr-xr-xpython/google/protobuf/pyext/reflection_cpp2_generated_test.py94
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.cc435
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.h39
-rw-r--r--python/google/protobuf/pyext/repeated_scalar_container.cc325
-rw-r--r--python/google/protobuf/pyext/repeated_scalar_container.h16
-rw-r--r--python/google/protobuf/pyext/scoped_pyobject_ptr.h17
-rwxr-xr-xpython/google/protobuf/reflection.py107
-rw-r--r--python/google/protobuf/symbol_database.py6
-rw-r--r--python/google/protobuf/text_encoding.py19
-rwxr-xr-xpython/google/protobuf/text_format.py914
-rwxr-xr-xpython/mox.py2
-rwxr-xr-xpython/setup.py293
-rw-r--r--python/tox.ini24
90 files changed, 16537 insertions, 4992 deletions
diff --git a/python/MANIFEST.in b/python/MANIFEST.in
new file mode 100644
index 000000000..260888263
--- /dev/null
+++ b/python/MANIFEST.in
@@ -0,0 +1,14 @@
+prune google/protobuf/internal/import_test_package
+exclude google/protobuf/internal/*_pb2.py
+exclude google/protobuf/internal/*_test.py
+exclude google/protobuf/internal/*.proto
+exclude google/protobuf/internal/test_util.py
+
+recursive-exclude google *_test.py
+recursive-exclude google *_test.proto
+recursive-exclude google unittest*_pb2.py
+
+global-exclude *.dll
+global-exclude *.pyc
+global-exclude *.pyo
+global-exclude *.so
diff --git a/python/README.txt b/python/README.md
index da044af44..57acfd94d 100644
--- a/python/README.txt
+++ b/python/README.md
@@ -1,4 +1,8 @@
Protocol Buffers - Google's data interchange format
+===================================================
+
+[![Build Status](https://travis-ci.org/google/protobuf.svg?branch=master)](https://travis-ci.org/google/protobuf)
+
Copyright 2008 Google Inc.
This directory contains the Python Protocol Buffers runtime library.
@@ -26,7 +30,7 @@ join the Protocol Buffers discussion list and let us know!
Installation
============
-1) Make sure you have Python 2.4 or newer. If in doubt, run:
+1) Make sure you have Python 2.6 or newer. If in doubt, run:
$ python -V
@@ -35,7 +39,7 @@ Installation
If you would rather install it manually, you may do so by following
the instructions on this page:
- http://peak.telecommunity.com/DevCenter/EasyInstall#installation-instructions
+ https://packaging.python.org/en/latest/installing.html#setup-for-installing-packages
3) Build the C++ code, or install a binary distribution of protoc. If
you install a binary distribution, make sure that it is the same
@@ -46,9 +50,38 @@ Installation
4) Build and run the tests:
$ python setup.py build
- $ python setup.py google_test
+ $ python setup.py test
+
+ To build, test, and use the C++ implementation, you must first compile
+ libprotobuf.so:
+
+ $ (cd .. && make)
+
+ On OS X:
+
+ If you are running a homebrew-provided python, you must make sure another
+ version of protobuf is not already installed, as homebrew's python will
+ search /usr/local/lib for libprotobuf.so before it searches ../src/.libs
+ You can either unlink homebrew's protobuf or install the libprotobuf you
+ built earlier:
+
+ $ brew unlink protobuf
+ or
+ $ (cd .. && make install)
+
+ On other *nix:
- If you want to test c++ implementation, run:
+ You must make libprotobuf.so dynamically available. You can either
+ install libprotobuf you built earlier, or set LD_LIBRARY_PATH:
+
+ $ export LD_LIBRARY_PATH=../src/.libs
+ or
+ $ (cd .. && make install)
+
+ To build the C++ implementation run:
+ $ python setup.py build --cpp_implementation
+
+ Then run the tests like so:
$ python setup.py test --cpp_implementation
If some tests fail, this library may not work correctly on your
@@ -64,14 +97,17 @@ Installation
5) Install:
- $ python setup.py install
- or:
- $ python setup.py install --cpp_implementation
+ $ python setup.py install
+
+ or:
+
+ $ (cd .. && make install)
+ $ python setup.py install --cpp_implementation
This step may require superuser privileges.
- NOTE: To use C++ implementation, you need to install C++ protobuf runtime
- library of the same version and export the environment variable before this
- step. See the "C++ Implementation" section below for more details.
+ NOTE: To use C++ implementation, you need to export an environment
+ variable before running your program. See the "C++ Implementation"
+ section below for more details.
Usage
=====
@@ -87,19 +123,5 @@ C++ Implementation
The C++ implementation for Python messages is built as a Python extension to
improve the overall protobuf Python performance.
-To use the C++ implementation, you need to:
-1) Install the C++ protobuf runtime library, please see instructions in the
- parent directory.
-2) Export an environment variable:
-
- $ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp
- $ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2
-
-You need to export this variable before running setup.py script to build and
-install the extension. You must also set the variable at runtime, otherwise
-the pure-Python implementation will be used. In a future release, we will
-change the default so that C++ implementation is used whenever it is available.
-It is strongly recommended to run `python setup.py test` after setting the
-variable to "cpp", so the tests will be against C++ implemented Python
-messages.
-
+To use the C++ implementation, you need to install the C++ protobuf runtime
+library, please see instructions in the parent directory.
diff --git a/python/ez_setup.py b/python/ez_setup.py
deleted file mode 100755
index 3aec98e48..000000000
--- a/python/ez_setup.py
+++ /dev/null
@@ -1,284 +0,0 @@
-#!python
-
-# This file was obtained from:
-# http://peak.telecommunity.com/dist/ez_setup.py
-# on 2011/1/21.
-
-"""Bootstrap setuptools installation
-
-If you want to use setuptools in your package's setup.py, just include this
-file in the same directory with it, and add this to the top of your setup.py::
-
- from ez_setup import use_setuptools
- use_setuptools()
-
-If you want to require a specific version of setuptools, set a download
-mirror, or use an alternate download directory, you can do so by supplying
-the appropriate options to ``use_setuptools()``.
-
-This file can also be run as a script to install or upgrade setuptools.
-"""
-import sys
-DEFAULT_VERSION = "0.6c11"
-DEFAULT_URL = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3]
-
-md5_data = {
- 'setuptools-0.6b1-py2.3.egg': '8822caf901250d848b996b7f25c6e6ca',
- 'setuptools-0.6b1-py2.4.egg': 'b79a8a403e4502fbb85ee3f1941735cb',
- 'setuptools-0.6b2-py2.3.egg': '5657759d8a6d8fc44070a9d07272d99b',
- 'setuptools-0.6b2-py2.4.egg': '4996a8d169d2be661fa32a6e52e4f82a',
- 'setuptools-0.6b3-py2.3.egg': 'bb31c0fc7399a63579975cad9f5a0618',
- 'setuptools-0.6b3-py2.4.egg': '38a8c6b3d6ecd22247f179f7da669fac',
- 'setuptools-0.6b4-py2.3.egg': '62045a24ed4e1ebc77fe039aa4e6f7e5',
- 'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4',
- 'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c',
- 'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b',
- 'setuptools-0.6c10-py2.3.egg': 'ce1e2ab5d3a0256456d9fc13800a7090',
- 'setuptools-0.6c10-py2.4.egg': '57d6d9d6e9b80772c59a53a8433a5dd4',
- 'setuptools-0.6c10-py2.5.egg': 'de46ac8b1c97c895572e5e8596aeb8c7',
- 'setuptools-0.6c10-py2.6.egg': '58ea40aef06da02ce641495523a0b7f5',
- 'setuptools-0.6c11-py2.3.egg': '2baeac6e13d414a9d28e7ba5b5a596de',
- 'setuptools-0.6c11-py2.4.egg': 'bd639f9b0eac4c42497034dec2ec0c2b',
- 'setuptools-0.6c11-py2.5.egg': '64c94f3bf7a72a13ec83e0b24f2749b2',
- 'setuptools-0.6c11-py2.6.egg': 'bfa92100bd772d5a213eedd356d64086',
- 'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27',
- 'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277',
- 'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa',
- 'setuptools-0.6c3-py2.4.egg': 'e0ed74682c998bfb73bf803a50e7b71e',
- 'setuptools-0.6c3-py2.5.egg': 'abef16fdd61955514841c7c6bd98965e',
- 'setuptools-0.6c4-py2.3.egg': 'b0b9131acab32022bfac7f44c5d7971f',
- 'setuptools-0.6c4-py2.4.egg': '2a1f9656d4fbf3c97bf946c0a124e6e2',
- 'setuptools-0.6c4-py2.5.egg': '8f5a052e32cdb9c72bcf4b5526f28afc',
- 'setuptools-0.6c5-py2.3.egg': 'ee9fd80965da04f2f3e6b3576e9d8167',
- 'setuptools-0.6c5-py2.4.egg': 'afe2adf1c01701ee841761f5bcd8aa64',
- 'setuptools-0.6c5-py2.5.egg': 'a8d3f61494ccaa8714dfed37bccd3d5d',
- 'setuptools-0.6c6-py2.3.egg': '35686b78116a668847237b69d549ec20',
- 'setuptools-0.6c6-py2.4.egg': '3c56af57be3225019260a644430065ab',
- 'setuptools-0.6c6-py2.5.egg': 'b2f8a7520709a5b34f80946de5f02f53',
- 'setuptools-0.6c7-py2.3.egg': '209fdf9adc3a615e5115b725658e13e2',
- 'setuptools-0.6c7-py2.4.egg': '5a8f954807d46a0fb67cf1f26c55a82e',
- 'setuptools-0.6c7-py2.5.egg': '45d2ad28f9750e7434111fde831e8372',
- 'setuptools-0.6c8-py2.3.egg': '50759d29b349db8cfd807ba8303f1902',
- 'setuptools-0.6c8-py2.4.egg': 'cba38d74f7d483c06e9daa6070cce6de',
- 'setuptools-0.6c8-py2.5.egg': '1721747ee329dc150590a58b3e1ac95b',
- 'setuptools-0.6c9-py2.3.egg': 'a83c4020414807b496e4cfbe08507c03',
- 'setuptools-0.6c9-py2.4.egg': '260a2be2e5388d66bdaee06abec6342a',
- 'setuptools-0.6c9-py2.5.egg': 'fe67c3e5a17b12c0e7c541b7ea43a8e6',
- 'setuptools-0.6c9-py2.6.egg': 'ca37b1ff16fa2ede6e19383e7b59245a',
-}
-
-import sys, os
-try: from hashlib import md5
-except ImportError: from md5 import md5
-
-def _validate_md5(egg_name, data):
- if egg_name in md5_data:
- digest = md5(data).hexdigest()
- if digest != md5_data[egg_name]:
- print >>sys.stderr, (
- "md5 validation of %s failed! (Possible download problem?)"
- % egg_name
- )
- sys.exit(2)
- return data
-
-def use_setuptools(
- version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
- download_delay=15
-):
- """Automatically find/download setuptools and make it available on sys.path
-
- `version` should be a valid setuptools version number that is available
- as an egg for download under the `download_base` URL (which should end with
- a '/'). `to_dir` is the directory where setuptools will be downloaded, if
- it is not already available. If `download_delay` is specified, it should
- be the number of seconds that will be paused before initiating a download,
- should one be required. If an older version of setuptools is installed,
- this routine will print a message to ``sys.stderr`` and raise SystemExit in
- an attempt to abort the calling script.
- """
- was_imported = 'pkg_resources' in sys.modules or 'setuptools' in sys.modules
- def do_download():
- egg = download_setuptools(version, download_base, to_dir, download_delay)
- sys.path.insert(0, egg)
- import setuptools; setuptools.bootstrap_install_from = egg
- try:
- import pkg_resources
- except ImportError:
- return do_download()
- try:
- return do_download()
- pkg_resources.require("setuptools>="+version); return
- except pkg_resources.VersionConflict, e:
- if was_imported:
- print >>sys.stderr, (
- "The required version of setuptools (>=%s) is not available, and\n"
- "can't be installed while this script is running. Please install\n"
- " a more recent version first, using 'easy_install -U setuptools'."
- "\n\n(Currently using %r)"
- ) % (version, e.args[0])
- sys.exit(2)
- except pkg_resources.DistributionNotFound:
- pass
-
- del pkg_resources, sys.modules['pkg_resources'] # reload ok
- return do_download()
-
-def download_setuptools(
- version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
- delay = 15
-):
- """Download setuptools from a specified location and return its filename
-
- `version` should be a valid setuptools version number that is available
- as an egg for download under the `download_base` URL (which should end
- with a '/'). `to_dir` is the directory where the egg will be downloaded.
- `delay` is the number of seconds to pause before an actual download attempt.
- """
- import urllib2, shutil
- egg_name = "setuptools-%s-py%s.egg" % (version,sys.version[:3])
- url = download_base + egg_name
- saveto = os.path.join(to_dir, egg_name)
- src = dst = None
- if not os.path.exists(saveto): # Avoid repeated downloads
- try:
- from distutils import log
- if delay:
- log.warn("""
----------------------------------------------------------------------------
-This script requires setuptools version %s to run (even to display
-help). I will attempt to download it for you (from
-%s), but
-you may need to enable firewall access for this script first.
-I will start the download in %d seconds.
-
-(Note: if this machine does not have network access, please obtain the file
-
- %s
-
-and place it in this directory before rerunning this script.)
----------------------------------------------------------------------------""",
- version, download_base, delay, url
- ); from time import sleep; sleep(delay)
- log.warn("Downloading %s", url)
- src = urllib2.urlopen(url)
- # Read/write all in one block, so we don't create a corrupt file
- # if the download is interrupted.
- data = _validate_md5(egg_name, src.read())
- dst = open(saveto,"wb"); dst.write(data)
- finally:
- if src: src.close()
- if dst: dst.close()
- return os.path.realpath(saveto)
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-def main(argv, version=DEFAULT_VERSION):
- """Install or upgrade setuptools and EasyInstall"""
- try:
- import setuptools
- except ImportError:
- egg = None
- try:
- egg = download_setuptools(version, delay=0)
- sys.path.insert(0,egg)
- from setuptools.command.easy_install import main
- return main(list(argv)+[egg]) # we're done here
- finally:
- if egg and os.path.exists(egg):
- os.unlink(egg)
- else:
- if setuptools.__version__ == '0.0.1':
- print >>sys.stderr, (
- "You have an obsolete version of setuptools installed. Please\n"
- "remove it from your system entirely before rerunning this script."
- )
- sys.exit(2)
-
- req = "setuptools>="+version
- import pkg_resources
- try:
- pkg_resources.require(req)
- except pkg_resources.VersionConflict:
- try:
- from setuptools.command.easy_install import main
- except ImportError:
- from easy_install import main
- main(list(argv)+[download_setuptools(delay=0)])
- sys.exit(0) # try to force an exit
- else:
- if argv:
- from setuptools.command.easy_install import main
- main(argv)
- else:
- print "Setuptools version",version,"or greater has been installed."
- print '(Run "ez_setup.py -U setuptools" to reinstall or upgrade.)'
-
-def update_md5(filenames):
- """Update our built-in md5 registry"""
-
- import re
-
- for name in filenames:
- base = os.path.basename(name)
- f = open(name,'rb')
- md5_data[base] = md5(f.read()).hexdigest()
- f.close()
-
- data = [" %r: %r,\n" % it for it in md5_data.items()]
- data.sort()
- repl = "".join(data)
-
- import inspect
- srcfile = inspect.getsourcefile(sys.modules[__name__])
- f = open(srcfile, 'rb'); src = f.read(); f.close()
-
- match = re.search("\nmd5_data = {\n([^}]+)}", src)
- if not match:
- print >>sys.stderr, "Internal error!"
- sys.exit(2)
-
- src = src[:match.start(1)] + repl + src[match.end(1):]
- f = open(srcfile,'w')
- f.write(src)
- f.close()
-
-
-if __name__=='__main__':
- if len(sys.argv)>2 and sys.argv[1]=='--md5update':
- update_md5(sys.argv[2:])
- else:
- main(sys.argv[1:])
diff --git a/python/google/__init__.py b/python/google/__init__.py
index de40ea7ca..558561412 100755
--- a/python/google/__init__.py
+++ b/python/google/__init__.py
@@ -1 +1,4 @@
-__import__('pkg_resources').declare_namespace(__name__)
+try:
+ __import__('pkg_resources').declare_namespace(__name__)
+except ImportError:
+ __path__ = __import__('pkgutil').extend_path(__path__, __name__)
diff --git a/python/google/protobuf/__init__.py b/python/google/protobuf/__init__.py
index e69de29bb..2a3c6771a 100755
--- a/python/google/protobuf/__init__.py
+++ b/python/google/protobuf/__init__.py
@@ -0,0 +1,39 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# Copyright 2007 Google Inc. All Rights Reserved.
+
+__version__ = '3.0.0b3'
+
+if __name__ != '__main__':
+ try:
+ __import__('pkg_resources').declare_namespace(__name__)
+ except ImportError:
+ __path__ = __import__('pkgutil').extend_path(__path__, __name__)
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
index 6da8bb0bc..3209b34d5 100755
--- a/python/google/protobuf/descriptor.py
+++ b/python/google/protobuf/descriptor.py
@@ -28,28 +28,23 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-# Needs to stay compatible with Python 2.5 due to GAE.
-#
-# Copyright 2007 Google Inc. All Rights Reserved.
-
"""Descriptors essentially contain exactly the information found in a .proto
file, in types that make this information accessible in Python.
"""
__author__ = 'robinson@google.com (Will Robinson)'
-from google.protobuf.internal import api_implementation
+import six
+from google.protobuf.internal import api_implementation
+_USE_C_DESCRIPTORS = False
if api_implementation.Type() == 'cpp':
# Used by MakeDescriptor in cpp mode
import os
import uuid
-
- if api_implementation.Version() == 2:
- from google.protobuf.pyext import _message
- else:
- from google.protobuf.internal import cpp_message
+ from google.protobuf.pyext import _message
+ _USE_C_DESCRIPTORS = getattr(_message, '_USE_C_DESCRIPTORS', False)
class Error(Exception):
@@ -60,12 +55,29 @@ class TypeTransformationError(Error):
"""Error transforming between python proto type and corresponding C++ type."""
-class DescriptorBase(object):
+if _USE_C_DESCRIPTORS:
+ # This metaclass allows to override the behavior of code like
+ # isinstance(my_descriptor, FieldDescriptor)
+ # and make it return True when the descriptor is an instance of the extension
+ # type written in C++.
+ class DescriptorMetaclass(type):
+ def __instancecheck__(cls, obj):
+ if super(DescriptorMetaclass, cls).__instancecheck__(obj):
+ return True
+ if isinstance(obj, cls._C_DESCRIPTOR_CLASS):
+ return True
+ return False
+else:
+ # The standard metaclass; nothing changes.
+ DescriptorMetaclass = type
+
+
+class DescriptorBase(six.with_metaclass(DescriptorMetaclass)):
"""Descriptors base class.
This class is the base of all descriptor classes. It provides common options
- related functionaility.
+ related functionality.
Attributes:
has_options: True if the descriptor has non-default options. Usually it
@@ -75,6 +87,11 @@ class DescriptorBase(object):
avoid some bootstrapping issues.
"""
+ if _USE_C_DESCRIPTORS:
+ # The class, or tuple of classes, that are considered as "virtual
+ # subclasses" of this descriptor class.
+ _C_DESCRIPTOR_CLASS = ()
+
def __init__(self, options, options_class_name):
"""Initialize the descriptor given its options message and the name of the
class of the options message. The name of the class is required in case
@@ -201,6 +218,9 @@ class Descriptor(_NestedDescriptorBase):
fields_by_name: (dict str -> FieldDescriptor) Same FieldDescriptor
objects as in |fields|, but indexed by "name" attribute in each
FieldDescriptor.
+ fields_by_camelcase_name: (dict str -> FieldDescriptor) Same
+ FieldDescriptor objects as in |fields|, but indexed by
+ "camelcase_name" attribute in each FieldDescriptor.
nested_types: (list of Descriptors) Descriptor references
for all protocol message types nested within this one.
@@ -224,9 +244,6 @@ class Descriptor(_NestedDescriptorBase):
is_extendable: Does this type define any extension ranges?
- options: (descriptor_pb2.MessageOptions) Protocol message options or None
- to use default message options.
-
oneofs: (list of OneofDescriptor) The list of descriptors for oneof fields
in this message.
oneofs_by_name: (dict str -> OneofDescriptor) Same objects as in |oneofs|,
@@ -235,13 +252,25 @@ class Descriptor(_NestedDescriptorBase):
file: (FileDescriptor) Reference to file descriptor.
"""
+ if _USE_C_DESCRIPTORS:
+ _C_DESCRIPTOR_CLASS = _message.Descriptor
+
+ def __new__(cls, name, full_name, filename, containing_type, fields,
+ nested_types, enum_types, extensions, options=None,
+ is_extendable=True, extension_ranges=None, oneofs=None,
+ file=None, serialized_start=None, serialized_end=None,
+ syntax=None):
+ _message.Message._CheckCalledFromGeneratedFile()
+ return _message.default_pool.FindMessageTypeByName(full_name)
+
# NOTE(tmarek): The file argument redefining a builtin is nothing we can
# fix right now since we don't know how many clients already rely on the
# name of the argument.
def __init__(self, name, full_name, filename, containing_type, fields,
nested_types, enum_types, extensions, options=None,
is_extendable=True, extension_ranges=None, oneofs=None,
- file=None, serialized_start=None, serialized_end=None): # pylint:disable=redefined-builtin
+ file=None, serialized_start=None, serialized_end=None,
+ syntax=None): # pylint:disable=redefined-builtin
"""Arguments to __init__() are as described in the description
of Descriptor fields above.
@@ -263,6 +292,7 @@ class Descriptor(_NestedDescriptorBase):
field.containing_type = self
self.fields_by_number = dict((f.number, f) for f in fields)
self.fields_by_name = dict((f.name, f) for f in fields)
+ self._fields_by_camelcase_name = None
self.nested_types = nested_types
for nested_type in nested_types:
@@ -286,6 +316,14 @@ class Descriptor(_NestedDescriptorBase):
self.oneofs_by_name = dict((o.name, o) for o in self.oneofs)
for oneof in self.oneofs:
oneof.containing_type = self
+ self.syntax = syntax or "proto2"
+
+ @property
+ def fields_by_camelcase_name(self):
+ if self._fields_by_camelcase_name is None:
+ self._fields_by_camelcase_name = dict(
+ (f.camelcase_name, f) for f in self.fields)
+ return self._fields_by_camelcase_name
def EnumValueName(self, enum, value):
"""Returns the string name of an enum value.
@@ -335,6 +373,7 @@ class FieldDescriptor(DescriptorBase):
name: (str) Name of this field, exactly as it appears in .proto.
full_name: (str) Name of this field, including containing scope. This is
particularly relevant for extensions.
+ camelcase_name: (str) Camelcase name of this field.
index: (int) Dense, 0-indexed index giving the order that this
field textually appears within its message in the .proto file.
number: (int) Tag number declared for this field in the .proto file.
@@ -452,6 +491,19 @@ class FieldDescriptor(DescriptorBase):
FIRST_RESERVED_FIELD_NUMBER = 19000
LAST_RESERVED_FIELD_NUMBER = 19999
+ if _USE_C_DESCRIPTORS:
+ _C_DESCRIPTOR_CLASS = _message.FieldDescriptor
+
+ def __new__(cls, name, full_name, index, number, type, cpp_type, label,
+ default_value, message_type, enum_type, containing_type,
+ is_extension, extension_scope, options=None,
+ has_default_value=True, containing_oneof=None):
+ _message.Message._CheckCalledFromGeneratedFile()
+ if is_extension:
+ return _message.default_pool.FindExtensionByName(full_name)
+ else:
+ return _message.default_pool.FindFieldByName(full_name)
+
def __init__(self, name, full_name, index, number, type, cpp_type, label,
default_value, message_type, enum_type, containing_type,
is_extension, extension_scope, options=None,
@@ -466,6 +518,7 @@ class FieldDescriptor(DescriptorBase):
super(FieldDescriptor, self).__init__(options, 'FieldOptions')
self.name = name
self.full_name = full_name
+ self._camelcase_name = None
self.index = index
self.number = number
self.type = type
@@ -481,23 +534,18 @@ class FieldDescriptor(DescriptorBase):
self.containing_oneof = containing_oneof
if api_implementation.Type() == 'cpp':
if is_extension:
- if api_implementation.Version() == 2:
- # pylint: disable=protected-access
- self._cdescriptor = (
- _message.Message._GetExtensionDescriptor(full_name))
- # pylint: enable=protected-access
- else:
- self._cdescriptor = cpp_message.GetExtensionDescriptor(full_name)
+ self._cdescriptor = _message.default_pool.FindExtensionByName(full_name)
else:
- if api_implementation.Version() == 2:
- # pylint: disable=protected-access
- self._cdescriptor = _message.Message._GetFieldDescriptor(full_name)
- # pylint: enable=protected-access
- else:
- self._cdescriptor = cpp_message.GetFieldDescriptor(full_name)
+ self._cdescriptor = _message.default_pool.FindFieldByName(full_name)
else:
self._cdescriptor = None
+ @property
+ def camelcase_name(self):
+ if self._camelcase_name is None:
+ self._camelcase_name = _ToCamelCase(self.name)
+ return self._camelcase_name
+
@staticmethod
def ProtoTypeToCppProtoType(proto_type):
"""Converts from a Python proto type to a C++ Proto Type.
@@ -544,6 +592,15 @@ class EnumDescriptor(_NestedDescriptorBase):
None to use default enum options.
"""
+ if _USE_C_DESCRIPTORS:
+ _C_DESCRIPTOR_CLASS = _message.EnumDescriptor
+
+ def __new__(cls, name, full_name, filename, values,
+ containing_type=None, options=None, file=None,
+ serialized_start=None, serialized_end=None):
+ _message.Message._CheckCalledFromGeneratedFile()
+ return _message.default_pool.FindEnumTypeByName(full_name)
+
def __init__(self, name, full_name, filename, values,
containing_type=None, options=None, file=None,
serialized_start=None, serialized_end=None):
@@ -588,6 +645,17 @@ class EnumValueDescriptor(DescriptorBase):
None to use default enum value options options.
"""
+ if _USE_C_DESCRIPTORS:
+ _C_DESCRIPTOR_CLASS = _message.EnumValueDescriptor
+
+ def __new__(cls, name, index, number, type=None, options=None):
+ _message.Message._CheckCalledFromGeneratedFile()
+ # There is no way we can build a complete EnumValueDescriptor with the
+ # given parameters (the name of the Enum is not known, for example).
+ # Fortunately generated files just pass it to the EnumDescriptor()
+ # constructor, which will ignore it, so returning None is good enough.
+ return None
+
def __init__(self, name, index, number, type=None, options=None):
"""Arguments are as described in the attribute description above."""
super(EnumValueDescriptor, self).__init__(options, 'EnumValueOptions')
@@ -611,6 +679,13 @@ class OneofDescriptor(object):
oneof can contain.
"""
+ if _USE_C_DESCRIPTORS:
+ _C_DESCRIPTOR_CLASS = _message.OneofDescriptor
+
+ def __new__(cls, name, full_name, index, containing_type, fields):
+ _message.Message._CheckCalledFromGeneratedFile()
+ return _message.default_pool.FindOneofByName(full_name)
+
def __init__(self, name, full_name, index, containing_type, fields):
"""Arguments are as described in the attribute description above."""
self.name = name
@@ -704,36 +779,58 @@ class FileDescriptor(DescriptorBase):
name: name of file, relative to root of source tree.
package: name of the package
+ syntax: string indicating syntax of the file (can be "proto2" or "proto3")
serialized_pb: (str) Byte string of serialized
descriptor_pb2.FileDescriptorProto.
dependencies: List of other FileDescriptors this FileDescriptor depends on.
+ public_dependencies: A list of FileDescriptors, subset of the dependencies
+ above, which were declared as "public".
message_types_by_name: Dict of message names of their descriptors.
enum_types_by_name: Dict of enum names and their descriptors.
extensions_by_name: Dict of extension names and their descriptors.
+ pool: the DescriptorPool this descriptor belongs to. When not passed to the
+ constructor, the global default pool is used.
"""
+ if _USE_C_DESCRIPTORS:
+ _C_DESCRIPTOR_CLASS = _message.FileDescriptor
+
+ def __new__(cls, name, package, options=None, serialized_pb=None,
+ dependencies=None, public_dependencies=None,
+ syntax=None, pool=None):
+ # FileDescriptor() is called from various places, not only from generated
+ # files, to register dynamic proto files and messages.
+ if serialized_pb:
+ # TODO(amauryfa): use the pool passed as argument. This will work only
+ # for C++-implemented DescriptorPools.
+ return _message.default_pool.AddSerializedFile(serialized_pb)
+ else:
+ return super(FileDescriptor, cls).__new__(cls)
+
def __init__(self, name, package, options=None, serialized_pb=None,
- dependencies=None):
+ dependencies=None, public_dependencies=None,
+ syntax=None, pool=None):
"""Constructor."""
super(FileDescriptor, self).__init__(options, 'FileOptions')
+ if pool is None:
+ from google.protobuf import descriptor_pool
+ pool = descriptor_pool.Default()
+ self.pool = pool
self.message_types_by_name = {}
self.name = name
self.package = package
+ self.syntax = syntax or "proto2"
self.serialized_pb = serialized_pb
self.enum_types_by_name = {}
self.extensions_by_name = {}
self.dependencies = (dependencies or [])
+ self.public_dependencies = (public_dependencies or [])
if (api_implementation.Type() == 'cpp' and
self.serialized_pb is not None):
- if api_implementation.Version() == 2:
- # pylint: disable=protected-access
- _message.Message._BuildFile(self.serialized_pb)
- # pylint: enable=protected-access
- else:
- cpp_message.BuildFile(self.serialized_pb)
+ _message.default_pool.AddSerializedFile(self.serialized_pb)
def CopyToProto(self, proto):
"""Copies this to a descriptor_pb2.FileDescriptorProto.
@@ -754,7 +851,29 @@ def _ParseOptions(message, string):
return message
-def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True):
+def _ToCamelCase(name):
+ """Converts name to camel-case and returns it."""
+ capitalize_next = False
+ result = []
+
+ for c in name:
+ if c == '_':
+ if result:
+ capitalize_next = True
+ elif capitalize_next:
+ result.append(c.upper())
+ capitalize_next = False
+ else:
+ result += c
+
+ # Lower-case the first letter.
+ if result and result[0].isupper():
+ result[0] = result[0].lower()
+ return ''.join(result)
+
+
+def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True,
+ syntax=None):
"""Make a protobuf Descriptor given a DescriptorProto protobuf.
Handles nested descriptors. Note that this is limited to the scope of defining
@@ -766,6 +885,8 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True):
package: Optional package name for the new message Descriptor (string).
build_file_if_cpp: Update the C++ descriptor pool if api matches.
Set to False on recursion, so no duplicates are created.
+ syntax: The syntax/semantics that should be used. Set to "proto3" to get
+ proto3 field presence semantics.
Returns:
A Descriptor for protobuf messages.
"""
@@ -778,10 +899,10 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True):
file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
file_descriptor_proto.message_type.add().MergeFrom(desc_proto)
- # Generate a random name for this proto file to prevent conflicts with
- # any imported ones. We need to specify a file name so BuildFile accepts
- # our FileDescriptorProto, but it is not important what that file name
- # is actually set to.
+ # Generate a random name for this proto file to prevent conflicts with any
+ # imported ones. We need to specify a file name so the descriptor pool
+ # accepts our FileDescriptorProto, but it is not important what that file
+ # name is actually set to.
proto_name = str(uuid.uuid4())
if package:
@@ -791,12 +912,11 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True):
else:
file_descriptor_proto.name = proto_name + '.proto'
- if api_implementation.Version() == 2:
- # pylint: disable=protected-access
- _message.Message._BuildFile(file_descriptor_proto.SerializeToString())
- # pylint: enable=protected-access
- else:
- cpp_message.BuildFile(file_descriptor_proto.SerializeToString())
+ _message.default_pool.Add(file_descriptor_proto)
+ result = _message.default_pool.FindFileByName(file_descriptor_proto.name)
+
+ if _USE_C_DESCRIPTORS:
+ return result.message_types_by_name[desc_proto.name]
full_message_name = [desc_proto.name]
if package: full_message_name.insert(0, package)
@@ -819,7 +939,8 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True):
# used by fields in the message, so no loops are possible here.
nested_desc = MakeDescriptor(nested_proto,
package='.'.join(full_message_name),
- build_file_if_cpp=False)
+ build_file_if_cpp=False,
+ syntax=syntax)
nested_types[full_name] = nested_desc
fields = []
@@ -841,9 +962,10 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True):
field_proto.number, field_proto.type,
FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type),
field_proto.label, None, nested_desc, enum_desc, None, False, None,
- has_default_value=False)
+ options=field_proto.options, has_default_value=False)
fields.append(field)
desc_name = '.'.join(full_message_name)
return Descriptor(desc_proto.name, desc_name, None, None, fields,
- nested_types.values(), enum_types.values(), [])
+ list(nested_types.values()), list(enum_types.values()), [],
+ options=desc_proto.options)
diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py
index 55fb8c707..1333f9966 100644
--- a/python/google/protobuf/descriptor_database.py
+++ b/python/google/protobuf/descriptor_database.py
@@ -65,6 +65,7 @@ class DescriptorDatabase(object):
raise DescriptorDatabaseConflictingDefinitionError(
'%s already added, but with different descriptor.' % proto_name)
+ # Add the top-level Message, Enum and Extension descriptors to the index.
package = file_desc_proto.package
for message in file_desc_proto.message_type:
self._file_desc_protos_by_symbol.update(
@@ -72,6 +73,9 @@ class DescriptorDatabase(object):
for enum in file_desc_proto.enum_type:
self._file_desc_protos_by_symbol[
'.'.join((package, enum.name))] = file_desc_proto
+ for extension in file_desc_proto.extension:
+ self._file_desc_protos_by_symbol[
+ '.'.join((package, extension.name))] = file_desc_proto
def FindFileByName(self, name):
"""Finds the file descriptor proto by file name.
@@ -133,5 +137,5 @@ def _ExtractSymbols(desc_proto, package):
for nested_type in desc_proto.nested_type:
for symbol in _ExtractSymbols(nested_type, message_name):
yield symbol
- for enum_type in desc_proto.enum_type:
- yield '.'.join((message_name, enum_type.name))
+ for enum_type in desc_proto.enum_type:
+ yield '.'.join((message_name, enum_type.name))
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index cf234cfae..20a337017 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -57,13 +57,14 @@ directly instead of this class.
__author__ = 'matthewtoia@google.com (Matt Toia)'
-import sys
-
from google.protobuf import descriptor
from google.protobuf import descriptor_database
from google.protobuf import text_encoding
+_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS
+
+
def _NormalizeFullyQualifiedName(name):
"""Remove leading period from fully-qualified type name.
@@ -82,6 +83,12 @@ def _NormalizeFullyQualifiedName(name):
class DescriptorPool(object):
"""A collection of protobufs dynamically constructed by descriptor protos."""
+ if _USE_C_DESCRIPTORS:
+
+ def __new__(cls, descriptor_db=None):
+ # pylint: disable=protected-access
+ return descriptor._message.DescriptorPool(descriptor_db)
+
def __init__(self, descriptor_db=None):
"""Initializes a Pool of proto buffs.
@@ -110,6 +117,20 @@ class DescriptorPool(object):
self._internal_db.Add(file_desc_proto)
+ def AddSerializedFile(self, serialized_file_desc_proto):
+ """Adds the FileDescriptorProto and its types to this pool.
+
+ Args:
+ serialized_file_desc_proto: A bytes string, serialization of the
+ FileDescriptorProto to add.
+ """
+
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf import descriptor_pb2
+ file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
+ serialized_file_desc_proto)
+ self.Add(file_desc_proto)
+
def AddDescriptor(self, desc):
"""Adds a Descriptor to the pool, non-recursively.
@@ -175,8 +196,7 @@ class DescriptorPool(object):
try:
file_proto = self._internal_db.FindFileByName(file_name)
- except KeyError:
- _, error, _ = sys.exc_info() #PY25 compatible for GAE.
+ except KeyError as error:
if self._descriptor_db:
file_proto = self._descriptor_db.FindFileByName(file_name)
else:
@@ -211,8 +231,7 @@ class DescriptorPool(object):
try:
file_proto = self._internal_db.FindFileContainingSymbol(symbol)
- except KeyError:
- _, error, _ = sys.exc_info() #PY25 compatible for GAE.
+ except KeyError as error:
if self._descriptor_db:
file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
else:
@@ -251,6 +270,39 @@ class DescriptorPool(object):
self.FindFileContainingSymbol(full_name)
return self._enum_descriptors[full_name]
+ def FindFieldByName(self, full_name):
+ """Loads the named field descriptor from the pool.
+
+ Args:
+ full_name: The full name of the field descriptor to load.
+
+ Returns:
+ The field descriptor for the named field.
+ """
+ full_name = _NormalizeFullyQualifiedName(full_name)
+ message_name, _, field_name = full_name.rpartition('.')
+ message_descriptor = self.FindMessageTypeByName(message_name)
+ return message_descriptor.fields_by_name[field_name]
+
+ def FindExtensionByName(self, full_name):
+ """Loads the named extension descriptor from the pool.
+
+ Args:
+ full_name: The full name of the extension descriptor to load.
+
+ Returns:
+ A FieldDescriptor, describing the named extension.
+ """
+ full_name = _NormalizeFullyQualifiedName(full_name)
+ message_name, _, extension_name = full_name.rpartition('.')
+ try:
+ # Most extensions are nested inside a message.
+ scope = self.FindMessageTypeByName(message_name)
+ except KeyError:
+ # Some extensions are defined at file scope.
+ scope = self.FindFileContainingSymbol(full_name)
+ return scope.extensions_by_name[extension_name]
+
def _ConvertFileProtoToFileDescriptor(self, file_proto):
"""Creates a FileDescriptor from a proto or returns a cached copy.
@@ -267,62 +319,88 @@ class DescriptorPool(object):
if file_proto.name not in self._file_descriptors:
built_deps = list(self._GetDeps(file_proto.dependency))
direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
+ public_deps = [direct_deps[i] for i in file_proto.public_dependency]
file_descriptor = descriptor.FileDescriptor(
+ pool=self,
name=file_proto.name,
package=file_proto.package,
+ syntax=file_proto.syntax,
options=file_proto.options,
serialized_pb=file_proto.SerializeToString(),
- dependencies=direct_deps)
- scope = {}
-
- # This loop extracts all the message and enum types from all the
- # dependencoes of the file_proto. This is necessary to create the
- # scope of available message types when defining the passed in
- # file proto.
- for dependency in built_deps:
- scope.update(self._ExtractSymbols(
- dependency.message_types_by_name.values()))
- scope.update((_PrefixWithDot(enum.full_name), enum)
- for enum in dependency.enum_types_by_name.values())
-
- for message_type in file_proto.message_type:
- message_desc = self._ConvertMessageDescriptor(
- message_type, file_proto.package, file_descriptor, scope)
- file_descriptor.message_types_by_name[message_desc.name] = message_desc
-
- for enum_type in file_proto.enum_type:
- file_descriptor.enum_types_by_name[enum_type.name] = (
- self._ConvertEnumDescriptor(enum_type, file_proto.package,
- file_descriptor, None, scope))
-
- for index, extension_proto in enumerate(file_proto.extension):
- extension_desc = self.MakeFieldDescriptor(
- extension_proto, file_proto.package, index, is_extension=True)
- extension_desc.containing_type = self._GetTypeFromScope(
- file_descriptor.package, extension_proto.extendee, scope)
- self.SetFieldType(extension_proto, extension_desc,
- file_descriptor.package, scope)
- file_descriptor.extensions_by_name[extension_desc.name] = extension_desc
-
- for desc_proto in file_proto.message_type:
- self.SetAllFieldTypes(file_proto.package, desc_proto, scope)
-
- if file_proto.package:
- desc_proto_prefix = _PrefixWithDot(file_proto.package)
+ dependencies=direct_deps,
+ public_dependencies=public_deps)
+ if _USE_C_DESCRIPTORS:
+ # When using C++ descriptors, all objects defined in the file were added
+ # to the C++ database when the FileDescriptor was built above.
+ # Just add them to this descriptor pool.
+ def _AddMessageDescriptor(message_desc):
+ self._descriptors[message_desc.full_name] = message_desc
+ for nested in message_desc.nested_types:
+ _AddMessageDescriptor(nested)
+ for enum_type in message_desc.enum_types:
+ _AddEnumDescriptor(enum_type)
+ def _AddEnumDescriptor(enum_desc):
+ self._enum_descriptors[enum_desc.full_name] = enum_desc
+ for message_type in file_descriptor.message_types_by_name.values():
+ _AddMessageDescriptor(message_type)
+ for enum_type in file_descriptor.enum_types_by_name.values():
+ _AddEnumDescriptor(enum_type)
else:
- desc_proto_prefix = ''
+ scope = {}
+
+ # This loop extracts all the message and enum types from all the
+ # dependencies of the file_proto. This is necessary to create the
+ # scope of available message types when defining the passed in
+ # file proto.
+ for dependency in built_deps:
+ scope.update(self._ExtractSymbols(
+ dependency.message_types_by_name.values()))
+ scope.update((_PrefixWithDot(enum.full_name), enum)
+ for enum in dependency.enum_types_by_name.values())
+
+ for message_type in file_proto.message_type:
+ message_desc = self._ConvertMessageDescriptor(
+ message_type, file_proto.package, file_descriptor, scope,
+ file_proto.syntax)
+ file_descriptor.message_types_by_name[message_desc.name] = (
+ message_desc)
+
+ for enum_type in file_proto.enum_type:
+ file_descriptor.enum_types_by_name[enum_type.name] = (
+ self._ConvertEnumDescriptor(enum_type, file_proto.package,
+ file_descriptor, None, scope))
+
+ for index, extension_proto in enumerate(file_proto.extension):
+ extension_desc = self._MakeFieldDescriptor(
+ extension_proto, file_proto.package, index, is_extension=True)
+ extension_desc.containing_type = self._GetTypeFromScope(
+ file_descriptor.package, extension_proto.extendee, scope)
+ self._SetFieldType(extension_proto, extension_desc,
+ file_descriptor.package, scope)
+ file_descriptor.extensions_by_name[extension_desc.name] = (
+ extension_desc)
+
+ for desc_proto in file_proto.message_type:
+ self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
+
+ if file_proto.package:
+ desc_proto_prefix = _PrefixWithDot(file_proto.package)
+ else:
+ desc_proto_prefix = ''
+
+ for desc_proto in file_proto.message_type:
+ desc = self._GetTypeFromScope(
+ desc_proto_prefix, desc_proto.name, scope)
+ file_descriptor.message_types_by_name[desc_proto.name] = desc
- for desc_proto in file_proto.message_type:
- desc = self._GetTypeFromScope(desc_proto_prefix, desc_proto.name, scope)
- file_descriptor.message_types_by_name[desc_proto.name] = desc
self.Add(file_proto)
self._file_descriptors[file_proto.name] = file_descriptor
return self._file_descriptors[file_proto.name]
def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
- scope=None):
+ scope=None, syntax=None):
"""Adds the proto to the pool in the specified package.
Args:
@@ -349,15 +427,17 @@ class DescriptorPool(object):
scope = {}
nested = [
- self._ConvertMessageDescriptor(nested, desc_name, file_desc, scope)
+ self._ConvertMessageDescriptor(
+ nested, desc_name, file_desc, scope, syntax)
for nested in desc_proto.nested_type]
enums = [
self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
for enum in desc_proto.enum_type]
- fields = [self.MakeFieldDescriptor(field, desc_name, index)
+ fields = [self._MakeFieldDescriptor(field, desc_name, index)
for index, field in enumerate(desc_proto.field)]
extensions = [
- self.MakeFieldDescriptor(extension, desc_name, index, is_extension=True)
+ self._MakeFieldDescriptor(extension, desc_name, index,
+ is_extension=True)
for index, extension in enumerate(desc_proto.extension)]
oneofs = [
descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)),
@@ -383,7 +463,8 @@ class DescriptorPool(object):
extension_ranges=extension_ranges,
file=file_desc,
serialized_start=None,
- serialized_end=None)
+ serialized_end=None,
+ syntax=syntax)
for nested in desc.nested_types:
nested.containing_type = desc
for enum in desc.enum_types:
@@ -436,8 +517,8 @@ class DescriptorPool(object):
self._enum_descriptors[enum_name] = desc
return desc
- def MakeFieldDescriptor(self, field_proto, message_name, index,
- is_extension=False):
+ def _MakeFieldDescriptor(self, field_proto, message_name, index,
+ is_extension=False):
"""Creates a field descriptor from a FieldDescriptorProto.
For message and enum type fields, this method will do a look up
@@ -478,7 +559,7 @@ class DescriptorPool(object):
extension_scope=None,
options=field_proto.options)
- def SetAllFieldTypes(self, package, desc_proto, scope):
+ def _SetAllFieldTypes(self, package, desc_proto, scope):
"""Sets all the descriptor's fields's types.
This method also sets the containing types on any extensions.
@@ -499,18 +580,18 @@ class DescriptorPool(object):
nested_package = '.'.join([package, desc_proto.name])
for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
- self.SetFieldType(field_proto, field_desc, nested_package, scope)
+ self._SetFieldType(field_proto, field_desc, nested_package, scope)
for extension_proto, extension_desc in (
zip(desc_proto.extension, main_desc.extensions)):
extension_desc.containing_type = self._GetTypeFromScope(
nested_package, extension_proto.extendee, scope)
- self.SetFieldType(extension_proto, extension_desc, nested_package, scope)
+ self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
for nested_type in desc_proto.nested_type:
- self.SetAllFieldTypes(nested_package, nested_type, scope)
+ self._SetAllFieldTypes(nested_package, nested_type, scope)
- def SetFieldType(self, field_proto, field_desc, package, scope):
+ def _SetFieldType(self, field_proto, field_desc, package, scope):
"""Sets the field's type, cpp_type, message_type and enum_type.
Args:
@@ -554,15 +635,29 @@ class DescriptorPool(object):
field_desc.default_value = field_proto.default_value.lower() == 'true'
elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
field_desc.default_value = field_desc.enum_type.values_by_name[
- field_proto.default_value].index
+ field_proto.default_value].number
elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
field_desc.default_value = text_encoding.CUnescape(
field_proto.default_value)
else:
+ # All other types are of the "int" type.
field_desc.default_value = int(field_proto.default_value)
else:
field_desc.has_default_value = False
- field_desc.default_value = None
+ if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
+ field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
+ field_desc.default_value = 0.0
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
+ field_desc.default_value = u''
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
+ field_desc.default_value = False
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
+ field_desc.default_value = field_desc.enum_type.values[0].number
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ field_desc.default_value = b''
+ else:
+ # All other types are of the "int" type.
+ field_desc.default_value = 0
field_desc.type = field_proto.type
@@ -641,3 +736,16 @@ class DescriptorPool(object):
def _PrefixWithDot(name):
return name if name.startswith('.') else '.%s' % name
+
+
+if _USE_C_DESCRIPTORS:
+ # TODO(amauryfa): This pool could be constructed from Python code, when we
+ # support a flag like 'use_cpp_generated_pool=True'.
+ # pylint: disable=protected-access
+ _DEFAULT = descriptor._message.default_pool
+else:
+ _DEFAULT = DescriptorPool()
+
+
+def Default():
+ return _DEFAULT
diff --git a/python/google/protobuf/internal/_parameterized.py b/python/google/protobuf/internal/_parameterized.py
new file mode 100755
index 000000000..dea3f1997
--- /dev/null
+++ b/python/google/protobuf/internal/_parameterized.py
@@ -0,0 +1,443 @@
+#! /usr/bin/env python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Adds support for parameterized tests to Python's unittest TestCase class.
+
+A parameterized test is a method in a test case that is invoked with different
+argument tuples.
+
+A simple example:
+
+ class AdditionExample(parameterized.ParameterizedTestCase):
+ @parameterized.Parameters(
+ (1, 2, 3),
+ (4, 5, 9),
+ (1, 1, 3))
+ def testAddition(self, op1, op2, result):
+ self.assertEqual(result, op1 + op2)
+
+
+Each invocation is a separate test case and properly isolated just
+like a normal test method, with its own setUp/tearDown cycle. In the
+example above, there are three separate testcases, one of which will
+fail due to an assertion error (1 + 1 != 3).
+
+Parameters for invididual test cases can be tuples (with positional parameters)
+or dictionaries (with named parameters):
+
+ class AdditionExample(parameterized.ParameterizedTestCase):
+ @parameterized.Parameters(
+ {'op1': 1, 'op2': 2, 'result': 3},
+ {'op1': 4, 'op2': 5, 'result': 9},
+ )
+ def testAddition(self, op1, op2, result):
+ self.assertEqual(result, op1 + op2)
+
+If a parameterized test fails, the error message will show the
+original test name (which is modified internally) and the arguments
+for the specific invocation, which are part of the string returned by
+the shortDescription() method on test cases.
+
+The id method of the test, used internally by the unittest framework,
+is also modified to show the arguments. To make sure that test names
+stay the same across several invocations, object representations like
+
+ >>> class Foo(object):
+ ... pass
+ >>> repr(Foo())
+ '<__main__.Foo object at 0x23d8610>'
+
+are turned into '<__main__.Foo>'. For even more descriptive names,
+especially in test logs, you can use the NamedParameters decorator. In
+this case, only tuples are supported, and the first parameters has to
+be a string (or an object that returns an apt name when converted via
+str()):
+
+ class NamedExample(parameterized.ParameterizedTestCase):
+ @parameterized.NamedParameters(
+ ('Normal', 'aa', 'aaa', True),
+ ('EmptyPrefix', '', 'abc', True),
+ ('BothEmpty', '', '', True))
+ def testStartsWith(self, prefix, string, result):
+ self.assertEqual(result, strings.startswith(prefix))
+
+Named tests also have the benefit that they can be run individually
+from the command line:
+
+ $ testmodule.py NamedExample.testStartsWithNormal
+ .
+ --------------------------------------------------------------------
+ Ran 1 test in 0.000s
+
+ OK
+
+Parameterized Classes
+=====================
+If invocation arguments are shared across test methods in a single
+ParameterizedTestCase class, instead of decorating all test methods
+individually, the class itself can be decorated:
+
+ @parameterized.Parameters(
+ (1, 2, 3)
+ (4, 5, 9))
+ class ArithmeticTest(parameterized.ParameterizedTestCase):
+ def testAdd(self, arg1, arg2, result):
+ self.assertEqual(arg1 + arg2, result)
+
+ def testSubtract(self, arg2, arg2, result):
+ self.assertEqual(result - arg1, arg2)
+
+Inputs from Iterables
+=====================
+If parameters should be shared across several test cases, or are dynamically
+created from other sources, a single non-tuple iterable can be passed into
+the decorator. This iterable will be used to obtain the test cases:
+
+ class AdditionExample(parameterized.ParameterizedTestCase):
+ @parameterized.Parameters(
+ c.op1, c.op2, c.result for c in testcases
+ )
+ def testAddition(self, op1, op2, result):
+ self.assertEqual(result, op1 + op2)
+
+
+Single-Argument Test Methods
+============================
+If a test method takes only one argument, the single argument does not need to
+be wrapped into a tuple:
+
+ class NegativeNumberExample(parameterized.ParameterizedTestCase):
+ @parameterized.Parameters(
+ -1, -3, -4, -5
+ )
+ def testIsNegative(self, arg):
+ self.assertTrue(IsNegative(arg))
+"""
+
+__author__ = 'tmarek@google.com (Torsten Marek)'
+
+import collections
+import functools
+import re
+import types
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+import uuid
+
+import six
+
+ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>')
+_SEPARATOR = uuid.uuid1().hex
+_FIRST_ARG = object()
+_ARGUMENT_REPR = object()
+
+
+def _CleanRepr(obj):
+ return ADDR_RE.sub(r'<\1>', repr(obj))
+
+
+# Helper function formerly from the unittest module, removed from it in
+# Python 2.7.
+def _StrClass(cls):
+ return '%s.%s' % (cls.__module__, cls.__name__)
+
+
+def _NonStringIterable(obj):
+ return (isinstance(obj, collections.Iterable) and not
+ isinstance(obj, six.string_types))
+
+
+def _FormatParameterList(testcase_params):
+ if isinstance(testcase_params, collections.Mapping):
+ return ', '.join('%s=%s' % (argname, _CleanRepr(value))
+ for argname, value in testcase_params.items())
+ elif _NonStringIterable(testcase_params):
+ return ', '.join(map(_CleanRepr, testcase_params))
+ else:
+ return _FormatParameterList((testcase_params,))
+
+
+class _ParameterizedTestIter(object):
+ """Callable and iterable class for producing new test cases."""
+
+ def __init__(self, test_method, testcases, naming_type):
+ """Returns concrete test functions for a test and a list of parameters.
+
+ The naming_type is used to determine the name of the concrete
+ functions as reported by the unittest framework. If naming_type is
+ _FIRST_ARG, the testcases must be tuples, and the first element must
+ have a string representation that is a valid Python identifier.
+
+ Args:
+ test_method: The decorated test method.
+ testcases: (list of tuple/dict) A list of parameter
+ tuples/dicts for individual test invocations.
+ naming_type: The test naming type, either _NAMED or _ARGUMENT_REPR.
+ """
+ self._test_method = test_method
+ self.testcases = testcases
+ self._naming_type = naming_type
+
+ def __call__(self, *args, **kwargs):
+ raise RuntimeError('You appear to be running a parameterized test case '
+ 'without having inherited from parameterized.'
+ 'ParameterizedTestCase. This is bad because none of '
+ 'your test cases are actually being run.')
+
+ def __iter__(self):
+ test_method = self._test_method
+ naming_type = self._naming_type
+
+ def MakeBoundParamTest(testcase_params):
+ @functools.wraps(test_method)
+ def BoundParamTest(self):
+ if isinstance(testcase_params, collections.Mapping):
+ test_method(self, **testcase_params)
+ elif _NonStringIterable(testcase_params):
+ test_method(self, *testcase_params)
+ else:
+ test_method(self, testcase_params)
+
+ if naming_type is _FIRST_ARG:
+ # Signal the metaclass that the name of the test function is unique
+ # and descriptive.
+ BoundParamTest.__x_use_name__ = True
+ BoundParamTest.__name__ += str(testcase_params[0])
+ testcase_params = testcase_params[1:]
+ elif naming_type is _ARGUMENT_REPR:
+ # __x_extra_id__ is used to pass naming information to the __new__
+ # method of TestGeneratorMetaclass.
+ # The metaclass will make sure to create a unique, but nondescriptive
+ # name for this test.
+ BoundParamTest.__x_extra_id__ = '(%s)' % (
+ _FormatParameterList(testcase_params),)
+ else:
+ raise RuntimeError('%s is not a valid naming type.' % (naming_type,))
+
+ BoundParamTest.__doc__ = '%s(%s)' % (
+ BoundParamTest.__name__, _FormatParameterList(testcase_params))
+ if test_method.__doc__:
+ BoundParamTest.__doc__ += '\n%s' % (test_method.__doc__,)
+ return BoundParamTest
+ return (MakeBoundParamTest(c) for c in self.testcases)
+
+
+def _IsSingletonList(testcases):
+ """True iff testcases contains only a single non-tuple element."""
+ return len(testcases) == 1 and not isinstance(testcases[0], tuple)
+
+
+def _ModifyClass(class_object, testcases, naming_type):
+ assert not getattr(class_object, '_id_suffix', None), (
+ 'Cannot add parameters to %s,'
+ ' which already has parameterized methods.' % (class_object,))
+ class_object._id_suffix = id_suffix = {}
+ # We change the size of __dict__ while we iterate over it,
+ # which Python 3.x will complain about, so use copy().
+ for name, obj in class_object.__dict__.copy().items():
+ if (name.startswith(unittest.TestLoader.testMethodPrefix)
+ and isinstance(obj, types.FunctionType)):
+ delattr(class_object, name)
+ methods = {}
+ _UpdateClassDictForParamTestCase(
+ methods, id_suffix, name,
+ _ParameterizedTestIter(obj, testcases, naming_type))
+ for name, meth in methods.items():
+ setattr(class_object, name, meth)
+
+
+def _ParameterDecorator(naming_type, testcases):
+ """Implementation of the parameterization decorators.
+
+ Args:
+ naming_type: The naming type.
+ testcases: Testcase parameters.
+
+ Returns:
+ A function for modifying the decorated object.
+ """
+ def _Apply(obj):
+ if isinstance(obj, type):
+ _ModifyClass(
+ obj,
+ list(testcases) if not isinstance(testcases, collections.Sequence)
+ else testcases,
+ naming_type)
+ return obj
+ else:
+ return _ParameterizedTestIter(obj, testcases, naming_type)
+
+ if _IsSingletonList(testcases):
+ assert _NonStringIterable(testcases[0]), (
+ 'Single parameter argument must be a non-string iterable')
+ testcases = testcases[0]
+
+ return _Apply
+
+
+def Parameters(*testcases):
+ """A decorator for creating parameterized tests.
+
+ See the module docstring for a usage example.
+ Args:
+ *testcases: Parameters for the decorated method, either a single
+ iterable, or a list of tuples/dicts/objects (for tests
+ with only one argument).
+
+ Returns:
+ A test generator to be handled by TestGeneratorMetaclass.
+ """
+ return _ParameterDecorator(_ARGUMENT_REPR, testcases)
+
+
+def NamedParameters(*testcases):
+ """A decorator for creating parameterized tests.
+
+ See the module docstring for a usage example. The first element of
+ each parameter tuple should be a string and will be appended to the
+ name of the test method.
+
+ Args:
+ *testcases: Parameters for the decorated method, either a single
+ iterable, or a list of tuples.
+
+ Returns:
+ A test generator to be handled by TestGeneratorMetaclass.
+ """
+ return _ParameterDecorator(_FIRST_ARG, testcases)
+
+
+class TestGeneratorMetaclass(type):
+ """Metaclass for test cases with test generators.
+
+ A test generator is an iterable in a testcase that produces callables. These
+ callables must be single-argument methods. These methods are injected into
+ the class namespace and the original iterable is removed. If the name of the
+ iterable conforms to the test pattern, the injected methods will be picked
+ up as tests by the unittest framework.
+
+ In general, it is supposed to be used in conjuction with the
+ Parameters decorator.
+ """
+
+ def __new__(mcs, class_name, bases, dct):
+ dct['_id_suffix'] = id_suffix = {}
+ for name, obj in dct.items():
+ if (name.startswith(unittest.TestLoader.testMethodPrefix) and
+ _NonStringIterable(obj)):
+ iterator = iter(obj)
+ dct.pop(name)
+ _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator)
+
+ return type.__new__(mcs, class_name, bases, dct)
+
+
+def _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator):
+ """Adds individual test cases to a dictionary.
+
+ Args:
+ dct: The target dictionary.
+ id_suffix: The dictionary for mapping names to test IDs.
+ name: The original name of the test case.
+ iterator: The iterator generating the individual test cases.
+ """
+ for idx, func in enumerate(iterator):
+ assert callable(func), 'Test generators must yield callables, got %r' % (
+ func,)
+ if getattr(func, '__x_use_name__', False):
+ new_name = func.__name__
+ else:
+ new_name = '%s%s%d' % (name, _SEPARATOR, idx)
+ assert new_name not in dct, (
+ 'Name of parameterized test case "%s" not unique' % (new_name,))
+ dct[new_name] = func
+ id_suffix[new_name] = getattr(func, '__x_extra_id__', '')
+
+
+class ParameterizedTestCase(unittest.TestCase):
+ """Base class for test cases using the Parameters decorator."""
+ __metaclass__ = TestGeneratorMetaclass
+
+ def _OriginalName(self):
+ return self._testMethodName.split(_SEPARATOR)[0]
+
+ def __str__(self):
+ return '%s (%s)' % (self._OriginalName(), _StrClass(self.__class__))
+
+ def id(self): # pylint: disable=invalid-name
+ """Returns the descriptive ID of the test.
+
+ This is used internally by the unittesting framework to get a name
+ for the test to be used in reports.
+
+ Returns:
+ The test id.
+ """
+ return '%s.%s%s' % (_StrClass(self.__class__),
+ self._OriginalName(),
+ self._id_suffix.get(self._testMethodName, ''))
+
+
+def CoopParameterizedTestCase(other_base_class):
+ """Returns a new base class with a cooperative metaclass base.
+
+ This enables the ParameterizedTestCase to be used in combination
+ with other base classes that have custom metaclasses, such as
+ mox.MoxTestBase.
+
+ Only works with metaclasses that do not override type.__new__.
+
+ Example:
+
+ import google3
+ import mox
+
+ from google3.testing.pybase import parameterized
+
+ class ExampleTest(parameterized.CoopParameterizedTestCase(mox.MoxTestBase)):
+ ...
+
+ Args:
+ other_base_class: (class) A test case base class.
+
+ Returns:
+ A new class object.
+ """
+ metaclass = type(
+ 'CoopMetaclass',
+ (other_base_class.__metaclass__,
+ TestGeneratorMetaclass), {})
+ return metaclass(
+ 'CoopParameterizedTestCase',
+ (other_base_class, ParameterizedTestCase), {})
diff --git a/python/google/protobuf/internal/any_test.proto b/python/google/protobuf/internal/any_test.proto
new file mode 100644
index 000000000..cd641ca0b
--- /dev/null
+++ b/python/google/protobuf/internal/any_test.proto
@@ -0,0 +1,42 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: jieluo@google.com (Jie Luo)
+
+syntax = "proto3";
+
+package google.protobuf.internal;
+
+import "google/protobuf/any.proto";
+
+message TestAny {
+ google.protobuf.Any value = 1;
+ int32 int_value = 2;
+}
diff --git a/python/google/protobuf/internal/api_implementation.cc b/python/google/protobuf/internal/api_implementation.cc
index 83db40b11..6db12e8dc 100644
--- a/python/google/protobuf/internal/api_implementation.cc
+++ b/python/google/protobuf/internal/api_implementation.cc
@@ -50,10 +50,7 @@ namespace python {
// and
// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2
#ifdef PYTHON_PROTO2_CPP_IMPL_V1
-#if PY_MAJOR_VERSION >= 3
-#error "PYTHON_PROTO2_CPP_IMPL_V1 is not supported under Python 3."
-#endif
-static int kImplVersion = 1;
+#error "PYTHON_PROTO2_CPP_IMPL_V1 is no longer supported."
#else
#ifdef PYTHON_PROTO2_CPP_IMPL_V2
static int kImplVersion = 2;
@@ -62,14 +59,7 @@ static int kImplVersion = 2;
static int kImplVersion = 0;
#else
-// The defaults are set here. Python 3 uses the fast C++ APIv2 by default.
-// Python 2 still uses the Python version by default until some compatibility
-// issues can be worked around.
-#if PY_MAJOR_VERSION >= 3
-static int kImplVersion = 2;
-#else
-static int kImplVersion = 0;
-#endif
+static int kImplVersion = -1; // -1 means "Unspecified by compiler flags".
#endif // PYTHON_PROTO2_PYTHON_IMPL
#endif // PYTHON_PROTO2_CPP_IMPL_V2
diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py
index f7926c16a..460a4a6c2 100755
--- a/python/google/protobuf/internal/api_implementation.py
+++ b/python/google/protobuf/internal/api_implementation.py
@@ -32,6 +32,7 @@
"""
import os
+import warnings
import sys
try:
@@ -40,14 +41,33 @@ try:
# The compile-time constants in the _api_implementation module can be used to
# switch to a certain implementation of the Python API at build time.
_api_version = _api_implementation.api_version
- del _api_implementation
+ _proto_extension_modules_exist_in_build = True
except ImportError:
- _api_version = 0
+ _api_version = -1 # Unspecified by compiler flags.
+ _proto_extension_modules_exist_in_build = False
+
+if _api_version == 1:
+ raise ValueError('api_version=1 is no longer supported.')
+if _api_version < 0: # Still unspecified?
+ try:
+ # The presence of this module in a build allows the proto implementation to
+ # be upgraded merely via build deps rather than a compiler flag or the
+ # runtime environment variable.
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf import _use_fast_cpp_protos
+ # Work around a known issue in the classic bootstrap .par import hook.
+ if not _use_fast_cpp_protos:
+ raise ImportError('_use_fast_cpp_protos import succeeded but was None')
+ del _use_fast_cpp_protos
+ _api_version = 2
+ except ImportError:
+ if _proto_extension_modules_exist_in_build:
+ if sys.version_info[0] >= 3: # Python 3 defaults to C++ impl v2.
+ _api_version = 2
+ # TODO(b/17427486): Make Python 2 default to C++ impl v2.
_default_implementation_type = (
- 'python' if _api_version == 0 else 'cpp')
-_default_version_str = (
- '1' if _api_version <= 1 else '2')
+ 'python' if _api_version <= 0 else 'cpp')
# This environment variable can be used to switch to a certain implementation
# of the Python API, overriding the compile-time constants in the
@@ -59,18 +79,22 @@ _implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION',
if _implementation_type != 'python':
_implementation_type = 'cpp'
+if 'PyPy' in sys.version and _implementation_type == 'cpp':
+ warnings.warn('PyPy does not work yet with cpp protocol buffers. '
+ 'Falling back to the python implementation.')
+ _implementation_type = 'python'
+
# This environment variable can be used to switch between the two
# 'cpp' implementations, overriding the compile-time constants in the
-# _api_implementation module. Right now only 1 and 2 are valid values. Any other
-# value will be ignored.
+# _api_implementation module. Right now only '2' is supported. Any other
+# value will cause an error to be raised.
_implementation_version_str = os.getenv(
- 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION',
- _default_version_str)
+ 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', '2')
-if _implementation_version_str not in ('1', '2'):
+if _implementation_version_str != '2':
raise ValueError(
- "unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: '" +
- _implementation_version_str + "' (supported versions: 1, 2)"
+ 'unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: "' +
+ _implementation_version_str + '" (supported versions: 2)'
)
_implementation_version = int(_implementation_version_str)
diff --git a/python/google/protobuf/internal/api_implementation_default_test.py b/python/google/protobuf/internal/api_implementation_default_test.py
deleted file mode 100644
index 78d5cf232..000000000
--- a/python/google/protobuf/internal/api_implementation_default_test.py
+++ /dev/null
@@ -1,63 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test that the api_implementation defaults are what we expect."""
-
-import os
-import sys
-# Clear environment implementation settings before the google3 imports.
-os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', None)
-os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', None)
-
-# pylint: disable=g-import-not-at-top
-from google.apputils import basetest
-from google.protobuf.internal import api_implementation
-
-
-class ApiImplementationDefaultTest(basetest.TestCase):
-
- if sys.version_info.major <= 2:
-
- def testThatPythonIsTheDefault(self):
- """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail."""
- self.assertEqual('python', api_implementation.Type())
-
- else:
-
- def testThatCppApiV2IsTheDefault(self):
- """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail."""
- self.assertEqual('cpp', api_implementation.Type())
- self.assertEqual(2, api_implementation.Version())
-
-
-if __name__ == '__main__':
- basetest.main()
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index 20bfa8570..97cdd848e 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -41,6 +41,146 @@ are:
__author__ = 'petar@google.com (Petar Petrov)'
+import collections
+import sys
+
+if sys.version_info[0] < 3:
+ # We would use collections.MutableMapping all the time, but in Python 2 it
+ # doesn't define __slots__. This causes two significant problems:
+ #
+ # 1. we can't disallow arbitrary attribute assignment, even if our derived
+ # classes *do* define __slots__.
+ #
+ # 2. we can't safely derive a C type from it without __slots__ defined (the
+ # interpreter expects to find a dict at tp_dictoffset, which we can't
+ # robustly provide. And we don't want an instance dict anyway.
+ #
+ # So this is the Python 2.7 definition of Mapping/MutableMapping functions
+ # verbatim, except that:
+ # 1. We declare __slots__.
+ # 2. We don't declare this as a virtual base class. The classes defined
+ # in collections are the interesting base classes, not us.
+ #
+ # Note: deriving from object is critical. It is the only thing that makes
+ # this a true type, allowing us to derive from it in C++ cleanly and making
+ # __slots__ properly disallow arbitrary element assignment.
+
+ class Mapping(object):
+ __slots__ = ()
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ def __contains__(self, key):
+ try:
+ self[key]
+ except KeyError:
+ return False
+ else:
+ return True
+
+ def iterkeys(self):
+ return iter(self)
+
+ def itervalues(self):
+ for key in self:
+ yield self[key]
+
+ def iteritems(self):
+ for key in self:
+ yield (key, self[key])
+
+ def keys(self):
+ return list(self)
+
+ def items(self):
+ return [(key, self[key]) for key in self]
+
+ def values(self):
+ return [self[key] for key in self]
+
+ # Mappings are not hashable by default, but subclasses can change this
+ __hash__ = None
+
+ def __eq__(self, other):
+ if not isinstance(other, collections.Mapping):
+ return NotImplemented
+ return dict(self.items()) == dict(other.items())
+
+ def __ne__(self, other):
+ return not (self == other)
+
+ class MutableMapping(Mapping):
+ __slots__ = ()
+
+ __marker = object()
+
+ def pop(self, key, default=__marker):
+ try:
+ value = self[key]
+ except KeyError:
+ if default is self.__marker:
+ raise
+ return default
+ else:
+ del self[key]
+ return value
+
+ def popitem(self):
+ try:
+ key = next(iter(self))
+ except StopIteration:
+ raise KeyError
+ value = self[key]
+ del self[key]
+ return key, value
+
+ def clear(self):
+ try:
+ while True:
+ self.popitem()
+ except KeyError:
+ pass
+
+ def update(*args, **kwds):
+ if len(args) > 2:
+ raise TypeError("update() takes at most 2 positional "
+ "arguments ({} given)".format(len(args)))
+ elif not args:
+ raise TypeError("update() takes at least 1 argument (0 given)")
+ self = args[0]
+ other = args[1] if len(args) >= 2 else ()
+
+ if isinstance(other, Mapping):
+ for key in other:
+ self[key] = other[key]
+ elif hasattr(other, "keys"):
+ for key in other.keys():
+ self[key] = other[key]
+ else:
+ for key, value in other:
+ self[key] = value
+ for key, value in kwds.items():
+ self[key] = value
+
+ def setdefault(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ self[key] = default
+ return default
+
+ collections.Mapping.register(Mapping)
+ collections.MutableMapping.register(MutableMapping)
+
+else:
+ # In Python 3 we can just use MutableMapping directly, because it defines
+ # __slots__.
+ MutableMapping = collections.MutableMapping
+
class BaseContainer(object):
@@ -119,15 +259,23 @@ class RepeatedScalarFieldContainer(BaseContainer):
self._message_listener.Modified()
def extend(self, elem_seq):
- """Extends by appending the given sequence. Similar to list.extend()."""
- if not elem_seq:
- return
+ """Extends by appending the given iterable. Similar to list.extend()."""
- new_values = []
- for elem in elem_seq:
- new_values.append(self._type_checker.CheckValue(elem))
- self._values.extend(new_values)
- self._message_listener.Modified()
+ if elem_seq is None:
+ return
+ try:
+ elem_seq_iter = iter(elem_seq)
+ except TypeError:
+ if not elem_seq:
+ # silently ignore falsy inputs :-/.
+ # TODO(ptucker): Deprecate this behavior. b/18413862
+ return
+ raise
+
+ new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
+ if new_values:
+ self._values.extend(new_values)
+ self._message_listener.Modified()
def MergeFrom(self, other):
"""Appends the contents of another repeated field of the same type to this
@@ -141,6 +289,12 @@ class RepeatedScalarFieldContainer(BaseContainer):
self._values.remove(elem)
self._message_listener.Modified()
+ def pop(self, key=-1):
+ """Removes and returns an item at a given index. Similar to list.pop()."""
+ value = self._values[key]
+ self.__delitem__(key)
+ return value
+
def __setitem__(self, key, value):
"""Sets the item on the specified position."""
if isinstance(key, slice): # PY3
@@ -183,6 +337,8 @@ class RepeatedScalarFieldContainer(BaseContainer):
# We are presumably comparing against some other sequence type.
return other == self._values
+collections.MutableSequence.register(BaseContainer)
+
class RepeatedCompositeFieldContainer(BaseContainer):
@@ -245,6 +401,12 @@ class RepeatedCompositeFieldContainer(BaseContainer):
self._values.remove(elem)
self._message_listener.Modified()
+ def pop(self, key=-1):
+ """Removes and returns an item at a given index. Similar to list.pop()."""
+ value = self._values[key]
+ self.__delitem__(key)
+ return value
+
def __getslice__(self, start, stop):
"""Retrieves the subset of items from between the specified indices."""
return self._values[start:stop]
@@ -267,3 +429,183 @@ class RepeatedCompositeFieldContainer(BaseContainer):
raise TypeError('Can only compare repeated composite fields against '
'other repeated composite fields.')
return self._values == other._values
+
+
+class ScalarMap(MutableMapping):
+
+ """Simple, type-checked, dict-like container for holding repeated scalars."""
+
+ # Disallows assignment to other attributes.
+ __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener']
+
+ def __init__(self, message_listener, key_checker, value_checker):
+ """
+ Args:
+ message_listener: A MessageListener implementation.
+ The ScalarMap will call this object's Modified() method when it
+ is modified.
+ key_checker: A type_checkers.ValueChecker instance to run on keys
+ inserted into this container.
+ value_checker: A type_checkers.ValueChecker instance to run on values
+ inserted into this container.
+ """
+ self._message_listener = message_listener
+ self._key_checker = key_checker
+ self._value_checker = value_checker
+ self._values = {}
+
+ def __getitem__(self, key):
+ try:
+ return self._values[key]
+ except KeyError:
+ key = self._key_checker.CheckValue(key)
+ val = self._value_checker.DefaultValue()
+ self._values[key] = val
+ return val
+
+ def __contains__(self, item):
+ # We check the key's type to match the strong-typing flavor of the API.
+ # Also this makes it easier to match the behavior of the C++ implementation.
+ self._key_checker.CheckValue(item)
+ return item in self._values
+
+ # We need to override this explicitly, because our defaultdict-like behavior
+ # will make the default implementation (from our base class) always insert
+ # the key.
+ def get(self, key, default=None):
+ if key in self:
+ return self[key]
+ else:
+ return default
+
+ def __setitem__(self, key, value):
+ checked_key = self._key_checker.CheckValue(key)
+ checked_value = self._value_checker.CheckValue(value)
+ self._values[checked_key] = checked_value
+ self._message_listener.Modified()
+
+ def __delitem__(self, key):
+ del self._values[key]
+ self._message_listener.Modified()
+
+ def __len__(self):
+ return len(self._values)
+
+ def __iter__(self):
+ return iter(self._values)
+
+ def __repr__(self):
+ return repr(self._values)
+
+ def MergeFrom(self, other):
+ self._values.update(other._values)
+ self._message_listener.Modified()
+
+ def InvalidateIterators(self):
+ # It appears that the only way to reliably invalidate iterators to
+ # self._values is to ensure that its size changes.
+ original = self._values
+ self._values = original.copy()
+ original[None] = None
+
+ # This is defined in the abstract base, but we can do it much more cheaply.
+ def clear(self):
+ self._values.clear()
+ self._message_listener.Modified()
+
+
+class MessageMap(MutableMapping):
+
+ """Simple, type-checked, dict-like container for with submessage values."""
+
+ # Disallows assignment to other attributes.
+ __slots__ = ['_key_checker', '_values', '_message_listener',
+ '_message_descriptor']
+
+ def __init__(self, message_listener, message_descriptor, key_checker):
+ """
+ Args:
+ message_listener: A MessageListener implementation.
+ The ScalarMap will call this object's Modified() method when it
+ is modified.
+ key_checker: A type_checkers.ValueChecker instance to run on keys
+ inserted into this container.
+ value_checker: A type_checkers.ValueChecker instance to run on values
+ inserted into this container.
+ """
+ self._message_listener = message_listener
+ self._message_descriptor = message_descriptor
+ self._key_checker = key_checker
+ self._values = {}
+
+ def __getitem__(self, key):
+ try:
+ return self._values[key]
+ except KeyError:
+ key = self._key_checker.CheckValue(key)
+ new_element = self._message_descriptor._concrete_class()
+ new_element._SetListener(self._message_listener)
+ self._values[key] = new_element
+ self._message_listener.Modified()
+
+ return new_element
+
+ def get_or_create(self, key):
+ """get_or_create() is an alias for getitem (ie. map[key]).
+
+ Args:
+ key: The key to get or create in the map.
+
+ This is useful in cases where you want to be explicit that the call is
+ mutating the map. This can avoid lint errors for statements like this
+ that otherwise would appear to be pointless statements:
+
+ msg.my_map[key]
+ """
+ return self[key]
+
+ # We need to override this explicitly, because our defaultdict-like behavior
+ # will make the default implementation (from our base class) always insert
+ # the key.
+ def get(self, key, default=None):
+ if key in self:
+ return self[key]
+ else:
+ return default
+
+ def __contains__(self, item):
+ return item in self._values
+
+ def __setitem__(self, key, value):
+ raise ValueError('May not set values directly, call my_map[key].foo = 5')
+
+ def __delitem__(self, key):
+ del self._values[key]
+ self._message_listener.Modified()
+
+ def __len__(self):
+ return len(self._values)
+
+ def __iter__(self):
+ return iter(self._values)
+
+ def __repr__(self):
+ return repr(self._values)
+
+ def MergeFrom(self, other):
+ for key in other:
+ self[key].MergeFrom(other[key])
+ # self._message_listener.Modified() not required here, because
+ # mutations to submessages already propagate.
+
+ def InvalidateIterators(self):
+ # It appears that the only way to reliably invalidate iterators to
+ # self._values is to ensure that its size changes.
+ original = self._values
+ self._values = original.copy()
+ original[None] = None
+
+ # This is defined in the abstract base, but we can do it much more cheaply.
+ def clear(self):
+ self._values.clear()
+ self._message_listener.Modified()
diff --git a/python/google/protobuf/internal/cpp_message.py b/python/google/protobuf/internal/cpp_message.py
deleted file mode 100755
index 0313cb0bc..000000000
--- a/python/google/protobuf/internal/cpp_message.py
+++ /dev/null
@@ -1,663 +0,0 @@
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Contains helper functions used to create protocol message classes from
-Descriptor objects at runtime backed by the protocol buffer C++ API.
-"""
-
-__author__ = 'petar@google.com (Petar Petrov)'
-
-import copy_reg
-import operator
-from google.protobuf.internal import _net_proto2___python
-from google.protobuf.internal import enum_type_wrapper
-from google.protobuf import message
-
-
-_LABEL_REPEATED = _net_proto2___python.LABEL_REPEATED
-_LABEL_OPTIONAL = _net_proto2___python.LABEL_OPTIONAL
-_CPPTYPE_MESSAGE = _net_proto2___python.CPPTYPE_MESSAGE
-_TYPE_MESSAGE = _net_proto2___python.TYPE_MESSAGE
-
-
-def GetDescriptorPool():
- """Creates a new DescriptorPool C++ object."""
- return _net_proto2___python.NewCDescriptorPool()
-
-
-_pool = GetDescriptorPool()
-
-
-def GetFieldDescriptor(full_field_name):
- """Searches for a field descriptor given a full field name."""
- return _pool.FindFieldByName(full_field_name)
-
-
-def BuildFile(content):
- """Registers a new proto file in the underlying C++ descriptor pool."""
- _net_proto2___python.BuildFile(content)
-
-
-def GetExtensionDescriptor(full_extension_name):
- """Searches for extension descriptor given a full field name."""
- return _pool.FindExtensionByName(full_extension_name)
-
-
-def NewCMessage(full_message_name):
- """Creates a new C++ protocol message by its name."""
- return _net_proto2___python.NewCMessage(full_message_name)
-
-
-def ScalarProperty(cdescriptor):
- """Returns a scalar property for the given descriptor."""
-
- def Getter(self):
- return self._cmsg.GetScalar(cdescriptor)
-
- def Setter(self, value):
- self._cmsg.SetScalar(cdescriptor, value)
-
- return property(Getter, Setter)
-
-
-def CompositeProperty(cdescriptor, message_type):
- """Returns a Python property the given composite field."""
-
- def Getter(self):
- sub_message = self._composite_fields.get(cdescriptor.name, None)
- if sub_message is None:
- cmessage = self._cmsg.NewSubMessage(cdescriptor)
- sub_message = message_type._concrete_class(__cmessage=cmessage)
- self._composite_fields[cdescriptor.name] = sub_message
- return sub_message
-
- return property(Getter)
-
-
-class RepeatedScalarContainer(object):
- """Container for repeated scalar fields."""
-
- __slots__ = ['_message', '_cfield_descriptor', '_cmsg']
-
- def __init__(self, msg, cfield_descriptor):
- self._message = msg
- self._cmsg = msg._cmsg
- self._cfield_descriptor = cfield_descriptor
-
- def append(self, value):
- self._cmsg.AddRepeatedScalar(
- self._cfield_descriptor, value)
-
- def extend(self, sequence):
- for element in sequence:
- self.append(element)
-
- def insert(self, key, value):
- values = self[slice(None, None, None)]
- values.insert(key, value)
- self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
-
- def remove(self, value):
- values = self[slice(None, None, None)]
- values.remove(value)
- self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
-
- def __setitem__(self, key, value):
- values = self[slice(None, None, None)]
- values[key] = value
- self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
-
- def __getitem__(self, key):
- return self._cmsg.GetRepeatedScalar(self._cfield_descriptor, key)
-
- def __delitem__(self, key):
- self._cmsg.DeleteRepeatedField(self._cfield_descriptor, key)
-
- def __len__(self):
- return len(self[slice(None, None, None)])
-
- def __eq__(self, other):
- if self is other:
- return True
- if not operator.isSequenceType(other):
- raise TypeError(
- 'Can only compare repeated scalar fields against sequences.')
- # We are presumably comparing against some other sequence type.
- return other == self[slice(None, None, None)]
-
- def __ne__(self, other):
- return not self == other
-
- def __hash__(self):
- raise TypeError('unhashable object')
-
- def sort(self, *args, **kwargs):
- # Maintain compatibility with the previous interface.
- if 'sort_function' in kwargs:
- kwargs['cmp'] = kwargs.pop('sort_function')
- self._cmsg.AssignRepeatedScalar(self._cfield_descriptor,
- sorted(self, *args, **kwargs))
-
-
-def RepeatedScalarProperty(cdescriptor):
- """Returns a Python property the given repeated scalar field."""
-
- def Getter(self):
- container = self._composite_fields.get(cdescriptor.name, None)
- if container is None:
- container = RepeatedScalarContainer(self, cdescriptor)
- self._composite_fields[cdescriptor.name] = container
- return container
-
- def Setter(self, new_value):
- raise AttributeError('Assignment not allowed to repeated field '
- '"%s" in protocol message object.' % cdescriptor.name)
-
- doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name
- return property(Getter, Setter, doc=doc)
-
-
-class RepeatedCompositeContainer(object):
- """Container for repeated composite fields."""
-
- __slots__ = ['_message', '_subclass', '_cfield_descriptor', '_cmsg']
-
- def __init__(self, msg, cfield_descriptor, subclass):
- self._message = msg
- self._cmsg = msg._cmsg
- self._subclass = subclass
- self._cfield_descriptor = cfield_descriptor
-
- def add(self, **kwargs):
- cmessage = self._cmsg.AddMessage(self._cfield_descriptor)
- return self._subclass(__cmessage=cmessage, __owner=self._message, **kwargs)
-
- def extend(self, elem_seq):
- """Extends by appending the given sequence of elements of the same type
- as this one, copying each individual message.
- """
- for message in elem_seq:
- self.add().MergeFrom(message)
-
- def remove(self, value):
- # TODO(protocol-devel): This is inefficient as it needs to generate a
- # message pointer for each message only to do index(). Move this to a C++
- # extension function.
- self.__delitem__(self[slice(None, None, None)].index(value))
-
- def MergeFrom(self, other):
- for message in other[:]:
- self.add().MergeFrom(message)
-
- def __getitem__(self, key):
- cmessages = self._cmsg.GetRepeatedMessage(
- self._cfield_descriptor, key)
- subclass = self._subclass
- if not isinstance(cmessages, list):
- return subclass(__cmessage=cmessages, __owner=self._message)
-
- return [subclass(__cmessage=m, __owner=self._message) for m in cmessages]
-
- def __delitem__(self, key):
- self._cmsg.DeleteRepeatedField(
- self._cfield_descriptor, key)
-
- def __len__(self):
- return self._cmsg.FieldLength(self._cfield_descriptor)
-
- def __eq__(self, other):
- """Compares the current instance with another one."""
- if self is other:
- return True
- if not isinstance(other, self.__class__):
- raise TypeError('Can only compare repeated composite fields against '
- 'other repeated composite fields.')
- messages = self[slice(None, None, None)]
- other_messages = other[slice(None, None, None)]
- return messages == other_messages
-
- def __hash__(self):
- raise TypeError('unhashable object')
-
- def sort(self, cmp=None, key=None, reverse=False, **kwargs):
- # Maintain compatibility with the old interface.
- if cmp is None and 'sort_function' in kwargs:
- cmp = kwargs.pop('sort_function')
-
- # The cmp function, if provided, is passed the results of the key function,
- # so we only need to wrap one of them.
- if key is None:
- index_key = self.__getitem__
- else:
- index_key = lambda i: key(self[i])
-
- # Sort the list of current indexes by the underlying object.
- indexes = range(len(self))
- indexes.sort(cmp=cmp, key=index_key, reverse=reverse)
-
- # Apply the transposition.
- for dest, src in enumerate(indexes):
- if dest == src:
- continue
- self._cmsg.SwapRepeatedFieldElements(self._cfield_descriptor, dest, src)
- # Don't swap the same value twice.
- indexes[src] = src
-
-
-def RepeatedCompositeProperty(cdescriptor, message_type):
- """Returns a Python property for the given repeated composite field."""
-
- def Getter(self):
- container = self._composite_fields.get(cdescriptor.name, None)
- if container is None:
- container = RepeatedCompositeContainer(
- self, cdescriptor, message_type._concrete_class)
- self._composite_fields[cdescriptor.name] = container
- return container
-
- def Setter(self, new_value):
- raise AttributeError('Assignment not allowed to repeated field '
- '"%s" in protocol message object.' % cdescriptor.name)
-
- doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name
- return property(Getter, Setter, doc=doc)
-
-
-class ExtensionDict(object):
- """Extension dictionary added to each protocol message."""
-
- def __init__(self, msg):
- self._message = msg
- self._cmsg = msg._cmsg
- self._values = {}
-
- def __setitem__(self, extension, value):
- from google.protobuf import descriptor
- if not isinstance(extension, descriptor.FieldDescriptor):
- raise KeyError('Bad extension %r.' % (extension,))
- cdescriptor = extension._cdescriptor
- if (cdescriptor.label != _LABEL_OPTIONAL or
- cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
- raise TypeError('Extension %r is repeated and/or a composite type.' % (
- extension.full_name,))
- self._cmsg.SetScalar(cdescriptor, value)
- self._values[extension] = value
-
- def __getitem__(self, extension):
- from google.protobuf import descriptor
- if not isinstance(extension, descriptor.FieldDescriptor):
- raise KeyError('Bad extension %r.' % (extension,))
-
- cdescriptor = extension._cdescriptor
- if (cdescriptor.label != _LABEL_REPEATED and
- cdescriptor.cpp_type != _CPPTYPE_MESSAGE):
- return self._cmsg.GetScalar(cdescriptor)
-
- ext = self._values.get(extension, None)
- if ext is not None:
- return ext
-
- ext = self._CreateNewHandle(extension)
- self._values[extension] = ext
- return ext
-
- def ClearExtension(self, extension):
- from google.protobuf import descriptor
- if not isinstance(extension, descriptor.FieldDescriptor):
- raise KeyError('Bad extension %r.' % (extension,))
- self._cmsg.ClearFieldByDescriptor(extension._cdescriptor)
- if extension in self._values:
- del self._values[extension]
-
- def HasExtension(self, extension):
- from google.protobuf import descriptor
- if not isinstance(extension, descriptor.FieldDescriptor):
- raise KeyError('Bad extension %r.' % (extension,))
- return self._cmsg.HasFieldByDescriptor(extension._cdescriptor)
-
- def _FindExtensionByName(self, name):
- """Tries to find a known extension with the specified name.
-
- Args:
- name: Extension full name.
-
- Returns:
- Extension field descriptor.
- """
- return self._message._extensions_by_name.get(name, None)
-
- def _CreateNewHandle(self, extension):
- cdescriptor = extension._cdescriptor
- if (cdescriptor.label != _LABEL_REPEATED and
- cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
- cmessage = self._cmsg.NewSubMessage(cdescriptor)
- return extension.message_type._concrete_class(__cmessage=cmessage)
-
- if cdescriptor.label == _LABEL_REPEATED:
- if cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
- return RepeatedCompositeContainer(
- self._message, cdescriptor, extension.message_type._concrete_class)
- else:
- return RepeatedScalarContainer(self._message, cdescriptor)
- # This shouldn't happen!
- assert False
- return None
-
-
-def NewMessage(bases, message_descriptor, dictionary):
- """Creates a new protocol message *class*."""
- _AddClassAttributesForNestedExtensions(message_descriptor, dictionary)
- _AddEnumValues(message_descriptor, dictionary)
- _AddDescriptors(message_descriptor, dictionary)
- return bases
-
-
-def InitMessage(message_descriptor, cls):
- """Constructs a new message instance (called before instance's __init__)."""
- cls._extensions_by_name = {}
- _AddInitMethod(message_descriptor, cls)
- _AddMessageMethods(message_descriptor, cls)
- _AddPropertiesForExtensions(message_descriptor, cls)
- copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
-
-
-def _AddDescriptors(message_descriptor, dictionary):
- """Sets up a new protocol message class dictionary.
-
- Args:
- message_descriptor: A Descriptor instance describing this message type.
- dictionary: Class dictionary to which we'll add a '__slots__' entry.
- """
- dictionary['__descriptors'] = {}
- for field in message_descriptor.fields:
- dictionary['__descriptors'][field.name] = GetFieldDescriptor(
- field.full_name)
-
- dictionary['__slots__'] = list(dictionary['__descriptors'].iterkeys()) + [
- '_cmsg', '_owner', '_composite_fields', 'Extensions', '_HACK_REFCOUNTS']
-
-
-def _AddEnumValues(message_descriptor, dictionary):
- """Sets class-level attributes for all enum fields defined in this message.
-
- Args:
- message_descriptor: Descriptor object for this message type.
- dictionary: Class dictionary that should be populated.
- """
- for enum_type in message_descriptor.enum_types:
- dictionary[enum_type.name] = enum_type_wrapper.EnumTypeWrapper(enum_type)
- for enum_value in enum_type.values:
- dictionary[enum_value.name] = enum_value.number
-
-
-def _AddClassAttributesForNestedExtensions(message_descriptor, dictionary):
- """Adds class attributes for the nested extensions."""
- extension_dict = message_descriptor.extensions_by_name
- for extension_name, extension_field in extension_dict.iteritems():
- assert extension_name not in dictionary
- dictionary[extension_name] = extension_field
-
-
-def _AddInitMethod(message_descriptor, cls):
- """Adds an __init__ method to cls."""
-
- # Create and attach message field properties to the message class.
- # This can be done just once per message class, since property setters and
- # getters are passed the message instance.
- # This makes message instantiation extremely fast, and at the same time it
- # doesn't require the creation of property objects for each message instance,
- # which saves a lot of memory.
- for field in message_descriptor.fields:
- field_cdescriptor = cls.__descriptors[field.name]
- if field.label == _LABEL_REPEATED:
- if field.cpp_type == _CPPTYPE_MESSAGE:
- value = RepeatedCompositeProperty(field_cdescriptor, field.message_type)
- else:
- value = RepeatedScalarProperty(field_cdescriptor)
- elif field.cpp_type == _CPPTYPE_MESSAGE:
- value = CompositeProperty(field_cdescriptor, field.message_type)
- else:
- value = ScalarProperty(field_cdescriptor)
- setattr(cls, field.name, value)
-
- # Attach a constant with the field number.
- constant_name = field.name.upper() + '_FIELD_NUMBER'
- setattr(cls, constant_name, field.number)
-
- def Init(self, **kwargs):
- """Message constructor."""
- cmessage = kwargs.pop('__cmessage', None)
- if cmessage:
- self._cmsg = cmessage
- else:
- self._cmsg = NewCMessage(message_descriptor.full_name)
-
- # Keep a reference to the owner, as the owner keeps a reference to the
- # underlying protocol buffer message.
- owner = kwargs.pop('__owner', None)
- if owner:
- self._owner = owner
-
- if message_descriptor.is_extendable:
- self.Extensions = ExtensionDict(self)
- else:
- # Reference counting in the C++ code is broken and depends on
- # the Extensions reference to keep this object alive during unit
- # tests (see b/4856052). Remove this once b/4945904 is fixed.
- self._HACK_REFCOUNTS = self
- self._composite_fields = {}
-
- for field_name, field_value in kwargs.iteritems():
- field_cdescriptor = self.__descriptors.get(field_name, None)
- if not field_cdescriptor:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
- if field_cdescriptor.label == _LABEL_REPEATED:
- if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
- field_name = getattr(self, field_name)
- for val in field_value:
- field_name.add().MergeFrom(val)
- else:
- getattr(self, field_name).extend(field_value)
- elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
- getattr(self, field_name).MergeFrom(field_value)
- else:
- setattr(self, field_name, field_value)
-
- Init.__module__ = None
- Init.__doc__ = None
- cls.__init__ = Init
-
-
-def _IsMessageSetExtension(field):
- """Checks if a field is a message set extension."""
- return (field.is_extension and
- field.containing_type.has_options and
- field.containing_type.GetOptions().message_set_wire_format and
- field.type == _TYPE_MESSAGE and
- field.message_type == field.extension_scope and
- field.label == _LABEL_OPTIONAL)
-
-
-def _AddMessageMethods(message_descriptor, cls):
- """Adds the methods to a protocol message class."""
- if message_descriptor.is_extendable:
-
- def ClearExtension(self, extension):
- self.Extensions.ClearExtension(extension)
-
- def HasExtension(self, extension):
- return self.Extensions.HasExtension(extension)
-
- def HasField(self, field_name):
- return self._cmsg.HasField(field_name)
-
- def ClearField(self, field_name):
- child_cmessage = None
- if field_name in self._composite_fields:
- child_field = self._composite_fields[field_name]
- del self._composite_fields[field_name]
-
- child_cdescriptor = self.__descriptors[field_name]
- # TODO(anuraag): Support clearing repeated message fields as well.
- if (child_cdescriptor.label != _LABEL_REPEATED and
- child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
- child_field._owner = None
- child_cmessage = child_field._cmsg
-
- if child_cmessage is not None:
- self._cmsg.ClearField(field_name, child_cmessage)
- else:
- self._cmsg.ClearField(field_name)
-
- def Clear(self):
- cmessages_to_release = []
- for field_name, child_field in self._composite_fields.iteritems():
- child_cdescriptor = self.__descriptors[field_name]
- # TODO(anuraag): Support clearing repeated message fields as well.
- if (child_cdescriptor.label != _LABEL_REPEATED and
- child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
- child_field._owner = None
- cmessages_to_release.append((child_cdescriptor, child_field._cmsg))
- self._composite_fields.clear()
- self._cmsg.Clear(cmessages_to_release)
-
- def IsInitialized(self, errors=None):
- if self._cmsg.IsInitialized():
- return True
- if errors is not None:
- errors.extend(self.FindInitializationErrors());
- return False
-
- def SerializeToString(self):
- if not self.IsInitialized():
- raise message.EncodeError(
- 'Message %s is missing required fields: %s' % (
- self._cmsg.full_name, ','.join(self.FindInitializationErrors())))
- return self._cmsg.SerializeToString()
-
- def SerializePartialToString(self):
- return self._cmsg.SerializePartialToString()
-
- def ParseFromString(self, serialized):
- self.Clear()
- self.MergeFromString(serialized)
-
- def MergeFromString(self, serialized):
- byte_size = self._cmsg.MergeFromString(serialized)
- if byte_size < 0:
- raise message.DecodeError('Unable to merge from string.')
- return byte_size
-
- def MergeFrom(self, msg):
- if not isinstance(msg, cls):
- raise TypeError(
- "Parameter to MergeFrom() must be instance of same class: "
- "expected %s got %s." % (cls.__name__, type(msg).__name__))
- self._cmsg.MergeFrom(msg._cmsg)
-
- def CopyFrom(self, msg):
- self._cmsg.CopyFrom(msg._cmsg)
-
- def ByteSize(self):
- return self._cmsg.ByteSize()
-
- def SetInParent(self):
- return self._cmsg.SetInParent()
-
- def ListFields(self):
- all_fields = []
- field_list = self._cmsg.ListFields()
- fields_by_name = cls.DESCRIPTOR.fields_by_name
- for is_extension, field_name in field_list:
- if is_extension:
- extension = cls._extensions_by_name[field_name]
- all_fields.append((extension, self.Extensions[extension]))
- else:
- field_descriptor = fields_by_name[field_name]
- all_fields.append(
- (field_descriptor, getattr(self, field_name)))
- all_fields.sort(key=lambda item: item[0].number)
- return all_fields
-
- def FindInitializationErrors(self):
- return self._cmsg.FindInitializationErrors()
-
- def __str__(self):
- return str(self._cmsg)
-
- def __eq__(self, other):
- if self is other:
- return True
- if not isinstance(other, self.__class__):
- return False
- return self.ListFields() == other.ListFields()
-
- def __ne__(self, other):
- return not self == other
-
- def __hash__(self):
- raise TypeError('unhashable object')
-
- def __unicode__(self):
- # Lazy import to prevent circular import when text_format imports this file.
- from google.protobuf import text_format
- return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
-
- # Attach the local methods to the message class.
- for key, value in locals().copy().iteritems():
- if key not in ('key', 'value', '__builtins__', '__name__', '__doc__'):
- setattr(cls, key, value)
-
- # Static methods:
-
- def RegisterExtension(extension_handle):
- extension_handle.containing_type = cls.DESCRIPTOR
- cls._extensions_by_name[extension_handle.full_name] = extension_handle
-
- if _IsMessageSetExtension(extension_handle):
- # MessageSet extension. Also register under type name.
- cls._extensions_by_name[
- extension_handle.message_type.full_name] = extension_handle
- cls.RegisterExtension = staticmethod(RegisterExtension)
-
- def FromString(string):
- msg = cls()
- msg.MergeFromString(string)
- return msg
- cls.FromString = staticmethod(FromString)
-
-
-
-def _AddPropertiesForExtensions(message_descriptor, cls):
- """Adds properties for all fields in this protocol message type."""
- extension_dict = message_descriptor.extensions_by_name
- for extension_name, extension_field in extension_dict.iteritems():
- constant_name = extension_name.upper() + '_FIELD_NUMBER'
- setattr(cls, constant_name, extension_field.number)
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index a4b906086..31869e457 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -28,10 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#PY25 compatible for GAE.
-#
-# Copyright 2009 Google Inc. All Rights Reserved.
-
"""Code for decoding protocol buffer primitives.
This code is very similar to encoder.py -- read the docs for that module first.
@@ -85,8 +81,12 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it.
__author__ = 'kenton@google.com (Kenton Varda)'
import struct
-import sys ##PY25
-_PY2 = sys.version_info[0] < 3 ##PY25
+
+import six
+
+if six.PY3:
+ long = int
+
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf import message
@@ -114,14 +114,11 @@ def _VarintDecoder(mask, result_type):
decoder returns a (value, new_pos) pair.
"""
- local_ord = ord
- py2 = _PY2 ##PY25
-##!PY25 py2 = str is bytes
def DecodeVarint(buffer, pos):
result = 0
shift = 0
while 1:
- b = local_ord(buffer[pos]) if py2 else buffer[pos]
+ b = six.indexbytes(buffer, pos)
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
@@ -137,14 +134,11 @@ def _VarintDecoder(mask, result_type):
def _SignedVarintDecoder(mask, result_type):
"""Like _VarintDecoder() but decodes signed values."""
- local_ord = ord
- py2 = _PY2 ##PY25
-##!PY25 py2 = str is bytes
def DecodeVarint(buffer, pos):
result = 0
shift = 0
while 1:
- b = local_ord(buffer[pos]) if py2 else buffer[pos]
+ b = six.indexbytes(buffer, pos)
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
@@ -183,10 +177,8 @@ def ReadTag(buffer, pos):
use that, but not in Python.
"""
- py2 = _PY2 ##PY25
-##!PY25 py2 = str is bytes
start = pos
- while (ord(buffer[pos]) if py2 else buffer[pos]) & 0x80:
+ while six.indexbytes(buffer, pos) & 0x80:
pos += 1
pos += 1
return (buffer[start:pos], pos)
@@ -301,7 +293,6 @@ def _FloatDecoder():
"""
local_unpack = struct.unpack
- b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25
def InnerDecode(buffer, pos):
# We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
@@ -312,17 +303,12 @@ def _FloatDecoder():
# If this value has all its exponent bits set, then it's non-finite.
# In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
# To avoid that, we parse it specially.
- if ((float_bytes[3:4] in b('\x7F\xFF')) ##PY25
-##!PY25 if ((float_bytes[3:4] in b'\x7F\xFF')
- and (float_bytes[2:3] >= b('\x80'))): ##PY25
-##!PY25 and (float_bytes[2:3] >= b'\x80')):
+ if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
# If at least one significand bit is set...
- if float_bytes[0:3] != b('\x00\x00\x80'): ##PY25
-##!PY25 if float_bytes[0:3] != b'\x00\x00\x80':
+ if float_bytes[0:3] != b'\x00\x00\x80':
return (_NAN, new_pos)
# If sign bit is set...
- if float_bytes[3:4] == b('\xFF'): ##PY25
-##!PY25 if float_bytes[3:4] == b'\xFF':
+ if float_bytes[3:4] == b'\xFF':
return (_NEG_INF, new_pos)
return (_POS_INF, new_pos)
@@ -341,7 +327,6 @@ def _DoubleDecoder():
"""
local_unpack = struct.unpack
- b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25
def InnerDecode(buffer, pos):
# We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
@@ -352,12 +337,9 @@ def _DoubleDecoder():
# If this value has all its exponent bits set and at least one significand
# bit set, it's not a number. In Python 2.4, struct.unpack will treat it
# as inf or -inf. To avoid that, we treat it specially.
-##!PY25 if ((double_bytes[7:8] in b'\x7F\xFF')
-##!PY25 and (double_bytes[6:7] >= b'\xF0')
-##!PY25 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
- if ((double_bytes[7:8] in b('\x7F\xFF')) ##PY25
- and (double_bytes[6:7] >= b('\xF0')) ##PY25
- and (double_bytes[0:7] != b('\x00\x00\x00\x00\x00\x00\xF0'))): ##PY25
+ if ((double_bytes[7:8] in b'\x7F\xFF')
+ and (double_bytes[6:7] >= b'\xF0')
+ and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
return (_NAN, new_pos)
# Note that we expect someone up-stack to catch struct.error and convert
@@ -480,12 +462,12 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
"""Returns a decoder for a string field."""
local_DecodeVarint = _DecodeVarint
- local_unicode = unicode
+ local_unicode = six.text_type
def _ConvertToUnicode(byte_str):
try:
return local_unicode(byte_str, 'utf-8')
- except UnicodeDecodeError, e:
+ except UnicodeDecodeError as e:
# add more information to the error message and re-raise it.
e.reason = '%s in field: %s' % (e, key.full_name)
raise
@@ -621,9 +603,6 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
# Read length.
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
@@ -736,6 +715,50 @@ def MessageSetItemDecoder(extensions_by_number):
return DecodeItem
# --------------------------------------------------------------------
+
+def MapDecoder(field_descriptor, new_default, is_message_map):
+ """Returns a decoder for a map field."""
+
+ key = field_descriptor
+ tag_bytes = encoder.TagBytes(field_descriptor.number,
+ wire_format.WIRETYPE_LENGTH_DELIMITED)
+ tag_len = len(tag_bytes)
+ local_DecodeVarint = _DecodeVarint
+ # Can't read _concrete_class yet; might not be initialized.
+ message_type = field_descriptor.message_type
+
+ def DecodeMap(buffer, pos, end, message, field_dict):
+ submsg = message_type._concrete_class()
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ while 1:
+ # Read length.
+ (size, pos) = local_DecodeVarint(buffer, pos)
+ new_pos = pos + size
+ if new_pos > end:
+ raise _DecodeError('Truncated message.')
+ # Read sub-message.
+ submsg.Clear()
+ if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
+ # The only reason _InternalParse would return early is if it
+ # encountered an end-group tag.
+ raise _DecodeError('Unexpected end-group tag.')
+
+ if is_message_map:
+ value[submsg.key].MergeFrom(submsg.value)
+ else:
+ value[submsg.key] = submsg.value
+
+ # Predict that the next tag is another copy of the same repeated field.
+ pos = new_pos + tag_len
+ if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+ # Prediction failed. Return.
+ return new_pos
+
+ return DecodeMap
+
+# --------------------------------------------------------------------
# Optimization is not as heavy here because calls to SkipField() are rare,
# except for handling end-group tags.
diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py
index fc65b69ad..5225a4582 100644
--- a/python/google/protobuf/internal/descriptor_database_test.py
+++ b/python/google/protobuf/internal/descriptor_database_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -34,13 +34,17 @@
__author__ = 'matthewtoia@google.com (Matt Toia)'
-from google.apputils import basetest
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf import descriptor_database
-class DescriptorDatabaseTest(basetest.TestCase):
+class DescriptorDatabaseTest(unittest.TestCase):
def testAdd(self):
db = descriptor_database.DescriptorDatabase()
@@ -48,16 +52,18 @@ class DescriptorDatabaseTest(basetest.TestCase):
factory_test2_pb2.DESCRIPTOR.serialized_pb)
db.Add(file_desc_proto)
- self.assertEquals(file_desc_proto, db.FindFileByName(
+ self.assertEqual(file_desc_proto, db.FindFileByName(
'google/protobuf/internal/factory_test2.proto'))
- self.assertEquals(file_desc_proto, db.FindFileContainingSymbol(
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message'))
- self.assertEquals(file_desc_proto, db.FindFileContainingSymbol(
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Message'))
- self.assertEquals(file_desc_proto, db.FindFileContainingSymbol(
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Enum'))
- self.assertEquals(file_desc_proto, db.FindFileContainingSymbol(
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum'))
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.MessageWithNestedEnumOnly.NestedEnum'))
if __name__ == '__main__':
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index d2f855798..6a13e0bcf 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -35,9 +35,15 @@
__author__ = 'matthewtoia@google.com (Matt Toia)'
import os
-import unittest
+import sys
-from google.apputils import basetest
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
+from google.protobuf import unittest_import_pb2
+from google.protobuf import unittest_import_public_pb2
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf.internal import api_implementation
@@ -45,12 +51,15 @@ from google.protobuf.internal import descriptor_pool_test1_pb2
from google.protobuf.internal import descriptor_pool_test2_pb2
from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2
+from google.protobuf.internal import more_messages_pb2
from google.protobuf import descriptor
from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool
+from google.protobuf import message_factory
+from google.protobuf import symbol_database
-class DescriptorPoolTest(basetest.TestCase):
+class DescriptorPoolTest(unittest.TestCase):
def setUp(self):
self.pool = descriptor_pool.DescriptorPool()
@@ -65,15 +74,15 @@ class DescriptorPoolTest(basetest.TestCase):
name1 = 'google/protobuf/internal/factory_test1.proto'
file_desc1 = self.pool.FindFileByName(name1)
self.assertIsInstance(file_desc1, descriptor.FileDescriptor)
- self.assertEquals(name1, file_desc1.name)
- self.assertEquals('google.protobuf.python.internal', file_desc1.package)
+ self.assertEqual(name1, file_desc1.name)
+ self.assertEqual('google.protobuf.python.internal', file_desc1.package)
self.assertIn('Factory1Message', file_desc1.message_types_by_name)
name2 = 'google/protobuf/internal/factory_test2.proto'
file_desc2 = self.pool.FindFileByName(name2)
self.assertIsInstance(file_desc2, descriptor.FileDescriptor)
- self.assertEquals(name2, file_desc2.name)
- self.assertEquals('google.protobuf.python.internal', file_desc2.package)
+ self.assertEqual(name2, file_desc2.name)
+ self.assertEqual('google.protobuf.python.internal', file_desc2.package)
self.assertIn('Factory2Message', file_desc2.message_types_by_name)
def testFindFileByNameFailure(self):
@@ -84,17 +93,17 @@ class DescriptorPoolTest(basetest.TestCase):
file_desc1 = self.pool.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory1Message')
self.assertIsInstance(file_desc1, descriptor.FileDescriptor)
- self.assertEquals('google/protobuf/internal/factory_test1.proto',
- file_desc1.name)
- self.assertEquals('google.protobuf.python.internal', file_desc1.package)
+ self.assertEqual('google/protobuf/internal/factory_test1.proto',
+ file_desc1.name)
+ self.assertEqual('google.protobuf.python.internal', file_desc1.package)
self.assertIn('Factory1Message', file_desc1.message_types_by_name)
file_desc2 = self.pool.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message')
self.assertIsInstance(file_desc2, descriptor.FileDescriptor)
- self.assertEquals('google/protobuf/internal/factory_test2.proto',
- file_desc2.name)
- self.assertEquals('google.protobuf.python.internal', file_desc2.package)
+ self.assertEqual('google/protobuf/internal/factory_test2.proto',
+ file_desc2.name)
+ self.assertEqual('google.protobuf.python.internal', file_desc2.package)
self.assertIn('Factory2Message', file_desc2.message_types_by_name)
def testFindFileContainingSymbolFailure(self):
@@ -105,72 +114,72 @@ class DescriptorPoolTest(basetest.TestCase):
msg1 = self.pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory1Message')
self.assertIsInstance(msg1, descriptor.Descriptor)
- self.assertEquals('Factory1Message', msg1.name)
- self.assertEquals('google.protobuf.python.internal.Factory1Message',
- msg1.full_name)
- self.assertEquals(None, msg1.containing_type)
+ self.assertEqual('Factory1Message', msg1.name)
+ self.assertEqual('google.protobuf.python.internal.Factory1Message',
+ msg1.full_name)
+ self.assertEqual(None, msg1.containing_type)
nested_msg1 = msg1.nested_types[0]
- self.assertEquals('NestedFactory1Message', nested_msg1.name)
- self.assertEquals(msg1, nested_msg1.containing_type)
+ self.assertEqual('NestedFactory1Message', nested_msg1.name)
+ self.assertEqual(msg1, nested_msg1.containing_type)
nested_enum1 = msg1.enum_types[0]
- self.assertEquals('NestedFactory1Enum', nested_enum1.name)
- self.assertEquals(msg1, nested_enum1.containing_type)
+ self.assertEqual('NestedFactory1Enum', nested_enum1.name)
+ self.assertEqual(msg1, nested_enum1.containing_type)
- self.assertEquals(nested_msg1, msg1.fields_by_name[
+ self.assertEqual(nested_msg1, msg1.fields_by_name[
'nested_factory_1_message'].message_type)
- self.assertEquals(nested_enum1, msg1.fields_by_name[
+ self.assertEqual(nested_enum1, msg1.fields_by_name[
'nested_factory_1_enum'].enum_type)
msg2 = self.pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message')
self.assertIsInstance(msg2, descriptor.Descriptor)
- self.assertEquals('Factory2Message', msg2.name)
- self.assertEquals('google.protobuf.python.internal.Factory2Message',
- msg2.full_name)
+ self.assertEqual('Factory2Message', msg2.name)
+ self.assertEqual('google.protobuf.python.internal.Factory2Message',
+ msg2.full_name)
self.assertIsNone(msg2.containing_type)
nested_msg2 = msg2.nested_types[0]
- self.assertEquals('NestedFactory2Message', nested_msg2.name)
- self.assertEquals(msg2, nested_msg2.containing_type)
+ self.assertEqual('NestedFactory2Message', nested_msg2.name)
+ self.assertEqual(msg2, nested_msg2.containing_type)
nested_enum2 = msg2.enum_types[0]
- self.assertEquals('NestedFactory2Enum', nested_enum2.name)
- self.assertEquals(msg2, nested_enum2.containing_type)
+ self.assertEqual('NestedFactory2Enum', nested_enum2.name)
+ self.assertEqual(msg2, nested_enum2.containing_type)
- self.assertEquals(nested_msg2, msg2.fields_by_name[
+ self.assertEqual(nested_msg2, msg2.fields_by_name[
'nested_factory_2_message'].message_type)
- self.assertEquals(nested_enum2, msg2.fields_by_name[
+ self.assertEqual(nested_enum2, msg2.fields_by_name[
'nested_factory_2_enum'].enum_type)
self.assertTrue(msg2.fields_by_name['int_with_default'].has_default_value)
- self.assertEquals(
+ self.assertEqual(
1776, msg2.fields_by_name['int_with_default'].default_value)
self.assertTrue(
msg2.fields_by_name['double_with_default'].has_default_value)
- self.assertEquals(
+ self.assertEqual(
9.99, msg2.fields_by_name['double_with_default'].default_value)
self.assertTrue(
msg2.fields_by_name['string_with_default'].has_default_value)
- self.assertEquals(
+ self.assertEqual(
'hello world', msg2.fields_by_name['string_with_default'].default_value)
self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default_value)
self.assertFalse(msg2.fields_by_name['bool_with_default'].default_value)
self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default_value)
- self.assertEquals(
+ self.assertEqual(
1, msg2.fields_by_name['enum_with_default'].default_value)
msg3 = self.pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Message')
- self.assertEquals(nested_msg2, msg3)
+ self.assertEqual(nested_msg2, msg3)
self.assertTrue(msg2.fields_by_name['bytes_with_default'].has_default_value)
- self.assertEquals(
+ self.assertEqual(
b'a\xfb\x00c',
msg2.fields_by_name['bytes_with_default'].default_value)
@@ -190,35 +199,66 @@ class DescriptorPoolTest(basetest.TestCase):
enum1 = self.pool.FindEnumTypeByName(
'google.protobuf.python.internal.Factory1Enum')
self.assertIsInstance(enum1, descriptor.EnumDescriptor)
- self.assertEquals(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number)
- self.assertEquals(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number)
+ self.assertEqual(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number)
+ self.assertEqual(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number)
nested_enum1 = self.pool.FindEnumTypeByName(
'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum')
self.assertIsInstance(nested_enum1, descriptor.EnumDescriptor)
- self.assertEquals(
+ self.assertEqual(
0, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_0'].number)
- self.assertEquals(
+ self.assertEqual(
1, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_1'].number)
enum2 = self.pool.FindEnumTypeByName(
'google.protobuf.python.internal.Factory2Enum')
self.assertIsInstance(enum2, descriptor.EnumDescriptor)
- self.assertEquals(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number)
- self.assertEquals(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number)
+ self.assertEqual(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number)
+ self.assertEqual(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number)
nested_enum2 = self.pool.FindEnumTypeByName(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum')
self.assertIsInstance(nested_enum2, descriptor.EnumDescriptor)
- self.assertEquals(
+ self.assertEqual(
0, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_0'].number)
- self.assertEquals(
+ self.assertEqual(
1, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_1'].number)
def testFindEnumTypeByNameFailure(self):
with self.assertRaises(KeyError):
self.pool.FindEnumTypeByName('Does not exist')
+ def testFindFieldByName(self):
+ field = self.pool.FindFieldByName(
+ 'google.protobuf.python.internal.Factory1Message.list_value')
+ self.assertEqual(field.name, 'list_value')
+ self.assertEqual(field.label, field.LABEL_REPEATED)
+ with self.assertRaises(KeyError):
+ self.pool.FindFieldByName('Does not exist')
+
+ def testFindExtensionByName(self):
+ # An extension defined in a message.
+ extension = self.pool.FindExtensionByName(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field')
+ self.assertEqual(extension.name, 'one_more_field')
+ # An extension defined at file scope.
+ extension = self.pool.FindExtensionByName(
+ 'google.protobuf.python.internal.another_field')
+ self.assertEqual(extension.name, 'another_field')
+ self.assertEqual(extension.number, 1002)
+ with self.assertRaises(KeyError):
+ self.pool.FindFieldByName('Does not exist')
+
+ def testExtensionsAreNotFields(self):
+ with self.assertRaises(KeyError):
+ self.pool.FindFieldByName('google.protobuf.python.internal.another_field')
+ with self.assertRaises(KeyError):
+ self.pool.FindFieldByName(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field')
+ with self.assertRaises(KeyError):
+ self.pool.FindExtensionByName(
+ 'google.protobuf.python.internal.Factory1Message.list_value')
+
def testUserDefinedDB(self):
db = descriptor_database.DescriptorDatabase()
self.pool = descriptor_pool.DescriptorPool(db)
@@ -226,32 +266,109 @@ class DescriptorPoolTest(basetest.TestCase):
db.Add(self.factory_test2_fd)
self.testFindMessageTypeByName()
+ def testAddSerializedFile(self):
+ self.pool = descriptor_pool.DescriptorPool()
+ self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString())
+ self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString())
+ self.testFindMessageTypeByName()
+
def testComplexNesting(self):
+ more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString(
+ more_messages_pb2.DESCRIPTOR.serialized_pb)
test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
test2_desc = descriptor_pb2.FileDescriptorProto.FromString(
descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
+ self.pool.Add(more_messages_desc)
self.pool.Add(test1_desc)
self.pool.Add(test2_desc)
TEST1_FILE.CheckFile(self, self.pool)
TEST2_FILE.CheckFile(self, self.pool)
+ def testEnumDefaultValue(self):
+ """Test the default value of enums which don't start at zero."""
+ def _CheckDefaultValue(file_descriptor):
+ default_value = (file_descriptor
+ .message_types_by_name['DescriptorPoolTest1']
+ .fields_by_name['nested_enum']
+ .default_value)
+ self.assertEqual(default_value,
+ descriptor_pool_test1_pb2.DescriptorPoolTest1.BETA)
+ # First check what the generated descriptor contains.
+ _CheckDefaultValue(descriptor_pool_test1_pb2.DESCRIPTOR)
+ # Then check the generated pool. Normally this is the same descriptor.
+ file_descriptor = symbol_database.Default().pool.FindFileByName(
+ 'google/protobuf/internal/descriptor_pool_test1.proto')
+ self.assertIs(file_descriptor, descriptor_pool_test1_pb2.DESCRIPTOR)
+ _CheckDefaultValue(file_descriptor)
+
+ # Then check the dynamic pool and its internal DescriptorDatabase.
+ descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString(
+ descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
+ self.pool.Add(descriptor_proto)
+ # And do the same check as above
+ file_descriptor = self.pool.FindFileByName(
+ 'google/protobuf/internal/descriptor_pool_test1.proto')
+ _CheckDefaultValue(file_descriptor)
+
+ def testDefaultValueForCustomMessages(self):
+ """Check the value returned by non-existent fields."""
+ def _CheckValueAndType(value, expected_value, expected_type):
+ self.assertEqual(value, expected_value)
+ self.assertIsInstance(value, expected_type)
+
+ def _CheckDefaultValues(msg):
+ try:
+ int64 = long
+ except NameError: # Python3
+ int64 = int
+ try:
+ unicode_type = unicode
+ except NameError: # Python3
+ unicode_type = str
+ _CheckValueAndType(msg.optional_int32, 0, int)
+ _CheckValueAndType(msg.optional_uint64, 0, (int64, int))
+ _CheckValueAndType(msg.optional_float, 0, (float, int))
+ _CheckValueAndType(msg.optional_double, 0, (float, int))
+ _CheckValueAndType(msg.optional_bool, False, bool)
+ _CheckValueAndType(msg.optional_string, u'', unicode_type)
+ _CheckValueAndType(msg.optional_bytes, b'', bytes)
+ _CheckValueAndType(msg.optional_nested_enum, msg.FOO, int)
+ # First for the generated message
+ _CheckDefaultValues(unittest_pb2.TestAllTypes())
+ # Then for a message built with from the DescriptorPool.
+ pool = descriptor_pool.DescriptorPool()
+ pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_public_pb2.DESCRIPTOR.serialized_pb))
+ pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_pb2.DESCRIPTOR.serialized_pb))
+ pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_pb2.DESCRIPTOR.serialized_pb))
+ message_class = message_factory.MessageFactory(pool).GetPrototype(
+ pool.FindMessageTypeByName(
+ unittest_pb2.TestAllTypes.DESCRIPTOR.full_name))
+ _CheckDefaultValues(message_class())
+
class ProtoFile(object):
- def __init__(self, name, package, messages, dependencies=None):
+ def __init__(self, name, package, messages, dependencies=None,
+ public_dependencies=None):
self.name = name
self.package = package
self.messages = messages
self.dependencies = dependencies or []
+ self.public_dependencies = public_dependencies or []
def CheckFile(self, test, pool):
file_desc = pool.FindFileByName(self.name)
- test.assertEquals(self.name, file_desc.name)
- test.assertEquals(self.package, file_desc.package)
+ test.assertEqual(self.name, file_desc.name)
+ test.assertEqual(self.package, file_desc.package)
dependencies_names = [f.name for f in file_desc.dependencies]
test.assertEqual(self.dependencies, dependencies_names)
+ public_dependencies_names = [f.name for f in file_desc.public_dependencies]
+ test.assertEqual(self.public_dependencies, public_dependencies_names)
for name, msg_type in self.messages.items():
msg_type.CheckType(test, None, name, file_desc)
@@ -328,7 +445,7 @@ class EnumField(object):
test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_ENUM,
field_desc.cpp_type)
test.assertTrue(field_desc.has_default_value)
- test.assertEqual(enum_desc.values_by_name[self.default_value].index,
+ test.assertEqual(enum_desc.values_by_name[self.default_value].number,
field_desc.default_value)
test.assertEqual(msg_desc, field_desc.containing_type)
test.assertEqual(enum_desc, field_desc.enum_type)
@@ -399,12 +516,12 @@ class ExtensionField(object):
test.assertEqual(self.extended_type, field_desc.containing_type.name)
-class AddDescriptorTest(basetest.TestCase):
+class AddDescriptorTest(unittest.TestCase):
def _TestMessage(self, prefix):
pool = descriptor_pool.DescriptorPool()
pool.AddDescriptor(unittest_pb2.TestAllTypes.DESCRIPTOR)
- self.assertEquals(
+ self.assertEqual(
'protobuf_unittest.TestAllTypes',
pool.FindMessageTypeByName(
prefix + 'protobuf_unittest.TestAllTypes').full_name)
@@ -415,22 +532,24 @@ class AddDescriptorTest(basetest.TestCase):
prefix + 'protobuf_unittest.TestAllTypes.NestedMessage')
pool.AddDescriptor(unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR)
- self.assertEquals(
+ self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedMessage',
pool.FindMessageTypeByName(
prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name)
# Files are implicitly also indexed when messages are added.
- self.assertEquals(
+ self.assertEqual(
'google/protobuf/unittest.proto',
pool.FindFileByName(
'google/protobuf/unittest.proto').name)
- self.assertEquals(
+ self.assertEqual(
'google/protobuf/unittest.proto',
pool.FindFileContainingSymbol(
prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name)
+ @unittest.skipIf(api_implementation.Type() == 'cpp',
+ 'With the cpp implementation, Add() must be called first')
def testMessage(self):
self._TestMessage('')
self._TestMessage('.')
@@ -438,7 +557,7 @@ class AddDescriptorTest(basetest.TestCase):
def _TestEnum(self, prefix):
pool = descriptor_pool.DescriptorPool()
pool.AddEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR)
- self.assertEquals(
+ self.assertEqual(
'protobuf_unittest.ForeignEnum',
pool.FindEnumTypeByName(
prefix + 'protobuf_unittest.ForeignEnum').full_name)
@@ -449,30 +568,34 @@ class AddDescriptorTest(basetest.TestCase):
prefix + 'protobuf_unittest.ForeignEnum.NestedEnum')
pool.AddEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR)
- self.assertEquals(
+ self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedEnum',
pool.FindEnumTypeByName(
prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name)
# Files are implicitly also indexed when enums are added.
- self.assertEquals(
+ self.assertEqual(
'google/protobuf/unittest.proto',
pool.FindFileByName(
'google/protobuf/unittest.proto').name)
- self.assertEquals(
+ self.assertEqual(
'google/protobuf/unittest.proto',
pool.FindFileContainingSymbol(
prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name)
+ @unittest.skipIf(api_implementation.Type() == 'cpp',
+ 'With the cpp implementation, Add() must be called first')
def testEnum(self):
self._TestEnum('')
self._TestEnum('.')
+ @unittest.skipIf(api_implementation.Type() == 'cpp',
+ 'With the cpp implementation, Add() must be called first')
def testFile(self):
pool = descriptor_pool.DescriptorPool()
pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR)
- self.assertEquals(
+ self.assertEqual(
'google/protobuf/unittest.proto',
pool.FindFileByName(
'google/protobuf/unittest.proto').name)
@@ -483,6 +606,67 @@ class AddDescriptorTest(basetest.TestCase):
pool.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes')
+ def testEmptyDescriptorPool(self):
+ # Check that an empty DescriptorPool() contains no messages.
+ pool = descriptor_pool.DescriptorPool()
+ proto_file_name = descriptor_pb2.DESCRIPTOR.name
+ self.assertRaises(KeyError, pool.FindFileByName, proto_file_name)
+ # Add the above file to the pool
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ descriptor_pb2.DESCRIPTOR.CopyToProto(file_descriptor)
+ pool.Add(file_descriptor)
+ # Now it exists.
+ self.assertTrue(pool.FindFileByName(proto_file_name))
+
+ def testCustomDescriptorPool(self):
+ # Create a new pool, and add a file descriptor.
+ pool = descriptor_pool.DescriptorPool()
+ file_desc = descriptor_pb2.FileDescriptorProto(
+ name='some/file.proto', package='package')
+ file_desc.message_type.add(name='Message')
+ pool.Add(file_desc)
+ self.assertEqual(pool.FindFileByName('some/file.proto').name,
+ 'some/file.proto')
+ self.assertEqual(pool.FindMessageTypeByName('package.Message').name,
+ 'Message')
+
+
+@unittest.skipIf(
+ api_implementation.Type() != 'cpp',
+ 'default_pool is only supported by the C++ implementation')
+class DefaultPoolTest(unittest.TestCase):
+
+ def testFindMethods(self):
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+ pool = _message.default_pool
+ self.assertIs(
+ pool.FindFileByName('google/protobuf/unittest.proto'),
+ unittest_pb2.DESCRIPTOR)
+ self.assertIs(
+ pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR)
+ self.assertIs(
+ pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32'])
+ self.assertIs(
+ pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'),
+ unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension'])
+ self.assertIs(
+ pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'),
+ unittest_pb2.ForeignEnum.DESCRIPTOR)
+ self.assertIs(
+ pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field'])
+
+ def testAddFileDescriptor(self):
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+ pool = _message.default_pool
+ file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto')
+ pool.Add(file_desc)
+ pool.AddSerializedFile(file_desc.SerializeToString())
+
TEST1_FILE = ProtoFile(
'google/protobuf/internal/descriptor_pool_test1.proto',
@@ -557,8 +741,10 @@ TEST2_FILE = ProtoFile(
ExtensionField(1001, 'DescriptorPoolTest1')),
]),
},
- dependencies=['google/protobuf/internal/descriptor_pool_test1.proto'])
+ dependencies=['google/protobuf/internal/descriptor_pool_test1.proto',
+ 'google/protobuf/internal/more_messages.proto'],
+ public_dependencies=['google/protobuf/internal/more_messages.proto'])
if __name__ == '__main__':
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/internal/descriptor_pool_test1.proto b/python/google/protobuf/internal/descriptor_pool_test1.proto
index 6dfe4ef32..00816b78e 100644
--- a/python/google/protobuf/internal/descriptor_pool_test1.proto
+++ b/python/google/protobuf/internal/descriptor_pool_test1.proto
@@ -28,6 +28,8 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+syntax = "proto2";
+
package google.protobuf.python.internal;
diff --git a/python/google/protobuf/internal/descriptor_pool_test2.proto b/python/google/protobuf/internal/descriptor_pool_test2.proto
index fbc84382a..a218eccb9 100644
--- a/python/google/protobuf/internal/descriptor_pool_test2.proto
+++ b/python/google/protobuf/internal/descriptor_pool_test2.proto
@@ -28,9 +28,12 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+syntax = "proto2";
+
package google.protobuf.python.internal;
import "google/protobuf/internal/descriptor_pool_test1.proto";
+import public "google/protobuf/internal/more_messages.proto";
message DescriptorPoolTest3 {
diff --git a/python/google/protobuf/internal/descriptor_python_test.py b/python/google/protobuf/internal/descriptor_python_test.py
deleted file mode 100644
index 5471ae021..000000000
--- a/python/google/protobuf/internal/descriptor_python_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Unittest for descriptor.py for the pure Python implementation."""
-
-import os
-os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
-
-# We must set the implementation version above before the google3 imports.
-# pylint: disable=g-import-not-at-top
-from google.apputils import basetest
-from google.protobuf.internal import api_implementation
-# Run all tests from the original module by putting them in our namespace.
-# pylint: disable=wildcard-import
-from google.protobuf.internal.descriptor_test import *
-
-
-class ConfirmPurePythonTest(basetest.TestCase):
-
- def testImplementationSetting(self):
- self.assertEqual('python', api_implementation.Type())
-
-
-if __name__ == '__main__':
- basetest.main()
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index b3777e396..b8e755533 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -34,12 +34,22 @@
__author__ = 'robinson@google.com (Will Robinson)'
-from google.apputils import basetest
+import sys
+
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
+from google.protobuf.internal import api_implementation
+from google.protobuf.internal import test_util
from google.protobuf import descriptor
+from google.protobuf import descriptor_pool
+from google.protobuf import symbol_database
from google.protobuf import text_format
@@ -48,44 +58,31 @@ name: 'TestEmptyMessage'
"""
-class DescriptorTest(basetest.TestCase):
+class DescriptorTest(unittest.TestCase):
def setUp(self):
- self.my_file = descriptor.FileDescriptor(
+ file_proto = descriptor_pb2.FileDescriptorProto(
name='some/filename/some.proto',
- package='protobuf_unittest'
- )
- self.my_enum = descriptor.EnumDescriptor(
- name='ForeignEnum',
- full_name='protobuf_unittest.ForeignEnum',
- filename=None,
- file=self.my_file,
- values=[
- descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4),
- descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5),
- descriptor.EnumValueDescriptor(name='FOREIGN_BAZ', index=2, number=6),
- ])
- self.my_message = descriptor.Descriptor(
- name='NestedMessage',
- full_name='protobuf_unittest.TestAllTypes.NestedMessage',
- filename=None,
- file=self.my_file,
- containing_type=None,
- fields=[
- descriptor.FieldDescriptor(
- name='bb',
- full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb',
- index=0, number=1,
- type=5, cpp_type=1, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None),
- ],
- nested_types=[],
- enum_types=[
- self.my_enum,
- ],
- extensions=[])
+ package='protobuf_unittest')
+ message_proto = file_proto.message_type.add(
+ name='NestedMessage')
+ message_proto.field.add(
+ name='bb',
+ number=1,
+ type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32,
+ label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL)
+ enum_proto = message_proto.enum_type.add(
+ name='ForeignEnum')
+ enum_proto.value.add(name='FOREIGN_FOO', number=4)
+ enum_proto.value.add(name='FOREIGN_BAR', number=5)
+ enum_proto.value.add(name='FOREIGN_BAZ', number=6)
+
+ self.pool = self.GetDescriptorPool()
+ self.pool.Add(file_proto)
+ self.my_file = self.pool.FindFileByName(file_proto.name)
+ self.my_message = self.my_file.message_types_by_name[message_proto.name]
+ self.my_enum = self.my_message.enum_types_by_name[enum_proto.name]
+
self.my_method = descriptor.MethodDescriptor(
name='Bar',
full_name='protobuf_unittest.TestService.Bar',
@@ -102,6 +99,9 @@ class DescriptorTest(basetest.TestCase):
self.my_method
])
+ def GetDescriptorPool(self):
+ return symbol_database.Default().pool
+
def testEnumValueName(self):
self.assertEqual(self.my_message.EnumValueName('ForeignEnum', 4),
'FOREIGN_FOO')
@@ -173,6 +173,11 @@ class DescriptorTest(basetest.TestCase):
self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2,
method_options.Extensions[method_opt1])
+ message_descriptor = (
+ unittest_custom_options_pb2.DummyMessageContainingEnum.DESCRIPTOR)
+ self.assertTrue(file_descriptor.has_options)
+ self.assertFalse(message_descriptor.has_options)
+
def testDifferentCustomOptionTypes(self):
kint32min = -2**31
kint64min = -2**63
@@ -393,9 +398,130 @@ class DescriptorTest(basetest.TestCase):
def testFileDescriptor(self):
self.assertEqual(self.my_file.name, 'some/filename/some.proto')
self.assertEqual(self.my_file.package, 'protobuf_unittest')
-
-
-class DescriptorCopyToProtoTest(basetest.TestCase):
+ self.assertEqual(self.my_file.pool, self.pool)
+ # Generated modules also belong to the default pool.
+ self.assertEqual(unittest_pb2.DESCRIPTOR.pool, descriptor_pool.Default())
+
+ @unittest.skipIf(
+ api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
+ 'Immutability of descriptors is only enforced in v2 implementation')
+ def testImmutableCppDescriptor(self):
+ message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ with self.assertRaises(AttributeError):
+ message_descriptor.fields_by_name = None
+ with self.assertRaises(TypeError):
+ message_descriptor.fields_by_name['Another'] = None
+ with self.assertRaises(TypeError):
+ message_descriptor.fields.append(None)
+
+
+class NewDescriptorTest(DescriptorTest):
+ """Redo the same tests as above, but with a separate DescriptorPool."""
+
+ def GetDescriptorPool(self):
+ return descriptor_pool.DescriptorPool()
+
+
+class GeneratedDescriptorTest(unittest.TestCase):
+ """Tests for the properties of descriptors in generated code."""
+
+ def CheckMessageDescriptor(self, message_descriptor):
+ # Basic properties
+ self.assertEqual(message_descriptor.name, 'TestAllTypes')
+ self.assertEqual(message_descriptor.full_name,
+ 'protobuf_unittest.TestAllTypes')
+ # Test equality and hashability
+ self.assertEqual(message_descriptor, message_descriptor)
+ self.assertEqual(message_descriptor.fields[0].containing_type,
+ message_descriptor)
+ self.assertIn(message_descriptor, [message_descriptor])
+ self.assertIn(message_descriptor, {message_descriptor: None})
+ # Test field containers
+ self.CheckDescriptorSequence(message_descriptor.fields)
+ self.CheckDescriptorMapping(message_descriptor.fields_by_name)
+ self.CheckDescriptorMapping(message_descriptor.fields_by_number)
+ self.CheckDescriptorMapping(message_descriptor.fields_by_camelcase_name)
+
+ def CheckFieldDescriptor(self, field_descriptor):
+ # Basic properties
+ self.assertEqual(field_descriptor.name, 'optional_int32')
+ self.assertEqual(field_descriptor.camelcase_name, 'optionalInt32')
+ self.assertEqual(field_descriptor.full_name,
+ 'protobuf_unittest.TestAllTypes.optional_int32')
+ self.assertEqual(field_descriptor.containing_type.name, 'TestAllTypes')
+ # Test equality and hashability
+ self.assertEqual(field_descriptor, field_descriptor)
+ self.assertEqual(
+ field_descriptor.containing_type.fields_by_name['optional_int32'],
+ field_descriptor)
+ self.assertEqual(
+ field_descriptor.containing_type.fields_by_camelcase_name[
+ 'optionalInt32'],
+ field_descriptor)
+ self.assertIn(field_descriptor, [field_descriptor])
+ self.assertIn(field_descriptor, {field_descriptor: None})
+
+ def CheckDescriptorSequence(self, sequence):
+ # Verifies that a property like 'messageDescriptor.fields' has all the
+ # properties of an immutable abc.Sequence.
+ self.assertGreater(len(sequence), 0) # Sized
+ self.assertEqual(len(sequence), len(list(sequence))) # Iterable
+ item = sequence[0]
+ self.assertEqual(item, sequence[0])
+ self.assertIn(item, sequence) # Container
+ self.assertEqual(sequence.index(item), 0)
+ self.assertEqual(sequence.count(item), 1)
+ reversed_iterator = reversed(sequence)
+ self.assertEqual(list(reversed_iterator), list(sequence)[::-1])
+ self.assertRaises(StopIteration, next, reversed_iterator)
+
+ def CheckDescriptorMapping(self, mapping):
+ # Verifies that a property like 'messageDescriptor.fields' has all the
+ # properties of an immutable abc.Mapping.
+ self.assertGreater(len(mapping), 0) # Sized
+ self.assertEqual(len(mapping), len(list(mapping))) # Iterable
+ if sys.version_info >= (3,):
+ key, item = next(iter(mapping.items()))
+ else:
+ key, item = mapping.items()[0]
+ self.assertIn(key, mapping) # Container
+ self.assertEqual(mapping.get(key), item)
+ # keys(), iterkeys() &co
+ item = (next(iter(mapping.keys())), next(iter(mapping.values())))
+ self.assertEqual(item, next(iter(mapping.items())))
+ if sys.version_info < (3,):
+ def CheckItems(seq, iterator):
+ self.assertEqual(next(iterator), seq[0])
+ self.assertEqual(list(iterator), seq[1:])
+ CheckItems(mapping.keys(), mapping.iterkeys())
+ CheckItems(mapping.values(), mapping.itervalues())
+ CheckItems(mapping.items(), mapping.iteritems())
+
+ def testDescriptor(self):
+ message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ self.CheckMessageDescriptor(message_descriptor)
+ field_descriptor = message_descriptor.fields_by_name['optional_int32']
+ self.CheckFieldDescriptor(field_descriptor)
+ field_descriptor = message_descriptor.fields_by_camelcase_name[
+ 'optionalInt32']
+ self.CheckFieldDescriptor(field_descriptor)
+
+ def testCppDescriptorContainer(self):
+ # Check that the collection is still valid even if the parent disappeared.
+ enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum']
+ values = enum.values
+ del enum
+ self.assertEqual('FOO', values[0].name)
+
+ def testCppDescriptorContainer_Iterator(self):
+ # Same test with the iterator
+ enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum']
+ values_iter = iter(enum.values)
+ del enum
+ self.assertEqual('FOO', next(values_iter).name)
+
+
+class DescriptorCopyToProtoTest(unittest.TestCase):
"""Tests for CopyTo functions of Descriptor."""
def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii):
@@ -471,7 +597,7 @@ class DescriptorCopyToProtoTest(basetest.TestCase):
"""
self._InternalTestCopyToProto(
- unittest_pb2._FOREIGNENUM,
+ unittest_pb2.ForeignEnum.DESCRIPTOR,
descriptor_pb2.EnumDescriptorProto,
TEST_FOREIGN_ENUM_ASCII)
@@ -588,13 +714,15 @@ class DescriptorCopyToProtoTest(basetest.TestCase):
output_type: '.protobuf_unittest.BarResponse'
>
"""
- self._InternalTestCopyToProto(
- unittest_pb2.TestService.DESCRIPTOR,
- descriptor_pb2.ServiceDescriptorProto,
- TEST_SERVICE_ASCII)
+ # TODO(rocking): enable this test after the proto descriptor change is
+ # checked in.
+ #self._InternalTestCopyToProto(
+ # unittest_pb2.TestService.DESCRIPTOR,
+ # descriptor_pb2.ServiceDescriptorProto,
+ # TEST_SERVICE_ASCII)
-class MakeDescriptorTest(basetest.TestCase):
+class MakeDescriptorTest(unittest.TestCase):
def testMakeDescriptorWithNestedFields(self):
file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
@@ -665,5 +793,30 @@ class MakeDescriptorTest(basetest.TestCase):
descriptor.FieldDescriptor.CPPTYPE_UINT64)
+ def testMakeDescriptorWithOptions(self):
+ descriptor_proto = descriptor_pb2.DescriptorProto()
+ aggregate_message = unittest_custom_options_pb2.AggregateMessage
+ aggregate_message.DESCRIPTOR.CopyToProto(descriptor_proto)
+ reformed_descriptor = descriptor.MakeDescriptor(descriptor_proto)
+
+ options = reformed_descriptor.GetOptions()
+ self.assertEqual(101,
+ options.Extensions[unittest_custom_options_pb2.msgopt].i)
+
+ def testCamelcaseName(self):
+ descriptor_proto = descriptor_pb2.DescriptorProto()
+ descriptor_proto.name = 'Bar'
+ names = ['foo_foo', 'FooBar', 'fooBaz', 'fooFoo', 'foobar']
+ camelcase_names = ['fooFoo', 'fooBar', 'fooBaz', 'fooFoo', 'foobar']
+ for index in range(len(names)):
+ field = descriptor_proto.field.add()
+ field.number = index + 1
+ field.name = names[index]
+ result = descriptor.MakeDescriptor(descriptor_proto)
+ for index in range(len(camelcase_names)):
+ self.assertEqual(result.fields[index].camelcase_name,
+ camelcase_names[index])
+
+
if __name__ == '__main__':
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py
index 38a5138ae..48ef2df31 100755
--- a/python/google/protobuf/internal/encoder.py
+++ b/python/google/protobuf/internal/encoder.py
@@ -28,10 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#PY25 compatible for GAE.
-#
-# Copyright 2009 Google Inc. All Rights Reserved.
-
"""Code for encoding protocol message primitives.
Contains the logic for encoding every logical protocol field type
@@ -45,7 +41,7 @@ FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The
sizer takes a value of this field's type and computes its byte size. The
encoder takes a writer function and a value. It encodes the value into byte
strings and invokes the writer function to write those strings. Typically the
-writer function is the write() method of a cStringIO.
+writer function is the write() method of a BytesIO.
We try to do as much work as possible when constructing the writer and the
sizer rather than when calling them. In particular:
@@ -71,8 +67,9 @@ sizer rather than when calling them. In particular:
__author__ = 'kenton@google.com (Kenton Varda)'
import struct
-import sys ##PY25
-_PY2 = sys.version_info[0] < 3 ##PY25
+
+import six
+
from google.protobuf.internal import wire_format
@@ -314,7 +311,7 @@ def MessageSizer(field_number, is_repeated, is_packed):
# --------------------------------------------------------------------
-# MessageSet is special.
+# MessageSet is special: it needs custom logic to compute its size properly.
def MessageSetItemSizer(field_number):
@@ -339,6 +336,32 @@ def MessageSetItemSizer(field_number):
return FieldSize
+# --------------------------------------------------------------------
+# Map is special: it needs custom logic to compute its size properly.
+
+
+def MapSizer(field_descriptor):
+ """Returns a sizer for a map field."""
+
+ # Can't look at field_descriptor.message_type._concrete_class because it may
+ # not have been initialized yet.
+ message_type = field_descriptor.message_type
+ message_sizer = MessageSizer(field_descriptor.number, False, False)
+
+ def FieldSize(map_value):
+ total = 0
+ for key in map_value:
+ value = map_value[key]
+ # It's wasteful to create the messages and throw them away one second
+ # later since we'll do the same for the actual encode. But there's not an
+ # obvious way to avoid this within the current design without tons of code
+ # duplication.
+ entry_msg = message_type._concrete_class(key=key, value=value)
+ total += message_sizer(entry_msg)
+ return total
+
+ return FieldSize
+
# ====================================================================
# Encoders!
@@ -346,16 +369,14 @@ def MessageSetItemSizer(field_number):
def _VarintEncoder():
"""Return an encoder for a basic varint value (does not include tag)."""
- local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25
-##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,))
def EncodeVarint(write, value):
bits = value & 0x7f
value >>= 7
while value:
- write(local_chr(0x80|bits))
+ write(six.int2byte(0x80|bits))
bits = value & 0x7f
value >>= 7
- return write(local_chr(bits))
+ return write(six.int2byte(bits))
return EncodeVarint
@@ -364,18 +385,16 @@ def _SignedVarintEncoder():
"""Return an encoder for a basic signed varint value (does not include
tag)."""
- local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25
-##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,))
def EncodeSignedVarint(write, value):
if value < 0:
value += (1 << 64)
bits = value & 0x7f
value >>= 7
while value:
- write(local_chr(0x80|bits))
+ write(six.int2byte(0x80|bits))
bits = value & 0x7f
value >>= 7
- return write(local_chr(bits))
+ return write(six.int2byte(bits))
return EncodeSignedVarint
@@ -390,8 +409,7 @@ def _VarintBytes(value):
pieces = []
_EncodeVarint(pieces.append, value)
- return "".encode("latin1").join(pieces) ##PY25
-##!PY25 return b"".join(pieces)
+ return b"".join(pieces)
def TagBytes(field_number, wire_type):
@@ -529,33 +547,26 @@ def _FloatingPointEncoder(wire_type, format):
format: The format string to pass to struct.pack().
"""
- b = _PY2 and (lambda x:x) or (lambda x:x.encode('latin1')) ##PY25
value_size = struct.calcsize(format)
if value_size == 4:
def EncodeNonFiniteOrRaise(write, value):
# Remember that the serialized form uses little-endian byte order.
if value == _POS_INF:
- write(b('\x00\x00\x80\x7F')) ##PY25
-##!PY25 write(b'\x00\x00\x80\x7F')
+ write(b'\x00\x00\x80\x7F')
elif value == _NEG_INF:
- write(b('\x00\x00\x80\xFF')) ##PY25
-##!PY25 write(b'\x00\x00\x80\xFF')
+ write(b'\x00\x00\x80\xFF')
elif value != value: # NaN
- write(b('\x00\x00\xC0\x7F')) ##PY25
-##!PY25 write(b'\x00\x00\xC0\x7F')
+ write(b'\x00\x00\xC0\x7F')
else:
raise
elif value_size == 8:
def EncodeNonFiniteOrRaise(write, value):
if value == _POS_INF:
- write(b('\x00\x00\x00\x00\x00\x00\xF0\x7F')) ##PY25
-##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
+ write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
elif value == _NEG_INF:
- write(b('\x00\x00\x00\x00\x00\x00\xF0\xFF')) ##PY25
-##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
+ write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
elif value != value: # NaN
- write(b('\x00\x00\x00\x00\x00\x00\xF8\x7F')) ##PY25
-##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
+ write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
else:
raise
else:
@@ -631,10 +642,8 @@ DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d')
def BoolEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a boolean field."""
-##!PY25 false_byte = b'\x00'
-##!PY25 true_byte = b'\x01'
- false_byte = '\x00'.encode('latin1') ##PY25
- true_byte = '\x01'.encode('latin1') ##PY25
+ false_byte = b'\x00'
+ true_byte = b'\x01'
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
@@ -770,8 +779,7 @@ def MessageSetItemEncoder(field_number):
}
}
"""
- start_bytes = "".encode("latin1").join([ ##PY25
-##!PY25 start_bytes = b"".join([
+ start_bytes = b"".join([
TagBytes(1, wire_format.WIRETYPE_START_GROUP),
TagBytes(2, wire_format.WIRETYPE_VARINT),
_VarintBytes(field_number),
@@ -786,3 +794,30 @@ def MessageSetItemEncoder(field_number):
return write(end_bytes)
return EncodeField
+
+
+# --------------------------------------------------------------------
+# As before, Map is special.
+
+
+def MapEncoder(field_descriptor):
+ """Encoder for extensions of MessageSet.
+
+ Maps always have a wire format like this:
+ message MapEntry {
+ key_type key = 1;
+ value_type value = 2;
+ }
+ repeated MapEntry map = N;
+ """
+ # Can't look at field_descriptor.message_type._concrete_class because it may
+ # not have been initialized yet.
+ message_type = field_descriptor.message_type
+ encode_message = MessageEncoder(field_descriptor.number, False, False)
+
+ def EncodeField(write, value):
+ for key in value:
+ entry_msg = message_type._concrete_class(key=key, value=value[key])
+ encode_message(write, entry_msg)
+
+ return EncodeField
diff --git a/python/google/protobuf/internal/factory_test1.proto b/python/google/protobuf/internal/factory_test1.proto
index 9f5a39192..d2fbbeecf 100644
--- a/python/google/protobuf/internal/factory_test1.proto
+++ b/python/google/protobuf/internal/factory_test1.proto
@@ -30,6 +30,7 @@
// Author: matthewtoia@google.com (Matt Toia)
+syntax = "proto2";
package google.protobuf.python.internal;
diff --git a/python/google/protobuf/internal/factory_test2.proto b/python/google/protobuf/internal/factory_test2.proto
index 27feb6ce4..bb1b54ada 100644
--- a/python/google/protobuf/internal/factory_test2.proto
+++ b/python/google/protobuf/internal/factory_test2.proto
@@ -30,6 +30,7 @@
// Author: matthewtoia@google.com (Matt Toia)
+syntax = "proto2";
package google.protobuf.python.internal;
@@ -87,6 +88,12 @@ message LoopMessage {
optional Factory2Message loop = 1;
}
+message MessageWithNestedEnumOnly {
+ enum NestedEnum {
+ NESTED_MESSAGE_ENUM_0 = 0;
+ }
+}
+
extend Factory1Message {
optional string another_field = 1002;
}
diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py
index 422fa9a66..83ea5f509 100755
--- a/python/google/protobuf/internal/generator_test.py
+++ b/python/google/protobuf/internal/generator_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -35,18 +35,23 @@
# indirect testing of the protocol compiler output.
"""Unittest that directly tests the output of the pure-Python protocol
-compiler. See //google/protobuf/reflection_test.py for a test which
+compiler. See //google/protobuf/internal/reflection_test.py for a test which
further ensures that we can use Python protocol message objects as we expect.
"""
__author__ = 'robinson@google.com (Will Robinson)'
-from google.apputils import basetest
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
from google.protobuf.internal import test_bad_identifiers_pb2
from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_import_public_pb2
from google.protobuf import unittest_mset_pb2
+from google.protobuf import unittest_mset_wire_format_pb2
from google.protobuf import unittest_no_generic_services_pb2
from google.protobuf import unittest_pb2
from google.protobuf import service
@@ -55,7 +60,7 @@ from google.protobuf import symbol_database
MAX_EXTENSION = 536870912
-class GeneratorTest(basetest.TestCase):
+class GeneratorTest(unittest.TestCase):
def testNestedMessageDescriptor(self):
field_name = 'optional_nested_message'
@@ -142,18 +147,18 @@ class GeneratorTest(basetest.TestCase):
self.assertTrue(not non_extension_descriptor.is_extension)
def testOptions(self):
- proto = unittest_mset_pb2.TestMessageSet()
+ proto = unittest_mset_wire_format_pb2.TestMessageSet()
self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format)
def testMessageWithCustomOptions(self):
proto = unittest_custom_options_pb2.TestMessageWithCustomOptions()
enum_options = proto.DESCRIPTOR.enum_types_by_name['AnEnum'].GetOptions()
self.assertTrue(enum_options is not None)
- # TODO(gps): We really should test for the presense of the enum_opt1
+ # TODO(gps): We really should test for the presence of the enum_opt1
# extension and for its value to be set to -789.
def testNestedTypes(self):
- self.assertEquals(
+ self.assertEqual(
set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types),
set([
unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
@@ -291,53 +296,53 @@ class GeneratorTest(basetest.TestCase):
self.assertIs(desc.oneofs[0], desc.oneofs_by_name['oneof_field'])
nested_names = set(['oneof_uint32', 'oneof_nested_message',
'oneof_string', 'oneof_bytes'])
- self.assertSameElements(
+ self.assertEqual(
nested_names,
- [field.name for field in desc.oneofs[0].fields])
- for field_name, field_desc in desc.fields_by_name.iteritems():
+ set([field.name for field in desc.oneofs[0].fields]))
+ for field_name, field_desc in desc.fields_by_name.items():
if field_name in nested_names:
self.assertIs(desc.oneofs[0], field_desc.containing_oneof)
else:
self.assertIsNone(field_desc.containing_oneof)
-class SymbolDatabaseRegistrationTest(basetest.TestCase):
+class SymbolDatabaseRegistrationTest(unittest.TestCase):
"""Checks that messages, enums and files are correctly registered."""
def testGetSymbol(self):
- self.assertEquals(
+ self.assertEqual(
unittest_pb2.TestAllTypes, symbol_database.Default().GetSymbol(
'protobuf_unittest.TestAllTypes'))
- self.assertEquals(
+ self.assertEqual(
unittest_pb2.TestAllTypes.NestedMessage,
symbol_database.Default().GetSymbol(
'protobuf_unittest.TestAllTypes.NestedMessage'))
with self.assertRaises(KeyError):
symbol_database.Default().GetSymbol('protobuf_unittest.NestedMessage')
- self.assertEquals(
+ self.assertEqual(
unittest_pb2.TestAllTypes.OptionalGroup,
symbol_database.Default().GetSymbol(
'protobuf_unittest.TestAllTypes.OptionalGroup'))
- self.assertEquals(
+ self.assertEqual(
unittest_pb2.TestAllTypes.RepeatedGroup,
symbol_database.Default().GetSymbol(
'protobuf_unittest.TestAllTypes.RepeatedGroup'))
def testEnums(self):
- self.assertEquals(
+ self.assertEqual(
'protobuf_unittest.ForeignEnum',
symbol_database.Default().pool.FindEnumTypeByName(
'protobuf_unittest.ForeignEnum').full_name)
- self.assertEquals(
+ self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedEnum',
symbol_database.Default().pool.FindEnumTypeByName(
'protobuf_unittest.TestAllTypes.NestedEnum').full_name)
def testFindFileByName(self):
- self.assertEquals(
+ self.assertEqual(
'google/protobuf/unittest.proto',
symbol_database.Default().pool.FindFileByName(
'google/protobuf/unittest.proto').name)
if __name__ == '__main__':
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/internal/message_python_test.py b/python/google/protobuf/internal/import_test_package/__init__.py
index c40623a8c..5121dd0ec 100644
--- a/python/google/protobuf/internal/message_python_test.py
+++ b/python/google/protobuf/internal/import_test_package/__init__.py
@@ -1,5 +1,3 @@
-#! /usr/bin/python
-#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
@@ -30,25 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-"""Tests for ..public.message for the pure Python implementation."""
-
-import os
-os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
-
-# We must set the implementation version above before the google3 imports.
-# pylint: disable=g-import-not-at-top
-from google.apputils import basetest
-from google.protobuf.internal import api_implementation
-# Run all tests from the original module by putting them in our namespace.
-# pylint: disable=wildcard-import
-from google.protobuf.internal.message_test import *
-
-
-class ConfirmPurePythonTest(basetest.TestCase):
-
- def testImplementationSetting(self):
- self.assertEqual('python', api_implementation.Type())
-
+"""Sample module importing a nested proto from itself."""
-if __name__ == '__main__':
- basetest.main()
+from google.protobuf.internal.import_test_package import outer_pb2 as myproto
diff --git a/python/google/protobuf/internal/import_test_package/inner.proto b/python/google/protobuf/internal/import_test_package/inner.proto
new file mode 100644
index 000000000..2887c1230
--- /dev/null
+++ b/python/google/protobuf/internal/import_test_package/inner.proto
@@ -0,0 +1,37 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+syntax = "proto2";
+
+package google.protobuf.python.internal.import_test_package;
+
+message Inner {
+ optional int32 value = 1 [default = 57];
+}
diff --git a/python/google/protobuf/internal/import_test_package/outer.proto b/python/google/protobuf/internal/import_test_package/outer.proto
new file mode 100644
index 000000000..a27fb5c8f
--- /dev/null
+++ b/python/google/protobuf/internal/import_test_package/outer.proto
@@ -0,0 +1,39 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+syntax = "proto2";
+
+package google.protobuf.python.internal.import_test_package;
+
+import "google/protobuf/internal/import_test_package/inner.proto";
+
+message Outer {
+ optional Inner inner = 1;
+}
diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py
new file mode 100644
index 000000000..bdc9f49a4
--- /dev/null
+++ b/python/google/protobuf/internal/json_format_test.py
@@ -0,0 +1,769 @@
+#! /usr/bin/env python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test for google.protobuf.json_format."""
+
+__author__ = 'jieluo@google.com (Jie Luo)'
+
+import json
+import math
+import sys
+
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
+from google.protobuf import any_pb2
+from google.protobuf import duration_pb2
+from google.protobuf import field_mask_pb2
+from google.protobuf import struct_pb2
+from google.protobuf import timestamp_pb2
+from google.protobuf import wrappers_pb2
+from google.protobuf.internal import well_known_types
+from google.protobuf import json_format
+from google.protobuf.util import json_format_proto3_pb2
+
+
+class JsonFormatBase(unittest.TestCase):
+
+ def FillAllFields(self, message):
+ message.int32_value = 20
+ message.int64_value = -20
+ message.uint32_value = 3120987654
+ message.uint64_value = 12345678900
+ message.float_value = float('-inf')
+ message.double_value = 3.1415
+ message.bool_value = True
+ message.string_value = 'foo'
+ message.bytes_value = b'bar'
+ message.message_value.value = 10
+ message.enum_value = json_format_proto3_pb2.BAR
+ # Repeated
+ message.repeated_int32_value.append(0x7FFFFFFF)
+ message.repeated_int32_value.append(-2147483648)
+ message.repeated_int64_value.append(9007199254740992)
+ message.repeated_int64_value.append(-9007199254740992)
+ message.repeated_uint32_value.append(0xFFFFFFF)
+ message.repeated_uint32_value.append(0x7FFFFFF)
+ message.repeated_uint64_value.append(9007199254740992)
+ message.repeated_uint64_value.append(9007199254740991)
+ message.repeated_float_value.append(0)
+
+ message.repeated_double_value.append(1E-15)
+ message.repeated_double_value.append(float('inf'))
+ message.repeated_bool_value.append(True)
+ message.repeated_bool_value.append(False)
+ message.repeated_string_value.append('Few symbols!#$,;')
+ message.repeated_string_value.append('bar')
+ message.repeated_bytes_value.append(b'foo')
+ message.repeated_bytes_value.append(b'bar')
+ message.repeated_message_value.add().value = 10
+ message.repeated_message_value.add().value = 11
+ message.repeated_enum_value.append(json_format_proto3_pb2.FOO)
+ message.repeated_enum_value.append(json_format_proto3_pb2.BAR)
+ self.message = message
+
+ def CheckParseBack(self, message, parsed_message):
+ json_format.Parse(json_format.MessageToJson(message),
+ parsed_message)
+ self.assertEqual(message, parsed_message)
+
+ def CheckError(self, text, error_message):
+ message = json_format_proto3_pb2.TestMessage()
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ error_message,
+ json_format.Parse, text, message)
+
+
+class JsonFormatTest(JsonFormatBase):
+
+ def testEmptyMessageToJson(self):
+ message = json_format_proto3_pb2.TestMessage()
+ self.assertEqual(json_format.MessageToJson(message),
+ '{}')
+ parsed_message = json_format_proto3_pb2.TestMessage()
+ self.CheckParseBack(message, parsed_message)
+
+ def testPartialMessageToJson(self):
+ message = json_format_proto3_pb2.TestMessage(
+ string_value='test',
+ repeated_int32_value=[89, 4])
+ self.assertEqual(json.loads(json_format.MessageToJson(message)),
+ json.loads('{"stringValue": "test", '
+ '"repeatedInt32Value": [89, 4]}'))
+ parsed_message = json_format_proto3_pb2.TestMessage()
+ self.CheckParseBack(message, parsed_message)
+
+ def testAllFieldsToJson(self):
+ message = json_format_proto3_pb2.TestMessage()
+ text = ('{"int32Value": 20, '
+ '"int64Value": "-20", '
+ '"uint32Value": 3120987654,'
+ '"uint64Value": "12345678900",'
+ '"floatValue": "-Infinity",'
+ '"doubleValue": 3.1415,'
+ '"boolValue": true,'
+ '"stringValue": "foo",'
+ '"bytesValue": "YmFy",'
+ '"messageValue": {"value": 10},'
+ '"enumValue": "BAR",'
+ '"repeatedInt32Value": [2147483647, -2147483648],'
+ '"repeatedInt64Value": ["9007199254740992", "-9007199254740992"],'
+ '"repeatedUint32Value": [268435455, 134217727],'
+ '"repeatedUint64Value": ["9007199254740992", "9007199254740991"],'
+ '"repeatedFloatValue": [0],'
+ '"repeatedDoubleValue": [1e-15, "Infinity"],'
+ '"repeatedBoolValue": [true, false],'
+ '"repeatedStringValue": ["Few symbols!#$,;", "bar"],'
+ '"repeatedBytesValue": ["Zm9v", "YmFy"],'
+ '"repeatedMessageValue": [{"value": 10}, {"value": 11}],'
+ '"repeatedEnumValue": ["FOO", "BAR"]'
+ '}')
+ self.FillAllFields(message)
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message)),
+ json.loads(text))
+ parsed_message = json_format_proto3_pb2.TestMessage()
+ json_format.Parse(text, parsed_message)
+ self.assertEqual(message, parsed_message)
+
+ def testJsonEscapeString(self):
+ message = json_format_proto3_pb2.TestMessage()
+ if sys.version_info[0] < 3:
+ message.string_value = '&\n<\"\r>\b\t\f\\\001/\xe2\x80\xa8\xe2\x80\xa9'
+ else:
+ message.string_value = '&\n<\"\r>\b\t\f\\\001/'
+ message.string_value += (b'\xe2\x80\xa8\xe2\x80\xa9').decode('utf-8')
+ self.assertEqual(
+ json_format.MessageToJson(message),
+ '{\n "stringValue": '
+ '"&\\n<\\\"\\r>\\b\\t\\f\\\\\\u0001/\\u2028\\u2029"\n}')
+ parsed_message = json_format_proto3_pb2.TestMessage()
+ self.CheckParseBack(message, parsed_message)
+ text = u'{"int32Value": "\u0031"}'
+ json_format.Parse(text, message)
+ self.assertEqual(message.int32_value, 1)
+
+ def testAlwaysSeriliaze(self):
+ message = json_format_proto3_pb2.TestMessage(
+ string_value='foo')
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads('{'
+ '"repeatedStringValue": [],'
+ '"stringValue": "foo",'
+ '"repeatedBoolValue": [],'
+ '"repeatedUint32Value": [],'
+ '"repeatedInt32Value": [],'
+ '"enumValue": "FOO",'
+ '"int32Value": 0,'
+ '"floatValue": 0,'
+ '"int64Value": "0",'
+ '"uint32Value": 0,'
+ '"repeatedBytesValue": [],'
+ '"repeatedUint64Value": [],'
+ '"repeatedDoubleValue": [],'
+ '"bytesValue": "",'
+ '"boolValue": false,'
+ '"repeatedEnumValue": [],'
+ '"uint64Value": "0",'
+ '"doubleValue": 0,'
+ '"repeatedFloatValue": [],'
+ '"repeatedInt64Value": [],'
+ '"repeatedMessageValue": []}'))
+ parsed_message = json_format_proto3_pb2.TestMessage()
+ self.CheckParseBack(message, parsed_message)
+
+ def testMapFields(self):
+ message = json_format_proto3_pb2.TestMap()
+ message.bool_map[True] = 1
+ message.bool_map[False] = 2
+ message.int32_map[1] = 2
+ message.int32_map[2] = 3
+ message.int64_map[1] = 2
+ message.int64_map[2] = 3
+ message.uint32_map[1] = 2
+ message.uint32_map[2] = 3
+ message.uint64_map[1] = 2
+ message.uint64_map[2] = 3
+ message.string_map['1'] = 2
+ message.string_map['null'] = 3
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads('{'
+ '"boolMap": {"false": 2, "true": 1},'
+ '"int32Map": {"1": 2, "2": 3},'
+ '"int64Map": {"1": 2, "2": 3},'
+ '"uint32Map": {"1": 2, "2": 3},'
+ '"uint64Map": {"1": 2, "2": 3},'
+ '"stringMap": {"1": 2, "null": 3}'
+ '}'))
+ parsed_message = json_format_proto3_pb2.TestMap()
+ self.CheckParseBack(message, parsed_message)
+
+ def testOneofFields(self):
+ message = json_format_proto3_pb2.TestOneof()
+ # Always print does not affect oneof fields.
+ self.assertEqual(
+ json_format.MessageToJson(message, True),
+ '{}')
+ message.oneof_int32_value = 0
+ self.assertEqual(
+ json_format.MessageToJson(message, True),
+ '{\n'
+ ' "oneofInt32Value": 0\n'
+ '}')
+ parsed_message = json_format_proto3_pb2.TestOneof()
+ self.CheckParseBack(message, parsed_message)
+
+ def testTimestampMessage(self):
+ message = json_format_proto3_pb2.TestTimestamp()
+ message.value.seconds = 0
+ message.value.nanos = 0
+ message.repeated_value.add().seconds = 20
+ message.repeated_value[0].nanos = 1
+ message.repeated_value.add().seconds = 0
+ message.repeated_value[1].nanos = 10000
+ message.repeated_value.add().seconds = 100000000
+ message.repeated_value[2].nanos = 0
+ # Maximum time
+ message.repeated_value.add().seconds = 253402300799
+ message.repeated_value[3].nanos = 999999999
+ # Minimum time
+ message.repeated_value.add().seconds = -62135596800
+ message.repeated_value[4].nanos = 0
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads('{'
+ '"value": "1970-01-01T00:00:00Z",'
+ '"repeatedValue": ['
+ ' "1970-01-01T00:00:20.000000001Z",'
+ ' "1970-01-01T00:00:00.000010Z",'
+ ' "1973-03-03T09:46:40Z",'
+ ' "9999-12-31T23:59:59.999999999Z",'
+ ' "0001-01-01T00:00:00Z"'
+ ']'
+ '}'))
+ parsed_message = json_format_proto3_pb2.TestTimestamp()
+ self.CheckParseBack(message, parsed_message)
+ text = (r'{"value": "1970-01-01T00:00:00.01+08:00",'
+ r'"repeatedValue":['
+ r' "1970-01-01T00:00:00.01+08:30",'
+ r' "1970-01-01T00:00:00.01-01:23"]}')
+ json_format.Parse(text, parsed_message)
+ self.assertEqual(parsed_message.value.seconds, -8 * 3600)
+ self.assertEqual(parsed_message.value.nanos, 10000000)
+ self.assertEqual(parsed_message.repeated_value[0].seconds, -8.5 * 3600)
+ self.assertEqual(parsed_message.repeated_value[1].seconds, 3600 + 23 * 60)
+
+ def testDurationMessage(self):
+ message = json_format_proto3_pb2.TestDuration()
+ message.value.seconds = 1
+ message.repeated_value.add().seconds = 0
+ message.repeated_value[0].nanos = 10
+ message.repeated_value.add().seconds = -1
+ message.repeated_value[1].nanos = -1000
+ message.repeated_value.add().seconds = 10
+ message.repeated_value[2].nanos = 11000000
+ message.repeated_value.add().seconds = -315576000000
+ message.repeated_value.add().seconds = 315576000000
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads('{'
+ '"value": "1s",'
+ '"repeatedValue": ['
+ ' "0.000000010s",'
+ ' "-1.000001s",'
+ ' "10.011s",'
+ ' "-315576000000s",'
+ ' "315576000000s"'
+ ']'
+ '}'))
+ parsed_message = json_format_proto3_pb2.TestDuration()
+ self.CheckParseBack(message, parsed_message)
+
+ def testFieldMaskMessage(self):
+ message = json_format_proto3_pb2.TestFieldMask()
+ message.value.paths.append('foo.bar')
+ message.value.paths.append('bar')
+ self.assertEqual(
+ json_format.MessageToJson(message, True),
+ '{\n'
+ ' "value": "foo.bar,bar"\n'
+ '}')
+ parsed_message = json_format_proto3_pb2.TestFieldMask()
+ self.CheckParseBack(message, parsed_message)
+
+ def testWrapperMessage(self):
+ message = json_format_proto3_pb2.TestWrapper()
+ message.bool_value.value = False
+ message.int32_value.value = 0
+ message.string_value.value = ''
+ message.bytes_value.value = b''
+ message.repeated_bool_value.add().value = True
+ message.repeated_bool_value.add().value = False
+ message.repeated_int32_value.add()
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads('{\n'
+ ' "int32Value": 0,'
+ ' "boolValue": false,'
+ ' "stringValue": "",'
+ ' "bytesValue": "",'
+ ' "repeatedBoolValue": [true, false],'
+ ' "repeatedInt32Value": [0],'
+ ' "repeatedUint32Value": [],'
+ ' "repeatedFloatValue": [],'
+ ' "repeatedDoubleValue": [],'
+ ' "repeatedBytesValue": [],'
+ ' "repeatedInt64Value": [],'
+ ' "repeatedUint64Value": [],'
+ ' "repeatedStringValue": []'
+ '}'))
+ parsed_message = json_format_proto3_pb2.TestWrapper()
+ self.CheckParseBack(message, parsed_message)
+
+ def testStructMessage(self):
+ message = json_format_proto3_pb2.TestStruct()
+ message.value['name'] = 'Jim'
+ message.value['age'] = 10
+ message.value['attend'] = True
+ message.value['email'] = None
+ message.value.get_or_create_struct('address')['city'] = 'SFO'
+ message.value['address']['house_number'] = 1024
+ struct_list = message.value.get_or_create_list('list')
+ struct_list.extend([6, 'seven', True, False, None])
+ struct_list.add_struct()['subkey2'] = 9
+ message.repeated_value.add()['age'] = 11
+ message.repeated_value.add()
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, False)),
+ json.loads(
+ '{'
+ ' "value": {'
+ ' "address": {'
+ ' "city": "SFO", '
+ ' "house_number": 1024'
+ ' }, '
+ ' "age": 10, '
+ ' "name": "Jim", '
+ ' "attend": true, '
+ ' "email": null, '
+ ' "list": [6, "seven", true, false, null, {"subkey2": 9}]'
+ ' },'
+ ' "repeatedValue": [{"age": 11}, {}]'
+ '}'))
+ parsed_message = json_format_proto3_pb2.TestStruct()
+ self.CheckParseBack(message, parsed_message)
+
+ def testValueMessage(self):
+ message = json_format_proto3_pb2.TestValue()
+ message.value.string_value = 'hello'
+ message.repeated_value.add().number_value = 11.1
+ message.repeated_value.add().bool_value = False
+ message.repeated_value.add().null_value = 0
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, False)),
+ json.loads(
+ '{'
+ ' "value": "hello",'
+ ' "repeatedValue": [11.1, false, null]'
+ '}'))
+ parsed_message = json_format_proto3_pb2.TestValue()
+ self.CheckParseBack(message, parsed_message)
+ # Can't parse back if the Value message is not set.
+ message.repeated_value.add()
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, False)),
+ json.loads(
+ '{'
+ ' "value": "hello",'
+ ' "repeatedValue": [11.1, false, null, null]'
+ '}'))
+
+ def testListValueMessage(self):
+ message = json_format_proto3_pb2.TestListValue()
+ message.value.values.add().number_value = 11.1
+ message.value.values.add().null_value = 0
+ message.value.values.add().bool_value = True
+ message.value.values.add().string_value = 'hello'
+ message.value.values.add().struct_value['name'] = 'Jim'
+ message.repeated_value.add().values.add().number_value = 1
+ message.repeated_value.add()
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, False)),
+ json.loads(
+ '{"value": [11.1, null, true, "hello", {"name": "Jim"}]\n,'
+ '"repeatedValue": [[1], []]}'))
+ parsed_message = json_format_proto3_pb2.TestListValue()
+ self.CheckParseBack(message, parsed_message)
+
+ def testAnyMessage(self):
+ message = json_format_proto3_pb2.TestAny()
+ value1 = json_format_proto3_pb2.MessageType()
+ value2 = json_format_proto3_pb2.MessageType()
+ value1.value = 1234
+ value2.value = 5678
+ message.value.Pack(value1)
+ message.repeated_value.add().Pack(value1)
+ message.repeated_value.add().Pack(value2)
+ message.repeated_value.add()
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads(
+ '{\n'
+ ' "repeatedValue": [ {\n'
+ ' "@type": "type.googleapis.com/proto3.MessageType",\n'
+ ' "value": 1234\n'
+ ' }, {\n'
+ ' "@type": "type.googleapis.com/proto3.MessageType",\n'
+ ' "value": 5678\n'
+ ' },\n'
+ ' {}],\n'
+ ' "value": {\n'
+ ' "@type": "type.googleapis.com/proto3.MessageType",\n'
+ ' "value": 1234\n'
+ ' }\n'
+ '}\n'))
+ parsed_message = json_format_proto3_pb2.TestAny()
+ self.CheckParseBack(message, parsed_message)
+
+ def testWellKnownInAnyMessage(self):
+ message = any_pb2.Any()
+ int32_value = wrappers_pb2.Int32Value()
+ int32_value.value = 1234
+ message.Pack(int32_value)
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads(
+ '{\n'
+ ' "@type": \"type.googleapis.com/google.protobuf.Int32Value\",\n'
+ ' "value": 1234\n'
+ '}\n'))
+ parsed_message = any_pb2.Any()
+ self.CheckParseBack(message, parsed_message)
+
+ timestamp = timestamp_pb2.Timestamp()
+ message.Pack(timestamp)
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads(
+ '{\n'
+ ' "@type": "type.googleapis.com/google.protobuf.Timestamp",\n'
+ ' "value": "1970-01-01T00:00:00Z"\n'
+ '}\n'))
+ self.CheckParseBack(message, parsed_message)
+
+ duration = duration_pb2.Duration()
+ duration.seconds = 1
+ message.Pack(duration)
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads(
+ '{\n'
+ ' "@type": "type.googleapis.com/google.protobuf.Duration",\n'
+ ' "value": "1s"\n'
+ '}\n'))
+ self.CheckParseBack(message, parsed_message)
+
+ field_mask = field_mask_pb2.FieldMask()
+ field_mask.paths.append('foo.bar')
+ field_mask.paths.append('bar')
+ message.Pack(field_mask)
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads(
+ '{\n'
+ ' "@type": "type.googleapis.com/google.protobuf.FieldMask",\n'
+ ' "value": "foo.bar,bar"\n'
+ '}\n'))
+ self.CheckParseBack(message, parsed_message)
+
+ struct_message = struct_pb2.Struct()
+ struct_message['name'] = 'Jim'
+ message.Pack(struct_message)
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads(
+ '{\n'
+ ' "@type": "type.googleapis.com/google.protobuf.Struct",\n'
+ ' "value": {"name": "Jim"}\n'
+ '}\n'))
+ self.CheckParseBack(message, parsed_message)
+
+ nested_any = any_pb2.Any()
+ int32_value.value = 5678
+ nested_any.Pack(int32_value)
+ message.Pack(nested_any)
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads(
+ '{\n'
+ ' "@type": "type.googleapis.com/google.protobuf.Any",\n'
+ ' "value": {\n'
+ ' "@type": "type.googleapis.com/google.protobuf.Int32Value",\n'
+ ' "value": 5678\n'
+ ' }\n'
+ '}\n'))
+ self.CheckParseBack(message, parsed_message)
+
+ def testParseNull(self):
+ message = json_format_proto3_pb2.TestMessage()
+ parsed_message = json_format_proto3_pb2.TestMessage()
+ self.FillAllFields(parsed_message)
+ json_format.Parse('{"int32Value": null, '
+ '"int64Value": null, '
+ '"uint32Value": null,'
+ '"uint64Value": null,'
+ '"floatValue": null,'
+ '"doubleValue": null,'
+ '"boolValue": null,'
+ '"stringValue": null,'
+ '"bytesValue": null,'
+ '"messageValue": null,'
+ '"enumValue": null,'
+ '"repeatedInt32Value": null,'
+ '"repeatedInt64Value": null,'
+ '"repeatedUint32Value": null,'
+ '"repeatedUint64Value": null,'
+ '"repeatedFloatValue": null,'
+ '"repeatedDoubleValue": null,'
+ '"repeatedBoolValue": null,'
+ '"repeatedStringValue": null,'
+ '"repeatedBytesValue": null,'
+ '"repeatedMessageValue": null,'
+ '"repeatedEnumValue": null'
+ '}',
+ parsed_message)
+ self.assertEqual(message, parsed_message)
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ 'Failed to parse repeatedInt32Value field: '
+ 'null is not allowed to be used as an element in a repeated field.',
+ json_format.Parse,
+ '{"repeatedInt32Value":[1, null]}',
+ parsed_message)
+
+ def testNanFloat(self):
+ message = json_format_proto3_pb2.TestMessage()
+ message.float_value = float('nan')
+ text = '{\n "floatValue": "NaN"\n}'
+ self.assertEqual(json_format.MessageToJson(message), text)
+ parsed_message = json_format_proto3_pb2.TestMessage()
+ json_format.Parse(text, parsed_message)
+ self.assertTrue(math.isnan(parsed_message.float_value))
+
+ def testParseEmptyText(self):
+ self.CheckError('',
+ r'Failed to load JSON: (Expecting value)|(No JSON).')
+
+ def testParseBadEnumValue(self):
+ self.CheckError(
+ '{"enumValue": 1}',
+ 'Enum value must be a string literal with double quotes. '
+ 'Type "proto3.EnumType" has no value named 1.')
+ self.CheckError(
+ '{"enumValue": "baz"}',
+ 'Enum value must be a string literal with double quotes. '
+ 'Type "proto3.EnumType" has no value named baz.')
+
+ def testParseBadIdentifer(self):
+ self.CheckError('{int32Value: 1}',
+ (r'Failed to load JSON: Expecting property name'
+ r'( enclosed in double quotes)?: line 1'))
+ self.CheckError('{"unknownName": 1}',
+ 'Message type "proto3.TestMessage" has no field named '
+ '"unknownName".')
+
+ def testDuplicateField(self):
+ # Duplicate key check is not supported for python2.6
+ if sys.version_info < (2, 7):
+ return
+ self.CheckError('{"int32Value": 1,\n"int32Value":2}',
+ 'Failed to load JSON: duplicate key int32Value.')
+
+ def testInvalidBoolValue(self):
+ self.CheckError('{"boolValue": 1}',
+ 'Failed to parse boolValue field: '
+ 'Expected true or false without quotes.')
+ self.CheckError('{"boolValue": "true"}',
+ 'Failed to parse boolValue field: '
+ 'Expected true or false without quotes.')
+
+ def testInvalidIntegerValue(self):
+ message = json_format_proto3_pb2.TestMessage()
+ text = '{"int32Value": 0x12345}'
+ self.assertRaises(json_format.ParseError,
+ json_format.Parse, text, message)
+ self.CheckError('{"int32Value": 012345}',
+ (r'Failed to load JSON: Expecting \'?,\'? delimiter: '
+ r'line 1.'))
+ self.CheckError('{"int32Value": 1.0}',
+ 'Failed to parse int32Value field: '
+ 'Couldn\'t parse integer: 1.0.')
+ self.CheckError('{"int32Value": " 1 "}',
+ 'Failed to parse int32Value field: '
+ 'Couldn\'t parse integer: " 1 ".')
+ self.CheckError('{"int32Value": "1 "}',
+ 'Failed to parse int32Value field: '
+ 'Couldn\'t parse integer: "1 ".')
+ self.CheckError('{"int32Value": 12345678901234567890}',
+ 'Failed to parse int32Value field: Value out of range: '
+ '12345678901234567890.')
+ self.CheckError('{"int32Value": 1e5}',
+ 'Failed to parse int32Value field: '
+ 'Couldn\'t parse integer: 100000.0.')
+ self.CheckError('{"uint32Value": -1}',
+ 'Failed to parse uint32Value field: '
+ 'Value out of range: -1.')
+
+ def testInvalidFloatValue(self):
+ self.CheckError('{"floatValue": "nan"}',
+ 'Failed to parse floatValue field: Couldn\'t '
+ 'parse float "nan", use "NaN" instead.')
+
+ def testInvalidBytesValue(self):
+ self.CheckError('{"bytesValue": "AQI"}',
+ 'Failed to parse bytesValue field: Incorrect padding.')
+ self.CheckError('{"bytesValue": "AQI*"}',
+ 'Failed to parse bytesValue field: Incorrect padding.')
+
+ def testInvalidMap(self):
+ message = json_format_proto3_pb2.TestMap()
+ text = '{"int32Map": {"null": 2, "2": 3}}'
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ 'Failed to parse int32Map field: invalid literal',
+ json_format.Parse, text, message)
+ text = '{"int32Map": {1: 2, "2": 3}}'
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ (r'Failed to load JSON: Expecting property name'
+ r'( enclosed in double quotes)?: line 1'),
+ json_format.Parse, text, message)
+ text = '{"boolMap": {"null": 1}}'
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ 'Failed to parse boolMap field: Expected "true" or "false", not null.',
+ json_format.Parse, text, message)
+ if sys.version_info < (2, 7):
+ return
+ text = r'{"stringMap": {"a": 3, "\u0061": 2}}'
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ 'Failed to load JSON: duplicate key a',
+ json_format.Parse, text, message)
+
+ def testInvalidTimestamp(self):
+ message = json_format_proto3_pb2.TestTimestamp()
+ text = '{"value": "10000-01-01T00:00:00.00Z"}'
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ 'time data \'10000-01-01T00:00:00\' does not match'
+ ' format \'%Y-%m-%dT%H:%M:%S\'.',
+ json_format.Parse, text, message)
+ text = '{"value": "1970-01-01T00:00:00.0123456789012Z"}'
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'nanos 0123456789012 more than 9 fractional digits.',
+ json_format.Parse, text, message)
+ text = '{"value": "1972-01-01T01:00:00.01+08"}'
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ (r'Invalid timezone offset value: \+08.'),
+ json_format.Parse, text, message)
+ # Time smaller than minimum time.
+ text = '{"value": "0000-01-01T00:00:00Z"}'
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ 'Failed to parse value field: year is out of range.',
+ json_format.Parse, text, message)
+ # Time bigger than maxinum time.
+ message.value.seconds = 253402300800
+ self.assertRaisesRegexp(
+ OverflowError,
+ 'date value out of range',
+ json_format.MessageToJson, message)
+
+ def testInvalidOneof(self):
+ message = json_format_proto3_pb2.TestOneof()
+ text = '{"oneofInt32Value": 1, "oneofStringValue": "2"}'
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ 'Message type "proto3.TestOneof"'
+ ' should not have multiple "oneof_value" oneof fields.',
+ json_format.Parse, text, message)
+
+ def testInvalidListValue(self):
+ message = json_format_proto3_pb2.TestListValue()
+ text = '{"value": 1234}'
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ r'Failed to parse value field: ListValue must be in \[\] which is 1234',
+ json_format.Parse, text, message)
+
+ def testInvalidStruct(self):
+ message = json_format_proto3_pb2.TestStruct()
+ text = '{"value": 1234}'
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ 'Failed to parse value field: Struct must be in a dict which is 1234',
+ json_format.Parse, text, message)
+
+ def testInvalidAny(self):
+ message = any_pb2.Any()
+ text = '{"@type": "type.googleapis.com/google.protobuf.Int32Value"}'
+ self.assertRaisesRegexp(
+ KeyError,
+ 'value',
+ json_format.Parse, text, message)
+ text = '{"value": 1234}'
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ '@type is missing when parsing any message.',
+ json_format.Parse, text, message)
+ text = '{"@type": "type.googleapis.com/MessageNotExist", "value": 1234}'
+ self.assertRaisesRegexp(
+ TypeError,
+ 'Can not find message descriptor by type_url: '
+ 'type.googleapis.com/MessageNotExist.',
+ json_format.Parse, text, message)
+ # Only last part is to be used: b/25630112
+ text = (r'{"@type": "incorrect.googleapis.com/google.protobuf.Int32Value",'
+ r'"value": 1234}')
+ json_format.Parse(text, message)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/message_factory_python_test.py b/python/google/protobuf/internal/message_factory_python_test.py
deleted file mode 100644
index 85e02b259..000000000
--- a/python/google/protobuf/internal/message_factory_python_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests for ..public.message_factory for the pure Python implementation."""
-
-import os
-os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
-
-# We must set the implementation version above before the google3 imports.
-# pylint: disable=g-import-not-at-top
-from google.apputils import basetest
-from google.protobuf.internal import api_implementation
-# Run all tests from the original module by putting them in our namespace.
-# pylint: disable=wildcard-import
-from google.protobuf.internal.message_factory_test import *
-
-
-class ConfirmPurePythonTest(basetest.TestCase):
-
- def testImplementationSetting(self):
- self.assertEqual('python', api_implementation.Type())
-
-
-if __name__ == '__main__':
- basetest.main()
diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py
index fcf134103..7bb7d1ace 100644
--- a/python/google/protobuf/internal/message_factory_test.py
+++ b/python/google/protobuf/internal/message_factory_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -34,7 +34,11 @@
__author__ = 'matthewtoia@google.com (Matt Toia)'
-from google.apputils import basetest
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2
@@ -43,7 +47,7 @@ from google.protobuf import descriptor_pool
from google.protobuf import message_factory
-class MessageFactoryTest(basetest.TestCase):
+class MessageFactoryTest(unittest.TestCase):
def setUp(self):
self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
@@ -81,9 +85,9 @@ class MessageFactoryTest(basetest.TestCase):
serialized = msg.SerializeToString()
converted = factory_test2_pb2.Factory2Message.FromString(serialized)
reserialized = converted.SerializeToString()
- self.assertEquals(serialized, reserialized)
+ self.assertEqual(serialized, reserialized)
result = cls.FromString(reserialized)
- self.assertEquals(msg, result)
+ self.assertEqual(msg, result)
def testGetPrototype(self):
db = descriptor_database.DescriptorDatabase()
@@ -93,28 +97,29 @@ class MessageFactoryTest(basetest.TestCase):
factory = message_factory.MessageFactory()
cls = factory.GetPrototype(pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message'))
- self.assertIsNot(cls, factory_test2_pb2.Factory2Message)
+ self.assertFalse(cls is factory_test2_pb2.Factory2Message)
self._ExerciseDynamicClass(cls)
cls2 = factory.GetPrototype(pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message'))
- self.assertIs(cls, cls2)
+ self.assertTrue(cls is cls2)
def testGetMessages(self):
# performed twice because multiple calls with the same input must be allowed
for _ in range(2):
- messages = message_factory.GetMessages([self.factory_test2_fd,
- self.factory_test1_fd])
- self.assertContainsSubset(
- ['google.protobuf.python.internal.Factory2Message',
- 'google.protobuf.python.internal.Factory1Message'],
- messages.keys())
+ messages = message_factory.GetMessages([self.factory_test1_fd,
+ self.factory_test2_fd])
+ self.assertTrue(
+ set(['google.protobuf.python.internal.Factory2Message',
+ 'google.protobuf.python.internal.Factory1Message'],
+ ).issubset(set(messages.keys())))
self._ExerciseDynamicClass(
messages['google.protobuf.python.internal.Factory2Message'])
- self.assertContainsSubset(
- ['google.protobuf.python.internal.Factory2Message.one_more_field',
- 'google.protobuf.python.internal.another_field'],
- (messages['google.protobuf.python.internal.Factory1Message']
- ._extensions_by_name.keys()))
+ self.assertTrue(
+ set(['google.protobuf.python.internal.Factory2Message.one_more_field',
+ 'google.protobuf.python.internal.another_field'],
+ ).issubset(
+ set(messages['google.protobuf.python.internal.Factory1Message']
+ ._extensions_by_name.keys())))
factory_msg1 = messages['google.protobuf.python.internal.Factory1Message']
msg1 = messages['google.protobuf.python.internal.Factory1Message']()
ext1 = factory_msg1._extensions_by_name[
@@ -123,9 +128,63 @@ class MessageFactoryTest(basetest.TestCase):
'google.protobuf.python.internal.another_field']
msg1.Extensions[ext1] = 'test1'
msg1.Extensions[ext2] = 'test2'
- self.assertEquals('test1', msg1.Extensions[ext1])
- self.assertEquals('test2', msg1.Extensions[ext2])
+ self.assertEqual('test1', msg1.Extensions[ext1])
+ self.assertEqual('test2', msg1.Extensions[ext2])
+
+ def testDuplicateExtensionNumber(self):
+ pool = descriptor_pool.DescriptorPool()
+ factory = message_factory.MessageFactory(pool=pool)
+
+ # Add Container message.
+ f = descriptor_pb2.FileDescriptorProto()
+ f.name = 'google/protobuf/internal/container.proto'
+ f.package = 'google.protobuf.python.internal'
+ msg = f.message_type.add()
+ msg.name = 'Container'
+ rng = msg.extension_range.add()
+ rng.start = 1
+ rng.end = 10
+ pool.Add(f)
+ msgs = factory.GetMessages([f.name])
+ self.assertIn('google.protobuf.python.internal.Container', msgs)
+
+ # Extend container.
+ f = descriptor_pb2.FileDescriptorProto()
+ f.name = 'google/protobuf/internal/extension.proto'
+ f.package = 'google.protobuf.python.internal'
+ f.dependency.append('google/protobuf/internal/container.proto')
+ msg = f.message_type.add()
+ msg.name = 'Extension'
+ ext = msg.extension.add()
+ ext.name = 'extension_field'
+ ext.number = 2
+ ext.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
+ ext.type_name = 'Extension'
+ ext.extendee = 'Container'
+ pool.Add(f)
+ msgs = factory.GetMessages([f.name])
+ self.assertIn('google.protobuf.python.internal.Extension', msgs)
+
+ # Add Duplicate extending the same field number.
+ f = descriptor_pb2.FileDescriptorProto()
+ f.name = 'google/protobuf/internal/duplicate.proto'
+ f.package = 'google.protobuf.python.internal'
+ f.dependency.append('google/protobuf/internal/container.proto')
+ msg = f.message_type.add()
+ msg.name = 'Duplicate'
+ ext = msg.extension.add()
+ ext.name = 'extension_field'
+ ext.number = 2
+ ext.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
+ ext.type_name = 'Duplicate'
+ ext.extendee = 'Container'
+ pool.Add(f)
+
+ with self.assertRaises(Exception) as cm:
+ factory.GetMessages([f.name])
+
+ self.assertIsInstance(cm.exception, (AssertionError, ValueError))
if __name__ == '__main__':
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/internal/message_set_extensions.proto b/python/google/protobuf/internal/message_set_extensions.proto
new file mode 100644
index 000000000..14e5f1937
--- /dev/null
+++ b/python/google/protobuf/internal/message_set_extensions.proto
@@ -0,0 +1,74 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// This file contains messages that extend MessageSet.
+
+syntax = "proto2";
+package google.protobuf.internal;
+
+
+// A message with message_set_wire_format.
+message TestMessageSet {
+ option message_set_wire_format = true;
+ extensions 4 to max;
+}
+
+message TestMessageSetExtension1 {
+ extend TestMessageSet {
+ optional TestMessageSetExtension1 message_set_extension = 98418603;
+ }
+ optional int32 i = 15;
+}
+
+message TestMessageSetExtension2 {
+ extend TestMessageSet {
+ optional TestMessageSetExtension2 message_set_extension = 98418634;
+ }
+ optional string str = 25;
+}
+
+message TestMessageSetExtension3 {
+ optional string text = 35;
+}
+
+extend TestMessageSet {
+ optional TestMessageSetExtension3 message_set_extension3 = 98418655;
+}
+
+// This message was used to generate
+// //net/proto2/python/internal/testdata/message_set_message, but is commented
+// out since it must not actually exist in code, to simulate an "unknown"
+// extension.
+// message TestMessageSetUnknownExtension {
+// extend TestMessageSet {
+// optional TestMessageSetUnknownExtension message_set_extension = 56141421;
+// }
+// optional int64 a = 1;
+// }
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 48b7ffd4e..4ee31d8ed 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -43,17 +43,36 @@ abstract interface.
__author__ = 'gps@google.com (Gregory P. Smith)'
+
+import collections
import copy
import math
import operator
import pickle
+import six
import sys
-from google.apputils import basetest
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
+from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_pb2
+from google.protobuf import unittest_proto3_arena_pb2
+from google.protobuf import descriptor_pb2
+from google.protobuf import descriptor_pool
+from google.protobuf import message_factory
+from google.protobuf import text_format
from google.protobuf.internal import api_implementation
+from google.protobuf.internal import packed_field_test_pb2
from google.protobuf.internal import test_util
from google.protobuf import message
+from google.protobuf.internal import _parameterized
+
+if six.PY3:
+ long = int
+
# Python pre-2.6 does not have isinf() or isnan() functions, so we have
# to provide our own.
@@ -69,88 +88,72 @@ def IsNegInf(val):
return isinf(val) and (val < 0)
-class MessageTest(basetest.TestCase):
+@_parameterized.Parameters(
+ (unittest_pb2),
+ (unittest_proto3_arena_pb2))
+class MessageTest(unittest.TestCase):
- def testBadUtf8String(self):
+ def testBadUtf8String(self, message_module):
if api_implementation.Type() != 'python':
self.skipTest("Skipping testBadUtf8String, currently only the python "
"api implementation raises UnicodeDecodeError when a "
"string field contains bad utf-8.")
bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
with self.assertRaises(UnicodeDecodeError) as context:
- unittest_pb2.TestAllTypes.FromString(bad_utf8_data)
- self.assertIn('field: protobuf_unittest.TestAllTypes.optional_string',
- str(context.exception))
-
- def testGoldenMessage(self):
- golden_data = test_util.GoldenFileData(
- 'golden_message_oneof_implemented')
- golden_message = unittest_pb2.TestAllTypes()
- golden_message.ParseFromString(golden_data)
- test_util.ExpectAllFieldsSet(self, golden_message)
- self.assertEqual(golden_data, golden_message.SerializeToString())
- golden_copy = copy.deepcopy(golden_message)
- self.assertEqual(golden_data, golden_copy.SerializeToString())
+ message_module.TestAllTypes.FromString(bad_utf8_data)
+ self.assertIn('TestAllTypes.optional_string', str(context.exception))
+
+ def testGoldenMessage(self, message_module):
+ # Proto3 doesn't have the "default_foo" members or foreign enums,
+ # and doesn't preserve unknown fields, so for proto3 we use a golden
+ # message that doesn't have these fields set.
+ if message_module is unittest_pb2:
+ golden_data = test_util.GoldenFileData(
+ 'golden_message_oneof_implemented')
+ else:
+ golden_data = test_util.GoldenFileData('golden_message_proto3')
- def testGoldenExtensions(self):
- golden_data = test_util.GoldenFileData('golden_message')
- golden_message = unittest_pb2.TestAllExtensions()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(all_set)
- self.assertEquals(all_set, golden_message)
+ if message_module is unittest_pb2:
+ test_util.ExpectAllFieldsSet(self, golden_message)
self.assertEqual(golden_data, golden_message.SerializeToString())
golden_copy = copy.deepcopy(golden_message)
self.assertEqual(golden_data, golden_copy.SerializeToString())
- def testGoldenPackedMessage(self):
+ def testGoldenPackedMessage(self, message_module):
golden_data = test_util.GoldenFileData('golden_packed_fields_message')
- golden_message = unittest_pb2.TestPackedTypes()
+ golden_message = message_module.TestPackedTypes()
golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestPackedTypes()
+ all_set = message_module.TestPackedTypes()
test_util.SetAllPackedFields(all_set)
- self.assertEquals(all_set, golden_message)
+ self.assertEqual(all_set, golden_message)
self.assertEqual(golden_data, all_set.SerializeToString())
golden_copy = copy.deepcopy(golden_message)
self.assertEqual(golden_data, golden_copy.SerializeToString())
- def testGoldenPackedExtensions(self):
- golden_data = test_util.GoldenFileData('golden_packed_fields_message')
- golden_message = unittest_pb2.TestPackedExtensions()
- golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestPackedExtensions()
- test_util.SetAllPackedExtensions(all_set)
- self.assertEquals(all_set, golden_message)
- self.assertEqual(golden_data, all_set.SerializeToString())
- golden_copy = copy.deepcopy(golden_message)
- self.assertEqual(golden_data, golden_copy.SerializeToString())
-
- def testPickleSupport(self):
+ def testPickleSupport(self, message_module):
golden_data = test_util.GoldenFileData('golden_message')
- golden_message = unittest_pb2.TestAllTypes()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
pickled_message = pickle.dumps(golden_message)
unpickled_message = pickle.loads(pickled_message)
- self.assertEquals(unpickled_message, golden_message)
-
-
- def testPickleIncompleteProto(self):
- golden_message = unittest_pb2.TestRequired(a=1)
- pickled_message = pickle.dumps(golden_message)
-
- unpickled_message = pickle.loads(pickled_message)
- self.assertEquals(unpickled_message, golden_message)
- self.assertEquals(unpickled_message.a, 1)
- # This is still an incomplete proto - so serializing should fail
- self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
+ self.assertEqual(unpickled_message, golden_message)
+
+ def testPositiveInfinity(self, message_module):
+ if message_module is unittest_pb2:
+ golden_data = (b'\x5D\x00\x00\x80\x7F'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
+ b'\xCD\x02\x00\x00\x80\x7F'
+ b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
+ else:
+ golden_data = (b'\x5D\x00\x00\x80\x7F'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
+ b'\xCA\x02\x04\x00\x00\x80\x7F'
+ b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
- def testPositiveInfinity(self):
- golden_data = (b'\x5D\x00\x00\x80\x7F'
- b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
- b'\xCD\x02\x00\x00\x80\x7F'
- b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
- golden_message = unittest_pb2.TestAllTypes()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsPosInf(golden_message.optional_float))
self.assertTrue(IsPosInf(golden_message.optional_double))
@@ -158,12 +161,19 @@ class MessageTest(basetest.TestCase):
self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNegativeInfinity(self):
- golden_data = (b'\x5D\x00\x00\x80\xFF'
- b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
- b'\xCD\x02\x00\x00\x80\xFF'
- b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
- golden_message = unittest_pb2.TestAllTypes()
+ def testNegativeInfinity(self, message_module):
+ if message_module is unittest_pb2:
+ golden_data = (b'\x5D\x00\x00\x80\xFF'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
+ b'\xCD\x02\x00\x00\x80\xFF'
+ b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
+ else:
+ golden_data = (b'\x5D\x00\x00\x80\xFF'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
+ b'\xCA\x02\x04\x00\x00\x80\xFF'
+ b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
+
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsNegInf(golden_message.optional_float))
self.assertTrue(IsNegInf(golden_message.optional_double))
@@ -171,12 +181,12 @@ class MessageTest(basetest.TestCase):
self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNotANumber(self):
+ def testNotANumber(self, message_module):
golden_data = (b'\x5D\x00\x00\xC0\x7F'
b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
b'\xCD\x02\x00\x00\xC0\x7F'
b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
- golden_message = unittest_pb2.TestAllTypes()
+ golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(isnan(golden_message.optional_float))
self.assertTrue(isnan(golden_message.optional_double))
@@ -188,47 +198,47 @@ class MessageTest(basetest.TestCase):
# verify the serialized string can be converted into a correctly
# behaving protocol buffer.
serialized = golden_message.SerializeToString()
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.ParseFromString(serialized)
self.assertTrue(isnan(message.optional_float))
self.assertTrue(isnan(message.optional_double))
self.assertTrue(isnan(message.repeated_float[0]))
self.assertTrue(isnan(message.repeated_double[0]))
- def testPositiveInfinityPacked(self):
+ def testPositiveInfinityPacked(self, message_module):
golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
- golden_message = unittest_pb2.TestPackedTypes()
+ golden_message = message_module.TestPackedTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsPosInf(golden_message.packed_float[0]))
self.assertTrue(IsPosInf(golden_message.packed_double[0]))
self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNegativeInfinityPacked(self):
+ def testNegativeInfinityPacked(self, message_module):
golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
- golden_message = unittest_pb2.TestPackedTypes()
+ golden_message = message_module.TestPackedTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsNegInf(golden_message.packed_float[0]))
self.assertTrue(IsNegInf(golden_message.packed_double[0]))
self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNotANumberPacked(self):
+ def testNotANumberPacked(self, message_module):
golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
- golden_message = unittest_pb2.TestPackedTypes()
+ golden_message = message_module.TestPackedTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(isnan(golden_message.packed_float[0]))
self.assertTrue(isnan(golden_message.packed_double[0]))
serialized = golden_message.SerializeToString()
- message = unittest_pb2.TestPackedTypes()
+ message = message_module.TestPackedTypes()
message.ParseFromString(serialized)
self.assertTrue(isnan(message.packed_float[0]))
self.assertTrue(isnan(message.packed_double[0]))
- def testExtremeFloatValues(self):
- message = unittest_pb2.TestAllTypes()
+ def testExtremeFloatValues(self, message_module):
+ message = message_module.TestAllTypes()
# Most positive exponent, no significand bits set.
kMostPosExponentNoSigBits = math.pow(2, 127)
@@ -272,8 +282,8 @@ class MessageTest(basetest.TestCase):
message.ParseFromString(message.SerializeToString())
self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
- def testExtremeDoubleValues(self):
- message = unittest_pb2.TestAllTypes()
+ def testExtremeDoubleValues(self, message_module):
+ message = message_module.TestAllTypes()
# Most positive exponent, no significand bits set.
kMostPosExponentNoSigBits = math.pow(2, 1023)
@@ -317,29 +327,43 @@ class MessageTest(basetest.TestCase):
message.ParseFromString(message.SerializeToString())
self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
- def testFloatPrinting(self):
- message = unittest_pb2.TestAllTypes()
+ def testFloatPrinting(self, message_module):
+ message = message_module.TestAllTypes()
message.optional_float = 2.0
self.assertEqual(str(message), 'optional_float: 2.0\n')
- def testHighPrecisionFloatPrinting(self):
- message = unittest_pb2.TestAllTypes()
+ def testHighPrecisionFloatPrinting(self, message_module):
+ message = message_module.TestAllTypes()
message.optional_double = 0.12345678912345678
- if sys.version_info.major >= 3:
+ if sys.version_info >= (3,):
self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n')
else:
self.assertEqual(str(message), 'optional_double: 0.123456789123\n')
- def testUnknownFieldPrinting(self):
- populated = unittest_pb2.TestAllTypes()
+ def testUnknownFieldPrinting(self, message_module):
+ populated = message_module.TestAllTypes()
test_util.SetAllNonLazyFields(populated)
- empty = unittest_pb2.TestEmptyMessage()
+ empty = message_module.TestEmptyMessage()
empty.ParseFromString(populated.SerializeToString())
self.assertEqual(str(empty), '')
- def testSortingRepeatedScalarFieldsDefaultComparator(self):
+ def testRepeatedNestedFieldIteration(self, message_module):
+ msg = message_module.TestAllTypes()
+ msg.repeated_nested_message.add(bb=1)
+ msg.repeated_nested_message.add(bb=2)
+ msg.repeated_nested_message.add(bb=3)
+ msg.repeated_nested_message.add(bb=4)
+
+ self.assertEqual([1, 2, 3, 4],
+ [m.bb for m in msg.repeated_nested_message])
+ self.assertEqual([4, 3, 2, 1],
+ [m.bb for m in reversed(msg.repeated_nested_message)])
+ self.assertEqual([4, 3, 2, 1],
+ [m.bb for m in msg.repeated_nested_message[::-1]])
+
+ def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
"""Check some different types with the default comparator."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
# TODO(mattp): would testing more scalar types strengthen test?
message.repeated_int32.append(1)
@@ -374,9 +398,9 @@ class MessageTest(basetest.TestCase):
self.assertEqual(message.repeated_bytes[1], b'b')
self.assertEqual(message.repeated_bytes[2], b'c')
- def testSortingRepeatedScalarFieldsCustomComparator(self):
+ def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
"""Check some different types with custom comparator."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.repeated_int32.append(-3)
message.repeated_int32.append(-2)
@@ -394,9 +418,9 @@ class MessageTest(basetest.TestCase):
self.assertEqual(message.repeated_string[1], 'bb')
self.assertEqual(message.repeated_string[2], 'aaa')
- def testSortingRepeatedCompositeFieldsCustomComparator(self):
+ def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
"""Check passing a custom comparator to sort a repeated composite field."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.repeated_nested_message.add().bb = 1
message.repeated_nested_message.add().bb = 3
@@ -412,9 +436,34 @@ class MessageTest(basetest.TestCase):
self.assertEqual(message.repeated_nested_message[4].bb, 5)
self.assertEqual(message.repeated_nested_message[5].bb, 6)
- def testRepeatedCompositeFieldSortArguments(self):
+ def testSortingRepeatedCompositeFieldsStable(self, message_module):
+ """Check passing a custom comparator to sort a repeated composite field."""
+ message = message_module.TestAllTypes()
+
+ message.repeated_nested_message.add().bb = 21
+ message.repeated_nested_message.add().bb = 20
+ message.repeated_nested_message.add().bb = 13
+ message.repeated_nested_message.add().bb = 33
+ message.repeated_nested_message.add().bb = 11
+ message.repeated_nested_message.add().bb = 24
+ message.repeated_nested_message.add().bb = 10
+ message.repeated_nested_message.sort(key=lambda z: z.bb // 10)
+ self.assertEqual(
+ [13, 11, 10, 21, 20, 24, 33],
+ [n.bb for n in message.repeated_nested_message])
+
+ # Make sure that for the C++ implementation, the underlying fields
+ # are actually reordered.
+ pb = message.SerializeToString()
+ message.Clear()
+ message.MergeFromString(pb)
+ self.assertEqual(
+ [13, 11, 10, 21, 20, 24, 33],
+ [n.bb for n in message.repeated_nested_message])
+
+ def testRepeatedCompositeFieldSortArguments(self, message_module):
"""Check sorting a repeated composite field using list.sort() arguments."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
get_bb = operator.attrgetter('bb')
cmp_bb = lambda a, b: cmp(a.bb, b.bb)
@@ -430,7 +479,7 @@ class MessageTest(basetest.TestCase):
message.repeated_nested_message.sort(key=get_bb, reverse=True)
self.assertEqual([k.bb for k in message.repeated_nested_message],
[6, 5, 4, 3, 2, 1])
- if sys.version_info.major >= 3: return # No cmp sorting in PY3.
+ if sys.version_info >= (3,): return # No cmp sorting in PY3.
message.repeated_nested_message.sort(sort_function=cmp_bb)
self.assertEqual([k.bb for k in message.repeated_nested_message],
[1, 2, 3, 4, 5, 6])
@@ -438,9 +487,9 @@ class MessageTest(basetest.TestCase):
self.assertEqual([k.bb for k in message.repeated_nested_message],
[6, 5, 4, 3, 2, 1])
- def testRepeatedScalarFieldSortArguments(self):
+ def testRepeatedScalarFieldSortArguments(self, message_module):
"""Check sorting a scalar field using list.sort() arguments."""
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.repeated_int32.append(-3)
message.repeated_int32.append(-2)
@@ -449,7 +498,7 @@ class MessageTest(basetest.TestCase):
self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
message.repeated_int32.sort(key=abs, reverse=True)
self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
- if sys.version_info.major < 3: # No cmp sorting in PY3.
+ if sys.version_info < (3,): # No cmp sorting in PY3.
abs_cmp = lambda a, b: cmp(abs(a), abs(b))
message.repeated_int32.sort(sort_function=abs_cmp)
self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
@@ -463,16 +512,16 @@ class MessageTest(basetest.TestCase):
self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
message.repeated_string.sort(key=len, reverse=True)
self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
- if sys.version_info.major < 3: # No cmp sorting in PY3.
+ if sys.version_info < (3,): # No cmp sorting in PY3.
len_cmp = lambda a, b: cmp(len(a), len(b))
message.repeated_string.sort(sort_function=len_cmp)
self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
message.repeated_string.sort(cmp=len_cmp, reverse=True)
self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
- def testRepeatedFieldsComparable(self):
- m1 = unittest_pb2.TestAllTypes()
- m2 = unittest_pb2.TestAllTypes()
+ def testRepeatedFieldsComparable(self, message_module):
+ m1 = message_module.TestAllTypes()
+ m2 = message_module.TestAllTypes()
m1.repeated_int32.append(0)
m1.repeated_int32.append(1)
m1.repeated_int32.append(2)
@@ -486,7 +535,7 @@ class MessageTest(basetest.TestCase):
m2.repeated_nested_message.add().bb = 2
m2.repeated_nested_message.add().bb = 3
- if sys.version_info.major >= 3: return # No cmp() in PY3.
+ if sys.version_info >= (3,): return # No cmp() in PY3.
# These comparisons should not raise errors.
_ = m1 < m2
@@ -505,54 +554,11 @@ class MessageTest(basetest.TestCase):
# TODO(anuraag): Implement extensiondict comparison in C++ and then add test
- def testParsingMerge(self):
- """Check the merge behavior when a required or optional field appears
- multiple times in the input."""
- messages = [
- unittest_pb2.TestAllTypes(),
- unittest_pb2.TestAllTypes(),
- unittest_pb2.TestAllTypes() ]
- messages[0].optional_int32 = 1
- messages[1].optional_int64 = 2
- messages[2].optional_int32 = 3
- messages[2].optional_string = 'hello'
-
- merged_message = unittest_pb2.TestAllTypes()
- merged_message.optional_int32 = 3
- merged_message.optional_int64 = 2
- merged_message.optional_string = 'hello'
-
- generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
- generator.field1.extend(messages)
- generator.field2.extend(messages)
- generator.field3.extend(messages)
- generator.ext1.extend(messages)
- generator.ext2.extend(messages)
- generator.group1.add().field1.MergeFrom(messages[0])
- generator.group1.add().field1.MergeFrom(messages[1])
- generator.group1.add().field1.MergeFrom(messages[2])
- generator.group2.add().field1.MergeFrom(messages[0])
- generator.group2.add().field1.MergeFrom(messages[1])
- generator.group2.add().field1.MergeFrom(messages[2])
-
- data = generator.SerializeToString()
- parsing_merge = unittest_pb2.TestParsingMerge()
- parsing_merge.ParseFromString(data)
-
- # Required and optional fields should be merged.
- self.assertEqual(parsing_merge.required_all_types, merged_message)
- self.assertEqual(parsing_merge.optional_all_types, merged_message)
- self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
- merged_message)
- self.assertEqual(parsing_merge.Extensions[
- unittest_pb2.TestParsingMerge.optional_ext],
- merged_message)
-
- # Repeated fields should not be merged.
- self.assertEqual(len(parsing_merge.repeated_all_types), 3)
- self.assertEqual(len(parsing_merge.repeatedgroup), 3)
- self.assertEqual(len(parsing_merge.Extensions[
- unittest_pb2.TestParsingMerge.repeated_ext]), 3)
+ def testRepeatedFieldsAreSequences(self, message_module):
+ m = message_module.TestAllTypes()
+ self.assertIsInstance(m.repeated_int32, collections.MutableSequence)
+ self.assertIsInstance(m.repeated_nested_message,
+ collections.MutableSequence)
def ensureNestedMessageExists(self, msg, attribute):
"""Make sure that a nested message object exists.
@@ -563,12 +569,28 @@ class MessageTest(basetest.TestCase):
getattr(msg, attribute)
self.assertFalse(msg.HasField(attribute))
- def testOneofGetCaseNonexistingField(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofGetCaseNonexistingField(self, message_module):
+ m = message_module.TestAllTypes()
self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
- def testOneofSemantics(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofDefaultValues(self, message_module):
+ m = message_module.TestAllTypes()
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+ self.assertFalse(m.HasField('oneof_uint32'))
+
+ # Oneof is set even when setting it to a default value.
+ m.oneof_uint32 = 0
+ self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
+ self.assertTrue(m.HasField('oneof_uint32'))
+ self.assertFalse(m.HasField('oneof_string'))
+
+ m.oneof_string = ""
+ self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
+ self.assertTrue(m.HasField('oneof_string'))
+ self.assertFalse(m.HasField('oneof_uint32'))
+
+ def testOneofSemantics(self, message_module):
+ m = message_module.TestAllTypes()
self.assertIs(None, m.WhichOneof('oneof_field'))
m.oneof_uint32 = 11
@@ -580,6 +602,18 @@ class MessageTest(basetest.TestCase):
self.assertFalse(m.HasField('oneof_uint32'))
self.assertTrue(m.HasField('oneof_string'))
+ # Read nested message accessor without accessing submessage.
+ m.oneof_nested_message
+ self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
+ self.assertTrue(m.HasField('oneof_string'))
+ self.assertFalse(m.HasField('oneof_nested_message'))
+
+ # Read accessor of nested message without accessing submessage.
+ m.oneof_nested_message.bb
+ self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
+ self.assertTrue(m.HasField('oneof_string'))
+ self.assertFalse(m.HasField('oneof_nested_message'))
+
m.oneof_nested_message.bb = 11
self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
self.assertFalse(m.HasField('oneof_string'))
@@ -590,72 +624,1099 @@ class MessageTest(basetest.TestCase):
self.assertFalse(m.HasField('oneof_nested_message'))
self.assertTrue(m.HasField('oneof_bytes'))
- def testOneofCompositeFieldReadAccess(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofCompositeFieldReadAccess(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
self.ensureNestedMessageExists(m, 'oneof_nested_message')
self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
self.assertEqual(11, m.oneof_uint32)
- def testOneofHasField(self):
- m = unittest_pb2.TestAllTypes()
- self.assertFalse(m.HasField('oneof_field'))
+ def testOneofWhichOneof(self, message_module):
+ m = message_module.TestAllTypes()
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertFalse(m.HasField('oneof_field'))
+
m.oneof_uint32 = 11
- self.assertTrue(m.HasField('oneof_field'))
+ self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertTrue(m.HasField('oneof_field'))
+
m.oneof_bytes = b'bb'
- self.assertTrue(m.HasField('oneof_field'))
+ self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
+
m.ClearField('oneof_bytes')
- self.assertFalse(m.HasField('oneof_field'))
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertFalse(m.HasField('oneof_field'))
- def testOneofClearField(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofClearField(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
m.ClearField('oneof_field')
- self.assertFalse(m.HasField('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertFalse(m.HasField('oneof_field'))
self.assertFalse(m.HasField('oneof_uint32'))
self.assertIs(None, m.WhichOneof('oneof_field'))
- def testOneofClearSetField(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofClearSetField(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
m.ClearField('oneof_uint32')
- self.assertFalse(m.HasField('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertFalse(m.HasField('oneof_field'))
self.assertFalse(m.HasField('oneof_uint32'))
self.assertIs(None, m.WhichOneof('oneof_field'))
- def testOneofClearUnsetField(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofClearUnsetField(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
self.ensureNestedMessageExists(m, 'oneof_nested_message')
m.ClearField('oneof_nested_message')
self.assertEqual(11, m.oneof_uint32)
- self.assertTrue(m.HasField('oneof_field'))
+ if message_module is unittest_pb2:
+ self.assertTrue(m.HasField('oneof_field'))
self.assertTrue(m.HasField('oneof_uint32'))
self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
- def testOneofDeserialize(self):
- m = unittest_pb2.TestAllTypes()
+ def testOneofDeserialize(self, message_module):
+ m = message_module.TestAllTypes()
m.oneof_uint32 = 11
- m2 = unittest_pb2.TestAllTypes()
+ m2 = message_module.TestAllTypes()
m2.ParseFromString(m.SerializeToString())
self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
- def testSortEmptyRepeatedCompositeContainer(self):
+ def testOneofCopyFrom(self, message_module):
+ m = message_module.TestAllTypes()
+ m.oneof_uint32 = 11
+ m2 = message_module.TestAllTypes()
+ m2.CopyFrom(m)
+ self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
+
+ def testOneofNestedMergeFrom(self, message_module):
+ m = message_module.NestedTestAllTypes()
+ m.payload.oneof_uint32 = 11
+ m2 = message_module.NestedTestAllTypes()
+ m2.payload.oneof_bytes = b'bb'
+ m2.child.payload.oneof_bytes = b'bb'
+ m2.MergeFrom(m)
+ self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
+ self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
+
+ def testOneofMessageMergeFrom(self, message_module):
+ m = message_module.NestedTestAllTypes()
+ m.payload.oneof_nested_message.bb = 11
+ m.child.payload.oneof_nested_message.bb = 12
+ m2 = message_module.NestedTestAllTypes()
+ m2.payload.oneof_uint32 = 13
+ m2.MergeFrom(m)
+ self.assertEqual('oneof_nested_message',
+ m2.payload.WhichOneof('oneof_field'))
+ self.assertEqual('oneof_nested_message',
+ m2.child.payload.WhichOneof('oneof_field'))
+
+ def testOneofNestedMessageInit(self, message_module):
+ m = message_module.TestAllTypes(
+ oneof_nested_message=message_module.TestAllTypes.NestedMessage())
+ self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
+
+ def testOneofClear(self, message_module):
+ m = message_module.TestAllTypes()
+ m.oneof_uint32 = 11
+ m.Clear()
+ self.assertIsNone(m.WhichOneof('oneof_field'))
+ m.oneof_bytes = b'bb'
+ self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
+
+ def testAssignByteStringToUnicodeField(self, message_module):
+ """Assigning a byte string to a string field should result
+ in the value being converted to a Unicode string."""
+ m = message_module.TestAllTypes()
+ m.optional_string = str('')
+ self.assertIsInstance(m.optional_string, six.text_type)
+
+ def testLongValuedSlice(self, message_module):
+ """It should be possible to use long-valued indicies in slices
+
+ This didn't used to work in the v2 C++ implementation.
+ """
+ m = message_module.TestAllTypes()
+
+ # Repeated scalar
+ m.repeated_int32.append(1)
+ sl = m.repeated_int32[long(0):long(len(m.repeated_int32))]
+ self.assertEqual(len(m.repeated_int32), len(sl))
+
+ # Repeated composite
+ m.repeated_nested_message.add().bb = 3
+ sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))]
+ self.assertEqual(len(m.repeated_nested_message), len(sl))
+
+ def testExtendShouldNotSwallowExceptions(self, message_module):
+ """This didn't use to work in the v2 C++ implementation."""
+ m = message_module.TestAllTypes()
+ with self.assertRaises(NameError) as _:
+ m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable
+ with self.assertRaises(NameError) as _:
+ m.repeated_nested_enum.extend(
+ a for i in range(10)) # pylint: disable=undefined-variable
+
+ FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
+
+ def testExtendInt32WithNothing(self, message_module):
+ """Test no-ops extending repeated int32 fields."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_int32)
+
+ # TODO(ptucker): Deprecate this behavior. b/18413862
+ for falsy_value in MessageTest.FALSY_VALUES:
+ m.repeated_int32.extend(falsy_value)
+ self.assertSequenceEqual([], m.repeated_int32)
+
+ m.repeated_int32.extend([])
+ self.assertSequenceEqual([], m.repeated_int32)
+
+ def testExtendFloatWithNothing(self, message_module):
+ """Test no-ops extending repeated float fields."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_float)
+
+ # TODO(ptucker): Deprecate this behavior. b/18413862
+ for falsy_value in MessageTest.FALSY_VALUES:
+ m.repeated_float.extend(falsy_value)
+ self.assertSequenceEqual([], m.repeated_float)
+
+ m.repeated_float.extend([])
+ self.assertSequenceEqual([], m.repeated_float)
+
+ def testExtendStringWithNothing(self, message_module):
+ """Test no-ops extending repeated string fields."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_string)
+
+ # TODO(ptucker): Deprecate this behavior. b/18413862
+ for falsy_value in MessageTest.FALSY_VALUES:
+ m.repeated_string.extend(falsy_value)
+ self.assertSequenceEqual([], m.repeated_string)
+
+ m.repeated_string.extend([])
+ self.assertSequenceEqual([], m.repeated_string)
+
+ def testExtendInt32WithPythonList(self, message_module):
+ """Test extending repeated int32 fields with python lists."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_int32)
+ m.repeated_int32.extend([0])
+ self.assertSequenceEqual([0], m.repeated_int32)
+ m.repeated_int32.extend([1, 2])
+ self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
+ m.repeated_int32.extend([3, 4])
+ self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
+
+ def testExtendFloatWithPythonList(self, message_module):
+ """Test extending repeated float fields with python lists."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_float)
+ m.repeated_float.extend([0.0])
+ self.assertSequenceEqual([0.0], m.repeated_float)
+ m.repeated_float.extend([1.0, 2.0])
+ self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
+ m.repeated_float.extend([3.0, 4.0])
+ self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
+
+ def testExtendStringWithPythonList(self, message_module):
+ """Test extending repeated string fields with python lists."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_string)
+ m.repeated_string.extend([''])
+ self.assertSequenceEqual([''], m.repeated_string)
+ m.repeated_string.extend(['11', '22'])
+ self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
+ m.repeated_string.extend(['33', '44'])
+ self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
+
+ def testExtendStringWithString(self, message_module):
+ """Test extending repeated string fields with characters from a string."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_string)
+ m.repeated_string.extend('abc')
+ self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
+
+ class TestIterable(object):
+ """This iterable object mimics the behavior of numpy.array.
+
+ __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
+
+ """
+
+ def __init__(self, values=None):
+ self._list = values or []
+
+ def __nonzero__(self):
+ size = len(self._list)
+ if size == 0:
+ return False
+ if size == 1:
+ return bool(self._list[0])
+ raise ValueError('Truth value is ambiguous.')
+
+ def __len__(self):
+ return len(self._list)
+
+ def __iter__(self):
+ return self._list.__iter__()
+
+ def testExtendInt32WithIterable(self, message_module):
+ """Test extending repeated int32 fields with iterable."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_int32)
+ m.repeated_int32.extend(MessageTest.TestIterable([]))
+ self.assertSequenceEqual([], m.repeated_int32)
+ m.repeated_int32.extend(MessageTest.TestIterable([0]))
+ self.assertSequenceEqual([0], m.repeated_int32)
+ m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
+ self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
+ m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
+ self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
+
+ def testExtendFloatWithIterable(self, message_module):
+ """Test extending repeated float fields with iterable."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_float)
+ m.repeated_float.extend(MessageTest.TestIterable([]))
+ self.assertSequenceEqual([], m.repeated_float)
+ m.repeated_float.extend(MessageTest.TestIterable([0.0]))
+ self.assertSequenceEqual([0.0], m.repeated_float)
+ m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
+ self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
+ m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
+ self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
+
+ def testExtendStringWithIterable(self, message_module):
+ """Test extending repeated string fields with iterable."""
+ m = message_module.TestAllTypes()
+ self.assertSequenceEqual([], m.repeated_string)
+ m.repeated_string.extend(MessageTest.TestIterable([]))
+ self.assertSequenceEqual([], m.repeated_string)
+ m.repeated_string.extend(MessageTest.TestIterable(['']))
+ self.assertSequenceEqual([''], m.repeated_string)
+ m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
+ self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
+ m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
+ self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
+
+ def testPickleRepeatedScalarContainer(self, message_module):
+ # TODO(tibell): The pure-Python implementation support pickling of
+ # scalar containers in *some* cases. For now the cpp2 version
+ # throws an exception to avoid a segfault. Investigate if we
+ # want to support pickling of these fields.
+ #
+ # For more information see: https://b2.corp.google.com/u/0/issues/18677897
+ if (api_implementation.Type() != 'cpp' or
+ api_implementation.Version() == 2):
+ return
+ m = message_module.TestAllTypes()
+ with self.assertRaises(pickle.PickleError) as _:
+ pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
+
+ def testSortEmptyRepeatedCompositeContainer(self, message_module):
"""Exercise a scenario that has led to segfaults in the past.
"""
- m = unittest_pb2.TestAllTypes()
+ m = message_module.TestAllTypes()
m.repeated_nested_message.sort()
- def testHasFieldOnRepeatedField(self):
+ def testHasFieldOnRepeatedField(self, message_module):
"""Using HasField on a repeated field should raise an exception.
"""
- m = unittest_pb2.TestAllTypes()
+ m = message_module.TestAllTypes()
with self.assertRaises(ValueError) as _:
m.HasField('repeated_int32')
+ def testRepeatedScalarFieldPop(self, message_module):
+ m = message_module.TestAllTypes()
+ with self.assertRaises(IndexError) as _:
+ m.repeated_int32.pop()
+ m.repeated_int32.extend(range(5))
+ self.assertEqual(4, m.repeated_int32.pop())
+ self.assertEqual(0, m.repeated_int32.pop(0))
+ self.assertEqual(2, m.repeated_int32.pop(1))
+ self.assertEqual([1, 3], m.repeated_int32)
+
+ def testRepeatedCompositeFieldPop(self, message_module):
+ m = message_module.TestAllTypes()
+ with self.assertRaises(IndexError) as _:
+ m.repeated_nested_message.pop()
+ for i in range(5):
+ n = m.repeated_nested_message.add()
+ n.bb = i
+ self.assertEqual(4, m.repeated_nested_message.pop().bb)
+ self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
+ self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
+ self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
+
+
+# Class to test proto2-only features (required, extensions, etc.)
+class Proto2Test(unittest.TestCase):
+
+ def testFieldPresence(self):
+ message = unittest_pb2.TestAllTypes()
+
+ self.assertFalse(message.HasField("optional_int32"))
+ self.assertFalse(message.HasField("optional_bool"))
+ self.assertFalse(message.HasField("optional_nested_message"))
+
+ with self.assertRaises(ValueError):
+ message.HasField("field_doesnt_exist")
+
+ with self.assertRaises(ValueError):
+ message.HasField("repeated_int32")
+ with self.assertRaises(ValueError):
+ message.HasField("repeated_nested_message")
+
+ self.assertEqual(0, message.optional_int32)
+ self.assertEqual(False, message.optional_bool)
+ self.assertEqual(0, message.optional_nested_message.bb)
+
+ # Fields are set even when setting the values to default values.
+ message.optional_int32 = 0
+ message.optional_bool = False
+ message.optional_nested_message.bb = 0
+ self.assertTrue(message.HasField("optional_int32"))
+ self.assertTrue(message.HasField("optional_bool"))
+ self.assertTrue(message.HasField("optional_nested_message"))
+
+ # Set the fields to non-default values.
+ message.optional_int32 = 5
+ message.optional_bool = True
+ message.optional_nested_message.bb = 15
+
+ self.assertTrue(message.HasField("optional_int32"))
+ self.assertTrue(message.HasField("optional_bool"))
+ self.assertTrue(message.HasField("optional_nested_message"))
+
+ # Clearing the fields unsets them and resets their value to default.
+ message.ClearField("optional_int32")
+ message.ClearField("optional_bool")
+ message.ClearField("optional_nested_message")
+
+ self.assertFalse(message.HasField("optional_int32"))
+ self.assertFalse(message.HasField("optional_bool"))
+ self.assertFalse(message.HasField("optional_nested_message"))
+ self.assertEqual(0, message.optional_int32)
+ self.assertEqual(False, message.optional_bool)
+ self.assertEqual(0, message.optional_nested_message.bb)
+
+ # TODO(tibell): The C++ implementations actually allows assignment
+ # of unknown enum values to *scalar* fields (but not repeated
+ # fields). Once checked enum fields becomes the default in the
+ # Python implementation, the C++ implementation should follow suit.
+ def testAssignInvalidEnum(self):
+ """It should not be possible to assign an invalid enum number to an
+ enum field."""
+ m = unittest_pb2.TestAllTypes()
+
+ with self.assertRaises(ValueError) as _:
+ m.optional_nested_enum = 1234567
+ self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
+
+ def testGoldenExtensions(self):
+ golden_data = test_util.GoldenFileData('golden_message')
+ golden_message = unittest_pb2.TestAllExtensions()
+ golden_message.ParseFromString(golden_data)
+ all_set = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(all_set)
+ self.assertEqual(all_set, golden_message)
+ self.assertEqual(golden_data, golden_message.SerializeToString())
+ golden_copy = copy.deepcopy(golden_message)
+ self.assertEqual(golden_data, golden_copy.SerializeToString())
+
+ def testGoldenPackedExtensions(self):
+ golden_data = test_util.GoldenFileData('golden_packed_fields_message')
+ golden_message = unittest_pb2.TestPackedExtensions()
+ golden_message.ParseFromString(golden_data)
+ all_set = unittest_pb2.TestPackedExtensions()
+ test_util.SetAllPackedExtensions(all_set)
+ self.assertEqual(all_set, golden_message)
+ self.assertEqual(golden_data, all_set.SerializeToString())
+ golden_copy = copy.deepcopy(golden_message)
+ self.assertEqual(golden_data, golden_copy.SerializeToString())
+
+ def testPickleIncompleteProto(self):
+ golden_message = unittest_pb2.TestRequired(a=1)
+ pickled_message = pickle.dumps(golden_message)
+
+ unpickled_message = pickle.loads(pickled_message)
+ self.assertEqual(unpickled_message, golden_message)
+ self.assertEqual(unpickled_message.a, 1)
+ # This is still an incomplete proto - so serializing should fail
+ self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
+
+
+ # TODO(haberman): this isn't really a proto2-specific test except that this
+ # message has a required field in it. Should probably be factored out so
+ # that we can test the other parts with proto3.
+ def testParsingMerge(self):
+ """Check the merge behavior when a required or optional field appears
+ multiple times in the input."""
+ messages = [
+ unittest_pb2.TestAllTypes(),
+ unittest_pb2.TestAllTypes(),
+ unittest_pb2.TestAllTypes() ]
+ messages[0].optional_int32 = 1
+ messages[1].optional_int64 = 2
+ messages[2].optional_int32 = 3
+ messages[2].optional_string = 'hello'
+
+ merged_message = unittest_pb2.TestAllTypes()
+ merged_message.optional_int32 = 3
+ merged_message.optional_int64 = 2
+ merged_message.optional_string = 'hello'
+
+ generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
+ generator.field1.extend(messages)
+ generator.field2.extend(messages)
+ generator.field3.extend(messages)
+ generator.ext1.extend(messages)
+ generator.ext2.extend(messages)
+ generator.group1.add().field1.MergeFrom(messages[0])
+ generator.group1.add().field1.MergeFrom(messages[1])
+ generator.group1.add().field1.MergeFrom(messages[2])
+ generator.group2.add().field1.MergeFrom(messages[0])
+ generator.group2.add().field1.MergeFrom(messages[1])
+ generator.group2.add().field1.MergeFrom(messages[2])
+
+ data = generator.SerializeToString()
+ parsing_merge = unittest_pb2.TestParsingMerge()
+ parsing_merge.ParseFromString(data)
+
+ # Required and optional fields should be merged.
+ self.assertEqual(parsing_merge.required_all_types, merged_message)
+ self.assertEqual(parsing_merge.optional_all_types, merged_message)
+ self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
+ merged_message)
+ self.assertEqual(parsing_merge.Extensions[
+ unittest_pb2.TestParsingMerge.optional_ext],
+ merged_message)
+
+ # Repeated fields should not be merged.
+ self.assertEqual(len(parsing_merge.repeated_all_types), 3)
+ self.assertEqual(len(parsing_merge.repeatedgroup), 3)
+ self.assertEqual(len(parsing_merge.Extensions[
+ unittest_pb2.TestParsingMerge.repeated_ext]), 3)
+
+ def testPythonicInit(self):
+ message = unittest_pb2.TestAllTypes(
+ optional_int32=100,
+ optional_fixed32=200,
+ optional_float=300.5,
+ optional_bytes=b'x',
+ optionalgroup={'a': 400},
+ optional_nested_message={'bb': 500},
+ optional_nested_enum='BAZ',
+ repeatedgroup=[{'a': 600},
+ {'a': 700}],
+ repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
+ default_int32=800,
+ oneof_string='y')
+ self.assertIsInstance(message, unittest_pb2.TestAllTypes)
+ self.assertEqual(100, message.optional_int32)
+ self.assertEqual(200, message.optional_fixed32)
+ self.assertEqual(300.5, message.optional_float)
+ self.assertEqual(b'x', message.optional_bytes)
+ self.assertEqual(400, message.optionalgroup.a)
+ self.assertIsInstance(message.optional_nested_message, unittest_pb2.TestAllTypes.NestedMessage)
+ self.assertEqual(500, message.optional_nested_message.bb)
+ self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+ message.optional_nested_enum)
+ self.assertEqual(2, len(message.repeatedgroup))
+ self.assertEqual(600, message.repeatedgroup[0].a)
+ self.assertEqual(700, message.repeatedgroup[1].a)
+ self.assertEqual(2, len(message.repeated_nested_enum))
+ self.assertEqual(unittest_pb2.TestAllTypes.FOO,
+ message.repeated_nested_enum[0])
+ self.assertEqual(unittest_pb2.TestAllTypes.BAR,
+ message.repeated_nested_enum[1])
+ self.assertEqual(800, message.default_int32)
+ self.assertEqual('y', message.oneof_string)
+ self.assertFalse(message.HasField('optional_int64'))
+ self.assertEqual(0, len(message.repeated_float))
+ self.assertEqual(42, message.default_int64)
+
+ message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
+ self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+ message.optional_nested_enum)
+
+ with self.assertRaises(ValueError):
+ unittest_pb2.TestAllTypes(
+ optional_nested_message={'INVALID_NESTED_FIELD': 17})
+
+ with self.assertRaises(TypeError):
+ unittest_pb2.TestAllTypes(
+ optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
+
+ with self.assertRaises(ValueError):
+ unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
+
+ with self.assertRaises(ValueError):
+ unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
+
+
+
+# Class to test proto3-only features/behavior (updated field presence & enums)
+class Proto3Test(unittest.TestCase):
+
+ # Utility method for comparing equality with a map.
+ def assertMapIterEquals(self, map_iter, dict_value):
+ # Avoid mutating caller's copy.
+ dict_value = dict(dict_value)
+
+ for k, v in map_iter:
+ self.assertEqual(v, dict_value[k])
+ del dict_value[k]
+
+ self.assertEqual({}, dict_value)
+
+ def testFieldPresence(self):
+ message = unittest_proto3_arena_pb2.TestAllTypes()
+
+ # We can't test presence of non-repeated, non-submessage fields.
+ with self.assertRaises(ValueError):
+ message.HasField('optional_int32')
+ with self.assertRaises(ValueError):
+ message.HasField('optional_float')
+ with self.assertRaises(ValueError):
+ message.HasField('optional_string')
+ with self.assertRaises(ValueError):
+ message.HasField('optional_bool')
+
+ # But we can still test presence of submessage fields.
+ self.assertFalse(message.HasField('optional_nested_message'))
+
+ # As with proto2, we can't test presence of fields that don't exist, or
+ # repeated fields.
+ with self.assertRaises(ValueError):
+ message.HasField('field_doesnt_exist')
+
+ with self.assertRaises(ValueError):
+ message.HasField('repeated_int32')
+ with self.assertRaises(ValueError):
+ message.HasField('repeated_nested_message')
+
+ # Fields should default to their type-specific default.
+ self.assertEqual(0, message.optional_int32)
+ self.assertEqual(0, message.optional_float)
+ self.assertEqual('', message.optional_string)
+ self.assertEqual(False, message.optional_bool)
+ self.assertEqual(0, message.optional_nested_message.bb)
+
+ # Setting a submessage should still return proper presence information.
+ message.optional_nested_message.bb = 0
+ self.assertTrue(message.HasField('optional_nested_message'))
+
+ # Set the fields to non-default values.
+ message.optional_int32 = 5
+ message.optional_float = 1.1
+ message.optional_string = 'abc'
+ message.optional_bool = True
+ message.optional_nested_message.bb = 15
+
+ # Clearing the fields unsets them and resets their value to default.
+ message.ClearField('optional_int32')
+ message.ClearField('optional_float')
+ message.ClearField('optional_string')
+ message.ClearField('optional_bool')
+ message.ClearField('optional_nested_message')
+
+ self.assertEqual(0, message.optional_int32)
+ self.assertEqual(0, message.optional_float)
+ self.assertEqual('', message.optional_string)
+ self.assertEqual(False, message.optional_bool)
+ self.assertEqual(0, message.optional_nested_message.bb)
+
+ def testAssignUnknownEnum(self):
+ """Assigning an unknown enum value is allowed and preserves the value."""
+ m = unittest_proto3_arena_pb2.TestAllTypes()
+
+ m.optional_nested_enum = 1234567
+ self.assertEqual(1234567, m.optional_nested_enum)
+ m.repeated_nested_enum.append(22334455)
+ self.assertEqual(22334455, m.repeated_nested_enum[0])
+ # Assignment is a different code path than append for the C++ impl.
+ m.repeated_nested_enum[0] = 7654321
+ self.assertEqual(7654321, m.repeated_nested_enum[0])
+ serialized = m.SerializeToString()
+
+ m2 = unittest_proto3_arena_pb2.TestAllTypes()
+ m2.ParseFromString(serialized)
+ self.assertEqual(1234567, m2.optional_nested_enum)
+ self.assertEqual(7654321, m2.repeated_nested_enum[0])
+
+ # Map isn't really a proto3-only feature. But there is no proto2 equivalent
+ # of google/protobuf/map_unittest.proto right now, so it's not easy to
+ # test both with the same test like we do for the other proto2/proto3 tests.
+ # (google/protobuf/map_protobuf_unittest.proto is very different in the set
+ # of messages and fields it contains).
+ def testScalarMapDefaults(self):
+ msg = map_unittest_pb2.TestMap()
+
+ # Scalars start out unset.
+ self.assertFalse(-123 in msg.map_int32_int32)
+ self.assertFalse(-2**33 in msg.map_int64_int64)
+ self.assertFalse(123 in msg.map_uint32_uint32)
+ self.assertFalse(2**33 in msg.map_uint64_uint64)
+ self.assertFalse(123 in msg.map_int32_double)
+ self.assertFalse(False in msg.map_bool_bool)
+ self.assertFalse('abc' in msg.map_string_string)
+ self.assertFalse(111 in msg.map_int32_bytes)
+ self.assertFalse(888 in msg.map_int32_enum)
+
+ # Accessing an unset key returns the default.
+ self.assertEqual(0, msg.map_int32_int32[-123])
+ self.assertEqual(0, msg.map_int64_int64[-2**33])
+ self.assertEqual(0, msg.map_uint32_uint32[123])
+ self.assertEqual(0, msg.map_uint64_uint64[2**33])
+ self.assertEqual(0.0, msg.map_int32_double[123])
+ self.assertTrue(isinstance(msg.map_int32_double[123], float))
+ self.assertEqual(False, msg.map_bool_bool[False])
+ self.assertTrue(isinstance(msg.map_bool_bool[False], bool))
+ self.assertEqual('', msg.map_string_string['abc'])
+ self.assertEqual(b'', msg.map_int32_bytes[111])
+ self.assertEqual(0, msg.map_int32_enum[888])
+
+ # It also sets the value in the map
+ self.assertTrue(-123 in msg.map_int32_int32)
+ self.assertTrue(-2**33 in msg.map_int64_int64)
+ self.assertTrue(123 in msg.map_uint32_uint32)
+ self.assertTrue(2**33 in msg.map_uint64_uint64)
+ self.assertTrue(123 in msg.map_int32_double)
+ self.assertTrue(False in msg.map_bool_bool)
+ self.assertTrue('abc' in msg.map_string_string)
+ self.assertTrue(111 in msg.map_int32_bytes)
+ self.assertTrue(888 in msg.map_int32_enum)
+
+ self.assertIsInstance(msg.map_string_string['abc'], six.text_type)
+
+ # Accessing an unset key still throws TypeError if the type of the key
+ # is incorrect.
+ with self.assertRaises(TypeError):
+ msg.map_string_string[123]
+
+ with self.assertRaises(TypeError):
+ 123 in msg.map_string_string
+
+ def testMapGet(self):
+ # Need to test that get() properly returns the default, even though the dict
+ # has defaultdict-like semantics.
+ msg = map_unittest_pb2.TestMap()
+
+ self.assertIsNone(msg.map_int32_int32.get(5))
+ self.assertEqual(10, msg.map_int32_int32.get(5, 10))
+ self.assertIsNone(msg.map_int32_int32.get(5))
+
+ msg.map_int32_int32[5] = 15
+ self.assertEqual(15, msg.map_int32_int32.get(5))
+
+ self.assertIsNone(msg.map_int32_foreign_message.get(5))
+ self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
+
+ submsg = msg.map_int32_foreign_message[5]
+ self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
+
+ def testScalarMap(self):
+ msg = map_unittest_pb2.TestMap()
+
+ self.assertEqual(0, len(msg.map_int32_int32))
+ self.assertFalse(5 in msg.map_int32_int32)
-class ValidTypeNamesTest(basetest.TestCase):
+ msg.map_int32_int32[-123] = -456
+ msg.map_int64_int64[-2**33] = -2**34
+ msg.map_uint32_uint32[123] = 456
+ msg.map_uint64_uint64[2**33] = 2**34
+ msg.map_string_string['abc'] = '123'
+ msg.map_int32_enum[888] = 2
+
+ self.assertEqual([], msg.FindInitializationErrors())
+
+ self.assertEqual(1, len(msg.map_string_string))
+
+ # Bad key.
+ with self.assertRaises(TypeError):
+ msg.map_string_string[123] = '123'
+
+ # Verify that trying to assign a bad key doesn't actually add a member to
+ # the map.
+ self.assertEqual(1, len(msg.map_string_string))
+
+ # Bad value.
+ with self.assertRaises(TypeError):
+ msg.map_string_string['123'] = 123
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+
+ # Bad key.
+ with self.assertRaises(TypeError):
+ msg2.map_string_string[123] = '123'
+
+ # Bad value.
+ with self.assertRaises(TypeError):
+ msg2.map_string_string['123'] = 123
+
+ self.assertEqual(-456, msg2.map_int32_int32[-123])
+ self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
+ self.assertEqual(456, msg2.map_uint32_uint32[123])
+ self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
+ self.assertEqual('123', msg2.map_string_string['abc'])
+ self.assertEqual(2, msg2.map_int32_enum[888])
+
+ def testStringUnicodeConversionInMap(self):
+ msg = map_unittest_pb2.TestMap()
+
+ unicode_obj = u'\u1234'
+ bytes_obj = unicode_obj.encode('utf8')
+
+ msg.map_string_string[bytes_obj] = bytes_obj
+
+ (key, value) = list(msg.map_string_string.items())[0]
+
+ self.assertEqual(key, unicode_obj)
+ self.assertEqual(value, unicode_obj)
+
+ self.assertIsInstance(key, six.text_type)
+ self.assertIsInstance(value, six.text_type)
+
+ def testMessageMap(self):
+ msg = map_unittest_pb2.TestMap()
+
+ self.assertEqual(0, len(msg.map_int32_foreign_message))
+ self.assertFalse(5 in msg.map_int32_foreign_message)
+
+ msg.map_int32_foreign_message[123]
+ # get_or_create() is an alias for getitem.
+ msg.map_int32_foreign_message.get_or_create(-456)
+
+ self.assertEqual(2, len(msg.map_int32_foreign_message))
+ self.assertIn(123, msg.map_int32_foreign_message)
+ self.assertIn(-456, msg.map_int32_foreign_message)
+ self.assertEqual(2, len(msg.map_int32_foreign_message))
+
+ # Bad key.
+ with self.assertRaises(TypeError):
+ msg.map_int32_foreign_message['123']
+
+ # Can't assign directly to submessage.
+ with self.assertRaises(ValueError):
+ msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
+
+ # Verify that trying to assign a bad key doesn't actually add a member to
+ # the map.
+ self.assertEqual(2, len(msg.map_int32_foreign_message))
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(2, len(msg2.map_int32_foreign_message))
+ self.assertIn(123, msg2.map_int32_foreign_message)
+ self.assertIn(-456, msg2.map_int32_foreign_message)
+ self.assertEqual(2, len(msg2.map_int32_foreign_message))
+
+ def testMergeFrom(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[12] = 34
+ msg.map_int32_int32[56] = 78
+ msg.map_int64_int64[22] = 33
+ msg.map_int32_foreign_message[111].c = 5
+ msg.map_int32_foreign_message[222].c = 10
+
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.map_int32_int32[12] = 55
+ msg2.map_int64_int64[88] = 99
+ msg2.map_int32_foreign_message[222].c = 15
+
+ msg2.MergeFrom(msg)
+
+ self.assertEqual(34, msg2.map_int32_int32[12])
+ self.assertEqual(78, msg2.map_int32_int32[56])
+ self.assertEqual(33, msg2.map_int64_int64[22])
+ self.assertEqual(99, msg2.map_int64_int64[88])
+ self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
+ self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
+
+ # Verify that there is only one entry per key, even though the MergeFrom
+ # may have internally created multiple entries for a single key in the
+ # list representation.
+ as_dict = {}
+ for key in msg2.map_int32_foreign_message:
+ self.assertFalse(key in as_dict)
+ as_dict[key] = msg2.map_int32_foreign_message[key].c
+
+ self.assertEqual({111: 5, 222: 10}, as_dict)
+
+ # Special case: test that delete of item really removes the item, even if
+ # there might have physically been duplicate keys due to the previous merge.
+ # This is only a special case for the C++ implementation which stores the
+ # map as an array.
+ del msg2.map_int32_int32[12]
+ self.assertFalse(12 in msg2.map_int32_int32)
+
+ del msg2.map_int32_foreign_message[222]
+ self.assertFalse(222 in msg2.map_int32_foreign_message)
+
+ def testMergeFromBadType(self):
+ msg = map_unittest_pb2.TestMap()
+ with self.assertRaisesRegexp(
+ TypeError,
+ r'Parameter to MergeFrom\(\) must be instance of same class: expected '
+ r'.*TestMap got int\.'):
+ msg.MergeFrom(1)
+
+ def testCopyFromBadType(self):
+ msg = map_unittest_pb2.TestMap()
+ with self.assertRaisesRegexp(
+ TypeError,
+ r'Parameter to [A-Za-z]*From\(\) must be instance of same class: '
+ r'expected .*TestMap got int\.'):
+ msg.CopyFrom(1)
+
+ def testIntegerMapWithLongs(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[long(-123)] = long(-456)
+ msg.map_int64_int64[long(-2**33)] = long(-2**34)
+ msg.map_uint32_uint32[long(123)] = long(456)
+ msg.map_uint64_uint64[long(2**33)] = long(2**34)
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(-456, msg2.map_int32_int32[-123])
+ self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
+ self.assertEqual(456, msg2.map_uint32_uint32[123])
+ self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
+
+ def testMapAssignmentCausesPresence(self):
+ msg = map_unittest_pb2.TestMapSubmessage()
+ msg.test_map.map_int32_int32[123] = 456
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMapSubmessage()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(msg, msg2)
+
+ # Now test that various mutations of the map properly invalidate the
+ # cached size of the submessage.
+ msg.test_map.map_int32_int32[888] = 999
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ msg.test_map.map_int32_int32.clear()
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ def testMapAssignmentCausesPresenceForSubmessages(self):
+ msg = map_unittest_pb2.TestMapSubmessage()
+ msg.test_map.map_int32_foreign_message[123].c = 5
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMapSubmessage()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(msg, msg2)
+
+ # Now test that various mutations of the map properly invalidate the
+ # cached size of the submessage.
+ msg.test_map.map_int32_foreign_message[888].c = 7
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ msg.test_map.map_int32_foreign_message[888].MergeFrom(
+ msg.test_map.map_int32_foreign_message[123])
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ msg.test_map.map_int32_foreign_message.clear()
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ def testModifyMapWhileIterating(self):
+ msg = map_unittest_pb2.TestMap()
+
+ string_string_iter = iter(msg.map_string_string)
+ int32_foreign_iter = iter(msg.map_int32_foreign_message)
+
+ msg.map_string_string['abc'] = '123'
+ msg.map_int32_foreign_message[5].c = 5
+
+ with self.assertRaises(RuntimeError):
+ for key in string_string_iter:
+ pass
+
+ with self.assertRaises(RuntimeError):
+ for key in int32_foreign_iter:
+ pass
+
+ def testSubmessageMap(self):
+ msg = map_unittest_pb2.TestMap()
+
+ submsg = msg.map_int32_foreign_message[111]
+ self.assertIs(submsg, msg.map_int32_foreign_message[111])
+ self.assertIsInstance(submsg, unittest_pb2.ForeignMessage)
+
+ submsg.c = 5
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
+
+ # Doesn't allow direct submessage assignment.
+ with self.assertRaises(ValueError):
+ msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
+
+ def testMapIteration(self):
+ msg = map_unittest_pb2.TestMap()
+
+ for k, v in msg.map_int32_int32.items():
+ # Should not be reached.
+ self.assertTrue(False)
+
+ msg.map_int32_int32[2] = 4
+ msg.map_int32_int32[3] = 6
+ msg.map_int32_int32[4] = 8
+ self.assertEqual(3, len(msg.map_int32_int32))
+
+ matching_dict = {2: 4, 3: 6, 4: 8}
+ self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
+
+ def testMapItems(self):
+ # Map items used to have strange behaviors when use c extension. Because
+ # [] may reorder the map and invalidate any exsting iterators.
+ # TODO(jieluo): Check if [] reordering the map is a bug or intended
+ # behavior.
+ msg = map_unittest_pb2.TestMap()
+ msg.map_string_string['local_init_op'] = ''
+ msg.map_string_string['trainable_variables'] = ''
+ msg.map_string_string['variables'] = ''
+ msg.map_string_string['init_op'] = ''
+ msg.map_string_string['summaries'] = ''
+ items1 = msg.map_string_string.items()
+ items2 = msg.map_string_string.items()
+ self.assertEqual(items1, items2)
+
+ def testMapIterationClearMessage(self):
+ # Iterator needs to work even if message and map are deleted.
+ msg = map_unittest_pb2.TestMap()
+
+ msg.map_int32_int32[2] = 4
+ msg.map_int32_int32[3] = 6
+ msg.map_int32_int32[4] = 8
+
+ it = msg.map_int32_int32.items()
+ del msg
+
+ matching_dict = {2: 4, 3: 6, 4: 8}
+ self.assertMapIterEquals(it, matching_dict)
+
+ def testMapConstruction(self):
+ msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
+ self.assertEqual(2, msg.map_int32_int32[1])
+ self.assertEqual(4, msg.map_int32_int32[3])
+
+ msg = map_unittest_pb2.TestMap(
+ map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
+ self.assertEqual(5, msg.map_int32_foreign_message[3].c)
+
+ def testMapValidAfterFieldCleared(self):
+ # Map needs to work even if field is cleared.
+ # For the C++ implementation this tests the correctness of
+ # ScalarMapContainer::Release()
+ msg = map_unittest_pb2.TestMap()
+ int32_map = msg.map_int32_int32
+
+ int32_map[2] = 4
+ int32_map[3] = 6
+ int32_map[4] = 8
+
+ msg.ClearField('map_int32_int32')
+ self.assertEqual(b'', msg.SerializeToString())
+ matching_dict = {2: 4, 3: 6, 4: 8}
+ self.assertMapIterEquals(int32_map.items(), matching_dict)
+
+ def testMessageMapValidAfterFieldCleared(self):
+ # Map needs to work even if field is cleared.
+ # For the C++ implementation this tests the correctness of
+ # ScalarMapContainer::Release()
+ msg = map_unittest_pb2.TestMap()
+ int32_foreign_message = msg.map_int32_foreign_message
+
+ int32_foreign_message[2].c = 5
+
+ msg.ClearField('map_int32_foreign_message')
+ self.assertEqual(b'', msg.SerializeToString())
+ self.assertTrue(2 in int32_foreign_message.keys())
+
+ def testMapIterInvalidatedByClearField(self):
+ # Map iterator is invalidated when field is cleared.
+ # But this case does need to not crash the interpreter.
+ # For the C++ implementation this tests the correctness of
+ # ScalarMapContainer::Release()
+ msg = map_unittest_pb2.TestMap()
+
+ it = iter(msg.map_int32_int32)
+
+ msg.ClearField('map_int32_int32')
+ with self.assertRaises(RuntimeError):
+ for _ in it:
+ pass
+
+ it = iter(msg.map_int32_foreign_message)
+ msg.ClearField('map_int32_foreign_message')
+ with self.assertRaises(RuntimeError):
+ for _ in it:
+ pass
+
+ def testMapDelete(self):
+ msg = map_unittest_pb2.TestMap()
+
+ self.assertEqual(0, len(msg.map_int32_int32))
+
+ msg.map_int32_int32[4] = 6
+ self.assertEqual(1, len(msg.map_int32_int32))
+
+ with self.assertRaises(KeyError):
+ del msg.map_int32_int32[88]
+
+ del msg.map_int32_int32[4]
+ self.assertEqual(0, len(msg.map_int32_int32))
+
+ def testMapsAreMapping(self):
+ msg = map_unittest_pb2.TestMap()
+ self.assertIsInstance(msg.map_int32_int32, collections.Mapping)
+ self.assertIsInstance(msg.map_int32_int32, collections.MutableMapping)
+ self.assertIsInstance(msg.map_int32_foreign_message, collections.Mapping)
+ self.assertIsInstance(msg.map_int32_foreign_message,
+ collections.MutableMapping)
+
+ def testMapFindInitializationErrorsSmokeTest(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_string_string['abc'] = '123'
+ msg.map_int32_int32[35] = 64
+ msg.map_string_foreign_message['foo'].c = 5
+ self.assertEqual(0, len(msg.FindInitializationErrors()))
+
+
+
+class ValidTypeNamesTest(unittest.TestCase):
def assertImportFromName(self, msg, base_name):
# Parse <type 'module.class_name'> to extra 'some.name' as a string.
@@ -676,6 +1737,116 @@ class ValidTypeNamesTest(basetest.TestCase):
self.assertImportFromName(pb.repeated_int32, 'Scalar')
self.assertImportFromName(pb.repeated_nested_message, 'Composite')
+class PackedFieldTest(unittest.TestCase):
+
+ def setMessage(self, message):
+ message.repeated_int32.append(1)
+ message.repeated_int64.append(1)
+ message.repeated_uint32.append(1)
+ message.repeated_uint64.append(1)
+ message.repeated_sint32.append(1)
+ message.repeated_sint64.append(1)
+ message.repeated_fixed32.append(1)
+ message.repeated_fixed64.append(1)
+ message.repeated_sfixed32.append(1)
+ message.repeated_sfixed64.append(1)
+ message.repeated_float.append(1.0)
+ message.repeated_double.append(1.0)
+ message.repeated_bool.append(True)
+ message.repeated_nested_enum.append(1)
+
+ def testPackedFields(self):
+ message = packed_field_test_pb2.TestPackedTypes()
+ self.setMessage(message)
+ golden_data = (b'\x0A\x01\x01'
+ b'\x12\x01\x01'
+ b'\x1A\x01\x01'
+ b'\x22\x01\x01'
+ b'\x2A\x01\x02'
+ b'\x32\x01\x02'
+ b'\x3A\x04\x01\x00\x00\x00'
+ b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00'
+ b'\x4A\x04\x01\x00\x00\x00'
+ b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00'
+ b'\x5A\x04\x00\x00\x80\x3f'
+ b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f'
+ b'\x6A\x01\x01'
+ b'\x72\x01\x01')
+ self.assertEqual(golden_data, message.SerializeToString())
+
+ def testUnpackedFields(self):
+ message = packed_field_test_pb2.TestUnpackedTypes()
+ self.setMessage(message)
+ golden_data = (b'\x08\x01'
+ b'\x10\x01'
+ b'\x18\x01'
+ b'\x20\x01'
+ b'\x28\x02'
+ b'\x30\x02'
+ b'\x3D\x01\x00\x00\x00'
+ b'\x41\x01\x00\x00\x00\x00\x00\x00\x00'
+ b'\x4D\x01\x00\x00\x00'
+ b'\x51\x01\x00\x00\x00\x00\x00\x00\x00'
+ b'\x5D\x00\x00\x80\x3f'
+ b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f'
+ b'\x68\x01'
+ b'\x70\x01')
+ self.assertEqual(golden_data, message.SerializeToString())
+
+
+@unittest.skipIf(api_implementation.Type() != 'cpp',
+ 'explicit tests of the C++ implementation')
+class OversizeProtosTest(unittest.TestCase):
+
+ def setUp(self):
+ self.file_desc = """
+ name: "f/f.msg2"
+ package: "f"
+ message_type {
+ name: "msg1"
+ field {
+ name: "payload"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ }
+ message_type {
+ name: "msg2"
+ field {
+ name: "field"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: "msg1"
+ }
+ }
+ """
+ pool = descriptor_pool.DescriptorPool()
+ desc = descriptor_pb2.FileDescriptorProto()
+ text_format.Parse(self.file_desc, desc)
+ pool.Add(desc)
+ self.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
+ pool.FindMessageTypeByName('f.msg2'))
+ self.p = self.proto_cls()
+ self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1)
+ self.p_serialized = self.p.SerializeToString()
+
+ def testAssertOversizeProto(self):
+ from google.protobuf.pyext._message import SetAllowOversizeProtos
+ SetAllowOversizeProtos(False)
+ q = self.proto_cls()
+ try:
+ q.ParseFromString(self.p_serialized)
+ except message.DecodeError as e:
+ self.assertEqual(str(e), 'Error parsing message')
+
+ def testSucceedOversizeProto(self):
+ from google.protobuf.pyext._message import SetAllowOversizeProtos
+ SetAllowOversizeProtos(True)
+ q = self.proto_cls()
+ q.ParseFromString(self.p_serialized)
+ self.assertEqual(self.p.field.payload, q.field.payload)
if __name__ == '__main__':
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/internal/missing_enum_values.proto b/python/google/protobuf/internal/missing_enum_values.proto
index e90f0cd3e..1850be5bb 100644
--- a/python/google/protobuf/internal/missing_enum_values.proto
+++ b/python/google/protobuf/internal/missing_enum_values.proto
@@ -28,6 +28,8 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+syntax = "proto2";
+
package google.protobuf.python.internal;
message TestEnumValues {
@@ -48,3 +50,7 @@ message TestMissingEnumValues {
repeated NestedEnum repeated_nested_enum = 2;
repeated NestedEnum packed_nested_enum = 3 [packed = true];
}
+
+message JustString {
+ required string dummy = 1;
+}
diff --git a/python/google/protobuf/internal/more_extensions.proto b/python/google/protobuf/internal/more_extensions.proto
index c04e597f9..78f146736 100644
--- a/python/google/protobuf/internal/more_extensions.proto
+++ b/python/google/protobuf/internal/more_extensions.proto
@@ -30,6 +30,7 @@
// Author: robinson@google.com (Will Robinson)
+syntax = "proto2";
package google.protobuf.internal;
diff --git a/python/google/protobuf/internal/more_extensions_dynamic.proto b/python/google/protobuf/internal/more_extensions_dynamic.proto
index 88bd9c1b2..11f85ef60 100644
--- a/python/google/protobuf/internal/more_extensions_dynamic.proto
+++ b/python/google/protobuf/internal/more_extensions_dynamic.proto
@@ -34,6 +34,7 @@
// generated C++ type is available for the extendee, but the extension is
// defined in a file whose C++ type is not in the binary.
+syntax = "proto2";
import "google/protobuf/internal/more_extensions.proto";
diff --git a/python/google/protobuf/internal/more_messages.proto b/python/google/protobuf/internal/more_messages.proto
index 61db66c56..2c6ab9efd 100644
--- a/python/google/protobuf/internal/more_messages.proto
+++ b/python/google/protobuf/internal/more_messages.proto
@@ -30,6 +30,7 @@
// Author: robinson@google.com (Will Robinson)
+syntax = "proto2";
package google.protobuf.internal;
diff --git a/python/google/protobuf/internal/packed_field_test.proto b/python/google/protobuf/internal/packed_field_test.proto
new file mode 100644
index 000000000..0dfdc10a8
--- /dev/null
+++ b/python/google/protobuf/internal/packed_field_test.proto
@@ -0,0 +1,73 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+syntax = "proto3";
+
+package google.protobuf.python.internal;
+
+message TestPackedTypes {
+ enum NestedEnum {
+ FOO = 0;
+ BAR = 1;
+ BAZ = 2;
+ }
+
+ repeated int32 repeated_int32 = 1;
+ repeated int64 repeated_int64 = 2;
+ repeated uint32 repeated_uint32 = 3;
+ repeated uint64 repeated_uint64 = 4;
+ repeated sint32 repeated_sint32 = 5;
+ repeated sint64 repeated_sint64 = 6;
+ repeated fixed32 repeated_fixed32 = 7;
+ repeated fixed64 repeated_fixed64 = 8;
+ repeated sfixed32 repeated_sfixed32 = 9;
+ repeated sfixed64 repeated_sfixed64 = 10;
+ repeated float repeated_float = 11;
+ repeated double repeated_double = 12;
+ repeated bool repeated_bool = 13;
+ repeated NestedEnum repeated_nested_enum = 14;
+}
+
+message TestUnpackedTypes {
+ repeated int32 repeated_int32 = 1 [packed = false];
+ repeated int64 repeated_int64 = 2 [packed = false];
+ repeated uint32 repeated_uint32 = 3 [packed = false];
+ repeated uint64 repeated_uint64 = 4 [packed = false];
+ repeated sint32 repeated_sint32 = 5 [packed = false];
+ repeated sint64 repeated_sint64 = 6 [packed = false];
+ repeated fixed32 repeated_fixed32 = 7 [packed = false];
+ repeated fixed64 repeated_fixed64 = 8 [packed = false];
+ repeated sfixed32 repeated_sfixed32 = 9 [packed = false];
+ repeated sfixed64 repeated_sfixed64 = 10 [packed = false];
+ repeated float repeated_float = 11 [packed = false];
+ repeated double repeated_double = 12 [packed = false];
+ repeated bool repeated_bool = 13 [packed = false];
+ repeated TestPackedTypes.NestedEnum repeated_nested_enum = 14 [packed = false];
+}
diff --git a/python/google/protobuf/internal/proto_builder_test.py b/python/google/protobuf/internal/proto_builder_test.py
new file mode 100644
index 000000000..36dfbfded
--- /dev/null
+++ b/python/google/protobuf/internal/proto_builder_test.py
@@ -0,0 +1,96 @@
+#! /usr/bin/env python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests for google.protobuf.proto_builder."""
+
+try:
+ from collections import OrderedDict
+except ImportError:
+ from ordereddict import OrderedDict #PY26
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+from google.protobuf import descriptor_pb2
+from google.protobuf import descriptor_pool
+from google.protobuf import proto_builder
+from google.protobuf import text_format
+
+
+class ProtoBuilderTest(unittest.TestCase):
+
+ def setUp(self):
+ self.ordered_fields = OrderedDict([
+ ('foo', descriptor_pb2.FieldDescriptorProto.TYPE_INT64),
+ ('bar', descriptor_pb2.FieldDescriptorProto.TYPE_STRING),
+ ])
+ self._fields = dict(self.ordered_fields)
+
+ def testMakeSimpleProtoClass(self):
+ """Test that we can create a proto class."""
+ proto_cls = proto_builder.MakeSimpleProtoClass(
+ self._fields,
+ full_name='net.proto2.python.public.proto_builder_test.Test')
+ proto = proto_cls()
+ proto.foo = 12345
+ proto.bar = 'asdf'
+ self.assertMultiLineEqual(
+ 'bar: "asdf"\nfoo: 12345\n', text_format.MessageToString(proto))
+
+ def testOrderedFields(self):
+ """Test that the field order is maintained when given an OrderedDict."""
+ proto_cls = proto_builder.MakeSimpleProtoClass(
+ self.ordered_fields,
+ full_name='net.proto2.python.public.proto_builder_test.OrderedTest')
+ proto = proto_cls()
+ proto.foo = 12345
+ proto.bar = 'asdf'
+ self.assertMultiLineEqual(
+ 'foo: 12345\nbar: "asdf"\n', text_format.MessageToString(proto))
+
+ def testMakeSameProtoClassTwice(self):
+ """Test that the DescriptorPool is used."""
+ pool = descriptor_pool.DescriptorPool()
+ proto_cls1 = proto_builder.MakeSimpleProtoClass(
+ self._fields,
+ full_name='net.proto2.python.public.proto_builder_test.Test',
+ pool=pool)
+ proto_cls2 = proto_builder.MakeSimpleProtoClass(
+ self._fields,
+ full_name='net.proto2.python.public.proto_builder_test.Test',
+ pool=pool)
+ self.assertIs(proto_cls1.DESCRIPTOR, proto_cls2.DESCRIPTOR)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index a5c26f452..f8f73dd20 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -28,10 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-# Keep it Python2.5 compatible for GAE.
-#
-# Copyright 2007 Google Inc. All Rights Reserved.
-#
# This code is meant to work on Python 2.4 and above only.
#
# TODO(robinson): Helpers for verbose, common checks like seeing if a
@@ -54,19 +50,21 @@ this file*.
__author__ = 'robinson@google.com (Will Robinson)'
+from io import BytesIO
import sys
-if sys.version_info[0] < 3:
- try:
- from cStringIO import StringIO as BytesIO
- except ImportError:
- from StringIO import StringIO as BytesIO
- import copy_reg as copyreg
-else:
- from io import BytesIO
- import copyreg
import struct
import weakref
+import six
+try:
+ import six.moves.copyreg as copyreg
+except ImportError:
+ # On some platforms, for example gMac, we run native Python because there is
+ # nothing like hermetic Python. This means lesser control on the system and
+ # the six.moves package may be missing (is missing on 20150321 on gMac). Be
+ # extra conservative and try to load the old replacement if it fails.
+ import copy_reg as copyreg
+
# We use "as" to avoid name collisions with variables.
from google.protobuf.internal import containers
from google.protobuf.internal import decoder
@@ -74,41 +72,121 @@ from google.protobuf.internal import encoder
from google.protobuf.internal import enum_type_wrapper
from google.protobuf.internal import message_listener as message_listener_mod
from google.protobuf.internal import type_checkers
+from google.protobuf.internal import well_known_types
from google.protobuf.internal import wire_format
from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message as message_mod
+from google.protobuf import symbol_database
from google.protobuf import text_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
+_AnyFullTypeName = 'google.protobuf.Any'
-def NewMessage(bases, descriptor, dictionary):
- _AddClassAttributesForNestedExtensions(descriptor, dictionary)
- _AddSlots(descriptor, dictionary)
- return bases
+class GeneratedProtocolMessageType(type):
+ """Metaclass for protocol message classes created at runtime from Descriptors.
-def InitMessage(descriptor, cls):
- cls._decoders_by_tag = {}
- cls._extensions_by_name = {}
- cls._extensions_by_number = {}
- if (descriptor.has_options and
- descriptor.GetOptions().message_set_wire_format):
- cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
- decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
+ We add implementations for all methods described in the Message class. We
+ also create properties to allow getting/setting all fields in the protocol
+ message. Finally, we create slots to prevent users from accidentally
+ "setting" nonexistent fields in the protocol message, which then wouldn't get
+ serialized / deserialized properly.
- # Attach stuff to each FieldDescriptor for quick lookup later on.
- for field in descriptor.fields:
- _AttachFieldHelpers(cls, field)
+ The protocol compiler currently uses this metaclass to create protocol
+ message classes at runtime. Clients can also manually create their own
+ classes at runtime, as in this example:
- _AddEnumValues(descriptor, cls)
- _AddInitMethod(descriptor, cls)
- _AddPropertiesForFields(descriptor, cls)
- _AddPropertiesForExtensions(descriptor, cls)
- _AddStaticMethods(cls)
- _AddMessageMethods(descriptor, cls)
- _AddPrivateHelperMethods(descriptor, cls)
- copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
+ mydescriptor = Descriptor(.....)
+ class MyProtoClass(Message):
+ __metaclass__ = GeneratedProtocolMessageType
+ DESCRIPTOR = mydescriptor
+ myproto_instance = MyProtoClass()
+ myproto.foo_field = 23
+ ...
+
+ The above example will not work for nested types. If you wish to include them,
+ use reflection.MakeClass() instead of manually instantiating the class in
+ order to create the appropriate class structure.
+ """
+
+ # Must be consistent with the protocol-compiler code in
+ # proto2/compiler/internal/generator.*.
+ _DESCRIPTOR_KEY = 'DESCRIPTOR'
+
+ def __new__(cls, name, bases, dictionary):
+ """Custom allocation for runtime-generated class types.
+
+ We override __new__ because this is apparently the only place
+ where we can meaningfully set __slots__ on the class we're creating(?).
+ (The interplay between metaclasses and slots is not very well-documented).
+
+ Args:
+ name: Name of the class (ignored, but required by the
+ metaclass protocol).
+ bases: Base classes of the class we're constructing.
+ (Should be message.Message). We ignore this field, but
+ it's required by the metaclass protocol
+ dictionary: The class dictionary of the class we're
+ constructing. dictionary[_DESCRIPTOR_KEY] must contain
+ a Descriptor object describing this protocol message
+ type.
+
+ Returns:
+ Newly-allocated class.
+ """
+ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+ if descriptor.full_name in well_known_types.WKTBASES:
+ bases += (well_known_types.WKTBASES[descriptor.full_name],)
+ _AddClassAttributesForNestedExtensions(descriptor, dictionary)
+ _AddSlots(descriptor, dictionary)
+
+ superclass = super(GeneratedProtocolMessageType, cls)
+ new_class = superclass.__new__(cls, name, bases, dictionary)
+ return new_class
+
+ def __init__(cls, name, bases, dictionary):
+ """Here we perform the majority of our work on the class.
+ We add enum getters, an __init__ method, implementations
+ of all Message methods, and properties for all fields
+ in the protocol type.
+
+ Args:
+ name: Name of the class (ignored, but required by the
+ metaclass protocol).
+ bases: Base classes of the class we're constructing.
+ (Should be message.Message). We ignore this field, but
+ it's required by the metaclass protocol
+ dictionary: The class dictionary of the class we're
+ constructing. dictionary[_DESCRIPTOR_KEY] must contain
+ a Descriptor object describing this protocol message
+ type.
+ """
+ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+ cls._decoders_by_tag = {}
+ cls._extensions_by_name = {}
+ cls._extensions_by_number = {}
+ if (descriptor.has_options and
+ descriptor.GetOptions().message_set_wire_format):
+ cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
+ decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
+
+ # Attach stuff to each FieldDescriptor for quick lookup later on.
+ for field in descriptor.fields:
+ _AttachFieldHelpers(cls, field)
+
+ descriptor._concrete_class = cls # pylint: disable=protected-access
+ _AddEnumValues(descriptor, cls)
+ _AddInitMethod(descriptor, cls)
+ _AddPropertiesForFields(descriptor, cls)
+ _AddPropertiesForExtensions(descriptor, cls)
+ _AddStaticMethods(cls)
+ _AddMessageMethods(descriptor, cls)
+ _AddPrivateHelperMethods(descriptor, cls)
+ copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
+
+ superclass = super(GeneratedProtocolMessageType, cls)
+ superclass.__init__(name, bases, dictionary)
# Stateless helpers for GeneratedProtocolMessageType below.
@@ -194,16 +272,40 @@ def _IsMessageSetExtension(field):
field.containing_type.has_options and
field.containing_type.GetOptions().message_set_wire_format and
field.type == _FieldDescriptor.TYPE_MESSAGE and
- field.message_type == field.extension_scope and
field.label == _FieldDescriptor.LABEL_OPTIONAL)
+def _IsMapField(field):
+ return (field.type == _FieldDescriptor.TYPE_MESSAGE and
+ field.message_type.has_options and
+ field.message_type.GetOptions().map_entry)
+
+
+def _IsMessageMapField(field):
+ value_type = field.message_type.fields_by_name["value"]
+ return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
+
+
def _AttachFieldHelpers(cls, field_descriptor):
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
- is_packed = (field_descriptor.has_options and
- field_descriptor.GetOptions().packed)
-
- if _IsMessageSetExtension(field_descriptor):
+ is_packable = (is_repeated and
+ wire_format.IsTypePackable(field_descriptor.type))
+ if not is_packable:
+ is_packed = False
+ elif field_descriptor.containing_type.syntax == "proto2":
+ is_packed = (field_descriptor.has_options and
+ field_descriptor.GetOptions().packed)
+ else:
+ has_packed_false = (field_descriptor.has_options and
+ field_descriptor.GetOptions().HasField("packed") and
+ field_descriptor.GetOptions().packed == False)
+ is_packed = not has_packed_false
+ is_map_entry = _IsMapField(field_descriptor)
+
+ if is_map_entry:
+ field_encoder = encoder.MapEncoder(field_descriptor)
+ sizer = encoder.MapSizer(field_descriptor)
+ elif _IsMessageSetExtension(field_descriptor):
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
sizer = encoder.MessageSetItemSizer(field_descriptor.number)
else:
@@ -219,12 +321,27 @@ def _AttachFieldHelpers(cls, field_descriptor):
def AddDecoder(wiretype, is_packed):
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
- cls._decoders_by_tag[tag_bytes] = (
- type_checkers.TYPE_TO_DECODER[field_descriptor.type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor),
- field_descriptor if field_descriptor.containing_oneof is not None
- else None)
+ decode_type = field_descriptor.type
+ if (decode_type == _FieldDescriptor.TYPE_ENUM and
+ type_checkers.SupportsOpenEnums(field_descriptor)):
+ decode_type = _FieldDescriptor.TYPE_INT32
+
+ oneof_descriptor = None
+ if field_descriptor.containing_oneof is not None:
+ oneof_descriptor = field_descriptor
+
+ if is_map_entry:
+ is_message_map = _IsMessageMapField(field_descriptor)
+
+ field_decoder = decoder.MapDecoder(
+ field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
+ is_message_map)
+ else:
+ field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor)
+
+ cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
False)
@@ -237,7 +354,7 @@ def _AttachFieldHelpers(cls, field_descriptor):
def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
extension_dict = descriptor.extensions_by_name
- for extension_name, extension_field in extension_dict.iteritems():
+ for extension_name, extension_field in extension_dict.items():
assert extension_name not in dictionary
dictionary[extension_name] = extension_field
@@ -257,6 +374,26 @@ def _AddEnumValues(descriptor, cls):
setattr(cls, enum_value.name, enum_value.number)
+def _GetInitializeDefaultForMap(field):
+ if field.label != _FieldDescriptor.LABEL_REPEATED:
+ raise ValueError('map_entry set on non-repeated field %s' % (
+ field.name))
+ fields_by_name = field.message_type.fields_by_name
+ key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
+
+ value_field = fields_by_name['value']
+ if _IsMessageMapField(field):
+ def MakeMessageMapDefault(message):
+ return containers.MessageMap(
+ message._listener_for_children, value_field.message_type, key_checker)
+ return MakeMessageMapDefault
+ else:
+ value_checker = type_checkers.GetTypeChecker(value_field)
+ def MakePrimitiveMapDefault(message):
+ return containers.ScalarMap(
+ message._listener_for_children, key_checker, value_checker)
+ return MakePrimitiveMapDefault
+
def _DefaultValueConstructorForField(field):
"""Returns a function which returns a default value for a field.
@@ -271,6 +408,9 @@ def _DefaultValueConstructorForField(field):
value may refer back to |message| via a weak reference.
"""
+ if _IsMapField(field):
+ return _GetInitializeDefaultForMap(field)
+
if field.label == _FieldDescriptor.LABEL_REPEATED:
if field.has_default_value and field.default_value != []:
raise ValueError('Repeated field default value not empty list: %s' % (
@@ -295,7 +435,10 @@ def _DefaultValueConstructorForField(field):
message_type = field.message_type
def MakeSubMessageDefault(message):
result = message_type._concrete_class()
- result._SetListener(message._listener_for_children)
+ result._SetListener(
+ _OneofListener(message, field)
+ if field.containing_oneof is not None
+ else message._listener_for_children)
return result
return MakeSubMessageDefault
@@ -306,9 +449,35 @@ def _DefaultValueConstructorForField(field):
return MakeScalarDefault
+def _ReraiseTypeErrorWithFieldName(message_name, field_name):
+ """Re-raise the currently-handled TypeError with the field name added."""
+ exc = sys.exc_info()[1]
+ if len(exc.args) == 1 and type(exc) is TypeError:
+ # simple TypeError; add field name to exception message
+ exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
+
+ # re-raise possibly-amended exception with original traceback:
+ six.reraise(type(exc), exc, sys.exc_info()[2])
+
+
def _AddInitMethod(message_descriptor, cls):
"""Adds an __init__ method to cls."""
- fields = message_descriptor.fields
+
+ def _GetIntegerEnumValue(enum_type, value):
+ """Convert a string or integer enum value to an integer.
+
+ If the value is a string, it is converted to the enum value in
+ enum_type with the same name. If the value is not a string, it's
+ returned as-is. (No conversion or bounds-checking is done.)
+ """
+ if isinstance(value, six.string_types):
+ try:
+ return enum_type.values_by_name[value].number
+ except KeyError:
+ raise ValueError('Enum type %s: unknown label "%s"' % (
+ enum_type.full_name, value))
+ return value
+
def init(self, **kwargs):
self._cached_byte_size = 0
self._cached_byte_size_dirty = len(kwargs) > 0
@@ -323,25 +492,52 @@ def _AddInitMethod(message_descriptor, cls):
self._is_present_in_parent = False
self._listener = message_listener_mod.NullMessageListener()
self._listener_for_children = _Listener(self)
- for field_name, field_value in kwargs.iteritems():
+ for field_name, field_value in kwargs.items():
field = _GetFieldByName(message_descriptor, field_name)
if field is None:
raise TypeError("%s() got an unexpected keyword argument '%s'" %
(message_descriptor.name, field_name))
+ if field_value is None:
+ # field=None is the same as no field at all.
+ continue
if field.label == _FieldDescriptor.LABEL_REPEATED:
copy = field._default_constructor(self)
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
- for val in field_value:
- copy.add().MergeFrom(val)
+ if _IsMapField(field):
+ if _IsMessageMapField(field):
+ for key in field_value:
+ copy[key].MergeFrom(field_value[key])
+ else:
+ copy.update(field_value)
+ else:
+ for val in field_value:
+ if isinstance(val, dict):
+ copy.add(**val)
+ else:
+ copy.add().MergeFrom(val)
else: # Scalar
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
+ field_value = [_GetIntegerEnumValue(field.enum_type, val)
+ for val in field_value]
copy.extend(field_value)
self._fields[field] = copy
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
copy = field._default_constructor(self)
- copy.MergeFrom(field_value)
+ new_val = field_value
+ if isinstance(field_value, dict):
+ new_val = field.message_type._concrete_class(**field_value)
+ try:
+ copy.MergeFrom(new_val)
+ except TypeError:
+ _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
self._fields[field] = copy
else:
- setattr(self, field_name, field_value)
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
+ field_value = _GetIntegerEnumValue(field.enum_type, field_value)
+ try:
+ setattr(self, field_name, field_value)
+ except TypeError:
+ _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
init.__module__ = None
init.__doc__ = None
@@ -360,7 +556,8 @@ def _GetFieldByName(message_descriptor, field_name):
try:
return message_descriptor.fields_by_name[field_name]
except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
+ raise ValueError('Protocol message %s has no "%s" field.' %
+ (message_descriptor.name, field_name))
def _AddPropertiesForFields(descriptor, cls):
@@ -459,6 +656,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
type_checker = type_checkers.GetTypeChecker(field)
default_value = field.default_value
valid_values = set()
+ is_proto3 = field.containing_type.syntax == "proto3"
def getter(self):
# TODO(protobuf-team): This may be broken since there may not be
@@ -466,15 +664,24 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
return self._fields.get(field, default_value)
getter.__module__ = None
getter.__doc__ = 'Getter for %s.' % proto_field_name
+
+ clear_when_set_to_default = is_proto3 and not field.containing_oneof
+
def field_setter(self, new_value):
# pylint: disable=protected-access
- self._fields[field] = type_checker.CheckValue(new_value)
+ # Testing the value for truthiness captures all of the proto3 defaults
+ # (0, 0.0, enum 0, and False).
+ new_value = type_checker.CheckValue(new_value)
+ if clear_when_set_to_default and not new_value:
+ self._fields.pop(field, None)
+ else:
+ self._fields[field] = new_value
# Check _cached_byte_size_dirty inline to improve performance, since scalar
# setters are called frequently.
if not self._cached_byte_size_dirty:
self._Modified()
- if field.containing_oneof is not None:
+ if field.containing_oneof:
def setter(self, new_value):
field_setter(self, new_value)
self._UpdateOneofState(field)
@@ -505,21 +712,11 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
proto_field_name = field.name
property_name = _PropertyName(proto_field_name)
- # TODO(komarek): Can anyone explain to me why we cache the message_type this
- # way, instead of referring to field.message_type inside of getter(self)?
- # What if someone sets message_type later on (which makes for simpler
- # dyanmic proto descriptor and class creation code).
- message_type = field.message_type
-
def getter(self):
field_value = self._fields.get(field)
if field_value is None:
# Construct a new object to represent this field.
- field_value = message_type._concrete_class() # use field.message_type?
- field_value._SetListener(
- _OneofListener(self, field)
- if field.containing_oneof is not None
- else self._listener_for_children)
+ field_value = field._default_constructor(self)
# Atomically check if another thread has preempted us and, if not, swap
# in the new object we just created. If someone has preempted us, we
@@ -546,7 +743,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
def _AddPropertiesForExtensions(descriptor, cls):
"""Adds properties for all fields in this protocol message type."""
extension_dict = descriptor.extensions_by_name
- for extension_name, extension_field in extension_dict.iteritems():
+ for extension_name, extension_field in extension_dict.items():
constant_name = extension_name.upper() + "_FIELD_NUMBER"
setattr(cls, constant_name, extension_field.number)
@@ -601,30 +798,41 @@ def _AddListFieldsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def ListFields(self):
- all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
+ all_fields = [item for item in self._fields.items() if _IsPresent(item)]
all_fields.sort(key = lambda item: item[0].number)
return all_fields
cls.ListFields = ListFields
+_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
+_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
def _AddHasFieldMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- singular_fields = {}
+ is_proto3 = (message_descriptor.syntax == "proto3")
+ error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
+
+ hassable_fields = {}
for field in message_descriptor.fields:
- if field.label != _FieldDescriptor.LABEL_REPEATED:
- singular_fields[field.name] = field
- # Fields inside oneofs are never repeated (enforced by the compiler).
- for field in message_descriptor.oneofs:
- singular_fields[field.name] = field
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ continue
+ # For proto3, only submessages and fields inside a oneof have presence.
+ if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
+ not field.containing_oneof):
+ continue
+ hassable_fields[field.name] = field
+
+ if not is_proto3:
+ # Fields inside oneofs are never repeated (enforced by the compiler).
+ for oneof in message_descriptor.oneofs:
+ hassable_fields[oneof.name] = oneof
def HasField(self, field_name):
try:
- field = singular_fields[field_name]
+ field = hassable_fields[field_name]
except KeyError:
- raise ValueError(
- 'Protocol message has no singular "%s" field.' % field_name)
+ raise ValueError(error_msg % field_name)
if isinstance(field, descriptor_mod.OneofDescriptor):
try:
@@ -654,9 +862,15 @@ def _AddClearFieldMethod(message_descriptor, cls):
else:
return
except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
+ raise ValueError('Protocol message %s() has no "%s" field.' %
+ (message_descriptor.name, field_name))
if field in self._fields:
+ # To match the C++ implementation, we need to invalidate iterators
+ # for map fields when ClearField() happens.
+ if hasattr(self._fields[field], 'InvalidateIterators'):
+ self._fields[field].InvalidateIterators()
+
# Note: If the field is a sub-message, its listener will still point
# at us. That's fine, because the worst than can happen is that it
# will call _Modified() and invalidate our byte size. Big deal.
@@ -685,16 +899,6 @@ def _AddClearExtensionMethod(cls):
cls.ClearExtension = ClearExtension
-def _AddClearMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def Clear(self):
- # Clear fields.
- self._fields = {}
- self._unknown_fields = ()
- self._Modified()
- cls.Clear = Clear
-
-
def _AddHasExtensionMethod(cls):
"""Helper for _AddMessageMethods()."""
def HasExtension(self, extension_handle):
@@ -709,6 +913,38 @@ def _AddHasExtensionMethod(cls):
return extension_handle in self._fields
cls.HasExtension = HasExtension
+def _InternalUnpackAny(msg):
+ """Unpacks Any message and returns the unpacked message.
+
+ This internal method is differnt from public Any Unpack method which takes
+ the target message as argument. _InternalUnpackAny method does not have
+ target message type and need to find the message type in descriptor pool.
+
+ Args:
+ msg: An Any message to be unpacked.
+
+ Returns:
+ The unpacked message.
+ """
+ type_url = msg.type_url
+ db = symbol_database.Default()
+
+ if not type_url:
+ return None
+
+ # TODO(haberman): For now we just strip the hostname. Better logic will be
+ # required.
+ type_name = type_url.split("/")[-1]
+ descriptor = db.pool.FindMessageTypeByName(type_name)
+
+ if descriptor is None:
+ return None
+
+ message_class = db.GetPrototype(descriptor)
+ message = message_class()
+
+ message.ParseFromString(msg.value)
+ return message
def _AddEqualsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
@@ -720,6 +956,12 @@ def _AddEqualsMethod(message_descriptor, cls):
if self is other:
return True
+ if self.DESCRIPTOR.full_name == _AnyFullTypeName:
+ any_a = _InternalUnpackAny(self)
+ any_b = _InternalUnpackAny(other)
+ if any_a and any_b:
+ return any_a == any_b
+
if not self.ListFields() == other.ListFields():
return False
@@ -741,6 +983,13 @@ def _AddStrMethod(message_descriptor, cls):
cls.__str__ = __str__
+def _AddReprMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def __repr__(self):
+ return text_format.MessageToString(self)
+ cls.__repr__ = __repr__
+
+
def _AddUnicodeMethod(unused_message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
@@ -749,16 +998,6 @@ def _AddUnicodeMethod(unused_message_descriptor, cls):
cls.__unicode__ = __unicode__
-def _AddSetListenerMethod(cls):
- """Helper for _AddMessageMethods()."""
- def SetListener(self, listener):
- if listener is None:
- self._listener = message_listener_mod.NullMessageListener()
- else:
- self._listener = listener
- cls._SetListener = SetListener
-
-
def _BytesForNonRepeatedElement(value, field_number, field_type):
"""Returns the number of bytes needed to serialize a non-repeated element.
The returned byte count includes space for tag information and any
@@ -845,7 +1084,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
except (IndexError, TypeError):
# Now ord(buf[p:p+1]) == ord('') gets TypeError.
raise message_mod.DecodeError('Truncated message.')
- except struct.error, e:
+ except struct.error as e:
raise message_mod.DecodeError(e)
return length # Return this for legacy reasons.
cls.MergeFromString = MergeFromString
@@ -853,6 +1092,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag
+ is_proto3 = message_descriptor.syntax == "proto3"
def InternalParse(self, buffer, pos, end):
self._Modified()
@@ -866,9 +1106,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
if new_pos == -1:
return pos
- if not unknown_field_list:
- unknown_field_list = self._unknown_fields = []
- unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
+ if not is_proto3:
+ if not unknown_field_list:
+ unknown_field_list = self._unknown_fields = []
+ unknown_field_list.append(
+ (tag_bytes, buffer[value_start_pos:new_pos]))
pos = new_pos
else:
pos = field_decoder(buffer, new_pos, end, self, field_dict)
@@ -909,6 +1151,9 @@ def _AddIsInitializedMethod(message_descriptor, cls):
for field, value in list(self._fields.items()): # dict can change size!
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
if field.label == _FieldDescriptor.LABEL_REPEATED:
+ if (field.message_type.has_options and
+ field.message_type.GetOptions().map_entry):
+ continue
for element in value:
if not element.IsInitialized():
if errors is not None:
@@ -944,16 +1189,26 @@ def _AddIsInitializedMethod(message_descriptor, cls):
else:
name = field.name
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- for i in xrange(len(value)):
+ if _IsMapField(field):
+ if _IsMessageMapField(field):
+ for key in value:
+ element = value[key]
+ prefix = "%s[%s]." % (name, key)
+ sub_errors = element.FindInitializationErrors()
+ errors += [prefix + error for error in sub_errors]
+ else:
+ # ScalarMaps can't have any initialization errors.
+ pass
+ elif field.label == _FieldDescriptor.LABEL_REPEATED:
+ for i in range(len(value)):
element = value[i]
prefix = "%s[%d]." % (name, i)
sub_errors = element.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
+ errors += [prefix + error for error in sub_errors]
else:
prefix = name + "."
sub_errors = value.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
+ errors += [prefix + error for error in sub_errors]
return errors
@@ -975,7 +1230,7 @@ def _AddMergeFromMethod(cls):
fields = self._fields
- for field, value in msg._fields.iteritems():
+ for field, value in msg._fields.items():
if field.label == LABEL_REPEATED:
field_value = fields.get(field)
if field_value is None:
@@ -993,6 +1248,8 @@ def _AddMergeFromMethod(cls):
field_value.MergeFrom(value)
else:
self._fields[field] = value
+ if field.containing_oneof:
+ self._UpdateOneofState(field)
if msg._unknown_fields:
if not self._unknown_fields:
@@ -1020,6 +1277,32 @@ def _AddWhichOneofMethod(message_descriptor, cls):
cls.WhichOneof = WhichOneof
+def _Clear(self):
+ # Clear fields.
+ self._fields = {}
+ self._unknown_fields = ()
+ self._oneofs = {}
+ self._Modified()
+
+
+def _DiscardUnknownFields(self):
+ self._unknown_fields = []
+ for field, value in self.ListFields():
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ for sub_message in value:
+ sub_message.DiscardUnknownFields()
+ else:
+ value.DiscardUnknownFields()
+
+
+def _SetListener(self, listener):
+ if listener is None:
+ self._listener = message_listener_mod.NullMessageListener()
+ else:
+ self._listener = listener
+
+
def _AddMessageMethods(message_descriptor, cls):
"""Adds implementations of all Message methods to cls."""
_AddListFieldsMethod(message_descriptor, cls)
@@ -1028,11 +1311,10 @@ def _AddMessageMethods(message_descriptor, cls):
if message_descriptor.is_extendable:
_AddClearExtensionMethod(cls)
_AddHasExtensionMethod(cls)
- _AddClearMethod(message_descriptor, cls)
_AddEqualsMethod(message_descriptor, cls)
_AddStrMethod(message_descriptor, cls)
+ _AddReprMethod(message_descriptor, cls)
_AddUnicodeMethod(message_descriptor, cls)
- _AddSetListenerMethod(cls)
_AddByteSizeMethod(message_descriptor, cls)
_AddSerializeToStringMethod(message_descriptor, cls)
_AddSerializePartialToStringMethod(message_descriptor, cls)
@@ -1040,6 +1322,11 @@ def _AddMessageMethods(message_descriptor, cls):
_AddIsInitializedMethod(message_descriptor, cls)
_AddMergeFromMethod(cls)
_AddWhichOneofMethod(message_descriptor, cls)
+ # Adds methods which do not depend on cls.
+ cls.Clear = _Clear
+ cls.DiscardUnknownFields = _DiscardUnknownFields
+ cls._SetListener = _SetListener
+
def _AddPrivateHelperMethods(message_descriptor, cls):
"""Adds implementation of private helper methods to cls."""
@@ -1232,11 +1519,10 @@ class _ExtensionDict(object):
# It's slightly wasteful to lookup the type checker each time,
# but we expect this to be a vanishingly uncommon case anyway.
- type_checker = type_checkers.GetTypeChecker(
- extension_handle)
+ type_checker = type_checkers.GetTypeChecker(extension_handle)
# pylint: disable=protected-access
self._extended_message._fields[extension_handle] = (
- type_checker.CheckValue(value))
+ type_checker.CheckValue(value))
self._extended_message._Modified()
def _FindExtensionByName(self, name):
@@ -1249,3 +1535,14 @@ class _ExtensionDict(object):
Extension field descriptor.
"""
return self._extended_message._extensions_by_name.get(name, None)
+
+ def _FindExtensionByNumber(self, number):
+ """Tries to find a known extension with the field number.
+
+ Args:
+ number: Extension field number.
+
+ Returns:
+ Extension field descriptor.
+ """
+ return self._extended_message._extensions_by_number.get(number, None)
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index d59815d00..6dc2fffe2 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# Protocol Buffers - Google's data interchange format
@@ -35,14 +35,17 @@
pure-Python protocol compiler.
"""
-__author__ = 'robinson@google.com (Will Robinson)'
-
import copy
import gc
import operator
+import six
import struct
-from google.apputils import basetest
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
@@ -54,6 +57,7 @@ from google.protobuf import text_format
from google.protobuf.internal import api_implementation
from google.protobuf.internal import more_extensions_pb2
from google.protobuf.internal import more_messages_pb2
+from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import wire_format
from google.protobuf.internal import test_util
from google.protobuf.internal import decoder
@@ -104,7 +108,7 @@ class _MiniDecoder(object):
return self._pos == len(self._bytes)
-class ReflectionTest(basetest.TestCase):
+class ReflectionTest(unittest.TestCase):
def assertListsEqual(self, values, others):
self.assertEqual(len(values), len(others))
@@ -116,11 +120,13 @@ class ReflectionTest(basetest.TestCase):
proto = unittest_pb2.TestAllTypes(
optional_int32=24,
optional_double=54.321,
- optional_string='optional_string')
+ optional_string='optional_string',
+ optional_float=None)
self.assertEqual(24, proto.optional_int32)
self.assertEqual(54.321, proto.optional_double)
self.assertEqual('optional_string', proto.optional_string)
+ self.assertFalse(proto.HasField("optional_float"))
def testRepeatedScalarConstructor(self):
# Constructor with only repeated scalar types should succeed.
@@ -128,12 +134,14 @@ class ReflectionTest(basetest.TestCase):
repeated_int32=[1, 2, 3, 4],
repeated_double=[1.23, 54.321],
repeated_bool=[True, False, False],
- repeated_string=["optional_string"])
+ repeated_string=["optional_string"],
+ repeated_float=None)
- self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32))
- self.assertEquals([1.23, 54.321], list(proto.repeated_double))
- self.assertEquals([True, False, False], list(proto.repeated_bool))
- self.assertEquals(["optional_string"], list(proto.repeated_string))
+ self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32))
+ self.assertEqual([1.23, 54.321], list(proto.repeated_double))
+ self.assertEqual([True, False, False], list(proto.repeated_bool))
+ self.assertEqual(["optional_string"], list(proto.repeated_string))
+ self.assertEqual([], list(proto.repeated_float))
def testRepeatedCompositeConstructor(self):
# Constructor with only repeated composite types should succeed.
@@ -152,18 +160,18 @@ class ReflectionTest(basetest.TestCase):
unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
- self.assertEquals(
+ self.assertEqual(
[unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.FOO),
unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.BAR)],
list(proto.repeated_nested_message))
- self.assertEquals(
+ self.assertEqual(
[unittest_pb2.ForeignMessage(c=-43),
unittest_pb2.ForeignMessage(c=45324),
unittest_pb2.ForeignMessage(c=12)],
list(proto.repeated_foreign_message))
- self.assertEquals(
+ self.assertEqual(
[unittest_pb2.TestAllTypes.RepeatedGroup(),
unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
@@ -184,23 +192,25 @@ class ReflectionTest(basetest.TestCase):
repeated_foreign_message=[
unittest_pb2.ForeignMessage(c=-43),
unittest_pb2.ForeignMessage(c=45324),
- unittest_pb2.ForeignMessage(c=12)])
+ unittest_pb2.ForeignMessage(c=12)],
+ optional_nested_message=None)
self.assertEqual(24, proto.optional_int32)
self.assertEqual('optional_string', proto.optional_string)
- self.assertEquals([1.23, 54.321], list(proto.repeated_double))
- self.assertEquals([True, False, False], list(proto.repeated_bool))
- self.assertEquals(
+ self.assertEqual([1.23, 54.321], list(proto.repeated_double))
+ self.assertEqual([True, False, False], list(proto.repeated_bool))
+ self.assertEqual(
[unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.FOO),
unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.BAR)],
list(proto.repeated_nested_message))
- self.assertEquals(
+ self.assertEqual(
[unittest_pb2.ForeignMessage(c=-43),
unittest_pb2.ForeignMessage(c=45324),
unittest_pb2.ForeignMessage(c=12)],
list(proto.repeated_foreign_message))
+ self.assertFalse(proto.HasField("optional_nested_message"))
def testConstructorTypeError(self):
self.assertRaises(
@@ -224,18 +234,18 @@ class ReflectionTest(basetest.TestCase):
def testConstructorInvalidatesCachedByteSize(self):
message = unittest_pb2.TestAllTypes(optional_int32 = 12)
- self.assertEquals(2, message.ByteSize())
+ self.assertEqual(2, message.ByteSize())
message = unittest_pb2.TestAllTypes(
optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage())
- self.assertEquals(3, message.ByteSize())
+ self.assertEqual(3, message.ByteSize())
message = unittest_pb2.TestAllTypes(repeated_int32 = [12])
- self.assertEquals(3, message.ByteSize())
+ self.assertEqual(3, message.ByteSize())
message = unittest_pb2.TestAllTypes(
repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()])
- self.assertEquals(3, message.ByteSize())
+ self.assertEqual(3, message.ByteSize())
def testSimpleHasBits(self):
# Test a scalar.
@@ -469,7 +479,7 @@ class ReflectionTest(basetest.TestCase):
proto.repeated_string.extend(['foo', 'bar'])
proto.repeated_string.extend([])
proto.repeated_string.append('baz')
- proto.repeated_string.extend(str(x) for x in xrange(2))
+ proto.repeated_string.extend(str(x) for x in range(2))
proto.optional_int32 = 21
proto.repeated_bool # Access but don't set anything; should not be listed.
self.assertEqual(
@@ -611,14 +621,18 @@ class ReflectionTest(basetest.TestCase):
def TestGetAndDeserialize(field_name, value, expected_type):
proto = unittest_pb2.TestAllTypes()
setattr(proto, field_name, value)
- self.assertTrue(isinstance(getattr(proto, field_name), expected_type))
+ self.assertIsInstance(getattr(proto, field_name), expected_type)
proto2 = unittest_pb2.TestAllTypes()
proto2.ParseFromString(proto.SerializeToString())
- self.assertTrue(isinstance(getattr(proto2, field_name), expected_type))
+ self.assertIsInstance(getattr(proto2, field_name), expected_type)
TestGetAndDeserialize('optional_int32', 1, int)
TestGetAndDeserialize('optional_int32', 1 << 30, int)
TestGetAndDeserialize('optional_uint32', 1 << 30, int)
+ try:
+ integer_64 = long
+ except NameError: # Python3
+ integer_64 = int
if struct.calcsize('L') == 4:
# Python only has signed ints, so 32-bit python can't fit an uint32
# in an int.
@@ -626,10 +640,10 @@ class ReflectionTest(basetest.TestCase):
else:
# 64-bit python can fit uint32 inside an int
TestGetAndDeserialize('optional_uint32', 1 << 31, int)
- TestGetAndDeserialize('optional_int64', 1 << 30, long)
- TestGetAndDeserialize('optional_int64', 1 << 60, long)
- TestGetAndDeserialize('optional_uint64', 1 << 30, long)
- TestGetAndDeserialize('optional_uint64', 1 << 60, long)
+ TestGetAndDeserialize('optional_int64', 1 << 30, integer_64)
+ TestGetAndDeserialize('optional_int64', 1 << 60, integer_64)
+ TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
+ TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
def testSingleScalarBoundsChecking(self):
def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
@@ -755,18 +769,18 @@ class ReflectionTest(basetest.TestCase):
def testEnum_KeysAndValues(self):
self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'],
- unittest_pb2.ForeignEnum.keys())
+ list(unittest_pb2.ForeignEnum.keys()))
self.assertEqual([4, 5, 6],
- unittest_pb2.ForeignEnum.values())
+ list(unittest_pb2.ForeignEnum.values()))
self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5),
('FOREIGN_BAZ', 6)],
- unittest_pb2.ForeignEnum.items())
+ list(unittest_pb2.ForeignEnum.items()))
proto = unittest_pb2.TestAllTypes()
- self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys())
- self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values())
+ self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], list(proto.NestedEnum.keys()))
+ self.assertEqual([1, 2, 3, -1], list(proto.NestedEnum.values()))
self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)],
- proto.NestedEnum.items())
+ list(proto.NestedEnum.items()))
def testRepeatedScalars(self):
proto = unittest_pb2.TestAllTypes()
@@ -805,7 +819,7 @@ class ReflectionTest(basetest.TestCase):
self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
# Test slice assignment with an iterator
- proto.repeated_int32[1:4] = (i for i in xrange(3))
+ proto.repeated_int32[1:4] = (i for i in range(3))
self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
# Test slice assignment.
@@ -896,7 +910,7 @@ class ReflectionTest(basetest.TestCase):
self.assertTrue(proto.repeated_nested_message)
self.assertEqual(2, len(proto.repeated_nested_message))
self.assertListsEqual([m0, m1], proto.repeated_nested_message)
- self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage))
+ self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage)
# Test out-of-bounds indices.
self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
@@ -1008,9 +1022,8 @@ class ReflectionTest(basetest.TestCase):
containing_type=None, nested_types=[], enum_types=[],
fields=[foo_field_descriptor], extensions=[],
options=descriptor_pb2.MessageOptions())
- class MyProtoClass(message.Message):
+ class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
DESCRIPTOR = mydescriptor
- __metaclass__ = reflection.GeneratedProtocolMessageType
myproto_instance = MyProtoClass()
self.assertEqual(0, myproto_instance.foo_field)
self.assertTrue(not myproto_instance.HasField('foo_field'))
@@ -1050,14 +1063,13 @@ class ReflectionTest(basetest.TestCase):
new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
desc = descriptor.MakeDescriptor(desc_proto)
- self.assertTrue(desc.fields_by_name.has_key('name'))
- self.assertTrue(desc.fields_by_name.has_key('year'))
- self.assertTrue(desc.fields_by_name.has_key('automatic'))
- self.assertTrue(desc.fields_by_name.has_key('price'))
- self.assertTrue(desc.fields_by_name.has_key('owners'))
-
- class CarMessage(message.Message):
- __metaclass__ = reflection.GeneratedProtocolMessageType
+ self.assertTrue('name' in desc.fields_by_name)
+ self.assertTrue('year' in desc.fields_by_name)
+ self.assertTrue('automatic' in desc.fields_by_name)
+ self.assertTrue('price' in desc.fields_by_name)
+ self.assertTrue('owners' in desc.fields_by_name)
+
+ class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
DESCRIPTOR = desc
prius = CarMessage()
@@ -1175,7 +1187,7 @@ class ReflectionTest(basetest.TestCase):
self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
# Make sure extensions haven't been registered into types that shouldn't
# have any.
- self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
+ self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
# If message A directly contains message B, and
# a.HasField('b') is currently False, then mutating any
@@ -1252,15 +1264,18 @@ class ReflectionTest(basetest.TestCase):
# Try something that *is* an extension handle, just not for
# this message...
- unknown_handle = more_extensions_pb2.optional_int_extension
- self.assertRaises(KeyError, extendee_proto.HasExtension,
- unknown_handle)
- self.assertRaises(KeyError, extendee_proto.ClearExtension,
- unknown_handle)
- self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
- unknown_handle)
- self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
- unknown_handle, 5)
+ for unknown_handle in (more_extensions_pb2.optional_int_extension,
+ more_extensions_pb2.optional_message_extension,
+ more_extensions_pb2.repeated_int_extension,
+ more_extensions_pb2.repeated_message_extension):
+ self.assertRaises(KeyError, extendee_proto.HasExtension,
+ unknown_handle)
+ self.assertRaises(KeyError, extendee_proto.ClearExtension,
+ unknown_handle)
+ self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
+ unknown_handle)
+ self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
+ unknown_handle, 5)
# Try call HasExtension() with a valid handle, but for a
# *repeated* field. (Just as with non-extension repeated
@@ -1496,18 +1511,18 @@ class ReflectionTest(basetest.TestCase):
test_util.SetAllNonLazyFields(proto)
# Clear the message.
proto.Clear()
- self.assertEquals(proto.ByteSize(), 0)
+ self.assertEqual(proto.ByteSize(), 0)
empty_proto = unittest_pb2.TestAllTypes()
- self.assertEquals(proto, empty_proto)
+ self.assertEqual(proto, empty_proto)
# Test if extensions which were set are cleared.
proto = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(proto)
# Clear the message.
proto.Clear()
- self.assertEquals(proto.ByteSize(), 0)
+ self.assertEqual(proto.ByteSize(), 0)
empty_proto = unittest_pb2.TestAllExtensions()
- self.assertEquals(proto, empty_proto)
+ self.assertEqual(proto, empty_proto)
def testDisconnectingBeforeClear(self):
proto = unittest_pb2.TestAllTypes()
@@ -1618,7 +1633,7 @@ class ReflectionTest(basetest.TestCase):
self.assertFalse(proto.IsInitialized(errors))
self.assertEqual(errors, ['a', 'b', 'c'])
- @basetest.unittest.skipIf(
+ @unittest.skipIf(
api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
'Errors are only available from the most recent C++ implementation.')
def testFileDescriptorErrors(self):
@@ -1660,30 +1675,29 @@ class ReflectionTest(basetest.TestCase):
setattr, proto, 'optional_bytes', u'unicode object')
# Check that the default value is of python's 'unicode' type.
- self.assertEqual(type(proto.optional_string), unicode)
+ self.assertEqual(type(proto.optional_string), six.text_type)
- proto.optional_string = unicode('Testing')
+ proto.optional_string = six.text_type('Testing')
self.assertEqual(proto.optional_string, str('Testing'))
# Assign a value of type 'str' which can be encoded in UTF-8.
proto.optional_string = str('Testing')
- self.assertEqual(proto.optional_string, unicode('Testing'))
+ self.assertEqual(proto.optional_string, six.text_type('Testing'))
- # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII.
+ # Try to assign a 'bytes' object which contains non-UTF-8.
self.assertRaises(ValueError,
setattr, proto, 'optional_string', b'a\x80a')
- if str is bytes: # PY2
- # Assign a 'str' object which contains a UTF-8 encoded string.
- self.assertRaises(ValueError,
- setattr, proto, 'optional_string', 'Тест')
- else:
- proto.optional_string = 'Тест'
- # No exception thrown.
+ # No exception: Assign already encoded UTF-8 bytes to a string field.
+ utf8_bytes = u'Тест'.encode('utf-8')
+ proto.optional_string = utf8_bytes
+ # No exception: Assign the a non-ascii unicode object.
+ proto.optional_string = u'Тест'
+ # No exception thrown (normal str assignment containing ASCII).
proto.optional_string = 'abc'
def testStringUTF8Serialization(self):
- proto = unittest_mset_pb2.TestMessageSet()
- extension_message = unittest_mset_pb2.TestMessageSetExtension2
+ proto = message_set_extensions_pb2.TestMessageSet()
+ extension_message = message_set_extensions_pb2.TestMessageSetExtension2
extension = extension_message.message_set_extension
test_utf8 = u'Тест'
@@ -1703,19 +1717,18 @@ class ReflectionTest(basetest.TestCase):
bytes_read = raw.MergeFromString(serialized)
self.assertEqual(len(serialized), bytes_read)
- message2 = unittest_mset_pb2.TestMessageSetExtension2()
+ message2 = message_set_extensions_pb2.TestMessageSetExtension2()
self.assertEqual(1, len(raw.item))
# Check that the type_id is the same as the tag ID in the .proto file.
- self.assertEqual(raw.item[0].type_id, 1547769)
+ self.assertEqual(raw.item[0].type_id, 98418634)
# Check the actual bytes on the wire.
- self.assertTrue(
- raw.item[0].message.endswith(test_utf8_bytes))
+ self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes))
bytes_read = message2.MergeFromString(raw.item[0].message)
self.assertEqual(len(raw.item[0].message), bytes_read)
- self.assertEqual(type(message2.str), unicode)
+ self.assertEqual(type(message2.str), six.text_type)
self.assertEqual(message2.str, test_utf8)
# The pure Python API throws an exception on MergeFromString(),
@@ -1739,7 +1752,7 @@ class ReflectionTest(basetest.TestCase):
def testBytesInTextFormat(self):
proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
- unicode(proto))
+ six.text_type(proto))
def testEmptyNestedMessage(self):
proto = unittest_pb2.TestAllTypes()
@@ -1774,12 +1787,29 @@ class ReflectionTest(basetest.TestCase):
proto.optionalgroup.SetInParent()
self.assertTrue(proto.HasField('optionalgroup'))
+ def testPackageInitializationImport(self):
+ """Test that we can import nested messages from their __init__.py.
+
+ Such setup is not trivial since at the time of processing of __init__.py one
+ can't refer to its submodules by name in code, so expressions like
+ google.protobuf.internal.import_test_package.inner_pb2
+ don't work. They do work in imports, so we have assign an alias at import
+ and then use that alias in generated code.
+ """
+ # We import here since it's the import that used to fail, and we want
+ # the failure to have the right context.
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.internal import import_test_package
+ # pylint: enable=g-import-not-at-top
+ msg = import_test_package.myproto.Outer()
+ # Just check the default value.
+ self.assertEqual(57, msg.inner.value)
# Since we had so many tests for protocol buffer equality, we broke these out
# into separate TestCase classes.
-class TestAllTypesEqualityTest(basetest.TestCase):
+class TestAllTypesEqualityTest(unittest.TestCase):
def setUp(self):
self.first_proto = unittest_pb2.TestAllTypes()
@@ -1795,7 +1825,7 @@ class TestAllTypesEqualityTest(basetest.TestCase):
self.assertEqual(self.first_proto, self.second_proto)
-class FullProtosEqualityTest(basetest.TestCase):
+class FullProtosEqualityTest(unittest.TestCase):
"""Equality tests using completely-full protos as a starting point."""
@@ -1881,7 +1911,7 @@ class FullProtosEqualityTest(basetest.TestCase):
self.assertEqual(self.first_proto, self.second_proto)
-class ExtensionEqualityTest(basetest.TestCase):
+class ExtensionEqualityTest(unittest.TestCase):
def testExtensionEquality(self):
first_proto = unittest_pb2.TestAllExtensions()
@@ -1914,7 +1944,7 @@ class ExtensionEqualityTest(basetest.TestCase):
self.assertEqual(first_proto, second_proto)
-class MutualRecursionEqualityTest(basetest.TestCase):
+class MutualRecursionEqualityTest(unittest.TestCase):
def testEqualityWithMutualRecursion(self):
first_proto = unittest_pb2.TestMutualRecursionA()
@@ -1926,7 +1956,7 @@ class MutualRecursionEqualityTest(basetest.TestCase):
self.assertEqual(first_proto, second_proto)
-class ByteSizeTest(basetest.TestCase):
+class ByteSizeTest(unittest.TestCase):
def setUp(self):
self.proto = unittest_pb2.TestAllTypes()
@@ -2222,7 +2252,7 @@ class ByteSizeTest(basetest.TestCase):
# * Handling of empty submessages (with and without "has"
# bits set).
-class SerializationTest(basetest.TestCase):
+class SerializationTest(unittest.TestCase):
def testSerializeEmtpyMessage(self):
first_proto = unittest_pb2.TestAllTypes()
@@ -2289,7 +2319,7 @@ class SerializationTest(basetest.TestCase):
test_util.SetAllFields(first_proto)
serialized = first_proto.SerializeToString()
- for truncation_point in xrange(len(serialized) + 1):
+ for truncation_point in range(len(serialized) + 1):
try:
second_proto = unittest_pb2.TestAllTypes()
unknown_fields = unittest_pb2.TestEmptyMessage()
@@ -2366,13 +2396,15 @@ class SerializationTest(basetest.TestCase):
self.assertEqual(42, second_proto.optional_nested_message.bb)
def testMessageSetWireFormat(self):
- proto = unittest_mset_pb2.TestMessageSet()
- extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
- extension_message2 = unittest_mset_pb2.TestMessageSetExtension2
+ proto = message_set_extensions_pb2.TestMessageSet()
+ extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
+ extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2
extension1 = extension_message1.message_set_extension
extension2 = extension_message2.message_set_extension
+ extension3 = message_set_extensions_pb2.message_set_extension3
proto.Extensions[extension1].i = 123
proto.Extensions[extension2].str = 'foo'
+ proto.Extensions[extension3].text = 'bar'
# Serialize using the MessageSet wire format (this is specified in the
# .proto file).
@@ -2384,27 +2416,34 @@ class SerializationTest(basetest.TestCase):
self.assertEqual(
len(serialized),
raw.MergeFromString(serialized))
- self.assertEqual(2, len(raw.item))
+ self.assertEqual(3, len(raw.item))
- message1 = unittest_mset_pb2.TestMessageSetExtension1()
+ message1 = message_set_extensions_pb2.TestMessageSetExtension1()
self.assertEqual(
len(raw.item[0].message),
message1.MergeFromString(raw.item[0].message))
self.assertEqual(123, message1.i)
- message2 = unittest_mset_pb2.TestMessageSetExtension2()
+ message2 = message_set_extensions_pb2.TestMessageSetExtension2()
self.assertEqual(
len(raw.item[1].message),
message2.MergeFromString(raw.item[1].message))
self.assertEqual('foo', message2.str)
+ message3 = message_set_extensions_pb2.TestMessageSetExtension3()
+ self.assertEqual(
+ len(raw.item[2].message),
+ message3.MergeFromString(raw.item[2].message))
+ self.assertEqual('bar', message3.text)
+
# Deserialize using the MessageSet wire format.
- proto2 = unittest_mset_pb2.TestMessageSet()
+ proto2 = message_set_extensions_pb2.TestMessageSet()
self.assertEqual(
len(serialized),
proto2.MergeFromString(serialized))
self.assertEqual(123, proto2.Extensions[extension1].i)
self.assertEqual('foo', proto2.Extensions[extension2].str)
+ self.assertEqual('bar', proto2.Extensions[extension3].text)
# Check byte size.
self.assertEqual(proto2.ByteSize(), len(serialized))
@@ -2417,39 +2456,39 @@ class SerializationTest(basetest.TestCase):
# Add an item.
item = raw.item.add()
- item.type_id = 1545008
- extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
- message1 = unittest_mset_pb2.TestMessageSetExtension1()
+ item.type_id = 98418603
+ extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
+ message1 = message_set_extensions_pb2.TestMessageSetExtension1()
message1.i = 12345
item.message = message1.SerializeToString()
# Add a second, unknown extension.
item = raw.item.add()
- item.type_id = 1545009
- extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
- message1 = unittest_mset_pb2.TestMessageSetExtension1()
+ item.type_id = 98418604
+ extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
+ message1 = message_set_extensions_pb2.TestMessageSetExtension1()
message1.i = 12346
item.message = message1.SerializeToString()
# Add another unknown extension.
item = raw.item.add()
- item.type_id = 1545010
- message1 = unittest_mset_pb2.TestMessageSetExtension2()
+ item.type_id = 98418605
+ message1 = message_set_extensions_pb2.TestMessageSetExtension2()
message1.str = 'foo'
item.message = message1.SerializeToString()
serialized = raw.SerializeToString()
# Parse message using the message set wire format.
- proto = unittest_mset_pb2.TestMessageSet()
+ proto = message_set_extensions_pb2.TestMessageSet()
self.assertEqual(
len(serialized),
proto.MergeFromString(serialized))
# Check that the message parsed well.
- extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
+ extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
extension1 = extension_message1.message_set_extension
- self.assertEquals(12345, proto.Extensions[extension1].i)
+ self.assertEqual(12345, proto.Extensions[extension1].i)
def testUnknownFields(self):
proto = unittest_pb2.TestAllTypes()
@@ -2734,9 +2773,10 @@ class SerializationTest(basetest.TestCase):
def testInitArgsUnknownFieldName(self):
def InitalizeEmptyMessageWithExtraKeywordArg():
unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
- self._CheckRaises(ValueError,
- InitalizeEmptyMessageWithExtraKeywordArg,
- 'Protocol message has no "unknown" field.')
+ self._CheckRaises(
+ ValueError,
+ InitalizeEmptyMessageWithExtraKeywordArg,
+ 'Protocol message TestEmptyMessage has no "unknown" field.')
def testInitRequiredKwargs(self):
proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
@@ -2773,10 +2813,10 @@ class SerializationTest(basetest.TestCase):
self.assertEqual(3, proto.repeated_int32[2])
-class OptionsTest(basetest.TestCase):
+class OptionsTest(unittest.TestCase):
def testMessageOptions(self):
- proto = unittest_mset_pb2.TestMessageSet()
+ proto = message_set_extensions_pb2.TestMessageSet()
self.assertEqual(True,
proto.DESCRIPTOR.GetOptions().message_set_wire_format)
proto = unittest_pb2.TestAllTypes()
@@ -2795,13 +2835,16 @@ class OptionsTest(basetest.TestCase):
proto.packed_double.append(3.0)
for field_descriptor, _ in proto.ListFields():
self.assertEqual(True, field_descriptor.GetOptions().packed)
- self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED,
+ self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED,
field_descriptor.label)
-class ClassAPITest(basetest.TestCase):
+class ClassAPITest(unittest.TestCase):
+ @unittest.skipIf(
+ api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
+ 'C++ implementation requires a call to MakeDescriptor()')
def testMakeClassWithNestedDescriptor(self):
leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
containing_type=None, fields=[],
@@ -2882,13 +2925,19 @@ class ClassAPITest(basetest.TestCase):
def testParsingFlatClassWithExplicitClassDeclaration(self):
"""Test that the generated class can parse a flat message."""
+ # TODO(xiaofeng): This test fails with cpp implemetnation in the call
+ # of six.with_metaclass(). The other two callsites of with_metaclass
+ # in this file are both excluded from cpp test, so it might be expected
+ # to fail. Need someone more familiar with the python code to take a
+ # look at this.
+ if api_implementation.Type() != 'python':
+ return
file_descriptor = descriptor_pb2.FileDescriptorProto()
file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
msg_descriptor = descriptor.MakeDescriptor(
file_descriptor.message_type[0])
- class MessageClass(message.Message):
- __metaclass__ = reflection.GeneratedProtocolMessageType
+ class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
DESCRIPTOR = msg_descriptor
msg = MessageClass()
msg_str = (
@@ -2931,4 +2980,4 @@ class ClassAPITest(basetest.TestCase):
self.assertEqual(msg.bar.baz.deep, 4)
if __name__ == '__main__':
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py
index 07dcf4453..62900b1d1 100755
--- a/python/google/protobuf/internal/service_reflection_test.py
+++ b/python/google/protobuf/internal/service_reflection_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -34,13 +34,18 @@
__author__ = 'petar@google.com (Petar Petrov)'
-from google.apputils import basetest
+
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
from google.protobuf import unittest_pb2
from google.protobuf import service_reflection
from google.protobuf import service
-class FooUnitTest(basetest.TestCase):
+class FooUnitTest(unittest.TestCase):
def testService(self):
class MockRpcChannel(service.RpcChannel):
@@ -80,7 +85,7 @@ class FooUnitTest(basetest.TestCase):
self.assertEqual('Method Bar not implemented.',
rpc_controller.failure_message)
self.assertEqual(None, self.callback_response)
-
+
class MyServiceImpl(unittest_pb2.TestService):
def Foo(self, rpc_controller, request, done):
self.foo_called = True
@@ -118,19 +123,18 @@ class FooUnitTest(basetest.TestCase):
rpc_controller = 'controller'
request = 'request'
- # GetDescriptor now static, still works as instance method for compatability
+ # GetDescriptor now static, still works as instance method for compatibility
self.assertEqual(unittest_pb2.TestService_Stub.GetDescriptor(),
stub.GetDescriptor())
# Invoke method.
stub.Foo(rpc_controller, request, MyCallback)
- self.assertTrue(isinstance(self.callback_response,
- unittest_pb2.FooResponse))
+ self.assertIsInstance(self.callback_response, unittest_pb2.FooResponse)
self.assertEqual(request, channel.request)
self.assertEqual(rpc_controller, channel.controller)
self.assertEqual(stub.GetDescriptor().methods[0], channel.method)
if __name__ == '__main__':
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py
index 47572d58c..c99b426dc 100644
--- a/python/google/protobuf/internal/symbol_database_test.py
+++ b/python/google/protobuf/internal/symbol_database_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -32,24 +32,33 @@
"""Tests for google.protobuf.symbol_database."""
-from google.apputils import basetest
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
from google.protobuf import unittest_pb2
+from google.protobuf import descriptor
from google.protobuf import symbol_database
-
-class SymbolDatabaseTest(basetest.TestCase):
+class SymbolDatabaseTest(unittest.TestCase):
def _Database(self):
- db = symbol_database.SymbolDatabase()
- # Register representative types from unittest_pb2.
- db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR)
- db.RegisterMessage(unittest_pb2.TestAllTypes)
- db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage)
- db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup)
- db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup)
- db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR)
- db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR)
- return db
+ # TODO(b/17734095): Remove this difference when the C++ implementation
+ # supports multiple databases.
+ if descriptor._USE_C_DESCRIPTORS:
+ return symbol_database.Default()
+ else:
+ db = symbol_database.SymbolDatabase()
+ # Register representative types from unittest_pb2.
+ db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR)
+ db.RegisterMessage(unittest_pb2.TestAllTypes)
+ db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage)
+ db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup)
+ db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup)
+ db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR)
+ db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR)
+ return db
def testGetPrototype(self):
instance = self._Database().GetPrototype(
@@ -64,57 +73,57 @@ class SymbolDatabaseTest(basetest.TestCase):
messages['protobuf_unittest.TestAllTypes'])
def testGetSymbol(self):
- self.assertEquals(
+ self.assertEqual(
unittest_pb2.TestAllTypes, self._Database().GetSymbol(
'protobuf_unittest.TestAllTypes'))
- self.assertEquals(
+ self.assertEqual(
unittest_pb2.TestAllTypes.NestedMessage, self._Database().GetSymbol(
'protobuf_unittest.TestAllTypes.NestedMessage'))
- self.assertEquals(
+ self.assertEqual(
unittest_pb2.TestAllTypes.OptionalGroup, self._Database().GetSymbol(
'protobuf_unittest.TestAllTypes.OptionalGroup'))
- self.assertEquals(
+ self.assertEqual(
unittest_pb2.TestAllTypes.RepeatedGroup, self._Database().GetSymbol(
'protobuf_unittest.TestAllTypes.RepeatedGroup'))
def testEnums(self):
# Check registration of types in the pool.
- self.assertEquals(
+ self.assertEqual(
'protobuf_unittest.ForeignEnum',
self._Database().pool.FindEnumTypeByName(
'protobuf_unittest.ForeignEnum').full_name)
- self.assertEquals(
+ self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedEnum',
self._Database().pool.FindEnumTypeByName(
'protobuf_unittest.TestAllTypes.NestedEnum').full_name)
def testFindMessageTypeByName(self):
- self.assertEquals(
+ self.assertEqual(
'protobuf_unittest.TestAllTypes',
self._Database().pool.FindMessageTypeByName(
'protobuf_unittest.TestAllTypes').full_name)
- self.assertEquals(
+ self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedMessage',
self._Database().pool.FindMessageTypeByName(
'protobuf_unittest.TestAllTypes.NestedMessage').full_name)
def testFindFindContainingSymbol(self):
# Lookup based on either enum or message.
- self.assertEquals(
+ self.assertEqual(
'google/protobuf/unittest.proto',
self._Database().pool.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes.NestedEnum').name)
- self.assertEquals(
+ self.assertEqual(
'google/protobuf/unittest.proto',
self._Database().pool.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes').name)
def testFindFileByName(self):
- self.assertEquals(
+ self.assertEqual(
'google/protobuf/unittest.proto',
self._Database().pool.FindFileByName(
'google/protobuf/unittest.proto').name)
if __name__ == '__main__':
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/internal/test_bad_identifiers.proto b/python/google/protobuf/internal/test_bad_identifiers.proto
index 9eb18cb09..c4860ea88 100644
--- a/python/google/protobuf/internal/test_bad_identifiers.proto
+++ b/python/google/protobuf/internal/test_bad_identifiers.proto
@@ -30,6 +30,7 @@
// Author: kenton@google.com (Kenton Varda)
+syntax = "proto2";
package protobuf_unittest;
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
index 787f46505..2c805599b 100755
--- a/python/google/protobuf/internal/test_util.py
+++ b/python/google/protobuf/internal/test_util.py
@@ -38,15 +38,23 @@ __author__ = 'robinson@google.com (Will Robinson)'
import os.path
+import sys
+
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
+from google.protobuf import descriptor_pb2
+# Tests whether the given TestAllTypes message is proto2 or not.
+# This is used to gate several fields/features that only exist
+# for the proto2 version of the message.
+def IsProto2(message):
+ return message.DESCRIPTOR.syntax == "proto2"
def SetAllNonLazyFields(message):
"""Sets every non-lazy field in the message to a unique value.
Args:
- message: A unittest_pb2.TestAllTypes instance.
+ message: A TestAllTypes instance.
"""
#
@@ -69,7 +77,8 @@ def SetAllNonLazyFields(message):
message.optional_string = u'115'
message.optional_bytes = b'116'
- message.optionalgroup.a = 117
+ if IsProto2(message):
+ message.optionalgroup.a = 117
message.optional_nested_message.bb = 118
message.optional_foreign_message.c = 119
message.optional_import_message.d = 120
@@ -77,7 +86,8 @@ def SetAllNonLazyFields(message):
message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ
message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ
- message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ
+ if IsProto2(message):
+ message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ
message.optional_string_piece = u'124'
message.optional_cord = u'125'
@@ -102,7 +112,8 @@ def SetAllNonLazyFields(message):
message.repeated_string.append(u'215')
message.repeated_bytes.append(b'216')
- message.repeatedgroup.add().a = 217
+ if IsProto2(message):
+ message.repeatedgroup.add().a = 217
message.repeated_nested_message.add().bb = 218
message.repeated_foreign_message.add().c = 219
message.repeated_import_message.add().d = 220
@@ -110,7 +121,8 @@ def SetAllNonLazyFields(message):
message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR)
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR)
- message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR)
+ if IsProto2(message):
+ message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR)
message.repeated_string_piece.append(u'224')
message.repeated_cord.append(u'225')
@@ -132,7 +144,8 @@ def SetAllNonLazyFields(message):
message.repeated_string.append(u'315')
message.repeated_bytes.append(b'316')
- message.repeatedgroup.add().a = 317
+ if IsProto2(message):
+ message.repeatedgroup.add().a = 317
message.repeated_nested_message.add().bb = 318
message.repeated_foreign_message.add().c = 319
message.repeated_import_message.add().d = 320
@@ -140,7 +153,8 @@ def SetAllNonLazyFields(message):
message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ)
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ)
- message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ)
+ if IsProto2(message):
+ message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ)
message.repeated_string_piece.append(u'324')
message.repeated_cord.append(u'325')
@@ -149,28 +163,29 @@ def SetAllNonLazyFields(message):
# Fields that have defaults.
#
- message.default_int32 = 401
- message.default_int64 = 402
- message.default_uint32 = 403
- message.default_uint64 = 404
- message.default_sint32 = 405
- message.default_sint64 = 406
- message.default_fixed32 = 407
- message.default_fixed64 = 408
- message.default_sfixed32 = 409
- message.default_sfixed64 = 410
- message.default_float = 411
- message.default_double = 412
- message.default_bool = False
- message.default_string = '415'
- message.default_bytes = b'416'
-
- message.default_nested_enum = unittest_pb2.TestAllTypes.FOO
- message.default_foreign_enum = unittest_pb2.FOREIGN_FOO
- message.default_import_enum = unittest_import_pb2.IMPORT_FOO
-
- message.default_string_piece = '424'
- message.default_cord = '425'
+ if IsProto2(message):
+ message.default_int32 = 401
+ message.default_int64 = 402
+ message.default_uint32 = 403
+ message.default_uint64 = 404
+ message.default_sint32 = 405
+ message.default_sint64 = 406
+ message.default_fixed32 = 407
+ message.default_fixed64 = 408
+ message.default_sfixed32 = 409
+ message.default_sfixed64 = 410
+ message.default_float = 411
+ message.default_double = 412
+ message.default_bool = False
+ message.default_string = '415'
+ message.default_bytes = b'416'
+
+ message.default_nested_enum = unittest_pb2.TestAllTypes.FOO
+ message.default_foreign_enum = unittest_pb2.FOREIGN_FOO
+ message.default_import_enum = unittest_import_pb2.IMPORT_FOO
+
+ message.default_string_piece = '424'
+ message.default_cord = '425'
message.oneof_uint32 = 601
message.oneof_nested_message.bb = 602
@@ -386,7 +401,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertTrue(message.HasField('optional_string'))
test_case.assertTrue(message.HasField('optional_bytes'))
- test_case.assertTrue(message.HasField('optionalgroup'))
+ if IsProto2(message):
+ test_case.assertTrue(message.HasField('optionalgroup'))
test_case.assertTrue(message.HasField('optional_nested_message'))
test_case.assertTrue(message.HasField('optional_foreign_message'))
test_case.assertTrue(message.HasField('optional_import_message'))
@@ -398,7 +414,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertTrue(message.HasField('optional_nested_enum'))
test_case.assertTrue(message.HasField('optional_foreign_enum'))
- test_case.assertTrue(message.HasField('optional_import_enum'))
+ if IsProto2(message):
+ test_case.assertTrue(message.HasField('optional_import_enum'))
test_case.assertTrue(message.HasField('optional_string_piece'))
test_case.assertTrue(message.HasField('optional_cord'))
@@ -419,7 +436,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual('115', message.optional_string)
test_case.assertEqual(b'116', message.optional_bytes)
- test_case.assertEqual(117, message.optionalgroup.a)
+ if IsProto2(message):
+ test_case.assertEqual(117, message.optionalgroup.a)
test_case.assertEqual(118, message.optional_nested_message.bb)
test_case.assertEqual(119, message.optional_foreign_message.c)
test_case.assertEqual(120, message.optional_import_message.d)
@@ -430,8 +448,9 @@ def ExpectAllFieldsSet(test_case, message):
message.optional_nested_enum)
test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
message.optional_foreign_enum)
- test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
- message.optional_import_enum)
+ if IsProto2(message):
+ test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+ message.optional_import_enum)
# -----------------------------------------------------------------
@@ -451,13 +470,15 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual(2, len(message.repeated_string))
test_case.assertEqual(2, len(message.repeated_bytes))
- test_case.assertEqual(2, len(message.repeatedgroup))
+ if IsProto2(message):
+ test_case.assertEqual(2, len(message.repeatedgroup))
test_case.assertEqual(2, len(message.repeated_nested_message))
test_case.assertEqual(2, len(message.repeated_foreign_message))
test_case.assertEqual(2, len(message.repeated_import_message))
test_case.assertEqual(2, len(message.repeated_nested_enum))
test_case.assertEqual(2, len(message.repeated_foreign_enum))
- test_case.assertEqual(2, len(message.repeated_import_enum))
+ if IsProto2(message):
+ test_case.assertEqual(2, len(message.repeated_import_enum))
test_case.assertEqual(2, len(message.repeated_string_piece))
test_case.assertEqual(2, len(message.repeated_cord))
@@ -478,7 +499,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual('215', message.repeated_string[0])
test_case.assertEqual(b'216', message.repeated_bytes[0])
- test_case.assertEqual(217, message.repeatedgroup[0].a)
+ if IsProto2(message):
+ test_case.assertEqual(217, message.repeatedgroup[0].a)
test_case.assertEqual(218, message.repeated_nested_message[0].bb)
test_case.assertEqual(219, message.repeated_foreign_message[0].c)
test_case.assertEqual(220, message.repeated_import_message[0].d)
@@ -488,8 +510,9 @@ def ExpectAllFieldsSet(test_case, message):
message.repeated_nested_enum[0])
test_case.assertEqual(unittest_pb2.FOREIGN_BAR,
message.repeated_foreign_enum[0])
- test_case.assertEqual(unittest_import_pb2.IMPORT_BAR,
- message.repeated_import_enum[0])
+ if IsProto2(message):
+ test_case.assertEqual(unittest_import_pb2.IMPORT_BAR,
+ message.repeated_import_enum[0])
test_case.assertEqual(301, message.repeated_int32[1])
test_case.assertEqual(302, message.repeated_int64[1])
@@ -507,7 +530,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual('315', message.repeated_string[1])
test_case.assertEqual(b'316', message.repeated_bytes[1])
- test_case.assertEqual(317, message.repeatedgroup[1].a)
+ if IsProto2(message):
+ test_case.assertEqual(317, message.repeatedgroup[1].a)
test_case.assertEqual(318, message.repeated_nested_message[1].bb)
test_case.assertEqual(319, message.repeated_foreign_message[1].c)
test_case.assertEqual(320, message.repeated_import_message[1].d)
@@ -517,53 +541,55 @@ def ExpectAllFieldsSet(test_case, message):
message.repeated_nested_enum[1])
test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
message.repeated_foreign_enum[1])
- test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
- message.repeated_import_enum[1])
+ if IsProto2(message):
+ test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+ message.repeated_import_enum[1])
# -----------------------------------------------------------------
- test_case.assertTrue(message.HasField('default_int32'))
- test_case.assertTrue(message.HasField('default_int64'))
- test_case.assertTrue(message.HasField('default_uint32'))
- test_case.assertTrue(message.HasField('default_uint64'))
- test_case.assertTrue(message.HasField('default_sint32'))
- test_case.assertTrue(message.HasField('default_sint64'))
- test_case.assertTrue(message.HasField('default_fixed32'))
- test_case.assertTrue(message.HasField('default_fixed64'))
- test_case.assertTrue(message.HasField('default_sfixed32'))
- test_case.assertTrue(message.HasField('default_sfixed64'))
- test_case.assertTrue(message.HasField('default_float'))
- test_case.assertTrue(message.HasField('default_double'))
- test_case.assertTrue(message.HasField('default_bool'))
- test_case.assertTrue(message.HasField('default_string'))
- test_case.assertTrue(message.HasField('default_bytes'))
-
- test_case.assertTrue(message.HasField('default_nested_enum'))
- test_case.assertTrue(message.HasField('default_foreign_enum'))
- test_case.assertTrue(message.HasField('default_import_enum'))
-
- test_case.assertEqual(401, message.default_int32)
- test_case.assertEqual(402, message.default_int64)
- test_case.assertEqual(403, message.default_uint32)
- test_case.assertEqual(404, message.default_uint64)
- test_case.assertEqual(405, message.default_sint32)
- test_case.assertEqual(406, message.default_sint64)
- test_case.assertEqual(407, message.default_fixed32)
- test_case.assertEqual(408, message.default_fixed64)
- test_case.assertEqual(409, message.default_sfixed32)
- test_case.assertEqual(410, message.default_sfixed64)
- test_case.assertEqual(411, message.default_float)
- test_case.assertEqual(412, message.default_double)
- test_case.assertEqual(False, message.default_bool)
- test_case.assertEqual('415', message.default_string)
- test_case.assertEqual(b'416', message.default_bytes)
-
- test_case.assertEqual(unittest_pb2.TestAllTypes.FOO,
- message.default_nested_enum)
- test_case.assertEqual(unittest_pb2.FOREIGN_FOO,
- message.default_foreign_enum)
- test_case.assertEqual(unittest_import_pb2.IMPORT_FOO,
- message.default_import_enum)
+ if IsProto2(message):
+ test_case.assertTrue(message.HasField('default_int32'))
+ test_case.assertTrue(message.HasField('default_int64'))
+ test_case.assertTrue(message.HasField('default_uint32'))
+ test_case.assertTrue(message.HasField('default_uint64'))
+ test_case.assertTrue(message.HasField('default_sint32'))
+ test_case.assertTrue(message.HasField('default_sint64'))
+ test_case.assertTrue(message.HasField('default_fixed32'))
+ test_case.assertTrue(message.HasField('default_fixed64'))
+ test_case.assertTrue(message.HasField('default_sfixed32'))
+ test_case.assertTrue(message.HasField('default_sfixed64'))
+ test_case.assertTrue(message.HasField('default_float'))
+ test_case.assertTrue(message.HasField('default_double'))
+ test_case.assertTrue(message.HasField('default_bool'))
+ test_case.assertTrue(message.HasField('default_string'))
+ test_case.assertTrue(message.HasField('default_bytes'))
+
+ test_case.assertTrue(message.HasField('default_nested_enum'))
+ test_case.assertTrue(message.HasField('default_foreign_enum'))
+ test_case.assertTrue(message.HasField('default_import_enum'))
+
+ test_case.assertEqual(401, message.default_int32)
+ test_case.assertEqual(402, message.default_int64)
+ test_case.assertEqual(403, message.default_uint32)
+ test_case.assertEqual(404, message.default_uint64)
+ test_case.assertEqual(405, message.default_sint32)
+ test_case.assertEqual(406, message.default_sint64)
+ test_case.assertEqual(407, message.default_fixed32)
+ test_case.assertEqual(408, message.default_fixed64)
+ test_case.assertEqual(409, message.default_sfixed32)
+ test_case.assertEqual(410, message.default_sfixed64)
+ test_case.assertEqual(411, message.default_float)
+ test_case.assertEqual(412, message.default_double)
+ test_case.assertEqual(False, message.default_bool)
+ test_case.assertEqual('415', message.default_string)
+ test_case.assertEqual(b'416', message.default_bytes)
+
+ test_case.assertEqual(unittest_pb2.TestAllTypes.FOO,
+ message.default_nested_enum)
+ test_case.assertEqual(unittest_pb2.FOREIGN_FOO,
+ message.default_foreign_enum)
+ test_case.assertEqual(unittest_import_pb2.IMPORT_FOO,
+ message.default_import_enum)
def GoldenFile(filename):
@@ -578,6 +604,14 @@ def GoldenFile(filename):
return open(full_path, 'rb')
path = os.path.join(path, '..')
+ # Search internally.
+ path = '.'
+ full_path = os.path.join(path, 'third_party/py/google/protobuf/testdata',
+ filename)
+ if os.path.exists(full_path):
+ # Found it. Load the golden file from the testdata directory.
+ return open(full_path, 'rb')
+
raise RuntimeError(
'Could not find golden files. This test must be run from within the '
'protobuf source package so that it can read test data files from the '
@@ -594,7 +628,7 @@ def SetAllPackedFields(message):
"""Sets every field in the message to a unique value.
Args:
- message: A unittest_pb2.TestPackedTypes instance.
+ message: A TestPackedTypes instance.
"""
message.packed_int32.extend([601, 701])
message.packed_int64.extend([602, 702])
diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py
index db0222bd3..c7d182c44 100755
--- a/python/google/protobuf/internal/text_encoding_test.py
+++ b/python/google/protobuf/internal/text_encoding_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -32,7 +32,11 @@
"""Tests for google.protobuf.text_encoding."""
-from google.apputils import basetest
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
from google.protobuf import text_encoding
TEST_VALUES = [
@@ -50,19 +54,19 @@ TEST_VALUES = [
b"\010\011\012\013\014\015")]
-class TextEncodingTestCase(basetest.TestCase):
+class TextEncodingTestCase(unittest.TestCase):
def testCEscape(self):
for escaped, escaped_utf8, unescaped in TEST_VALUES:
- self.assertEquals(escaped,
+ self.assertEqual(escaped,
text_encoding.CEscape(unescaped, as_utf8=False))
- self.assertEquals(escaped_utf8,
+ self.assertEqual(escaped_utf8,
text_encoding.CEscape(unescaped, as_utf8=True))
def testCUnescape(self):
for escaped, escaped_utf8, unescaped in TEST_VALUES:
- self.assertEquals(unescaped, text_encoding.CUnescape(escaped))
- self.assertEquals(unescaped, text_encoding.CUnescape(escaped_utf8))
+ self.assertEqual(unescaped, text_encoding.CUnescape(escaped))
+ self.assertEqual(unescaped, text_encoding.CUnescape(escaped_utf8))
if __name__ == "__main__":
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index b0a3a5f72..ab2bf05b8 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -34,16 +34,42 @@
__author__ = 'kenton@google.com (Kenton Varda)'
+
import re
+import six
+import string
-from google.apputils import basetest
-from google.protobuf import text_format
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
+from google.protobuf.internal import _parameterized
+
+from google.protobuf import map_unittest_pb2
+from google.protobuf import unittest_mset_pb2
+from google.protobuf import unittest_pb2
+from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import test_util
-from google.protobuf import unittest_pb2
-from google.protobuf import unittest_mset_pb2
+from google.protobuf.internal import message_set_extensions_pb2
+from google.protobuf import text_format
-class TextFormatTest(basetest.TestCase):
+
+# Low-level nuts-n-bolts tests.
+class SimpleTextFormatTests(unittest.TestCase):
+
+ # The members of _QUOTES are formatted into a regexp template that
+ # expects single characters. Therefore it's an error (in addition to being
+ # non-sensical in the first place) to try to specify a "quote mark" that is
+ # more than one character.
+ def testQuoteMarksAreSingleChars(self):
+ for quote in text_format._QUOTES:
+ self.assertEqual(1, len(quote))
+
+
+# Base class with some common functionality.
+class TextFormatBase(unittest.TestCase):
def ReadGolden(self, golden_filename):
with test_util.GoldenFile(golden_filename) as f:
@@ -55,70 +81,26 @@ class TextFormatTest(basetest.TestCase):
self.assertMultiLineEqual(text, ''.join(golden_lines))
def CompareToGoldenText(self, text, golden_text):
- self.assertMultiLineEqual(text, golden_text)
+ self.assertEqual(text, golden_text)
- def testPrintAllFields(self):
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(message)),
- 'text_format_unittest_data_oneof_implemented.txt')
-
- def testPrintInIndexOrder(self):
- message = unittest_pb2.TestFieldOrderings()
- message.my_string = '115'
- message.my_int = 101
- message.my_float = 111
- self.CompareToGoldenText(
- self.RemoveRedundantZeros(text_format.MessageToString(
- message, use_index_order=True)),
- 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n')
- self.CompareToGoldenText(
- self.RemoveRedundantZeros(text_format.MessageToString(
- message)), 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n')
-
- def testPrintAllExtensions(self):
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(message)),
- 'text_format_unittest_extensions_data.txt')
+ def RemoveRedundantZeros(self, text):
+ # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove
+ # these zeros in order to match the golden file.
+ text = text.replace('e+0','e+').replace('e+0','e+') \
+ .replace('e-0','e-').replace('e-0','e-')
+ # Floating point fields are printed with .0 suffix even if they are
+ # actualy integer numbers.
+ text = re.compile('\.0$', re.MULTILINE).sub('', text)
+ return text
- def testPrintAllFieldsPointy(self):
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(
- text_format.MessageToString(message, pointy_brackets=True)),
- 'text_format_unittest_data_pointy_oneof.txt')
- def testPrintAllExtensionsPointy(self):
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(
- message, pointy_brackets=True)),
- 'text_format_unittest_extensions_data_pointy.txt')
+@_parameterized.Parameters(
+ (unittest_pb2),
+ (unittest_proto3_arena_pb2))
+class TextFormatTest(TextFormatBase):
- def testPrintMessageSet(self):
- message = unittest_mset_pb2.TestMessageSetContainer()
- ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
- ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
- message.message_set.Extensions[ext1].i = 23
- message.message_set.Extensions[ext2].str = 'foo'
- self.CompareToGoldenText(
- text_format.MessageToString(message),
- 'message_set {\n'
- ' [protobuf_unittest.TestMessageSetExtension1] {\n'
- ' i: 23\n'
- ' }\n'
- ' [protobuf_unittest.TestMessageSetExtension2] {\n'
- ' str: \"foo\"\n'
- ' }\n'
- '}\n')
-
- def testPrintExotic(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintExotic(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_int64.append(-9223372036854775808)
message.repeated_uint64.append(18446744073709551615)
message.repeated_double.append(123.456)
@@ -137,61 +119,44 @@ class TextFormatTest(basetest.TestCase):
' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n'
'repeated_string: "\\303\\274\\352\\234\\237"\n')
- def testPrintExoticUnicodeSubclass(self):
- class UnicodeSub(unicode):
+ def testPrintExoticUnicodeSubclass(self, message_module):
+ class UnicodeSub(six.text_type):
pass
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f'))
self.CompareToGoldenText(
text_format.MessageToString(message),
'repeated_string: "\\303\\274\\352\\234\\237"\n')
- def testPrintNestedMessageAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintNestedMessageAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
msg = message.repeated_nested_message.add()
msg.bb = 42
self.CompareToGoldenText(
text_format.MessageToString(message, as_one_line=True),
'repeated_nested_message { bb: 42 }')
- def testPrintRepeatedFieldsAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintRepeatedFieldsAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_int32.append(1)
message.repeated_int32.append(1)
message.repeated_int32.append(3)
- message.repeated_string.append("Google")
- message.repeated_string.append("Zurich")
+ message.repeated_string.append('Google')
+ message.repeated_string.append('Zurich')
self.CompareToGoldenText(
text_format.MessageToString(message, as_one_line=True),
'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 '
'repeated_string: "Google" repeated_string: "Zurich"')
- def testPrintNestedNewLineInStringAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
- message.optional_string = "a\nnew\nline"
+ def testPrintNestedNewLineInStringAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
+ message.optional_string = 'a\nnew\nline'
self.CompareToGoldenText(
text_format.MessageToString(message, as_one_line=True),
'optional_string: "a\\nnew\\nline"')
- def testPrintMessageSetAsOneLine(self):
- message = unittest_mset_pb2.TestMessageSetContainer()
- ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
- ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
- message.message_set.Extensions[ext1].i = 23
- message.message_set.Extensions[ext2].str = 'foo'
- self.CompareToGoldenText(
- text_format.MessageToString(message, as_one_line=True),
- 'message_set {'
- ' [protobuf_unittest.TestMessageSetExtension1] {'
- ' i: 23'
- ' }'
- ' [protobuf_unittest.TestMessageSetExtension2] {'
- ' str: \"foo\"'
- ' }'
- ' }')
-
- def testPrintExoticAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintExoticAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_int64.append(-9223372036854775808)
message.repeated_uint64.append(18446744073709551615)
message.repeated_double.append(123.456)
@@ -211,8 +176,8 @@ class TextFormatTest(basetest.TestCase):
'"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""'
' repeated_string: "\\303\\274\\352\\234\\237"')
- def testRoundTripExoticAsOneLine(self):
- message = unittest_pb2.TestAllTypes()
+ def testRoundTripExoticAsOneLine(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_int64.append(-9223372036854775808)
message.repeated_uint64.append(18446744073709551615)
message.repeated_double.append(123.456)
@@ -224,33 +189,33 @@ class TextFormatTest(basetest.TestCase):
# Test as_utf8 = False.
wire_text = text_format.MessageToString(
message, as_one_line=True, as_utf8=False)
- parsed_message = unittest_pb2.TestAllTypes()
+ parsed_message = message_module.TestAllTypes()
r = text_format.Parse(wire_text, parsed_message)
self.assertIs(r, parsed_message)
- self.assertEquals(message, parsed_message)
+ self.assertEqual(message, parsed_message)
# Test as_utf8 = True.
wire_text = text_format.MessageToString(
message, as_one_line=True, as_utf8=True)
- parsed_message = unittest_pb2.TestAllTypes()
+ parsed_message = message_module.TestAllTypes()
r = text_format.Parse(wire_text, parsed_message)
self.assertIs(r, parsed_message)
- self.assertEquals(message, parsed_message,
+ self.assertEqual(message, parsed_message,
'\n%s != %s' % (message, parsed_message))
- def testPrintRawUtf8String(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintRawUtf8String(self, message_module):
+ message = message_module.TestAllTypes()
message.repeated_string.append(u'\u00fc\ua71f')
text = text_format.MessageToString(message, as_utf8=True)
self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n')
- parsed_message = unittest_pb2.TestAllTypes()
+ parsed_message = message_module.TestAllTypes()
text_format.Parse(text, parsed_message)
- self.assertEquals(message, parsed_message,
+ self.assertEqual(message, parsed_message,
'\n%s != %s' % (message, parsed_message))
- def testPrintFloatFormat(self):
+ def testPrintFloatFormat(self, message_module):
# Check that float_format argument is passed to sub-message formatting.
- message = unittest_pb2.NestedTestAllTypes()
+ message = message_module.NestedTestAllTypes()
# We use 1.25 as it is a round number in binary. The proto 32-bit float
# will not gain additional imprecise digits as a 64-bit Python float and
# show up in its str. 32-bit 1.2 is noisy when extended to 64-bit:
@@ -272,93 +237,62 @@ class TextFormatTest(basetest.TestCase):
text_message = text_format.MessageToString(message, float_format='.15g')
self.CompareToGoldenText(
self.RemoveRedundantZeros(text_message),
- 'payload {{\n {}\n {}\n {}\n {}\n}}\n'.format(*formatted_fields))
+ 'payload {{\n {0}\n {1}\n {2}\n {3}\n}}\n'.format(*formatted_fields))
# as_one_line=True is a separate code branch where float_format is passed.
text_message = text_format.MessageToString(message, as_one_line=True,
float_format='.15g')
self.CompareToGoldenText(
self.RemoveRedundantZeros(text_message),
- 'payload {{ {} {} {} {} }}'.format(*formatted_fields))
+ 'payload {{ {0} {1} {2} {3} }}'.format(*formatted_fields))
- def testMessageToString(self):
- message = unittest_pb2.ForeignMessage()
+ def testMessageToString(self, message_module):
+ message = message_module.ForeignMessage()
message.c = 123
self.assertEqual('c: 123\n', str(message))
- def RemoveRedundantZeros(self, text):
- # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove
- # these zeros in order to match the golden file.
- text = text.replace('e+0','e+').replace('e+0','e+') \
- .replace('e-0','e-').replace('e-0','e-')
- # Floating point fields are printed with .0 suffix even if they are
- # actualy integer numbers.
- text = re.compile('\.0$', re.MULTILINE).sub('', text)
- return text
-
- def testParseGolden(self):
- golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt'))
- parsed_message = unittest_pb2.TestAllTypes()
- r = text_format.Parse(golden_text, parsed_message)
- self.assertIs(r, parsed_message)
-
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.assertEquals(message, parsed_message)
-
- def testParseGoldenExtensions(self):
- golden_text = '\n'.join(self.ReadGolden(
- 'text_format_unittest_extensions_data.txt'))
- parsed_message = unittest_pb2.TestAllExtensions()
- text_format.Parse(golden_text, parsed_message)
-
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- self.assertEquals(message, parsed_message)
-
- def testParseAllFields(self):
- message = unittest_pb2.TestAllTypes()
+ def testPrintField(self, message_module):
+ message = message_module.TestAllTypes()
+ field = message.DESCRIPTOR.fields_by_name['optional_float']
+ value = message.optional_float
+ out = text_format.TextWriter(False)
+ text_format.PrintField(field, value, out)
+ self.assertEqual('optional_float: 0.0\n', out.getvalue())
+ out.close()
+ # Test Printer
+ out = text_format.TextWriter(False)
+ printer = text_format._Printer(out)
+ printer.PrintField(field, value)
+ self.assertEqual('optional_float: 0.0\n', out.getvalue())
+ out.close()
+
+ def testPrintFieldValue(self, message_module):
+ message = message_module.TestAllTypes()
+ field = message.DESCRIPTOR.fields_by_name['optional_float']
+ value = message.optional_float
+ out = text_format.TextWriter(False)
+ text_format.PrintFieldValue(field, value, out)
+ self.assertEqual('0.0', out.getvalue())
+ out.close()
+ # Test Printer
+ out = text_format.TextWriter(False)
+ printer = text_format._Printer(out)
+ printer.PrintFieldValue(field, value)
+ self.assertEqual('0.0', out.getvalue())
+ out.close()
+
+ def testParseAllFields(self, message_module):
+ message = message_module.TestAllTypes()
test_util.SetAllFields(message)
ascii_text = text_format.MessageToString(message)
- parsed_message = unittest_pb2.TestAllTypes()
+ parsed_message = message_module.TestAllTypes()
text_format.Parse(ascii_text, parsed_message)
self.assertEqual(message, parsed_message)
- test_util.ExpectAllFieldsSet(self, message)
+ if message_module is unittest_pb2:
+ test_util.ExpectAllFieldsSet(self, message)
- def testParseAllExtensions(self):
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- ascii_text = text_format.MessageToString(message)
-
- parsed_message = unittest_pb2.TestAllExtensions()
- text_format.Parse(ascii_text, parsed_message)
- self.assertEqual(message, parsed_message)
-
- def testParseMessageSet(self):
- message = unittest_pb2.TestAllTypes()
- text = ('repeated_uint64: 1\n'
- 'repeated_uint64: 2\n')
- text_format.Parse(text, message)
- self.assertEqual(1, message.repeated_uint64[0])
- self.assertEqual(2, message.repeated_uint64[1])
-
- message = unittest_mset_pb2.TestMessageSetContainer()
- text = ('message_set {\n'
- ' [protobuf_unittest.TestMessageSetExtension1] {\n'
- ' i: 23\n'
- ' }\n'
- ' [protobuf_unittest.TestMessageSetExtension2] {\n'
- ' str: \"foo\"\n'
- ' }\n'
- '}\n')
- text_format.Parse(text, message)
- ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
- ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
- self.assertEquals(23, message.message_set.Extensions[ext1].i)
- self.assertEquals('foo', message.message_set.Extensions[ext2].str)
-
- def testParseExotic(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseExotic(self, message_module):
+ message = message_module.TestAllTypes()
text = ('repeated_int64: -9223372036854775808\n'
'repeated_uint64: 18446744073709551615\n'
'repeated_double: 123.456\n'
@@ -383,8 +317,8 @@ class TextFormatTest(basetest.TestCase):
self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2])
self.assertEqual(u'\u00fc', message.repeated_string[3])
- def testParseTrailingCommas(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseTrailingCommas(self, message_module):
+ message = message_module.TestAllTypes()
text = ('repeated_int64: 100;\n'
'repeated_int64: 200;\n'
'repeated_int64: 300,\n'
@@ -398,101 +332,87 @@ class TextFormatTest(basetest.TestCase):
self.assertEqual(u'one', message.repeated_string[0])
self.assertEqual(u'two', message.repeated_string[1])
- def testParseEmptyText(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseRepeatedScalarShortFormat(self, message_module):
+ message = message_module.TestAllTypes()
+ text = ('repeated_int64: [100, 200];\n'
+ 'repeated_int64: 300,\n'
+ 'repeated_string: ["one", "two"];\n')
+ text_format.Parse(text, message)
+
+ self.assertEqual(100, message.repeated_int64[0])
+ self.assertEqual(200, message.repeated_int64[1])
+ self.assertEqual(300, message.repeated_int64[2])
+ self.assertEqual(u'one', message.repeated_string[0])
+ self.assertEqual(u'two', message.repeated_string[1])
+
+ def testParseRepeatedMessageShortFormat(self, message_module):
+ message = message_module.TestAllTypes()
+ text = ('repeated_nested_message: [{bb: 100}, {bb: 200}],\n'
+ 'repeated_nested_message: {bb: 300}\n'
+ 'repeated_nested_message [{bb: 400}];\n')
+ text_format.Parse(text, message)
+
+ self.assertEqual(100, message.repeated_nested_message[0].bb)
+ self.assertEqual(200, message.repeated_nested_message[1].bb)
+ self.assertEqual(300, message.repeated_nested_message[2].bb)
+ self.assertEqual(400, message.repeated_nested_message[3].bb)
+
+ def testParseEmptyText(self, message_module):
+ message = message_module.TestAllTypes()
text = ''
text_format.Parse(text, message)
- self.assertEquals(unittest_pb2.TestAllTypes(), message)
+ self.assertEqual(message_module.TestAllTypes(), message)
- def testParseInvalidUtf8(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseInvalidUtf8(self, message_module):
+ message = message_module.TestAllTypes()
text = 'repeated_string: "\\xc3\\xc3"'
self.assertRaises(text_format.ParseError, text_format.Parse, text, message)
- def testParseSingleWord(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseSingleWord(self, message_module):
+ message = message_module.TestAllTypes()
text = 'foo'
- self.assertRaisesWithLiteralMatch(
+ six.assertRaisesRegex(self,
text_format.ParseError,
- ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named '
- '"foo".'),
+ (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"foo".'),
text_format.Parse, text, message)
- def testParseUnknownField(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseUnknownField(self, message_module):
+ message = message_module.TestAllTypes()
text = 'unknown_field: 8\n'
- self.assertRaisesWithLiteralMatch(
+ six.assertRaisesRegex(self,
text_format.ParseError,
- ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named '
- '"unknown_field".'),
+ (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"unknown_field".'),
text_format.Parse, text, message)
- def testParseBadExtension(self):
- message = unittest_pb2.TestAllExtensions()
- text = '[unknown_extension]: 8\n'
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError,
- '1:2 : Extension "unknown_extension" not registered.',
- text_format.Parse, text, message)
- message = unittest_pb2.TestAllTypes()
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError,
- ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
- 'extensions.'),
- text_format.Parse, text, message)
-
- def testParseGroupNotClosed(self):
- message = unittest_pb2.TestAllTypes()
- text = 'RepeatedGroup: <'
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError, '1:16 : Expected ">".',
- text_format.Parse, text, message)
-
- text = 'RepeatedGroup: {'
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError, '1:16 : Expected "}".',
- text_format.Parse, text, message)
-
- def testParseEmptyGroup(self):
- message = unittest_pb2.TestAllTypes()
- text = 'OptionalGroup: {}'
- text_format.Parse(text, message)
- self.assertTrue(message.HasField('optionalgroup'))
-
- message.Clear()
-
- message = unittest_pb2.TestAllTypes()
- text = 'OptionalGroup: <>'
- text_format.Parse(text, message)
- self.assertTrue(message.HasField('optionalgroup'))
-
- def testParseBadEnumValue(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseBadEnumValue(self, message_module):
+ message = message_module.TestAllTypes()
text = 'optional_nested_enum: BARR'
- self.assertRaisesWithLiteralMatch(
+ six.assertRaisesRegex(self,
text_format.ParseError,
- ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
- 'has no value named BARR.'),
+ (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ r'has no value named BARR.'),
text_format.Parse, text, message)
- message = unittest_pb2.TestAllTypes()
+ message = message_module.TestAllTypes()
text = 'optional_nested_enum: 100'
- self.assertRaisesWithLiteralMatch(
+ six.assertRaisesRegex(self,
text_format.ParseError,
- ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
- 'has no value with number 100.'),
+ (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ r'has no value with number 100.'),
text_format.Parse, text, message)
- def testParseBadIntValue(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseBadIntValue(self, message_module):
+ message = message_module.TestAllTypes()
text = 'optional_int32: bork'
- self.assertRaisesWithLiteralMatch(
+ six.assertRaisesRegex(self,
text_format.ParseError,
('1:17 : Couldn\'t parse integer: bork'),
text_format.Parse, text, message)
- def testParseStringFieldUnescape(self):
- message = unittest_pb2.TestAllTypes()
+ def testParseStringFieldUnescape(self, message_module):
+ message = message_module.TestAllTypes()
text = r'''repeated_string: "\xf\x62"
repeated_string: "\\xf\\x62"
repeated_string: "\\\xf\\\x62"
@@ -511,43 +431,483 @@ class TextFormatTest(basetest.TestCase):
message.repeated_string[4])
self.assertEqual(SLASH + 'x20', message.repeated_string[5])
- def testMergeRepeatedScalars(self):
- message = unittest_pb2.TestAllTypes()
+ def testMergeDuplicateScalars(self, message_module):
+ message = message_module.TestAllTypes()
text = ('optional_int32: 42 '
'optional_int32: 67')
r = text_format.Merge(text, message)
self.assertIs(r, message)
self.assertEqual(67, message.optional_int32)
- def testParseRepeatedScalars(self):
- message = unittest_pb2.TestAllTypes()
- text = ('optional_int32: 42 '
- 'optional_int32: 67')
- self.assertRaisesWithLiteralMatch(
- text_format.ParseError,
- ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
- 'have multiple "optional_int32" fields.'),
- text_format.Parse, text, message)
-
- def testMergeRepeatedNestedMessageScalars(self):
- message = unittest_pb2.TestAllTypes()
+ def testMergeDuplicateNestedMessageScalars(self, message_module):
+ message = message_module.TestAllTypes()
text = ('optional_nested_message { bb: 1 } '
'optional_nested_message { bb: 2 }')
r = text_format.Merge(text, message)
self.assertTrue(r is message)
self.assertEqual(2, message.optional_nested_message.bb)
- def testParseRepeatedNestedMessageScalars(self):
+ def testParseOneof(self, message_module):
+ m = message_module.TestAllTypes()
+ m.oneof_uint32 = 11
+ m2 = message_module.TestAllTypes()
+ text_format.Parse(text_format.MessageToString(m), m2)
+ self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
+
+ def testParseMultipleOneof(self, message_module):
+ m_string = '\n'.join([
+ 'oneof_uint32: 11',
+ 'oneof_string: "foo"'])
+ m2 = message_module.TestAllTypes()
+ if message_module is unittest_pb2:
+ with self.assertRaisesRegexp(
+ text_format.ParseError, ' is specified along with field '):
+ text_format.Parse(m_string, m2)
+ else:
+ text_format.Parse(m_string, m2)
+ self.assertEqual('oneof_string', m2.WhichOneof('oneof_field'))
+
+
+# These are tests that aren't fundamentally specific to proto2, but are at
+# the moment because of differences between the proto2 and proto3 test schemas.
+# Ideally the schemas would be made more similar so these tests could pass.
+class OnlyWorksWithProto2RightNowTests(TextFormatBase):
+
+ def testPrintAllFieldsPointy(self):
message = unittest_pb2.TestAllTypes()
- text = ('optional_nested_message { bb: 1 } '
- 'optional_nested_message { bb: 2 }')
- self.assertRaisesWithLiteralMatch(
+ test_util.SetAllFields(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(
+ text_format.MessageToString(message, pointy_brackets=True)),
+ 'text_format_unittest_data_pointy_oneof.txt')
+
+ def testParseGolden(self):
+ golden_text = '\n'.join(self.ReadGolden(
+ 'text_format_unittest_data_oneof_implemented.txt'))
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.Parse(golden_text, parsed_message)
+ self.assertIs(r, parsed_message)
+
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.assertEqual(message, parsed_message)
+
+ def testPrintAllFields(self):
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'text_format_unittest_data_oneof_implemented.txt')
+
+ def testPrintAllFieldsPointy(self):
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(
+ text_format.MessageToString(message, pointy_brackets=True)),
+ 'text_format_unittest_data_pointy_oneof.txt')
+
+ def testPrintInIndexOrder(self):
+ message = unittest_pb2.TestFieldOrderings()
+ message.my_string = '115'
+ message.my_int = 101
+ message.my_float = 111
+ message.optional_nested_message.oo = 0
+ message.optional_nested_message.bb = 1
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message, use_index_order=True)),
+ 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n'
+ 'optional_nested_message {\n oo: 0\n bb: 1\n}\n')
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message)),
+ 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n'
+ 'optional_nested_message {\n bb: 1\n oo: 0\n}\n')
+
+ def testMergeLinesGolden(self):
+ opened = self.ReadGolden('text_format_unittest_data_oneof_implemented.txt')
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.MergeLines(opened, parsed_message)
+ self.assertIs(r, parsed_message)
+
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.assertEqual(message, parsed_message)
+
+ def testParseLinesGolden(self):
+ opened = self.ReadGolden('text_format_unittest_data_oneof_implemented.txt')
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.ParseLines(opened, parsed_message)
+ self.assertIs(r, parsed_message)
+
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.assertEqual(message, parsed_message)
+
+ def testPrintMap(self):
+ message = map_unittest_pb2.TestMap()
+
+ message.map_int32_int32[-123] = -456
+ message.map_int64_int64[-2**33] = -2**34
+ message.map_uint32_uint32[123] = 456
+ message.map_uint64_uint64[2**33] = 2**34
+ message.map_string_string["abc"] = "123"
+ message.map_int32_foreign_message[111].c = 5
+
+ # Maps are serialized to text format using their underlying repeated
+ # representation.
+ self.CompareToGoldenText(
+ text_format.MessageToString(message),
+ 'map_int32_int32 {\n'
+ ' key: -123\n'
+ ' value: -456\n'
+ '}\n'
+ 'map_int64_int64 {\n'
+ ' key: -8589934592\n'
+ ' value: -17179869184\n'
+ '}\n'
+ 'map_uint32_uint32 {\n'
+ ' key: 123\n'
+ ' value: 456\n'
+ '}\n'
+ 'map_uint64_uint64 {\n'
+ ' key: 8589934592\n'
+ ' value: 17179869184\n'
+ '}\n'
+ 'map_string_string {\n'
+ ' key: "abc"\n'
+ ' value: "123"\n'
+ '}\n'
+ 'map_int32_foreign_message {\n'
+ ' key: 111\n'
+ ' value {\n'
+ ' c: 5\n'
+ ' }\n'
+ '}\n')
+
+ def testMapOrderEnforcement(self):
+ message = map_unittest_pb2.TestMap()
+ for letter in string.ascii_uppercase[13:26]:
+ message.map_string_string[letter] = 'dummy'
+ for letter in reversed(string.ascii_uppercase[0:13]):
+ message.map_string_string[letter] = 'dummy'
+ golden = ''.join((
+ 'map_string_string {\n key: "%c"\n value: "dummy"\n}\n' % (letter,)
+ for letter in string.ascii_uppercase))
+ self.CompareToGoldenText(text_format.MessageToString(message), golden)
+
+ def testMapOrderSemantics(self):
+ golden_lines = self.ReadGolden('map_test_data.txt')
+ # The C++ implementation emits defaulted-value fields, while the Python
+ # implementation does not. Adjusting for this is awkward, but it is
+ # valuable to test against a common golden file.
+ line_blacklist = (' key: 0\n',
+ ' value: 0\n',
+ ' key: false\n',
+ ' value: false\n')
+ golden_lines = [line for line in golden_lines if line not in line_blacklist]
+
+ message = map_unittest_pb2.TestMap()
+ text_format.ParseLines(golden_lines, message)
+ candidate = text_format.MessageToString(message)
+ # The Python implementation emits "1.0" for the double value that the C++
+ # implementation emits as "1".
+ candidate = candidate.replace('1.0', '1', 2)
+ self.assertMultiLineEqual(candidate, ''.join(golden_lines))
+
+
+# Tests of proto2-only features (MessageSet, extensions, etc.).
+class Proto2Tests(TextFormatBase):
+
+ def testPrintMessageSet(self):
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ self.CompareToGoldenText(
+ text_format.MessageToString(message),
+ 'message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
+
+ message = message_set_extensions_pb2.TestMessageSet()
+ ext = message_set_extensions_pb2.message_set_extension3
+ message.Extensions[ext].text = 'bar'
+ self.CompareToGoldenText(
+ text_format.MessageToString(message),
+ '[google.protobuf.internal.TestMessageSetExtension3] {\n'
+ ' text: \"bar\"\n'
+ '}\n')
+
+ def testPrintMessageSetByFieldNumber(self):
+ out = text_format.TextWriter(False)
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ text_format.PrintMessage(message, out, use_field_number=True)
+ self.CompareToGoldenText(
+ out.getvalue(),
+ '1 {\n'
+ ' 1545008 {\n'
+ ' 15: 23\n'
+ ' }\n'
+ ' 1547769 {\n'
+ ' 25: \"foo\"\n'
+ ' }\n'
+ '}\n')
+ out.close()
+
+ def testPrintMessageSetAsOneLine(self):
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ self.CompareToGoldenText(
+ text_format.MessageToString(message, as_one_line=True),
+ 'message_set {'
+ ' [protobuf_unittest.TestMessageSetExtension1] {'
+ ' i: 23'
+ ' }'
+ ' [protobuf_unittest.TestMessageSetExtension2] {'
+ ' str: \"foo\"'
+ ' }'
+ ' }')
+
+ def testParseMessageSet(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('repeated_uint64: 1\n'
+ 'repeated_uint64: 2\n')
+ text_format.Parse(text, message)
+ self.assertEqual(1, message.repeated_uint64[0])
+ self.assertEqual(2, message.repeated_uint64[1])
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
+ text_format.Parse(text, message)
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ self.assertEqual(23, message.message_set.Extensions[ext1].i)
+ self.assertEqual('foo', message.message_set.Extensions[ext2].str)
+
+ def testParseMessageByFieldNumber(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('34: 1\n'
+ 'repeated_uint64: 2\n')
+ text_format.Parse(text, message, allow_field_number=True)
+ self.assertEqual(1, message.repeated_uint64[0])
+ self.assertEqual(2, message.repeated_uint64[1])
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('1 {\n'
+ ' 1545008 {\n'
+ ' 15: 23\n'
+ ' }\n'
+ ' 1547769 {\n'
+ ' 25: \"foo\"\n'
+ ' }\n'
+ '}\n')
+ text_format.Parse(text, message, allow_field_number=True)
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ self.assertEqual(23, message.message_set.Extensions[ext1].i)
+ self.assertEqual('foo', message.message_set.Extensions[ext2].str)
+
+ # Can't parse field number without set allow_field_number=True.
+ message = unittest_pb2.TestAllTypes()
+ text = '34:1\n'
+ six.assertRaisesRegex(
+ self,
text_format.ParseError,
- ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
- 'should not have multiple "bb" fields.'),
+ (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"34".'),
text_format.Parse, text, message)
- def testMergeRepeatedExtensionScalars(self):
+ # Can't parse if field number is not found.
+ text = '1234:1\n'
+ six.assertRaisesRegex(
+ self,
+ text_format.ParseError,
+ (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"1234".'),
+ text_format.Parse, text, message, allow_field_number=True)
+
+ def testPrintAllExtensions(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'text_format_unittest_extensions_data.txt')
+
+ def testPrintAllExtensionsPointy(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message, pointy_brackets=True)),
+ 'text_format_unittest_extensions_data_pointy.txt')
+
+ def testParseGoldenExtensions(self):
+ golden_text = '\n'.join(self.ReadGolden(
+ 'text_format_unittest_extensions_data.txt'))
+ parsed_message = unittest_pb2.TestAllExtensions()
+ text_format.Parse(golden_text, parsed_message)
+
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.assertEqual(message, parsed_message)
+
+ def testParseAllExtensions(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ ascii_text = text_format.MessageToString(message)
+
+ parsed_message = unittest_pb2.TestAllExtensions()
+ text_format.Parse(ascii_text, parsed_message)
+ self.assertEqual(message, parsed_message)
+
+ def testParseAllowedUnknownExtension(self):
+ # Skip over unknown extension correctly.
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('message_set {\n'
+ ' [unknown_extension] {\n'
+ ' i: 23\n'
+ ' bin: "\xe0"'
+ ' [nested_unknown_ext]: {\n'
+ ' i: 23\n'
+ ' test: "test_string"\n'
+ ' floaty_float: -0.315\n'
+ ' num: -inf\n'
+ ' multiline_str: "abc"\n'
+ ' "def"\n'
+ ' "xyz."\n'
+ ' [nested_unknown_ext]: <\n'
+ ' i: 23\n'
+ ' i: 24\n'
+ ' pointfloat: .3\n'
+ ' test: "test_string"\n'
+ ' floaty_float: -0.315\n'
+ ' num: -inf\n'
+ ' long_string: "test" "test2" \n'
+ ' >\n'
+ ' }\n'
+ ' }\n'
+ ' [unknown_extension]: 5\n'
+ '}\n')
+ text_format.Parse(text, message, allow_unknown_extension=True)
+ golden = 'message_set {\n}\n'
+ self.CompareToGoldenText(text_format.MessageToString(message), golden)
+
+ # Catch parse errors in unknown extension.
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' [unknown_extension] {\n'
+ ' i:\n' # Missing value.
+ ' }\n'
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ 'Invalid field value: }',
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' [unknown_extension] {\n'
+ ' str: "malformed string\n' # Missing closing quote.
+ ' }\n'
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ 'Invalid field value: "',
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' [unknown_extension] {\n'
+ ' str: "malformed\n multiline\n string\n'
+ ' }\n'
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ 'Invalid field value: "',
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' [malformed_extension] <\n'
+ ' i: -5\n'
+ ' \n' # Missing '>' here.
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ '5:1 : Expected ">".',
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ # Don't allow unknown fields with allow_unknown_extension=True.
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ malformed = ('message_set {\n'
+ ' unknown_field: true\n'
+ ' \n' # Missing '>' here.
+ '}\n')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ ('2:3 : Message type '
+ '"proto2_wireformat_unittest.TestMessageSet" has no'
+ ' field named "unknown_field".'),
+ text_format.Parse, malformed, message,
+ allow_unknown_extension=True)
+
+ # Parse known extension correcty.
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
+ text_format.Parse(text, message, allow_unknown_extension=True)
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ self.assertEqual(23, message.message_set.Extensions[ext1].i)
+ self.assertEqual('foo', message.message_set.Extensions[ext2].str)
+
+ def testParseBadExtension(self):
+ message = unittest_pb2.TestAllExtensions()
+ text = '[unknown_extension]: 8\n'
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ '1:2 : Extension "unknown_extension" not registered.',
+ text_format.Parse, text, message)
+ message = unittest_pb2.TestAllTypes()
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
+ 'extensions.'),
+ text_format.Parse, text, message)
+
+ def testMergeDuplicateExtensionScalars(self):
message = unittest_pb2.TestAllExtensions()
text = ('[protobuf_unittest.optional_int32_extension]: 42 '
'[protobuf_unittest.optional_int32_extension]: 67')
@@ -556,46 +916,102 @@ class TextFormatTest(basetest.TestCase):
67,
message.Extensions[unittest_pb2.optional_int32_extension])
- def testParseRepeatedExtensionScalars(self):
+ def testParseDuplicateExtensionScalars(self):
message = unittest_pb2.TestAllExtensions()
text = ('[protobuf_unittest.optional_int32_extension]: 42 '
'[protobuf_unittest.optional_int32_extension]: 67')
- self.assertRaisesWithLiteralMatch(
+ six.assertRaisesRegex(self,
text_format.ParseError,
('1:96 : Message type "protobuf_unittest.TestAllExtensions" '
'should not have multiple '
'"protobuf_unittest.optional_int32_extension" extensions.'),
text_format.Parse, text, message)
- def testParseLinesGolden(self):
- opened = self.ReadGolden('text_format_unittest_data.txt')
- parsed_message = unittest_pb2.TestAllTypes()
- r = text_format.ParseLines(opened, parsed_message)
- self.assertIs(r, parsed_message)
+ def testParseDuplicateNestedMessageScalars(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('optional_nested_message { bb: 1 } '
+ 'optional_nested_message { bb: 2 }')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
+ 'should not have multiple "bb" fields.'),
+ text_format.Parse, text, message)
+ def testParseDuplicateScalars(self):
message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.assertEquals(message, parsed_message)
+ text = ('optional_int32: 42 '
+ 'optional_int32: 67')
+ six.assertRaisesRegex(self,
+ text_format.ParseError,
+ ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
+ 'have multiple "optional_int32" fields.'),
+ text_format.Parse, text, message)
- def testMergeLinesGolden(self):
- opened = self.ReadGolden('text_format_unittest_data.txt')
- parsed_message = unittest_pb2.TestAllTypes()
- r = text_format.MergeLines(opened, parsed_message)
- self.assertIs(r, parsed_message)
+ def testParseGroupNotClosed(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'RepeatedGroup: <'
+ six.assertRaisesRegex(self,
+ text_format.ParseError, '1:16 : Expected ">".',
+ text_format.Parse, text, message)
+ text = 'RepeatedGroup: {'
+ six.assertRaisesRegex(self,
+ text_format.ParseError, '1:16 : Expected "}".',
+ text_format.Parse, text, message)
+ def testParseEmptyGroup(self):
message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.assertEqual(message, parsed_message)
+ text = 'OptionalGroup: {}'
+ text_format.Parse(text, message)
+ self.assertTrue(message.HasField('optionalgroup'))
- def testParseOneof(self):
- m = unittest_pb2.TestAllTypes()
- m.oneof_uint32 = 11
- m2 = unittest_pb2.TestAllTypes()
- text_format.Parse(text_format.MessageToString(m), m2)
- self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
+ message.Clear()
+
+ message = unittest_pb2.TestAllTypes()
+ text = 'OptionalGroup: <>'
+ text_format.Parse(text, message)
+ self.assertTrue(message.HasField('optionalgroup'))
+
+ # Maps aren't really proto2-only, but our test schema only has maps for
+ # proto2.
+ def testParseMap(self):
+ text = ('map_int32_int32 {\n'
+ ' key: -123\n'
+ ' value: -456\n'
+ '}\n'
+ 'map_int64_int64 {\n'
+ ' key: -8589934592\n'
+ ' value: -17179869184\n'
+ '}\n'
+ 'map_uint32_uint32 {\n'
+ ' key: 123\n'
+ ' value: 456\n'
+ '}\n'
+ 'map_uint64_uint64 {\n'
+ ' key: 8589934592\n'
+ ' value: 17179869184\n'
+ '}\n'
+ 'map_string_string {\n'
+ ' key: "abc"\n'
+ ' value: "123"\n'
+ '}\n'
+ 'map_int32_foreign_message {\n'
+ ' key: 111\n'
+ ' value {\n'
+ ' c: 5\n'
+ ' }\n'
+ '}\n')
+ message = map_unittest_pb2.TestMap()
+ text_format.Parse(text, message)
+
+ self.assertEqual(-456, message.map_int32_int32[-123])
+ self.assertEqual(-2**34, message.map_int64_int64[-2**33])
+ self.assertEqual(456, message.map_uint32_uint32[123])
+ self.assertEqual(2**34, message.map_uint64_uint64[2**33])
+ self.assertEqual("123", message.map_string_string["abc"])
+ self.assertEqual(5, message.map_int32_foreign_message[111].c)
-class TokenizerTest(basetest.TestCase):
+class TokenizerTest(unittest.TestCase):
def testSimpleTokenCases(self):
text = ('identifier1:"string1"\n \n\n'
@@ -740,4 +1156,4 @@ class TokenizerTest(basetest.TestCase):
if __name__ == '__main__':
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index 56d264604..1be3ad9a3 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -28,10 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#PY25 compatible for GAE.
-#
-# Copyright 2008 Google Inc. All Rights Reserved.
-
"""Provides type checking routines.
This module defines type checking utilities in the forms of dictionaries:
@@ -49,8 +45,11 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization
__author__ = 'robinson@google.com (Will Robinson)'
-import sys ##PY25
-if sys.version < '2.6': bytes = str ##PY25
+import six
+
+if six.PY3:
+ long = int
+
from google.protobuf.internal import api_implementation
from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
@@ -59,6 +58,8 @@ from google.protobuf import descriptor
_FieldDescriptor = descriptor.FieldDescriptor
+def SupportsOpenEnums(field_descriptor):
+ return field_descriptor.containing_type.syntax == "proto3"
def GetTypeChecker(field):
"""Returns a type checker for a message field of the specified types.
@@ -74,7 +75,11 @@ def GetTypeChecker(field):
field.type == _FieldDescriptor.TYPE_STRING):
return UnicodeValueChecker()
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
- return EnumValueChecker(field.enum_type)
+ if SupportsOpenEnums(field):
+ # When open enums are supported, any int32 can be assigned.
+ return _VALUE_CHECKERS[_FieldDescriptor.CPPTYPE_INT32]
+ else:
+ return EnumValueChecker(field.enum_type)
return _VALUE_CHECKERS[field.cpp_type]
@@ -104,6 +109,16 @@ class TypeChecker(object):
return proposed_value
+class TypeCheckerWithDefault(TypeChecker):
+
+ def __init__(self, default_value, *acceptable_types):
+ TypeChecker.__init__(self, acceptable_types)
+ self._default_value = default_value
+
+ def DefaultValue(self):
+ return self._default_value
+
+
# IntValueChecker and its subclasses perform integer type-checks
# and bounds-checks.
class IntValueChecker(object):
@@ -111,9 +126,9 @@ class IntValueChecker(object):
"""Checker used for integer fields. Performs type-check and range check."""
def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, (int, long)):
+ if not isinstance(proposed_value, six.integer_types):
message = ('%.1024r has type %s, but expected one of: %s' %
- (proposed_value, type(proposed_value), (int, long)))
+ (proposed_value, type(proposed_value), six.integer_types))
raise TypeError(message)
if not self._MIN <= proposed_value <= self._MAX:
raise ValueError('Value out of range: %d' % proposed_value)
@@ -123,6 +138,9 @@ class IntValueChecker(object):
proposed_value = self._TYPE(proposed_value)
return proposed_value
+ def DefaultValue(self):
+ return 0
+
class EnumValueChecker(object):
@@ -132,14 +150,17 @@ class EnumValueChecker(object):
self._enum_type = enum_type
def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, (int, long)):
+ if not isinstance(proposed_value, six.integer_types):
message = ('%.1024r has type %s, but expected one of: %s' %
- (proposed_value, type(proposed_value), (int, long)))
+ (proposed_value, type(proposed_value), six.integer_types))
raise TypeError(message)
if proposed_value not in self._enum_type.values_by_number:
raise ValueError('Unknown enum value: %d' % proposed_value)
return proposed_value
+ def DefaultValue(self):
+ return self._enum_type.values[0].number
+
class UnicodeValueChecker(object):
@@ -149,23 +170,25 @@ class UnicodeValueChecker(object):
"""
def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, (bytes, unicode)):
+ if not isinstance(proposed_value, (bytes, six.text_type)):
message = ('%.1024r has type %s, but expected one of: %s' %
- (proposed_value, type(proposed_value), (bytes, unicode)))
+ (proposed_value, type(proposed_value), (bytes, six.text_type)))
raise TypeError(message)
- # If the value is of type 'bytes' make sure that it is in 7-bit ASCII
- # encoding.
+ # If the value is of type 'bytes' make sure that it is valid UTF-8 data.
if isinstance(proposed_value, bytes):
try:
- proposed_value = proposed_value.decode('ascii')
+ proposed_value = proposed_value.decode('utf-8')
except UnicodeDecodeError:
- raise ValueError('%.1024r has type bytes, but isn\'t in 7-bit ASCII '
- 'encoding. Non-ASCII strings must be converted to '
+ raise ValueError('%.1024r has type bytes, but isn\'t valid UTF-8 '
+ 'encoding. Non-UTF-8 strings must be converted to '
'unicode objects before being added.' %
(proposed_value))
return proposed_value
+ def DefaultValue(self):
+ return u""
+
class Int32ValueChecker(IntValueChecker):
# We're sure to use ints instead of longs here since comparison may be more
@@ -199,12 +222,13 @@ _VALUE_CHECKERS = {
_FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(),
_FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(),
_FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(),
- _FieldDescriptor.CPPTYPE_DOUBLE: TypeChecker(
- float, int, long),
- _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker(
- float, int, long),
- _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int),
- _FieldDescriptor.CPPTYPE_STRING: TypeChecker(bytes),
+ _FieldDescriptor.CPPTYPE_DOUBLE: TypeCheckerWithDefault(
+ 0.0, float, int, long),
+ _FieldDescriptor.CPPTYPE_FLOAT: TypeCheckerWithDefault(
+ 0.0, float, int, long),
+ _FieldDescriptor.CPPTYPE_BOOL: TypeCheckerWithDefault(
+ False, bool, int),
+ _FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes),
}
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
index 71775609d..84073f1c5 100755
--- a/python/google/protobuf/internal/unknown_fields_test.py
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# Protocol Buffers - Google's data interchange format
@@ -35,16 +35,28 @@
__author__ = 'bohdank@google.com (Bohdan Koval)'
-from google.apputils import basetest
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
+from google.protobuf import unittest_proto3_arena_pb2
+from google.protobuf.internal import api_implementation
from google.protobuf.internal import encoder
+from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import missing_enum_values_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import type_checkers
-class UnknownFieldsTest(basetest.TestCase):
+def SkipIfCppImplementation(func):
+ return unittest.skipIf(
+ api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
+ 'C++ implementation does not expose unknown fields to Python')(func)
+
+
+class UnknownFieldsTest(unittest.TestCase):
def setUp(self):
self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
@@ -53,7 +65,98 @@ class UnknownFieldsTest(basetest.TestCase):
self.all_fields_data = self.all_fields.SerializeToString()
self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data)
- self.unknown_fields = self.empty_message._unknown_fields
+
+ def testSerialize(self):
+ data = self.empty_message.SerializeToString()
+
+ # Don't use assertEqual because we don't want to dump raw binary data to
+ # stdout.
+ self.assertTrue(data == self.all_fields_data)
+
+ def testSerializeProto3(self):
+ # Verify that proto3 doesn't preserve unknown fields.
+ message = unittest_proto3_arena_pb2.TestEmptyMessage()
+ message.ParseFromString(self.all_fields_data)
+ self.assertEqual(0, len(message.SerializeToString()))
+
+ def testByteSize(self):
+ self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
+
+ def testListFields(self):
+ # Make sure ListFields doesn't return unknown fields.
+ self.assertEqual(0, len(self.empty_message.ListFields()))
+
+ def testSerializeMessageSetWireFormatUnknownExtension(self):
+ # Create a message using the message set wire format with an unknown
+ # message.
+ raw = unittest_mset_pb2.RawMessageSet()
+
+ # Add an unknown extension.
+ item = raw.item.add()
+ item.type_id = 98418603
+ message1 = message_set_extensions_pb2.TestMessageSetExtension1()
+ message1.i = 12345
+ item.message = message1.SerializeToString()
+
+ serialized = raw.SerializeToString()
+
+ # Parse message using the message set wire format.
+ proto = message_set_extensions_pb2.TestMessageSet()
+ proto.MergeFromString(serialized)
+
+ # Verify that the unknown extension is serialized unchanged
+ reserialized = proto.SerializeToString()
+ new_raw = unittest_mset_pb2.RawMessageSet()
+ new_raw.MergeFromString(reserialized)
+ self.assertEqual(raw, new_raw)
+
+ def testEquals(self):
+ message = unittest_pb2.TestEmptyMessage()
+ message.ParseFromString(self.all_fields_data)
+ self.assertEqual(self.empty_message, message)
+
+ self.all_fields.ClearField('optional_string')
+ message.ParseFromString(self.all_fields.SerializeToString())
+ self.assertNotEqual(self.empty_message, message)
+
+ def testDiscardUnknownFields(self):
+ self.empty_message.DiscardUnknownFields()
+ self.assertEqual(b'', self.empty_message.SerializeToString())
+ # Test message field and repeated message field.
+ message = unittest_pb2.TestAllTypes()
+ other_message = unittest_pb2.TestAllTypes()
+ other_message.optional_string = 'discard'
+ message.optional_nested_message.ParseFromString(
+ other_message.SerializeToString())
+ message.repeated_nested_message.add().ParseFromString(
+ other_message.SerializeToString())
+ self.assertNotEqual(
+ b'', message.optional_nested_message.SerializeToString())
+ self.assertNotEqual(
+ b'', message.repeated_nested_message[0].SerializeToString())
+ message.DiscardUnknownFields()
+ self.assertEqual(b'', message.optional_nested_message.SerializeToString())
+ self.assertEqual(
+ b'', message.repeated_nested_message[0].SerializeToString())
+
+
+class UnknownFieldsAccessorsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ self.all_fields = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(self.all_fields)
+ self.all_fields_data = self.all_fields.SerializeToString()
+ self.empty_message = unittest_pb2.TestEmptyMessage()
+ self.empty_message.ParseFromString(self.all_fields_data)
+ if api_implementation.Type() != 'cpp':
+ # _unknown_fields is an implementation detail.
+ self.unknown_fields = self.empty_message._unknown_fields
+
+ # All the tests that use GetField() check an implementation detail of the
+ # Python implementation, which stores unknown fields as serialized strings.
+ # These tests are skipped by the C++ implementation: it's enough to check that
+ # the message is correctly serialized.
def GetField(self, name):
field_descriptor = self.descriptor.fields_by_name[name]
@@ -66,45 +169,45 @@ class UnknownFieldsTest(basetest.TestCase):
decoder(value, 0, len(value), self.all_fields, result_dict)
return result_dict[field_descriptor]
+ @SkipIfCppImplementation
def testEnum(self):
value = self.GetField('optional_nested_enum')
self.assertEqual(self.all_fields.optional_nested_enum, value)
+ @SkipIfCppImplementation
def testRepeatedEnum(self):
value = self.GetField('repeated_nested_enum')
self.assertEqual(self.all_fields.repeated_nested_enum, value)
+ @SkipIfCppImplementation
def testVarint(self):
value = self.GetField('optional_int32')
self.assertEqual(self.all_fields.optional_int32, value)
+ @SkipIfCppImplementation
def testFixed32(self):
value = self.GetField('optional_fixed32')
self.assertEqual(self.all_fields.optional_fixed32, value)
+ @SkipIfCppImplementation
def testFixed64(self):
value = self.GetField('optional_fixed64')
self.assertEqual(self.all_fields.optional_fixed64, value)
+ @SkipIfCppImplementation
def testLengthDelimited(self):
value = self.GetField('optional_string')
self.assertEqual(self.all_fields.optional_string, value)
+ @SkipIfCppImplementation
def testGroup(self):
value = self.GetField('optionalgroup')
self.assertEqual(self.all_fields.optionalgroup, value)
- def testSerialize(self):
- data = self.empty_message.SerializeToString()
-
- # Don't use assertEqual because we don't want to dump raw binary data to
- # stdout.
- self.assertTrue(data == self.all_fields_data)
-
def testCopyFrom(self):
message = unittest_pb2.TestEmptyMessage()
message.CopyFrom(self.empty_message)
- self.assertEqual(self.unknown_fields, message._unknown_fields)
+ self.assertEqual(message.SerializeToString(), self.all_fields_data)
def testMergeFrom(self):
message = unittest_pb2.TestAllTypes()
@@ -118,64 +221,27 @@ class UnknownFieldsTest(basetest.TestCase):
message.optional_uint32 = 4
destination = unittest_pb2.TestEmptyMessage()
destination.ParseFromString(message.SerializeToString())
- unknown_fields = destination._unknown_fields[:]
destination.MergeFrom(source)
- self.assertEqual(unknown_fields + source._unknown_fields,
- destination._unknown_fields)
+ # Check that the fields where correctly merged, even stored in the unknown
+ # fields set.
+ message.ParseFromString(destination.SerializeToString())
+ self.assertEqual(message.optional_int32, 1)
+ self.assertEqual(message.optional_uint32, 2)
+ self.assertEqual(message.optional_int64, 3)
def testClear(self):
self.empty_message.Clear()
- self.assertEqual(0, len(self.empty_message._unknown_fields))
-
- def testByteSize(self):
- self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
+ # All cleared, even unknown fields.
+ self.assertEqual(self.empty_message.SerializeToString(), b'')
def testUnknownExtensions(self):
message = unittest_pb2.TestEmptyMessageWithExtensions()
message.ParseFromString(self.all_fields_data)
- self.assertEqual(self.empty_message._unknown_fields,
- message._unknown_fields)
+ self.assertEqual(message.SerializeToString(), self.all_fields_data)
- def testListFields(self):
- # Make sure ListFields doesn't return unknown fields.
- self.assertEqual(0, len(self.empty_message.ListFields()))
- def testSerializeMessageSetWireFormatUnknownExtension(self):
- # Create a message using the message set wire format with an unknown
- # message.
- raw = unittest_mset_pb2.RawMessageSet()
-
- # Add an unknown extension.
- item = raw.item.add()
- item.type_id = 1545009
- message1 = unittest_mset_pb2.TestMessageSetExtension1()
- message1.i = 12345
- item.message = message1.SerializeToString()
-
- serialized = raw.SerializeToString()
-
- # Parse message using the message set wire format.
- proto = unittest_mset_pb2.TestMessageSet()
- proto.MergeFromString(serialized)
-
- # Verify that the unknown extension is serialized unchanged
- reserialized = proto.SerializeToString()
- new_raw = unittest_mset_pb2.RawMessageSet()
- new_raw.MergeFromString(reserialized)
- self.assertEqual(raw, new_raw)
-
- def testEquals(self):
- message = unittest_pb2.TestEmptyMessage()
- message.ParseFromString(self.all_fields_data)
- self.assertEqual(self.empty_message, message)
-
- self.all_fields.ClearField('optional_string')
- message.ParseFromString(self.all_fields.SerializeToString())
- self.assertNotEqual(self.empty_message, message)
-
-
-class UnknownFieldsTest(basetest.TestCase):
+class UnknownEnumValuesTest(unittest.TestCase):
def setUp(self):
self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR
@@ -194,7 +260,14 @@ class UnknownFieldsTest(basetest.TestCase):
self.message_data = self.message.SerializeToString()
self.missing_message = missing_enum_values_pb2.TestMissingEnumValues()
self.missing_message.ParseFromString(self.message_data)
- self.unknown_fields = self.missing_message._unknown_fields
+ if api_implementation.Type() != 'cpp':
+ # _unknown_fields is an implementation detail.
+ self.unknown_fields = self.missing_message._unknown_fields
+
+ # All the tests that use GetField() check an implementation detail of the
+ # Python implementation, which stores unknown fields as serialized strings.
+ # These tests are skipped by the C++ implementation: it's enough to check that
+ # the message is correctly serialized.
def GetField(self, name):
field_descriptor = self.descriptor.fields_by_name[name]
@@ -208,15 +281,31 @@ class UnknownFieldsTest(basetest.TestCase):
decoder(value, 0, len(value), self.message, result_dict)
return result_dict[field_descriptor]
+ def testUnknownParseMismatchEnumValue(self):
+ just_string = missing_enum_values_pb2.JustString()
+ just_string.dummy = 'blah'
+
+ missing = missing_enum_values_pb2.TestEnumValues()
+ # The parse is invalid, storing the string proto into the set of
+ # unknown fields.
+ missing.ParseFromString(just_string.SerializeToString())
+
+ # Fetching the enum field shouldn't crash, instead returning the
+ # default value.
+ self.assertEqual(missing.optional_nested_enum, 0)
+
+ @SkipIfCppImplementation
def testUnknownEnumValue(self):
self.assertFalse(self.missing_message.HasField('optional_nested_enum'))
value = self.GetField('optional_nested_enum')
self.assertEqual(self.message.optional_nested_enum, value)
+ @SkipIfCppImplementation
def testUnknownRepeatedEnumValue(self):
value = self.GetField('repeated_nested_enum')
self.assertEqual(self.message.repeated_nested_enum, value)
+ @SkipIfCppImplementation
def testUnknownPackedEnumValue(self):
value = self.GetField('packed_nested_enum')
self.assertEqual(self.message.packed_nested_enum, value)
@@ -228,4 +317,4 @@ class UnknownFieldsTest(basetest.TestCase):
if __name__ == '__main__':
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py
new file mode 100644
index 000000000..7c5dffd0f
--- /dev/null
+++ b/python/google/protobuf/internal/well_known_types.py
@@ -0,0 +1,724 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Contains well known classes.
+
+This files defines well known classes which need extra maintenance including:
+ - Any
+ - Duration
+ - FieldMask
+ - Struct
+ - Timestamp
+"""
+
+__author__ = 'jieluo@google.com (Jie Luo)'
+
+from datetime import datetime
+from datetime import timedelta
+import six
+
+from google.protobuf.descriptor import FieldDescriptor
+
+_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
+_NANOS_PER_SECOND = 1000000000
+_NANOS_PER_MILLISECOND = 1000000
+_NANOS_PER_MICROSECOND = 1000
+_MILLIS_PER_SECOND = 1000
+_MICROS_PER_SECOND = 1000000
+_SECONDS_PER_DAY = 24 * 3600
+
+
+class Error(Exception):
+ """Top-level module error."""
+
+
+class ParseError(Error):
+ """Thrown in case of parsing error."""
+
+
+class Any(object):
+ """Class for Any Message type."""
+
+ def Pack(self, msg, type_url_prefix='type.googleapis.com/'):
+ """Packs the specified message into current Any message."""
+ if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/':
+ self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
+ else:
+ self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
+ self.value = msg.SerializeToString()
+
+ def Unpack(self, msg):
+ """Unpacks the current Any message into specified message."""
+ descriptor = msg.DESCRIPTOR
+ if not self.Is(descriptor):
+ return False
+ msg.ParseFromString(self.value)
+ return True
+
+ def TypeName(self):
+ """Returns the protobuf type name of the inner message."""
+ # Only last part is to be used: b/25630112
+ return self.type_url.split('/')[-1]
+
+ def Is(self, descriptor):
+ """Checks if this Any represents the given protobuf type."""
+ return self.TypeName() == descriptor.full_name
+
+
+class Timestamp(object):
+ """Class for Timestamp message type."""
+
+ def ToJsonString(self):
+ """Converts Timestamp to RFC 3339 date string format.
+
+ Returns:
+ A string converted from timestamp. The string is always Z-normalized
+ and uses 3, 6 or 9 fractional digits as required to represent the
+ exact time. Example of the return format: '1972-01-01T10:00:20.021Z'
+ """
+ nanos = self.nanos % _NANOS_PER_SECOND
+ total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND
+ seconds = total_sec % _SECONDS_PER_DAY
+ days = (total_sec - seconds) // _SECONDS_PER_DAY
+ dt = datetime(1970, 1, 1) + timedelta(days, seconds)
+
+ result = dt.isoformat()
+ if (nanos % 1e9) == 0:
+ # If there are 0 fractional digits, the fractional
+ # point '.' should be omitted when serializing.
+ return result + 'Z'
+ if (nanos % 1e6) == 0:
+ # Serialize 3 fractional digits.
+ return result + '.%03dZ' % (nanos / 1e6)
+ if (nanos % 1e3) == 0:
+ # Serialize 6 fractional digits.
+ return result + '.%06dZ' % (nanos / 1e3)
+ # Serialize 9 fractional digits.
+ return result + '.%09dZ' % nanos
+
+ def FromJsonString(self, value):
+ """Parse a RFC 3339 date string format to Timestamp.
+
+ Args:
+ value: A date string. Any fractional digits (or none) and any offset are
+ accepted as long as they fit into nano-seconds precision.
+ Example of accepted format: '1972-01-01T10:00:20.021-05:00'
+
+ Raises:
+ ParseError: On parsing problems.
+ """
+ timezone_offset = value.find('Z')
+ if timezone_offset == -1:
+ timezone_offset = value.find('+')
+ if timezone_offset == -1:
+ timezone_offset = value.rfind('-')
+ if timezone_offset == -1:
+ raise ParseError(
+ 'Failed to parse timestamp: missing valid timezone offset.')
+ time_value = value[0:timezone_offset]
+ # Parse datetime and nanos.
+ point_position = time_value.find('.')
+ if point_position == -1:
+ second_value = time_value
+ nano_value = ''
+ else:
+ second_value = time_value[:point_position]
+ nano_value = time_value[point_position + 1:]
+ date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT)
+ td = date_object - datetime(1970, 1, 1)
+ seconds = td.seconds + td.days * _SECONDS_PER_DAY
+ if len(nano_value) > 9:
+ raise ParseError(
+ 'Failed to parse Timestamp: nanos {0} more than '
+ '9 fractional digits.'.format(nano_value))
+ if nano_value:
+ nanos = round(float('0.' + nano_value) * 1e9)
+ else:
+ nanos = 0
+ # Parse timezone offsets.
+ if value[timezone_offset] == 'Z':
+ if len(value) != timezone_offset + 1:
+ raise ParseError('Failed to parse timestamp: invalid trailing'
+ ' data {0}.'.format(value))
+ else:
+ timezone = value[timezone_offset:]
+ pos = timezone.find(':')
+ if pos == -1:
+ raise ParseError(
+ 'Invalid timezone offset value: {0}.'.format(timezone))
+ if timezone[0] == '+':
+ seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
+ else:
+ seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
+ # Set seconds and nanos
+ self.seconds = int(seconds)
+ self.nanos = int(nanos)
+
+ def GetCurrentTime(self):
+ """Get the current UTC into Timestamp."""
+ self.FromDatetime(datetime.utcnow())
+
+ def ToNanoseconds(self):
+ """Converts Timestamp to nanoseconds since epoch."""
+ return self.seconds * _NANOS_PER_SECOND + self.nanos
+
+ def ToMicroseconds(self):
+ """Converts Timestamp to microseconds since epoch."""
+ return (self.seconds * _MICROS_PER_SECOND +
+ self.nanos // _NANOS_PER_MICROSECOND)
+
+ def ToMilliseconds(self):
+ """Converts Timestamp to milliseconds since epoch."""
+ return (self.seconds * _MILLIS_PER_SECOND +
+ self.nanos // _NANOS_PER_MILLISECOND)
+
+ def ToSeconds(self):
+ """Converts Timestamp to seconds since epoch."""
+ return self.seconds
+
+ def FromNanoseconds(self, nanos):
+ """Converts nanoseconds since epoch to Timestamp."""
+ self.seconds = nanos // _NANOS_PER_SECOND
+ self.nanos = nanos % _NANOS_PER_SECOND
+
+ def FromMicroseconds(self, micros):
+ """Converts microseconds since epoch to Timestamp."""
+ self.seconds = micros // _MICROS_PER_SECOND
+ self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND
+
+ def FromMilliseconds(self, millis):
+ """Converts milliseconds since epoch to Timestamp."""
+ self.seconds = millis // _MILLIS_PER_SECOND
+ self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND
+
+ def FromSeconds(self, seconds):
+ """Converts seconds since epoch to Timestamp."""
+ self.seconds = seconds
+ self.nanos = 0
+
+ def ToDatetime(self):
+ """Converts Timestamp to datetime."""
+ return datetime.utcfromtimestamp(
+ self.seconds + self.nanos / float(_NANOS_PER_SECOND))
+
+ def FromDatetime(self, dt):
+ """Converts datetime to Timestamp."""
+ td = dt - datetime(1970, 1, 1)
+ self.seconds = td.seconds + td.days * _SECONDS_PER_DAY
+ self.nanos = td.microseconds * _NANOS_PER_MICROSECOND
+
+
+class Duration(object):
+ """Class for Duration message type."""
+
+ def ToJsonString(self):
+ """Converts Duration to string format.
+
+ Returns:
+ A string converted from self. The string format will contains
+ 3, 6, or 9 fractional digits depending on the precision required to
+ represent the exact Duration value. For example: "1s", "1.010s",
+ "1.000000100s", "-3.100s"
+ """
+ if self.seconds < 0 or self.nanos < 0:
+ result = '-'
+ seconds = - self.seconds + int((0 - self.nanos) // 1e9)
+ nanos = (0 - self.nanos) % 1e9
+ else:
+ result = ''
+ seconds = self.seconds + int(self.nanos // 1e9)
+ nanos = self.nanos % 1e9
+ result += '%d' % seconds
+ if (nanos % 1e9) == 0:
+ # If there are 0 fractional digits, the fractional
+ # point '.' should be omitted when serializing.
+ return result + 's'
+ if (nanos % 1e6) == 0:
+ # Serialize 3 fractional digits.
+ return result + '.%03ds' % (nanos / 1e6)
+ if (nanos % 1e3) == 0:
+ # Serialize 6 fractional digits.
+ return result + '.%06ds' % (nanos / 1e3)
+ # Serialize 9 fractional digits.
+ return result + '.%09ds' % nanos
+
+ def FromJsonString(self, value):
+ """Converts a string to Duration.
+
+ Args:
+ value: A string to be converted. The string must end with 's'. Any
+ fractional digits (or none) are accepted as long as they fit into
+ precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s
+
+ Raises:
+ ParseError: On parsing problems.
+ """
+ if len(value) < 1 or value[-1] != 's':
+ raise ParseError(
+ 'Duration must end with letter "s": {0}.'.format(value))
+ try:
+ pos = value.find('.')
+ if pos == -1:
+ self.seconds = int(value[:-1])
+ self.nanos = 0
+ else:
+ self.seconds = int(value[:pos])
+ if value[0] == '-':
+ self.nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
+ else:
+ self.nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
+ except ValueError:
+ raise ParseError(
+ 'Couldn\'t parse duration: {0}.'.format(value))
+
+ def ToNanoseconds(self):
+ """Converts a Duration to nanoseconds."""
+ return self.seconds * _NANOS_PER_SECOND + self.nanos
+
+ def ToMicroseconds(self):
+ """Converts a Duration to microseconds."""
+ micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND)
+ return self.seconds * _MICROS_PER_SECOND + micros
+
+ def ToMilliseconds(self):
+ """Converts a Duration to milliseconds."""
+ millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND)
+ return self.seconds * _MILLIS_PER_SECOND + millis
+
+ def ToSeconds(self):
+ """Converts a Duration to seconds."""
+ return self.seconds
+
+ def FromNanoseconds(self, nanos):
+ """Converts nanoseconds to Duration."""
+ self._NormalizeDuration(nanos // _NANOS_PER_SECOND,
+ nanos % _NANOS_PER_SECOND)
+
+ def FromMicroseconds(self, micros):
+ """Converts microseconds to Duration."""
+ self._NormalizeDuration(
+ micros // _MICROS_PER_SECOND,
+ (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND)
+
+ def FromMilliseconds(self, millis):
+ """Converts milliseconds to Duration."""
+ self._NormalizeDuration(
+ millis // _MILLIS_PER_SECOND,
+ (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND)
+
+ def FromSeconds(self, seconds):
+ """Converts seconds to Duration."""
+ self.seconds = seconds
+ self.nanos = 0
+
+ def ToTimedelta(self):
+ """Converts Duration to timedelta."""
+ return timedelta(
+ seconds=self.seconds, microseconds=_RoundTowardZero(
+ self.nanos, _NANOS_PER_MICROSECOND))
+
+ def FromTimedelta(self, td):
+ """Convertd timedelta to Duration."""
+ self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
+ td.microseconds * _NANOS_PER_MICROSECOND)
+
+ def _NormalizeDuration(self, seconds, nanos):
+ """Set Duration by seconds and nonas."""
+ # Force nanos to be negative if the duration is negative.
+ if seconds < 0 and nanos > 0:
+ seconds += 1
+ nanos -= _NANOS_PER_SECOND
+ self.seconds = seconds
+ self.nanos = nanos
+
+
+def _RoundTowardZero(value, divider):
+ """Truncates the remainder part after division."""
+ # For some languanges, the sign of the remainder is implementation
+ # dependent if any of the operands is negative. Here we enforce
+ # "rounded toward zero" semantics. For example, for (-5) / 2 an
+ # implementation may give -3 as the result with the remainder being
+ # 1. This function ensures we always return -2 (closer to zero).
+ result = value // divider
+ remainder = value % divider
+ if result < 0 and remainder > 0:
+ return result + 1
+ else:
+ return result
+
+
+class FieldMask(object):
+ """Class for FieldMask message type."""
+
+ def ToJsonString(self):
+ """Converts FieldMask to string according to proto3 JSON spec."""
+ return ','.join(self.paths)
+
+ def FromJsonString(self, value):
+ """Converts string to FieldMask according to proto3 JSON spec."""
+ self.Clear()
+ for path in value.split(','):
+ self.paths.append(path)
+
+ def IsValidForDescriptor(self, message_descriptor):
+ """Checks whether the FieldMask is valid for Message Descriptor."""
+ for path in self.paths:
+ if not _IsValidPath(message_descriptor, path):
+ return False
+ return True
+
+ def AllFieldsFromDescriptor(self, message_descriptor):
+ """Gets all direct fields of Message Descriptor to FieldMask."""
+ self.Clear()
+ for field in message_descriptor.fields:
+ self.paths.append(field.name)
+
+ def CanonicalFormFromMask(self, mask):
+ """Converts a FieldMask to the canonical form.
+
+ Removes paths that are covered by another path. For example,
+ "foo.bar" is covered by "foo" and will be removed if "foo"
+ is also in the FieldMask. Then sorts all paths in alphabetical order.
+
+ Args:
+ mask: The original FieldMask to be converted.
+ """
+ tree = _FieldMaskTree(mask)
+ tree.ToFieldMask(self)
+
+ def Union(self, mask1, mask2):
+ """Merges mask1 and mask2 into this FieldMask."""
+ _CheckFieldMaskMessage(mask1)
+ _CheckFieldMaskMessage(mask2)
+ tree = _FieldMaskTree(mask1)
+ tree.MergeFromFieldMask(mask2)
+ tree.ToFieldMask(self)
+
+ def Intersect(self, mask1, mask2):
+ """Intersects mask1 and mask2 into this FieldMask."""
+ _CheckFieldMaskMessage(mask1)
+ _CheckFieldMaskMessage(mask2)
+ tree = _FieldMaskTree(mask1)
+ intersection = _FieldMaskTree()
+ for path in mask2.paths:
+ tree.IntersectPath(path, intersection)
+ intersection.ToFieldMask(self)
+
+ def MergeMessage(
+ self, source, destination,
+ replace_message_field=False, replace_repeated_field=False):
+ """Merges fields specified in FieldMask from source to destination.
+
+ Args:
+ source: Source message.
+ destination: The destination message to be merged into.
+ replace_message_field: Replace message field if True. Merge message
+ field if False.
+ replace_repeated_field: Replace repeated field if True. Append
+ elements of repeated field if False.
+ """
+ tree = _FieldMaskTree(self)
+ tree.MergeMessage(
+ source, destination, replace_message_field, replace_repeated_field)
+
+
+def _IsValidPath(message_descriptor, path):
+ """Checks whether the path is valid for Message Descriptor."""
+ parts = path.split('.')
+ last = parts.pop()
+ for name in parts:
+ field = message_descriptor.fields_by_name[name]
+ if (field is None or
+ field.label == FieldDescriptor.LABEL_REPEATED or
+ field.type != FieldDescriptor.TYPE_MESSAGE):
+ return False
+ message_descriptor = field.message_type
+ return last in message_descriptor.fields_by_name
+
+
+def _CheckFieldMaskMessage(message):
+ """Raises ValueError if message is not a FieldMask."""
+ message_descriptor = message.DESCRIPTOR
+ if (message_descriptor.name != 'FieldMask' or
+ message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
+ raise ValueError('Message {0} is not a FieldMask.'.format(
+ message_descriptor.full_name))
+
+
+class _FieldMaskTree(object):
+ """Represents a FieldMask in a tree structure.
+
+ For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
+ the FieldMaskTree will be:
+ [_root] -+- foo -+- bar
+ | |
+ | +- baz
+ |
+ +- bar --- baz
+ In the tree, each leaf node represents a field path.
+ """
+
+ def __init__(self, field_mask=None):
+ """Initializes the tree by FieldMask."""
+ self._root = {}
+ if field_mask:
+ self.MergeFromFieldMask(field_mask)
+
+ def MergeFromFieldMask(self, field_mask):
+ """Merges a FieldMask to the tree."""
+ for path in field_mask.paths:
+ self.AddPath(path)
+
+ def AddPath(self, path):
+ """Adds a field path into the tree.
+
+ If the field path to add is a sub-path of an existing field path
+ in the tree (i.e., a leaf node), it means the tree already matches
+ the given path so nothing will be added to the tree. If the path
+ matches an existing non-leaf node in the tree, that non-leaf node
+ will be turned into a leaf node with all its children removed because
+ the path matches all the node's children. Otherwise, a new path will
+ be added.
+
+ Args:
+ path: The field path to add.
+ """
+ node = self._root
+ for name in path.split('.'):
+ if name not in node:
+ node[name] = {}
+ elif not node[name]:
+ # Pre-existing empty node implies we already have this entire tree.
+ return
+ node = node[name]
+ # Remove any sub-trees we might have had.
+ node.clear()
+
+ def ToFieldMask(self, field_mask):
+ """Converts the tree to a FieldMask."""
+ field_mask.Clear()
+ _AddFieldPaths(self._root, '', field_mask)
+
+ def IntersectPath(self, path, intersection):
+ """Calculates the intersection part of a field path with this tree.
+
+ Args:
+ path: The field path to calculates.
+ intersection: The out tree to record the intersection part.
+ """
+ node = self._root
+ for name in path.split('.'):
+ if name not in node:
+ return
+ elif not node[name]:
+ intersection.AddPath(path)
+ return
+ node = node[name]
+ intersection.AddLeafNodes(path, node)
+
+ def AddLeafNodes(self, prefix, node):
+ """Adds leaf nodes begin with prefix to this tree."""
+ if not node:
+ self.AddPath(prefix)
+ for name in node:
+ child_path = prefix + '.' + name
+ self.AddLeafNodes(child_path, node[name])
+
+ def MergeMessage(
+ self, source, destination,
+ replace_message, replace_repeated):
+ """Merge all fields specified by this tree from source to destination."""
+ _MergeMessage(
+ self._root, source, destination, replace_message, replace_repeated)
+
+
+def _StrConvert(value):
+ """Converts value to str if it is not."""
+ # This file is imported by c extension and some methods like ClearField
+ # requires string for the field name. py2/py3 has different text
+ # type and may use unicode.
+ if not isinstance(value, str):
+ return value.encode('utf-8')
+ return value
+
+
+def _MergeMessage(
+ node, source, destination, replace_message, replace_repeated):
+ """Merge all fields specified by a sub-tree from source to destination."""
+ source_descriptor = source.DESCRIPTOR
+ for name in node:
+ child = node[name]
+ field = source_descriptor.fields_by_name[name]
+ if field is None:
+ raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
+ name, source_descriptor.full_name))
+ if child:
+ # Sub-paths are only allowed for singular message fields.
+ if (field.label == FieldDescriptor.LABEL_REPEATED or
+ field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
+ raise ValueError('Error: Field {0} in message {1} is not a singular '
+ 'message field and cannot have sub-fields.'.format(
+ name, source_descriptor.full_name))
+ _MergeMessage(
+ child, getattr(source, name), getattr(destination, name),
+ replace_message, replace_repeated)
+ continue
+ if field.label == FieldDescriptor.LABEL_REPEATED:
+ if replace_repeated:
+ destination.ClearField(_StrConvert(name))
+ repeated_source = getattr(source, name)
+ repeated_destination = getattr(destination, name)
+ if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
+ for item in repeated_source:
+ repeated_destination.add().MergeFrom(item)
+ else:
+ repeated_destination.extend(repeated_source)
+ else:
+ if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
+ if replace_message:
+ destination.ClearField(_StrConvert(name))
+ if source.HasField(name):
+ getattr(destination, name).MergeFrom(getattr(source, name))
+ else:
+ setattr(destination, name, getattr(source, name))
+
+
+def _AddFieldPaths(node, prefix, field_mask):
+ """Adds the field paths descended from node to field_mask."""
+ if not node:
+ field_mask.paths.append(prefix)
+ return
+ for name in sorted(node):
+ if prefix:
+ child_path = prefix + '.' + name
+ else:
+ child_path = name
+ _AddFieldPaths(node[name], child_path, field_mask)
+
+
+_INT_OR_FLOAT = six.integer_types + (float,)
+
+
+def _SetStructValue(struct_value, value):
+ if value is None:
+ struct_value.null_value = 0
+ elif isinstance(value, bool):
+ # Note: this check must come before the number check because in Python
+ # True and False are also considered numbers.
+ struct_value.bool_value = value
+ elif isinstance(value, six.string_types):
+ struct_value.string_value = value
+ elif isinstance(value, _INT_OR_FLOAT):
+ struct_value.number_value = value
+ else:
+ raise ValueError('Unexpected type')
+
+
+def _GetStructValue(struct_value):
+ which = struct_value.WhichOneof('kind')
+ if which == 'struct_value':
+ return struct_value.struct_value
+ elif which == 'null_value':
+ return None
+ elif which == 'number_value':
+ return struct_value.number_value
+ elif which == 'string_value':
+ return struct_value.string_value
+ elif which == 'bool_value':
+ return struct_value.bool_value
+ elif which == 'list_value':
+ return struct_value.list_value
+ elif which is None:
+ raise ValueError('Value not set')
+
+
+class Struct(object):
+ """Class for Struct message type."""
+
+ __slots__ = []
+
+ def __getitem__(self, key):
+ return _GetStructValue(self.fields[key])
+
+ def __setitem__(self, key, value):
+ _SetStructValue(self.fields[key], value)
+
+ def get_or_create_list(self, key):
+ """Returns a list for this key, creating if it didn't exist already."""
+ return self.fields[key].list_value
+
+ def get_or_create_struct(self, key):
+ """Returns a struct for this key, creating if it didn't exist already."""
+ return self.fields[key].struct_value
+
+ # TODO(haberman): allow constructing/merging from dict.
+
+
+class ListValue(object):
+ """Class for ListValue message type."""
+
+ def __len__(self):
+ return len(self.values)
+
+ def append(self, value):
+ _SetStructValue(self.values.add(), value)
+
+ def extend(self, elem_seq):
+ for value in elem_seq:
+ self.append(value)
+
+ def __getitem__(self, index):
+ """Retrieves item by the specified index."""
+ return _GetStructValue(self.values.__getitem__(index))
+
+ def __setitem__(self, index, value):
+ _SetStructValue(self.values.__getitem__(index), value)
+
+ def items(self):
+ for i in range(len(self)):
+ yield self[i]
+
+ def add_struct(self):
+ """Appends and returns a struct value as the next value in the list."""
+ return self.values.add().struct_value
+
+ def add_list(self):
+ """Appends and returns a list value as the next value in the list."""
+ return self.values.add().list_value
+
+
+WKTBASES = {
+ 'google.protobuf.Any': Any,
+ 'google.protobuf.Duration': Duration,
+ 'google.protobuf.FieldMask': FieldMask,
+ 'google.protobuf.ListValue': ListValue,
+ 'google.protobuf.Struct': Struct,
+ 'google.protobuf.Timestamp': Timestamp,
+}
diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py
new file mode 100644
index 000000000..2f32ac994
--- /dev/null
+++ b/python/google/protobuf/internal/well_known_types_test.py
@@ -0,0 +1,644 @@
+#! /usr/bin/env python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test for google.protobuf.internal.well_known_types."""
+
+__author__ = 'jieluo@google.com (Jie Luo)'
+
+from datetime import datetime
+
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
+from google.protobuf import any_pb2
+from google.protobuf import duration_pb2
+from google.protobuf import field_mask_pb2
+from google.protobuf import struct_pb2
+from google.protobuf import timestamp_pb2
+from google.protobuf import unittest_pb2
+from google.protobuf.internal import any_test_pb2
+from google.protobuf.internal import test_util
+from google.protobuf.internal import well_known_types
+from google.protobuf import descriptor
+from google.protobuf import text_format
+
+
+class TimeUtilTestBase(unittest.TestCase):
+
+ def CheckTimestampConversion(self, message, text):
+ self.assertEqual(text, message.ToJsonString())
+ parsed_message = timestamp_pb2.Timestamp()
+ parsed_message.FromJsonString(text)
+ self.assertEqual(message, parsed_message)
+
+ def CheckDurationConversion(self, message, text):
+ self.assertEqual(text, message.ToJsonString())
+ parsed_message = duration_pb2.Duration()
+ parsed_message.FromJsonString(text)
+ self.assertEqual(message, parsed_message)
+
+
+class TimeUtilTest(TimeUtilTestBase):
+
+ def testTimestampSerializeAndParse(self):
+ message = timestamp_pb2.Timestamp()
+ # Generated output should contain 3, 6, or 9 fractional digits.
+ message.seconds = 0
+ message.nanos = 0
+ self.CheckTimestampConversion(message, '1970-01-01T00:00:00Z')
+ message.nanos = 10000000
+ self.CheckTimestampConversion(message, '1970-01-01T00:00:00.010Z')
+ message.nanos = 10000
+ self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000010Z')
+ message.nanos = 10
+ self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000000010Z')
+ # Test min timestamps.
+ message.seconds = -62135596800
+ message.nanos = 0
+ self.CheckTimestampConversion(message, '0001-01-01T00:00:00Z')
+ # Test max timestamps.
+ message.seconds = 253402300799
+ message.nanos = 999999999
+ self.CheckTimestampConversion(message, '9999-12-31T23:59:59.999999999Z')
+ # Test negative timestamps.
+ message.seconds = -1
+ self.CheckTimestampConversion(message, '1969-12-31T23:59:59.999999999Z')
+
+ # Parsing accepts an fractional digits as long as they fit into nano
+ # precision.
+ message.FromJsonString('1970-01-01T00:00:00.1Z')
+ self.assertEqual(0, message.seconds)
+ self.assertEqual(100000000, message.nanos)
+ # Parsing accpets offsets.
+ message.FromJsonString('1970-01-01T00:00:00-08:00')
+ self.assertEqual(8 * 3600, message.seconds)
+ self.assertEqual(0, message.nanos)
+
+ def testDurationSerializeAndParse(self):
+ message = duration_pb2.Duration()
+ # Generated output should contain 3, 6, or 9 fractional digits.
+ message.seconds = 0
+ message.nanos = 0
+ self.CheckDurationConversion(message, '0s')
+ message.nanos = 10000000
+ self.CheckDurationConversion(message, '0.010s')
+ message.nanos = 10000
+ self.CheckDurationConversion(message, '0.000010s')
+ message.nanos = 10
+ self.CheckDurationConversion(message, '0.000000010s')
+
+ # Test min and max
+ message.seconds = 315576000000
+ message.nanos = 999999999
+ self.CheckDurationConversion(message, '315576000000.999999999s')
+ message.seconds = -315576000000
+ message.nanos = -999999999
+ self.CheckDurationConversion(message, '-315576000000.999999999s')
+
+ # Parsing accepts an fractional digits as long as they fit into nano
+ # precision.
+ message.FromJsonString('0.1s')
+ self.assertEqual(100000000, message.nanos)
+ message.FromJsonString('0.0000001s')
+ self.assertEqual(100, message.nanos)
+
+ def testTimestampIntegerConversion(self):
+ message = timestamp_pb2.Timestamp()
+ message.FromNanoseconds(1)
+ self.assertEqual('1970-01-01T00:00:00.000000001Z',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToNanoseconds())
+
+ message.FromNanoseconds(-1)
+ self.assertEqual('1969-12-31T23:59:59.999999999Z',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToNanoseconds())
+
+ message.FromMicroseconds(1)
+ self.assertEqual('1970-01-01T00:00:00.000001Z',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToMicroseconds())
+
+ message.FromMicroseconds(-1)
+ self.assertEqual('1969-12-31T23:59:59.999999Z',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToMicroseconds())
+
+ message.FromMilliseconds(1)
+ self.assertEqual('1970-01-01T00:00:00.001Z',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToMilliseconds())
+
+ message.FromMilliseconds(-1)
+ self.assertEqual('1969-12-31T23:59:59.999Z',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToMilliseconds())
+
+ message.FromSeconds(1)
+ self.assertEqual('1970-01-01T00:00:01Z',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToSeconds())
+
+ message.FromSeconds(-1)
+ self.assertEqual('1969-12-31T23:59:59Z',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToSeconds())
+
+ message.FromNanoseconds(1999)
+ self.assertEqual(1, message.ToMicroseconds())
+ # For negative values, Timestamp will be rounded down.
+ # For example, "1969-12-31T23:59:59.5Z" (i.e., -0.5s) rounded to seconds
+ # will be "1969-12-31T23:59:59Z" (i.e., -1s) rather than
+ # "1970-01-01T00:00:00Z" (i.e., 0s).
+ message.FromNanoseconds(-1999)
+ self.assertEqual(-2, message.ToMicroseconds())
+
+ def testDurationIntegerConversion(self):
+ message = duration_pb2.Duration()
+ message.FromNanoseconds(1)
+ self.assertEqual('0.000000001s',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToNanoseconds())
+
+ message.FromNanoseconds(-1)
+ self.assertEqual('-0.000000001s',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToNanoseconds())
+
+ message.FromMicroseconds(1)
+ self.assertEqual('0.000001s',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToMicroseconds())
+
+ message.FromMicroseconds(-1)
+ self.assertEqual('-0.000001s',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToMicroseconds())
+
+ message.FromMilliseconds(1)
+ self.assertEqual('0.001s',
+ message.ToJsonString())
+ self.assertEqual(1, message.ToMilliseconds())
+
+ message.FromMilliseconds(-1)
+ self.assertEqual('-0.001s',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToMilliseconds())
+
+ message.FromSeconds(1)
+ self.assertEqual('1s', message.ToJsonString())
+ self.assertEqual(1, message.ToSeconds())
+
+ message.FromSeconds(-1)
+ self.assertEqual('-1s',
+ message.ToJsonString())
+ self.assertEqual(-1, message.ToSeconds())
+
+ # Test truncation behavior.
+ message.FromNanoseconds(1999)
+ self.assertEqual(1, message.ToMicroseconds())
+
+ # For negative values, Duration will be rounded towards 0.
+ message.FromNanoseconds(-1999)
+ self.assertEqual(-1, message.ToMicroseconds())
+
+ def testDatetimeConverison(self):
+ message = timestamp_pb2.Timestamp()
+ dt = datetime(1970, 1, 1)
+ message.FromDatetime(dt)
+ self.assertEqual(dt, message.ToDatetime())
+
+ message.FromMilliseconds(1999)
+ self.assertEqual(datetime(1970, 1, 1, 0, 0, 1, 999000),
+ message.ToDatetime())
+
+ def testTimedeltaConversion(self):
+ message = duration_pb2.Duration()
+ message.FromNanoseconds(1999999999)
+ td = message.ToTimedelta()
+ self.assertEqual(1, td.seconds)
+ self.assertEqual(999999, td.microseconds)
+
+ message.FromNanoseconds(-1999999999)
+ td = message.ToTimedelta()
+ self.assertEqual(-1, td.days)
+ self.assertEqual(86398, td.seconds)
+ self.assertEqual(1, td.microseconds)
+
+ message.FromMicroseconds(-1)
+ td = message.ToTimedelta()
+ self.assertEqual(-1, td.days)
+ self.assertEqual(86399, td.seconds)
+ self.assertEqual(999999, td.microseconds)
+ converted_message = duration_pb2.Duration()
+ converted_message.FromTimedelta(td)
+ self.assertEqual(message, converted_message)
+
+ def testInvalidTimestamp(self):
+ message = timestamp_pb2.Timestamp()
+ self.assertRaisesRegexp(
+ ValueError,
+ 'time data \'10000-01-01T00:00:00\' does not match'
+ ' format \'%Y-%m-%dT%H:%M:%S\'',
+ message.FromJsonString, '10000-01-01T00:00:00.00Z')
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'nanos 0123456789012 more than 9 fractional digits.',
+ message.FromJsonString,
+ '1970-01-01T00:00:00.0123456789012Z')
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ (r'Invalid timezone offset value: \+08.'),
+ message.FromJsonString,
+ '1972-01-01T01:00:00.01+08',)
+ self.assertRaisesRegexp(
+ ValueError,
+ 'year is out of range',
+ message.FromJsonString,
+ '0000-01-01T00:00:00Z')
+ message.seconds = 253402300800
+ self.assertRaisesRegexp(
+ OverflowError,
+ 'date value out of range',
+ message.ToJsonString)
+
+ def testInvalidDuration(self):
+ message = duration_pb2.Duration()
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'Duration must end with letter "s": 1.',
+ message.FromJsonString, '1')
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'Couldn\'t parse duration: 1...2s.',
+ message.FromJsonString, '1...2s')
+
+
+class FieldMaskTest(unittest.TestCase):
+
+ def testStringFormat(self):
+ mask = field_mask_pb2.FieldMask()
+ self.assertEqual('', mask.ToJsonString())
+ mask.paths.append('foo')
+ self.assertEqual('foo', mask.ToJsonString())
+ mask.paths.append('bar')
+ self.assertEqual('foo,bar', mask.ToJsonString())
+
+ mask.FromJsonString('')
+ self.assertEqual('', mask.ToJsonString())
+ mask.FromJsonString('foo')
+ self.assertEqual(['foo'], mask.paths)
+ mask.FromJsonString('foo,bar')
+ self.assertEqual(['foo', 'bar'], mask.paths)
+
+ def testDescriptorToFieldMask(self):
+ mask = field_mask_pb2.FieldMask()
+ msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ mask.AllFieldsFromDescriptor(msg_descriptor)
+ self.assertEqual(75, len(mask.paths))
+ self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
+ for field in msg_descriptor.fields:
+ self.assertTrue(field.name in mask.paths)
+ mask.paths.append('optional_nested_message.bb')
+ self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
+ mask.paths.append('repeated_nested_message.bb')
+ self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
+
+ def testCanonicalFrom(self):
+ mask = field_mask_pb2.FieldMask()
+ out_mask = field_mask_pb2.FieldMask()
+ # Paths will be sorted.
+ mask.FromJsonString('baz.quz,bar,foo')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('bar,baz.quz,foo', out_mask.ToJsonString())
+ # Duplicated paths will be removed.
+ mask.FromJsonString('foo,bar,foo')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('bar,foo', out_mask.ToJsonString())
+ # Sub-paths of other paths will be removed.
+ mask.FromJsonString('foo.b1,bar.b1,foo.b2,bar')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('bar,foo.b1,foo.b2', out_mask.ToJsonString())
+
+ # Test more deeply nested cases.
+ mask.FromJsonString(
+ 'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('foo.bar.baz1,foo.bar.baz2',
+ out_mask.ToJsonString())
+ mask.FromJsonString(
+ 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('foo.bar.baz1,foo.bar.baz2',
+ out_mask.ToJsonString())
+ mask.FromJsonString(
+ 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('foo.bar', out_mask.ToJsonString())
+ mask.FromJsonString(
+ 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo')
+ out_mask.CanonicalFormFromMask(mask)
+ self.assertEqual('foo', out_mask.ToJsonString())
+
+ def testUnion(self):
+ mask1 = field_mask_pb2.FieldMask()
+ mask2 = field_mask_pb2.FieldMask()
+ out_mask = field_mask_pb2.FieldMask()
+ mask1.FromJsonString('foo,baz')
+ mask2.FromJsonString('bar,quz')
+ out_mask.Union(mask1, mask2)
+ self.assertEqual('bar,baz,foo,quz', out_mask.ToJsonString())
+ # Overlap with duplicated paths.
+ mask1.FromJsonString('foo,baz.bb')
+ mask2.FromJsonString('baz.bb,quz')
+ out_mask.Union(mask1, mask2)
+ self.assertEqual('baz.bb,foo,quz', out_mask.ToJsonString())
+ # Overlap with paths covering some other paths.
+ mask1.FromJsonString('foo.bar.baz,quz')
+ mask2.FromJsonString('foo.bar,bar')
+ out_mask.Union(mask1, mask2)
+ self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString())
+
+ def testIntersect(self):
+ mask1 = field_mask_pb2.FieldMask()
+ mask2 = field_mask_pb2.FieldMask()
+ out_mask = field_mask_pb2.FieldMask()
+ # Test cases without overlapping.
+ mask1.FromJsonString('foo,baz')
+ mask2.FromJsonString('bar,quz')
+ out_mask.Intersect(mask1, mask2)
+ self.assertEqual('', out_mask.ToJsonString())
+ # Overlap with duplicated paths.
+ mask1.FromJsonString('foo,baz.bb')
+ mask2.FromJsonString('baz.bb,quz')
+ out_mask.Intersect(mask1, mask2)
+ self.assertEqual('baz.bb', out_mask.ToJsonString())
+ # Overlap with paths covering some other paths.
+ mask1.FromJsonString('foo.bar.baz,quz')
+ mask2.FromJsonString('foo.bar,bar')
+ out_mask.Intersect(mask1, mask2)
+ self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
+ mask1.FromJsonString('foo.bar,bar')
+ mask2.FromJsonString('foo.bar.baz,quz')
+ out_mask.Intersect(mask1, mask2)
+ self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
+
+ def testMergeMessage(self):
+ # Test merge one field.
+ src = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(src)
+ for field in src.DESCRIPTOR.fields:
+ if field.containing_oneof:
+ continue
+ field_name = field.name
+ dst = unittest_pb2.TestAllTypes()
+ # Only set one path to mask.
+ mask = field_mask_pb2.FieldMask()
+ mask.paths.append(field_name)
+ mask.MergeMessage(src, dst)
+ # The expected result message.
+ msg = unittest_pb2.TestAllTypes()
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ repeated_src = getattr(src, field_name)
+ repeated_msg = getattr(msg, field_name)
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ for item in repeated_src:
+ repeated_msg.add().CopyFrom(item)
+ else:
+ repeated_msg.extend(repeated_src)
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ getattr(msg, field_name).CopyFrom(getattr(src, field_name))
+ else:
+ setattr(msg, field_name, getattr(src, field_name))
+ # Only field specified in mask is merged.
+ self.assertEqual(msg, dst)
+
+ # Test merge nested fields.
+ nested_src = unittest_pb2.NestedTestAllTypes()
+ nested_dst = unittest_pb2.NestedTestAllTypes()
+ nested_src.child.payload.optional_int32 = 1234
+ nested_src.child.child.payload.optional_int32 = 5678
+ mask = field_mask_pb2.FieldMask()
+ mask.FromJsonString('child.payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(0, nested_dst.child.child.payload.optional_int32)
+
+ mask.FromJsonString('child.child.payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
+
+ nested_dst.Clear()
+ mask.FromJsonString('child.child.payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(0, nested_dst.child.payload.optional_int32)
+ self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
+
+ nested_dst.Clear()
+ mask.FromJsonString('child')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
+
+ # Test MergeOptions.
+ nested_dst.Clear()
+ nested_dst.child.payload.optional_int64 = 4321
+ # Message fields will be merged by default.
+ mask.FromJsonString('child.payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(4321, nested_dst.child.payload.optional_int64)
+ # Change the behavior to replace message fields.
+ mask.FromJsonString('child.payload')
+ mask.MergeMessage(nested_src, nested_dst, True, False)
+ self.assertEqual(1234, nested_dst.child.payload.optional_int32)
+ self.assertEqual(0, nested_dst.child.payload.optional_int64)
+
+ # By default, fields missing in source are not cleared in destination.
+ nested_dst.payload.optional_int32 = 1234
+ self.assertTrue(nested_dst.HasField('payload'))
+ mask.FromJsonString('payload')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertTrue(nested_dst.HasField('payload'))
+ # But they are cleared when replacing message fields.
+ nested_dst.Clear()
+ nested_dst.payload.optional_int32 = 1234
+ mask.FromJsonString('payload')
+ mask.MergeMessage(nested_src, nested_dst, True, False)
+ self.assertFalse(nested_dst.HasField('payload'))
+
+ nested_src.payload.repeated_int32.append(1234)
+ nested_dst.payload.repeated_int32.append(5678)
+ # Repeated fields will be appended by default.
+ mask.FromJsonString('payload.repeated_int32')
+ mask.MergeMessage(nested_src, nested_dst)
+ self.assertEqual(2, len(nested_dst.payload.repeated_int32))
+ self.assertEqual(5678, nested_dst.payload.repeated_int32[0])
+ self.assertEqual(1234, nested_dst.payload.repeated_int32[1])
+ # Change the behavior to replace repeated fields.
+ mask.FromJsonString('payload.repeated_int32')
+ mask.MergeMessage(nested_src, nested_dst, False, True)
+ self.assertEqual(1, len(nested_dst.payload.repeated_int32))
+ self.assertEqual(1234, nested_dst.payload.repeated_int32[0])
+
+
+class StructTest(unittest.TestCase):
+
+ def testStruct(self):
+ struct = struct_pb2.Struct()
+ struct_class = struct.__class__
+
+ struct['key1'] = 5
+ struct['key2'] = 'abc'
+ struct['key3'] = True
+ struct.get_or_create_struct('key4')['subkey'] = 11.0
+ struct_list = struct.get_or_create_list('key5')
+ struct_list.extend([6, 'seven', True, False, None])
+ struct_list.add_struct()['subkey2'] = 9
+
+ self.assertTrue(isinstance(struct, well_known_types.Struct))
+ self.assertEquals(5, struct['key1'])
+ self.assertEquals('abc', struct['key2'])
+ self.assertIs(True, struct['key3'])
+ self.assertEquals(11, struct['key4']['subkey'])
+ inner_struct = struct_class()
+ inner_struct['subkey2'] = 9
+ self.assertEquals([6, 'seven', True, False, None, inner_struct],
+ list(struct['key5'].items()))
+
+ serialized = struct.SerializeToString()
+
+ struct2 = struct_pb2.Struct()
+ struct2.ParseFromString(serialized)
+
+ self.assertEquals(struct, struct2)
+
+ self.assertTrue(isinstance(struct2, well_known_types.Struct))
+ self.assertEquals(5, struct2['key1'])
+ self.assertEquals('abc', struct2['key2'])
+ self.assertIs(True, struct2['key3'])
+ self.assertEquals(11, struct2['key4']['subkey'])
+ self.assertEquals([6, 'seven', True, False, None, inner_struct],
+ list(struct2['key5'].items()))
+
+ struct_list = struct2['key5']
+ self.assertEquals(6, struct_list[0])
+ self.assertEquals('seven', struct_list[1])
+ self.assertEquals(True, struct_list[2])
+ self.assertEquals(False, struct_list[3])
+ self.assertEquals(None, struct_list[4])
+ self.assertEquals(inner_struct, struct_list[5])
+
+ struct_list[1] = 7
+ self.assertEquals(7, struct_list[1])
+
+ struct_list.add_list().extend([1, 'two', True, False, None])
+ self.assertEquals([1, 'two', True, False, None],
+ list(struct_list[6].items()))
+
+ text_serialized = str(struct)
+ struct3 = struct_pb2.Struct()
+ text_format.Merge(text_serialized, struct3)
+ self.assertEquals(struct, struct3)
+
+ struct.get_or_create_struct('key3')['replace'] = 12
+ self.assertEquals(12, struct['key3']['replace'])
+
+
+class AnyTest(unittest.TestCase):
+
+ def testAnyMessage(self):
+ # Creates and sets message.
+ msg = any_test_pb2.TestAny()
+ msg_descriptor = msg.DESCRIPTOR
+ all_types = unittest_pb2.TestAllTypes()
+ all_descriptor = all_types.DESCRIPTOR
+ all_types.repeated_string.append(u'\u00fc\ua71f')
+ # Packs to Any.
+ msg.value.Pack(all_types)
+ self.assertEqual(msg.value.type_url,
+ 'type.googleapis.com/%s' % all_descriptor.full_name)
+ self.assertEqual(msg.value.value,
+ all_types.SerializeToString())
+ # Tests Is() method.
+ self.assertTrue(msg.value.Is(all_descriptor))
+ self.assertFalse(msg.value.Is(msg_descriptor))
+ # Unpacks Any.
+ unpacked_message = unittest_pb2.TestAllTypes()
+ self.assertTrue(msg.value.Unpack(unpacked_message))
+ self.assertEqual(all_types, unpacked_message)
+ # Unpacks to different type.
+ self.assertFalse(msg.value.Unpack(msg))
+ # Only Any messages have Pack method.
+ try:
+ msg.Pack(all_types)
+ except AttributeError:
+ pass
+ else:
+ raise AttributeError('%s should not have Pack method.' %
+ msg_descriptor.full_name)
+
+ def testMessageName(self):
+ # Creates and sets message.
+ submessage = any_test_pb2.TestAny()
+ submessage.int_value = 12345
+ msg = any_pb2.Any()
+ msg.Pack(submessage)
+ self.assertEqual(msg.TypeName(), 'google.protobuf.internal.TestAny')
+
+ def testPackWithCustomTypeUrl(self):
+ submessage = any_test_pb2.TestAny()
+ submessage.int_value = 12345
+ msg = any_pb2.Any()
+ # Pack with a custom type URL prefix.
+ msg.Pack(submessage, 'type.myservice.com')
+ self.assertEqual(msg.type_url,
+ 'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name)
+ # Pack with a custom type URL prefix ending with '/'.
+ msg.Pack(submessage, 'type.myservice.com/')
+ self.assertEqual(msg.type_url,
+ 'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name)
+ # Pack with an empty type URL prefix.
+ msg.Pack(submessage, '')
+ self.assertEqual(msg.type_url,
+ '/%s' % submessage.DESCRIPTOR.full_name)
+ # Test unpacking the type.
+ unpacked_message = any_test_pb2.TestAny()
+ self.assertTrue(msg.Unpack(unpacked_message))
+ self.assertEqual(submessage, unpacked_message)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py
index f39035cae..da120f336 100755
--- a/python/google/protobuf/internal/wire_format_test.py
+++ b/python/google/protobuf/internal/wire_format_test.py
@@ -1,4 +1,4 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -34,12 +34,16 @@
__author__ = 'robinson@google.com (Will Robinson)'
-from google.apputils import basetest
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
from google.protobuf import message
from google.protobuf.internal import wire_format
-class WireFormatTest(basetest.TestCase):
+class WireFormatTest(unittest.TestCase):
def testPackTag(self):
field_number = 0xabc
@@ -250,4 +254,4 @@ class WireFormatTest(basetest.TestCase):
if __name__ == '__main__':
- basetest.main()
+ unittest.main()
diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py
new file mode 100644
index 000000000..7921556e6
--- /dev/null
+++ b/python/google/protobuf/json_format.py
@@ -0,0 +1,645 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Contains routines for printing protocol messages in JSON format.
+
+Simple usage example:
+
+ # Create a proto object and serialize it to a json format string.
+ message = my_proto_pb2.MyMessage(foo='bar')
+ json_string = json_format.MessageToJson(message)
+
+ # Parse a json format string to proto object.
+ message = json_format.Parse(json_string, my_proto_pb2.MyMessage())
+"""
+
+__author__ = 'jieluo@google.com (Jie Luo)'
+
+import base64
+import json
+import math
+import six
+import sys
+
+from google.protobuf import descriptor
+from google.protobuf import symbol_database
+
+_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
+_INT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT32,
+ descriptor.FieldDescriptor.CPPTYPE_UINT32,
+ descriptor.FieldDescriptor.CPPTYPE_INT64,
+ descriptor.FieldDescriptor.CPPTYPE_UINT64])
+_INT64_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT64,
+ descriptor.FieldDescriptor.CPPTYPE_UINT64])
+_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT,
+ descriptor.FieldDescriptor.CPPTYPE_DOUBLE])
+_INFINITY = 'Infinity'
+_NEG_INFINITY = '-Infinity'
+_NAN = 'NaN'
+
+
+class Error(Exception):
+ """Top-level module error for json_format."""
+
+
+class SerializeToJsonError(Error):
+ """Thrown if serialization to JSON fails."""
+
+
+class ParseError(Error):
+ """Thrown in case of parsing error."""
+
+
+def MessageToJson(message, including_default_value_fields=False):
+ """Converts protobuf message to JSON format.
+
+ Args:
+ message: The protocol buffers message instance to serialize.
+ including_default_value_fields: If True, singular primitive fields,
+ repeated fields, and map fields will always be serialized. If
+ False, only serialize non-empty fields. Singular message fields
+ and oneof fields are not affected by this option.
+
+ Returns:
+ A string containing the JSON formatted protocol buffer message.
+ """
+ js = _MessageToJsonObject(message, including_default_value_fields)
+ return json.dumps(js, indent=2)
+
+
+def _MessageToJsonObject(message, including_default_value_fields):
+ """Converts message to an object according to Proto3 JSON Specification."""
+ message_descriptor = message.DESCRIPTOR
+ full_name = message_descriptor.full_name
+ if _IsWrapperMessage(message_descriptor):
+ return _WrapperMessageToJsonObject(message)
+ if full_name in _WKTJSONMETHODS:
+ return _WKTJSONMETHODS[full_name][0](
+ message, including_default_value_fields)
+ js = {}
+ return _RegularMessageToJsonObject(
+ message, js, including_default_value_fields)
+
+
+def _IsMapEntry(field):
+ return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ field.message_type.has_options and
+ field.message_type.GetOptions().map_entry)
+
+
+def _RegularMessageToJsonObject(message, js, including_default_value_fields):
+ """Converts normal message according to Proto3 JSON Specification."""
+ fields = message.ListFields()
+ include_default = including_default_value_fields
+
+ try:
+ for field, value in fields:
+ name = field.camelcase_name
+ if _IsMapEntry(field):
+ # Convert a map field.
+ v_field = field.message_type.fields_by_name['value']
+ js_map = {}
+ for key in value:
+ if isinstance(key, bool):
+ if key:
+ recorded_key = 'true'
+ else:
+ recorded_key = 'false'
+ else:
+ recorded_key = key
+ js_map[recorded_key] = _FieldToJsonObject(
+ v_field, value[key], including_default_value_fields)
+ js[name] = js_map
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ # Convert a repeated field.
+ js[name] = [_FieldToJsonObject(field, k, include_default)
+ for k in value]
+ else:
+ js[name] = _FieldToJsonObject(field, value, include_default)
+
+ # Serialize default value if including_default_value_fields is True.
+ if including_default_value_fields:
+ message_descriptor = message.DESCRIPTOR
+ for field in message_descriptor.fields:
+ # Singular message fields and oneof fields will not be affected.
+ if ((field.label != descriptor.FieldDescriptor.LABEL_REPEATED and
+ field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or
+ field.containing_oneof):
+ continue
+ name = field.camelcase_name
+ if name in js:
+ # Skip the field which has been serailized already.
+ continue
+ if _IsMapEntry(field):
+ js[name] = {}
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ js[name] = []
+ else:
+ js[name] = _FieldToJsonObject(field, field.default_value)
+
+ except ValueError as e:
+ raise SerializeToJsonError(
+ 'Failed to serialize {0} field: {1}.'.format(field.name, e))
+
+ return js
+
+
+def _FieldToJsonObject(
+ field, value, including_default_value_fields=False):
+ """Converts field value according to Proto3 JSON Specification."""
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ return _MessageToJsonObject(value, including_default_value_fields)
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
+ enum_value = field.enum_type.values_by_number.get(value, None)
+ if enum_value is not None:
+ return enum_value.name
+ else:
+ raise SerializeToJsonError('Enum field contains an integer value '
+ 'which can not mapped to an enum value.')
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
+ if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ # Use base64 Data encoding for bytes
+ return base64.b64encode(value).decode('utf-8')
+ else:
+ return value
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
+ return bool(value)
+ elif field.cpp_type in _INT64_TYPES:
+ return str(value)
+ elif field.cpp_type in _FLOAT_TYPES:
+ if math.isinf(value):
+ if value < 0.0:
+ return _NEG_INFINITY
+ else:
+ return _INFINITY
+ if math.isnan(value):
+ return _NAN
+ return value
+
+
+def _AnyMessageToJsonObject(message, including_default):
+ """Converts Any message according to Proto3 JSON Specification."""
+ if not message.ListFields():
+ return {}
+ js = {}
+ type_url = message.type_url
+ js['@type'] = type_url
+ sub_message = _CreateMessageFromTypeUrl(type_url)
+ sub_message.ParseFromString(message.value)
+ message_descriptor = sub_message.DESCRIPTOR
+ full_name = message_descriptor.full_name
+ if _IsWrapperMessage(message_descriptor):
+ js['value'] = _WrapperMessageToJsonObject(sub_message)
+ return js
+ if full_name in _WKTJSONMETHODS:
+ js['value'] = _WKTJSONMETHODS[full_name][0](sub_message, including_default)
+ return js
+ return _RegularMessageToJsonObject(sub_message, js, including_default)
+
+
+def _CreateMessageFromTypeUrl(type_url):
+ # TODO(jieluo): Should add a way that users can register the type resolver
+ # instead of the default one.
+ db = symbol_database.Default()
+ type_name = type_url.split('/')[-1]
+ try:
+ message_descriptor = db.pool.FindMessageTypeByName(type_name)
+ except KeyError:
+ raise TypeError(
+ 'Can not find message descriptor by type_url: {0}.'.format(type_url))
+ message_class = db.GetPrototype(message_descriptor)
+ return message_class()
+
+
+def _GenericMessageToJsonObject(message, unused_including_default):
+ """Converts message by ToJsonString according to Proto3 JSON Specification."""
+ # Duration, Timestamp and FieldMask have ToJsonString method to do the
+ # convert. Users can also call the method directly.
+ return message.ToJsonString()
+
+
+def _ValueMessageToJsonObject(message, unused_including_default=False):
+ """Converts Value message according to Proto3 JSON Specification."""
+ which = message.WhichOneof('kind')
+ # If the Value message is not set treat as null_value when serialize
+ # to JSON. The parse back result will be different from original message.
+ if which is None or which == 'null_value':
+ return None
+ if which == 'list_value':
+ return _ListValueMessageToJsonObject(message.list_value)
+ if which == 'struct_value':
+ value = message.struct_value
+ else:
+ value = getattr(message, which)
+ oneof_descriptor = message.DESCRIPTOR.fields_by_name[which]
+ return _FieldToJsonObject(oneof_descriptor, value)
+
+
+def _ListValueMessageToJsonObject(message, unused_including_default=False):
+ """Converts ListValue message according to Proto3 JSON Specification."""
+ return [_ValueMessageToJsonObject(value)
+ for value in message.values]
+
+
+def _StructMessageToJsonObject(message, unused_including_default=False):
+ """Converts Struct message according to Proto3 JSON Specification."""
+ fields = message.fields
+ ret = {}
+ for key in fields:
+ ret[key] = _ValueMessageToJsonObject(fields[key])
+ return ret
+
+
+def _IsWrapperMessage(message_descriptor):
+ return message_descriptor.file.name == 'google/protobuf/wrappers.proto'
+
+
+def _WrapperMessageToJsonObject(message):
+ return _FieldToJsonObject(
+ message.DESCRIPTOR.fields_by_name['value'], message.value)
+
+
+def _DuplicateChecker(js):
+ result = {}
+ for name, value in js:
+ if name in result:
+ raise ParseError('Failed to load JSON: duplicate key {0}.'.format(name))
+ result[name] = value
+ return result
+
+
+def Parse(text, message):
+ """Parses a JSON representation of a protocol message into a message.
+
+ Args:
+ text: Message JSON representation.
+ message: A protocol beffer message to merge into.
+
+ Returns:
+ The same message passed as argument.
+
+ Raises::
+ ParseError: On JSON parsing problems.
+ """
+ if not isinstance(text, six.text_type): text = text.decode('utf-8')
+ try:
+ if sys.version_info < (2, 7):
+ # object_pair_hook is not supported before python2.7
+ js = json.loads(text)
+ else:
+ js = json.loads(text, object_pairs_hook=_DuplicateChecker)
+ except ValueError as e:
+ raise ParseError('Failed to load JSON: {0}.'.format(str(e)))
+ _ConvertMessage(js, message)
+ return message
+
+
+def _ConvertFieldValuePair(js, message):
+ """Convert field value pairs into regular message.
+
+ Args:
+ js: A JSON object to convert the field value pairs.
+ message: A regular protocol message to record the data.
+
+ Raises:
+ ParseError: In case of problems converting.
+ """
+ names = []
+ message_descriptor = message.DESCRIPTOR
+ for name in js:
+ try:
+ field = message_descriptor.fields_by_camelcase_name.get(name, None)
+ if not field:
+ raise ParseError(
+ 'Message type "{0}" has no field named "{1}".'.format(
+ message_descriptor.full_name, name))
+ if name in names:
+ raise ParseError(
+ 'Message type "{0}" should not have multiple "{1}" fields.'.format(
+ message.DESCRIPTOR.full_name, name))
+ names.append(name)
+ # Check no other oneof field is parsed.
+ if field.containing_oneof is not None:
+ oneof_name = field.containing_oneof.name
+ if oneof_name in names:
+ raise ParseError('Message type "{0}" should not have multiple "{1}" '
+ 'oneof fields.'.format(
+ message.DESCRIPTOR.full_name, oneof_name))
+ names.append(oneof_name)
+
+ value = js[name]
+ if value is None:
+ message.ClearField(field.name)
+ continue
+
+ # Parse field value.
+ if _IsMapEntry(field):
+ message.ClearField(field.name)
+ _ConvertMapFieldValue(value, message, field)
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ message.ClearField(field.name)
+ if not isinstance(value, list):
+ raise ParseError('repeated field {0} must be in [] which is '
+ '{1}.'.format(name, value))
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ # Repeated message field.
+ for item in value:
+ sub_message = getattr(message, field.name).add()
+ # None is a null_value in Value.
+ if (item is None and
+ sub_message.DESCRIPTOR.full_name != 'google.protobuf.Value'):
+ raise ParseError('null is not allowed to be used as an element'
+ ' in a repeated field.')
+ _ConvertMessage(item, sub_message)
+ else:
+ # Repeated scalar field.
+ for item in value:
+ if item is None:
+ raise ParseError('null is not allowed to be used as an element'
+ ' in a repeated field.')
+ getattr(message, field.name).append(
+ _ConvertScalarFieldValue(item, field))
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ sub_message = getattr(message, field.name)
+ _ConvertMessage(value, sub_message)
+ else:
+ setattr(message, field.name, _ConvertScalarFieldValue(value, field))
+ except ParseError as e:
+ if field and field.containing_oneof is None:
+ raise ParseError('Failed to parse {0} field: {1}'.format(name, e))
+ else:
+ raise ParseError(str(e))
+ except ValueError as e:
+ raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
+ except TypeError as e:
+ raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
+
+
+def _ConvertMessage(value, message):
+ """Convert a JSON object into a message.
+
+ Args:
+ value: A JSON object.
+ message: A WKT or regular protocol message to record the data.
+
+ Raises:
+ ParseError: In case of convert problems.
+ """
+ message_descriptor = message.DESCRIPTOR
+ full_name = message_descriptor.full_name
+ if _IsWrapperMessage(message_descriptor):
+ _ConvertWrapperMessage(value, message)
+ elif full_name in _WKTJSONMETHODS:
+ _WKTJSONMETHODS[full_name][1](value, message)
+ else:
+ _ConvertFieldValuePair(value, message)
+
+
+def _ConvertAnyMessage(value, message):
+ """Convert a JSON representation into Any message."""
+ if isinstance(value, dict) and not value:
+ return
+ try:
+ type_url = value['@type']
+ except KeyError:
+ raise ParseError('@type is missing when parsing any message.')
+
+ sub_message = _CreateMessageFromTypeUrl(type_url)
+ message_descriptor = sub_message.DESCRIPTOR
+ full_name = message_descriptor.full_name
+ if _IsWrapperMessage(message_descriptor):
+ _ConvertWrapperMessage(value['value'], sub_message)
+ elif full_name in _WKTJSONMETHODS:
+ _WKTJSONMETHODS[full_name][1](value['value'], sub_message)
+ else:
+ del value['@type']
+ _ConvertFieldValuePair(value, sub_message)
+ # Sets Any message
+ message.value = sub_message.SerializeToString()
+ message.type_url = type_url
+
+
+def _ConvertGenericMessage(value, message):
+ """Convert a JSON representation into message with FromJsonString."""
+ # Durantion, Timestamp, FieldMask have FromJsonString method to do the
+ # convert. Users can also call the method directly.
+ message.FromJsonString(value)
+
+
+_INT_OR_FLOAT = six.integer_types + (float,)
+
+
+def _ConvertValueMessage(value, message):
+ """Convert a JSON representation into Value message."""
+ if isinstance(value, dict):
+ _ConvertStructMessage(value, message.struct_value)
+ elif isinstance(value, list):
+ _ConvertListValueMessage(value, message.list_value)
+ elif value is None:
+ message.null_value = 0
+ elif isinstance(value, bool):
+ message.bool_value = value
+ elif isinstance(value, six.string_types):
+ message.string_value = value
+ elif isinstance(value, _INT_OR_FLOAT):
+ message.number_value = value
+ else:
+ raise ParseError('Unexpected type for Value message.')
+
+
+def _ConvertListValueMessage(value, message):
+ """Convert a JSON representation into ListValue message."""
+ if not isinstance(value, list):
+ raise ParseError(
+ 'ListValue must be in [] which is {0}.'.format(value))
+ message.ClearField('values')
+ for item in value:
+ _ConvertValueMessage(item, message.values.add())
+
+
+def _ConvertStructMessage(value, message):
+ """Convert a JSON representation into Struct message."""
+ if not isinstance(value, dict):
+ raise ParseError(
+ 'Struct must be in a dict which is {0}.'.format(value))
+ for key in value:
+ _ConvertValueMessage(value[key], message.fields[key])
+ return
+
+
+def _ConvertWrapperMessage(value, message):
+ """Convert a JSON representation into Wrapper message."""
+ field = message.DESCRIPTOR.fields_by_name['value']
+ setattr(message, 'value', _ConvertScalarFieldValue(value, field))
+
+
+def _ConvertMapFieldValue(value, message, field):
+ """Convert map field value for a message map field.
+
+ Args:
+ value: A JSON object to convert the map field value.
+ message: A protocol message to record the converted data.
+ field: The descriptor of the map field to be converted.
+
+ Raises:
+ ParseError: In case of convert problems.
+ """
+ if not isinstance(value, dict):
+ raise ParseError(
+ 'Map field {0} must be in a dict which is {1}.'.format(
+ field.name, value))
+ key_field = field.message_type.fields_by_name['key']
+ value_field = field.message_type.fields_by_name['value']
+ for key in value:
+ key_value = _ConvertScalarFieldValue(key, key_field, True)
+ if value_field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ _ConvertMessage(value[key], getattr(message, field.name)[key_value])
+ else:
+ getattr(message, field.name)[key_value] = _ConvertScalarFieldValue(
+ value[key], value_field)
+
+
+def _ConvertScalarFieldValue(value, field, require_str=False):
+ """Convert a single scalar field value.
+
+ Args:
+ value: A scalar value to convert the scalar field value.
+ field: The descriptor of the field to convert.
+ require_str: If True, the field value must be a str.
+
+ Returns:
+ The converted scalar field value
+
+ Raises:
+ ParseError: In case of convert problems.
+ """
+ if field.cpp_type in _INT_TYPES:
+ return _ConvertInteger(value)
+ elif field.cpp_type in _FLOAT_TYPES:
+ return _ConvertFloat(value)
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
+ return _ConvertBool(value, require_str)
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
+ if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ return base64.b64decode(value)
+ else:
+ return value
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
+ # Convert an enum value.
+ enum_value = field.enum_type.values_by_name.get(value, None)
+ if enum_value is None:
+ raise ParseError(
+ 'Enum value must be a string literal with double quotes. '
+ 'Type "{0}" has no value named {1}.'.format(
+ field.enum_type.full_name, value))
+ return enum_value.number
+
+
+def _ConvertInteger(value):
+ """Convert an integer.
+
+ Args:
+ value: A scalar value to convert.
+
+ Returns:
+ The integer value.
+
+ Raises:
+ ParseError: If an integer couldn't be consumed.
+ """
+ if isinstance(value, float):
+ raise ParseError('Couldn\'t parse integer: {0}.'.format(value))
+
+ if isinstance(value, six.text_type) and value.find(' ') != -1:
+ raise ParseError('Couldn\'t parse integer: "{0}".'.format(value))
+
+ return int(value)
+
+
+def _ConvertFloat(value):
+ """Convert an floating point number."""
+ if value == 'nan':
+ raise ParseError('Couldn\'t parse float "nan", use "NaN" instead.')
+ try:
+ # Assume Python compatible syntax.
+ return float(value)
+ except ValueError:
+ # Check alternative spellings.
+ if value == _NEG_INFINITY:
+ return float('-inf')
+ elif value == _INFINITY:
+ return float('inf')
+ elif value == _NAN:
+ return float('nan')
+ else:
+ raise ParseError('Couldn\'t parse float: {0}.'.format(value))
+
+
+def _ConvertBool(value, require_str):
+ """Convert a boolean value.
+
+ Args:
+ value: A scalar value to convert.
+ require_str: If True, value must be a str.
+
+ Returns:
+ The bool parsed.
+
+ Raises:
+ ParseError: If a boolean value couldn't be consumed.
+ """
+ if require_str:
+ if value == 'true':
+ return True
+ elif value == 'false':
+ return False
+ else:
+ raise ParseError('Expected "true" or "false", not {0}.'.format(value))
+
+ if not isinstance(value, bool):
+ raise ParseError('Expected true or false without quotes.')
+ return value
+
+_WKTJSONMETHODS = {
+ 'google.protobuf.Any': [_AnyMessageToJsonObject,
+ _ConvertAnyMessage],
+ 'google.protobuf.Duration': [_GenericMessageToJsonObject,
+ _ConvertGenericMessage],
+ 'google.protobuf.FieldMask': [_GenericMessageToJsonObject,
+ _ConvertGenericMessage],
+ 'google.protobuf.ListValue': [_ListValueMessageToJsonObject,
+ _ConvertListValueMessage],
+ 'google.protobuf.Struct': [_StructMessageToJsonObject,
+ _ConvertStructMessage],
+ 'google.protobuf.Timestamp': [_GenericMessageToJsonObject,
+ _ConvertGenericMessage],
+ 'google.protobuf.Value': [_ValueMessageToJsonObject,
+ _ConvertValueMessage]
+}
diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py
index c186452a7..606f735f3 100755
--- a/python/google/protobuf/message.py
+++ b/python/google/protobuf/message.py
@@ -36,7 +36,6 @@
__author__ = 'robinson@google.com (Will Robinson)'
-
class Error(Exception): pass
class DecodeError(Error): pass
class EncodeError(Error): pass
@@ -233,12 +232,21 @@ class Message(object):
raise NotImplementedError
def HasField(self, field_name):
- """Checks if a certain field is set for the message. Note if the
- field_name is not defined in the message descriptor, ValueError will be
- raised."""
+ """Checks if a certain field is set for the message, or if any field inside
+ a oneof group is set. Note that if the field_name is not defined in the
+ message descriptor, ValueError will be raised."""
raise NotImplementedError
def ClearField(self, field_name):
+ """Clears the contents of a given field, or the field set inside a oneof
+ group. If the name neither refers to a defined field or oneof group,
+ ValueError is raised."""
+ raise NotImplementedError
+
+ def WhichOneof(self, oneof_group):
+ """Returns the name of the field that is set inside a oneof group, or
+ None if no field is set. If no group with the given name exists, ValueError
+ will be raised."""
raise NotImplementedError
def HasExtension(self, extension_handle):
@@ -247,6 +255,9 @@ class Message(object):
def ClearExtension(self, extension_handle):
raise NotImplementedError
+ def DiscardUnknownFields(self):
+ raise NotImplementedError
+
def ByteSize(self):
"""Returns the serialized size of this message.
Recursively calls ByteSize() on all contained messages.
diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py
index 7fd7bec0b..1b059d130 100644
--- a/python/google/protobuf/message_factory.py
+++ b/python/google/protobuf/message_factory.py
@@ -28,10 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#PY25 compatible for GAE.
-#
-# Copyright 2012 Google Inc. All Rights Reserved.
-
"""Provides a factory class for generating dynamic messages.
The easiest way to use this class is if you have access to the FileDescriptor
@@ -43,8 +39,6 @@ my_proto_instance = message_classes['some.proto.package.MessageName']()
__author__ = 'matthewtoia@google.com (Matt Toia)'
-import sys ##PY25
-from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool
from google.protobuf import message
from google.protobuf import reflection
@@ -55,8 +49,7 @@ class MessageFactory(object):
def __init__(self, pool=None):
"""Initializes a new factory."""
- self.pool = (pool or descriptor_pool.DescriptorPool(
- descriptor_database.DescriptorDatabase()))
+ self.pool = pool or descriptor_pool.DescriptorPool()
# local cache of all classes built from protobuf descriptors
self._classes = {}
@@ -75,8 +68,7 @@ class MessageFactory(object):
"""
if descriptor.full_name not in self._classes:
descriptor_name = descriptor.name
- if sys.version_info[0] < 3: ##PY25
-##!PY25 if str is bytes: # PY2
+ if str is bytes: # PY2
descriptor_name = descriptor.name.encode('ascii', 'ignore')
result_class = reflection.GeneratedProtocolMessageType(
descriptor_name,
@@ -111,7 +103,7 @@ class MessageFactory(object):
result = {}
for file_name in files:
file_desc = self.pool.FindFileByName(file_name)
- for name, msg in file_desc.message_types_by_name.iteritems():
+ for name, msg in file_desc.message_types_by_name.items():
if file_desc.package:
full_name = '.'.join([file_desc.package, name])
else:
@@ -128,7 +120,7 @@ class MessageFactory(object):
# ignore the registration if the original was the same, or raise
# an error if they were different.
- for name, extension in file_desc.extensions_by_name.iteritems():
+ for name, extension in file_desc.extensions_by_name.items():
if extension.containing_type.full_name not in self._classes:
self.GetPrototype(extension.containing_type)
extended_class = self._classes[extension.containing_type.full_name]
diff --git a/python/google/protobuf/proto_builder.py b/python/google/protobuf/proto_builder.py
new file mode 100644
index 000000000..736caed38
--- /dev/null
+++ b/python/google/protobuf/proto_builder.py
@@ -0,0 +1,130 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Dynamic Protobuf class creator."""
+
+try:
+ from collections import OrderedDict
+except ImportError:
+ from ordereddict import OrderedDict #PY26
+import hashlib
+import os
+
+from google.protobuf import descriptor_pb2
+from google.protobuf import message_factory
+
+
+def _GetMessageFromFactory(factory, full_name):
+ """Get a proto class from the MessageFactory by name.
+
+ Args:
+ factory: a MessageFactory instance.
+ full_name: str, the fully qualified name of the proto type.
+ Returns:
+ A class, for the type identified by full_name.
+ Raises:
+ KeyError, if the proto is not found in the factory's descriptor pool.
+ """
+ proto_descriptor = factory.pool.FindMessageTypeByName(full_name)
+ proto_cls = factory.GetPrototype(proto_descriptor)
+ return proto_cls
+
+
+def MakeSimpleProtoClass(fields, full_name=None, pool=None):
+ """Create a Protobuf class whose fields are basic types.
+
+ Note: this doesn't validate field names!
+
+ Args:
+ fields: dict of {name: field_type} mappings for each field in the proto. If
+ this is an OrderedDict the order will be maintained, otherwise the
+ fields will be sorted by name.
+ full_name: optional str, the fully-qualified name of the proto type.
+ pool: optional DescriptorPool instance.
+ Returns:
+ a class, the new protobuf class with a FileDescriptor.
+ """
+ factory = message_factory.MessageFactory(pool=pool)
+
+ if full_name is not None:
+ try:
+ proto_cls = _GetMessageFromFactory(factory, full_name)
+ return proto_cls
+ except KeyError:
+ # The factory's DescriptorPool doesn't know about this class yet.
+ pass
+
+ # Get a list of (name, field_type) tuples from the fields dict. If fields was
+ # an OrderedDict we keep the order, but otherwise we sort the field to ensure
+ # consistent ordering.
+ field_items = fields.items()
+ if not isinstance(fields, OrderedDict):
+ field_items = sorted(field_items)
+
+ # Use a consistent file name that is unlikely to conflict with any imported
+ # proto files.
+ fields_hash = hashlib.sha1()
+ for f_name, f_type in field_items:
+ fields_hash.update(f_name.encode('utf-8'))
+ fields_hash.update(str(f_type).encode('utf-8'))
+ proto_file_name = fields_hash.hexdigest() + '.proto'
+
+ # If the proto is anonymous, use the same hash to name it.
+ if full_name is None:
+ full_name = ('net.proto2.python.public.proto_builder.AnonymousProto_' +
+ fields_hash.hexdigest())
+ try:
+ proto_cls = _GetMessageFromFactory(factory, full_name)
+ return proto_cls
+ except KeyError:
+ # The factory's DescriptorPool doesn't know about this class yet.
+ pass
+
+ # This is the first time we see this proto: add a new descriptor to the pool.
+ factory.pool.Add(
+ _MakeFileDescriptorProto(proto_file_name, full_name, field_items))
+ return _GetMessageFromFactory(factory, full_name)
+
+
+def _MakeFileDescriptorProto(proto_file_name, full_name, field_items):
+ """Populate FileDescriptorProto for MessageFactory's DescriptorPool."""
+ package, name = full_name.rsplit('.', 1)
+ file_proto = descriptor_pb2.FileDescriptorProto()
+ file_proto.name = os.path.join(package.replace('.', '/'), proto_file_name)
+ file_proto.package = package
+ desc_proto = file_proto.message_type.add()
+ desc_proto.name = name
+ for f_number, (f_name, f_type) in enumerate(field_items, 1):
+ field_proto = desc_proto.field.add()
+ field_proto.name = f_name
+ field_proto.number = f_number
+ field_proto.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
+ field_proto.type = f_type
+ return file_proto
diff --git a/python/google/protobuf/pyext/__init__.py b/python/google/protobuf/pyext/__init__.py
index e69de29bb..558561412 100644
--- a/python/google/protobuf/pyext/__init__.py
+++ b/python/google/protobuf/pyext/__init__.py
@@ -0,0 +1,4 @@
+try:
+ __import__('pkg_resources').declare_namespace(__name__)
+except ImportError:
+ __path__ = __import__('pkgutil').extend_path(__path__, __name__)
diff --git a/python/google/protobuf/pyext/cpp_message.py b/python/google/protobuf/pyext/cpp_message.py
index dcf34a023..b215211ee 100644
--- a/python/google/protobuf/pyext/cpp_message.py
+++ b/python/google/protobuf/pyext/cpp_message.py
@@ -37,25 +37,29 @@ Descriptor objects at runtime backed by the protocol buffer C++ API.
__author__ = 'tibell@google.com (Johan Tibell)'
from google.protobuf.pyext import _message
-from google.protobuf import message
-def NewMessage(bases, message_descriptor, dictionary):
- """Creates a new protocol message *class*."""
- new_bases = []
- for base in bases:
- if base is message.Message:
- # _message.Message must come before message.Message as it
- # overrides methods in that class.
- new_bases.append(_message.Message)
- new_bases.append(base)
- return tuple(new_bases)
+class GeneratedProtocolMessageType(_message.MessageMeta):
+ """Metaclass for protocol message classes created at runtime from Descriptors.
-def InitMessage(message_descriptor, cls):
- """Constructs a new message instance (called before instance's __init__)."""
+ The protocol compiler currently uses this metaclass to create protocol
+ message classes at runtime. Clients can also manually create their own
+ classes at runtime, as in this example:
- def SubInit(self, **kwargs):
- super(cls, self).__init__(message_descriptor, **kwargs)
- cls.__init__ = SubInit
- cls.AddDescriptors(message_descriptor)
+ mydescriptor = Descriptor(.....)
+ class MyProtoClass(Message):
+ __metaclass__ = GeneratedProtocolMessageType
+ DESCRIPTOR = mydescriptor
+ myproto_instance = MyProtoClass()
+ myproto.foo_field = 23
+ ...
+
+ The above example will not work for nested types. If you wish to include them,
+ use reflection.MakeClass() instead of manually instantiating the class in
+ order to create the appropriate class structure.
+ """
+
+ # Must be consistent with the protocol-compiler code in
+ # proto2/compiler/internal/generator.*.
+ _DESCRIPTOR_KEY = 'DESCRIPTOR'
diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc
index 3f7be73cb..235575389 100644
--- a/python/google/protobuf/pyext/descriptor.cc
+++ b/python/google/protobuf/pyext/descriptor.cc
@@ -31,29 +31,119 @@
// Author: petar@google.com (Petar Petrov)
#include <Python.h>
+#include <frameobject.h>
#include <string>
+#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/descriptor.pb.h>
+#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/descriptor_containers.h>
+#include <google/protobuf/pyext/descriptor_pool.h>
+#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
-#define C(str) const_cast<char*>(str)
-
#if PY_MAJOR_VERSION >= 3
#define PyString_FromStringAndSize PyUnicode_FromStringAndSize
+ #define PyString_Check PyUnicode_Check
+ #define PyString_InternFromString PyUnicode_InternFromString
#define PyInt_FromLong PyLong_FromLong
+ #define PyInt_FromSize_t PyLong_FromSize_t
#if PY_VERSION_HEX < 0x03030000
#error "Python 3.0 - 3.2 are not supported."
- #else
- #define PyString_AsString(ob) \
- (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob))
#endif
+ #define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob)? \
+ ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \
+ PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
#endif
namespace google {
namespace protobuf {
namespace python {
+// Store interned descriptors, so that the same C++ descriptor yields the same
+// Python object. Objects are not immortal: this map does not own the
+// references, and items are deleted when the last reference to the object is
+// released.
+// This is enough to support the "is" operator on live objects.
+// All descriptors are stored here.
+hash_map<const void*, PyObject*> interned_descriptors;
+
+PyObject* PyString_FromCppString(const string& str) {
+ return PyString_FromStringAndSize(str.c_str(), str.size());
+}
+
+// Check that the calling Python code is the global scope of a _pb2.py module.
+// This function is used to support the current code generated by the proto
+// compiler, which creates descriptors, then update some properties.
+// For example:
+// message_descriptor = Descriptor(
+// name='Message',
+// fields = [FieldDescriptor(name='field')]
+// message_descriptor.fields[0].containing_type = message_descriptor
+//
+// This code is still executed, but the descriptors now have no other storage
+// than the (const) C++ pointer, and are immutable.
+// So we let this code pass, by simply ignoring the new value.
+//
+// From user code, descriptors still look immutable.
+//
+// TODO(amauryfa): Change the proto2 compiler to remove the assignments, and
+// remove this hack.
+bool _CalledFromGeneratedFile(int stacklevel) {
+#ifndef PYPY_VERSION
+ // This check is not critical and is somewhat difficult to implement correctly
+ // in PyPy.
+ PyFrameObject* frame = PyEval_GetFrame();
+ if (frame == NULL) {
+ return false;
+ }
+ while (stacklevel-- > 0) {
+ frame = frame->f_back;
+ if (frame == NULL) {
+ return false;
+ }
+ }
+ if (frame->f_globals != frame->f_locals) {
+ // Not at global module scope
+ return false;
+ }
+
+ if (frame->f_code->co_filename == NULL) {
+ return false;
+ }
+ char* filename;
+ Py_ssize_t filename_size;
+ if (PyString_AsStringAndSize(frame->f_code->co_filename,
+ &filename, &filename_size) < 0) {
+ // filename is not a string.
+ PyErr_Clear();
+ return false;
+ }
+ if (filename_size < 7) {
+ // filename is too short.
+ return false;
+ }
+ if (strcmp(&filename[filename_size - 7], "_pb2.py") != 0) {
+ // Filename is not ending with _pb2.
+ return false;
+ }
+#endif
+ return true;
+}
+
+// If the calling code is not a _pb2.py file, raise AttributeError.
+// To be used in attribute setters.
+static int CheckCalledFromGeneratedFile(const char* attr_name) {
+ if (_CalledFromGeneratedFile(0)) {
+ return 0;
+ }
+ PyErr_Format(PyExc_AttributeError,
+ "attribute is not writable: %s", attr_name);
+ return -1;
+}
+
#ifndef PyVarObject_HEAD_INIT
#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size,
@@ -63,58 +153,484 @@ namespace python {
#endif
-static google::protobuf::DescriptorPool* g_descriptor_pool = NULL;
+// Helper functions for descriptor objects.
+
+// A set of templates to retrieve the C++ FileDescriptor of any descriptor.
+template<class DescriptorClass>
+const FileDescriptor* GetFileDescriptor(const DescriptorClass* descriptor) {
+ return descriptor->file();
+}
+template<>
+const FileDescriptor* GetFileDescriptor(const FileDescriptor* descriptor) {
+ return descriptor;
+}
+template<>
+const FileDescriptor* GetFileDescriptor(const EnumValueDescriptor* descriptor) {
+ return descriptor->type()->file();
+}
+template<>
+const FileDescriptor* GetFileDescriptor(const OneofDescriptor* descriptor) {
+ return descriptor->containing_type()->file();
+}
+
+// Converts options into a Python protobuf, and cache the result.
+//
+// This is a bit tricky because options can contain extension fields defined in
+// the same proto file. In this case the options parsed from the serialized_pb
+// have unkown fields, and we need to parse them again.
+//
+// Always returns a new reference.
+template<class DescriptorClass>
+static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
+ // Options (and their extensions) are completely resolved in the proto file
+ // containing the descriptor.
+ PyDescriptorPool* pool = GetDescriptorPool_FromPool(
+ GetFileDescriptor(descriptor)->pool());
+
+ hash_map<const void*, PyObject*>* descriptor_options =
+ pool->descriptor_options;
+ // First search in the cache.
+ if (descriptor_options->find(descriptor) != descriptor_options->end()) {
+ PyObject *value = (*descriptor_options)[descriptor];
+ Py_INCREF(value);
+ return value;
+ }
+
+ // Build the Options object: get its Python class, and make a copy of the C++
+ // read-only instance.
+ const Message& options(descriptor->options());
+ const Descriptor *message_type = options.GetDescriptor();
+ CMessageClass* message_class(
+ cdescriptor_pool::GetMessageClass(pool, message_type));
+ if (message_class == NULL) {
+ // The Options message was not found in the current DescriptorPool.
+ // In this case, there cannot be extensions to these options, and we can
+ // try to use the basic pool instead.
+ PyErr_Clear();
+ message_class = cdescriptor_pool::GetMessageClass(
+ GetDefaultDescriptorPool(), message_type);
+ }
+ if (message_class == NULL) {
+ PyErr_Format(PyExc_TypeError, "Could not retrieve class for Options: %s",
+ message_type->full_name().c_str());
+ return NULL;
+ }
+ ScopedPyObjectPtr value(
+ PyEval_CallObject(message_class->AsPyObject(), NULL));
+ if (value == NULL) {
+ return NULL;
+ }
+ if (!PyObject_TypeCheck(value.get(), &CMessage_Type)) {
+ PyErr_Format(PyExc_TypeError, "Invalid class for %s: %s",
+ message_type->full_name().c_str(),
+ Py_TYPE(value.get())->tp_name);
+ return NULL;
+ }
+ CMessage* cmsg = reinterpret_cast<CMessage*>(value.get());
+
+ const Reflection* reflection = options.GetReflection();
+ const UnknownFieldSet& unknown_fields(reflection->GetUnknownFields(options));
+ if (unknown_fields.empty()) {
+ cmsg->message->CopyFrom(options);
+ } else {
+ // Reparse options string! XXX call cmessage::MergeFromString
+ string serialized;
+ options.SerializeToString(&serialized);
+ io::CodedInputStream input(
+ reinterpret_cast<const uint8*>(serialized.c_str()), serialized.size());
+ input.SetExtensionRegistry(pool->pool, pool->message_factory);
+ bool success = cmsg->message->MergePartialFromCodedStream(&input);
+ if (!success) {
+ PyErr_Format(PyExc_ValueError, "Error parsing Options message");
+ return NULL;
+ }
+ }
+
+ // Cache the result.
+ Py_INCREF(value.get());
+ (*pool->descriptor_options)[descriptor] = value.get();
+
+ return value.release();
+}
+
+// Copy the C++ descriptor to a Python message.
+// The Python message is an instance of descriptor_pb2.DescriptorProto
+// or similar.
+template<class DescriptorProtoClass, class DescriptorClass>
+static PyObject* CopyToPythonProto(const DescriptorClass *descriptor,
+ PyObject *target) {
+ const Descriptor* self_descriptor =
+ DescriptorProtoClass::default_instance().GetDescriptor();
+ CMessage* message = reinterpret_cast<CMessage*>(target);
+ if (!PyObject_TypeCheck(target, &CMessage_Type) ||
+ message->message->GetDescriptor() != self_descriptor) {
+ PyErr_Format(PyExc_TypeError, "Not a %s message",
+ self_descriptor->full_name().c_str());
+ return NULL;
+ }
+ cmessage::AssureWritable(message);
+ DescriptorProtoClass* descriptor_message =
+ static_cast<DescriptorProtoClass*>(message->message);
+ descriptor->CopyTo(descriptor_message);
+ Py_RETURN_NONE;
+}
+
+// All Descriptors classes share the same memory layout.
+typedef struct PyBaseDescriptor {
+ PyObject_HEAD
+
+ // Pointer to the C++ proto2 descriptor.
+ // Like all descriptors, it is owned by the global DescriptorPool.
+ const void* descriptor;
+
+ // Owned reference to the DescriptorPool, to ensure it is kept alive.
+ PyDescriptorPool* pool;
+} PyBaseDescriptor;
+
+
+// FileDescriptor structure "inherits" from the base descriptor.
+typedef struct PyFileDescriptor {
+ PyBaseDescriptor base;
+
+ // The cached version of serialized pb. Either NULL, or a Bytes string.
+ // We own the reference.
+ PyObject *serialized_pb;
+} PyFileDescriptor;
-namespace cfield_descriptor {
-static void Dealloc(CFieldDescriptor* self) {
- Py_CLEAR(self->descriptor_field);
+namespace descriptor {
+
+// Creates or retrieve a Python descriptor of the specified type.
+// Objects are interned: the same descriptor will return the same object if it
+// was kept alive.
+// 'was_created' is an optional pointer to a bool, and is set to true if a new
+// object was allocated.
+// Always return a new reference.
+template<class DescriptorClass>
+PyObject* NewInternedDescriptor(PyTypeObject* type,
+ const DescriptorClass* descriptor,
+ bool* was_created) {
+ if (was_created) {
+ *was_created = false;
+ }
+ if (descriptor == NULL) {
+ PyErr_BadInternalCall();
+ return NULL;
+ }
+
+ // See if the object is in the map of interned descriptors
+ hash_map<const void*, PyObject*>::iterator it =
+ interned_descriptors.find(descriptor);
+ if (it != interned_descriptors.end()) {
+ GOOGLE_DCHECK(Py_TYPE(it->second) == type);
+ Py_INCREF(it->second);
+ return it->second;
+ }
+ // Create a new descriptor object
+ PyBaseDescriptor* py_descriptor = PyObject_New(
+ PyBaseDescriptor, type);
+ if (py_descriptor == NULL) {
+ return NULL;
+ }
+ py_descriptor->descriptor = descriptor;
+
+ // and cache it.
+ interned_descriptors.insert(
+ std::make_pair(descriptor, reinterpret_cast<PyObject*>(py_descriptor)));
+
+ // Ensures that the DescriptorPool stays alive.
+ PyDescriptorPool* pool = GetDescriptorPool_FromPool(
+ GetFileDescriptor(descriptor)->pool());
+ if (pool == NULL) {
+ // Don't DECREF, the object is not fully initialized.
+ PyObject_Del(py_descriptor);
+ return NULL;
+ }
+ Py_INCREF(pool);
+ py_descriptor->pool = pool;
+
+ if (was_created) {
+ *was_created = true;
+ }
+ return reinterpret_cast<PyObject*>(py_descriptor);
+}
+
+static void Dealloc(PyBaseDescriptor* self) {
+ // Remove from interned dictionary
+ interned_descriptors.erase(self->descriptor);
+ Py_CLEAR(self->pool);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
-static PyObject* GetFullName(CFieldDescriptor* self, void *closure) {
- return PyString_FromStringAndSize(
- self->descriptor->full_name().c_str(),
- self->descriptor->full_name().size());
+static PyGetSetDef Getters[] = {
+ {NULL}
+};
+
+PyTypeObject PyBaseDescriptor_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".DescriptorBase", // tp_name
+ sizeof(PyBaseDescriptor), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "Descriptors base class", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ 0, // tp_methods
+ 0, // tp_members
+ Getters, // tp_getset
+};
+
+} // namespace descriptor
+
+const void* PyDescriptor_AsVoidPtr(PyObject* obj) {
+ if (!PyObject_TypeCheck(obj, &descriptor::PyBaseDescriptor_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a BaseDescriptor");
+ return NULL;
+ }
+ return reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor;
}
-static PyObject* GetName(CFieldDescriptor *self, void *closure) {
- return PyString_FromStringAndSize(
- self->descriptor->name().c_str(),
- self->descriptor->name().size());
+namespace message_descriptor {
+
+// Unchecked accessor to the C++ pointer.
+static const Descriptor* _GetDescriptor(PyBaseDescriptor* self) {
+ return reinterpret_cast<const Descriptor*>(self->descriptor);
}
-static PyObject* GetCppType(CFieldDescriptor *self, void *closure) {
- return PyInt_FromLong(self->descriptor->cpp_type());
+static PyObject* GetName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->name());
}
-static PyObject* GetLabel(CFieldDescriptor *self, void *closure) {
- return PyInt_FromLong(self->descriptor->label());
+static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->full_name());
}
-static PyObject* GetID(CFieldDescriptor *self, void *closure) {
- return PyLong_FromVoidPtr(self);
+static PyObject* GetFile(PyBaseDescriptor *self, void *closure) {
+ return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file());
+}
+
+static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) {
+ // Retuns the canonical class for the given descriptor.
+ // This is the class that was registered with the primary descriptor pool
+ // which contains this descriptor.
+ // This might not be the one you expect! For example the returned object does
+ // not know about extensions defined in a custom pool.
+ CMessageClass* concrete_class(cdescriptor_pool::GetMessageClass(
+ GetDescriptorPool_FromPool(_GetDescriptor(self)->file()->pool()),
+ _GetDescriptor(self)));
+ Py_XINCREF(concrete_class);
+ return concrete_class->AsPyObject();
+}
+
+static PyObject* GetFieldsByName(PyBaseDescriptor* self, void *closure) {
+ return NewMessageFieldsByName(_GetDescriptor(self));
+}
+
+static PyObject* GetFieldsByCamelcaseName(PyBaseDescriptor* self,
+ void *closure) {
+ return NewMessageFieldsByCamelcaseName(_GetDescriptor(self));
+}
+
+static PyObject* GetFieldsByNumber(PyBaseDescriptor* self, void *closure) {
+ return NewMessageFieldsByNumber(_GetDescriptor(self));
+}
+
+static PyObject* GetFieldsSeq(PyBaseDescriptor* self, void *closure) {
+ return NewMessageFieldsSeq(_GetDescriptor(self));
+}
+
+static PyObject* GetNestedTypesByName(PyBaseDescriptor* self, void *closure) {
+ return NewMessageNestedTypesByName(_GetDescriptor(self));
+}
+
+static PyObject* GetNestedTypesSeq(PyBaseDescriptor* self, void *closure) {
+ return NewMessageNestedTypesSeq(_GetDescriptor(self));
+}
+
+static PyObject* GetExtensionsByName(PyBaseDescriptor* self, void *closure) {
+ return NewMessageExtensionsByName(_GetDescriptor(self));
+}
+
+static PyObject* GetExtensions(PyBaseDescriptor* self, void *closure) {
+ return NewMessageExtensionsSeq(_GetDescriptor(self));
+}
+
+static PyObject* GetEnumsSeq(PyBaseDescriptor* self, void *closure) {
+ return NewMessageEnumsSeq(_GetDescriptor(self));
+}
+
+static PyObject* GetEnumTypesByName(PyBaseDescriptor* self, void *closure) {
+ return NewMessageEnumsByName(_GetDescriptor(self));
+}
+
+static PyObject* GetEnumValuesByName(PyBaseDescriptor* self, void *closure) {
+ return NewMessageEnumValuesByName(_GetDescriptor(self));
+}
+
+static PyObject* GetOneofsByName(PyBaseDescriptor* self, void *closure) {
+ return NewMessageOneofsByName(_GetDescriptor(self));
+}
+
+static PyObject* GetOneofsSeq(PyBaseDescriptor* self, void *closure) {
+ return NewMessageOneofsSeq(_GetDescriptor(self));
+}
+
+static PyObject* IsExtendable(PyBaseDescriptor *self, void *closure) {
+ if (_GetDescriptor(self)->extension_range_count() > 0) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
+static PyObject* GetExtensionRanges(PyBaseDescriptor *self, void *closure) {
+ const Descriptor* descriptor = _GetDescriptor(self);
+ PyObject* range_list = PyList_New(descriptor->extension_range_count());
+
+ for (int i = 0; i < descriptor->extension_range_count(); i++) {
+ const Descriptor::ExtensionRange* range = descriptor->extension_range(i);
+ PyObject* start = PyInt_FromLong(range->start);
+ PyObject* end = PyInt_FromLong(range->end);
+ PyList_SetItem(range_list, i, PyTuple_Pack(2, start, end));
+ }
+
+ return range_list;
+}
+
+static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) {
+ const Descriptor* containing_type =
+ _GetDescriptor(self)->containing_type();
+ if (containing_type) {
+ return PyMessageDescriptor_FromDescriptor(containing_type);
+ } else {
+ Py_RETURN_NONE;
+ }
+}
+
+static int SetContainingType(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("containing_type");
+}
+
+static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) {
+ const MessageOptions& options(_GetDescriptor(self)->options());
+ if (&options != &MessageOptions::default_instance()) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+static int SetHasOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("has_options");
+}
+
+static PyObject* GetOptions(PyBaseDescriptor *self) {
+ return GetOrBuildOptions(_GetDescriptor(self));
+}
+
+static int SetOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_options");
+}
+
+static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) {
+ return CopyToPythonProto<DescriptorProto>(_GetDescriptor(self), target);
+}
+
+static PyObject* EnumValueName(PyBaseDescriptor *self, PyObject *args) {
+ const char *enum_name;
+ int number;
+ if (!PyArg_ParseTuple(args, "si", &enum_name, &number))
+ return NULL;
+ const EnumDescriptor *enum_type =
+ _GetDescriptor(self)->FindEnumTypeByName(enum_name);
+ if (enum_type == NULL) {
+ PyErr_SetString(PyExc_KeyError, enum_name);
+ return NULL;
+ }
+ const EnumValueDescriptor *enum_value =
+ enum_type->FindValueByNumber(number);
+ if (enum_value == NULL) {
+ PyErr_Format(PyExc_KeyError, "%d", number);
+ return NULL;
+ }
+ return PyString_FromCppString(enum_value->name());
+}
+
+static PyObject* GetSyntax(PyBaseDescriptor *self, void *closure) {
+ return PyString_InternFromString(
+ FileDescriptor::SyntaxName(_GetDescriptor(self)->file()->syntax()));
}
static PyGetSetDef Getters[] = {
- { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL},
- { C("name"), (getter)GetName, NULL, "last name", NULL},
- { C("cpp_type"), (getter)GetCppType, NULL, "C++ Type", NULL},
- { C("label"), (getter)GetLabel, NULL, "Label", NULL},
- { C("id"), (getter)GetID, NULL, "ID", NULL},
+ { "name", (getter)GetName, NULL, "Last name"},
+ { "full_name", (getter)GetFullName, NULL, "Full name"},
+ { "_concrete_class", (getter)GetConcreteClass, NULL, "concrete class"},
+ { "file", (getter)GetFile, NULL, "File descriptor"},
+
+ { "fields", (getter)GetFieldsSeq, NULL, "Fields sequence"},
+ { "fields_by_name", (getter)GetFieldsByName, NULL, "Fields by name"},
+ { "fields_by_camelcase_name", (getter)GetFieldsByCamelcaseName, NULL,
+ "Fields by camelCase name"},
+ { "fields_by_number", (getter)GetFieldsByNumber, NULL, "Fields by number"},
+ { "nested_types", (getter)GetNestedTypesSeq, NULL, "Nested types sequence"},
+ { "nested_types_by_name", (getter)GetNestedTypesByName, NULL,
+ "Nested types by name"},
+ { "extensions", (getter)GetExtensions, NULL, "Extensions Sequence"},
+ { "extensions_by_name", (getter)GetExtensionsByName, NULL,
+ "Extensions by name"},
+ { "extension_ranges", (getter)GetExtensionRanges, NULL, "Extension ranges"},
+ { "enum_types", (getter)GetEnumsSeq, NULL, "Enum sequence"},
+ { "enum_types_by_name", (getter)GetEnumTypesByName, NULL,
+ "Enum types by name"},
+ { "enum_values_by_name", (getter)GetEnumValuesByName, NULL,
+ "Enum values by name"},
+ { "oneofs_by_name", (getter)GetOneofsByName, NULL, "Oneofs by name"},
+ { "oneofs", (getter)GetOneofsSeq, NULL, "Oneofs by name"},
+ { "containing_type", (getter)GetContainingType, (setter)SetContainingType,
+ "Containing type"},
+ { "is_extendable", (getter)IsExtendable, (setter)NULL},
+ { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
+ { "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ { "syntax", (getter)GetSyntax, (setter)NULL, "Syntax"},
{NULL}
};
-} // namespace cfield_descriptor
+static PyMethodDef Methods[] = {
+ { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, },
+ { "CopyToProto", (PyCFunction)CopyToProto, METH_O, },
+ { "EnumValueName", (PyCFunction)EnumValueName, METH_VARARGS, },
+ {NULL}
+};
+
+} // namespace message_descriptor
-PyTypeObject CFieldDescriptor_Type = {
+PyTypeObject PyMessageDescriptor_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- C("google.protobuf.internal."
- "_net_proto2___python."
- "CFieldDescriptor"), // tp_name
- sizeof(CFieldDescriptor), // tp_basicsize
+ FULL_MODULE_NAME ".MessageDescriptor", // tp_name
+ sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize
- (destructor)cfield_descriptor::Dealloc, // tp_dealloc
+ 0, // tp_dealloc
0, // tp_print
0, // tp_getattr
0, // tp_setattr
@@ -130,223 +646,934 @@ PyTypeObject CFieldDescriptor_Type = {
0, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
- C("A Field Descriptor"), // tp_doc
+ "A Message Descriptor", // tp_doc
0, // tp_traverse
0, // tp_clear
0, // tp_richcompare
0, // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
- 0, // tp_methods
+ message_descriptor::Methods, // tp_methods
0, // tp_members
- cfield_descriptor::Getters, // tp_getset
- 0, // tp_base
- 0, // tp_dict
- 0, // tp_descr_get
- 0, // tp_descr_set
- 0, // tp_dictoffset
- 0, // tp_init
- PyType_GenericAlloc, // tp_alloc
- PyType_GenericNew, // tp_new
- PyObject_Del, // tp_free
+ message_descriptor::Getters, // tp_getset
+ &descriptor::PyBaseDescriptor_Type, // tp_base
};
-namespace cdescriptor_pool {
-
-static void Dealloc(CDescriptorPool* self) {
- Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+PyObject* PyMessageDescriptor_FromDescriptor(
+ const Descriptor* message_descriptor) {
+ return descriptor::NewInternedDescriptor(
+ &PyMessageDescriptor_Type, message_descriptor, NULL);
}
-static PyObject* NewCDescriptor(
- const google::protobuf::FieldDescriptor* field_descriptor) {
- CFieldDescriptor* cfield_descriptor = PyObject_New(
- CFieldDescriptor, &CFieldDescriptor_Type);
- if (cfield_descriptor == NULL) {
+const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj) {
+ if (!PyObject_TypeCheck(obj, &PyMessageDescriptor_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a MessageDescriptor");
return NULL;
}
- cfield_descriptor->descriptor = field_descriptor;
- cfield_descriptor->descriptor_field = NULL;
+ return reinterpret_cast<const Descriptor*>(
+ reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor);
+}
+
+namespace field_descriptor {
- return reinterpret_cast<PyObject*>(cfield_descriptor);
+// Unchecked accessor to the C++ pointer.
+static const FieldDescriptor* _GetDescriptor(
+ PyBaseDescriptor *self) {
+ return reinterpret_cast<const FieldDescriptor*>(self->descriptor);
}
-PyObject* FindFieldByName(CDescriptorPool* self, PyObject* name) {
- const char* full_field_name = PyString_AsString(name);
- if (full_field_name == NULL) {
- return NULL;
+static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->full_name());
+}
+
+static PyObject* GetName(PyBaseDescriptor *self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->name());
+}
+
+static PyObject* GetCamelcaseName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->camelcase_name());
+}
+
+static PyObject* GetType(PyBaseDescriptor *self, void *closure) {
+ return PyInt_FromLong(_GetDescriptor(self)->type());
+}
+
+static PyObject* GetCppType(PyBaseDescriptor *self, void *closure) {
+ return PyInt_FromLong(_GetDescriptor(self)->cpp_type());
+}
+
+static PyObject* GetLabel(PyBaseDescriptor *self, void *closure) {
+ return PyInt_FromLong(_GetDescriptor(self)->label());
+}
+
+static PyObject* GetNumber(PyBaseDescriptor *self, void *closure) {
+ return PyInt_FromLong(_GetDescriptor(self)->number());
+}
+
+static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) {
+ return PyInt_FromLong(_GetDescriptor(self)->index());
+}
+
+static PyObject* GetID(PyBaseDescriptor *self, void *closure) {
+ return PyLong_FromVoidPtr(self);
+}
+
+static PyObject* IsExtension(PyBaseDescriptor *self, void *closure) {
+ return PyBool_FromLong(_GetDescriptor(self)->is_extension());
+}
+
+static PyObject* HasDefaultValue(PyBaseDescriptor *self, void *closure) {
+ return PyBool_FromLong(_GetDescriptor(self)->has_default_value());
+}
+
+static PyObject* GetDefaultValue(PyBaseDescriptor *self, void *closure) {
+ PyObject *result;
+
+ switch (_GetDescriptor(self)->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32: {
+ int32 value = _GetDescriptor(self)->default_value_int32();
+ result = PyInt_FromLong(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_INT64: {
+ int64 value = _GetDescriptor(self)->default_value_int64();
+ result = PyLong_FromLongLong(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_UINT32: {
+ uint32 value = _GetDescriptor(self)->default_value_uint32();
+ result = PyInt_FromSize_t(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_UINT64: {
+ uint64 value = _GetDescriptor(self)->default_value_uint64();
+ result = PyLong_FromUnsignedLongLong(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_FLOAT: {
+ float value = _GetDescriptor(self)->default_value_float();
+ result = PyFloat_FromDouble(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_DOUBLE: {
+ double value = _GetDescriptor(self)->default_value_double();
+ result = PyFloat_FromDouble(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_BOOL: {
+ bool value = _GetDescriptor(self)->default_value_bool();
+ result = PyBool_FromLong(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_STRING: {
+ string value = _GetDescriptor(self)->default_value_string();
+ result = ToStringObject(_GetDescriptor(self), value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_ENUM: {
+ const EnumValueDescriptor* value =
+ _GetDescriptor(self)->default_value_enum();
+ result = PyInt_FromLong(value->number());
+ break;
+ }
+ default:
+ PyErr_Format(PyExc_NotImplementedError, "default value for %s",
+ _GetDescriptor(self)->full_name().c_str());
+ return NULL;
+ }
+ return result;
+}
+
+static PyObject* GetCDescriptor(PyObject *self, void *closure) {
+ Py_INCREF(self);
+ return self;
+}
+
+static PyObject *GetEnumType(PyBaseDescriptor *self, void *closure) {
+ const EnumDescriptor* enum_type = _GetDescriptor(self)->enum_type();
+ if (enum_type) {
+ return PyEnumDescriptor_FromDescriptor(enum_type);
+ } else {
+ Py_RETURN_NONE;
}
+}
- const google::protobuf::FieldDescriptor* field_descriptor = NULL;
+static int SetEnumType(PyBaseDescriptor *self, PyObject *value, void *closure) {
+ return CheckCalledFromGeneratedFile("enum_type");
+}
- field_descriptor = self->pool->FindFieldByName(full_field_name);
+static PyObject *GetMessageType(PyBaseDescriptor *self, void *closure) {
+ const Descriptor* message_type = _GetDescriptor(self)->message_type();
+ if (message_type) {
+ return PyMessageDescriptor_FromDescriptor(message_type);
+ } else {
+ Py_RETURN_NONE;
+ }
+}
- if (field_descriptor == NULL) {
- PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s",
- full_field_name);
- return NULL;
+static int SetMessageType(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("message_type");
+}
+
+static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) {
+ const Descriptor* containing_type =
+ _GetDescriptor(self)->containing_type();
+ if (containing_type) {
+ return PyMessageDescriptor_FromDescriptor(containing_type);
+ } else {
+ Py_RETURN_NONE;
}
+}
- return NewCDescriptor(field_descriptor);
+static int SetContainingType(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("containing_type");
}
-PyObject* FindExtensionByName(CDescriptorPool* self, PyObject* arg) {
- const char* full_field_name = PyString_AsString(arg);
- if (full_field_name == NULL) {
- return NULL;
+static PyObject* GetExtensionScope(PyBaseDescriptor *self, void *closure) {
+ const Descriptor* extension_scope =
+ _GetDescriptor(self)->extension_scope();
+ if (extension_scope) {
+ return PyMessageDescriptor_FromDescriptor(extension_scope);
+ } else {
+ Py_RETURN_NONE;
}
+}
+
+static PyObject* GetContainingOneof(PyBaseDescriptor *self, void *closure) {
+ const OneofDescriptor* containing_oneof =
+ _GetDescriptor(self)->containing_oneof();
+ if (containing_oneof) {
+ return PyOneofDescriptor_FromDescriptor(containing_oneof);
+ } else {
+ Py_RETURN_NONE;
+ }
+}
+
+static int SetContainingOneof(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("containing_oneof");
+}
+
+static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) {
+ const FieldOptions& options(_GetDescriptor(self)->options());
+ if (&options != &FieldOptions::default_instance()) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+static int SetHasOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("has_options");
+}
+
+static PyObject* GetOptions(PyBaseDescriptor *self) {
+ return GetOrBuildOptions(_GetDescriptor(self));
+}
+
+static int SetOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_options");
+}
+
+
+static PyGetSetDef Getters[] = {
+ { "full_name", (getter)GetFullName, NULL, "Full name"},
+ { "name", (getter)GetName, NULL, "Unqualified name"},
+ { "camelcase_name", (getter)GetCamelcaseName, NULL, "Camelcase name"},
+ { "type", (getter)GetType, NULL, "C++ Type"},
+ { "cpp_type", (getter)GetCppType, NULL, "C++ Type"},
+ { "label", (getter)GetLabel, NULL, "Label"},
+ { "number", (getter)GetNumber, NULL, "Number"},
+ { "index", (getter)GetIndex, NULL, "Index"},
+ { "default_value", (getter)GetDefaultValue, NULL, "Default Value"},
+ { "has_default_value", (getter)HasDefaultValue},
+ { "is_extension", (getter)IsExtension, NULL, "ID"},
+ { "id", (getter)GetID, NULL, "ID"},
+ { "_cdescriptor", (getter)GetCDescriptor, NULL, "HAACK REMOVE ME"},
+
+ { "message_type", (getter)GetMessageType, (setter)SetMessageType,
+ "Message type"},
+ { "enum_type", (getter)GetEnumType, (setter)SetEnumType, "Enum type"},
+ { "containing_type", (getter)GetContainingType, (setter)SetContainingType,
+ "Containing type"},
+ { "extension_scope", (getter)GetExtensionScope, (setter)NULL,
+ "Extension scope"},
+ { "containing_oneof", (getter)GetContainingOneof, (setter)SetContainingOneof,
+ "Containing oneof"},
+ { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
+ { "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ {NULL}
+};
+
+static PyMethodDef Methods[] = {
+ { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, },
+ {NULL}
+};
+
+} // namespace field_descriptor
+
+PyTypeObject PyFieldDescriptor_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".FieldDescriptor", // tp_name
+ sizeof(PyBaseDescriptor), // tp_basicsize
+ 0, // tp_itemsize
+ 0, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A Field Descriptor", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ field_descriptor::Methods, // tp_methods
+ 0, // tp_members
+ field_descriptor::Getters, // tp_getset
+ &descriptor::PyBaseDescriptor_Type, // tp_base
+};
+
+PyObject* PyFieldDescriptor_FromDescriptor(
+ const FieldDescriptor* field_descriptor) {
+ return descriptor::NewInternedDescriptor(
+ &PyFieldDescriptor_Type, field_descriptor, NULL);
+}
- const google::protobuf::FieldDescriptor* field_descriptor =
- self->pool->FindExtensionByName(full_field_name);
- if (field_descriptor == NULL) {
- PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s",
- full_field_name);
+const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj) {
+ if (!PyObject_TypeCheck(obj, &PyFieldDescriptor_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a FieldDescriptor");
return NULL;
}
+ return reinterpret_cast<const FieldDescriptor*>(
+ reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor);
+}
+
+namespace enum_descriptor {
+
+// Unchecked accessor to the C++ pointer.
+static const EnumDescriptor* _GetDescriptor(
+ PyBaseDescriptor *self) {
+ return reinterpret_cast<const EnumDescriptor*>(self->descriptor);
+}
+
+static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->full_name());
+}
+
+static PyObject* GetName(PyBaseDescriptor *self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->name());
+}
+
+static PyObject* GetFile(PyBaseDescriptor *self, void *closure) {
+ return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file());
+}
+
+static PyObject* GetEnumvaluesByName(PyBaseDescriptor* self, void *closure) {
+ return NewEnumValuesByName(_GetDescriptor(self));
+}
+
+static PyObject* GetEnumvaluesByNumber(PyBaseDescriptor* self, void *closure) {
+ return NewEnumValuesByNumber(_GetDescriptor(self));
+}
+
+static PyObject* GetEnumvaluesSeq(PyBaseDescriptor* self, void *closure) {
+ return NewEnumValuesSeq(_GetDescriptor(self));
+}
+
+static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) {
+ const Descriptor* containing_type =
+ _GetDescriptor(self)->containing_type();
+ if (containing_type) {
+ return PyMessageDescriptor_FromDescriptor(containing_type);
+ } else {
+ Py_RETURN_NONE;
+ }
+}
+
+static int SetContainingType(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("containing_type");
+}
+
+
+static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) {
+ const EnumOptions& options(_GetDescriptor(self)->options());
+ if (&options != &EnumOptions::default_instance()) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+static int SetHasOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("has_options");
+}
+
+static PyObject* GetOptions(PyBaseDescriptor *self) {
+ return GetOrBuildOptions(_GetDescriptor(self));
+}
+
+static int SetOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_options");
+}
- return NewCDescriptor(field_descriptor);
+static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) {
+ return CopyToPythonProto<EnumDescriptorProto>(_GetDescriptor(self), target);
}
static PyMethodDef Methods[] = {
- { C("FindFieldByName"),
- (PyCFunction)FindFieldByName,
- METH_O,
- C("Searches for a field descriptor by full name.") },
- { C("FindExtensionByName"),
- (PyCFunction)FindExtensionByName,
- METH_O,
- C("Searches for extension descriptor by full name.") },
+ { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, },
+ { "CopyToProto", (PyCFunction)CopyToProto, METH_O, },
{NULL}
};
-} // namespace cdescriptor_pool
+static PyGetSetDef Getters[] = {
+ { "full_name", (getter)GetFullName, NULL, "Full name"},
+ { "name", (getter)GetName, NULL, "last name"},
+ { "file", (getter)GetFile, NULL, "File descriptor"},
+ { "values", (getter)GetEnumvaluesSeq, NULL, "values"},
+ { "values_by_name", (getter)GetEnumvaluesByName, NULL,
+ "Enum values by name"},
+ { "values_by_number", (getter)GetEnumvaluesByNumber, NULL,
+ "Enum values by number"},
+
+ { "containing_type", (getter)GetContainingType, (setter)SetContainingType,
+ "Containing type"},
+ { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
+ { "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ {NULL}
+};
+
+} // namespace enum_descriptor
-PyTypeObject CDescriptorPool_Type = {
+PyTypeObject PyEnumDescriptor_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- C("google.protobuf.internal."
- "_net_proto2___python."
- "CFieldDescriptor"), // tp_name
- sizeof(CDescriptorPool), // tp_basicsize
- 0, // tp_itemsize
- (destructor)cdescriptor_pool::Dealloc, // tp_dealloc
- 0, // tp_print
- 0, // tp_getattr
- 0, // tp_setattr
- 0, // tp_compare
- 0, // tp_repr
- 0, // tp_as_number
- 0, // tp_as_sequence
- 0, // tp_as_mapping
- 0, // tp_hash
- 0, // tp_call
- 0, // tp_str
- 0, // tp_getattro
- 0, // tp_setattro
- 0, // tp_as_buffer
- Py_TPFLAGS_DEFAULT, // tp_flags
- C("A Descriptor Pool"), // tp_doc
- 0, // tp_traverse
- 0, // tp_clear
- 0, // tp_richcompare
- 0, // tp_weaklistoffset
- 0, // tp_iter
- 0, // tp_iternext
- cdescriptor_pool::Methods, // tp_methods
- 0, // tp_members
- 0, // tp_getset
- 0, // tp_base
- 0, // tp_dict
- 0, // tp_descr_get
- 0, // tp_descr_set
- 0, // tp_dictoffset
- 0, // tp_init
- PyType_GenericAlloc, // tp_alloc
- PyType_GenericNew, // tp_new
- PyObject_Del, // tp_free
+ FULL_MODULE_NAME ".EnumDescriptor", // tp_name
+ sizeof(PyBaseDescriptor), // tp_basicsize
+ 0, // tp_itemsize
+ 0, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A Enum Descriptor", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ enum_descriptor::Methods, // tp_getset
+ 0, // tp_members
+ enum_descriptor::Getters, // tp_getset
+ &descriptor::PyBaseDescriptor_Type, // tp_base
};
-google::protobuf::DescriptorPool* GetDescriptorPool() {
- if (g_descriptor_pool == NULL) {
- g_descriptor_pool = new google::protobuf::DescriptorPool(
- google::protobuf::DescriptorPool::generated_pool());
+PyObject* PyEnumDescriptor_FromDescriptor(
+ const EnumDescriptor* enum_descriptor) {
+ return descriptor::NewInternedDescriptor(
+ &PyEnumDescriptor_Type, enum_descriptor, NULL);
+}
+
+const EnumDescriptor* PyEnumDescriptor_AsDescriptor(PyObject* obj) {
+ if (!PyObject_TypeCheck(obj, &PyEnumDescriptor_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not an EnumDescriptor");
+ return NULL;
+ }
+ return reinterpret_cast<const EnumDescriptor*>(
+ reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor);
+}
+
+namespace enumvalue_descriptor {
+
+// Unchecked accessor to the C++ pointer.
+static const EnumValueDescriptor* _GetDescriptor(
+ PyBaseDescriptor *self) {
+ return reinterpret_cast<const EnumValueDescriptor*>(self->descriptor);
+}
+
+static PyObject* GetName(PyBaseDescriptor *self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->name());
+}
+
+static PyObject* GetNumber(PyBaseDescriptor *self, void *closure) {
+ return PyInt_FromLong(_GetDescriptor(self)->number());
+}
+
+static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) {
+ return PyInt_FromLong(_GetDescriptor(self)->index());
+}
+
+static PyObject* GetType(PyBaseDescriptor *self, void *closure) {
+ return PyEnumDescriptor_FromDescriptor(_GetDescriptor(self)->type());
+}
+
+static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) {
+ const EnumValueOptions& options(_GetDescriptor(self)->options());
+ if (&options != &EnumValueOptions::default_instance()) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
}
- return g_descriptor_pool;
+}
+static int SetHasOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("has_options");
+}
+
+static PyObject* GetOptions(PyBaseDescriptor *self) {
+ return GetOrBuildOptions(_GetDescriptor(self));
+}
+
+static int SetOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_options");
+}
+
+
+static PyGetSetDef Getters[] = {
+ { "name", (getter)GetName, NULL, "name"},
+ { "number", (getter)GetNumber, NULL, "number"},
+ { "index", (getter)GetIndex, NULL, "index"},
+ { "type", (getter)GetType, NULL, "index"},
+
+ { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
+ { "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ {NULL}
+};
+
+static PyMethodDef Methods[] = {
+ { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, },
+ {NULL}
+};
+
+} // namespace enumvalue_descriptor
+
+PyTypeObject PyEnumValueDescriptor_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".EnumValueDescriptor", // tp_name
+ sizeof(PyBaseDescriptor), // tp_basicsize
+ 0, // tp_itemsize
+ 0, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A EnumValue Descriptor", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ enumvalue_descriptor::Methods, // tp_methods
+ 0, // tp_members
+ enumvalue_descriptor::Getters, // tp_getset
+ &descriptor::PyBaseDescriptor_Type, // tp_base
+};
+
+PyObject* PyEnumValueDescriptor_FromDescriptor(
+ const EnumValueDescriptor* enumvalue_descriptor) {
+ return descriptor::NewInternedDescriptor(
+ &PyEnumValueDescriptor_Type, enumvalue_descriptor, NULL);
+}
+
+namespace file_descriptor {
+
+// Unchecked accessor to the C++ pointer.
+static const FileDescriptor* _GetDescriptor(PyFileDescriptor *self) {
+ return reinterpret_cast<const FileDescriptor*>(self->base.descriptor);
+}
+
+static void Dealloc(PyFileDescriptor* self) {
+ Py_XDECREF(self->serialized_pb);
+ descriptor::Dealloc(&self->base);
+}
+
+static PyObject* GetPool(PyFileDescriptor *self, void *closure) {
+ PyObject* pool = reinterpret_cast<PyObject*>(
+ GetDescriptorPool_FromPool(_GetDescriptor(self)->pool()));
+ Py_XINCREF(pool);
+ return pool;
+}
+
+static PyObject* GetName(PyFileDescriptor *self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->name());
}
-PyObject* Python_NewCDescriptorPool(PyObject* ignored, PyObject* args) {
- CDescriptorPool* cdescriptor_pool = PyObject_New(
- CDescriptorPool, &CDescriptorPool_Type);
- if (cdescriptor_pool == NULL) {
+static PyObject* GetPackage(PyFileDescriptor *self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->package());
+}
+
+static PyObject* GetSerializedPb(PyFileDescriptor *self, void *closure) {
+ PyObject *serialized_pb = self->serialized_pb;
+ if (serialized_pb != NULL) {
+ Py_INCREF(serialized_pb);
+ return serialized_pb;
+ }
+ FileDescriptorProto file_proto;
+ _GetDescriptor(self)->CopyTo(&file_proto);
+ string contents;
+ file_proto.SerializePartialToString(&contents);
+ self->serialized_pb = PyBytes_FromStringAndSize(
+ contents.c_str(), contents.size());
+ if (self->serialized_pb == NULL) {
return NULL;
}
- cdescriptor_pool->pool = GetDescriptorPool();
- return reinterpret_cast<PyObject*>(cdescriptor_pool);
+ Py_INCREF(self->serialized_pb);
+ return self->serialized_pb;
}
+static PyObject* GetMessageTypesByName(PyFileDescriptor* self, void *closure) {
+ return NewFileMessageTypesByName(_GetDescriptor(self));
+}
-// Collects errors that occur during proto file building to allow them to be
-// propagated in the python exception instead of only living in ERROR logs.
-class BuildFileErrorCollector : public google::protobuf::DescriptorPool::ErrorCollector {
- public:
- BuildFileErrorCollector() : error_message(""), had_errors(false) {}
+static PyObject* GetEnumTypesByName(PyFileDescriptor* self, void *closure) {
+ return NewFileEnumTypesByName(_GetDescriptor(self));
+}
- void AddError(const string& filename, const string& element_name,
- const Message* descriptor, ErrorLocation location,
- const string& message) {
- // Replicates the logging behavior that happens in the C++ implementation
- // when an error collector is not passed in.
- if (!had_errors) {
- error_message +=
- ("Invalid proto descriptor for file \"" + filename + "\":\n");
- }
- // As this only happens on failure and will result in the program not
- // running at all, no effort is made to optimize this string manipulation.
- error_message += (" " + element_name + ": " + message + "\n");
+static PyObject* GetExtensionsByName(PyFileDescriptor* self, void *closure) {
+ return NewFileExtensionsByName(_GetDescriptor(self));
+}
+
+static PyObject* GetDependencies(PyFileDescriptor* self, void *closure) {
+ return NewFileDependencies(_GetDescriptor(self));
+}
+
+static PyObject* GetPublicDependencies(PyFileDescriptor* self, void *closure) {
+ return NewFilePublicDependencies(_GetDescriptor(self));
+}
+
+static PyObject* GetHasOptions(PyFileDescriptor *self, void *closure) {
+ const FileOptions& options(_GetDescriptor(self)->options());
+ if (&options != &FileOptions::default_instance()) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
}
+}
+static int SetHasOptions(PyFileDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("has_options");
+}
+
+static PyObject* GetOptions(PyFileDescriptor *self) {
+ return GetOrBuildOptions(_GetDescriptor(self));
+}
+
+static int SetOptions(PyFileDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_options");
+}
- string error_message;
- bool had_errors;
+static PyObject* GetSyntax(PyFileDescriptor *self, void *closure) {
+ return PyString_InternFromString(
+ FileDescriptor::SyntaxName(_GetDescriptor(self)->syntax()));
+}
+
+static PyObject* CopyToProto(PyFileDescriptor *self, PyObject *target) {
+ return CopyToPythonProto<FileDescriptorProto>(_GetDescriptor(self), target);
+}
+
+static PyGetSetDef Getters[] = {
+ { "pool", (getter)GetPool, NULL, "pool"},
+ { "name", (getter)GetName, NULL, "name"},
+ { "package", (getter)GetPackage, NULL, "package"},
+ { "serialized_pb", (getter)GetSerializedPb},
+ { "message_types_by_name", (getter)GetMessageTypesByName, NULL,
+ "Messages by name"},
+ { "enum_types_by_name", (getter)GetEnumTypesByName, NULL, "Enums by name"},
+ { "extensions_by_name", (getter)GetExtensionsByName, NULL,
+ "Extensions by name"},
+ { "dependencies", (getter)GetDependencies, NULL, "Dependencies"},
+ { "public_dependencies", (getter)GetPublicDependencies, NULL, "Dependencies"},
+
+ { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
+ { "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ { "syntax", (getter)GetSyntax, (setter)NULL, "Syntax"},
+ {NULL}
+};
+
+static PyMethodDef Methods[] = {
+ { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, },
+ { "CopyToProto", (PyCFunction)CopyToProto, METH_O, },
+ {NULL}
};
-PyObject* Python_BuildFile(PyObject* ignored, PyObject* arg) {
- char* message_type;
- Py_ssize_t message_len;
+} // namespace file_descriptor
- if (PyBytes_AsStringAndSize(arg, &message_type, &message_len) < 0) {
+PyTypeObject PyFileDescriptor_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".FileDescriptor", // tp_name
+ sizeof(PyFileDescriptor), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)file_descriptor::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A File Descriptor", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ file_descriptor::Methods, // tp_methods
+ 0, // tp_members
+ file_descriptor::Getters, // tp_getset
+ &descriptor::PyBaseDescriptor_Type, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ 0, // tp_alloc
+ 0, // tp_new
+ PyObject_Del, // tp_free
+};
+
+PyObject* PyFileDescriptor_FromDescriptor(
+ const FileDescriptor* file_descriptor) {
+ return PyFileDescriptor_FromDescriptorWithSerializedPb(file_descriptor,
+ NULL);
+}
+
+PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb(
+ const FileDescriptor* file_descriptor, PyObject *serialized_pb) {
+ bool was_created;
+ PyObject* py_descriptor = descriptor::NewInternedDescriptor(
+ &PyFileDescriptor_Type, file_descriptor, &was_created);
+ if (py_descriptor == NULL) {
return NULL;
}
+ if (was_created) {
+ PyFileDescriptor* cfile_descriptor =
+ reinterpret_cast<PyFileDescriptor*>(py_descriptor);
+ Py_XINCREF(serialized_pb);
+ cfile_descriptor->serialized_pb = serialized_pb;
+ }
+ // TODO(amauryfa): In the case of a cached object, check that serialized_pb
+ // is the same as before.
+
+ return py_descriptor;
+}
- google::protobuf::FileDescriptorProto file_proto;
- if (!file_proto.ParseFromArray(message_type, message_len)) {
- PyErr_SetString(PyExc_TypeError, "Couldn't parse file content!");
+const FileDescriptor* PyFileDescriptor_AsDescriptor(PyObject* obj) {
+ if (!PyObject_TypeCheck(obj, &PyFileDescriptor_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a FileDescriptor");
return NULL;
}
+ return reinterpret_cast<const FileDescriptor*>(
+ reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor);
+}
+
+namespace oneof_descriptor {
+
+// Unchecked accessor to the C++ pointer.
+static const OneofDescriptor* _GetDescriptor(
+ PyBaseDescriptor *self) {
+ return reinterpret_cast<const OneofDescriptor*>(self->descriptor);
+}
+
+static PyObject* GetName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->name());
+}
+
+static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->full_name());
+}
+
+static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) {
+ return PyInt_FromLong(_GetDescriptor(self)->index());
+}
- if (google::protobuf::DescriptorPool::generated_pool()->FindFileByName(
- file_proto.name()) != NULL) {
+static PyObject* GetFields(PyBaseDescriptor* self, void *closure) {
+ return NewOneofFieldsSeq(_GetDescriptor(self));
+}
+
+static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) {
+ const Descriptor* containing_type =
+ _GetDescriptor(self)->containing_type();
+ if (containing_type) {
+ return PyMessageDescriptor_FromDescriptor(containing_type);
+ } else {
Py_RETURN_NONE;
}
+}
- BuildFileErrorCollector error_collector;
- const google::protobuf::FileDescriptor* descriptor =
- GetDescriptorPool()->BuildFileCollectingErrors(file_proto,
- &error_collector);
- if (descriptor == NULL) {
- PyErr_Format(PyExc_TypeError,
- "Couldn't build proto file into descriptor pool!\n%s",
- error_collector.error_message.c_str());
- return NULL;
+static PyGetSetDef Getters[] = {
+ { "name", (getter)GetName, NULL, "Name"},
+ { "full_name", (getter)GetFullName, NULL, "Full name"},
+ { "index", (getter)GetIndex, NULL, "Index"},
+
+ { "containing_type", (getter)GetContainingType, NULL, "Containing type"},
+ { "fields", (getter)GetFields, NULL, "Fields"},
+ {NULL}
+};
+
+} // namespace oneof_descriptor
+
+PyTypeObject PyOneofDescriptor_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".OneofDescriptor", // tp_name
+ sizeof(PyBaseDescriptor), // tp_basicsize
+ 0, // tp_itemsize
+ 0, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A Oneof Descriptor", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ 0, // tp_methods
+ 0, // tp_members
+ oneof_descriptor::Getters, // tp_getset
+ &descriptor::PyBaseDescriptor_Type, // tp_base
+};
+
+PyObject* PyOneofDescriptor_FromDescriptor(
+ const OneofDescriptor* oneof_descriptor) {
+ return descriptor::NewInternedDescriptor(
+ &PyOneofDescriptor_Type, oneof_descriptor, NULL);
+}
+
+// Add a enum values to a type dictionary.
+static bool AddEnumValues(PyTypeObject *type,
+ const EnumDescriptor* enum_descriptor) {
+ for (int i = 0; i < enum_descriptor->value_count(); ++i) {
+ const EnumValueDescriptor* value = enum_descriptor->value(i);
+ ScopedPyObjectPtr obj(PyInt_FromLong(value->number()));
+ if (obj == NULL) {
+ return false;
+ }
+ if (PyDict_SetItemString(type->tp_dict, value->name().c_str(), obj.get()) <
+ 0) {
+ return false;
+ }
}
+ return true;
+}
- Py_RETURN_NONE;
+static bool AddIntConstant(PyTypeObject *type, const char* name, int value) {
+ ScopedPyObjectPtr obj(PyInt_FromLong(value));
+ if (PyDict_SetItemString(type->tp_dict, name, obj.get()) < 0) {
+ return false;
+ }
+ return true;
}
+
bool InitDescriptor() {
- CFieldDescriptor_Type.tp_new = PyType_GenericNew;
- if (PyType_Ready(&CFieldDescriptor_Type) < 0)
+ if (PyType_Ready(&PyMessageDescriptor_Type) < 0)
+ return false;
+
+ if (PyType_Ready(&PyFieldDescriptor_Type) < 0)
+ return false;
+
+ if (!AddEnumValues(&PyFieldDescriptor_Type,
+ FieldDescriptorProto::Label_descriptor())) {
+ return false;
+ }
+ if (!AddEnumValues(&PyFieldDescriptor_Type,
+ FieldDescriptorProto::Type_descriptor())) {
+ return false;
+ }
+#define ADD_FIELDDESC_CONSTANT(NAME) AddIntConstant( \
+ &PyFieldDescriptor_Type, #NAME, FieldDescriptor::NAME)
+ if (!ADD_FIELDDESC_CONSTANT(CPPTYPE_INT32) ||
+ !ADD_FIELDDESC_CONSTANT(CPPTYPE_INT64) ||
+ !ADD_FIELDDESC_CONSTANT(CPPTYPE_UINT32) ||
+ !ADD_FIELDDESC_CONSTANT(CPPTYPE_UINT64) ||
+ !ADD_FIELDDESC_CONSTANT(CPPTYPE_DOUBLE) ||
+ !ADD_FIELDDESC_CONSTANT(CPPTYPE_FLOAT) ||
+ !ADD_FIELDDESC_CONSTANT(CPPTYPE_BOOL) ||
+ !ADD_FIELDDESC_CONSTANT(CPPTYPE_ENUM) ||
+ !ADD_FIELDDESC_CONSTANT(CPPTYPE_STRING) ||
+ !ADD_FIELDDESC_CONSTANT(CPPTYPE_MESSAGE)) {
+ return false;
+ }
+#undef ADD_FIELDDESC_CONSTANT
+
+ if (PyType_Ready(&PyEnumDescriptor_Type) < 0)
+ return false;
+
+ if (PyType_Ready(&PyEnumValueDescriptor_Type) < 0)
+ return false;
+
+ if (PyType_Ready(&PyFileDescriptor_Type) < 0)
+ return false;
+
+ if (PyType_Ready(&PyOneofDescriptor_Type) < 0)
return false;
- CDescriptorPool_Type.tp_new = PyType_GenericNew;
- if (PyType_Ready(&CDescriptorPool_Type) < 0)
+ if (!InitDescriptorMappingTypes())
return false;
return true;
diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h
index ae7a1b9c6..eb99df182 100644
--- a/python/google/protobuf/pyext/descriptor.h
+++ b/python/google/protobuf/pyext/descriptor.h
@@ -34,60 +34,61 @@
#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__
#include <Python.h>
-#include <structmember.h>
#include <google/protobuf/descriptor.h>
-#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
-typedef int Py_ssize_t;
-#define PY_SSIZE_T_MAX INT_MAX
-#define PY_SSIZE_T_MIN INT_MIN
-#endif
-
namespace google {
namespace protobuf {
namespace python {
-typedef struct CFieldDescriptor {
- PyObject_HEAD
-
- // The proto2 descriptor that this object represents.
- const google::protobuf::FieldDescriptor* descriptor;
-
- // Reference to the original field object in the Python DESCRIPTOR.
- PyObject* descriptor_field;
-} CFieldDescriptor;
-
-typedef struct {
- PyObject_HEAD
-
- const google::protobuf::DescriptorPool* pool;
-} CDescriptorPool;
-
-extern PyTypeObject CFieldDescriptor_Type;
-
-extern PyTypeObject CDescriptorPool_Type;
+extern PyTypeObject PyMessageDescriptor_Type;
+extern PyTypeObject PyFieldDescriptor_Type;
+extern PyTypeObject PyEnumDescriptor_Type;
+extern PyTypeObject PyEnumValueDescriptor_Type;
+extern PyTypeObject PyFileDescriptor_Type;
+extern PyTypeObject PyOneofDescriptor_Type;
-namespace cdescriptor_pool {
-
-// Looks up a field by name. Returns a CDescriptor corresponding to
-// the field on success, or NULL on failure.
-//
+// Wraps a Descriptor in a Python object.
+// The C++ pointer is usually borrowed from the global DescriptorPool.
+// In any case, it must stay alive as long as the Python object.
// Returns a new reference.
-PyObject* FindFieldByName(CDescriptorPool* self, PyObject* name);
-
-// Looks up an extension by name. Returns a CDescriptor corresponding
-// to the field on success, or NULL on failure.
-//
+PyObject* PyMessageDescriptor_FromDescriptor(const Descriptor* descriptor);
+PyObject* PyFieldDescriptor_FromDescriptor(const FieldDescriptor* descriptor);
+PyObject* PyEnumDescriptor_FromDescriptor(const EnumDescriptor* descriptor);
+PyObject* PyEnumValueDescriptor_FromDescriptor(
+ const EnumValueDescriptor* descriptor);
+PyObject* PyOneofDescriptor_FromDescriptor(const OneofDescriptor* descriptor);
+PyObject* PyFileDescriptor_FromDescriptor(
+ const FileDescriptor* file_descriptor);
+
+// Alternate constructor of PyFileDescriptor, used when we already have a
+// serialized FileDescriptorProto that can be cached.
// Returns a new reference.
-PyObject* FindExtensionByName(CDescriptorPool* self, PyObject* arg);
-
-} // namespace cdescriptor_pool
+PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb(
+ const FileDescriptor* file_descriptor, PyObject* serialized_pb);
+
+// Return the C++ descriptor pointer.
+// This function checks the parameter type; on error, return NULL with a Python
+// exception set.
+const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj);
+const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj);
+const EnumDescriptor* PyEnumDescriptor_AsDescriptor(PyObject* obj);
+const FileDescriptor* PyFileDescriptor_AsDescriptor(PyObject* obj);
+
+// Returns the raw C++ pointer.
+const void* PyDescriptor_AsVoidPtr(PyObject* obj);
+
+// Check that the calling Python code is the global scope of a _pb2.py module.
+// This function is used to support the current code generated by the proto
+// compiler, which insists on modifying descriptors after they have been
+// created.
+//
+// stacklevel indicates which Python frame should be the _pb2.py module.
+//
+// Don't use this function outside descriptor classes.
+bool _CalledFromGeneratedFile(int stacklevel);
-PyObject* Python_NewCDescriptorPool(PyObject* ignored, PyObject* args);
-PyObject* Python_BuildFile(PyObject* ignored, PyObject* args);
bool InitDescriptor();
-google::protobuf::DescriptorPool* GetDescriptorPool();
} // namespace python
} // namespace protobuf
diff --git a/python/google/protobuf/pyext/descriptor_containers.cc b/python/google/protobuf/pyext/descriptor_containers.cc
new file mode 100644
index 000000000..e505d8122
--- /dev/null
+++ b/python/google/protobuf/pyext/descriptor_containers.cc
@@ -0,0 +1,1652 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Mappings and Sequences of descriptors.
+// Used by Descriptor.fields_by_name, EnumDescriptor.values...
+//
+// They avoid the allocation of a full dictionary or a full list: they simply
+// store a pointer to the parent descriptor, use the C++ Descriptor methods (see
+// google/protobuf/descriptor.h) to retrieve other descriptors, and create
+// Python objects on the fly.
+//
+// The containers fully conform to abc.Mapping and abc.Sequence, and behave just
+// like read-only dictionaries and lists.
+//
+// Because the interface of C++ Descriptors is quite regular, this file actually
+// defines only three types, the exact behavior of a container is controlled by
+// a DescriptorContainerDef structure, which contains functions that uses the
+// public Descriptor API.
+//
+// Note: This DescriptorContainerDef is similar to the "virtual methods table"
+// that a C++ compiler generates for a class. We have to make it explicit
+// because the Python API is based on C, and does not play well with C++
+// inheritance.
+
+#include <Python.h>
+
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/pyext/descriptor_containers.h>
+#include <google/protobuf/pyext/descriptor_pool.h>
+#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+#if PY_MAJOR_VERSION >= 3
+ #define PyString_FromStringAndSize PyUnicode_FromStringAndSize
+ #define PyString_FromFormat PyUnicode_FromFormat
+ #define PyInt_FromLong PyLong_FromLong
+ #if PY_VERSION_HEX < 0x03030000
+ #error "Python 3.0 - 3.2 are not supported."
+ #endif
+ #define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob)? \
+ ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \
+ PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
+#endif
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+struct PyContainer;
+
+typedef int (*CountMethod)(PyContainer* self);
+typedef const void* (*GetByIndexMethod)(PyContainer* self, int index);
+typedef const void* (*GetByNameMethod)(PyContainer* self, const string& name);
+typedef const void* (*GetByCamelcaseNameMethod)(PyContainer* self,
+ const string& name);
+typedef const void* (*GetByNumberMethod)(PyContainer* self, int index);
+typedef PyObject* (*NewObjectFromItemMethod)(const void* descriptor);
+typedef const string& (*GetItemNameMethod)(const void* descriptor);
+typedef const string& (*GetItemCamelcaseNameMethod)(const void* descriptor);
+typedef int (*GetItemNumberMethod)(const void* descriptor);
+typedef int (*GetItemIndexMethod)(const void* descriptor);
+
+struct DescriptorContainerDef {
+ const char* mapping_name;
+ // Returns the number of items in the container.
+ CountMethod count_fn;
+ // Retrieve item by index (usually the order of declaration in the proto file)
+ // Used by sequences, but also iterators. 0 <= index < Count().
+ GetByIndexMethod get_by_index_fn;
+ // Retrieve item by name (usually a call to some 'FindByName' method).
+ // Used by "by_name" mappings.
+ GetByNameMethod get_by_name_fn;
+ // Retrieve item by camelcase name (usually a call to some
+ // 'FindByCamelcaseName' method). Used by "by_camelcase_name" mappings.
+ GetByCamelcaseNameMethod get_by_camelcase_name_fn;
+ // Retrieve item by declared number (field tag, or enum value).
+ // Used by "by_number" mappings.
+ GetByNumberMethod get_by_number_fn;
+ // Converts a item C++ descriptor to a Python object. Returns a new reference.
+ NewObjectFromItemMethod new_object_from_item_fn;
+ // Retrieve the name of an item. Used by iterators on "by_name" mappings.
+ GetItemNameMethod get_item_name_fn;
+ // Retrieve the camelcase name of an item. Used by iterators on
+ // "by_camelcase_name" mappings.
+ GetItemCamelcaseNameMethod get_item_camelcase_name_fn;
+ // Retrieve the number of an item. Used by iterators on "by_number" mappings.
+ GetItemNumberMethod get_item_number_fn;
+ // Retrieve the index of an item for the container type.
+ // Used by "__contains__".
+ // If not set, "x in sequence" will do a linear search.
+ GetItemIndexMethod get_item_index_fn;
+};
+
+struct PyContainer {
+ PyObject_HEAD
+
+ // The proto2 descriptor this container belongs to the global DescriptorPool.
+ const void* descriptor;
+
+ // A pointer to a static structure with function pointers that control the
+ // behavior of the container. Very similar to the table of virtual functions
+ // of a C++ class.
+ const DescriptorContainerDef* container_def;
+
+ // The kind of container: list, or dict by name or value.
+ enum ContainerKind {
+ KIND_SEQUENCE,
+ KIND_BYNAME,
+ KIND_BYCAMELCASENAME,
+ KIND_BYNUMBER,
+ } kind;
+};
+
+struct PyContainerIterator {
+ PyObject_HEAD
+
+ // The container we are iterating over. Own a reference.
+ PyContainer* container;
+
+ // The current index in the iterator.
+ int index;
+
+ // The kind of container: list, or dict by name or value.
+ enum IterKind {
+ KIND_ITERKEY,
+ KIND_ITERVALUE,
+ KIND_ITERITEM,
+ KIND_ITERVALUE_REVERSED, // For sequences
+ } kind;
+};
+
+namespace descriptor {
+
+// Returns the C++ item descriptor for a given Python key.
+// When the descriptor is found, return true and set *item.
+// When the descriptor is not found, return true, but set *item to NULL.
+// On error, returns false with an exception set.
+static bool _GetItemByKey(PyContainer* self, PyObject* key, const void** item) {
+ switch (self->kind) {
+ case PyContainer::KIND_BYNAME:
+ {
+ char* name;
+ Py_ssize_t name_size;
+ if (PyString_AsStringAndSize(key, &name, &name_size) < 0) {
+ if (PyErr_ExceptionMatches(PyExc_TypeError)) {
+ // Not a string, cannot be in the container.
+ PyErr_Clear();
+ *item = NULL;
+ return true;
+ }
+ return false;
+ }
+ *item = self->container_def->get_by_name_fn(
+ self, string(name, name_size));
+ return true;
+ }
+ case PyContainer::KIND_BYCAMELCASENAME:
+ {
+ char* camelcase_name;
+ Py_ssize_t name_size;
+ if (PyString_AsStringAndSize(key, &camelcase_name, &name_size) < 0) {
+ if (PyErr_ExceptionMatches(PyExc_TypeError)) {
+ // Not a string, cannot be in the container.
+ PyErr_Clear();
+ *item = NULL;
+ return true;
+ }
+ return false;
+ }
+ *item = self->container_def->get_by_camelcase_name_fn(
+ self, string(camelcase_name, name_size));
+ return true;
+ }
+ case PyContainer::KIND_BYNUMBER:
+ {
+ Py_ssize_t number = PyNumber_AsSsize_t(key, NULL);
+ if (number == -1 && PyErr_Occurred()) {
+ if (PyErr_ExceptionMatches(PyExc_TypeError)) {
+ // Not a number, cannot be in the container.
+ PyErr_Clear();
+ *item = NULL;
+ return true;
+ }
+ return false;
+ }
+ *item = self->container_def->get_by_number_fn(self, number);
+ return true;
+ }
+ default:
+ PyErr_SetNone(PyExc_NotImplementedError);
+ return false;
+ }
+}
+
+// Returns the key of the object at the given index.
+// Used when iterating over mappings.
+static PyObject* _NewKey_ByIndex(PyContainer* self, Py_ssize_t index) {
+ const void* item = self->container_def->get_by_index_fn(self, index);
+ switch (self->kind) {
+ case PyContainer::KIND_BYNAME:
+ {
+ const string& name(self->container_def->get_item_name_fn(item));
+ return PyString_FromStringAndSize(name.c_str(), name.size());
+ }
+ case PyContainer::KIND_BYCAMELCASENAME:
+ {
+ const string& name(
+ self->container_def->get_item_camelcase_name_fn(item));
+ return PyString_FromStringAndSize(name.c_str(), name.size());
+ }
+ case PyContainer::KIND_BYNUMBER:
+ {
+ int value = self->container_def->get_item_number_fn(item);
+ return PyInt_FromLong(value);
+ }
+ default:
+ PyErr_SetNone(PyExc_NotImplementedError);
+ return NULL;
+ }
+}
+
+// Returns the object at the given index.
+// Also used when iterating over mappings.
+static PyObject* _NewObj_ByIndex(PyContainer* self, Py_ssize_t index) {
+ return self->container_def->new_object_from_item_fn(
+ self->container_def->get_by_index_fn(self, index));
+}
+
+static Py_ssize_t Length(PyContainer* self) {
+ return self->container_def->count_fn(self);
+}
+
+// The DescriptorMapping type.
+
+static PyObject* Subscript(PyContainer* self, PyObject* key) {
+ const void* item = NULL;
+ if (!_GetItemByKey(self, key, &item)) {
+ return NULL;
+ }
+ if (!item) {
+ PyErr_SetObject(PyExc_KeyError, key);
+ return NULL;
+ }
+ return self->container_def->new_object_from_item_fn(item);
+}
+
+static int AssSubscript(PyContainer* self, PyObject* key, PyObject* value) {
+ if (_CalledFromGeneratedFile(0)) {
+ return 0;
+ }
+ PyErr_Format(PyExc_TypeError,
+ "'%.200s' object does not support item assignment",
+ Py_TYPE(self)->tp_name);
+ return -1;
+}
+
+static PyMappingMethods MappingMappingMethods = {
+ (lenfunc)Length, // mp_length
+ (binaryfunc)Subscript, // mp_subscript
+ (objobjargproc)AssSubscript, // mp_ass_subscript
+};
+
+static int Contains(PyContainer* self, PyObject* key) {
+ const void* item = NULL;
+ if (!_GetItemByKey(self, key, &item)) {
+ return -1;
+ }
+ if (item) {
+ return 1;
+ } else {
+ return 0;
+ }
+}
+
+static PyObject* ContainerRepr(PyContainer* self) {
+ const char* kind = "";
+ switch (self->kind) {
+ case PyContainer::KIND_SEQUENCE:
+ kind = "sequence";
+ break;
+ case PyContainer::KIND_BYNAME:
+ kind = "mapping by name";
+ break;
+ case PyContainer::KIND_BYCAMELCASENAME:
+ kind = "mapping by camelCase name";
+ break;
+ case PyContainer::KIND_BYNUMBER:
+ kind = "mapping by number";
+ break;
+ }
+ return PyString_FromFormat(
+ "<%s %s>", self->container_def->mapping_name, kind);
+}
+
+extern PyTypeObject DescriptorMapping_Type;
+extern PyTypeObject DescriptorSequence_Type;
+
+// A sequence container can only be equal to another sequence container, or (for
+// backward compatibility) to a list containing the same items.
+// Returns 1 if equal, 0 if unequal, -1 on error.
+static int DescriptorSequence_Equal(PyContainer* self, PyObject* other) {
+ // Check the identity of C++ pointers.
+ if (PyObject_TypeCheck(other, &DescriptorSequence_Type)) {
+ PyContainer* other_container = reinterpret_cast<PyContainer*>(other);
+ if (self->descriptor == other_container->descriptor &&
+ self->container_def == other_container->container_def &&
+ self->kind == other_container->kind) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+
+ // If other is a list
+ if (PyList_Check(other)) {
+ // return list(self) == other
+ int size = Length(self);
+ if (size != PyList_Size(other)) {
+ return false;
+ }
+ for (int index = 0; index < size; index++) {
+ ScopedPyObjectPtr value1(_NewObj_ByIndex(self, index));
+ if (value1 == NULL) {
+ return -1;
+ }
+ PyObject* value2 = PyList_GetItem(other, index);
+ if (value2 == NULL) {
+ return -1;
+ }
+ int cmp = PyObject_RichCompareBool(value1.get(), value2, Py_EQ);
+ if (cmp != 1) // error or not equal
+ return cmp;
+ }
+ // All items were found and equal
+ return 1;
+ }
+
+ // Any other object is different.
+ return 0;
+}
+
+// A mapping container can only be equal to another mapping container, or (for
+// backward compatibility) to a dict containing the same items.
+// Returns 1 if equal, 0 if unequal, -1 on error.
+static int DescriptorMapping_Equal(PyContainer* self, PyObject* other) {
+ // Check the identity of C++ pointers.
+ if (PyObject_TypeCheck(other, &DescriptorMapping_Type)) {
+ PyContainer* other_container = reinterpret_cast<PyContainer*>(other);
+ if (self->descriptor == other_container->descriptor &&
+ self->container_def == other_container->container_def &&
+ self->kind == other_container->kind) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+
+ // If other is a dict
+ if (PyDict_Check(other)) {
+ // equivalent to dict(self.items()) == other
+ int size = Length(self);
+ if (size != PyDict_Size(other)) {
+ return false;
+ }
+ for (int index = 0; index < size; index++) {
+ ScopedPyObjectPtr key(_NewKey_ByIndex(self, index));
+ if (key == NULL) {
+ return -1;
+ }
+ ScopedPyObjectPtr value1(_NewObj_ByIndex(self, index));
+ if (value1 == NULL) {
+ return -1;
+ }
+ PyObject* value2 = PyDict_GetItem(other, key.get());
+ if (value2 == NULL) {
+ // Not found in the other dictionary
+ return 0;
+ }
+ int cmp = PyObject_RichCompareBool(value1.get(), value2, Py_EQ);
+ if (cmp != 1) // error or not equal
+ return cmp;
+ }
+ // All items were found and equal
+ return 1;
+ }
+
+ // Any other object is different.
+ return 0;
+}
+
+static PyObject* RichCompare(PyContainer* self, PyObject* other, int opid) {
+ if (opid != Py_EQ && opid != Py_NE) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+
+ int result;
+
+ if (self->kind == PyContainer::KIND_SEQUENCE) {
+ result = DescriptorSequence_Equal(self, other);
+ } else {
+ result = DescriptorMapping_Equal(self, other);
+ }
+ if (result < 0) {
+ return NULL;
+ }
+ if (result ^ (opid == Py_NE)) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
+static PySequenceMethods MappingSequenceMethods = {
+ 0, // sq_length
+ 0, // sq_concat
+ 0, // sq_repeat
+ 0, // sq_item
+ 0, // sq_slice
+ 0, // sq_ass_item
+ 0, // sq_ass_slice
+ (objobjproc)Contains, // sq_contains
+};
+
+static PyObject* Get(PyContainer* self, PyObject* args) {
+ PyObject* key;
+ PyObject* default_value = Py_None;
+ if (!PyArg_UnpackTuple(args, "get", 1, 2, &key, &default_value)) {
+ return NULL;
+ }
+
+ const void* item;
+ if (!_GetItemByKey(self, key, &item)) {
+ return NULL;
+ }
+ if (item == NULL) {
+ Py_INCREF(default_value);
+ return default_value;
+ }
+ return self->container_def->new_object_from_item_fn(item);
+}
+
+static PyObject* Keys(PyContainer* self, PyObject* args) {
+ Py_ssize_t count = Length(self);
+ ScopedPyObjectPtr list(PyList_New(count));
+ if (list == NULL) {
+ return NULL;
+ }
+ for (Py_ssize_t index = 0; index < count; ++index) {
+ PyObject* key = _NewKey_ByIndex(self, index);
+ if (key == NULL) {
+ return NULL;
+ }
+ PyList_SET_ITEM(list.get(), index, key);
+ }
+ return list.release();
+}
+
+static PyObject* Values(PyContainer* self, PyObject* args) {
+ Py_ssize_t count = Length(self);
+ ScopedPyObjectPtr list(PyList_New(count));
+ if (list == NULL) {
+ return NULL;
+ }
+ for (Py_ssize_t index = 0; index < count; ++index) {
+ PyObject* value = _NewObj_ByIndex(self, index);
+ if (value == NULL) {
+ return NULL;
+ }
+ PyList_SET_ITEM(list.get(), index, value);
+ }
+ return list.release();
+}
+
+static PyObject* Items(PyContainer* self, PyObject* args) {
+ Py_ssize_t count = Length(self);
+ ScopedPyObjectPtr list(PyList_New(count));
+ if (list == NULL) {
+ return NULL;
+ }
+ for (Py_ssize_t index = 0; index < count; ++index) {
+ ScopedPyObjectPtr obj(PyTuple_New(2));
+ if (obj == NULL) {
+ return NULL;
+ }
+ PyObject* key = _NewKey_ByIndex(self, index);
+ if (key == NULL) {
+ return NULL;
+ }
+ PyTuple_SET_ITEM(obj.get(), 0, key);
+ PyObject* value = _NewObj_ByIndex(self, index);
+ if (value == NULL) {
+ return NULL;
+ }
+ PyTuple_SET_ITEM(obj.get(), 1, value);
+ PyList_SET_ITEM(list.get(), index, obj.release());
+ }
+ return list.release();
+}
+
+static PyObject* NewContainerIterator(PyContainer* mapping,
+ PyContainerIterator::IterKind kind);
+
+static PyObject* Iter(PyContainer* self) {
+ return NewContainerIterator(self, PyContainerIterator::KIND_ITERKEY);
+}
+static PyObject* IterKeys(PyContainer* self, PyObject* args) {
+ return NewContainerIterator(self, PyContainerIterator::KIND_ITERKEY);
+}
+static PyObject* IterValues(PyContainer* self, PyObject* args) {
+ return NewContainerIterator(self, PyContainerIterator::KIND_ITERVALUE);
+}
+static PyObject* IterItems(PyContainer* self, PyObject* args) {
+ return NewContainerIterator(self, PyContainerIterator::KIND_ITERITEM);
+}
+
+static PyMethodDef MappingMethods[] = {
+ { "get", (PyCFunction)Get, METH_VARARGS, },
+ { "keys", (PyCFunction)Keys, METH_NOARGS, },
+ { "values", (PyCFunction)Values, METH_NOARGS, },
+ { "items", (PyCFunction)Items, METH_NOARGS, },
+ { "iterkeys", (PyCFunction)IterKeys, METH_NOARGS, },
+ { "itervalues", (PyCFunction)IterValues, METH_NOARGS, },
+ { "iteritems", (PyCFunction)IterItems, METH_NOARGS, },
+ {NULL}
+};
+
+PyTypeObject DescriptorMapping_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "DescriptorMapping", // tp_name
+ sizeof(PyContainer), // tp_basicsize
+ 0, // tp_itemsize
+ 0, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ (reprfunc)ContainerRepr, // tp_repr
+ 0, // tp_as_number
+ &MappingSequenceMethods, // tp_as_sequence
+ &MappingMappingMethods, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ 0, // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ (richcmpfunc)RichCompare, // tp_richcompare
+ 0, // tp_weaklistoffset
+ (getiterfunc)Iter, // tp_iter
+ 0, // tp_iternext
+ MappingMethods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ 0, // tp_alloc
+ 0, // tp_new
+ 0, // tp_free
+};
+
+// The DescriptorSequence type.
+
+static PyObject* GetItem(PyContainer* self, Py_ssize_t index) {
+ if (index < 0) {
+ index += Length(self);
+ }
+ if (index < 0 || index >= Length(self)) {
+ PyErr_SetString(PyExc_IndexError, "index out of range");
+ return NULL;
+ }
+ return _NewObj_ByIndex(self, index);
+}
+
+// Returns the position of the item in the sequence, of -1 if not found.
+// This function never fails.
+int Find(PyContainer* self, PyObject* item) {
+ // The item can only be in one position: item.index.
+ // Check that self[item.index] == item, it's faster than a linear search.
+ //
+ // This assumes that sequences are only defined by syntax of the .proto file:
+ // a specific item belongs to only one sequence, depending on its position in
+ // the .proto file definition.
+ const void* descriptor_ptr = PyDescriptor_AsVoidPtr(item);
+ if (descriptor_ptr == NULL) {
+ // Not a descriptor, it cannot be in the list.
+ return -1;
+ }
+ if (self->container_def->get_item_index_fn) {
+ int index = self->container_def->get_item_index_fn(descriptor_ptr);
+ if (index < 0 || index >= Length(self)) {
+ // This index is not from this collection.
+ return -1;
+ }
+ if (self->container_def->get_by_index_fn(self, index) != descriptor_ptr) {
+ // The descriptor at this index is not the same.
+ return -1;
+ }
+ // self[item.index] == item, so return the index.
+ return index;
+ } else {
+ // Fall back to linear search.
+ int length = Length(self);
+ for (int index=0; index < length; index++) {
+ if (self->container_def->get_by_index_fn(self, index) == descriptor_ptr) {
+ return index;
+ }
+ }
+ // Not found
+ return -1;
+ }
+}
+
+// Implements list.index(): the position of the item is in the sequence.
+static PyObject* Index(PyContainer* self, PyObject* item) {
+ int position = Find(self, item);
+ if (position < 0) {
+ // Not found
+ PyErr_SetNone(PyExc_ValueError);
+ return NULL;
+ } else {
+ return PyInt_FromLong(position);
+ }
+}
+// Implements "list.__contains__()": is the object in the sequence.
+static int SeqContains(PyContainer* self, PyObject* item) {
+ int position = Find(self, item);
+ if (position < 0) {
+ return 0;
+ } else {
+ return 1;
+ }
+}
+
+// Implements list.count(): number of occurrences of the item in the sequence.
+// An item can only appear once in a sequence. If it exists, return 1.
+static PyObject* Count(PyContainer* self, PyObject* item) {
+ int position = Find(self, item);
+ if (position < 0) {
+ return PyInt_FromLong(0);
+ } else {
+ return PyInt_FromLong(1);
+ }
+}
+
+static PyObject* Append(PyContainer* self, PyObject* args) {
+ if (_CalledFromGeneratedFile(0)) {
+ Py_RETURN_NONE;
+ }
+ PyErr_Format(PyExc_TypeError,
+ "'%.200s' object is not a mutable sequence",
+ Py_TYPE(self)->tp_name);
+ return NULL;
+}
+
+static PyObject* Reversed(PyContainer* self, PyObject* args) {
+ return NewContainerIterator(self,
+ PyContainerIterator::KIND_ITERVALUE_REVERSED);
+}
+
+static PyMethodDef SeqMethods[] = {
+ { "index", (PyCFunction)Index, METH_O, },
+ { "count", (PyCFunction)Count, METH_O, },
+ { "append", (PyCFunction)Append, METH_O, },
+ { "__reversed__", (PyCFunction)Reversed, METH_NOARGS, },
+ {NULL}
+};
+
+static PySequenceMethods SeqSequenceMethods = {
+ (lenfunc)Length, // sq_length
+ 0, // sq_concat
+ 0, // sq_repeat
+ (ssizeargfunc)GetItem, // sq_item
+ 0, // sq_slice
+ 0, // sq_ass_item
+ 0, // sq_ass_slice
+ (objobjproc)SeqContains, // sq_contains
+};
+
+PyTypeObject DescriptorSequence_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "DescriptorSequence", // tp_name
+ sizeof(PyContainer), // tp_basicsize
+ 0, // tp_itemsize
+ 0, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ (reprfunc)ContainerRepr, // tp_repr
+ 0, // tp_as_number
+ &SeqSequenceMethods, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ 0, // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ (richcmpfunc)RichCompare, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ SeqMethods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ 0, // tp_alloc
+ 0, // tp_new
+ 0, // tp_free
+};
+
+static PyObject* NewMappingByName(
+ DescriptorContainerDef* container_def, const void* descriptor) {
+ PyContainer* self = PyObject_New(PyContainer, &DescriptorMapping_Type);
+ if (self == NULL) {
+ return NULL;
+ }
+ self->descriptor = descriptor;
+ self->container_def = container_def;
+ self->kind = PyContainer::KIND_BYNAME;
+ return reinterpret_cast<PyObject*>(self);
+}
+
+static PyObject* NewMappingByCamelcaseName(
+ DescriptorContainerDef* container_def, const void* descriptor) {
+ PyContainer* self = PyObject_New(PyContainer, &DescriptorMapping_Type);
+ if (self == NULL) {
+ return NULL;
+ }
+ self->descriptor = descriptor;
+ self->container_def = container_def;
+ self->kind = PyContainer::KIND_BYCAMELCASENAME;
+ return reinterpret_cast<PyObject*>(self);
+}
+
+static PyObject* NewMappingByNumber(
+ DescriptorContainerDef* container_def, const void* descriptor) {
+ if (container_def->get_by_number_fn == NULL ||
+ container_def->get_item_number_fn == NULL) {
+ PyErr_SetNone(PyExc_NotImplementedError);
+ return NULL;
+ }
+ PyContainer* self = PyObject_New(PyContainer, &DescriptorMapping_Type);
+ if (self == NULL) {
+ return NULL;
+ }
+ self->descriptor = descriptor;
+ self->container_def = container_def;
+ self->kind = PyContainer::KIND_BYNUMBER;
+ return reinterpret_cast<PyObject*>(self);
+}
+
+static PyObject* NewSequence(
+ DescriptorContainerDef* container_def, const void* descriptor) {
+ PyContainer* self = PyObject_New(PyContainer, &DescriptorSequence_Type);
+ if (self == NULL) {
+ return NULL;
+ }
+ self->descriptor = descriptor;
+ self->container_def = container_def;
+ self->kind = PyContainer::KIND_SEQUENCE;
+ return reinterpret_cast<PyObject*>(self);
+}
+
+// Implement iterators over PyContainers.
+
+static void Iterator_Dealloc(PyContainerIterator* self) {
+ Py_CLEAR(self->container);
+ Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+}
+
+static PyObject* Iterator_Next(PyContainerIterator* self) {
+ int count = self->container->container_def->count_fn(self->container);
+ if (self->index >= count) {
+ // Return NULL with no exception to indicate the end.
+ return NULL;
+ }
+ int index = self->index;
+ self->index += 1;
+ switch (self->kind) {
+ case PyContainerIterator::KIND_ITERKEY:
+ return _NewKey_ByIndex(self->container, index);
+ case PyContainerIterator::KIND_ITERVALUE:
+ return _NewObj_ByIndex(self->container, index);
+ case PyContainerIterator::KIND_ITERVALUE_REVERSED:
+ return _NewObj_ByIndex(self->container, count - index - 1);
+ case PyContainerIterator::KIND_ITERITEM:
+ {
+ PyObject* obj = PyTuple_New(2);
+ if (obj == NULL) {
+ return NULL;
+ }
+ PyObject* key = _NewKey_ByIndex(self->container, index);
+ if (key == NULL) {
+ Py_DECREF(obj);
+ return NULL;
+ }
+ PyTuple_SET_ITEM(obj, 0, key);
+ PyObject* value = _NewObj_ByIndex(self->container, index);
+ if (value == NULL) {
+ Py_DECREF(obj);
+ return NULL;
+ }
+ PyTuple_SET_ITEM(obj, 1, value);
+ return obj;
+ }
+ default:
+ PyErr_SetNone(PyExc_NotImplementedError);
+ return NULL;
+ }
+}
+
+static PyTypeObject ContainerIterator_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "DescriptorContainerIterator", // tp_name
+ sizeof(PyContainerIterator), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)Iterator_Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ 0, // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ PyObject_SelfIter, // tp_iter
+ (iternextfunc)Iterator_Next, // tp_iternext
+ 0, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ 0, // tp_alloc
+ 0, // tp_new
+ 0, // tp_free
+};
+
+static PyObject* NewContainerIterator(PyContainer* container,
+ PyContainerIterator::IterKind kind) {
+ PyContainerIterator* self = PyObject_New(PyContainerIterator,
+ &ContainerIterator_Type);
+ if (self == NULL) {
+ return NULL;
+ }
+ Py_INCREF(container);
+ self->container = container;
+ self->kind = kind;
+ self->index = 0;
+
+ return reinterpret_cast<PyObject*>(self);
+}
+
+} // namespace descriptor
+
+// Now define the real collections!
+
+namespace message_descriptor {
+
+typedef const Descriptor* ParentDescriptor;
+
+static ParentDescriptor GetDescriptor(PyContainer* self) {
+ return reinterpret_cast<ParentDescriptor>(self->descriptor);
+}
+
+namespace fields {
+
+typedef const FieldDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->field_count();
+}
+
+static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindFieldByName(name);
+}
+
+static ItemDescriptor GetByCamelcaseName(PyContainer* self,
+ const string& name) {
+ return GetDescriptor(self)->FindFieldByCamelcaseName(name);
+}
+
+static ItemDescriptor GetByNumber(PyContainer* self, int number) {
+ return GetDescriptor(self)->FindFieldByNumber(number);
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->field(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyFieldDescriptor_FromDescriptor(item);
+}
+
+static const string& GetItemName(ItemDescriptor item) {
+ return item->name();
+}
+
+static const string& GetItemCamelcaseName(ItemDescriptor item) {
+ return item->camelcase_name();
+}
+
+static int GetItemNumber(ItemDescriptor item) {
+ return item->number();
+}
+
+static int GetItemIndex(ItemDescriptor item) {
+ return item->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "MessageFields",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)GetByName,
+ (GetByCamelcaseNameMethod)GetByCamelcaseName,
+ (GetByNumberMethod)GetByNumber,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)GetItemName,
+ (GetItemCamelcaseNameMethod)GetItemCamelcaseName,
+ (GetItemNumberMethod)GetItemNumber,
+ (GetItemIndexMethod)GetItemIndex,
+};
+
+} // namespace fields
+
+PyObject* NewMessageFieldsByName(ParentDescriptor descriptor) {
+ return descriptor::NewMappingByName(&fields::ContainerDef, descriptor);
+}
+
+PyObject* NewMessageFieldsByCamelcaseName(ParentDescriptor descriptor) {
+ return descriptor::NewMappingByCamelcaseName(&fields::ContainerDef,
+ descriptor);
+}
+
+PyObject* NewMessageFieldsByNumber(ParentDescriptor descriptor) {
+ return descriptor::NewMappingByNumber(&fields::ContainerDef, descriptor);
+}
+
+PyObject* NewMessageFieldsSeq(ParentDescriptor descriptor) {
+ return descriptor::NewSequence(&fields::ContainerDef, descriptor);
+}
+
+namespace nested_types {
+
+typedef const Descriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->nested_type_count();
+}
+
+static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindNestedTypeByName(name);
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->nested_type(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyMessageDescriptor_FromDescriptor(item);
+}
+
+static const string& GetItemName(ItemDescriptor item) {
+ return item->name();
+}
+
+static int GetItemIndex(ItemDescriptor item) {
+ return item->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "MessageNestedTypes",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)GetByName,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)GetItemName,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)GetItemIndex,
+};
+
+} // namespace nested_types
+
+PyObject* NewMessageNestedTypesSeq(ParentDescriptor descriptor) {
+ return descriptor::NewSequence(&nested_types::ContainerDef, descriptor);
+}
+
+PyObject* NewMessageNestedTypesByName(ParentDescriptor descriptor) {
+ return descriptor::NewMappingByName(&nested_types::ContainerDef, descriptor);
+}
+
+namespace enums {
+
+typedef const EnumDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->enum_type_count();
+}
+
+static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindEnumTypeByName(name);
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->enum_type(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyEnumDescriptor_FromDescriptor(item);
+}
+
+static const string& GetItemName(ItemDescriptor item) {
+ return item->name();
+}
+
+static int GetItemIndex(ItemDescriptor item) {
+ return item->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "MessageNestedEnums",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)GetByName,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)GetItemName,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)GetItemIndex,
+};
+
+} // namespace enums
+
+PyObject* NewMessageEnumsByName(ParentDescriptor descriptor) {
+ return descriptor::NewMappingByName(&enums::ContainerDef, descriptor);
+}
+
+PyObject* NewMessageEnumsSeq(ParentDescriptor descriptor) {
+ return descriptor::NewSequence(&enums::ContainerDef, descriptor);
+}
+
+namespace enumvalues {
+
+// This is the "enum_values_by_name" mapping, which collects values from all
+// enum types in a message.
+//
+// Note that the behavior of the C++ descriptor is different: it will search and
+// return the first value that matches the name, whereas the Python
+// implementation retrieves the last one.
+
+typedef const EnumValueDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ int count = 0;
+ for (int i = 0; i < GetDescriptor(self)->enum_type_count(); ++i) {
+ count += GetDescriptor(self)->enum_type(i)->value_count();
+ }
+ return count;
+}
+
+static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindEnumValueByName(name);
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ // This is not optimal, but the number of enums *types* in a given message
+ // is small. This function is only used when iterating over the mapping.
+ const EnumDescriptor* enum_type = NULL;
+ int enum_type_count = GetDescriptor(self)->enum_type_count();
+ for (int i = 0; i < enum_type_count; ++i) {
+ enum_type = GetDescriptor(self)->enum_type(i);
+ int enum_value_count = enum_type->value_count();
+ if (index < enum_value_count) {
+ // Found it!
+ break;
+ }
+ index -= enum_value_count;
+ }
+ // The next statement cannot overflow, because this function is only called by
+ // internal iterators which ensure that 0 <= index < Count().
+ return enum_type->value(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyEnumValueDescriptor_FromDescriptor(item);
+}
+
+static const string& GetItemName(ItemDescriptor item) {
+ return item->name();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "MessageEnumValues",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)GetByName,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)GetItemName,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)NULL,
+};
+
+} // namespace enumvalues
+
+PyObject* NewMessageEnumValuesByName(ParentDescriptor descriptor) {
+ return descriptor::NewMappingByName(&enumvalues::ContainerDef, descriptor);
+}
+
+namespace extensions {
+
+typedef const FieldDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->extension_count();
+}
+
+static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindExtensionByName(name);
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->extension(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyFieldDescriptor_FromDescriptor(item);
+}
+
+static const string& GetItemName(ItemDescriptor item) {
+ return item->name();
+}
+
+static int GetItemIndex(ItemDescriptor item) {
+ return item->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "MessageExtensions",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)GetByName,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)GetItemName,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)GetItemIndex,
+};
+
+} // namespace extensions
+
+PyObject* NewMessageExtensionsByName(ParentDescriptor descriptor) {
+ return descriptor::NewMappingByName(&extensions::ContainerDef, descriptor);
+}
+
+PyObject* NewMessageExtensionsSeq(ParentDescriptor descriptor) {
+ return descriptor::NewSequence(&extensions::ContainerDef, descriptor);
+}
+
+namespace oneofs {
+
+typedef const OneofDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->oneof_decl_count();
+}
+
+static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindOneofByName(name);
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->oneof_decl(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyOneofDescriptor_FromDescriptor(item);
+}
+
+static const string& GetItemName(ItemDescriptor item) {
+ return item->name();
+}
+
+static int GetItemIndex(ItemDescriptor item) {
+ return item->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "MessageOneofs",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)GetByName,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)GetItemName,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)GetItemIndex,
+};
+
+} // namespace oneofs
+
+PyObject* NewMessageOneofsByName(ParentDescriptor descriptor) {
+ return descriptor::NewMappingByName(&oneofs::ContainerDef, descriptor);
+}
+
+PyObject* NewMessageOneofsSeq(ParentDescriptor descriptor) {
+ return descriptor::NewSequence(&oneofs::ContainerDef, descriptor);
+}
+
+} // namespace message_descriptor
+
+namespace enum_descriptor {
+
+typedef const EnumDescriptor* ParentDescriptor;
+
+static ParentDescriptor GetDescriptor(PyContainer* self) {
+ return reinterpret_cast<ParentDescriptor>(self->descriptor);
+}
+
+namespace enumvalues {
+
+typedef const EnumValueDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->value_count();
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->value(index);
+}
+
+static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindValueByName(name);
+}
+
+static ItemDescriptor GetByNumber(PyContainer* self, int number) {
+ return GetDescriptor(self)->FindValueByNumber(number);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyEnumValueDescriptor_FromDescriptor(item);
+}
+
+static const string& GetItemName(ItemDescriptor item) {
+ return item->name();
+}
+
+static int GetItemNumber(ItemDescriptor item) {
+ return item->number();
+}
+
+static int GetItemIndex(ItemDescriptor item) {
+ return item->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "EnumValues",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)GetByName,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)GetByNumber,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)GetItemName,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)GetItemNumber,
+ (GetItemIndexMethod)GetItemIndex,
+};
+
+} // namespace enumvalues
+
+PyObject* NewEnumValuesByName(ParentDescriptor descriptor) {
+ return descriptor::NewMappingByName(&enumvalues::ContainerDef, descriptor);
+}
+
+PyObject* NewEnumValuesByNumber(ParentDescriptor descriptor) {
+ return descriptor::NewMappingByNumber(&enumvalues::ContainerDef, descriptor);
+}
+
+PyObject* NewEnumValuesSeq(ParentDescriptor descriptor) {
+ return descriptor::NewSequence(&enumvalues::ContainerDef, descriptor);
+}
+
+} // namespace enum_descriptor
+
+namespace oneof_descriptor {
+
+typedef const OneofDescriptor* ParentDescriptor;
+
+static ParentDescriptor GetDescriptor(PyContainer* self) {
+ return reinterpret_cast<ParentDescriptor>(self->descriptor);
+}
+
+namespace fields {
+
+typedef const FieldDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->field_count();
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->field(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyFieldDescriptor_FromDescriptor(item);
+}
+
+static int GetItemIndex(ItemDescriptor item) {
+ return item->index_in_oneof();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "OneofFields",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)NULL,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)NULL,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)GetItemIndex,
+};
+
+} // namespace fields
+
+PyObject* NewOneofFieldsSeq(ParentDescriptor descriptor) {
+ return descriptor::NewSequence(&fields::ContainerDef, descriptor);
+}
+
+} // namespace oneof_descriptor
+
+namespace file_descriptor {
+
+typedef const FileDescriptor* ParentDescriptor;
+
+static ParentDescriptor GetDescriptor(PyContainer* self) {
+ return reinterpret_cast<ParentDescriptor>(self->descriptor);
+}
+
+namespace messages {
+
+typedef const Descriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->message_type_count();
+}
+
+static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindMessageTypeByName(name);
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->message_type(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyMessageDescriptor_FromDescriptor(item);
+}
+
+static const string& GetItemName(ItemDescriptor item) {
+ return item->name();
+}
+
+static int GetItemIndex(ItemDescriptor item) {
+ return item->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "FileMessages",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)GetByName,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)GetItemName,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)GetItemIndex,
+};
+
+} // namespace messages
+
+PyObject* NewFileMessageTypesByName(const FileDescriptor* descriptor) {
+ return descriptor::NewMappingByName(&messages::ContainerDef, descriptor);
+}
+
+namespace enums {
+
+typedef const EnumDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->enum_type_count();
+}
+
+static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindEnumTypeByName(name);
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->enum_type(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyEnumDescriptor_FromDescriptor(item);
+}
+
+static const string& GetItemName(ItemDescriptor item) {
+ return item->name();
+}
+
+static int GetItemIndex(ItemDescriptor item) {
+ return item->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "FileEnums",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)GetByName,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)GetItemName,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)GetItemIndex,
+};
+
+} // namespace enums
+
+PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor) {
+ return descriptor::NewMappingByName(&enums::ContainerDef, descriptor);
+}
+
+namespace extensions {
+
+typedef const FieldDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->extension_count();
+}
+
+static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindExtensionByName(name);
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->extension(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyFieldDescriptor_FromDescriptor(item);
+}
+
+static const string& GetItemName(ItemDescriptor item) {
+ return item->name();
+}
+
+static int GetItemIndex(ItemDescriptor item) {
+ return item->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "FileExtensions",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)GetByName,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)GetItemName,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)GetItemIndex,
+};
+
+} // namespace extensions
+
+PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor) {
+ return descriptor::NewMappingByName(&extensions::ContainerDef, descriptor);
+}
+
+namespace dependencies {
+
+typedef const FileDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->dependency_count();
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->dependency(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyFileDescriptor_FromDescriptor(item);
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "FileDependencies",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)NULL,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)NULL,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)NULL,
+};
+
+} // namespace dependencies
+
+PyObject* NewFileDependencies(const FileDescriptor* descriptor) {
+ return descriptor::NewSequence(&dependencies::ContainerDef, descriptor);
+}
+
+namespace public_dependencies {
+
+typedef const FileDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->public_dependency_count();
+}
+
+static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->public_dependency(index);
+}
+
+static PyObject* NewObjectFromItem(ItemDescriptor item) {
+ return PyFileDescriptor_FromDescriptor(item);
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "FilePublicDependencies",
+ (CountMethod)Count,
+ (GetByIndexMethod)GetByIndex,
+ (GetByNameMethod)NULL,
+ (GetByCamelcaseNameMethod)NULL,
+ (GetByNumberMethod)NULL,
+ (NewObjectFromItemMethod)NewObjectFromItem,
+ (GetItemNameMethod)NULL,
+ (GetItemCamelcaseNameMethod)NULL,
+ (GetItemNumberMethod)NULL,
+ (GetItemIndexMethod)NULL,
+};
+
+} // namespace public_dependencies
+
+PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor) {
+ return descriptor::NewSequence(&public_dependencies::ContainerDef,
+ descriptor);
+}
+
+} // namespace file_descriptor
+
+
+// Register all implementations
+
+bool InitDescriptorMappingTypes() {
+ if (PyType_Ready(&descriptor::DescriptorMapping_Type) < 0)
+ return false;
+ if (PyType_Ready(&descriptor::DescriptorSequence_Type) < 0)
+ return false;
+ if (PyType_Ready(&descriptor::ContainerIterator_Type) < 0)
+ return false;
+ return true;
+}
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/descriptor_containers.h b/python/google/protobuf/pyext/descriptor_containers.h
new file mode 100644
index 000000000..ce40747d5
--- /dev/null
+++ b/python/google/protobuf/pyext/descriptor_containers.h
@@ -0,0 +1,101 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__
+
+// Mappings and Sequences of descriptors.
+// They implement containers like fields_by_name, EnumDescriptor.values...
+// See descriptor_containers.cc for more description.
+#include <Python.h>
+
+namespace google {
+namespace protobuf {
+
+class Descriptor;
+class FileDescriptor;
+class EnumDescriptor;
+class OneofDescriptor;
+
+namespace python {
+
+// Initialize the various types and objects.
+bool InitDescriptorMappingTypes();
+
+// Each function below returns a Mapping, or a Sequence of descriptors.
+// They all return a new reference.
+
+namespace message_descriptor {
+PyObject* NewMessageFieldsByName(const Descriptor* descriptor);
+PyObject* NewMessageFieldsByCamelcaseName(const Descriptor* descriptor);
+PyObject* NewMessageFieldsByNumber(const Descriptor* descriptor);
+PyObject* NewMessageFieldsSeq(const Descriptor* descriptor);
+
+PyObject* NewMessageNestedTypesSeq(const Descriptor* descriptor);
+PyObject* NewMessageNestedTypesByName(const Descriptor* descriptor);
+
+PyObject* NewMessageEnumsByName(const Descriptor* descriptor);
+PyObject* NewMessageEnumsSeq(const Descriptor* descriptor);
+PyObject* NewMessageEnumValuesByName(const Descriptor* descriptor);
+
+PyObject* NewMessageExtensionsByName(const Descriptor* descriptor);
+PyObject* NewMessageExtensionsSeq(const Descriptor* descriptor);
+
+PyObject* NewMessageOneofsByName(const Descriptor* descriptor);
+PyObject* NewMessageOneofsSeq(const Descriptor* descriptor);
+} // namespace message_descriptor
+
+namespace enum_descriptor {
+PyObject* NewEnumValuesByName(const EnumDescriptor* descriptor);
+PyObject* NewEnumValuesByNumber(const EnumDescriptor* descriptor);
+PyObject* NewEnumValuesSeq(const EnumDescriptor* descriptor);
+} // namespace enum_descriptor
+
+namespace oneof_descriptor {
+PyObject* NewOneofFieldsSeq(const OneofDescriptor* descriptor);
+} // namespace oneof_descriptor
+
+namespace file_descriptor {
+PyObject* NewFileMessageTypesByName(const FileDescriptor* descriptor);
+
+PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor);
+
+PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor);
+
+PyObject* NewFileDependencies(const FileDescriptor* descriptor);
+PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor);
+} // namespace file_descriptor
+
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__
diff --git a/python/google/protobuf/pyext/descriptor_cpp2_test.py b/python/google/protobuf/pyext/descriptor_cpp2_test.py
deleted file mode 100644
index 3cf45a226..000000000
--- a/python/google/protobuf/pyext/descriptor_cpp2_test.py
+++ /dev/null
@@ -1,58 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests for google.protobuf.pyext behavior."""
-
-__author__ = 'anuraag@google.com (Anuraag Agrawal)'
-
-import os
-os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp'
-os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2'
-
-# We must set the implementation version above before the google3 imports.
-# pylint: disable=g-import-not-at-top
-from google.apputils import basetest
-from google.protobuf.internal import api_implementation
-# Run all tests from the original module by putting them in our namespace.
-# pylint: disable=wildcard-import
-from google.protobuf.internal.descriptor_test import *
-
-
-class ConfirmCppApi2Test(basetest.TestCase):
-
- def testImplementationSetting(self):
- self.assertEqual('cpp', api_implementation.Type())
- self.assertEqual(2, api_implementation.Version())
-
-
-if __name__ == '__main__':
- basetest.main()
diff --git a/python/google/protobuf/pyext/descriptor_database.cc b/python/google/protobuf/pyext/descriptor_database.cc
new file mode 100644
index 000000000..daa40cc72
--- /dev/null
+++ b/python/google/protobuf/pyext/descriptor_database.cc
@@ -0,0 +1,148 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// This file defines a C++ DescriptorDatabase, which wraps a Python Database
+// and delegate all its operations to Python methods.
+
+#include <google/protobuf/pyext/descriptor_database.h>
+
+#include <google/protobuf/stubs/logging.h>
+#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/descriptor.pb.h>
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+PyDescriptorDatabase::PyDescriptorDatabase(PyObject* py_database)
+ : py_database_(py_database) {
+ Py_INCREF(py_database_);
+}
+
+PyDescriptorDatabase::~PyDescriptorDatabase() { Py_DECREF(py_database_); }
+
+// Convert a Python object to a FileDescriptorProto pointer.
+// Handles all kinds of Python errors, which are simply logged.
+static bool GetFileDescriptorProto(PyObject* py_descriptor,
+ FileDescriptorProto* output) {
+ if (py_descriptor == NULL) {
+ if (PyErr_ExceptionMatches(PyExc_KeyError)) {
+ // Expected error: item was simply not found.
+ PyErr_Clear();
+ } else {
+ GOOGLE_LOG(ERROR) << "DescriptorDatabase method raised an error";
+ PyErr_Print();
+ }
+ return false;
+ }
+ if (py_descriptor == Py_None) {
+ return false;
+ }
+ const Descriptor* filedescriptor_descriptor =
+ FileDescriptorProto::default_instance().GetDescriptor();
+ CMessage* message = reinterpret_cast<CMessage*>(py_descriptor);
+ if (PyObject_TypeCheck(py_descriptor, &CMessage_Type) &&
+ message->message->GetDescriptor() == filedescriptor_descriptor) {
+ // Fast path: Just use the pointer.
+ FileDescriptorProto* file_proto =
+ static_cast<FileDescriptorProto*>(message->message);
+ *output = *file_proto;
+ return true;
+ } else {
+ // Slow path: serialize the message. This allows to use databases which
+ // use a different implementation of FileDescriptorProto.
+ ScopedPyObjectPtr serialized_pb(
+ PyObject_CallMethod(py_descriptor, "SerializeToString", NULL));
+ if (serialized_pb == NULL) {
+ GOOGLE_LOG(ERROR)
+ << "DescriptorDatabase method did not return a FileDescriptorProto";
+ PyErr_Print();
+ return false;
+ }
+ char* str;
+ Py_ssize_t len;
+ if (PyBytes_AsStringAndSize(serialized_pb.get(), &str, &len) < 0) {
+ GOOGLE_LOG(ERROR)
+ << "DescriptorDatabase method did not return a FileDescriptorProto";
+ PyErr_Print();
+ return false;
+ }
+ FileDescriptorProto file_proto;
+ if (!file_proto.ParseFromArray(str, len)) {
+ GOOGLE_LOG(ERROR)
+ << "DescriptorDatabase method did not return a FileDescriptorProto";
+ return false;
+ }
+ *output = file_proto;
+ return true;
+ }
+}
+
+// Find a file by file name.
+bool PyDescriptorDatabase::FindFileByName(const string& filename,
+ FileDescriptorProto* output) {
+ ScopedPyObjectPtr py_descriptor(PyObject_CallMethod(
+ py_database_, "FindFileByName", "s#", filename.c_str(), filename.size()));
+ return GetFileDescriptorProto(py_descriptor.get(), output);
+}
+
+// Find the file that declares the given fully-qualified symbol name.
+bool PyDescriptorDatabase::FindFileContainingSymbol(
+ const string& symbol_name, FileDescriptorProto* output) {
+ ScopedPyObjectPtr py_descriptor(
+ PyObject_CallMethod(py_database_, "FindFileContainingSymbol", "s#",
+ symbol_name.c_str(), symbol_name.size()));
+ return GetFileDescriptorProto(py_descriptor.get(), output);
+}
+
+// Find the file which defines an extension extending the given message type
+// with the given field number.
+// Python DescriptorDatabases are not required to implement this method.
+bool PyDescriptorDatabase::FindFileContainingExtension(
+ const string& containing_type, int field_number,
+ FileDescriptorProto* output) {
+ ScopedPyObjectPtr py_method(
+ PyObject_GetAttrString(py_database_, "FindFileContainingExtension"));
+ if (py_method == NULL) {
+ // This method is not implemented, returns without error.
+ PyErr_Clear();
+ return false;
+ }
+ ScopedPyObjectPtr py_descriptor(
+ PyObject_CallFunction(py_method.get(), "s#i", containing_type.c_str(),
+ containing_type.size(), field_number));
+ return GetFileDescriptorProto(py_descriptor.get(), output);
+}
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/descriptor_database.h b/python/google/protobuf/pyext/descriptor_database.h
new file mode 100644
index 000000000..fc71c4bcb
--- /dev/null
+++ b/python/google/protobuf/pyext/descriptor_database.h
@@ -0,0 +1,75 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__
+
+#include <Python.h>
+
+#include <google/protobuf/descriptor_database.h>
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+class PyDescriptorDatabase : public DescriptorDatabase {
+ public:
+ explicit PyDescriptorDatabase(PyObject* py_database);
+ ~PyDescriptorDatabase();
+
+ // Implement the abstract interface. All these functions fill the output
+ // with a copy of FileDescriptorProto.
+
+ // Find a file by file name.
+ bool FindFileByName(const string& filename,
+ FileDescriptorProto* output);
+
+ // Find the file that declares the given fully-qualified symbol name.
+ bool FindFileContainingSymbol(const string& symbol_name,
+ FileDescriptorProto* output);
+
+ // Find the file which defines an extension extending the given message type
+ // with the given field number.
+ // Containing_type must be a fully-qualified type name.
+ // Python objects are not required to implement this method.
+ bool FindFileContainingExtension(const string& containing_type,
+ int field_number,
+ FileDescriptorProto* output);
+
+ private:
+ // The python object that implements the database. The reference is owned.
+ PyObject* py_database_;
+};
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__
diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc
new file mode 100644
index 000000000..1faff96bc
--- /dev/null
+++ b/python/google/protobuf/pyext/descriptor_pool.cc
@@ -0,0 +1,593 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Implements the DescriptorPool, which collects all descriptors.
+
+#include <Python.h>
+
+#include <google/protobuf/descriptor.pb.h>
+#include <google/protobuf/dynamic_message.h>
+#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/descriptor_database.h>
+#include <google/protobuf/pyext/descriptor_pool.h>
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+#if PY_MAJOR_VERSION >= 3
+ #define PyString_FromStringAndSize PyUnicode_FromStringAndSize
+ #if PY_VERSION_HEX < 0x03030000
+ #error "Python 3.0 - 3.2 are not supported."
+ #endif
+ #define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob)? \
+ ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \
+ PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
+#endif
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+// A map to cache Python Pools per C++ pointer.
+// Pointers are not owned here, and belong to the PyDescriptorPool.
+static hash_map<const DescriptorPool*, PyDescriptorPool*> descriptor_pool_map;
+
+namespace cdescriptor_pool {
+
+// Create a Python DescriptorPool object, but does not fill the "pool"
+// attribute.
+static PyDescriptorPool* _CreateDescriptorPool() {
+ PyDescriptorPool* cpool = PyObject_New(
+ PyDescriptorPool, &PyDescriptorPool_Type);
+ if (cpool == NULL) {
+ return NULL;
+ }
+
+ cpool->underlay = NULL;
+ cpool->database = NULL;
+
+ DynamicMessageFactory* message_factory = new DynamicMessageFactory();
+ // This option might be the default some day.
+ message_factory->SetDelegateToGeneratedFactory(true);
+ cpool->message_factory = message_factory;
+
+ // TODO(amauryfa): Rewrite the SymbolDatabase in C so that it uses the same
+ // storage.
+ cpool->classes_by_descriptor =
+ new PyDescriptorPool::ClassesByMessageMap();
+ cpool->descriptor_options =
+ new hash_map<const void*, PyObject *>();
+
+ return cpool;
+}
+
+// Create a Python DescriptorPool, using the given pool as an underlay:
+// new messages will be added to a custom pool, not to the underlay.
+//
+// Ownership of the underlay is not transferred, its pointer should
+// stay alive.
+static PyDescriptorPool* PyDescriptorPool_NewWithUnderlay(
+ const DescriptorPool* underlay) {
+ PyDescriptorPool* cpool = _CreateDescriptorPool();
+ if (cpool == NULL) {
+ return NULL;
+ }
+ cpool->pool = new DescriptorPool(underlay);
+ cpool->underlay = underlay;
+
+ if (!descriptor_pool_map.insert(
+ std::make_pair(cpool->pool, cpool)).second) {
+ // Should never happen -- would indicate an internal error / bug.
+ PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered");
+ return NULL;
+ }
+
+ return cpool;
+}
+
+static PyDescriptorPool* PyDescriptorPool_NewWithDatabase(
+ DescriptorDatabase* database) {
+ PyDescriptorPool* cpool = _CreateDescriptorPool();
+ if (cpool == NULL) {
+ return NULL;
+ }
+ if (database != NULL) {
+ cpool->pool = new DescriptorPool(database);
+ cpool->database = database;
+ } else {
+ cpool->pool = new DescriptorPool();
+ }
+
+ if (!descriptor_pool_map.insert(std::make_pair(cpool->pool, cpool)).second) {
+ // Should never happen -- would indicate an internal error / bug.
+ PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered");
+ return NULL;
+ }
+
+ return cpool;
+}
+
+// The public DescriptorPool constructor.
+static PyObject* New(PyTypeObject* type,
+ PyObject* args, PyObject* kwargs) {
+ static char* kwlist[] = {"descriptor_db", 0};
+ PyObject* py_database = NULL;
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &py_database)) {
+ return NULL;
+ }
+ DescriptorDatabase* database = NULL;
+ if (py_database && py_database != Py_None) {
+ database = new PyDescriptorDatabase(py_database);
+ }
+ return reinterpret_cast<PyObject*>(
+ PyDescriptorPool_NewWithDatabase(database));
+}
+
+static void Dealloc(PyDescriptorPool* self) {
+ typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator;
+ descriptor_pool_map.erase(self->pool);
+ for (iterator it = self->classes_by_descriptor->begin();
+ it != self->classes_by_descriptor->end(); ++it) {
+ Py_DECREF(it->second);
+ }
+ delete self->classes_by_descriptor;
+ for (hash_map<const void*, PyObject*>::iterator it =
+ self->descriptor_options->begin();
+ it != self->descriptor_options->end(); ++it) {
+ Py_DECREF(it->second);
+ }
+ delete self->descriptor_options;
+ delete self->message_factory;
+ delete self->database;
+ delete self->pool;
+ Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+}
+
+PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const Descriptor* message_descriptor =
+ self->pool->FindMessageTypeByName(string(name, name_size));
+
+ if (message_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find message %.200s", name);
+ return NULL;
+ }
+
+ return PyMessageDescriptor_FromDescriptor(message_descriptor);
+}
+
+// Add a message class to our database.
+int RegisterMessageClass(PyDescriptorPool* self,
+ const Descriptor* message_descriptor,
+ CMessageClass* message_class) {
+ Py_INCREF(message_class);
+ typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator;
+ std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
+ std::make_pair(message_descriptor, message_class));
+ if (!ret.second) {
+ // Update case: DECREF the previous value.
+ Py_DECREF(ret.first->second);
+ ret.first->second = message_class;
+ }
+ return 0;
+}
+
+// Retrieve the message class added to our database.
+CMessageClass* GetMessageClass(PyDescriptorPool* self,
+ const Descriptor* message_descriptor) {
+ typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator;
+ iterator ret = self->classes_by_descriptor->find(message_descriptor);
+ if (ret == self->classes_by_descriptor->end()) {
+ PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
+ message_descriptor->full_name().c_str());
+ return NULL;
+ } else {
+ return ret->second;
+ }
+}
+
+PyObject* FindFileByName(PyDescriptorPool* self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const FileDescriptor* file_descriptor =
+ self->pool->FindFileByName(string(name, name_size));
+ if (file_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find file %.200s",
+ name);
+ return NULL;
+ }
+
+ return PyFileDescriptor_FromDescriptor(file_descriptor);
+}
+
+PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const FieldDescriptor* field_descriptor =
+ self->pool->FindFieldByName(string(name, name_size));
+ if (field_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find field %.200s",
+ name);
+ return NULL;
+ }
+
+ return PyFieldDescriptor_FromDescriptor(field_descriptor);
+}
+
+PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const FieldDescriptor* field_descriptor =
+ self->pool->FindExtensionByName(string(name, name_size));
+ if (field_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find extension field %.200s", name);
+ return NULL;
+ }
+
+ return PyFieldDescriptor_FromDescriptor(field_descriptor);
+}
+
+PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const EnumDescriptor* enum_descriptor =
+ self->pool->FindEnumTypeByName(string(name, name_size));
+ if (enum_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find enum %.200s", name);
+ return NULL;
+ }
+
+ return PyEnumDescriptor_FromDescriptor(enum_descriptor);
+}
+
+PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const OneofDescriptor* oneof_descriptor =
+ self->pool->FindOneofByName(string(name, name_size));
+ if (oneof_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find oneof %.200s", name);
+ return NULL;
+ }
+
+ return PyOneofDescriptor_FromDescriptor(oneof_descriptor);
+}
+
+PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const FileDescriptor* file_descriptor =
+ self->pool->FindFileContainingSymbol(string(name, name_size));
+ if (file_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find symbol %.200s", name);
+ return NULL;
+ }
+
+ return PyFileDescriptor_FromDescriptor(file_descriptor);
+}
+
+// These functions should not exist -- the only valid way to create
+// descriptors is to call Add() or AddSerializedFile().
+// But these AddDescriptor() functions were created in Python and some people
+// call them, so we support them for now for compatibility.
+// However we do check that the existing descriptor already exists in the pool,
+// which appears to always be true for existing calls -- but then why do people
+// call a function that will just be a no-op?
+// TODO(amauryfa): Need to investigate further.
+
+PyObject* AddFileDescriptor(PyDescriptorPool* self, PyObject* descriptor) {
+ const FileDescriptor* file_descriptor =
+ PyFileDescriptor_AsDescriptor(descriptor);
+ if (!file_descriptor) {
+ return NULL;
+ }
+ if (file_descriptor !=
+ self->pool->FindFileByName(file_descriptor->name())) {
+ PyErr_Format(PyExc_ValueError,
+ "The file descriptor %s does not belong to this pool",
+ file_descriptor->name().c_str());
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+PyObject* AddDescriptor(PyDescriptorPool* self, PyObject* descriptor) {
+ const Descriptor* message_descriptor =
+ PyMessageDescriptor_AsDescriptor(descriptor);
+ if (!message_descriptor) {
+ return NULL;
+ }
+ if (message_descriptor !=
+ self->pool->FindMessageTypeByName(message_descriptor->full_name())) {
+ PyErr_Format(PyExc_ValueError,
+ "The message descriptor %s does not belong to this pool",
+ message_descriptor->full_name().c_str());
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+PyObject* AddEnumDescriptor(PyDescriptorPool* self, PyObject* descriptor) {
+ const EnumDescriptor* enum_descriptor =
+ PyEnumDescriptor_AsDescriptor(descriptor);
+ if (!enum_descriptor) {
+ return NULL;
+ }
+ if (enum_descriptor !=
+ self->pool->FindEnumTypeByName(enum_descriptor->full_name())) {
+ PyErr_Format(PyExc_ValueError,
+ "The enum descriptor %s does not belong to this pool",
+ enum_descriptor->full_name().c_str());
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+// The code below loads new Descriptors from a serialized FileDescriptorProto.
+
+
+// Collects errors that occur during proto file building to allow them to be
+// propagated in the python exception instead of only living in ERROR logs.
+class BuildFileErrorCollector : public DescriptorPool::ErrorCollector {
+ public:
+ BuildFileErrorCollector() : error_message(""), had_errors(false) {}
+
+ void AddError(const string& filename, const string& element_name,
+ const Message* descriptor, ErrorLocation location,
+ const string& message) {
+ // Replicates the logging behavior that happens in the C++ implementation
+ // when an error collector is not passed in.
+ if (!had_errors) {
+ error_message +=
+ ("Invalid proto descriptor for file \"" + filename + "\":\n");
+ had_errors = true;
+ }
+ // As this only happens on failure and will result in the program not
+ // running at all, no effort is made to optimize this string manipulation.
+ error_message += (" " + element_name + ": " + message + "\n");
+ }
+
+ string error_message;
+ bool had_errors;
+};
+
+PyObject* AddSerializedFile(PyDescriptorPool* self, PyObject* serialized_pb) {
+ char* message_type;
+ Py_ssize_t message_len;
+
+ if (self->database != NULL) {
+ PyErr_SetString(
+ PyExc_ValueError,
+ "Cannot call Add on a DescriptorPool that uses a DescriptorDatabase. "
+ "Add your file to the underlying database.");
+ return NULL;
+ }
+
+ if (PyBytes_AsStringAndSize(serialized_pb, &message_type, &message_len) < 0) {
+ return NULL;
+ }
+
+ FileDescriptorProto file_proto;
+ if (!file_proto.ParseFromArray(message_type, message_len)) {
+ PyErr_SetString(PyExc_TypeError, "Couldn't parse file content!");
+ return NULL;
+ }
+
+ // If the file was already part of a C++ library, all its descriptors are in
+ // the underlying pool. No need to do anything else.
+ const FileDescriptor* generated_file = NULL;
+ if (self->underlay) {
+ generated_file = self->underlay->FindFileByName(file_proto.name());
+ }
+ if (generated_file != NULL) {
+ return PyFileDescriptor_FromDescriptorWithSerializedPb(
+ generated_file, serialized_pb);
+ }
+
+ BuildFileErrorCollector error_collector;
+ const FileDescriptor* descriptor =
+ self->pool->BuildFileCollectingErrors(file_proto,
+ &error_collector);
+ if (descriptor == NULL) {
+ PyErr_Format(PyExc_TypeError,
+ "Couldn't build proto file into descriptor pool!\n%s",
+ error_collector.error_message.c_str());
+ return NULL;
+ }
+
+ return PyFileDescriptor_FromDescriptorWithSerializedPb(
+ descriptor, serialized_pb);
+}
+
+PyObject* Add(PyDescriptorPool* self, PyObject* file_descriptor_proto) {
+ ScopedPyObjectPtr serialized_pb(
+ PyObject_CallMethod(file_descriptor_proto, "SerializeToString", NULL));
+ if (serialized_pb == NULL) {
+ return NULL;
+ }
+ return AddSerializedFile(self, serialized_pb.get());
+}
+
+static PyMethodDef Methods[] = {
+ { "Add", (PyCFunction)Add, METH_O,
+ "Adds the FileDescriptorProto and its types to this pool." },
+ { "AddSerializedFile", (PyCFunction)AddSerializedFile, METH_O,
+ "Adds a serialized FileDescriptorProto to this pool." },
+
+ // TODO(amauryfa): Understand why the Python implementation differs from
+ // this one, ask users to use another API and deprecate these functions.
+ { "AddFileDescriptor", (PyCFunction)AddFileDescriptor, METH_O,
+ "No-op. Add() must have been called before." },
+ { "AddDescriptor", (PyCFunction)AddDescriptor, METH_O,
+ "No-op. Add() must have been called before." },
+ { "AddEnumDescriptor", (PyCFunction)AddEnumDescriptor, METH_O,
+ "No-op. Add() must have been called before." },
+
+ { "FindFileByName", (PyCFunction)FindFileByName, METH_O,
+ "Searches for a file descriptor by its .proto name." },
+ { "FindMessageTypeByName", (PyCFunction)FindMessageByName, METH_O,
+ "Searches for a message descriptor by full name." },
+ { "FindFieldByName", (PyCFunction)FindFieldByName, METH_O,
+ "Searches for a field descriptor by full name." },
+ { "FindExtensionByName", (PyCFunction)FindExtensionByName, METH_O,
+ "Searches for extension descriptor by full name." },
+ { "FindEnumTypeByName", (PyCFunction)FindEnumTypeByName, METH_O,
+ "Searches for enum type descriptor by full name." },
+ { "FindOneofByName", (PyCFunction)FindOneofByName, METH_O,
+ "Searches for oneof descriptor by full name." },
+
+ { "FindFileContainingSymbol", (PyCFunction)FindFileContainingSymbol, METH_O,
+ "Gets the FileDescriptor containing the specified symbol." },
+ {NULL}
+};
+
+} // namespace cdescriptor_pool
+
+PyTypeObject PyDescriptorPool_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".DescriptorPool", // tp_name
+ sizeof(PyDescriptorPool), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)cdescriptor_pool::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A Descriptor Pool", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ cdescriptor_pool::Methods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ 0, // tp_alloc
+ cdescriptor_pool::New, // tp_new
+ PyObject_Del, // tp_free
+};
+
+// This is the DescriptorPool which contains all the definitions from the
+// generated _pb2.py modules.
+static PyDescriptorPool* python_generated_pool = NULL;
+
+bool InitDescriptorPool() {
+ if (PyType_Ready(&PyDescriptorPool_Type) < 0)
+ return false;
+
+ // The Pool of messages declared in Python libraries.
+ // generated_pool() contains all messages already linked in C++ libraries, and
+ // is used as underlay.
+ python_generated_pool = cdescriptor_pool::PyDescriptorPool_NewWithUnderlay(
+ DescriptorPool::generated_pool());
+ if (python_generated_pool == NULL) {
+ return false;
+ }
+ // Register this pool to be found for C++-generated descriptors.
+ descriptor_pool_map.insert(
+ std::make_pair(DescriptorPool::generated_pool(),
+ python_generated_pool));
+
+ return true;
+}
+
+// The default DescriptorPool used everywhere in this module.
+// Today it's the python_generated_pool.
+// TODO(amauryfa): Remove all usages of this function: the pool should be
+// derived from the context.
+PyDescriptorPool* GetDefaultDescriptorPool() {
+ return python_generated_pool;
+}
+
+PyDescriptorPool* GetDescriptorPool_FromPool(const DescriptorPool* pool) {
+ // Fast path for standard descriptors.
+ if (pool == python_generated_pool->pool ||
+ pool == DescriptorPool::generated_pool()) {
+ return python_generated_pool;
+ }
+ hash_map<const DescriptorPool*, PyDescriptorPool*>::iterator it =
+ descriptor_pool_map.find(pool);
+ if (it == descriptor_pool_map.end()) {
+ PyErr_SetString(PyExc_KeyError, "Unknown descriptor pool");
+ return NULL;
+ }
+ return it->second;
+}
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h
new file mode 100644
index 000000000..2a42c1126
--- /dev/null
+++ b/python/google/protobuf/pyext/descriptor_pool.h
@@ -0,0 +1,167 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__
+
+#include <Python.h>
+
+#include <google/protobuf/stubs/hash.h>
+#include <google/protobuf/descriptor.h>
+
+namespace google {
+namespace protobuf {
+class MessageFactory;
+
+namespace python {
+
+// The (meta) type of all Messages classes.
+struct CMessageClass;
+
+// Wraps operations to the global DescriptorPool which contains information
+// about all messages and fields.
+//
+// There is normally one pool per process. We make it a Python object only
+// because it contains many Python references.
+// TODO(amauryfa): See whether such objects can appear in reference cycles, and
+// consider adding support for the cyclic GC.
+//
+// "Methods" that interacts with this DescriptorPool are in the cdescriptor_pool
+// namespace.
+typedef struct PyDescriptorPool {
+ PyObject_HEAD
+
+ // The C++ pool containing Descriptors.
+ DescriptorPool* pool;
+
+ // The C++ pool acting as an underlay. Can be NULL.
+ // This pointer is not owned and must stay alive.
+ const DescriptorPool* underlay;
+
+ // The C++ descriptor database used to fetch unknown protos. Can be NULL.
+ // This pointer is owned.
+ const DescriptorDatabase* database;
+
+ // DynamicMessageFactory used to create C++ instances of messages.
+ // This object cache the descriptors that were used, so the DescriptorPool
+ // needs to get rid of it before it can delete itself.
+ //
+ // Note: A C++ MessageFactory is different from the Python MessageFactory.
+ // The C++ one creates messages, when the Python one creates classes.
+ MessageFactory* message_factory;
+
+ // Make our own mapping to retrieve Python classes from C++ descriptors.
+ //
+ // Descriptor pointers stored here are owned by the DescriptorPool above.
+ // Python references to classes are owned by this PyDescriptorPool.
+ typedef hash_map<const Descriptor*, CMessageClass*> ClassesByMessageMap;
+ ClassesByMessageMap* classes_by_descriptor;
+
+ // Cache the options for any kind of descriptor.
+ // Descriptor pointers are owned by the DescriptorPool above.
+ // Python objects are owned by the map.
+ hash_map<const void*, PyObject*>* descriptor_options;
+} PyDescriptorPool;
+
+
+extern PyTypeObject PyDescriptorPool_Type;
+
+namespace cdescriptor_pool {
+
+// Looks up a message by name.
+// Returns a message Descriptor, or NULL if not found.
+const Descriptor* FindMessageTypeByName(PyDescriptorPool* self,
+ const string& name);
+
+// Registers a new Python class for the given message descriptor.
+// On error, returns -1 with a Python exception set.
+int RegisterMessageClass(PyDescriptorPool* self,
+ const Descriptor* message_descriptor,
+ CMessageClass* message_class);
+
+// Retrieves the Python class registered with the given message descriptor.
+//
+// Returns a *borrowed* reference if found, otherwise returns NULL with an
+// exception set.
+CMessageClass* GetMessageClass(PyDescriptorPool* self,
+ const Descriptor* message_descriptor);
+
+// The functions below are also exposed as methods of the DescriptorPool type.
+
+// Looks up a message by name. Returns a PyMessageDescriptor corresponding to
+// the field on success, or NULL on failure.
+//
+// Returns a new reference.
+PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* name);
+
+// Looks up a field by name. Returns a PyFieldDescriptor corresponding to
+// the field on success, or NULL on failure.
+//
+// Returns a new reference.
+PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* name);
+
+// Looks up an extension by name. Returns a PyFieldDescriptor corresponding
+// to the field on success, or NULL on failure.
+//
+// Returns a new reference.
+PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg);
+
+// Looks up an enum type by name. Returns a PyEnumDescriptor corresponding
+// to the field on success, or NULL on failure.
+//
+// Returns a new reference.
+PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg);
+
+// Looks up a oneof by name. Returns a COneofDescriptor corresponding
+// to the oneof on success, or NULL on failure.
+//
+// Returns a new reference.
+PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg);
+
+} // namespace cdescriptor_pool
+
+// Retrieve the global descriptor pool owned by the _message module.
+// This is the one used by pb2.py generated modules.
+// Returns a *borrowed* reference.
+// "Default" pool used to register messages from _pb2.py modules.
+PyDescriptorPool* GetDefaultDescriptorPool();
+
+// Retrieve the python descriptor pool owning a C++ descriptor pool.
+// Returns a *borrowed* reference.
+PyDescriptorPool* GetDescriptorPool_FromPool(const DescriptorPool* pool);
+
+// Initialize objects used by this module.
+bool InitDescriptorPool();
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc
index 3861c7943..21bbb8c2b 100644
--- a/python/google/protobuf/pyext/extension_dict.cc
+++ b/python/google/protobuf/pyext/extension_dict.cc
@@ -33,11 +33,13 @@
#include <google/protobuf/pyext/extension_dict.h>
+#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/message.h>
#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/repeated_scalar_container.h>
@@ -48,36 +50,8 @@ namespace google {
namespace protobuf {
namespace python {
-extern google::protobuf::DynamicMessageFactory* global_message_factory;
-
namespace extension_dict {
-// TODO(tibell): Always use self->message for clarity, just like in
-// RepeatedCompositeContainer.
-static google::protobuf::Message* GetMessage(ExtensionDict* self) {
- if (self->parent != NULL) {
- return self->parent->message;
- } else {
- return self->message;
- }
-}
-
-CFieldDescriptor* InternalGetCDescriptorFromExtension(PyObject* extension) {
- PyObject* cdescriptor = PyObject_GetAttrString(extension, "_cdescriptor");
- if (cdescriptor == NULL) {
- PyErr_SetString(PyExc_KeyError, "Unregistered extension.");
- return NULL;
- }
- if (!PyObject_TypeCheck(cdescriptor, &CFieldDescriptor_Type)) {
- PyErr_SetString(PyExc_TypeError, "Not a CFieldDescriptor");
- Py_DECREF(cdescriptor);
- return NULL;
- }
- CFieldDescriptor* descriptor =
- reinterpret_cast<CFieldDescriptor*>(cdescriptor);
- return descriptor;
-}
-
PyObject* len(ExtensionDict* self) {
#if PY_MAJOR_VERSION >= 3
return PyLong_FromLong(PyDict_Size(self->values));
@@ -89,10 +63,9 @@ PyObject* len(ExtensionDict* self) {
// TODO(tibell): Use VisitCompositeField.
int ReleaseExtension(ExtensionDict* self,
PyObject* extension,
- const google::protobuf::FieldDescriptor* descriptor) {
- if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
- if (descriptor->cpp_type() ==
- google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ const FieldDescriptor* descriptor) {
+ if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
+ if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
if (repeated_composite_container::Release(
reinterpret_cast<RepeatedCompositeContainer*>(
extension)) < 0) {
@@ -105,10 +78,9 @@ int ReleaseExtension(ExtensionDict* self,
return -1;
}
}
- } else if (descriptor->cpp_type() ==
- google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
if (cmessage::ReleaseSubMessage(
- GetMessage(self), descriptor,
+ self->parent, descriptor,
reinterpret_cast<CMessage*>(extension)) < 0) {
return -1;
}
@@ -118,19 +90,17 @@ int ReleaseExtension(ExtensionDict* self,
}
PyObject* subscript(ExtensionDict* self, PyObject* key) {
- CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension(
- key);
- if (cdescriptor == NULL) {
+ const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
+ if (descriptor == NULL) {
return NULL;
}
- ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor));
- const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor;
- if (descriptor == NULL) {
+ if (!CheckFieldBelongsToMessage(descriptor, self->message)) {
return NULL;
}
+
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
- return cmessage::InternalGetScalar(self->parent, descriptor);
+ return cmessage::InternalGetScalar(self->message, descriptor);
}
PyObject* value = PyDict_GetItem(self->values, key);
@@ -139,10 +109,18 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
return value;
}
+ if (self->parent == NULL) {
+ // We are in "detached" state. Don't allow further modifications.
+ // TODO(amauryfa): Support adding non-scalars to a detached extension dict.
+ // This probably requires to store the type of the main message.
+ PyErr_SetObject(PyExc_KeyError, key);
+ return NULL;
+ }
+
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
PyObject* sub_message = cmessage::InternalGetSubMessage(
- self->parent, cdescriptor);
+ self->parent, descriptor);
if (sub_message == NULL) {
return NULL;
}
@@ -152,33 +130,22 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- // COPIED
- PyObject* py_container = PyObject_CallObject(
- reinterpret_cast<PyObject*>(&RepeatedCompositeContainer_Type),
- NULL);
+ CMessageClass* message_class = cdescriptor_pool::GetMessageClass(
+ cmessage::GetDescriptorPoolForMessage(self->parent),
+ descriptor->message_type());
+ if (message_class == NULL) {
+ return NULL;
+ }
+ PyObject* py_container = repeated_composite_container::NewContainer(
+ self->parent, descriptor, message_class);
if (py_container == NULL) {
return NULL;
}
- RepeatedCompositeContainer* container =
- reinterpret_cast<RepeatedCompositeContainer*>(py_container);
- PyObject* field = cdescriptor->descriptor_field;
- PyObject* message_type = PyObject_GetAttrString(field, "message_type");
- PyObject* concrete_class = PyObject_GetAttrString(message_type,
- "_concrete_class");
- container->owner = self->owner;
- container->parent = self->parent;
- container->message = self->parent->message;
- container->parent_field = cdescriptor;
- container->subclass_init = concrete_class;
- Py_DECREF(message_type);
PyDict_SetItem(self->values, key, py_container);
return py_container;
} else {
- // COPIED
- ScopedPyObjectPtr init_args(PyTuple_Pack(2, self->parent, cdescriptor));
- PyObject* py_container = PyObject_CallObject(
- reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type),
- init_args);
+ PyObject* py_container = repeated_scalar_container::NewContainer(
+ self->parent, descriptor);
if (py_container == NULL) {
return NULL;
}
@@ -191,22 +158,25 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
}
int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
- CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension(
- key);
- if (cdescriptor == NULL) {
+ const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
+ if (descriptor == NULL) {
+ return -1;
+ }
+ if (!CheckFieldBelongsToMessage(descriptor, self->message)) {
return -1;
}
- ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor));
- const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor;
+
if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL ||
descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite "
"type");
return -1;
}
- cmessage::AssureWritable(self->parent);
- if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
- return -1;
+ if (self->parent) {
+ cmessage::AssureWritable(self->parent);
+ if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
+ return -1;
+ }
}
// TODO(tibell): We shouldn't write scalars to the cache.
PyDict_SetItem(self->values, key, value);
@@ -214,22 +184,23 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
}
PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) {
- CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension(
- extension);
- if (cdescriptor == NULL) {
+ const FieldDescriptor* descriptor =
+ cmessage::GetExtensionDescriptor(extension);
+ if (descriptor == NULL) {
return NULL;
}
- ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor));
PyObject* value = PyDict_GetItem(self->values, extension);
- if (value != NULL) {
- if (ReleaseExtension(self, value, cdescriptor->descriptor) < 0) {
+ if (self->parent) {
+ if (value != NULL) {
+ if (ReleaseExtension(self, value, descriptor) < 0) {
+ return NULL;
+ }
+ }
+ if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor(
+ self->parent, descriptor)) == NULL) {
return NULL;
}
}
- if (cmessage::ClearFieldByDescriptor(self->parent,
- cdescriptor->descriptor) == NULL) {
- return NULL;
- }
if (PyDict_DelItem(self->values, extension) < 0) {
PyErr_Clear();
}
@@ -237,15 +208,20 @@ PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) {
}
PyObject* HasExtension(ExtensionDict* self, PyObject* extension) {
- CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension(
- extension);
- if (cdescriptor == NULL) {
+ const FieldDescriptor* descriptor =
+ cmessage::GetExtensionDescriptor(extension);
+ if (descriptor == NULL) {
return NULL;
}
- ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor));
- PyObject* result = cmessage::HasFieldByDescriptor(
- self->parent, cdescriptor->descriptor);
- return result;
+ if (self->parent) {
+ return cmessage::HasFieldByDescriptor(self->parent, descriptor);
+ } else {
+ int exists = PyDict_Contains(self->values, extension);
+ if (exists < 0) {
+ return NULL;
+ }
+ return PyBool_FromLong(exists);
+ }
}
PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) {
@@ -254,7 +230,7 @@ PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) {
if (extensions_by_name == NULL) {
return NULL;
}
- PyObject* result = PyDict_GetItem(extensions_by_name, name);
+ PyObject* result = PyDict_GetItem(extensions_by_name.get(), name);
if (result == NULL) {
Py_RETURN_NONE;
} else {
@@ -263,11 +239,33 @@ PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) {
}
}
-int init(ExtensionDict* self, PyObject* args, PyObject* kwargs) {
- self->parent = NULL;
- self->message = NULL;
+PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number) {
+ ScopedPyObjectPtr extensions_by_number(PyObject_GetAttrString(
+ reinterpret_cast<PyObject*>(self->parent), "_extensions_by_number"));
+ if (extensions_by_number == NULL) {
+ return NULL;
+ }
+ PyObject* result = PyDict_GetItem(extensions_by_number.get(), number);
+ if (result == NULL) {
+ Py_RETURN_NONE;
+ } else {
+ Py_INCREF(result);
+ return result;
+ }
+}
+
+ExtensionDict* NewExtensionDict(CMessage *parent) {
+ ExtensionDict* self = reinterpret_cast<ExtensionDict*>(
+ PyType_GenericAlloc(&ExtensionDict_Type, 0));
+ if (self == NULL) {
+ return NULL;
+ }
+
+ self->parent = parent; // Store a borrowed reference.
+ self->message = parent->message;
+ self->owner = parent->owner;
self->values = PyDict_New();
- return 0;
+ return self;
}
void dealloc(ExtensionDict* self) {
@@ -288,6 +286,8 @@ static PyMethodDef Methods[] = {
EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."),
EDMETHOD(_FindExtensionByName, METH_O,
"Finds an extension by name."),
+ EDMETHOD(_FindExtensionByNumber, METH_O,
+ "Finds an extension by field number."),
{ NULL, NULL }
};
@@ -295,8 +295,7 @@ static PyMethodDef Methods[] = {
PyTypeObject ExtensionDict_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- "google.protobuf.internal."
- "cpp._message.ExtensionDict", // tp_name
+ FULL_MODULE_NAME ".ExtensionDict", // tp_name
sizeof(ExtensionDict), // tp_basicsize
0, // tp_itemsize
(destructor)extension_dict::dealloc, // tp_dealloc
@@ -308,7 +307,7 @@ PyTypeObject ExtensionDict_Type = {
0, // tp_as_number
0, // tp_as_sequence
&extension_dict::MpMethods, // tp_as_mapping
- 0, // tp_hash
+ PyObject_HashNotImplemented, // tp_hash
0, // tp_call
0, // tp_str
0, // tp_getattro
@@ -330,7 +329,7 @@ PyTypeObject ExtensionDict_Type = {
0, // tp_descr_get
0, // tp_descr_set
0, // tp_dictoffset
- (initproc)extension_dict::init, // tp_init
+ 0, // tp_init
};
} // namespace python
diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h
index 13c874a49..2456eda1e 100644
--- a/python/google/protobuf/pyext/extension_dict.h
+++ b/python/google/protobuf/pyext/extension_dict.h
@@ -41,25 +41,41 @@
#include <google/protobuf/stubs/shared_ptr.h>
#endif
-
namespace google {
namespace protobuf {
class Message;
class FieldDescriptor;
+#ifdef _SHARED_PTR_H
+using std::shared_ptr;
+#else
using internal::shared_ptr;
+#endif
namespace python {
struct CMessage;
-struct CFieldDescriptor;
typedef struct ExtensionDict {
PyObject_HEAD;
+
+ // This is the top-level C++ Message object that owns the whole
+ // proto tree. Every Python container class holds a
+ // reference to it in order to keep it alive as long as there's a
+ // Python object that references any part of the tree.
shared_ptr<Message> owner;
+
+ // Weak reference to parent message. Used to make sure
+ // the parent is writable when an extension field is modified.
CMessage* parent;
+
+ // Pointer to the C++ Message that this ExtensionDict extends.
+ // Not owned by us.
Message* message;
+
+ // A dict of child messages, indexed by Extension descriptors.
+ // Similar to CMessage::composite_fields.
PyObject* values;
} ExtensionDict;
@@ -67,11 +83,8 @@ extern PyTypeObject ExtensionDict_Type;
namespace extension_dict {
-// Gets the _cdescriptor reference to a CFieldDescriptor object given a
-// python descriptor object.
-//
-// Returns a new reference.
-CFieldDescriptor* InternalGetCDescriptorFromExtension(PyObject* extension);
+// Builds an Extensions dict for a specific message.
+ExtensionDict* NewExtensionDict(CMessage *parent);
// Gets the number of extension values in this ExtensionDict as a python object.
//
@@ -84,7 +97,7 @@ PyObject* len(ExtensionDict* self);
// Returns 0 on success, -1 on failure.
int ReleaseExtension(ExtensionDict* self,
PyObject* extension,
- const google::protobuf::FieldDescriptor* descriptor);
+ const FieldDescriptor* descriptor);
// Gets an extension from the dict for the given extension descriptor.
//
@@ -104,17 +117,18 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value);
PyObject* ClearExtension(ExtensionDict* self,
PyObject* extension);
-// Checks if the dict has an extension.
-//
-// Returns a new python boolean reference.
-PyObject* HasExtension(ExtensionDict* self, PyObject* extension);
-
// Gets an extension from the dict given the extension name as opposed to
// descriptor.
//
// Returns a new reference.
PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name);
+// Gets an extension from the dict given the extension field number as
+// opposed to descriptor.
+//
+// Returns a new reference.
+PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number);
+
} // namespace extension_dict
} // namespace python
} // namespace protobuf
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc
new file mode 100644
index 000000000..e022406d1
--- /dev/null
+++ b/python/google/protobuf/pyext/map_container.cc
@@ -0,0 +1,970 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: haberman@google.com (Josh Haberman)
+
+#include <google/protobuf/pyext/map_container.h>
+
+#include <memory>
+#ifndef _SHARED_PTR_H
+#include <google/protobuf/stubs/shared_ptr.h>
+#endif
+
+#include <google/protobuf/stubs/logging.h>
+#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/stubs/scoped_ptr.h>
+#include <google/protobuf/map_field.h>
+#include <google/protobuf/map.h>
+#include <google/protobuf/message.h>
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+#if PY_MAJOR_VERSION >= 3
+ #define PyInt_FromLong PyLong_FromLong
+ #define PyInt_FromSize_t PyLong_FromSize_t
+#endif
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+// Functions that need access to map reflection functionality.
+// They need to be contained in this class because it is friended.
+class MapReflectionFriend {
+ public:
+ // Methods that are in common between the map types.
+ static PyObject* Contains(PyObject* _self, PyObject* key);
+ static Py_ssize_t Length(PyObject* _self);
+ static PyObject* GetIterator(PyObject *_self);
+ static PyObject* IterNext(PyObject* _self);
+
+ // Methods that differ between the map types.
+ static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
+ static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
+ static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
+ static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
+};
+
+struct MapIterator {
+ PyObject_HEAD;
+
+ google::protobuf::scoped_ptr< ::google::protobuf::MapIterator> iter;
+
+ // A pointer back to the container, so we can notice changes to the version.
+ // We own a ref on this.
+ MapContainer* container;
+
+ // We need to keep a ref on the Message* too, because
+ // MapIterator::~MapIterator() accesses it. Normally this would be ok because
+ // the ref on container (above) would guarantee outlive semantics. However in
+ // the case of ClearField(), InitializeAndCopyToParentContainer() resets the
+ // message pointer (and the owner) to a different message, a copy of the
+ // original. But our iterator still points to the original, which could now
+ // get deleted before us.
+ //
+ // To prevent this, we ensure that the Message will always stay alive as long
+ // as this iterator does. This is solely for the benefit of the MapIterator
+ // destructor -- we should never actually access the iterator in this state
+ // except to delete it.
+ shared_ptr<Message> owner;
+
+ // The version of the map when we took the iterator to it.
+ //
+ // We store this so that if the map is modified during iteration we can throw
+ // an error.
+ uint64 version;
+
+ // True if the container is empty. We signal this separately to avoid calling
+ // any of the iteration methods, which are non-const.
+ bool empty;
+};
+
+Message* MapContainer::GetMutableMessage() {
+ cmessage::AssureWritable(parent);
+ return const_cast<Message*>(message);
+}
+
+// Consumes a reference on the Python string object.
+static bool PyStringToSTL(PyObject* py_string, string* stl_string) {
+ char *value;
+ Py_ssize_t value_len;
+
+ if (!py_string) {
+ return false;
+ }
+ if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
+ Py_DECREF(py_string);
+ return false;
+ } else {
+ stl_string->assign(value, value_len);
+ Py_DECREF(py_string);
+ return true;
+ }
+}
+
+static bool PythonToMapKey(PyObject* obj,
+ const FieldDescriptor* field_descriptor,
+ MapKey* key) {
+ switch (field_descriptor->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32: {
+ GOOGLE_CHECK_GET_INT32(obj, value, false);
+ key->SetInt32Value(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_INT64: {
+ GOOGLE_CHECK_GET_INT64(obj, value, false);
+ key->SetInt64Value(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_UINT32: {
+ GOOGLE_CHECK_GET_UINT32(obj, value, false);
+ key->SetUInt32Value(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_UINT64: {
+ GOOGLE_CHECK_GET_UINT64(obj, value, false);
+ key->SetUInt64Value(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_BOOL: {
+ GOOGLE_CHECK_GET_BOOL(obj, value, false);
+ key->SetBoolValue(value);
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_STRING: {
+ string str;
+ if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
+ return false;
+ }
+ key->SetStringValue(str);
+ break;
+ }
+ default:
+ PyErr_Format(
+ PyExc_SystemError, "Type %d cannot be a map key",
+ field_descriptor->cpp_type());
+ return false;
+ }
+ return true;
+}
+
+static PyObject* MapKeyToPython(const FieldDescriptor* field_descriptor,
+ const MapKey& key) {
+ switch (field_descriptor->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32:
+ return PyInt_FromLong(key.GetInt32Value());
+ case FieldDescriptor::CPPTYPE_INT64:
+ return PyLong_FromLongLong(key.GetInt64Value());
+ case FieldDescriptor::CPPTYPE_UINT32:
+ return PyInt_FromSize_t(key.GetUInt32Value());
+ case FieldDescriptor::CPPTYPE_UINT64:
+ return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
+ case FieldDescriptor::CPPTYPE_BOOL:
+ return PyBool_FromLong(key.GetBoolValue());
+ case FieldDescriptor::CPPTYPE_STRING:
+ return ToStringObject(field_descriptor, key.GetStringValue());
+ default:
+ PyErr_Format(
+ PyExc_SystemError, "Couldn't convert type %d to value",
+ field_descriptor->cpp_type());
+ return NULL;
+ }
+}
+
+// This is only used for ScalarMap, so we don't need to handle the
+// CPPTYPE_MESSAGE case.
+PyObject* MapValueRefToPython(const FieldDescriptor* field_descriptor,
+ MapValueRef* value) {
+ switch (field_descriptor->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32:
+ return PyInt_FromLong(value->GetInt32Value());
+ case FieldDescriptor::CPPTYPE_INT64:
+ return PyLong_FromLongLong(value->GetInt64Value());
+ case FieldDescriptor::CPPTYPE_UINT32:
+ return PyInt_FromSize_t(value->GetUInt32Value());
+ case FieldDescriptor::CPPTYPE_UINT64:
+ return PyLong_FromUnsignedLongLong(value->GetUInt64Value());
+ case FieldDescriptor::CPPTYPE_FLOAT:
+ return PyFloat_FromDouble(value->GetFloatValue());
+ case FieldDescriptor::CPPTYPE_DOUBLE:
+ return PyFloat_FromDouble(value->GetDoubleValue());
+ case FieldDescriptor::CPPTYPE_BOOL:
+ return PyBool_FromLong(value->GetBoolValue());
+ case FieldDescriptor::CPPTYPE_STRING:
+ return ToStringObject(field_descriptor, value->GetStringValue());
+ case FieldDescriptor::CPPTYPE_ENUM:
+ return PyInt_FromLong(value->GetEnumValue());
+ default:
+ PyErr_Format(
+ PyExc_SystemError, "Couldn't convert type %d to value",
+ field_descriptor->cpp_type());
+ return NULL;
+ }
+}
+
+// This is only used for ScalarMap, so we don't need to handle the
+// CPPTYPE_MESSAGE case.
+static bool PythonToMapValueRef(PyObject* obj,
+ const FieldDescriptor* field_descriptor,
+ bool allow_unknown_enum_values,
+ MapValueRef* value_ref) {
+ switch (field_descriptor->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32: {
+ GOOGLE_CHECK_GET_INT32(obj, value, false);
+ value_ref->SetInt32Value(value);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_INT64: {
+ GOOGLE_CHECK_GET_INT64(obj, value, false);
+ value_ref->SetInt64Value(value);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_UINT32: {
+ GOOGLE_CHECK_GET_UINT32(obj, value, false);
+ value_ref->SetUInt32Value(value);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_UINT64: {
+ GOOGLE_CHECK_GET_UINT64(obj, value, false);
+ value_ref->SetUInt64Value(value);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_FLOAT: {
+ GOOGLE_CHECK_GET_FLOAT(obj, value, false);
+ value_ref->SetFloatValue(value);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_DOUBLE: {
+ GOOGLE_CHECK_GET_DOUBLE(obj, value, false);
+ value_ref->SetDoubleValue(value);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_BOOL: {
+ GOOGLE_CHECK_GET_BOOL(obj, value, false);
+ value_ref->SetBoolValue(value);
+ return true;;
+ }
+ case FieldDescriptor::CPPTYPE_STRING: {
+ string str;
+ if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
+ return false;
+ }
+ value_ref->SetStringValue(str);
+ return true;
+ }
+ case FieldDescriptor::CPPTYPE_ENUM: {
+ GOOGLE_CHECK_GET_INT32(obj, value, false);
+ if (allow_unknown_enum_values) {
+ value_ref->SetEnumValue(value);
+ return true;
+ } else {
+ const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
+ const EnumValueDescriptor* enum_value =
+ enum_descriptor->FindValueByNumber(value);
+ if (enum_value != NULL) {
+ value_ref->SetEnumValue(value);
+ return true;
+ } else {
+ PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
+ return false;
+ }
+ }
+ break;
+ }
+ default:
+ PyErr_Format(
+ PyExc_SystemError, "Setting value to a field of unknown type %d",
+ field_descriptor->cpp_type());
+ return false;
+ }
+}
+
+// Map methods common to ScalarMap and MessageMap //////////////////////////////
+
+static MapContainer* GetMap(PyObject* obj) {
+ return reinterpret_cast<MapContainer*>(obj);
+}
+
+Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
+ MapContainer* self = GetMap(_self);
+ const google::protobuf::Message* message = self->message;
+ return message->GetReflection()->MapSize(*message,
+ self->parent_field_descriptor);
+}
+
+PyObject* Clear(PyObject* _self) {
+ MapContainer* self = GetMap(_self);
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+
+ reflection->ClearField(message, self->parent_field_descriptor);
+
+ Py_RETURN_NONE;
+}
+
+PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
+ MapContainer* self = GetMap(_self);
+
+ const Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+ MapKey map_key;
+
+ if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
+ return NULL;
+ }
+
+ if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
+ map_key)) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
+// Initializes the underlying Message object of "to" so it becomes a new parent
+// repeated scalar, and copies all the values from "from" to it. A child scalar
+// container can be released by passing it as both from and to (e.g. making it
+// the recipient of the new parent message and copying the values from itself).
+static int InitializeAndCopyToParentContainer(MapContainer* from,
+ MapContainer* to) {
+ // For now we require from == to, re-evaluate if we want to support deep copy
+ // as in repeated_scalar_container.cc.
+ GOOGLE_DCHECK(from == to);
+ Message* new_message = from->message->New();
+
+ if (MapReflectionFriend::Length(reinterpret_cast<PyObject*>(from)) > 0) {
+ // A somewhat roundabout way of copying just one field from old_message to
+ // new_message. This is the best we can do with what Reflection gives us.
+ Message* mutable_old = from->GetMutableMessage();
+ vector<const FieldDescriptor*> fields;
+ fields.push_back(from->parent_field_descriptor);
+
+ // Move the map field into the new message.
+ mutable_old->GetReflection()->SwapFields(mutable_old, new_message, fields);
+
+ // If/when we support from != to, this will be required also to copy the
+ // map field back into the existing message:
+ // mutable_old->MergeFrom(*new_message);
+ }
+
+ // If from == to this could delete old_message.
+ to->owner.reset(new_message);
+
+ to->parent = NULL;
+ to->parent_field_descriptor = from->parent_field_descriptor;
+ to->message = new_message;
+
+ // Invalidate iterators, since they point to the old copy of the field.
+ to->version++;
+
+ return 0;
+}
+
+int MapContainer::Release() {
+ return InitializeAndCopyToParentContainer(this, this);
+}
+
+
+// ScalarMap ///////////////////////////////////////////////////////////////////
+
+PyObject *NewScalarMapContainer(
+ CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
+ if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
+ return NULL;
+ }
+
+#if PY_MAJOR_VERSION >= 3
+ ScopedPyObjectPtr obj(PyType_GenericAlloc(
+ reinterpret_cast<PyTypeObject *>(ScalarMapContainer_Type), 0));
+#else
+ ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0));
+#endif
+ if (obj.get() == NULL) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Could not allocate new container.");
+ }
+
+ MapContainer* self = GetMap(obj.get());
+
+ self->message = parent->message;
+ self->parent = parent;
+ self->parent_field_descriptor = parent_field_descriptor;
+ self->owner = parent->owner;
+ self->version = 0;
+
+ self->key_field_descriptor =
+ parent_field_descriptor->message_type()->FindFieldByName("key");
+ self->value_field_descriptor =
+ parent_field_descriptor->message_type()->FindFieldByName("value");
+
+ if (self->key_field_descriptor == NULL ||
+ self->value_field_descriptor == NULL) {
+ return PyErr_Format(PyExc_KeyError,
+ "Map entry descriptor did not have key/value fields");
+ }
+
+ return obj.release();
+}
+
+PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
+ PyObject* key) {
+ MapContainer* self = GetMap(_self);
+
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+ MapKey map_key;
+ MapValueRef value;
+
+ if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
+ return NULL;
+ }
+
+ if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
+ map_key, &value)) {
+ self->version++;
+ }
+
+ return MapValueRefToPython(self->value_field_descriptor, &value);
+}
+
+int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
+ PyObject* v) {
+ MapContainer* self = GetMap(_self);
+
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+ MapKey map_key;
+ MapValueRef value;
+
+ if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
+ return -1;
+ }
+
+ self->version++;
+
+ if (v) {
+ // Set item to v.
+ reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
+ map_key, &value);
+
+ return PythonToMapValueRef(v, self->value_field_descriptor,
+ reflection->SupportsUnknownEnumValues(), &value)
+ ? 0
+ : -1;
+ } else {
+ // Delete key from map.
+ if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
+ map_key)) {
+ return 0;
+ } else {
+ PyErr_Format(PyExc_KeyError, "Key not present in map");
+ return -1;
+ }
+ }
+}
+
+static PyObject* ScalarMapGet(PyObject* self, PyObject* args) {
+ PyObject* key;
+ PyObject* default_value = NULL;
+ if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
+ if (is_present.get() == NULL) {
+ return NULL;
+ }
+
+ if (PyObject_IsTrue(is_present.get())) {
+ return MapReflectionFriend::ScalarMapGetItem(self, key);
+ } else {
+ if (default_value != NULL) {
+ Py_INCREF(default_value);
+ return default_value;
+ } else {
+ Py_RETURN_NONE;
+ }
+ }
+}
+
+static void ScalarMapDealloc(PyObject* _self) {
+ MapContainer* self = GetMap(_self);
+ self->owner.reset();
+ Py_TYPE(_self)->tp_free(_self);
+}
+
+static PyMethodDef ScalarMapMethods[] = {
+ { "__contains__", MapReflectionFriend::Contains, METH_O,
+ "Tests whether a key is a member of the map." },
+ { "clear", (PyCFunction)Clear, METH_NOARGS,
+ "Removes all elements from the map." },
+ { "get", ScalarMapGet, METH_VARARGS,
+ "Gets the value for the given key if present, or otherwise a default" },
+ /*
+ { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
+ "Makes a deep copy of the class." },
+ { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
+ "Outputs picklable representation of the repeated field." },
+ */
+ {NULL, NULL},
+};
+
+#if PY_MAJOR_VERSION >= 3
+ static PyType_Slot ScalarMapContainer_Type_slots[] = {
+ {Py_tp_dealloc, (void *)ScalarMapDealloc},
+ {Py_mp_length, (void *)MapReflectionFriend::Length},
+ {Py_mp_subscript, (void *)MapReflectionFriend::ScalarMapGetItem},
+ {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem},
+ {Py_tp_methods, (void *)ScalarMapMethods},
+ {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
+ {0, 0},
+ };
+
+ PyType_Spec ScalarMapContainer_Type_spec = {
+ FULL_MODULE_NAME ".ScalarMapContainer",
+ sizeof(MapContainer),
+ 0,
+ Py_TPFLAGS_DEFAULT,
+ ScalarMapContainer_Type_slots
+ };
+ PyObject *ScalarMapContainer_Type;
+#else
+ static PyMappingMethods ScalarMapMappingMethods = {
+ MapReflectionFriend::Length, // mp_length
+ MapReflectionFriend::ScalarMapGetItem, // mp_subscript
+ MapReflectionFriend::ScalarMapSetItem, // mp_ass_subscript
+ };
+
+ PyTypeObject ScalarMapContainer_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".ScalarMapContainer", // tp_name
+ sizeof(MapContainer), // tp_basicsize
+ 0, // tp_itemsize
+ ScalarMapDealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ &ScalarMapMappingMethods, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A scalar map container", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ MapReflectionFriend::GetIterator, // tp_iter
+ 0, // tp_iternext
+ ScalarMapMethods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ };
+#endif
+
+
+// MessageMap //////////////////////////////////////////////////////////////////
+
+static MessageMapContainer* GetMessageMap(PyObject* obj) {
+ return reinterpret_cast<MessageMapContainer*>(obj);
+}
+
+static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
+ // Get or create the CMessage object corresponding to this message.
+ ScopedPyObjectPtr key(PyLong_FromVoidPtr(message));
+ PyObject* ret = PyDict_GetItem(self->message_dict, key.get());
+
+ if (ret == NULL) {
+ CMessage* cmsg = cmessage::NewEmptyMessage(self->message_class);
+ ret = reinterpret_cast<PyObject*>(cmsg);
+
+ if (cmsg == NULL) {
+ return NULL;
+ }
+ cmsg->owner = self->owner;
+ cmsg->message = message;
+ cmsg->parent = self->parent;
+
+ if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) {
+ Py_DECREF(ret);
+ return NULL;
+ }
+ } else {
+ Py_INCREF(ret);
+ }
+
+ return ret;
+}
+
+PyObject* NewMessageMapContainer(
+ CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
+ CMessageClass* message_class) {
+ if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
+ return NULL;
+ }
+
+#if PY_MAJOR_VERSION >= 3
+ PyObject* obj = PyType_GenericAlloc(
+ reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0);
+#else
+ PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0);
+#endif
+ if (obj == NULL) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Could not allocate new container.");
+ }
+
+ MessageMapContainer* self = GetMessageMap(obj);
+
+ self->message = parent->message;
+ self->parent = parent;
+ self->parent_field_descriptor = parent_field_descriptor;
+ self->owner = parent->owner;
+ self->version = 0;
+
+ self->key_field_descriptor =
+ parent_field_descriptor->message_type()->FindFieldByName("key");
+ self->value_field_descriptor =
+ parent_field_descriptor->message_type()->FindFieldByName("value");
+
+ self->message_dict = PyDict_New();
+ if (self->message_dict == NULL) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Could not allocate message dict.");
+ }
+
+ Py_INCREF(message_class);
+ self->message_class = message_class;
+
+ if (self->key_field_descriptor == NULL ||
+ self->value_field_descriptor == NULL) {
+ Py_DECREF(obj);
+ return PyErr_Format(PyExc_KeyError,
+ "Map entry descriptor did not have key/value fields");
+ }
+
+ return obj;
+}
+
+int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
+ PyObject* v) {
+ if (v) {
+ PyErr_Format(PyExc_ValueError,
+ "Direct assignment of submessage not allowed");
+ return -1;
+ }
+
+ // Now we know that this is a delete, not a set.
+
+ MessageMapContainer* self = GetMessageMap(_self);
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+ MapKey map_key;
+ MapValueRef value;
+
+ self->version++;
+
+ if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
+ return -1;
+ }
+
+ // Delete key from map.
+ if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
+ map_key)) {
+ return 0;
+ } else {
+ PyErr_Format(PyExc_KeyError, "Key not present in map");
+ return -1;
+ }
+}
+
+PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
+ PyObject* key) {
+ MessageMapContainer* self = GetMessageMap(_self);
+
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+ MapKey map_key;
+ MapValueRef value;
+
+ if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
+ return NULL;
+ }
+
+ if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
+ map_key, &value)) {
+ self->version++;
+ }
+
+ return GetCMessage(self, value.MutableMessageValue());
+}
+
+PyObject* MessageMapGet(PyObject* self, PyObject* args) {
+ PyObject* key;
+ PyObject* default_value = NULL;
+ if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
+ if (is_present.get() == NULL) {
+ return NULL;
+ }
+
+ if (PyObject_IsTrue(is_present.get())) {
+ return MapReflectionFriend::MessageMapGetItem(self, key);
+ } else {
+ if (default_value != NULL) {
+ Py_INCREF(default_value);
+ return default_value;
+ } else {
+ Py_RETURN_NONE;
+ }
+ }
+}
+
+static void MessageMapDealloc(PyObject* _self) {
+ MessageMapContainer* self = GetMessageMap(_self);
+ self->owner.reset();
+ Py_DECREF(self->message_dict);
+ Py_DECREF(self->message_class);
+ Py_TYPE(_self)->tp_free(_self);
+}
+
+static PyMethodDef MessageMapMethods[] = {
+ { "__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
+ "Tests whether the map contains this element."},
+ { "clear", (PyCFunction)Clear, METH_NOARGS,
+ "Removes all elements from the map."},
+ { "get", MessageMapGet, METH_VARARGS,
+ "Gets the value for the given key if present, or otherwise a default" },
+ { "get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
+ "Alias for getitem, useful to make explicit that the map is mutated." },
+ /*
+ { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
+ "Makes a deep copy of the class." },
+ { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
+ "Outputs picklable representation of the repeated field." },
+ */
+ {NULL, NULL},
+};
+
+#if PY_MAJOR_VERSION >= 3
+ static PyType_Slot MessageMapContainer_Type_slots[] = {
+ {Py_tp_dealloc, (void *)MessageMapDealloc},
+ {Py_mp_length, (void *)MapReflectionFriend::Length},
+ {Py_mp_subscript, (void *)MapReflectionFriend::MessageMapGetItem},
+ {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem},
+ {Py_tp_methods, (void *)MessageMapMethods},
+ {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
+ {0, 0}
+ };
+
+ PyType_Spec MessageMapContainer_Type_spec = {
+ FULL_MODULE_NAME ".MessageMapContainer",
+ sizeof(MessageMapContainer),
+ 0,
+ Py_TPFLAGS_DEFAULT,
+ MessageMapContainer_Type_slots
+ };
+
+ PyObject *MessageMapContainer_Type;
+#else
+ static PyMappingMethods MessageMapMappingMethods = {
+ MapReflectionFriend::Length, // mp_length
+ MapReflectionFriend::MessageMapGetItem, // mp_subscript
+ MapReflectionFriend::MessageMapSetItem, // mp_ass_subscript
+ };
+
+ PyTypeObject MessageMapContainer_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".MessageMapContainer", // tp_name
+ sizeof(MessageMapContainer), // tp_basicsize
+ 0, // tp_itemsize
+ MessageMapDealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ &MessageMapMappingMethods, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A map container for message", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ MapReflectionFriend::GetIterator, // tp_iter
+ 0, // tp_iternext
+ MessageMapMethods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ };
+#endif
+
+// MapIterator /////////////////////////////////////////////////////////////////
+
+static MapIterator* GetIter(PyObject* obj) {
+ return reinterpret_cast<MapIterator*>(obj);
+}
+
+PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
+ MapContainer* self = GetMap(_self);
+
+ ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
+ if (obj == NULL) {
+ return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
+ }
+
+ MapIterator* iter = GetIter(obj.get());
+
+ Py_INCREF(self);
+ iter->container = self;
+ iter->version = self->version;
+ iter->owner = self->owner;
+
+ if (MapReflectionFriend::Length(_self) > 0) {
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+
+ iter->iter.reset(new ::google::protobuf::MapIterator(
+ reflection->MapBegin(message, self->parent_field_descriptor)));
+ }
+
+ return obj.release();
+}
+
+PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
+ MapIterator* self = GetIter(_self);
+
+ // This won't catch mutations to the map performed by MergeFrom(); no easy way
+ // to address that.
+ if (self->version != self->container->version) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Map modified during iteration.");
+ }
+
+ if (self->iter.get() == NULL) {
+ return NULL;
+ }
+
+ Message* message = self->container->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+
+ if (*self->iter ==
+ reflection->MapEnd(message, self->container->parent_field_descriptor)) {
+ return NULL;
+ }
+
+ PyObject* ret = MapKeyToPython(self->container->key_field_descriptor,
+ self->iter->GetKey());
+
+ ++(*self->iter);
+
+ return ret;
+}
+
+static void DeallocMapIterator(PyObject* _self) {
+ MapIterator* self = GetIter(_self);
+ self->iter.reset();
+ self->owner.reset();
+ Py_XDECREF(self->container);
+ Py_TYPE(_self)->tp_free(_self);
+}
+
+PyTypeObject MapIterator_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".MapIterator", // tp_name
+ sizeof(MapIterator), // tp_basicsize
+ 0, // tp_itemsize
+ DeallocMapIterator, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A scalar map iterator", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ PyObject_SelfIter, // tp_iter
+ MapReflectionFriend::IterNext, // tp_iternext
+ 0, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+};
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/map_container.h b/python/google/protobuf/pyext/map_container.h
new file mode 100644
index 000000000..fbd6713f7
--- /dev/null
+++ b/python/google/protobuf/pyext/map_container.h
@@ -0,0 +1,142 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__
+
+#include <Python.h>
+
+#include <memory>
+#ifndef _SHARED_PTR_H
+#include <google/protobuf/stubs/shared_ptr.h>
+#endif
+
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/message.h>
+
+namespace google {
+namespace protobuf {
+
+class Message;
+
+#ifdef _SHARED_PTR_H
+using std::shared_ptr;
+#else
+using internal::shared_ptr;
+#endif
+
+namespace python {
+
+struct CMessage;
+struct CMessageClass;
+
+// This struct is used directly for ScalarMap, and is the base class of
+// MessageMapContainer, which is used for MessageMap.
+struct MapContainer {
+ PyObject_HEAD;
+
+ // This is the top-level C++ Message object that owns the whole
+ // proto tree. Every Python MapContainer holds a
+ // reference to it in order to keep it alive as long as there's a
+ // Python object that references any part of the tree.
+ shared_ptr<Message> owner;
+
+ // Pointer to the C++ Message that contains this container. The
+ // MapContainer does not own this pointer.
+ const Message* message;
+
+ // Use to get a mutable message when necessary.
+ Message* GetMutableMessage();
+
+ // Weak reference to a parent CMessage object (i.e. may be NULL.)
+ //
+ // Used to make sure all ancestors are also mutable when first
+ // modifying the container.
+ CMessage* parent;
+
+ // Pointer to the parent's descriptor that describes this
+ // field. Used together with the parent's message when making a
+ // default message instance mutable.
+ // The pointer is owned by the global DescriptorPool.
+ const FieldDescriptor* parent_field_descriptor;
+ const FieldDescriptor* key_field_descriptor;
+ const FieldDescriptor* value_field_descriptor;
+
+ // We bump this whenever we perform a mutation, to invalidate existing
+ // iterators.
+ uint64 version;
+
+ // Releases the messages in the container to a new message.
+ //
+ // Returns 0 on success, -1 on failure.
+ int Release();
+
+ // Set the owner field of self and any children of self.
+ void SetOwner(const shared_ptr<Message>& new_owner) {
+ owner = new_owner;
+ }
+};
+
+struct MessageMapContainer : public MapContainer {
+ // The type used to create new child messages.
+ CMessageClass* message_class;
+
+ // A dict mapping Message* -> CMessage.
+ PyObject* message_dict;
+};
+
+#if PY_MAJOR_VERSION >= 3
+ extern PyObject *MessageMapContainer_Type;
+ extern PyType_Spec MessageMapContainer_Type_spec;
+ extern PyObject *ScalarMapContainer_Type;
+ extern PyType_Spec ScalarMapContainer_Type_spec;
+#else
+ extern PyTypeObject MessageMapContainer_Type;
+ extern PyTypeObject ScalarMapContainer_Type;
+#endif
+
+extern PyTypeObject MapIterator_Type; // Both map types use the same iterator.
+
+// Builds a MapContainer object, from a parent message and a
+// field descriptor.
+extern PyObject* NewScalarMapContainer(
+ CMessage* parent, const FieldDescriptor* parent_field_descriptor);
+
+// Builds a MessageMap object, from a parent message and a
+// field descriptor.
+extern PyObject* NewMessageMapContainer(
+ CMessage* parent, const FieldDescriptor* parent_field_descriptor,
+ CMessageClass* message_class);
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index 9fb7083f8..83c151ff6 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -33,12 +33,14 @@
#include <google/protobuf/pyext/message.h>
+#include <map>
#include <memory>
#ifndef _SHARED_PTR_H
#include <google/protobuf/stubs/shared_ptr.h>
#endif
#include <string>
#include <vector>
+#include <structmember.h> // A Python header file.
#ifndef PyVarObject_HEAD_INIT
#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size,
@@ -48,16 +50,21 @@
#endif
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/util/message_differencer.h>
#include <google/protobuf/descriptor.h>
-#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
+#include <google/protobuf/unknown_field_set.h>
#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/extension_dict.h>
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/repeated_scalar_container.h>
+#include <google/protobuf/pyext/map_container.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+#include <google/protobuf/stubs/strutil.h>
#if PY_MAJOR_VERSION >= 3
#define PyInt_Check PyLong_Check
@@ -71,7 +78,11 @@
#error "Python 3.0 - 3.2 are not supported."
#else
#define PyString_AsString(ob) \
- (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob))
+ (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob))
+ #define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob)? \
+ ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \
+ PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
#endif
#endif
@@ -79,14 +90,328 @@ namespace google {
namespace protobuf {
namespace python {
+static PyObject* kDESCRIPTOR;
+static PyObject* k_extensions_by_name;
+static PyObject* k_extensions_by_number;
+PyObject* EnumTypeWrapper_class;
+static PyObject* PythonMessage_class;
+static PyObject* kEmptyWeakref;
+static PyObject* WKT_classes = NULL;
+
+namespace message_meta {
+
+static int InsertEmptyWeakref(PyTypeObject* base);
+
+// Add the number of a field descriptor to the containing message class.
+// Equivalent to:
+// _cls.<field>_FIELD_NUMBER = <number>
+static bool AddFieldNumberToClass(
+ PyObject* cls, const FieldDescriptor* field_descriptor) {
+ string constant_name = field_descriptor->name() + "_FIELD_NUMBER";
+ UpperString(&constant_name);
+ ScopedPyObjectPtr attr_name(PyString_FromStringAndSize(
+ constant_name.c_str(), constant_name.size()));
+ if (attr_name == NULL) {
+ return false;
+ }
+ ScopedPyObjectPtr number(PyInt_FromLong(field_descriptor->number()));
+ if (number == NULL) {
+ return false;
+ }
+ if (PyObject_SetAttr(cls, attr_name.get(), number.get()) == -1) {
+ return false;
+ }
+ return true;
+}
+
+
+// Finalize the creation of the Message class.
+static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
+ // If there are extension_ranges, the message is "extendable", and extension
+ // classes will register themselves in this class.
+ if (descriptor->extension_range_count() > 0) {
+ ScopedPyObjectPtr by_name(PyDict_New());
+ if (PyObject_SetAttr(cls, k_extensions_by_name, by_name.get()) < 0) {
+ return -1;
+ }
+ ScopedPyObjectPtr by_number(PyDict_New());
+ if (PyObject_SetAttr(cls, k_extensions_by_number, by_number.get()) < 0) {
+ return -1;
+ }
+ }
+
+ // For each field set: cls.<field>_FIELD_NUMBER = <number>
+ for (int i = 0; i < descriptor->field_count(); ++i) {
+ if (!AddFieldNumberToClass(cls, descriptor->field(i))) {
+ return -1;
+ }
+ }
+
+ // For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>).
+ for (int i = 0; i < descriptor->enum_type_count(); ++i) {
+ const EnumDescriptor* enum_descriptor = descriptor->enum_type(i);
+ ScopedPyObjectPtr enum_type(
+ PyEnumDescriptor_FromDescriptor(enum_descriptor));
+ if (enum_type == NULL) {
+ return -1;
+ }
+ // Add wrapped enum type to message class.
+ ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs(
+ EnumTypeWrapper_class, enum_type.get(), NULL));
+ if (wrapped == NULL) {
+ return -1;
+ }
+ if (PyObject_SetAttrString(
+ cls, enum_descriptor->name().c_str(), wrapped.get()) == -1) {
+ return -1;
+ }
+
+ // For each enum value add cls.<name> = <number>
+ for (int j = 0; j < enum_descriptor->value_count(); ++j) {
+ const EnumValueDescriptor* enum_value_descriptor =
+ enum_descriptor->value(j);
+ ScopedPyObjectPtr value_number(PyInt_FromLong(
+ enum_value_descriptor->number()));
+ if (value_number == NULL) {
+ return -1;
+ }
+ if (PyObject_SetAttrString(cls, enum_value_descriptor->name().c_str(),
+ value_number.get()) == -1) {
+ return -1;
+ }
+ }
+ }
+
+ // For each extension set cls.<extension name> = <extension descriptor>.
+ //
+ // Extension descriptors come from
+ // <message descriptor>.extensions_by_name[name]
+ // which was defined previously.
+ for (int i = 0; i < descriptor->extension_count(); ++i) {
+ const google::protobuf::FieldDescriptor* field = descriptor->extension(i);
+ ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field));
+ if (extension_field == NULL) {
+ return -1;
+ }
+
+ // Add the extension field to the message class.
+ if (PyObject_SetAttrString(
+ cls, field->name().c_str(), extension_field.get()) == -1) {
+ return -1;
+ }
+
+ // For each extension set cls.<extension name>_FIELD_NUMBER = <number>.
+ if (!AddFieldNumberToClass(cls, field)) {
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+static PyObject* New(PyTypeObject* type,
+ PyObject* args, PyObject* kwargs) {
+ static char *kwlist[] = {"name", "bases", "dict", 0};
+ PyObject *bases, *dict;
+ const char* name;
+
+ // Check arguments: (name, bases, dict)
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "sO!O!:type", kwlist,
+ &name,
+ &PyTuple_Type, &bases,
+ &PyDict_Type, &dict)) {
+ return NULL;
+ }
+
+ // Check bases: only (), or (message.Message,) are allowed
+ if (!(PyTuple_GET_SIZE(bases) == 0 ||
+ (PyTuple_GET_SIZE(bases) == 1 &&
+ PyTuple_GET_ITEM(bases, 0) == PythonMessage_class))) {
+ PyErr_SetString(PyExc_TypeError,
+ "A Message class can only inherit from Message");
+ return NULL;
+ }
+
+ // Check dict['DESCRIPTOR']
+ PyObject* py_descriptor = PyDict_GetItem(dict, kDESCRIPTOR);
+ if (py_descriptor == NULL) {
+ PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR");
+ return NULL;
+ }
+ if (!PyObject_TypeCheck(py_descriptor, &PyMessageDescriptor_Type)) {
+ PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s",
+ py_descriptor->ob_type->tp_name);
+ return NULL;
+ }
+
+ // Build the arguments to the base metaclass.
+ // We change the __bases__ classes.
+ ScopedPyObjectPtr new_args;
+ const Descriptor* message_descriptor =
+ PyMessageDescriptor_AsDescriptor(py_descriptor);
+ if (message_descriptor == NULL) {
+ return NULL;
+ }
+
+ if (WKT_classes == NULL) {
+ ScopedPyObjectPtr well_known_types(PyImport_ImportModule(
+ "google.protobuf.internal.well_known_types"));
+ GOOGLE_DCHECK(well_known_types != NULL);
+
+ WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES");
+ GOOGLE_DCHECK(WKT_classes != NULL);
+ }
+
+ PyObject* well_known_class = PyDict_GetItemString(
+ WKT_classes, message_descriptor->full_name().c_str());
+ if (well_known_class == NULL) {
+ new_args.reset(Py_BuildValue("s(OO)O", name, &CMessage_Type,
+ PythonMessage_class, dict));
+ } else {
+ new_args.reset(Py_BuildValue("s(OOO)O", name, &CMessage_Type,
+ PythonMessage_class, well_known_class, dict));
+ }
+
+ if (new_args == NULL) {
+ return NULL;
+ }
+ // Call the base metaclass.
+ ScopedPyObjectPtr result(PyType_Type.tp_new(type, new_args.get(), NULL));
+ if (result == NULL) {
+ return NULL;
+ }
+ CMessageClass* newtype = reinterpret_cast<CMessageClass*>(result.get());
+
+ // Insert the empty weakref into the base classes.
+ if (InsertEmptyWeakref(
+ reinterpret_cast<PyTypeObject*>(PythonMessage_class)) < 0 ||
+ InsertEmptyWeakref(&CMessage_Type) < 0) {
+ return NULL;
+ }
+
+ // Cache the descriptor, both as Python object and as C++ pointer.
+ const Descriptor* descriptor =
+ PyMessageDescriptor_AsDescriptor(py_descriptor);
+ if (descriptor == NULL) {
+ return NULL;
+ }
+ Py_INCREF(py_descriptor);
+ newtype->py_message_descriptor = py_descriptor;
+ newtype->message_descriptor = descriptor;
+ // TODO(amauryfa): Don't always use the canonical pool of the descriptor,
+ // use the MessageFactory optionally passed in the class dict.
+ newtype->py_descriptor_pool = GetDescriptorPool_FromPool(
+ descriptor->file()->pool());
+ if (newtype->py_descriptor_pool == NULL) {
+ return NULL;
+ }
+ Py_INCREF(newtype->py_descriptor_pool);
+
+ // Add the message to the DescriptorPool.
+ if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool,
+ descriptor, newtype) < 0) {
+ return NULL;
+ }
+
+ // Continue with type initialization: add other descriptors, enum values...
+ if (AddDescriptors(result.get(), descriptor) < 0) {
+ return NULL;
+ }
+ return result.release();
+}
+
+static void Dealloc(CMessageClass *self) {
+ Py_DECREF(self->py_message_descriptor);
+ Py_DECREF(self->py_descriptor_pool);
+ Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+}
+
+
+// This function inserts and empty weakref at the end of the list of
+// subclasses for the main protocol buffer Message class.
+//
+// This eliminates a O(n^2) behaviour in the internal add_subclass
+// routine.
+static int InsertEmptyWeakref(PyTypeObject *base_type) {
+#if PY_MAJOR_VERSION >= 3
+ // Python 3.4 has already included the fix for the issue that this
+ // hack addresses. For further background and the fix please see
+ // https://bugs.python.org/issue17936.
+ return 0;
+#else
+ PyObject *subclasses = base_type->tp_subclasses;
+ if (subclasses && PyList_CheckExact(subclasses)) {
+ return PyList_Append(subclasses, kEmptyWeakref);
+ }
+ return 0;
+#endif // PY_MAJOR_VERSION >= 3
+}
+
+} // namespace message_meta
+
+PyTypeObject CMessageClass_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".MessageMeta", // tp_name
+ sizeof(CMessageClass), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)message_meta::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags
+ "The metaclass of ProtocolMessages", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ 0, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ 0, // tp_alloc
+ message_meta::New, // tp_new
+};
+
+static CMessageClass* CheckMessageClass(PyTypeObject* cls) {
+ if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) {
+ PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name);
+ return NULL;
+ }
+ return reinterpret_cast<CMessageClass*>(cls);
+}
+
+static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) {
+ CMessageClass* type = CheckMessageClass(cls);
+ if (type == NULL) {
+ return NULL;
+ }
+ return type->message_descriptor;
+}
+
// Forward declarations
namespace cmessage {
-static PyObject* GetDescriptor(CMessage* self, PyObject* name);
-static string GetMessageName(CMessage* self);
int InternalReleaseFieldByDescriptor(
- const google::protobuf::FieldDescriptor* field_descriptor,
- PyObject* composite_field,
- google::protobuf::Message* parent_message);
+ CMessage* self,
+ const FieldDescriptor* field_descriptor,
+ PyObject* composite_field);
} // namespace cmessage
// ---------------------------------------------------------------------
@@ -105,7 +430,7 @@ struct ChildVisitor {
// Returns 0 on success, -1 on failure.
int VisitCMessage(CMessage* cmessage,
- const google::protobuf::FieldDescriptor* field_descriptor) {
+ const FieldDescriptor* field_descriptor) {
return 0;
}
};
@@ -116,20 +441,26 @@ template<class Visitor>
static int VisitCompositeField(const FieldDescriptor* descriptor,
PyObject* child,
Visitor visitor) {
- if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
- if (descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
- RepeatedCompositeContainer* container =
- reinterpret_cast<RepeatedCompositeContainer*>(child);
- if (visitor.VisitRepeatedCompositeContainer(container) == -1)
- return -1;
+ if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
+ if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ if (descriptor->is_map()) {
+ MapContainer* container = reinterpret_cast<MapContainer*>(child);
+ if (visitor.VisitMapContainer(container) == -1) {
+ return -1;
+ }
+ } else {
+ RepeatedCompositeContainer* container =
+ reinterpret_cast<RepeatedCompositeContainer*>(child);
+ if (visitor.VisitRepeatedCompositeContainer(container) == -1)
+ return -1;
+ }
} else {
RepeatedScalarContainer* container =
reinterpret_cast<RepeatedScalarContainer*>(child);
if (visitor.VisitRepeatedScalarContainer(container) == -1)
return -1;
}
- } else if (descriptor->cpp_type() ==
- google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
CMessage* cmsg = reinterpret_cast<CMessage*>(child);
if (visitor.VisitCMessage(cmsg, descriptor) == -1)
return -1;
@@ -148,24 +479,33 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) {
PyObject* field;
// Visit normal fields.
- while (PyDict_Next(self->composite_fields, &pos, &key, &field)) {
- PyObject* cdescriptor = cmessage::GetDescriptor(self, key);
- if (cdescriptor != NULL) {
- const google::protobuf::FieldDescriptor* descriptor =
- reinterpret_cast<CFieldDescriptor*>(cdescriptor)->descriptor;
- if (VisitCompositeField(descriptor, field, visitor) == -1)
+ if (self->composite_fields) {
+ // Never use self->message in this function, it may be already freed.
+ const Descriptor* message_descriptor =
+ GetMessageDescriptor(Py_TYPE(self));
+ while (PyDict_Next(self->composite_fields, &pos, &key, &field)) {
+ Py_ssize_t key_str_size;
+ char *key_str_data;
+ if (PyString_AsStringAndSize(key, &key_str_data, &key_str_size) != 0)
return -1;
+ const string key_str(key_str_data, key_str_size);
+ const FieldDescriptor* descriptor =
+ message_descriptor->FindFieldByName(key_str);
+ if (descriptor != NULL) {
+ if (VisitCompositeField(descriptor, field, visitor) == -1)
+ return -1;
+ }
}
}
// Visit extension fields.
if (self->extensions != NULL) {
+ pos = 0;
while (PyDict_Next(self->extensions->values, &pos, &key, &field)) {
- CFieldDescriptor* cdescriptor =
- extension_dict::InternalGetCDescriptorFromExtension(key);
- if (cdescriptor == NULL)
+ const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
+ if (descriptor == NULL)
return -1;
- if (VisitCompositeField(cdescriptor->descriptor, field, visitor) == -1)
+ if (VisitCompositeField(descriptor, field, visitor) == -1)
return -1;
}
}
@@ -184,25 +524,13 @@ PyObject* kint64min_py;
PyObject* kint64max_py;
PyObject* kuint64max_py;
-PyObject* EnumTypeWrapper_class;
PyObject* EncodeError_class;
PyObject* DecodeError_class;
PyObject* PickleError_class;
// Constant PyString values used for GetAttr/GetItem.
-static PyObject* kDESCRIPTOR;
-static PyObject* k__descriptors;
+static PyObject* k_cdescriptor;
static PyObject* kfull_name;
-static PyObject* kname;
-static PyObject* kmessage_type;
-static PyObject* kis_extendable;
-static PyObject* kextensions_by_name;
-static PyObject* k_extensions_by_name;
-static PyObject* k_extensions_by_number;
-static PyObject* k_concrete_class;
-static PyObject* kfields_by_name;
-
-static CDescriptorPool* descriptor_pool;
/* Is 64bit */
void FormatTypeError(PyObject* arg, char* expected_types) {
@@ -235,12 +563,14 @@ bool CheckAndGetInteger(
if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 ||
PyObject_RichCompareBool(max, arg, Py_GE) != 1) {
#endif
- PyObject *s = PyObject_Str(arg);
- if (s) {
- PyErr_Format(PyExc_ValueError,
- "Value out of range: %s",
- PyString_AsString(s));
- Py_DECREF(s);
+ if (!PyErr_Occurred()) {
+ PyObject *s = PyObject_Str(arg);
+ if (s) {
+ PyErr_Format(PyExc_ValueError,
+ "Value out of range: %s",
+ PyString_AsString(s));
+ Py_DECREF(s);
+ }
}
return false;
}
@@ -298,49 +628,59 @@ bool CheckAndGetBool(PyObject* arg, bool* value) {
return true;
}
-bool CheckAndSetString(
- PyObject* arg, google::protobuf::Message* message,
- const google::protobuf::FieldDescriptor* descriptor,
- const google::protobuf::Reflection* reflection,
- bool append,
- int index) {
- GOOGLE_DCHECK(descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING ||
- descriptor->type() == google::protobuf::FieldDescriptor::TYPE_BYTES);
- if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) {
+// Checks whether the given object (which must be "bytes" or "unicode") contains
+// valid UTF-8.
+bool IsValidUTF8(PyObject* obj) {
+ if (PyBytes_Check(obj)) {
+ PyObject* unicode = PyUnicode_FromEncodedObject(obj, "utf-8", NULL);
+
+ // Clear the error indicator; we report our own error when desired.
+ PyErr_Clear();
+
+ if (unicode) {
+ Py_DECREF(unicode);
+ return true;
+ } else {
+ return false;
+ }
+ } else {
+ // Unicode object, known to be valid UTF-8.
+ return true;
+ }
+}
+
+bool AllowInvalidUTF8(const FieldDescriptor* field) { return false; }
+
+PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
+ GOOGLE_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING ||
+ descriptor->type() == FieldDescriptor::TYPE_BYTES);
+ if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) {
FormatTypeError(arg, "bytes, unicode");
- return false;
+ return NULL;
}
- if (PyBytes_Check(arg)) {
- PyObject* unicode = PyUnicode_FromEncodedObject(arg, "ascii", NULL);
- if (unicode == NULL) {
- PyObject* repr = PyObject_Repr(arg);
- PyErr_Format(PyExc_ValueError,
- "%s has type str, but isn't in 7-bit ASCII "
- "encoding. Non-ASCII strings must be converted to "
- "unicode objects before being added.",
- PyString_AsString(repr));
- Py_DECREF(repr);
- return false;
- } else {
- Py_DECREF(unicode);
- }
+ if (!IsValidUTF8(arg) && !AllowInvalidUTF8(descriptor)) {
+ PyObject* repr = PyObject_Repr(arg);
+ PyErr_Format(PyExc_ValueError,
+ "%s has type str, but isn't valid UTF-8 "
+ "encoding. Non-UTF-8 strings must be converted to "
+ "unicode objects before being added.",
+ PyString_AsString(repr));
+ Py_DECREF(repr);
+ return NULL;
}
} else if (!PyBytes_Check(arg)) {
FormatTypeError(arg, "bytes");
- return false;
+ return NULL;
}
PyObject* encoded_string = NULL;
- if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) {
+ if (descriptor->type() == FieldDescriptor::TYPE_STRING) {
if (PyBytes_Check(arg)) {
-#if PY_MAJOR_VERSION < 3
- encoded_string = PyString_AsEncodedObject(arg, "utf-8", NULL);
-#else
+ // The bytes were already validated as correctly encoded UTF-8 above.
encoded_string = arg; // Already encoded.
Py_INCREF(encoded_string);
-#endif
} else {
encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL);
}
@@ -350,14 +690,24 @@ bool CheckAndSetString(
Py_INCREF(encoded_string);
}
- if (encoded_string == NULL) {
+ return encoded_string;
+}
+
+bool CheckAndSetString(
+ PyObject* arg, Message* message,
+ const FieldDescriptor* descriptor,
+ const Reflection* reflection,
+ bool append,
+ int index) {
+ ScopedPyObjectPtr encoded_string(CheckString(arg, descriptor));
+
+ if (encoded_string.get() == NULL) {
return false;
}
char* value;
Py_ssize_t value_len;
- if (PyBytes_AsStringAndSize(encoded_string, &value, &value_len) < 0) {
- Py_DECREF(encoded_string);
+ if (PyBytes_AsStringAndSize(encoded_string.get(), &value, &value_len) < 0) {
return false;
}
@@ -369,13 +719,11 @@ bool CheckAndSetString(
} else {
reflection->SetRepeatedString(message, descriptor, index, value_string);
}
- Py_DECREF(encoded_string);
return true;
}
-PyObject* ToStringObject(
- const google::protobuf::FieldDescriptor* descriptor, string value) {
- if (descriptor->type() != google::protobuf::FieldDescriptor::TYPE_STRING) {
+PyObject* ToStringObject(const FieldDescriptor* descriptor, string value) {
+ if (descriptor->type() != FieldDescriptor::TYPE_STRING) {
return PyBytes_FromStringAndSize(value.c_str(), value.length());
}
@@ -391,16 +739,36 @@ PyObject* ToStringObject(
return result;
}
-google::protobuf::DynamicMessageFactory* global_message_factory;
+bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
+ const Message* message) {
+ if (message->GetDescriptor() == field_descriptor->containing_type()) {
+ return true;
+ }
+ PyErr_Format(PyExc_KeyError, "Field '%s' does not belong to message '%s'",
+ field_descriptor->full_name().c_str(),
+ message->GetDescriptor()->full_name().c_str());
+ return false;
+}
namespace cmessage {
+PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) {
+ // No need to check the type: the type of instances of CMessage is always
+ // an instance of CMessageClass. Let's prove it with a debug-only check.
+ GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type));
+ return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_descriptor_pool;
+}
+
+MessageFactory* GetFactoryForMessage(CMessage* message) {
+ return GetDescriptorPoolForMessage(message)->message_factory;
+}
+
static int MaybeReleaseOverlappingOneofField(
CMessage* cmessage,
- const google::protobuf::FieldDescriptor* field) {
+ const FieldDescriptor* field) {
#ifdef GOOGLE_PROTOBUF_HAS_ONEOF
- google::protobuf::Message* message = cmessage->message;
- const google::protobuf::Reflection* reflection = message->GetReflection();
+ Message* message = cmessage->message;
+ const Reflection* reflection = message->GetReflection();
if (!field->containing_oneof() ||
!reflection->HasOneof(*message, field->containing_oneof()) ||
reflection->HasField(*message, field)) {
@@ -411,20 +779,20 @@ static int MaybeReleaseOverlappingOneofField(
const OneofDescriptor* oneof = field->containing_oneof();
const FieldDescriptor* existing_field =
reflection->GetOneofFieldDescriptor(*message, oneof);
- if (existing_field->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ if (existing_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
// Non-message fields don't need to be released.
return 0;
}
const char* field_name = existing_field->name().c_str();
- PyObject* child_message = PyDict_GetItemString(
- cmessage->composite_fields, field_name);
+ PyObject* child_message = cmessage->composite_fields ?
+ PyDict_GetItemString(cmessage->composite_fields, field_name) : NULL;
if (child_message == NULL) {
// No python reference to this field so no need to release.
return 0;
}
if (InternalReleaseFieldByDescriptor(
- existing_field, child_message, message) < 0) {
+ cmessage, existing_field, child_message) < 0) {
return -1;
}
return PyDict_DelItemString(cmessage->composite_fields, field_name);
@@ -436,21 +804,21 @@ static int MaybeReleaseOverlappingOneofField(
// ---------------------------------------------------------------------
// Making a message writable
-static google::protobuf::Message* GetMutableMessage(
+static Message* GetMutableMessage(
CMessage* parent,
- const google::protobuf::FieldDescriptor* parent_field) {
- google::protobuf::Message* parent_message = parent->message;
- const google::protobuf::Reflection* reflection = parent_message->GetReflection();
+ const FieldDescriptor* parent_field) {
+ Message* parent_message = parent->message;
+ const Reflection* reflection = parent_message->GetReflection();
if (MaybeReleaseOverlappingOneofField(parent, parent_field) < 0) {
return NULL;
}
return reflection->MutableMessage(
- parent_message, parent_field, global_message_factory);
+ parent_message, parent_field, GetFactoryForMessage(parent));
}
struct FixupMessageReference : public ChildVisitor {
// message must outlive this object.
- explicit FixupMessageReference(google::protobuf::Message* message) :
+ explicit FixupMessageReference(Message* message) :
message_(message) {}
int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
@@ -463,8 +831,13 @@ struct FixupMessageReference : public ChildVisitor {
return 0;
}
+ int VisitMapContainer(MapContainer* container) {
+ container->message = message_;
+ return 0;
+ }
+
private:
- google::protobuf::Message* message_;
+ Message* message_;
};
int AssureWritable(CMessage* self) {
@@ -476,20 +849,20 @@ int AssureWritable(CMessage* self) {
// If parent is NULL but we are trying to modify a read-only message, this
// is a reference to a constant default instance that needs to be replaced
// with a mutable top-level message.
- const Message* prototype = global_message_factory->GetPrototype(
- self->message->GetDescriptor());
- self->message = prototype->New();
+ self->message = self->message->New();
self->owner.reset(self->message);
+ // Cascade the new owner to eventual children: even if this message is
+ // empty, some submessages or repeated containers might exist already.
+ SetOwner(self, self->owner);
} else {
// Otherwise, we need a mutable child message.
if (AssureWritable(self->parent) == -1)
return -1;
// Make self->message writable.
- google::protobuf::Message* parent_message = self->parent->message;
- google::protobuf::Message* mutable_message = GetMutableMessage(
+ Message* mutable_message = GetMutableMessage(
self->parent,
- self->parent_field->descriptor);
+ self->parent_field_descriptor);
if (mutable_message == NULL) {
return -1;
}
@@ -500,8 +873,8 @@ int AssureWritable(CMessage* self) {
// When a CMessage is made writable its Message pointer is updated
// to point to a new mutable Message. When that happens we need to
// update any references to the old, read-only CMessage. There are
- // three places such references occur: RepeatedScalarContainer,
- // RepeatedCompositeContainer, and ExtensionDict.
+ // four places such references occur: RepeatedScalarContainer,
+ // RepeatedCompositeContainer, MapContainer, and ExtensionDict.
if (self->extensions != NULL)
self->extensions->message = self->message;
if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1)
@@ -512,26 +885,65 @@ int AssureWritable(CMessage* self) {
// --- Globals:
-static PyObject* GetDescriptor(CMessage* self, PyObject* name) {
- PyObject* descriptors =
- PyDict_GetItem(Py_TYPE(self)->tp_dict, k__descriptors);
- if (descriptors == NULL) {
- PyErr_SetString(PyExc_TypeError, "No __descriptors");
+// Retrieve a C++ FieldDescriptor for a message attribute.
+// The C++ message must be valid.
+// TODO(amauryfa): This function should stay internal, because exception
+// handling is not consistent.
+static const FieldDescriptor* GetFieldDescriptor(
+ CMessage* self, PyObject* name) {
+ const Descriptor *message_descriptor = self->message->GetDescriptor();
+ char* field_name;
+ Py_ssize_t size;
+ if (PyString_AsStringAndSize(name, &field_name, &size) < 0) {
return NULL;
}
-
- return PyDict_GetItem(descriptors, name);
+ const FieldDescriptor *field_descriptor =
+ message_descriptor->FindFieldByName(string(field_name, size));
+ if (field_descriptor == NULL) {
+ // Note: No exception is set!
+ return NULL;
+ }
+ return field_descriptor;
}
-static const google::protobuf::Message* CreateMessage(const char* message_type) {
- string message_name(message_type);
- const google::protobuf::Descriptor* descriptor =
- GetDescriptorPool()->FindMessageTypeByName(message_name);
- if (descriptor == NULL) {
- PyErr_SetString(PyExc_TypeError, message_type);
+// Retrieve a C++ FieldDescriptor for an extension handle.
+const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) {
+ ScopedPyObjectPtr cdescriptor;
+ if (!PyObject_TypeCheck(extension, &PyFieldDescriptor_Type)) {
+ // Most callers consider extensions as a plain dictionary. We should
+ // allow input which is not a field descriptor, and simply pretend it does
+ // not exist.
+ PyErr_SetObject(PyExc_KeyError, extension);
return NULL;
}
- return global_message_factory->GetPrototype(descriptor);
+ return PyFieldDescriptor_AsDescriptor(extension);
+}
+
+// If value is a string, convert it into an enum value based on the labels in
+// descriptor, otherwise simply return value. Always returns a new reference.
+static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor,
+ PyObject* value) {
+ if (PyString_Check(value) || PyUnicode_Check(value)) {
+ const EnumDescriptor* enum_descriptor = descriptor.enum_type();
+ if (enum_descriptor == NULL) {
+ PyErr_SetString(PyExc_TypeError, "not an enum field");
+ return NULL;
+ }
+ char* enum_label;
+ Py_ssize_t size;
+ if (PyString_AsStringAndSize(value, &enum_label, &size) < 0) {
+ return NULL;
+ }
+ const EnumValueDescriptor* enum_value_descriptor =
+ enum_descriptor->FindValueByName(string(enum_label, size));
+ if (enum_value_descriptor == NULL) {
+ PyErr_SetString(PyExc_ValueError, "unknown enum label");
+ return NULL;
+ }
+ return PyInt_FromLong(enum_value_descriptor->number());
+ }
+ Py_INCREF(value);
+ return value;
}
// If cmessage_list is not NULL, this function releases values into the
@@ -539,12 +951,13 @@ static const google::protobuf::Message* CreateMessage(const char* message_type)
// needs to do this to make sure CMessages stay alive if they're still
// referenced after deletion. Repeated scalar container doesn't need to worry.
int InternalDeleteRepeatedField(
- google::protobuf::Message* message,
- const google::protobuf::FieldDescriptor* field_descriptor,
+ CMessage* self,
+ const FieldDescriptor* field_descriptor,
PyObject* slice,
PyObject* cmessage_list) {
+ Message* message = self->message;
Py_ssize_t length, from, to, step, slice_length;
- const google::protobuf::Reflection* reflection = message->GetReflection();
+ const Reflection* reflection = message->GetReflection();
int min, max;
length = reflection->FieldSize(*message, field_descriptor);
@@ -616,7 +1029,7 @@ int InternalDeleteRepeatedField(
CMessage* last_cmessage = reinterpret_cast<CMessage*>(
PyList_GET_ITEM(cmessage_list, PyList_GET_SIZE(cmessage_list) - 1));
repeated_composite_container::ReleaseLastTo(
- field_descriptor, message, last_cmessage);
+ self, field_descriptor, last_cmessage);
if (PySequence_DelItem(cmessage_list, -1) < 0) {
return -1;
}
@@ -627,39 +1040,8 @@ int InternalDeleteRepeatedField(
return 0;
}
-int InitAttributes(CMessage* self, PyObject* arg, PyObject* kwargs) {
- ScopedPyObjectPtr descriptor;
- if (arg == NULL) {
- descriptor.reset(
- PyObject_GetAttr(reinterpret_cast<PyObject*>(self), kDESCRIPTOR));
- if (descriptor == NULL) {
- return NULL;
- }
- } else {
- descriptor.reset(arg);
- descriptor.inc();
- }
- ScopedPyObjectPtr is_extendable(PyObject_GetAttr(descriptor, kis_extendable));
- if (is_extendable == NULL) {
- return NULL;
- }
- int retcode = PyObject_IsTrue(is_extendable);
- if (retcode == -1) {
- return NULL;
- }
- if (retcode) {
- PyObject* py_extension_dict = PyObject_CallObject(
- reinterpret_cast<PyObject*>(&ExtensionDict_Type), NULL);
- if (py_extension_dict == NULL) {
- return NULL;
- }
- ExtensionDict* extension_dict = reinterpret_cast<ExtensionDict*>(
- py_extension_dict);
- extension_dict->parent = self;
- extension_dict->message = self->message;
- self->extensions = extension_dict;
- }
-
+// Initializes fields of a message. Used in constructors.
+int InitAttributes(CMessage* self, PyObject* kwargs) {
if (kwargs == NULL) {
return 0;
}
@@ -672,46 +1054,138 @@ int InitAttributes(CMessage* self, PyObject* arg, PyObject* kwargs) {
PyErr_SetString(PyExc_ValueError, "Field name must be a string");
return -1;
}
- PyObject* py_cdescriptor = GetDescriptor(self, name);
- if (py_cdescriptor == NULL) {
- PyErr_Format(PyExc_ValueError, "Protocol message has no \"%s\" field.",
+ const FieldDescriptor* descriptor = GetFieldDescriptor(self, name);
+ if (descriptor == NULL) {
+ PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.",
+ self->message->GetDescriptor()->name().c_str(),
PyString_AsString(name));
return -1;
}
- const google::protobuf::FieldDescriptor* descriptor =
- reinterpret_cast<CFieldDescriptor*>(py_cdescriptor)->descriptor;
- if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
+ if (value == Py_None) {
+ // field=None is the same as no field at all.
+ continue;
+ }
+ if (descriptor->is_map()) {
+ ScopedPyObjectPtr map(GetAttr(self, name));
+ const FieldDescriptor* value_descriptor =
+ descriptor->message_type()->FindFieldByName("value");
+ if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ Py_ssize_t map_pos = 0;
+ PyObject* map_key;
+ PyObject* map_value;
+ while (PyDict_Next(value, &map_pos, &map_key, &map_value)) {
+ ScopedPyObjectPtr function_return;
+ function_return.reset(PyObject_GetItem(map.get(), map_key));
+ if (function_return.get() == NULL) {
+ return -1;
+ }
+ ScopedPyObjectPtr ok(PyObject_CallMethod(
+ function_return.get(), "MergeFrom", "O", map_value));
+ if (ok.get() == NULL) {
+ return -1;
+ }
+ }
+ } else {
+ ScopedPyObjectPtr function_return;
+ function_return.reset(
+ PyObject_CallMethod(map.get(), "update", "O", value));
+ if (function_return.get() == NULL) {
+ return -1;
+ }
+ }
+ } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
ScopedPyObjectPtr container(GetAttr(self, name));
if (container == NULL) {
return -1;
}
- if (descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
- if (repeated_composite_container::Extend(
- reinterpret_cast<RepeatedCompositeContainer*>(container.get()),
- value)
- == NULL) {
+ if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ RepeatedCompositeContainer* rc_container =
+ reinterpret_cast<RepeatedCompositeContainer*>(container.get());
+ ScopedPyObjectPtr iter(PyObject_GetIter(value));
+ if (iter == NULL) {
+ PyErr_SetString(PyExc_TypeError, "Value must be iterable");
+ return -1;
+ }
+ ScopedPyObjectPtr next;
+ while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
+ PyObject* kwargs = (PyDict_Check(next.get()) ? next.get() : NULL);
+ ScopedPyObjectPtr new_msg(
+ repeated_composite_container::Add(rc_container, NULL, kwargs));
+ if (new_msg == NULL) {
+ return -1;
+ }
+ if (kwargs == NULL) {
+ // next was not a dict, it's a message we need to merge
+ ScopedPyObjectPtr merged(MergeFrom(
+ reinterpret_cast<CMessage*>(new_msg.get()), next.get()));
+ if (merged.get() == NULL) {
+ return -1;
+ }
+ }
+ }
+ if (PyErr_Occurred()) {
+ // Check to see how PyIter_Next() exited.
+ return -1;
+ }
+ } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
+ RepeatedScalarContainer* rs_container =
+ reinterpret_cast<RepeatedScalarContainer*>(container.get());
+ ScopedPyObjectPtr iter(PyObject_GetIter(value));
+ if (iter == NULL) {
+ PyErr_SetString(PyExc_TypeError, "Value must be iterable");
+ return -1;
+ }
+ ScopedPyObjectPtr next;
+ while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
+ ScopedPyObjectPtr enum_value(
+ GetIntegerEnumValue(*descriptor, next.get()));
+ if (enum_value == NULL) {
+ return -1;
+ }
+ ScopedPyObjectPtr new_msg(repeated_scalar_container::Append(
+ rs_container, enum_value.get()));
+ if (new_msg == NULL) {
+ return -1;
+ }
+ }
+ if (PyErr_Occurred()) {
+ // Check to see how PyIter_Next() exited.
return -1;
}
} else {
- if (repeated_scalar_container::Extend(
+ if (ScopedPyObjectPtr(repeated_scalar_container::Extend(
reinterpret_cast<RepeatedScalarContainer*>(container.get()),
- value) ==
+ value)) ==
NULL) {
return -1;
}
}
- } else if (descriptor->cpp_type() ==
- google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
ScopedPyObjectPtr message(GetAttr(self, name));
if (message == NULL) {
return -1;
}
- if (MergeFrom(reinterpret_cast<CMessage*>(message.get()),
- value) == NULL) {
- return -1;
+ CMessage* cmessage = reinterpret_cast<CMessage*>(message.get());
+ if (PyDict_Check(value)) {
+ if (InitAttributes(cmessage, value) < 0) {
+ return -1;
+ }
+ } else {
+ ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
+ if (merged == NULL) {
+ return -1;
+ }
}
} else {
- if (SetAttr(self, name, value) < 0) {
+ ScopedPyObjectPtr new_val;
+ if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
+ new_val.reset(GetIntegerEnumValue(*descriptor, value));
+ if (new_val == NULL) {
+ return -1;
+ }
+ }
+ if (SetAttr(self, name, (new_val.get() == NULL) ? value : new_val.get()) <
+ 0) {
return -1;
}
}
@@ -719,59 +1193,64 @@ int InitAttributes(CMessage* self, PyObject* arg, PyObject* kwargs) {
return 0;
}
-static PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
- CMessage* self = reinterpret_cast<CMessage*>(type->tp_alloc(type, 0));
+// Allocates an incomplete Python Message: the caller must fill self->message,
+// self->owner and eventually self->parent.
+CMessage* NewEmptyMessage(CMessageClass* type) {
+ CMessage* self = reinterpret_cast<CMessage*>(
+ PyType_GenericAlloc(&type->super.ht_type, 0));
if (self == NULL) {
return NULL;
}
self->message = NULL;
self->parent = NULL;
- self->parent_field = NULL;
+ self->parent_field_descriptor = NULL;
self->read_only = false;
self->extensions = NULL;
- self->composite_fields = PyDict_New();
- if (self->composite_fields == NULL) {
- return NULL;
- }
- return reinterpret_cast<PyObject*>(self);
-}
+ self->composite_fields = NULL;
-PyObject* NewEmpty(PyObject* type) {
- return New(reinterpret_cast<PyTypeObject*>(type), NULL, NULL);
+ return self;
}
-static int Init(CMessage* self, PyObject* args, PyObject* kwargs) {
- if (kwargs == NULL) {
- // TODO(anuraag): Set error
- return -1;
+// The __new__ method of Message classes.
+// Creates a new C++ message and takes ownership.
+static PyObject* New(PyTypeObject* cls,
+ PyObject* unused_args, PyObject* unused_kwargs) {
+ CMessageClass* type = CheckMessageClass(cls);
+ if (type == NULL) {
+ return NULL;
}
-
- PyObject* descriptor = PyTuple_GetItem(args, 0);
- if (descriptor == NULL || PyTuple_Size(args) != 1) {
- PyErr_SetString(PyExc_ValueError, "args must contain one arg: descriptor");
- return -1;
+ // Retrieve the message descriptor and the default instance (=prototype).
+ const Descriptor* message_descriptor = type->message_descriptor;
+ if (message_descriptor == NULL) {
+ return NULL;
}
-
- ScopedPyObjectPtr py_message_type(PyObject_GetAttr(descriptor, kfull_name));
- if (py_message_type == NULL) {
- return -1;
+ const Message* default_message = type->py_descriptor_pool->message_factory
+ ->GetPrototype(message_descriptor);
+ if (default_message == NULL) {
+ PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str());
+ return NULL;
}
- const char* message_type = PyString_AsString(py_message_type.get());
- const google::protobuf::Message* message = CreateMessage(message_type);
- if (message == NULL) {
- return -1;
+ CMessage* self = NewEmptyMessage(type);
+ if (self == NULL) {
+ return NULL;
}
-
- self->message = message->New();
+ self->message = default_message->New();
self->owner.reset(self->message);
+ return reinterpret_cast<PyObject*>(self);
+}
- if (InitAttributes(self, descriptor, kwargs) < 0) {
+// The __init__ method of Message classes.
+// It initializes fields from keywords passed to the constructor.
+static int Init(CMessage* self, PyObject* args, PyObject* kwargs) {
+ if (PyTuple_Size(args) != 0) {
+ PyErr_SetString(PyExc_TypeError, "No positional arguments allowed");
return -1;
}
- return 0;
+
+ return InitAttributes(self, kwargs);
}
// ---------------------------------------------------------------------
@@ -800,8 +1279,13 @@ struct ClearWeakReferences : public ChildVisitor {
return 0;
}
+ int VisitMapContainer(MapContainer* container) {
+ container->parent = NULL;
+ return 0;
+ }
+
int VisitCMessage(CMessage* cmessage,
- const google::protobuf::FieldDescriptor* field_descriptor) {
+ const FieldDescriptor* field_descriptor) {
cmessage->parent = NULL;
return 0;
}
@@ -810,6 +1294,9 @@ struct ClearWeakReferences : public ChildVisitor {
static void Dealloc(CMessage* self) {
// Null out all weak references from children to this message.
GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences()));
+ if (self->extensions) {
+ self->extensions->parent = NULL;
+ }
Py_CLEAR(self->extensions);
Py_CLEAR(self->composite_fields);
@@ -851,14 +1338,12 @@ PyObject* IsInitialized(CMessage* self, PyObject* args) {
}
PyObject* HasFieldByDescriptor(
- CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor) {
- google::protobuf::Message* message = self->message;
- if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) {
- PyErr_SetString(PyExc_KeyError,
- "Field does not belong to message!");
+ CMessage* self, const FieldDescriptor* field_descriptor) {
+ Message* message = self->message;
+ if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
return NULL;
}
- if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
+ if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
PyErr_SetString(PyExc_KeyError,
"Field is repeated. A singular method is required.");
return NULL;
@@ -868,42 +1353,78 @@ PyObject* HasFieldByDescriptor(
return PyBool_FromLong(has_field ? 1 : 0);
}
-const google::protobuf::FieldDescriptor* FindFieldWithOneofs(
- const google::protobuf::Message* message, const char* field_name, bool* in_oneof) {
- const google::protobuf::Descriptor* descriptor = message->GetDescriptor();
- const google::protobuf::FieldDescriptor* field_descriptor =
+const FieldDescriptor* FindFieldWithOneofs(
+ const Message* message, const string& field_name, bool* in_oneof) {
+ *in_oneof = false;
+ const Descriptor* descriptor = message->GetDescriptor();
+ const FieldDescriptor* field_descriptor =
descriptor->FindFieldByName(field_name);
- if (field_descriptor == NULL) {
- const google::protobuf::OneofDescriptor* oneof_desc =
- message->GetDescriptor()->FindOneofByName(field_name);
- if (oneof_desc == NULL) {
- *in_oneof = false;
- return NULL;
- } else {
- *in_oneof = true;
- return message->GetReflection()->GetOneofFieldDescriptor(
- *message, oneof_desc);
+ if (field_descriptor != NULL) {
+ return field_descriptor;
+ }
+ const OneofDescriptor* oneof_desc =
+ descriptor->FindOneofByName(field_name);
+ if (oneof_desc != NULL) {
+ *in_oneof = true;
+ return message->GetReflection()->GetOneofFieldDescriptor(*message,
+ oneof_desc);
+ }
+ return NULL;
+}
+
+bool CheckHasPresence(const FieldDescriptor* field_descriptor, bool in_oneof) {
+ if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
+ PyErr_Format(PyExc_ValueError,
+ "Protocol message has no singular \"%s\" field.",
+ field_descriptor->name().c_str());
+ return false;
+ }
+
+ if (field_descriptor->file()->syntax() == FileDescriptor::SYNTAX_PROTO3) {
+ // HasField() for a oneof *itself* isn't supported.
+ if (in_oneof) {
+ PyErr_Format(PyExc_ValueError,
+ "Can't test oneof field \"%s\" for presence in proto3, use "
+ "WhichOneof instead.",
+ field_descriptor->containing_oneof()->name().c_str());
+ return false;
+ }
+
+ // ...but HasField() for fields *in* a oneof is supported.
+ if (field_descriptor->containing_oneof() != NULL) {
+ return true;
+ }
+
+ if (field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
+ PyErr_Format(
+ PyExc_ValueError,
+ "Can't test non-submessage field \"%s\" for presence in proto3.",
+ field_descriptor->name().c_str());
+ return false;
}
}
- return field_descriptor;
+
+ return true;
}
PyObject* HasField(CMessage* self, PyObject* arg) {
-#if PY_MAJOR_VERSION < 3
char* field_name;
- if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) {
+ Py_ssize_t size;
+#if PY_MAJOR_VERSION < 3
+ if (PyString_AsStringAndSize(arg, &field_name, &size) < 0) {
+ return NULL;
+ }
#else
- char* field_name = PyUnicode_AsUTF8(arg);
+ field_name = PyUnicode_AsUTF8AndSize(arg, &size);
if (!field_name) {
-#endif
return NULL;
}
+#endif
- google::protobuf::Message* message = self->message;
- const google::protobuf::Descriptor* descriptor = message->GetDescriptor();
+ Message* message = self->message;
bool is_in_oneof;
- const google::protobuf::FieldDescriptor* field_descriptor =
- FindFieldWithOneofs(message, field_name, &is_in_oneof);
+ const FieldDescriptor* field_descriptor =
+ FindFieldWithOneofs(message, string(field_name, size), &is_in_oneof);
if (field_descriptor == NULL) {
if (!is_in_oneof) {
PyErr_Format(PyExc_ValueError, "Unknown field %s.", field_name);
@@ -913,44 +1434,51 @@ PyObject* HasField(CMessage* self, PyObject* arg) {
}
}
- if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
- PyErr_Format(PyExc_ValueError,
- "Protocol message has no singular \"%s\" field.", field_name);
+ if (!CheckHasPresence(field_descriptor, is_in_oneof)) {
return NULL;
}
- bool has_field =
- message->GetReflection()->HasField(*message, field_descriptor);
- if (!has_field && field_descriptor->cpp_type() ==
- google::protobuf::FieldDescriptor::CPPTYPE_ENUM) {
- // We may have an invalid enum value stored in the UnknownFieldSet and need
- // to check presence in there as well.
- const google::protobuf::UnknownFieldSet& unknown_field_set =
+ if (message->GetReflection()->HasField(*message, field_descriptor)) {
+ Py_RETURN_TRUE;
+ }
+ if (!message->GetReflection()->SupportsUnknownEnumValues() &&
+ field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
+ // Special case: Python HasField() differs in semantics from C++
+ // slightly: we return HasField('enum_field') == true if there is
+ // an unknown enum value present. To implement this we have to
+ // look in the UnknownFieldSet.
+ const UnknownFieldSet& unknown_field_set =
message->GetReflection()->GetUnknownFields(*message);
for (int i = 0; i < unknown_field_set.field_count(); ++i) {
if (unknown_field_set.field(i).number() == field_descriptor->number()) {
Py_RETURN_TRUE;
}
}
- Py_RETURN_FALSE;
}
- return PyBool_FromLong(has_field ? 1 : 0);
+ Py_RETURN_FALSE;
}
-PyObject* ClearExtension(CMessage* self, PyObject* arg) {
+PyObject* ClearExtension(CMessage* self, PyObject* extension) {
if (self->extensions != NULL) {
- return extension_dict::ClearExtension(self->extensions, arg);
+ return extension_dict::ClearExtension(self->extensions, extension);
+ } else {
+ const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
+ if (descriptor == NULL) {
+ return NULL;
+ }
+ if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) {
+ return NULL;
+ }
}
- PyErr_SetString(PyExc_TypeError, "Message is not extendable");
- return NULL;
+ Py_RETURN_NONE;
}
-PyObject* HasExtension(CMessage* self, PyObject* arg) {
- if (self->extensions != NULL) {
- return extension_dict::HasExtension(self->extensions, arg);
+PyObject* HasExtension(CMessage* self, PyObject* extension) {
+ const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
+ if (descriptor == NULL) {
+ return NULL;
}
- PyErr_SetString(PyExc_TypeError, "Message is not extendable");
- return NULL;
+ return HasFieldByDescriptor(self, descriptor);
}
// ---------------------------------------------------------------------
@@ -1000,8 +1528,13 @@ struct SetOwnerVisitor : public ChildVisitor {
return 0;
}
+ int VisitMapContainer(MapContainer* container) {
+ container->SetOwner(new_owner_);
+ return 0;
+ }
+
int VisitCMessage(CMessage* cmessage,
- const google::protobuf::FieldDescriptor* field_descriptor) {
+ const FieldDescriptor* field_descriptor) {
return SetOwner(cmessage, new_owner_);
}
@@ -1020,18 +1553,18 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) {
// Releases the message specified by 'field' and returns the
// pointer. If the field does not exist a new message is created using
// 'descriptor'. The caller takes ownership of the returned pointer.
-Message* ReleaseMessage(google::protobuf::Message* message,
- const google::protobuf::Descriptor* descriptor,
- const google::protobuf::FieldDescriptor* field_descriptor) {
- Message* released_message = message->GetReflection()->ReleaseMessage(
- message, field_descriptor, global_message_factory);
+Message* ReleaseMessage(CMessage* self,
+ const Descriptor* descriptor,
+ const FieldDescriptor* field_descriptor) {
+ MessageFactory* message_factory = GetFactoryForMessage(self);
+ Message* released_message = self->message->GetReflection()->ReleaseMessage(
+ self->message, field_descriptor, message_factory);
// ReleaseMessage will return NULL which differs from
// child_cmessage->message, if the field does not exist. In this case,
// the latter points to the default instance via a const_cast<>, so we
// have to reset it to a new mutable object since we are taking ownership.
if (released_message == NULL) {
- const Message* prototype = global_message_factory->GetPrototype(
- descriptor);
+ const Message* prototype = message_factory->GetPrototype(descriptor);
GOOGLE_DCHECK(prototype != NULL);
released_message = prototype->New();
}
@@ -1039,16 +1572,16 @@ Message* ReleaseMessage(google::protobuf::Message* message,
return released_message;
}
-int ReleaseSubMessage(google::protobuf::Message* message,
- const google::protobuf::FieldDescriptor* field_descriptor,
+int ReleaseSubMessage(CMessage* self,
+ const FieldDescriptor* field_descriptor,
CMessage* child_cmessage) {
// Release the Message
shared_ptr<Message> released_message(ReleaseMessage(
- message, child_cmessage->message->GetDescriptor(), field_descriptor));
+ self, child_cmessage->message->GetDescriptor(), field_descriptor));
child_cmessage->message = released_message.get();
child_cmessage->owner.swap(released_message);
child_cmessage->parent = NULL;
- child_cmessage->parent_field = NULL;
+ child_cmessage->parent_field_descriptor = NULL;
child_cmessage->read_only = false;
return ForEachCompositeField(child_cmessage,
SetOwnerVisitor(child_cmessage->owner));
@@ -1056,8 +1589,8 @@ int ReleaseSubMessage(google::protobuf::Message* message,
struct ReleaseChild : public ChildVisitor {
// message must outlive this object.
- explicit ReleaseChild(google::protobuf::Message* parent_message) :
- parent_message_(parent_message) {}
+ explicit ReleaseChild(CMessage* parent) :
+ parent_(parent) {}
int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
return repeated_composite_container::Release(
@@ -1069,44 +1602,33 @@ struct ReleaseChild : public ChildVisitor {
reinterpret_cast<RepeatedScalarContainer*>(container));
}
+ int VisitMapContainer(MapContainer* container) {
+ return reinterpret_cast<MapContainer*>(container)->Release();
+ }
+
int VisitCMessage(CMessage* cmessage,
- const google::protobuf::FieldDescriptor* field_descriptor) {
- return ReleaseSubMessage(parent_message_, field_descriptor,
+ const FieldDescriptor* field_descriptor) {
+ return ReleaseSubMessage(parent_, field_descriptor,
reinterpret_cast<CMessage*>(cmessage));
}
- google::protobuf::Message* parent_message_;
+ CMessage* parent_;
};
int InternalReleaseFieldByDescriptor(
- const google::protobuf::FieldDescriptor* field_descriptor,
- PyObject* composite_field,
- google::protobuf::Message* parent_message) {
+ CMessage* self,
+ const FieldDescriptor* field_descriptor,
+ PyObject* composite_field) {
return VisitCompositeField(
field_descriptor,
composite_field,
- ReleaseChild(parent_message));
-}
-
-int InternalReleaseField(CMessage* self, PyObject* composite_field,
- PyObject* name) {
- PyObject* cdescriptor = GetDescriptor(self, name);
- if (cdescriptor != NULL) {
- const google::protobuf::FieldDescriptor* descriptor =
- reinterpret_cast<CFieldDescriptor*>(cdescriptor)->descriptor;
- return InternalReleaseFieldByDescriptor(
- descriptor, composite_field, self->message);
- }
-
- return 0;
+ ReleaseChild(self));
}
PyObject* ClearFieldByDescriptor(
CMessage* self,
- const google::protobuf::FieldDescriptor* descriptor) {
- if (!FIELD_BELONGS_TO_MESSAGE(descriptor, self->message)) {
- PyErr_SetString(PyExc_KeyError,
- "Field does not belong to message!");
+ const FieldDescriptor* descriptor) {
+ if (!CheckFieldBelongsToMessage(descriptor, self->message)) {
return NULL;
}
AssureWritable(self);
@@ -1115,25 +1637,23 @@ PyObject* ClearFieldByDescriptor(
}
PyObject* ClearField(CMessage* self, PyObject* arg) {
- char* field_name;
if (!PyString_Check(arg)) {
PyErr_SetString(PyExc_TypeError, "field name must be a string");
return NULL;
}
#if PY_MAJOR_VERSION < 3
- if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) {
- return NULL;
- }
+ const char* field_name = PyString_AS_STRING(arg);
+ Py_ssize_t size = PyString_GET_SIZE(arg);
#else
- field_name = PyUnicode_AsUTF8(arg);
+ Py_ssize_t size;
+ const char* field_name = PyUnicode_AsUTF8AndSize(arg, &size);
#endif
AssureWritable(self);
- google::protobuf::Message* message = self->message;
- const google::protobuf::Descriptor* descriptor = message->GetDescriptor();
+ Message* message = self->message;
ScopedPyObjectPtr arg_in_oneof;
bool is_in_oneof;
- const google::protobuf::FieldDescriptor* field_descriptor =
- FindFieldWithOneofs(message, field_name, &is_in_oneof);
+ const FieldDescriptor* field_descriptor =
+ FindFieldWithOneofs(message, string(field_name, size), &is_in_oneof);
if (field_descriptor == NULL) {
if (!is_in_oneof) {
PyErr_Format(PyExc_ValueError,
@@ -1143,24 +1663,27 @@ PyObject* ClearField(CMessage* self, PyObject* arg) {
Py_RETURN_NONE;
}
} else if (is_in_oneof) {
- arg_in_oneof.reset(PyString_FromString(field_descriptor->name().c_str()));
+ const string& name = field_descriptor->name();
+ arg_in_oneof.reset(PyString_FromStringAndSize(name.c_str(), name.size()));
arg = arg_in_oneof.get();
}
- PyObject* composite_field = PyDict_GetItem(self->composite_fields,
- arg);
+ PyObject* composite_field = self->composite_fields ?
+ PyDict_GetItem(self->composite_fields, arg) : NULL;
// Only release the field if there's a possibility that there are
// references to it.
if (composite_field != NULL) {
- if (InternalReleaseField(self, composite_field, arg) < 0) {
+ if (InternalReleaseFieldByDescriptor(self, field_descriptor,
+ composite_field) < 0) {
return NULL;
}
PyDict_DelItem(self->composite_fields, arg);
}
message->GetReflection()->ClearField(message, field_descriptor);
- if (field_descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_ENUM) {
- google::protobuf::UnknownFieldSet* unknown_field_set =
+ if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM &&
+ !message->GetReflection()->SupportsUnknownEnumValues()) {
+ UnknownFieldSet* unknown_field_set =
message->GetReflection()->MutableUnknownFields(message);
unknown_field_set->DeleteByNumber(field_descriptor->number());
}
@@ -1170,25 +1693,12 @@ PyObject* ClearField(CMessage* self, PyObject* arg) {
PyObject* Clear(CMessage* self) {
AssureWritable(self);
- if (ForEachCompositeField(self, ReleaseChild(self->message)) == -1)
+ if (ForEachCompositeField(self, ReleaseChild(self)) == -1)
return NULL;
-
- // The old ExtensionDict still aliases this CMessage, but all its
- // fields have been released.
- if (self->extensions != NULL) {
- Py_CLEAR(self->extensions);
- PyObject* py_extension_dict = PyObject_CallObject(
- reinterpret_cast<PyObject*>(&ExtensionDict_Type), NULL);
- if (py_extension_dict == NULL) {
- return NULL;
- }
- ExtensionDict* extension_dict = reinterpret_cast<ExtensionDict*>(
- py_extension_dict);
- extension_dict->parent = self;
- extension_dict->message = self->message;
- self->extensions = extension_dict;
+ Py_CLEAR(self->extensions);
+ if (self->composite_fields) {
+ PyDict_Clear(self->composite_fields);
}
- PyDict_Clear(self->composite_fields);
self->message->Clear();
Py_RETURN_NONE;
}
@@ -1196,8 +1706,8 @@ PyObject* Clear(CMessage* self) {
// ---------------------------------------------------------------------
static string GetMessageName(CMessage* self) {
- if (self->parent_field != NULL) {
- return self->parent_field->descriptor->full_name();
+ if (self->parent_field_descriptor != NULL) {
+ return self->parent_field_descriptor->full_name();
} else {
return self->message->GetDescriptor()->full_name();
}
@@ -1218,7 +1728,27 @@ static PyObject* SerializeToString(CMessage* self, PyObject* args) {
if (joined == NULL) {
return NULL;
}
- PyErr_Format(EncodeError_class, "Message %s is missing required fields: %s",
+
+ // TODO(haberman): this is a (hopefully temporary) hack. The unit testing
+ // infrastructure reloads all pure-Python modules for every test, but not
+ // C++ modules (because that's generally impossible:
+ // http://bugs.python.org/issue1144263). But if we cache EncodeError, we'll
+ // return the EncodeError from a previous load of the module, which won't
+ // match a user's attempt to catch EncodeError. So we have to look it up
+ // again every time.
+ ScopedPyObjectPtr message_module(PyImport_ImportModule(
+ "google.protobuf.message"));
+ if (message_module.get() == NULL) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr encode_error(
+ PyObject_GetAttrString(message_module.get(), "EncodeError"));
+ if (encode_error.get() == NULL) {
+ return NULL;
+ }
+ PyErr_Format(encode_error.get(),
+ "Message %s is missing required fields: %s",
GetMessageName(self).c_str(), PyString_AsString(joined.get()));
return NULL;
}
@@ -1244,40 +1774,44 @@ static PyObject* SerializePartialToString(CMessage* self) {
// Formats proto fields for ascii dumps using python formatting functions where
// appropriate.
-class PythonFieldValuePrinter : public google::protobuf::TextFormat::FieldValuePrinter {
+class PythonFieldValuePrinter : public TextFormat::FieldValuePrinter {
public:
- PythonFieldValuePrinter() : float_holder_(PyFloat_FromDouble(0)) {}
-
// Python has some differences from C++ when printing floating point numbers.
//
// 1) Trailing .0 is always printed.
- // 2) Outputted is rounded to 12 digits.
+ // 2) (Python2) Output is rounded to 12 digits.
+ // 3) (Python3) The full precision of the double is preserved (and Python uses
+ // David M. Gay's dtoa(), when the C++ code uses SimpleDtoa. There are some
+ // differences, but they rarely happen)
//
// We override floating point printing with the C-API function for printing
// Python floats to ensure consistency.
string PrintFloat(float value) const { return PrintDouble(value); }
string PrintDouble(double value) const {
- reinterpret_cast<PyFloatObject*>(float_holder_.get())->ob_fval = value;
- ScopedPyObjectPtr s(PyObject_Str(float_holder_.get()));
- if (s == NULL) return string();
-#if PY_MAJOR_VERSION < 3
- char *cstr = PyBytes_AS_STRING(static_cast<PyObject*>(s));
-#else
- char *cstr = PyUnicode_AsUTF8(s);
-#endif
- return string(cstr);
- }
+ // This implementation is not highly optimized (it allocates two temporary
+ // Python objects) but it is simple and portable. If this is shown to be a
+ // performance bottleneck, we can optimize it, but the results will likely
+ // be more complicated to accommodate the differing behavior of double
+ // formatting between Python 2 and Python 3.
+ //
+ // (Though a valid question is: do we really want to make out output
+ // dependent on the Python version?)
+ ScopedPyObjectPtr py_value(PyFloat_FromDouble(value));
+ if (!py_value.get()) {
+ return string();
+ }
- private:
- // Holder for a python float object which we use to allow us to use
- // the Python API for printing doubles. We initialize once and then
- // directly modify it for every float printed to save on allocations
- // and refcounting.
- ScopedPyObjectPtr float_holder_;
+ ScopedPyObjectPtr py_str(PyObject_Str(py_value.get()));
+ if (!py_str.get()) {
+ return string();
+ }
+
+ return string(PyString_AsString(py_str.get()));
+ }
};
static PyObject* ToStr(CMessage* self) {
- google::protobuf::TextFormat::Printer printer;
+ TextFormat::Printer printer;
// Passes ownership
printer.SetDefaultFieldValuePrinter(new PythonFieldValuePrinter());
printer.SetHideUnknownFields(true);
@@ -1291,8 +1825,12 @@ static PyObject* ToStr(CMessage* self) {
PyObject* MergeFrom(CMessage* self, PyObject* arg) {
CMessage* other_message;
- if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) {
- PyErr_SetString(PyExc_TypeError, "Must be a message");
+ if (!PyObject_TypeCheck(arg, &CMessage_Type)) {
+ PyErr_Format(PyExc_TypeError,
+ "Parameter to MergeFrom() must be instance of same class: "
+ "expected %s got %s.",
+ self->message->GetDescriptor()->full_name().c_str(),
+ Py_TYPE(arg)->tp_name);
return NULL;
}
@@ -1300,8 +1838,8 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) {
if (other_message->message->GetDescriptor() !=
self->message->GetDescriptor()) {
PyErr_Format(PyExc_TypeError,
- "Tried to merge from a message with a different type. "
- "to: %s, from: %s",
+ "Parameter to MergeFrom() must be instance of same class: "
+ "expected %s got %s.",
self->message->GetDescriptor()->full_name().c_str(),
other_message->message->GetDescriptor()->full_name().c_str());
return NULL;
@@ -1319,8 +1857,12 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) {
static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
CMessage* other_message;
- if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) {
- PyErr_SetString(PyExc_TypeError, "Must be a message");
+ if (!PyObject_TypeCheck(arg, &CMessage_Type)) {
+ PyErr_Format(PyExc_TypeError,
+ "Parameter to CopyFrom() must be instance of same class: "
+ "expected %s got %s.",
+ self->message->GetDescriptor()->full_name().c_str(),
+ Py_TYPE(arg)->tp_name);
return NULL;
}
@@ -1333,8 +1875,8 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
if (other_message->message->GetDescriptor() !=
self->message->GetDescriptor()) {
PyErr_Format(PyExc_TypeError,
- "Tried to copy from a message with a different type. "
- "to: %s, from: %s",
+ "Parameter to CopyFrom() must be instance of same class: "
+ "expected %s got %s.",
self->message->GetDescriptor()->full_name().c_str(),
other_message->message->GetDescriptor()->full_name().c_str());
return NULL;
@@ -1344,13 +1886,37 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
// CopyFrom on the message will not clean up self->composite_fields,
// which can leave us in an inconsistent state, so clear it out here.
- Clear(self);
+ (void)ScopedPyObjectPtr(Clear(self));
self->message->CopyFrom(*other_message->message);
Py_RETURN_NONE;
}
+// Protobuf has a 64MB limit built in, this variable will override this. Please
+// do not enable this unless you fully understand the implications: protobufs
+// must all be kept in memory at the same time, so if they grow too big you may
+// get OOM errors. The protobuf APIs do not provide any tools for processing
+// protobufs in chunks. If you have protos this big you should break them up if
+// it is at all convenient to do so.
+static bool allow_oversize_protos = false;
+
+// Provide a method in the module to set allow_oversize_protos to a boolean
+// value. This method returns the newly value of allow_oversize_protos.
+static PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) {
+ if (!arg || !PyBool_Check(arg)) {
+ PyErr_SetString(PyExc_TypeError,
+ "Argument to SetAllowOversizeProtos must be boolean");
+ return NULL;
+ }
+ allow_oversize_protos = PyObject_IsTrue(arg);
+ if (allow_oversize_protos) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
const void* data;
Py_ssize_t data_length;
@@ -1359,9 +1925,13 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
}
AssureWritable(self);
- google::protobuf::io::CodedInputStream input(
+ io::CodedInputStream input(
reinterpret_cast<const uint8*>(data), data_length);
- input.SetExtensionRegistry(GetDescriptorPool(), global_message_factory);
+ if (allow_oversize_protos) {
+ input.SetTotalBytesLimit(INT_MAX, INT_MAX);
+ }
+ PyDescriptorPool* pool = GetDescriptorPoolForMessage(self);
+ input.SetExtensionRegistry(pool->pool, pool->message_factory);
bool success = self->message->MergePartialFromCodedStream(&input);
if (success) {
return PyInt_FromLong(input.CurrentPosition());
@@ -1372,7 +1942,7 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
}
static PyObject* ParseFromString(CMessage* self, PyObject* arg) {
- if (Clear(self) == NULL) {
+ if (ScopedPyObjectPtr(Clear(self)) == NULL) {
return NULL;
}
return MergeFromString(self, arg);
@@ -1384,14 +1954,12 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) {
static PyObject* RegisterExtension(PyObject* cls,
PyObject* extension_handle) {
- ScopedPyObjectPtr message_descriptor(PyObject_GetAttr(cls, kDESCRIPTOR));
- if (message_descriptor == NULL) {
- return NULL;
- }
- if (PyObject_SetAttrString(extension_handle, "containing_type",
- message_descriptor) < 0) {
+ const FieldDescriptor* descriptor =
+ GetExtensionDescriptor(extension_handle);
+ if (descriptor == NULL) {
return NULL;
}
+
ScopedPyObjectPtr extensions_by_name(
PyObject_GetAttr(cls, k_extensions_by_name));
if (extensions_by_name == NULL) {
@@ -1402,7 +1970,23 @@ static PyObject* RegisterExtension(PyObject* cls,
if (full_name == NULL) {
return NULL;
}
- if (PyDict_SetItem(extensions_by_name, full_name, extension_handle) < 0) {
+
+ // If the extension was already registered, check that it is the same.
+ PyObject* existing_extension =
+ PyDict_GetItem(extensions_by_name.get(), full_name.get());
+ if (existing_extension != NULL) {
+ const FieldDescriptor* existing_extension_descriptor =
+ GetExtensionDescriptor(existing_extension);
+ if (existing_extension_descriptor != descriptor) {
+ PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
+ return NULL;
+ }
+ // Nothing else to do.
+ Py_RETURN_NONE;
+ }
+
+ if (PyDict_SetItem(extensions_by_name.get(), full_name.get(),
+ extension_handle) < 0) {
return NULL;
}
@@ -1413,36 +1997,52 @@ static PyObject* RegisterExtension(PyObject* cls,
PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class");
return NULL;
}
+
ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number"));
if (number == NULL) {
return NULL;
}
- if (PyDict_SetItem(extensions_by_number, number, extension_handle) < 0) {
- return NULL;
- }
- CFieldDescriptor* cdescriptor =
- extension_dict::InternalGetCDescriptorFromExtension(extension_handle);
- ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor));
- if (cdescriptor == NULL) {
+ // If the extension was already registered by number, check that it is the
+ // same.
+ existing_extension = PyDict_GetItem(extensions_by_number.get(), number.get());
+ if (existing_extension != NULL) {
+ const FieldDescriptor* existing_extension_descriptor =
+ GetExtensionDescriptor(existing_extension);
+ if (existing_extension_descriptor != descriptor) {
+ const Descriptor* msg_desc = GetMessageDescriptor(
+ reinterpret_cast<PyTypeObject*>(cls));
+ PyErr_Format(
+ PyExc_ValueError,
+ "Extensions \"%s\" and \"%s\" both try to extend message type "
+ "\"%s\" with field number %ld.",
+ existing_extension_descriptor->full_name().c_str(),
+ descriptor->full_name().c_str(),
+ msg_desc->full_name().c_str(),
+ PyInt_AsLong(number.get()));
+ return NULL;
+ }
+ // Nothing else to do.
+ Py_RETURN_NONE;
+ }
+ if (PyDict_SetItem(extensions_by_number.get(), number.get(),
+ extension_handle) < 0) {
return NULL;
}
- Py_INCREF(extension_handle);
- cdescriptor->descriptor_field = extension_handle;
- const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor;
+
// Check if it's a message set
if (descriptor->is_extension() &&
descriptor->containing_type()->options().message_set_wire_format() &&
- descriptor->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE &&
- descriptor->message_type() == descriptor->extension_scope() &&
- descriptor->label() == google::protobuf::FieldDescriptor::LABEL_OPTIONAL) {
+ descriptor->type() == FieldDescriptor::TYPE_MESSAGE &&
+ descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) {
ScopedPyObjectPtr message_name(PyString_FromStringAndSize(
descriptor->message_type()->full_name().c_str(),
descriptor->message_type()->full_name().size()));
if (message_name == NULL) {
return NULL;
}
- PyDict_SetItem(extensions_by_name, message_name, extension_handle);
+ PyDict_SetItem(extensions_by_name.get(), message_name.get(),
+ extension_handle);
}
Py_RETURN_NONE;
@@ -1454,53 +2054,38 @@ static PyObject* SetInParent(CMessage* self, PyObject* args) {
}
static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
- char* oneof_name;
- if (!PyString_Check(arg)) {
- PyErr_SetString(PyExc_TypeError, "field name must be a string");
+ Py_ssize_t name_size;
+ char *name_data;
+ if (PyString_AsStringAndSize(arg, &name_data, &name_size) < 0)
return NULL;
- }
- oneof_name = PyString_AsString(arg);
- if (oneof_name == NULL) {
- return NULL;
- }
- const google::protobuf::OneofDescriptor* oneof_desc =
+ string oneof_name = string(name_data, name_size);
+ const OneofDescriptor* oneof_desc =
self->message->GetDescriptor()->FindOneofByName(oneof_name);
if (oneof_desc == NULL) {
PyErr_Format(PyExc_ValueError,
- "Protocol message has no oneof \"%s\" field.", oneof_name);
+ "Protocol message has no oneof \"%s\" field.",
+ oneof_name.c_str());
return NULL;
}
- const google::protobuf::FieldDescriptor* field_in_oneof =
+ const FieldDescriptor* field_in_oneof =
self->message->GetReflection()->GetOneofFieldDescriptor(
*self->message, oneof_desc);
if (field_in_oneof == NULL) {
Py_RETURN_NONE;
} else {
- return PyString_FromString(field_in_oneof->name().c_str());
+ const string& name = field_in_oneof->name();
+ return PyString_FromStringAndSize(name.c_str(), name.size());
}
}
+static PyObject* GetExtensionDict(CMessage* self, void *closure);
+
static PyObject* ListFields(CMessage* self) {
- vector<const google::protobuf::FieldDescriptor*> fields;
+ vector<const FieldDescriptor*> fields;
self->message->GetReflection()->ListFields(*self->message, &fields);
- PyObject* descriptor = PyDict_GetItem(Py_TYPE(self)->tp_dict, kDESCRIPTOR);
- if (descriptor == NULL) {
- return NULL;
- }
- ScopedPyObjectPtr fields_by_name(
- PyObject_GetAttr(descriptor, kfields_by_name));
- if (fields_by_name == NULL) {
- return NULL;
- }
- ScopedPyObjectPtr extensions_by_name(PyObject_GetAttr(
- reinterpret_cast<PyObject*>(Py_TYPE(self)), k_extensions_by_name));
- if (extensions_by_name == NULL) {
- PyErr_SetString(PyExc_ValueError, "no extensionsbyname");
- return NULL;
- }
// Normally, the list will be exactly the size of the fields.
- PyObject* all_fields = PyList_New(fields.size());
+ ScopedPyObjectPtr all_fields(PyList_New(fields.size()));
if (all_fields == NULL) {
return NULL;
}
@@ -1509,75 +2094,84 @@ static PyObject* ListFields(CMessage* self) {
// the field information. Thus the actual size of the py list will be
// smaller than the size of fields. Set the actual size at the end.
Py_ssize_t actual_size = 0;
- for (Py_ssize_t i = 0; i < fields.size(); ++i) {
+ for (size_t i = 0; i < fields.size(); ++i) {
ScopedPyObjectPtr t(PyTuple_New(2));
if (t == NULL) {
- Py_DECREF(all_fields);
return NULL;
}
if (fields[i]->is_extension()) {
- const string& field_name = fields[i]->full_name();
- PyObject* extension_field = PyDict_GetItemString(extensions_by_name,
- field_name.c_str());
+ ScopedPyObjectPtr extension_field(
+ PyFieldDescriptor_FromDescriptor(fields[i]));
if (extension_field == NULL) {
- // If we couldn't fetch extension_field, it means the module that
- // defines this extension has not been explicitly imported in Python
- // code, and the extension hasn't been registered. There's nothing much
- // we can do about this, so just skip it in the output to match the
- // behavior of the python implementation.
+ return NULL;
+ }
+ // With C++ descriptors, the field can always be retrieved, but for
+ // unknown extensions which have not been imported in Python code, there
+ // is no message class and we cannot retrieve the value.
+ // TODO(amauryfa): consider building the class on the fly!
+ if (fields[i]->message_type() != NULL &&
+ cdescriptor_pool::GetMessageClass(
+ GetDescriptorPoolForMessage(self),
+ fields[i]->message_type()) == NULL) {
+ PyErr_Clear();
continue;
}
- PyObject* extensions = reinterpret_cast<PyObject*>(self->extensions);
+ ScopedPyObjectPtr extensions(GetExtensionDict(self, NULL));
if (extensions == NULL) {
- Py_DECREF(all_fields);
return NULL;
}
// 'extension' reference later stolen by PyTuple_SET_ITEM.
- PyObject* extension = PyObject_GetItem(extensions, extension_field);
+ PyObject* extension = PyObject_GetItem(
+ extensions.get(), extension_field.get());
if (extension == NULL) {
- Py_DECREF(all_fields);
return NULL;
}
- Py_INCREF(extension_field);
- PyTuple_SET_ITEM(t.get(), 0, extension_field);
+ PyTuple_SET_ITEM(t.get(), 0, extension_field.release());
// Steals reference to 'extension'
PyTuple_SET_ITEM(t.get(), 1, extension);
} else {
+ // Normal field
const string& field_name = fields[i]->name();
ScopedPyObjectPtr py_field_name(PyString_FromStringAndSize(
field_name.c_str(), field_name.length()));
if (py_field_name == NULL) {
PyErr_SetString(PyExc_ValueError, "bad string");
- Py_DECREF(all_fields);
return NULL;
}
- PyObject* field_descriptor =
- PyDict_GetItem(fields_by_name, py_field_name);
+ ScopedPyObjectPtr field_descriptor(
+ PyFieldDescriptor_FromDescriptor(fields[i]));
if (field_descriptor == NULL) {
- Py_DECREF(all_fields);
return NULL;
}
- PyObject* field_value = GetAttr(self, py_field_name);
+ PyObject* field_value = GetAttr(self, py_field_name.get());
if (field_value == NULL) {
- PyErr_SetObject(PyExc_ValueError, py_field_name);
- Py_DECREF(all_fields);
+ PyErr_SetObject(PyExc_ValueError, py_field_name.get());
return NULL;
}
- Py_INCREF(field_descriptor);
- PyTuple_SET_ITEM(t.get(), 0, field_descriptor);
+ PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release());
PyTuple_SET_ITEM(t.get(), 1, field_value);
}
- PyList_SET_ITEM(all_fields, actual_size, t.release());
+ PyList_SET_ITEM(all_fields.get(), actual_size, t.release());
++actual_size;
}
- Py_SIZE(all_fields) = actual_size;
- return all_fields;
+ if (static_cast<size_t>(actual_size) != fields.size() &&
+ (PyList_SetSlice(all_fields.get(), actual_size, fields.size(), NULL) <
+ 0)) {
+ return NULL;
+ }
+ return all_fields.release();
+}
+
+static PyObject* DiscardUnknownFields(CMessage* self) {
+ AssureWritable(self);
+ self->message->DiscardUnknownFields();
+ Py_RETURN_NONE;
}
PyObject* FindInitializationErrors(CMessage* self) {
- google::protobuf::Message* message = self->message;
+ Message* message = self->message;
vector<string> errors;
message->FindInitializationErrors(&errors);
@@ -1585,7 +2179,7 @@ PyObject* FindInitializationErrors(CMessage* self) {
if (error_list == NULL) {
return NULL;
}
- for (Py_ssize_t i = 0; i < errors.size(); ++i) {
+ for (size_t i = 0; i < errors.size(); ++i) {
const string& error = errors[i];
PyObject* error_string = PyString_FromStringAndSize(
error.c_str(), error.length());
@@ -1599,94 +2193,105 @@ PyObject* FindInitializationErrors(CMessage* self) {
}
static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
- if (!PyObject_TypeCheck(other, &CMessage_Type)) {
- if (opid == Py_EQ) {
- Py_RETURN_FALSE;
- } else if (opid == Py_NE) {
- Py_RETURN_TRUE;
- }
- }
- if (opid == Py_EQ || opid == Py_NE) {
- ScopedPyObjectPtr self_fields(ListFields(self));
- ScopedPyObjectPtr other_fields(ListFields(
- reinterpret_cast<CMessage*>(other)));
- return PyObject_RichCompare(self_fields, other_fields, opid);
- } else {
+ // Only equality comparisons are implemented.
+ if (opid != Py_EQ && opid != Py_NE) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
+ bool equals = true;
+ // If other is not a message, it cannot be equal.
+ if (!PyObject_TypeCheck(other, &CMessage_Type)) {
+ equals = false;
+ }
+ const google::protobuf::Message* other_message =
+ reinterpret_cast<CMessage*>(other)->message;
+ // If messages don't have the same descriptors, they are not equal.
+ if (equals &&
+ self->message->GetDescriptor() != other_message->GetDescriptor()) {
+ equals = false;
+ }
+ // Check the message contents.
+ if (equals && !google::protobuf::util::MessageDifferencer::Equals(
+ *self->message,
+ *reinterpret_cast<CMessage*>(other)->message)) {
+ equals = false;
+ }
+ if (equals ^ (opid == Py_EQ)) {
+ Py_RETURN_FALSE;
+ } else {
+ Py_RETURN_TRUE;
+ }
}
-PyObject* InternalGetScalar(
- CMessage* self,
- const google::protobuf::FieldDescriptor* field_descriptor) {
- google::protobuf::Message* message = self->message;
- const google::protobuf::Reflection* reflection = message->GetReflection();
+PyObject* InternalGetScalar(const Message* message,
+ const FieldDescriptor* field_descriptor) {
+ const Reflection* reflection = message->GetReflection();
- if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) {
- PyErr_SetString(
- PyExc_KeyError, "Field does not belong to message!");
+ if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
return NULL;
}
PyObject* result = NULL;
switch (field_descriptor->cpp_type()) {
- case google::protobuf::FieldDescriptor::CPPTYPE_INT32: {
+ case FieldDescriptor::CPPTYPE_INT32: {
int32 value = reflection->GetInt32(*message, field_descriptor);
result = PyInt_FromLong(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_INT64: {
+ case FieldDescriptor::CPPTYPE_INT64: {
int64 value = reflection->GetInt64(*message, field_descriptor);
result = PyLong_FromLongLong(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: {
+ case FieldDescriptor::CPPTYPE_UINT32: {
uint32 value = reflection->GetUInt32(*message, field_descriptor);
result = PyInt_FromSize_t(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: {
+ case FieldDescriptor::CPPTYPE_UINT64: {
uint64 value = reflection->GetUInt64(*message, field_descriptor);
result = PyLong_FromUnsignedLongLong(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: {
+ case FieldDescriptor::CPPTYPE_FLOAT: {
float value = reflection->GetFloat(*message, field_descriptor);
result = PyFloat_FromDouble(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: {
+ case FieldDescriptor::CPPTYPE_DOUBLE: {
double value = reflection->GetDouble(*message, field_descriptor);
result = PyFloat_FromDouble(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: {
+ case FieldDescriptor::CPPTYPE_BOOL: {
bool value = reflection->GetBool(*message, field_descriptor);
result = PyBool_FromLong(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
+ case FieldDescriptor::CPPTYPE_STRING: {
string value = reflection->GetString(*message, field_descriptor);
result = ToStringObject(field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
- if (!message->GetReflection()->HasField(*message, field_descriptor)) {
+ case FieldDescriptor::CPPTYPE_ENUM: {
+ if (!message->GetReflection()->SupportsUnknownEnumValues() &&
+ !message->GetReflection()->HasField(*message, field_descriptor)) {
// Look for the value in the unknown fields.
- google::protobuf::UnknownFieldSet* unknown_field_set =
- message->GetReflection()->MutableUnknownFields(message);
- for (int i = 0; i < unknown_field_set->field_count(); ++i) {
- if (unknown_field_set->field(i).number() ==
- field_descriptor->number()) {
- result = PyInt_FromLong(unknown_field_set->field(i).varint());
+ const UnknownFieldSet& unknown_field_set =
+ message->GetReflection()->GetUnknownFields(*message);
+ for (int i = 0; i < unknown_field_set.field_count(); ++i) {
+ if (unknown_field_set.field(i).number() ==
+ field_descriptor->number() &&
+ unknown_field_set.field(i).type() ==
+ google::protobuf::UnknownField::TYPE_VARINT) {
+ result = PyInt_FromLong(unknown_field_set.field(i).varint());
break;
}
}
}
if (result == NULL) {
- const google::protobuf::EnumValueDescriptor* enum_value =
+ const EnumValueDescriptor* enum_value =
message->GetReflection()->GetEnum(*message, field_descriptor);
result = PyInt_FromLong(enum_value->number());
}
@@ -1701,116 +2306,100 @@ PyObject* InternalGetScalar(
return result;
}
-PyObject* InternalGetSubMessage(CMessage* self,
- CFieldDescriptor* cfield_descriptor) {
- PyObject* field = cfield_descriptor->descriptor_field;
- ScopedPyObjectPtr message_type(PyObject_GetAttr(field, kmessage_type));
- if (message_type == NULL) {
- return NULL;
- }
- ScopedPyObjectPtr concrete_class(
- PyObject_GetAttr(message_type, k_concrete_class));
- if (concrete_class == NULL) {
+PyObject* InternalGetSubMessage(
+ CMessage* self, const FieldDescriptor* field_descriptor) {
+ const Reflection* reflection = self->message->GetReflection();
+ PyDescriptorPool* pool = GetDescriptorPoolForMessage(self);
+ const Message& sub_message = reflection->GetMessage(
+ *self->message, field_descriptor, pool->message_factory);
+
+ CMessageClass* message_class = cdescriptor_pool::GetMessageClass(
+ pool, field_descriptor->message_type());
+ if (message_class == NULL) {
return NULL;
}
- PyObject* py_cmsg = cmessage::NewEmpty(concrete_class);
- if (py_cmsg == NULL) {
+
+ CMessage* cmsg = cmessage::NewEmptyMessage(message_class);
+ if (cmsg == NULL) {
return NULL;
}
- if (!PyObject_TypeCheck(py_cmsg, &CMessage_Type)) {
- PyErr_SetString(PyExc_TypeError, "Not a CMessage!");
- }
- CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
- const google::protobuf::FieldDescriptor* field_descriptor =
- cfield_descriptor->descriptor;
- const google::protobuf::Reflection* reflection = self->message->GetReflection();
- const google::protobuf::Message& sub_message = reflection->GetMessage(
- *self->message, field_descriptor, global_message_factory);
cmsg->owner = self->owner;
cmsg->parent = self;
- cmsg->parent_field = cfield_descriptor;
+ cmsg->parent_field_descriptor = field_descriptor;
cmsg->read_only = !reflection->HasField(*self->message, field_descriptor);
- cmsg->message = const_cast<google::protobuf::Message*>(&sub_message);
+ cmsg->message = const_cast<Message*>(&sub_message);
- if (InitAttributes(cmsg, NULL, NULL) < 0) {
- Py_DECREF(py_cmsg);
- return NULL;
- }
- return py_cmsg;
+ return reinterpret_cast<PyObject*>(cmsg);
}
-int InternalSetScalar(
- CMessage* self,
- const google::protobuf::FieldDescriptor* field_descriptor,
+int InternalSetNonOneofScalar(
+ Message* message,
+ const FieldDescriptor* field_descriptor,
PyObject* arg) {
- google::protobuf::Message* message = self->message;
- const google::protobuf::Reflection* reflection = message->GetReflection();
-
- if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) {
- PyErr_SetString(
- PyExc_KeyError, "Field does not belong to message!");
- return -1;
- }
+ const Reflection* reflection = message->GetReflection();
- if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) {
+ if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
return -1;
}
switch (field_descriptor->cpp_type()) {
- case google::protobuf::FieldDescriptor::CPPTYPE_INT32: {
+ case FieldDescriptor::CPPTYPE_INT32: {
GOOGLE_CHECK_GET_INT32(arg, value, -1);
reflection->SetInt32(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_INT64: {
+ case FieldDescriptor::CPPTYPE_INT64: {
GOOGLE_CHECK_GET_INT64(arg, value, -1);
reflection->SetInt64(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: {
+ case FieldDescriptor::CPPTYPE_UINT32: {
GOOGLE_CHECK_GET_UINT32(arg, value, -1);
reflection->SetUInt32(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: {
+ case FieldDescriptor::CPPTYPE_UINT64: {
GOOGLE_CHECK_GET_UINT64(arg, value, -1);
reflection->SetUInt64(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: {
+ case FieldDescriptor::CPPTYPE_FLOAT: {
GOOGLE_CHECK_GET_FLOAT(arg, value, -1);
reflection->SetFloat(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: {
+ case FieldDescriptor::CPPTYPE_DOUBLE: {
GOOGLE_CHECK_GET_DOUBLE(arg, value, -1);
reflection->SetDouble(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: {
+ case FieldDescriptor::CPPTYPE_BOOL: {
GOOGLE_CHECK_GET_BOOL(arg, value, -1);
reflection->SetBool(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
+ case FieldDescriptor::CPPTYPE_STRING: {
if (!CheckAndSetString(
arg, message, field_descriptor, reflection, false, -1)) {
return -1;
}
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
+ case FieldDescriptor::CPPTYPE_ENUM: {
GOOGLE_CHECK_GET_INT32(arg, value, -1);
- const google::protobuf::EnumDescriptor* enum_descriptor =
- field_descriptor->enum_type();
- const google::protobuf::EnumValueDescriptor* enum_value =
- enum_descriptor->FindValueByNumber(value);
- if (enum_value != NULL) {
- reflection->SetEnum(message, field_descriptor, enum_value);
+ if (reflection->SupportsUnknownEnumValues()) {
+ reflection->SetEnumValue(message, field_descriptor, value);
} else {
- PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
- return -1;
+ const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
+ const EnumValueDescriptor* enum_value =
+ enum_descriptor->FindValueByNumber(value);
+ if (enum_value != NULL) {
+ reflection->SetEnum(message, field_descriptor, enum_value);
+ } else {
+ PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
+ return -1;
+ }
}
break;
}
@@ -1824,6 +2413,21 @@ int InternalSetScalar(
return 0;
}
+int InternalSetScalar(
+ CMessage* self,
+ const FieldDescriptor* field_descriptor,
+ PyObject* arg) {
+ if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
+ return -1;
+ }
+
+ if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) {
+ return -1;
+ }
+
+ return InternalSetNonOneofScalar(self->message, field_descriptor, arg);
+}
+
PyObject* FromString(PyTypeObject* cls, PyObject* serialized) {
PyObject* py_cmsg = PyObject_CallObject(
reinterpret_cast<PyObject*>(cls), NULL);
@@ -1838,197 +2442,9 @@ PyObject* FromString(PyTypeObject* cls, PyObject* serialized) {
return NULL;
}
- if (InitAttributes(cmsg, NULL, NULL) < 0) {
- Py_DECREF(py_cmsg);
- return NULL;
- }
return py_cmsg;
}
-static PyObject* AddDescriptors(PyTypeObject* cls,
- PyObject* descriptor) {
- if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
- k_extensions_by_name, PyDict_New()) < 0) {
- return NULL;
- }
- if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
- k_extensions_by_number, PyDict_New()) < 0) {
- return NULL;
- }
-
- ScopedPyObjectPtr field_descriptors(PyDict_New());
-
- ScopedPyObjectPtr fields(PyObject_GetAttrString(descriptor, "fields"));
- if (fields == NULL) {
- return NULL;
- }
-
- ScopedPyObjectPtr _NUMBER_string(PyString_FromString("_FIELD_NUMBER"));
- if (_NUMBER_string == NULL) {
- return NULL;
- }
-
- const Py_ssize_t fields_size = PyList_GET_SIZE(fields.get());
- for (int i = 0; i < fields_size; ++i) {
- PyObject* field = PyList_GET_ITEM(fields.get(), i);
- ScopedPyObjectPtr field_name(PyObject_GetAttr(field, kname));
- ScopedPyObjectPtr full_field_name(PyObject_GetAttr(field, kfull_name));
- if (field_name == NULL || full_field_name == NULL) {
- PyErr_SetString(PyExc_TypeError, "Name is null");
- return NULL;
- }
-
- PyObject* field_descriptor =
- cdescriptor_pool::FindFieldByName(descriptor_pool, full_field_name);
- if (field_descriptor == NULL) {
- PyErr_SetString(PyExc_TypeError, "Couldn't find field");
- return NULL;
- }
- Py_INCREF(field);
- CFieldDescriptor* cfield_descriptor = reinterpret_cast<CFieldDescriptor*>(
- field_descriptor);
- cfield_descriptor->descriptor_field = field;
- if (PyDict_SetItem(field_descriptors, field_name, field_descriptor) < 0) {
- return NULL;
- }
-
- // The FieldDescriptor's name field might either be of type bytes or
- // of type unicode, depending on whether the FieldDescriptor was
- // parsed from a serialized message or read from the
- // <message>_pb2.py module.
- ScopedPyObjectPtr field_name_upcased(
- PyObject_CallMethod(field_name, "upper", NULL));
- if (field_name_upcased == NULL) {
- return NULL;
- }
-
- ScopedPyObjectPtr field_number_name(PyObject_CallMethod(
- field_name_upcased, "__add__", "(O)", _NUMBER_string.get()));
- if (field_number_name == NULL) {
- return NULL;
- }
-
- ScopedPyObjectPtr number(PyInt_FromLong(
- cfield_descriptor->descriptor->number()));
- if (number == NULL) {
- return NULL;
- }
- if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
- field_number_name, number) == -1) {
- return NULL;
- }
- }
-
- PyDict_SetItem(cls->tp_dict, k__descriptors, field_descriptors);
-
- // Enum Values
- ScopedPyObjectPtr enum_types(PyObject_GetAttrString(descriptor,
- "enum_types"));
- if (enum_types == NULL) {
- return NULL;
- }
- ScopedPyObjectPtr type_iter(PyObject_GetIter(enum_types));
- if (type_iter == NULL) {
- return NULL;
- }
- ScopedPyObjectPtr enum_type;
- while ((enum_type.reset(PyIter_Next(type_iter))) != NULL) {
- ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs(
- EnumTypeWrapper_class, enum_type.get(), NULL));
- if (wrapped == NULL) {
- return NULL;
- }
- ScopedPyObjectPtr enum_name(PyObject_GetAttr(enum_type, kname));
- if (enum_name == NULL) {
- return NULL;
- }
- if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
- enum_name, wrapped) == -1) {
- return NULL;
- }
-
- ScopedPyObjectPtr enum_values(PyObject_GetAttrString(enum_type, "values"));
- if (enum_values == NULL) {
- return NULL;
- }
- ScopedPyObjectPtr values_iter(PyObject_GetIter(enum_values));
- if (values_iter == NULL) {
- return NULL;
- }
- ScopedPyObjectPtr enum_value;
- while ((enum_value.reset(PyIter_Next(values_iter))) != NULL) {
- ScopedPyObjectPtr value_name(PyObject_GetAttr(enum_value, kname));
- if (value_name == NULL) {
- return NULL;
- }
- ScopedPyObjectPtr value_number(PyObject_GetAttrString(enum_value,
- "number"));
- if (value_number == NULL) {
- return NULL;
- }
- if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
- value_name, value_number) == -1) {
- return NULL;
- }
- }
- if (PyErr_Occurred()) { // If PyIter_Next failed
- return NULL;
- }
- }
- if (PyErr_Occurred()) { // If PyIter_Next failed
- return NULL;
- }
-
- ScopedPyObjectPtr extension_dict(
- PyObject_GetAttr(descriptor, kextensions_by_name));
- if (extension_dict == NULL || !PyDict_Check(extension_dict)) {
- PyErr_SetString(PyExc_TypeError, "extensions_by_name not a dict");
- return NULL;
- }
- Py_ssize_t pos = 0;
- PyObject* extension_name;
- PyObject* extension_field;
-
- while (PyDict_Next(extension_dict, &pos, &extension_name, &extension_field)) {
- if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
- extension_name, extension_field) == -1) {
- return NULL;
- }
- ScopedPyObjectPtr py_cfield_descriptor(
- PyObject_GetAttrString(extension_field, "_cdescriptor"));
- if (py_cfield_descriptor == NULL) {
- return NULL;
- }
- CFieldDescriptor* cfield_descriptor =
- reinterpret_cast<CFieldDescriptor*>(py_cfield_descriptor.get());
- Py_INCREF(extension_field);
- cfield_descriptor->descriptor_field = extension_field;
-
- ScopedPyObjectPtr field_name_upcased(
- PyObject_CallMethod(extension_name, "upper", NULL));
- if (field_name_upcased == NULL) {
- return NULL;
- }
- ScopedPyObjectPtr field_number_name(PyObject_CallMethod(
- field_name_upcased, "__add__", "(O)", _NUMBER_string.get()));
- if (field_number_name == NULL) {
- return NULL;
- }
- ScopedPyObjectPtr number(PyInt_FromLong(
- cfield_descriptor->descriptor->number()));
- if (number == NULL) {
- return NULL;
- }
- if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
- field_number_name, PyInt_FromLong(
- cfield_descriptor->descriptor->number())) == -1) {
- return NULL;
- }
- }
-
- Py_RETURN_NONE;
-}
-
PyObject* DeepCopy(CMessage* self, PyObject* arg) {
PyObject* clone = PyObject_CallObject(
reinterpret_cast<PyObject*>(Py_TYPE(self)), NULL);
@@ -2039,12 +2455,9 @@ PyObject* DeepCopy(CMessage* self, PyObject* arg) {
Py_DECREF(clone);
return NULL;
}
- if (InitAttributes(reinterpret_cast<CMessage*>(clone), NULL, NULL) < 0) {
- Py_DECREF(clone);
- return NULL;
- }
- if (MergeFrom(reinterpret_cast<CMessage*>(clone),
- reinterpret_cast<PyObject*>(self)) == NULL) {
+ if (ScopedPyObjectPtr(MergeFrom(
+ reinterpret_cast<CMessage*>(clone),
+ reinterpret_cast<PyObject*>(self))) == NULL) {
Py_DECREF(clone);
return NULL;
}
@@ -2063,16 +2476,16 @@ PyObject* ToUnicode(CMessage* self) {
return NULL;
}
Py_INCREF(Py_True);
- ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs(text_format, method_name,
- self, Py_True, NULL));
+ ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs(
+ text_format.get(), method_name.get(), self, Py_True, NULL));
Py_DECREF(Py_True);
if (encoded == NULL) {
return NULL;
}
#if PY_MAJOR_VERSION < 3
- PyObject* decoded = PyString_AsDecodedObject(encoded, "utf-8", NULL);
+ PyObject* decoded = PyString_AsDecodedObject(encoded.get(), "utf-8", NULL);
#else
- PyObject* decoded = PyUnicode_FromEncodedObject(encoded, "utf-8", NULL);
+ PyObject* decoded = PyUnicode_FromEncodedObject(encoded.get(), "utf-8", NULL);
#endif
if (decoded == NULL) {
return NULL;
@@ -2095,7 +2508,7 @@ PyObject* Reduce(CMessage* self) {
if (serialized == NULL) {
return NULL;
}
- if (PyDict_SetItemString(state, "serialized", serialized) < 0) {
+ if (PyDict_SetItemString(state.get(), "serialized", serialized.get()) < 0) {
return NULL;
}
return Py_BuildValue("OOO", constructor.get(), args.get(), state.get());
@@ -2110,24 +2523,49 @@ PyObject* SetState(CMessage* self, PyObject* state) {
if (serialized == NULL) {
return NULL;
}
- if (ParseFromString(self, serialized) == NULL) {
+ if (ScopedPyObjectPtr(ParseFromString(self, serialized)) == NULL) {
return NULL;
}
Py_RETURN_NONE;
}
// CMessage static methods:
-PyObject* _GetFieldDescriptor(PyObject* unused, PyObject* arg) {
- return cdescriptor_pool::FindFieldByName(descriptor_pool, arg);
+PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
+ PyObject* unused_arg) {
+ if (!_CalledFromGeneratedFile(1)) {
+ PyErr_SetString(PyExc_TypeError,
+ "Descriptors should not be created directly, "
+ "but only retrieved from their parent.");
+ return NULL;
+ }
+ Py_RETURN_NONE;
}
-PyObject* _GetExtensionDescriptor(PyObject* unused, PyObject* arg) {
- return cdescriptor_pool::FindExtensionByName(descriptor_pool, arg);
+static PyObject* GetExtensionDict(CMessage* self, void *closure) {
+ if (self->extensions) {
+ Py_INCREF(self->extensions);
+ return reinterpret_cast<PyObject*>(self->extensions);
+ }
+
+ // If there are extension_ranges, the message is "extendable". Allocate a
+ // dictionary to store the extension fields.
+ const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self));
+ if (descriptor->extension_range_count() > 0) {
+ ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self);
+ if (extension_dict == NULL) {
+ return NULL;
+ }
+ self->extensions = extension_dict;
+ Py_INCREF(self->extensions);
+ return reinterpret_cast<PyObject*>(self->extensions);
+ }
+
+ PyErr_SetNone(PyExc_AttributeError);
+ return NULL;
}
-static PyMemberDef Members[] = {
- {"Extensions", T_OBJECT_EX, offsetof(CMessage, extensions), 0,
- "Extension dict"},
+static PyGetSetDef Getters[] = {
+ {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"},
{NULL}
};
@@ -2140,8 +2578,6 @@ static PyMethodDef Methods[] = {
"Inputs picklable representation of the message." },
{ "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
"Outputs a unicode representation of the message." },
- { "AddDescriptors", (PyCFunction)AddDescriptors, METH_O | METH_CLASS,
- "Adds field descriptors to the class" },
{ "ByteSize", (PyCFunction)ByteSize, METH_NOARGS,
"Returns the size of the message in bytes." },
{ "Clear", (PyCFunction)Clear, METH_NOARGS,
@@ -2152,6 +2588,8 @@ static PyMethodDef Methods[] = {
"Clears a message field." },
{ "CopyFrom", (PyCFunction)CopyFrom, METH_O,
"Copies a protocol message into the current message." },
+ { "DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS,
+ "Discards the unknown fields." },
{ "FindInitializationErrors", (PyCFunction)FindInitializationErrors,
METH_NOARGS,
"Finds unset required fields." },
@@ -2185,113 +2623,117 @@ static PyMethodDef Methods[] = {
"or None if no field is set." },
// Static Methods.
- { "_BuildFile", (PyCFunction)Python_BuildFile, METH_O | METH_STATIC,
- "Registers a new protocol buffer file in the global C++ descriptor pool." },
- { "_GetFieldDescriptor", (PyCFunction)_GetFieldDescriptor,
- METH_O | METH_STATIC, "Finds a field descriptor in the message pool." },
- { "_GetExtensionDescriptor", (PyCFunction)_GetExtensionDescriptor,
- METH_O | METH_STATIC,
- "Finds a extension descriptor in the message pool." },
+ { "_CheckCalledFromGeneratedFile", (PyCFunction)_CheckCalledFromGeneratedFile,
+ METH_NOARGS | METH_STATIC,
+ "Raises TypeError if the caller is not in a _pb2.py file."},
{ NULL, NULL}
};
+static bool SetCompositeField(
+ CMessage* self, PyObject* name, PyObject* value) {
+ if (self->composite_fields == NULL) {
+ self->composite_fields = PyDict_New();
+ if (self->composite_fields == NULL) {
+ return false;
+ }
+ }
+ return PyDict_SetItem(self->composite_fields, name, value) == 0;
+}
+
PyObject* GetAttr(CMessage* self, PyObject* name) {
- PyObject* value = PyDict_GetItem(self->composite_fields, name);
+ PyObject* value = self->composite_fields ?
+ PyDict_GetItem(self->composite_fields, name) : NULL;
if (value != NULL) {
Py_INCREF(value);
return value;
}
- PyObject* descriptor = GetDescriptor(self, name);
- if (descriptor != NULL) {
- CFieldDescriptor* cdescriptor =
- reinterpret_cast<CFieldDescriptor*>(descriptor);
- const google::protobuf::FieldDescriptor* field_descriptor = cdescriptor->descriptor;
- if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
- if (field_descriptor->cpp_type() ==
- google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
- PyObject* py_container = PyObject_CallObject(
- reinterpret_cast<PyObject*>(&RepeatedCompositeContainer_Type),
- NULL);
- if (py_container == NULL) {
- return NULL;
- }
- RepeatedCompositeContainer* container =
- reinterpret_cast<RepeatedCompositeContainer*>(py_container);
- PyObject* field = cdescriptor->descriptor_field;
- PyObject* message_type = PyObject_GetAttr(field, kmessage_type);
- if (message_type == NULL) {
- return NULL;
- }
- PyObject* concrete_class =
- PyObject_GetAttr(message_type, k_concrete_class);
- if (concrete_class == NULL) {
- return NULL;
- }
- container->parent = self;
- container->parent_field = cdescriptor;
- container->message = self->message;
- container->owner = self->owner;
- container->subclass_init = concrete_class;
- Py_DECREF(message_type);
- if (PyDict_SetItem(self->composite_fields, name, py_container) < 0) {
- Py_DECREF(py_container);
- return NULL;
- }
- return py_container;
- } else {
- ScopedPyObjectPtr init_args(PyTuple_Pack(2, self, cdescriptor));
- PyObject* py_container = PyObject_CallObject(
- reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type),
- init_args);
- if (py_container == NULL) {
- return NULL;
- }
- if (PyDict_SetItem(self->composite_fields, name, py_container) < 0) {
- Py_DECREF(py_container);
- return NULL;
- }
- return py_container;
+ const FieldDescriptor* field_descriptor = GetFieldDescriptor(self, name);
+ if (field_descriptor == NULL) {
+ return CMessage_Type.tp_base->tp_getattro(
+ reinterpret_cast<PyObject*>(self), name);
+ }
+
+ if (field_descriptor->is_map()) {
+ PyObject* py_container = NULL;
+ const Descriptor* entry_type = field_descriptor->message_type();
+ const FieldDescriptor* value_type = entry_type->FindFieldByName("value");
+ if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ CMessageClass* value_class = cdescriptor_pool::GetMessageClass(
+ GetDescriptorPoolForMessage(self), value_type->message_type());
+ if (value_class == NULL) {
+ return NULL;
}
+ py_container =
+ NewMessageMapContainer(self, field_descriptor, value_class);
} else {
- if (field_descriptor->cpp_type() ==
- google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
- PyObject* sub_message = InternalGetSubMessage(self, cdescriptor);
- if (PyDict_SetItem(self->composite_fields, name, sub_message) < 0) {
- Py_DECREF(sub_message);
- return NULL;
- }
- return sub_message;
- } else {
- return InternalGetScalar(self, field_descriptor);
+ py_container = NewScalarMapContainer(self, field_descriptor);
+ }
+ if (py_container == NULL) {
+ return NULL;
+ }
+ if (!SetCompositeField(self, name, py_container)) {
+ Py_DECREF(py_container);
+ return NULL;
+ }
+ return py_container;
+ }
+
+ if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
+ PyObject* py_container = NULL;
+ if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ CMessageClass* message_class = cdescriptor_pool::GetMessageClass(
+ GetDescriptorPoolForMessage(self), field_descriptor->message_type());
+ if (message_class == NULL) {
+ return NULL;
}
+ py_container = repeated_composite_container::NewContainer(
+ self, field_descriptor, message_class);
+ } else {
+ py_container = repeated_scalar_container::NewContainer(
+ self, field_descriptor);
+ }
+ if (py_container == NULL) {
+ return NULL;
+ }
+ if (!SetCompositeField(self, name, py_container)) {
+ Py_DECREF(py_container);
+ return NULL;
+ }
+ return py_container;
+ }
+
+ if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ PyObject* sub_message = InternalGetSubMessage(self, field_descriptor);
+ if (sub_message == NULL) {
+ return NULL;
+ }
+ if (!SetCompositeField(self, name, sub_message)) {
+ Py_DECREF(sub_message);
+ return NULL;
}
+ return sub_message;
}
- return CMessage_Type.tp_base->tp_getattro(reinterpret_cast<PyObject*>(self),
- name);
+ return InternalGetScalar(self->message, field_descriptor);
}
int SetAttr(CMessage* self, PyObject* name, PyObject* value) {
- if (PyDict_Contains(self->composite_fields, name)) {
+ if (self->composite_fields && PyDict_Contains(self->composite_fields, name)) {
PyErr_SetString(PyExc_TypeError, "Can't set composite field");
return -1;
}
- PyObject* descriptor = GetDescriptor(self, name);
- if (descriptor != NULL) {
+ const FieldDescriptor* field_descriptor = GetFieldDescriptor(self, name);
+ if (field_descriptor != NULL) {
AssureWritable(self);
- CFieldDescriptor* cdescriptor =
- reinterpret_cast<CFieldDescriptor*>(descriptor);
- const google::protobuf::FieldDescriptor* field_descriptor = cdescriptor->descriptor;
- if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
+ if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
PyErr_Format(PyExc_AttributeError, "Assignment not allowed to repeated "
"field \"%s\" in protocol message object.",
field_descriptor->name().c_str());
return -1;
} else {
- if (field_descriptor->cpp_type() ==
- google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
PyErr_Format(PyExc_AttributeError, "Assignment not allowed to "
"field \"%s\" in protocol message object.",
field_descriptor->name().c_str());
@@ -2302,16 +2744,18 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) {
}
}
- PyErr_Format(PyExc_AttributeError, "Assignment not allowed");
+ PyErr_Format(PyExc_AttributeError,
+ "Assignment not allowed "
+ "(no field \"%s\" in protocol message object).",
+ PyString_AsString(name));
return -1;
}
} // namespace cmessage
PyTypeObject CMessage_Type = {
- PyVarObject_HEAD_INIT(&PyType_Type, 0)
- "google.protobuf.internal."
- "cpp._message.CMessage", // tp_name
+ PyVarObject_HEAD_INIT(&CMessageClass_Type, 0)
+ FULL_MODULE_NAME ".CMessage", // tp_name
sizeof(CMessage), // tp_basicsize
0, // tp_itemsize
(destructor)cmessage::Dealloc, // tp_dealloc
@@ -2319,11 +2763,11 @@ PyTypeObject CMessage_Type = {
0, // tp_getattr
0, // tp_setattr
0, // tp_compare
- 0, // tp_repr
+ (reprfunc)cmessage::ToStr, // tp_repr
0, // tp_as_number
0, // tp_as_sequence
0, // tp_as_mapping
- 0, // tp_hash
+ PyObject_HashNotImplemented, // tp_hash
0, // tp_call
(reprfunc)cmessage::ToStr, // tp_str
(getattrofunc)cmessage::GetAttr, // tp_getattro
@@ -2338,8 +2782,8 @@ PyTypeObject CMessage_Type = {
0, // tp_iter
0, // tp_iternext
cmessage::Methods, // tp_methods
- cmessage::Members, // tp_members
- 0, // tp_getset
+ 0, // tp_members
+ cmessage::Getters, // tp_getset
0, // tp_base
0, // tp_dict
0, // tp_descr_get
@@ -2355,7 +2799,7 @@ PyTypeObject CMessage_Type = {
const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg);
Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg);
-static const google::protobuf::Message* GetCProtoInsidePyProtoImpl(PyObject* msg) {
+static const Message* GetCProtoInsidePyProtoImpl(PyObject* msg) {
if (!PyObject_TypeCheck(msg, &CMessage_Type)) {
return NULL;
}
@@ -2363,12 +2807,12 @@ static const google::protobuf::Message* GetCProtoInsidePyProtoImpl(PyObject* msg
return cmsg->message;
}
-static google::protobuf::Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
+static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
if (!PyObject_TypeCheck(msg, &CMessage_Type)) {
return NULL;
}
CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
- if (PyDict_Size(cmsg->composite_fields) != 0 ||
+ if ((cmsg->composite_fields && PyDict_Size(cmsg->composite_fields) != 0) ||
(cmsg->extensions != NULL &&
PyDict_Size(cmsg->extensions->values) != 0)) {
// There is currently no way of accurately syncing arbitrary changes to
@@ -2401,87 +2845,212 @@ void InitGlobals() {
kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max);
kDESCRIPTOR = PyString_FromString("DESCRIPTOR");
- k__descriptors = PyString_FromString("__descriptors");
+ k_cdescriptor = PyString_FromString("_cdescriptor");
kfull_name = PyString_FromString("full_name");
- kis_extendable = PyString_FromString("is_extendable");
- kextensions_by_name = PyString_FromString("extensions_by_name");
k_extensions_by_name = PyString_FromString("_extensions_by_name");
k_extensions_by_number = PyString_FromString("_extensions_by_number");
- k_concrete_class = PyString_FromString("_concrete_class");
- kmessage_type = PyString_FromString("message_type");
- kname = PyString_FromString("name");
- kfields_by_name = PyString_FromString("fields_by_name");
-
- global_message_factory = new DynamicMessageFactory(GetDescriptorPool());
- global_message_factory->SetDelegateToGeneratedFactory(true);
- descriptor_pool = reinterpret_cast<google::protobuf::python::CDescriptorPool*>(
- Python_NewCDescriptorPool(NULL, NULL));
+ PyObject *dummy_obj = PySet_New(NULL);
+ kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL);
+ Py_DECREF(dummy_obj);
}
bool InitProto2MessageModule(PyObject *m) {
- InitGlobals();
-
- google::protobuf::python::CMessage_Type.tp_hash = PyObject_HashNotImplemented;
- if (PyType_Ready(&google::protobuf::python::CMessage_Type) < 0) {
+ // Initialize types and globals in descriptor.cc
+ if (!InitDescriptor()) {
return false;
}
- // All three of these are actually set elsewhere, directly onto the child
- // protocol buffer message class, but set them here as well to document that
- // subclasses need to set these.
- PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, kDESCRIPTOR, Py_None);
- PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict,
- k_extensions_by_name, Py_None);
- PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict,
- k_extensions_by_number, Py_None);
+ // Initialize types and globals in descriptor_pool.cc
+ if (!InitDescriptorPool()) {
+ return false;
+ }
- PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(
- &google::protobuf::python::CMessage_Type));
+ // Initialize constants defined in this file.
+ InitGlobals();
- google::protobuf::python::RepeatedScalarContainer_Type.tp_new = PyType_GenericNew;
- google::protobuf::python::RepeatedScalarContainer_Type.tp_hash =
- PyObject_HashNotImplemented;
- if (PyType_Ready(&google::protobuf::python::RepeatedScalarContainer_Type) < 0) {
+ CMessageClass_Type.tp_base = &PyType_Type;
+ if (PyType_Ready(&CMessageClass_Type) < 0) {
return false;
}
+ PyModule_AddObject(m, "MessageMeta",
+ reinterpret_cast<PyObject*>(&CMessageClass_Type));
- PyModule_AddObject(m, "RepeatedScalarContainer",
- reinterpret_cast<PyObject*>(
- &google::protobuf::python::RepeatedScalarContainer_Type));
+ if (PyType_Ready(&CMessage_Type) < 0) {
+ return false;
+ }
- google::protobuf::python::RepeatedCompositeContainer_Type.tp_new = PyType_GenericNew;
- google::protobuf::python::RepeatedCompositeContainer_Type.tp_hash =
- PyObject_HashNotImplemented;
- if (PyType_Ready(&google::protobuf::python::RepeatedCompositeContainer_Type) < 0) {
+ // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set
+ // it here as well to document that subclasses need to set it.
+ PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None);
+ // Subclasses with message extensions will override _extensions_by_name and
+ // _extensions_by_number with fresh mutable dictionaries in AddDescriptors.
+ // All other classes can share this same immutable mapping.
+ ScopedPyObjectPtr empty_dict(PyDict_New());
+ if (empty_dict == NULL) {
+ return false;
+ }
+ ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict.get()));
+ if (immutable_dict == NULL) {
+ return false;
+ }
+ if (PyDict_SetItem(CMessage_Type.tp_dict,
+ k_extensions_by_name, immutable_dict.get()) < 0) {
+ return false;
+ }
+ if (PyDict_SetItem(CMessage_Type.tp_dict,
+ k_extensions_by_number, immutable_dict.get()) < 0) {
return false;
}
- PyModule_AddObject(
- m, "RepeatedCompositeContainer",
- reinterpret_cast<PyObject*>(
- &google::protobuf::python::RepeatedCompositeContainer_Type));
+ PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type));
- google::protobuf::python::ExtensionDict_Type.tp_new = PyType_GenericNew;
- google::protobuf::python::ExtensionDict_Type.tp_hash = PyObject_HashNotImplemented;
- if (PyType_Ready(&google::protobuf::python::ExtensionDict_Type) < 0) {
- return false;
+ // Initialize Repeated container types.
+ {
+ if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(m, "RepeatedScalarContainer",
+ reinterpret_cast<PyObject*>(
+ &RepeatedScalarContainer_Type));
+
+ if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(
+ m, "RepeatedCompositeContainer",
+ reinterpret_cast<PyObject*>(
+ &RepeatedCompositeContainer_Type));
+
+ // Register them as collections.Sequence
+ ScopedPyObjectPtr collections(PyImport_ImportModule("collections"));
+ if (collections == NULL) {
+ return false;
+ }
+ ScopedPyObjectPtr mutable_sequence(
+ PyObject_GetAttrString(collections.get(), "MutableSequence"));
+ if (mutable_sequence == NULL) {
+ return false;
+ }
+ if (ScopedPyObjectPtr(
+ PyObject_CallMethod(mutable_sequence.get(), "register", "O",
+ &RepeatedScalarContainer_Type)) == NULL) {
+ return false;
+ }
+ if (ScopedPyObjectPtr(
+ PyObject_CallMethod(mutable_sequence.get(), "register", "O",
+ &RepeatedCompositeContainer_Type)) == NULL) {
+ return false;
+ }
}
- PyModule_AddObject(
- m, "ExtensionDict",
- reinterpret_cast<PyObject*>(&google::protobuf::python::ExtensionDict_Type));
+ // Initialize Map container types.
+ {
+ // ScalarMapContainer_Type derives from our MutableMapping type.
+ ScopedPyObjectPtr containers(PyImport_ImportModule(
+ "google.protobuf.internal.containers"));
+ if (containers == NULL) {
+ return false;
+ }
+
+ ScopedPyObjectPtr mutable_mapping(
+ PyObject_GetAttrString(containers.get(), "MutableMapping"));
+ if (mutable_mapping == NULL) {
+ return false;
+ }
+
+ if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) {
+ return false;
+ }
+
+ Py_INCREF(mutable_mapping.get());
+#if PY_MAJOR_VERSION >= 3
+ PyObject* bases = PyTuple_New(1);
+ PyTuple_SET_ITEM(bases, 0, mutable_mapping.get());
+
+ ScalarMapContainer_Type =
+ PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases);
+ PyModule_AddObject(m, "ScalarMapContainer", ScalarMapContainer_Type);
+#else
+ ScalarMapContainer_Type.tp_base =
+ reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
+
+ if (PyType_Ready(&ScalarMapContainer_Type) < 0) {
+ return false;
+ }
- if (!google::protobuf::python::InitDescriptor()) {
+ PyModule_AddObject(m, "ScalarMapContainer",
+ reinterpret_cast<PyObject*>(&ScalarMapContainer_Type));
+#endif
+
+ if (PyType_Ready(&MapIterator_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(m, "MapIterator",
+ reinterpret_cast<PyObject*>(&MapIterator_Type));
+
+
+#if PY_MAJOR_VERSION >= 3
+ MessageMapContainer_Type =
+ PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases);
+ PyModule_AddObject(m, "MessageMapContainer", MessageMapContainer_Type);
+#else
+ Py_INCREF(mutable_mapping.get());
+ MessageMapContainer_Type.tp_base =
+ reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
+
+ if (PyType_Ready(&MessageMapContainer_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(m, "MessageMapContainer",
+ reinterpret_cast<PyObject*>(&MessageMapContainer_Type));
+#endif
+ }
+
+ if (PyType_Ready(&ExtensionDict_Type) < 0) {
return false;
}
+ PyModule_AddObject(
+ m, "ExtensionDict",
+ reinterpret_cast<PyObject*>(&ExtensionDict_Type));
+
+ // Expose the DescriptorPool used to hold all descriptors added from generated
+ // pb2.py files.
+ // PyModule_AddObject steals a reference.
+ Py_INCREF(GetDefaultDescriptorPool());
+ PyModule_AddObject(m, "default_pool",
+ reinterpret_cast<PyObject*>(GetDefaultDescriptorPool()));
+
+ PyModule_AddObject(m, "DescriptorPool", reinterpret_cast<PyObject*>(
+ &PyDescriptorPool_Type));
+
+ // This implementation provides full Descriptor types, we advertise it so that
+ // descriptor.py can use them in replacement of the Python classes.
+ PyModule_AddIntConstant(m, "_USE_C_DESCRIPTORS", 1);
+
+ PyModule_AddObject(m, "Descriptor", reinterpret_cast<PyObject*>(
+ &PyMessageDescriptor_Type));
+ PyModule_AddObject(m, "FieldDescriptor", reinterpret_cast<PyObject*>(
+ &PyFieldDescriptor_Type));
+ PyModule_AddObject(m, "EnumDescriptor", reinterpret_cast<PyObject*>(
+ &PyEnumDescriptor_Type));
+ PyModule_AddObject(m, "EnumValueDescriptor", reinterpret_cast<PyObject*>(
+ &PyEnumValueDescriptor_Type));
+ PyModule_AddObject(m, "FileDescriptor", reinterpret_cast<PyObject*>(
+ &PyFileDescriptor_Type));
+ PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast<PyObject*>(
+ &PyOneofDescriptor_Type));
PyObject* enum_type_wrapper = PyImport_ImportModule(
"google.protobuf.internal.enum_type_wrapper");
if (enum_type_wrapper == NULL) {
return false;
}
- google::protobuf::python::EnumTypeWrapper_class =
+ EnumTypeWrapper_class =
PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper");
Py_DECREF(enum_type_wrapper);
@@ -2490,25 +3059,21 @@ bool InitProto2MessageModule(PyObject *m) {
if (message_module == NULL) {
return false;
}
- google::protobuf::python::EncodeError_class = PyObject_GetAttrString(message_module,
- "EncodeError");
- google::protobuf::python::DecodeError_class = PyObject_GetAttrString(message_module,
- "DecodeError");
+ EncodeError_class = PyObject_GetAttrString(message_module, "EncodeError");
+ DecodeError_class = PyObject_GetAttrString(message_module, "DecodeError");
+ PythonMessage_class = PyObject_GetAttrString(message_module, "Message");
Py_DECREF(message_module);
PyObject* pickle_module = PyImport_ImportModule("pickle");
if (pickle_module == NULL) {
return false;
}
- google::protobuf::python::PickleError_class = PyObject_GetAttrString(pickle_module,
- "PickleError");
+ PickleError_class = PyObject_GetAttrString(pickle_module, "PickleError");
Py_DECREF(pickle_module);
// Override {Get,Mutable}CProtoInsidePyProto.
- google::protobuf::python::GetCProtoInsidePyProtoPtr =
- google::protobuf::python::GetCProtoInsidePyProtoImpl;
- google::protobuf::python::MutableCProtoInsidePyProtoPtr =
- google::protobuf::python::MutableCProtoInsidePyProtoImpl;
+ GetCProtoInsidePyProtoPtr = GetCProtoInsidePyProtoImpl;
+ MutableCProtoInsidePyProtoPtr = MutableCProtoInsidePyProtoImpl;
return true;
}
@@ -2516,6 +3081,12 @@ bool InitProto2MessageModule(PyObject *m) {
} // namespace python
} // namespace protobuf
+static PyMethodDef ModuleMethods[] = {
+ {"SetAllowOversizeProtos",
+ (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos,
+ METH_O, "Enable/disable oversize proto parsing."},
+ { NULL, NULL}
+};
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef _module = {
@@ -2523,7 +3094,7 @@ static struct PyModuleDef _module = {
"_message",
google::protobuf::python::module_docstring,
-1,
- NULL,
+ ModuleMethods, /* m_methods */
NULL,
NULL,
NULL,
@@ -2542,7 +3113,8 @@ extern "C" {
#if PY_MAJOR_VERSION >= 3
m = PyModule_Create(&_module);
#else
- m = Py_InitModule3("_message", NULL, google::protobuf::python::module_docstring);
+ m = Py_InitModule3("_message", ModuleMethods,
+ google::protobuf::python::module_docstring);
#endif
if (m == NULL) {
return INITFUNC_ERRORVAL;
diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h
index 9f4978f44..3a4bec81c 100644
--- a/python/google/protobuf/pyext/message.h
+++ b/python/google/protobuf/pyext/message.h
@@ -42,20 +42,27 @@
#endif
#include <string>
-
namespace google {
namespace protobuf {
class Message;
class Reflection;
class FieldDescriptor;
-
+class Descriptor;
+class DescriptorPool;
+class MessageFactory;
+
+#ifdef _SHARED_PTR_H
+using std::shared_ptr;
+using ::std::string;
+#else
using internal::shared_ptr;
+#endif
namespace python {
-struct CFieldDescriptor;
struct ExtensionDict;
+struct PyDescriptorPool;
typedef struct CMessage {
PyObject_HEAD;
@@ -79,13 +86,11 @@ typedef struct CMessage {
// to use this pointer will result in a crash.
struct CMessage* parent;
- // Weak reference to the parent's descriptor that describes this submessage.
+ // Pointer to the parent's descriptor that describes this submessage.
// Used together with the parent's message when making a default message
// instance mutable.
- // TODO(anuraag): With a bit of work on the Python/C++ layer, it should be
- // possible to make this a direct pointer to a C++ FieldDescriptor, this would
- // be easier if this implementation replaces upstream.
- CFieldDescriptor* parent_field;
+ // The pointer is owned by the global DescriptorPool.
+ const FieldDescriptor* parent_field_descriptor;
// Pointer to the C++ Message object for this CMessage. The
// CMessage does not own this pointer.
@@ -111,49 +116,91 @@ typedef struct CMessage {
extern PyTypeObject CMessage_Type;
+
+// The (meta) type of all Messages classes.
+// It allows us to cache some C++ pointers in the class object itself, they are
+// faster to extract than from the type's dictionary.
+
+struct CMessageClass {
+ // This is how CPython subclasses C structures: the base structure must be
+ // the first member of the object.
+ PyHeapTypeObject super;
+
+ // C++ descriptor of this message.
+ const Descriptor* message_descriptor;
+
+ // Owned reference, used to keep the pointer above alive.
+ PyObject* py_message_descriptor;
+
+ // The Python DescriptorPool used to create the class. It is needed to resolve
+ // fields descriptors, including extensions fields; its C++ MessageFactory is
+ // used to instantiate submessages.
+ // This can be different from DESCRIPTOR.file.pool, in the case of a custom
+ // DescriptorPool which defines new extensions.
+ // We own the reference, because it's important to keep the descriptors and
+ // factory alive.
+ PyDescriptorPool* py_descriptor_pool;
+
+ PyObject* AsPyObject() {
+ return reinterpret_cast<PyObject*>(this);
+ }
+};
+
+
namespace cmessage {
-// Create a new empty message that can be populated by the parent.
-PyObject* NewEmpty(PyObject* type);
+// Internal function to create a new empty Message Python object, but with empty
+// pointers to the C++ objects.
+// The caller must fill self->message, self->owner and eventually self->parent.
+CMessage* NewEmptyMessage(CMessageClass* type);
// Release a submessage from its proto tree, making it a new top-level messgae.
// A new message will be created if this is a read-only default instance.
//
// Corresponds to reflection api method ReleaseMessage.
-int ReleaseSubMessage(google::protobuf::Message* message,
- const google::protobuf::FieldDescriptor* field_descriptor,
+int ReleaseSubMessage(CMessage* self,
+ const FieldDescriptor* field_descriptor,
CMessage* child_cmessage);
+// Retrieves the C++ descriptor of a Python Extension descriptor.
+// On error, return NULL with an exception set.
+const FieldDescriptor* GetExtensionDescriptor(PyObject* extension);
+
// Initializes a new CMessage instance for a submessage. Only called once per
// submessage as the result is cached in composite_fields.
//
// Corresponds to reflection api method GetMessage.
-PyObject* InternalGetSubMessage(CMessage* self,
- CFieldDescriptor* cfield_descriptor);
+PyObject* InternalGetSubMessage(
+ CMessage* self, const FieldDescriptor* field_descriptor);
// Deletes a range of C++ submessages in a repeated field (following a
// removal in a RepeatedCompositeContainer).
//
// Releases messages to the provided cmessage_list if it is not NULL rather
// than just removing them from the underlying proto. This cmessage_list must
-// have a CMessage for each underlying submessage. The CMessages refered to
+// have a CMessage for each underlying submessage. The CMessages referred to
// by slice will be removed from cmessage_list by this function.
//
// Corresponds to reflection api method RemoveLast.
-int InternalDeleteRepeatedField(google::protobuf::Message* message,
- const google::protobuf::FieldDescriptor* field_descriptor,
+int InternalDeleteRepeatedField(CMessage* self,
+ const FieldDescriptor* field_descriptor,
PyObject* slice, PyObject* cmessage_list);
// Sets the specified scalar value to the message.
int InternalSetScalar(CMessage* self,
- const google::protobuf::FieldDescriptor* field_descriptor,
+ const FieldDescriptor* field_descriptor,
PyObject* value);
+// Sets the specified scalar value to the message. Requires it is not a Oneof.
+int InternalSetNonOneofScalar(Message* message,
+ const FieldDescriptor* field_descriptor,
+ PyObject* arg);
+
// Retrieves the specified scalar value from the message.
//
// Returns a new python reference.
-PyObject* InternalGetScalar(CMessage* self,
- const google::protobuf::FieldDescriptor* field_descriptor);
+PyObject* InternalGetScalar(const Message* message,
+ const FieldDescriptor* field_descriptor);
// Clears the message, removing all contained data. Extension dictionary and
// submessages are released first if there are remaining external references.
@@ -169,8 +216,7 @@ PyObject* Clear(CMessage* self);
//
// Corresponds to reflection api method ClearField.
PyObject* ClearFieldByDescriptor(
- CMessage* self,
- const google::protobuf::FieldDescriptor* descriptor);
+ CMessage* self, const FieldDescriptor* descriptor);
// Clears the data for the given field name. The message is released if there
// are any external references.
@@ -183,17 +229,15 @@ PyObject* ClearField(CMessage* self, PyObject* arg);
//
// Corresponds to reflection api method HasField
PyObject* HasFieldByDescriptor(
- CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor);
+ CMessage* self, const FieldDescriptor* field_descriptor);
// Checks if the message has the named field.
//
// Corresponds to reflection api method HasField.
PyObject* HasField(CMessage* self, PyObject* arg);
-// Initializes constants/enum values on a message. This is called by
-// RepeatedCompositeContainer and ExtensionDict after calling the constructor.
-// TODO(anuraag): Make it always called from within the constructor since it can
-int InitAttributes(CMessage* self, PyObject* descriptor, PyObject* kwargs);
+// Initializes values of fields on a newly constructed message.
+int InitAttributes(CMessage* self, PyObject* kwargs);
PyObject* MergeFrom(CMessage* self, PyObject* arg);
@@ -216,16 +260,23 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner);
int AssureWritable(CMessage* self);
+// Returns the "best" DescriptorPool for the given message.
+// This is often equivalent to message.DESCRIPTOR.pool, but not always, when
+// the message class was created from a MessageFactory using a custom pool which
+// uses the generated pool as an underlay.
+//
+// The returned pool is suitable for finding fields and building submessages,
+// even in the case of extensions.
+PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message);
+
} // namespace cmessage
+
/* Is 64bit */
#define IS_64BIT (SIZEOF_LONG == 8)
-#define FIELD_BELONGS_TO_MESSAGE(field_descriptor, message) \
- ((message)->GetDescriptor() == (field_descriptor)->containing_type())
-
#define FIELD_IS_REPEATED(field_descriptor) \
- ((field_descriptor)->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED)
+ ((field_descriptor)->label() == FieldDescriptor::LABEL_REPEATED)
#define GOOGLE_CHECK_GET_INT32(arg, value, err) \
int32 value; \
@@ -278,7 +329,7 @@ extern PyObject* kint64min_py;
extern PyObject* kint64max_py;
extern PyObject* kuint64max_py;
-#define C(str) const_cast<char*>(str)
+#define FULL_MODULE_NAME "google.protobuf.pyext._message"
void FormatTypeError(PyObject* arg, char* expected_types);
template<class T>
@@ -287,14 +338,19 @@ bool CheckAndGetInteger(
bool CheckAndGetDouble(PyObject* arg, double* value);
bool CheckAndGetFloat(PyObject* arg, float* value);
bool CheckAndGetBool(PyObject* arg, bool* value);
+PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor);
bool CheckAndSetString(
- PyObject* arg, google::protobuf::Message* message,
- const google::protobuf::FieldDescriptor* descriptor,
- const google::protobuf::Reflection* reflection,
+ PyObject* arg, Message* message,
+ const FieldDescriptor* descriptor,
+ const Reflection* reflection,
bool append,
int index);
-PyObject* ToStringObject(
- const google::protobuf::FieldDescriptor* descriptor, string value);
+PyObject* ToStringObject(const FieldDescriptor* descriptor, string value);
+
+// Check if the passed field descriptor belongs to the given message.
+// If not, return false and set a Python exception (a KeyError)
+bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
+ const Message* message);
extern PyObject* PickleError_class;
diff --git a/python/google/protobuf/pyext/message_factory_cpp2_test.py b/python/google/protobuf/pyext/message_factory_cpp2_test.py
deleted file mode 100644
index 32ab4f852..000000000
--- a/python/google/protobuf/pyext/message_factory_cpp2_test.py
+++ /dev/null
@@ -1,56 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests for google.protobuf.message_factory."""
-
-import os
-os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp'
-os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2'
-
-# We must set the implementation version above before the google3 imports.
-# pylint: disable=g-import-not-at-top
-from google.apputils import basetest
-from google.protobuf.internal import api_implementation
-# Run all tests from the original module by putting them in our namespace.
-# pylint: disable=wildcard-import
-from google.protobuf.internal.message_factory_test import *
-
-
-class ConfirmCppApi2Test(basetest.TestCase):
-
- def testImplementationSetting(self):
- self.assertEqual('cpp', api_implementation.Type())
- self.assertEqual(2, api_implementation.Version())
-
-
-if __name__ == '__main__':
- basetest.main()
diff --git a/python/google/protobuf/pyext/proto2_api_test.proto b/python/google/protobuf/pyext/proto2_api_test.proto
index 72c31b1cd..18aecfb7d 100644
--- a/python/google/protobuf/pyext/proto2_api_test.proto
+++ b/python/google/protobuf/pyext/proto2_api_test.proto
@@ -28,6 +28,8 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+syntax = "proto2";
+
import "google/protobuf/internal/cpp/proto1_api_test.proto";
package google.protobuf.python.internal;
diff --git a/python/google/protobuf/pyext/python.proto b/python/google/protobuf/pyext/python.proto
index d47d402c6..cce645d71 100644
--- a/python/google/protobuf/pyext/python.proto
+++ b/python/google/protobuf/pyext/python.proto
@@ -33,6 +33,7 @@
// These message definitions are used to exercises known corner cases
// in the C++ implementation of the Python API.
+syntax = "proto2";
package google.protobuf.python.internal;
@@ -63,4 +64,5 @@ message TestAllExtensions {
extend TestAllExtensions {
optional TestAllTypes.NestedMessage optional_nested_message_extension = 1;
+ repeated TestAllTypes.NestedMessage repeated_nested_message_extension = 2;
}
diff --git a/python/google/protobuf/pyext/reflection_cpp2_generated_test.py b/python/google/protobuf/pyext/reflection_cpp2_generated_test.py
deleted file mode 100755
index 552efd482..000000000
--- a/python/google/protobuf/pyext/reflection_cpp2_generated_test.py
+++ /dev/null
@@ -1,94 +0,0 @@
-#! /usr/bin/python
-# -*- coding: utf-8 -*-
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# https://developers.google.com/protocol-buffers/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Unittest for reflection.py, which tests the generated C++ implementation."""
-
-__author__ = 'jasonh@google.com (Jason Hsueh)'
-
-import os
-os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp'
-os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2'
-
-from google.apputils import basetest
-from google.protobuf.internal import api_implementation
-from google.protobuf.internal import more_extensions_dynamic_pb2
-from google.protobuf.internal import more_extensions_pb2
-from google.protobuf.internal.reflection_test import *
-
-
-class ReflectionCppTest(basetest.TestCase):
- def testImplementationSetting(self):
- self.assertEqual('cpp', api_implementation.Type())
- self.assertEqual(2, api_implementation.Version())
-
- def testExtensionOfGeneratedTypeInDynamicFile(self):
- """Tests that a file built dynamically can extend a generated C++ type.
-
- The C++ implementation uses a DescriptorPool that has the generated
- DescriptorPool as an underlay. Typically, a type can only find
- extensions in its own pool. With the python C-extension, the generated C++
- extendee may be available, but not the extension. This tests that the
- C-extension implements the correct special handling to make such extensions
- available.
- """
- pb1 = more_extensions_pb2.ExtendedMessage()
- # Test that basic accessors work.
- self.assertFalse(
- pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension))
- self.assertFalse(
- pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension))
- pb1.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension] = 17
- pb1.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a = 24
- self.assertTrue(
- pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension))
- self.assertTrue(
- pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension))
-
- # Now serialize the data and parse to a new message.
- pb2 = more_extensions_pb2.ExtendedMessage()
- pb2.MergeFromString(pb1.SerializeToString())
-
- self.assertTrue(
- pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension))
- self.assertTrue(
- pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension))
- self.assertEqual(
- 17, pb2.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension])
- self.assertEqual(
- 24,
- pb2.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a)
-
-
-
-if __name__ == '__main__':
- basetest.main()
diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc
index 5c05b3d89..4f339e772 100644
--- a/python/google/protobuf/pyext/repeated_composite_container.cc
+++ b/python/google/protobuf/pyext/repeated_composite_container.cc
@@ -38,11 +38,13 @@
#include <google/protobuf/stubs/shared_ptr.h>
#endif
+#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/message.h>
#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
@@ -56,8 +58,6 @@ namespace google {
namespace protobuf {
namespace python {
-extern google::protobuf::DynamicMessageFactory* global_message_factory;
-
namespace repeated_composite_container {
// TODO(tibell): We might also want to check:
@@ -65,144 +65,25 @@ namespace repeated_composite_container {
#define GOOGLE_CHECK_ATTACHED(self) \
do { \
GOOGLE_CHECK_NOTNULL((self)->message); \
- GOOGLE_CHECK_NOTNULL((self)->parent_field); \
+ GOOGLE_CHECK_NOTNULL((self)->parent_field_descriptor); \
} while (0);
#define GOOGLE_CHECK_RELEASED(self) \
do { \
GOOGLE_CHECK((self)->owner.get() == NULL); \
GOOGLE_CHECK((self)->message == NULL); \
- GOOGLE_CHECK((self)->parent_field == NULL); \
+ GOOGLE_CHECK((self)->parent_field_descriptor == NULL); \
GOOGLE_CHECK((self)->parent == NULL); \
} while (0);
-// Returns a new reference.
-static PyObject* GetKey(PyObject* x) {
- // Just the identity function.
- Py_INCREF(x);
- return x;
-}
-
-#define GET_KEY(keyfunc, value) \
- ((keyfunc) == NULL ? \
- GetKey((value)) : \
- PyObject_CallFunctionObjArgs((keyfunc), (value), NULL))
-
-// Converts a comparison function that returns -1, 0, or 1 into a
-// less-than predicate.
-//
-// Returns -1 on error, 1 if x < y, 0 if x >= y.
-static int islt(PyObject *x, PyObject *y, PyObject *compare) {
- if (compare == NULL)
- return PyObject_RichCompareBool(x, y, Py_LT);
-
- ScopedPyObjectPtr res(PyObject_CallFunctionObjArgs(compare, x, y, NULL));
- if (res == NULL)
- return -1;
- if (!PyInt_Check(res)) {
- PyErr_Format(PyExc_TypeError,
- "comparison function must return int, not %.200s",
- Py_TYPE(res)->tp_name);
- return -1;
- }
- return PyInt_AsLong(res) < 0;
-}
-
-// Copied from uarrsort.c but swaps memcpy swaps with protobuf/python swaps
-// TODO(anuraag): Is there a better way to do this then reinventing the wheel?
-static int InternalQuickSort(RepeatedCompositeContainer* self,
- Py_ssize_t start,
- Py_ssize_t limit,
- PyObject* cmp,
- PyObject* keyfunc) {
- if (limit - start <= 1)
- return 0; // Nothing to sort.
-
- GOOGLE_CHECK_ATTACHED(self);
-
- google::protobuf::Message* message = self->message;
- const google::protobuf::Reflection* reflection = message->GetReflection();
- const google::protobuf::FieldDescriptor* descriptor = self->parent_field->descriptor;
- Py_ssize_t left;
- Py_ssize_t right;
-
- PyObject* children = self->child_messages;
-
- do {
- left = start;
- right = limit;
- ScopedPyObjectPtr mid(
- GET_KEY(keyfunc, PyList_GET_ITEM(children, (start + limit) / 2)));
- do {
- ScopedPyObjectPtr key(GET_KEY(keyfunc, PyList_GET_ITEM(children, left)));
- int is_lt = islt(key, mid, cmp);
- if (is_lt == -1)
- return -1;
- /* array[left]<x */
- while (is_lt) {
- ++left;
- ScopedPyObjectPtr key(GET_KEY(keyfunc,
- PyList_GET_ITEM(children, left)));
- is_lt = islt(key, mid, cmp);
- if (is_lt == -1)
- return -1;
- }
- key.reset(GET_KEY(keyfunc, PyList_GET_ITEM(children, right - 1)));
- is_lt = islt(mid, key, cmp);
- if (is_lt == -1)
- return -1;
- while (is_lt) {
- --right;
- ScopedPyObjectPtr key(GET_KEY(keyfunc,
- PyList_GET_ITEM(children, right - 1)));
- is_lt = islt(mid, key, cmp);
- if (is_lt == -1)
- return -1;
- }
- if (left < right) {
- --right;
- if (left < right) {
- reflection->SwapElements(message, descriptor, left, right);
- PyObject* tmp = PyList_GET_ITEM(children, left);
- PyList_SET_ITEM(children, left, PyList_GET_ITEM(children, right));
- PyList_SET_ITEM(children, right, tmp);
- }
- ++left;
- }
- } while (left < right);
-
- if ((right - start) < (limit - left)) {
- /* sort [start..right[ */
- if (start < (right - 1)) {
- InternalQuickSort(self, start, right, cmp, keyfunc);
- }
-
- /* sort [left..limit[ */
- start = left;
- } else {
- /* sort [left..limit[ */
- if (left < (limit - 1)) {
- InternalQuickSort(self, left, limit, cmp, keyfunc);
- }
-
- /* sort [start..right[ */
- limit = right;
- }
- } while (start < (limit - 1));
-
- return 0;
-}
-
-#undef GET_KEY
-
// ---------------------------------------------------------------------
// len()
static Py_ssize_t Length(RepeatedCompositeContainer* self) {
- google::protobuf::Message* message = self->message;
+ Message* message = self->message;
if (message != NULL) {
return message->GetReflection()->FieldSize(*message,
- self->parent_field->descriptor);
+ self->parent_field_descriptor);
} else {
// The container has been released (i.e. by a call to Clear() or
// ClearField() on the parent) and thus there's no message.
@@ -221,23 +102,22 @@ static int UpdateChildMessages(RepeatedCompositeContainer* self) {
// be removed in such a way so there's no need to worry about that.
Py_ssize_t message_length = Length(self);
Py_ssize_t child_length = PyList_GET_SIZE(self->child_messages);
- google::protobuf::Message* message = self->message;
- const google::protobuf::Reflection* reflection = message->GetReflection();
+ Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
for (Py_ssize_t i = child_length; i < message_length; ++i) {
const Message& sub_message = reflection->GetRepeatedMessage(
- *(self->message), self->parent_field->descriptor, i);
- ScopedPyObjectPtr py_cmsg(cmessage::NewEmpty(self->subclass_init));
- if (py_cmsg == NULL) {
+ *(self->message), self->parent_field_descriptor, i);
+ CMessage* cmsg = cmessage::NewEmptyMessage(self->child_message_class);
+ ScopedPyObjectPtr py_cmsg(reinterpret_cast<PyObject*>(cmsg));
+ if (cmsg == NULL) {
return -1;
}
- CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg.get());
cmsg->owner = self->owner;
- cmsg->message = const_cast<google::protobuf::Message*>(&sub_message);
+ cmsg->message = const_cast<Message*>(&sub_message);
cmsg->parent = self->parent;
- if (cmessage::InitAttributes(cmsg, NULL, NULL) < 0) {
+ if (PyList_Append(self->child_messages, py_cmsg.get()) < 0) {
return -1;
}
- PyList_Append(self->child_messages, py_cmsg);
}
return 0;
}
@@ -255,26 +135,27 @@ static PyObject* AddToAttached(RepeatedCompositeContainer* self,
}
if (cmessage::AssureWritable(self->parent) == -1)
return NULL;
- google::protobuf::Message* message = self->message;
- google::protobuf::Message* sub_message =
+ Message* message = self->message;
+ Message* sub_message =
message->GetReflection()->AddMessage(message,
- self->parent_field->descriptor);
- PyObject* py_cmsg = cmessage::NewEmpty(self->subclass_init);
- if (py_cmsg == NULL) {
+ self->parent_field_descriptor);
+ CMessage* cmsg = cmessage::NewEmptyMessage(self->child_message_class);
+ if (cmsg == NULL)
return NULL;
- }
- CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
cmsg->owner = self->owner;
cmsg->message = sub_message;
cmsg->parent = self->parent;
- // cmessage::InitAttributes must be called after cmsg->message has
- // been set.
- if (cmessage::InitAttributes(cmsg, NULL, kwargs) < 0) {
+ if (cmessage::InitAttributes(cmsg, kwargs) < 0) {
+ Py_DECREF(cmsg);
+ return NULL;
+ }
+
+ PyObject* py_cmsg = reinterpret_cast<PyObject*>(cmsg);
+ if (PyList_Append(self->child_messages, py_cmsg) < 0) {
Py_DECREF(py_cmsg);
return NULL;
}
- PyList_Append(self->child_messages, py_cmsg);
return py_cmsg;
}
@@ -283,20 +164,16 @@ static PyObject* AddToReleased(RepeatedCompositeContainer* self,
PyObject* kwargs) {
GOOGLE_CHECK_RELEASED(self);
- // Create the CMessage
- PyObject* py_cmsg = PyObject_CallObject(self->subclass_init, NULL);
+ // Create a new Message detached from the rest.
+ PyObject* py_cmsg = PyEval_CallObjectWithKeywords(
+ self->child_message_class->AsPyObject(), NULL, kwargs);
if (py_cmsg == NULL)
return NULL;
- CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
- if (cmessage::InitAttributes(cmsg, NULL, kwargs) < 0) {
+
+ if (PyList_Append(self->child_messages, py_cmsg) < 0) {
Py_DECREF(py_cmsg);
return NULL;
}
-
- // The Message got created by the call to subclass_init above and
- // it set self->owner to the newly allocated message.
-
- PyList_Append(self->child_messages, py_cmsg);
return py_cmsg;
}
@@ -323,8 +200,8 @@ PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) {
return NULL;
}
ScopedPyObjectPtr next;
- while ((next.reset(PyIter_Next(iter))) != NULL) {
- if (!PyObject_TypeCheck(next, &CMessage_Type)) {
+ while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
+ if (!PyObject_TypeCheck(next.get(), &CMessage_Type)) {
PyErr_SetString(PyExc_TypeError, "Not a cmessage");
return NULL;
}
@@ -333,7 +210,8 @@ PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) {
return NULL;
}
CMessage* new_cmessage = reinterpret_cast<CMessage*>(new_message.get());
- if (cmessage::MergeFrom(new_cmessage, next) == NULL) {
+ if (ScopedPyObjectPtr(cmessage::MergeFrom(new_cmessage, next.get())) ==
+ NULL) {
return NULL;
}
}
@@ -354,35 +232,9 @@ PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* slice) {
if (UpdateChildMessages(self) < 0) {
return NULL;
}
- Py_ssize_t from;
- Py_ssize_t to;
- Py_ssize_t step;
- Py_ssize_t length = Length(self);
- Py_ssize_t slicelength;
- if (PySlice_Check(slice)) {
-#if PY_MAJOR_VERSION >= 3
- if (PySlice_GetIndicesEx(slice,
-#else
- if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice),
-#endif
- length, &from, &to, &step, &slicelength) == -1) {
- return NULL;
- }
- return PyList_GetSlice(self->child_messages, from, to);
- } else if (PyInt_Check(slice) || PyLong_Check(slice)) {
- from = to = PyLong_AsLong(slice);
- if (from < 0) {
- from = to = length + from;
- }
- PyObject* result = PyList_GetItem(self->child_messages, from);
- if (result == NULL) {
- return NULL;
- }
- Py_INCREF(result);
- return result;
- }
- PyErr_SetString(PyExc_TypeError, "index must be an integer or slice");
- return NULL;
+ // Just forward the call to the subscript-handling function of the
+ // list containing the child messages.
+ return PyObject_GetItem(self->child_messages, slice);
}
int AssignSubscript(RepeatedCompositeContainer* self,
@@ -397,9 +249,9 @@ int AssignSubscript(RepeatedCompositeContainer* self,
}
// Delete from the underlying Message, if any.
- if (self->message != NULL) {
- if (cmessage::InternalDeleteRepeatedField(self->message,
- self->parent_field->descriptor,
+ if (self->parent != NULL) {
+ if (cmessage::InternalDeleteRepeatedField(self->parent,
+ self->parent_field_descriptor,
slice,
self->child_messages) < 0) {
return -1;
@@ -441,7 +293,7 @@ static PyObject* Remove(RepeatedCompositeContainer* self, PyObject* value) {
return NULL;
}
ScopedPyObjectPtr py_index(PyLong_FromLong(index));
- if (AssignSubscript(self, py_index, NULL) < 0) {
+ if (AssignSubscript(self, py_index.get(), NULL) < 0) {
return NULL;
}
Py_RETURN_NONE;
@@ -465,17 +317,17 @@ static PyObject* RichCompare(RepeatedCompositeContainer* self,
if (full_slice == NULL) {
return NULL;
}
- ScopedPyObjectPtr list(Subscript(self, full_slice));
+ ScopedPyObjectPtr list(Subscript(self, full_slice.get()));
if (list == NULL) {
return NULL;
}
ScopedPyObjectPtr other_list(
- Subscript(
- reinterpret_cast<RepeatedCompositeContainer*>(other), full_slice));
+ Subscript(reinterpret_cast<RepeatedCompositeContainer*>(other),
+ full_slice.get()));
if (other_list == NULL) {
return NULL;
}
- return PyObject_RichCompare(list, other_list, opid);
+ return PyObject_RichCompare(list.get(), other_list.get(), opid);
} else {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
@@ -485,58 +337,39 @@ static PyObject* RichCompare(RepeatedCompositeContainer* self,
// ---------------------------------------------------------------------
// sort()
-static PyObject* SortAttached(RepeatedCompositeContainer* self,
- PyObject* args,
- PyObject* kwds) {
- // Sort the underlying Message array.
- PyObject *compare = NULL;
- int reverse = 0;
- PyObject *keyfunc = NULL;
- static char *kwlist[] = {"cmp", "key", "reverse", 0};
-
- if (args != NULL) {
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOi:sort",
- kwlist, &compare, &keyfunc, &reverse))
- return NULL;
- }
- if (compare == Py_None)
- compare = NULL;
- if (keyfunc == Py_None)
- keyfunc = NULL;
-
+static void ReorderAttached(RepeatedCompositeContainer* self) {
+ Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+ const FieldDescriptor* descriptor = self->parent_field_descriptor;
const Py_ssize_t length = Length(self);
- if (InternalQuickSort(self, 0, length, compare, keyfunc) < 0)
- return NULL;
-
- // Finally reverse the result if requested.
- if (reverse) {
- google::protobuf::Message* message = self->message;
- const google::protobuf::Reflection* reflection = message->GetReflection();
- const google::protobuf::FieldDescriptor* descriptor = self->parent_field->descriptor;
- // Reverse the Message array.
- for (int i = 0; i < length / 2; ++i)
- reflection->SwapElements(message, descriptor, i, length - i - 1);
+ // Since Python protobuf objects are never arena-allocated, adding and
+ // removing message pointers to the underlying array is just updating
+ // pointers.
+ for (Py_ssize_t i = 0; i < length; ++i)
+ reflection->ReleaseLast(message, descriptor);
- // Reverse the Python list.
- ScopedPyObjectPtr res(PyObject_CallMethod(self->child_messages,
- "reverse", NULL));
- if (res == NULL)
- return NULL;
+ for (Py_ssize_t i = 0; i < length; ++i) {
+ CMessage* py_cmsg = reinterpret_cast<CMessage*>(
+ PyList_GET_ITEM(self->child_messages, i));
+ reflection->AddAllocatedMessage(message, descriptor, py_cmsg->message);
}
-
- Py_RETURN_NONE;
}
-static PyObject* SortReleased(RepeatedCompositeContainer* self,
- PyObject* args,
- PyObject* kwds) {
+// Returns 0 if successful; returns -1 and sets an exception if
+// unsuccessful.
+static int SortPythonMessages(RepeatedCompositeContainer* self,
+ PyObject* args,
+ PyObject* kwds) {
ScopedPyObjectPtr m(PyObject_GetAttrString(self->child_messages, "sort"));
if (m == NULL)
- return NULL;
- if (PyObject_Call(m, args, kwds) == NULL)
- return NULL;
- Py_RETURN_NONE;
+ return -1;
+ if (PyObject_Call(m.get(), args, kwds) == NULL)
+ return -1;
+ if (self->message != NULL) {
+ ReorderAttached(self);
+ }
+ return 0;
}
static PyObject* Sort(RepeatedCompositeContainer* self,
@@ -554,13 +387,13 @@ static PyObject* Sort(RepeatedCompositeContainer* self,
}
}
- if (UpdateChildMessages(self) < 0)
+ if (UpdateChildMessages(self) < 0) {
return NULL;
- if (self->message == NULL) {
- return SortReleased(self, args, kwds);
- } else {
- return SortAttached(self, args, kwds);
}
+ if (SortPythonMessages(self, args, kwds) < 0) {
+ return NULL;
+ }
+ Py_RETURN_NONE;
}
// ---------------------------------------------------------------------
@@ -581,46 +414,43 @@ static PyObject* Item(RepeatedCompositeContainer* self, Py_ssize_t index) {
return item;
}
-// The caller takes ownership of the returned Message.
-Message* ReleaseLast(const FieldDescriptor* field,
- const Descriptor* type,
- Message* message) {
- GOOGLE_CHECK_NOTNULL(field);
- GOOGLE_CHECK_NOTNULL(type);
- GOOGLE_CHECK_NOTNULL(message);
-
- Message* released_message = message->GetReflection()->ReleaseLast(
- message, field);
- // TODO(tibell): Deal with proto1.
-
- // ReleaseMessage will return NULL which differs from
- // child_cmessage->message, if the field does not exist. In this case,
- // the latter points to the default instance via a const_cast<>, so we
- // have to reset it to a new mutable object since we are taking ownership.
- if (released_message == NULL) {
- const Message* prototype = global_message_factory->GetPrototype(type);
- GOOGLE_CHECK_NOTNULL(prototype);
- return prototype->New();
- } else {
- return released_message;
+static PyObject* Pop(RepeatedCompositeContainer* self,
+ PyObject* args) {
+ Py_ssize_t index = -1;
+ if (!PyArg_ParseTuple(args, "|n", &index)) {
+ return NULL;
}
+ PyObject* item = Item(self, index);
+ if (item == NULL) {
+ PyErr_Format(PyExc_IndexError,
+ "list index (%zd) out of range",
+ index);
+ return NULL;
+ }
+ ScopedPyObjectPtr py_index(PyLong_FromSsize_t(index));
+ if (AssignSubscript(self, py_index.get(), NULL) < 0) {
+ return NULL;
+ }
+ return item;
}
-// Release field of message and transfer the ownership to cmessage.
-void ReleaseLastTo(const FieldDescriptor* field,
- Message* message,
- CMessage* cmessage) {
+// Release field of parent message and transfer the ownership to target.
+void ReleaseLastTo(CMessage* parent,
+ const FieldDescriptor* field,
+ CMessage* target) {
+ GOOGLE_CHECK_NOTNULL(parent);
GOOGLE_CHECK_NOTNULL(field);
- GOOGLE_CHECK_NOTNULL(message);
- GOOGLE_CHECK_NOTNULL(cmessage);
+ GOOGLE_CHECK_NOTNULL(target);
shared_ptr<Message> released_message(
- ReleaseLast(field, cmessage->message->GetDescriptor(), message));
- cmessage->parent = NULL;
- cmessage->parent_field = NULL;
- cmessage->message = released_message.get();
- cmessage->read_only = false;
- cmessage::SetOwner(cmessage, released_message);
+ parent->message->GetReflection()->ReleaseLast(parent->message, field));
+ // TODO(tibell): Deal with proto1.
+
+ target->parent = NULL;
+ target->parent_field_descriptor = NULL;
+ target->message = released_message.get();
+ target->read_only = false;
+ cmessage::SetOwner(target, released_message);
}
// Called to release a container using
@@ -633,7 +463,7 @@ int Release(RepeatedCompositeContainer* self) {
}
Message* message = self->message;
- const FieldDescriptor* field = self->parent_field->descriptor;
+ const FieldDescriptor* field = self->parent_field_descriptor;
// The reflection API only lets us release the last message in a
// repeated field. Therefore we iterate through the children
@@ -643,12 +473,12 @@ int Release(RepeatedCompositeContainer* self) {
for (Py_ssize_t i = size - 1; i >= 0; --i) {
CMessage* child_cmessage = reinterpret_cast<CMessage*>(
PyList_GET_ITEM(self->child_messages, i));
- ReleaseLastTo(field, message, child_cmessage);
+ ReleaseLastTo(self->parent, field, child_cmessage);
}
// Detach from containing message.
self->parent = NULL;
- self->parent_field = NULL;
+ self->parent_field_descriptor = NULL;
self->message = NULL;
self->owner.reset();
@@ -670,22 +500,40 @@ int SetOwner(RepeatedCompositeContainer* self,
return 0;
}
-static int Init(RepeatedCompositeContainer* self,
- PyObject* args,
- PyObject* kwargs) {
- self->message = NULL;
- self->parent = NULL;
- self->parent_field = NULL;
- self->subclass_init = NULL;
+// The private constructor of RepeatedCompositeContainer objects.
+PyObject *NewContainer(
+ CMessage* parent,
+ const FieldDescriptor* parent_field_descriptor,
+ CMessageClass* concrete_class) {
+ if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
+ return NULL;
+ }
+
+ RepeatedCompositeContainer* self =
+ reinterpret_cast<RepeatedCompositeContainer*>(
+ PyType_GenericAlloc(&RepeatedCompositeContainer_Type, 0));
+ if (self == NULL) {
+ return NULL;
+ }
+
+ self->message = parent->message;
+ self->parent = parent;
+ self->parent_field_descriptor = parent_field_descriptor;
+ self->owner = parent->owner;
+ Py_INCREF(concrete_class);
+ self->child_message_class = concrete_class;
self->child_messages = PyList_New(0);
- return 0;
+
+ return reinterpret_cast<PyObject*>(self);
}
static void Dealloc(RepeatedCompositeContainer* self) {
Py_CLEAR(self->child_messages);
+ Py_CLEAR(self->child_message_class);
// TODO(tibell): Do we need to call delete on these objects to make
// sure their destructors are called?
self->owner.reset();
+
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@@ -707,6 +555,8 @@ static PyMethodDef Methods[] = {
"Adds an object to the repeated container." },
{ "extend", (PyCFunction) Extend, METH_O,
"Adds objects to the repeated container." },
+ { "pop", (PyCFunction)Pop, METH_VARARGS,
+ "Removes an object from the repeated container and returns it." },
{ "remove", (PyCFunction) Remove, METH_O,
"Removes an object from the repeated container." },
{ "sort", (PyCFunction) Sort, METH_VARARGS | METH_KEYWORDS,
@@ -720,9 +570,8 @@ static PyMethodDef Methods[] = {
PyTypeObject RepeatedCompositeContainer_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- "google.protobuf.internal."
- "cpp._message.RepeatedCompositeContainer", // tp_name
- sizeof(RepeatedCompositeContainer), // tp_basicsize
+ FULL_MODULE_NAME ".RepeatedCompositeContainer", // tp_name
+ sizeof(RepeatedCompositeContainer), // tp_basicsize
0, // tp_itemsize
(destructor)repeated_composite_container::Dealloc, // tp_dealloc
0, // tp_print
@@ -733,7 +582,7 @@ PyTypeObject RepeatedCompositeContainer_Type = {
0, // tp_as_number
&repeated_composite_container::SqMethods, // tp_as_sequence
&repeated_composite_container::MpMethods, // tp_as_mapping
- 0, // tp_hash
+ PyObject_HashNotImplemented, // tp_hash
0, // tp_call
0, // tp_str
0, // tp_getattro
@@ -755,7 +604,7 @@ PyTypeObject RepeatedCompositeContainer_Type = {
0, // tp_descr_get
0, // tp_descr_set
0, // tp_dictoffset
- (initproc)repeated_composite_container::Init, // tp_init
+ 0, // tp_init
};
} // namespace python
diff --git a/python/google/protobuf/pyext/repeated_composite_container.h b/python/google/protobuf/pyext/repeated_composite_container.h
index 898ef5a71..a7b56b61b 100644
--- a/python/google/protobuf/pyext/repeated_composite_container.h
+++ b/python/google/protobuf/pyext/repeated_composite_container.h
@@ -43,30 +43,33 @@
#include <string>
#include <vector>
-
namespace google {
namespace protobuf {
class FieldDescriptor;
class Message;
+#ifdef _SHARED_PTR_H
+using std::shared_ptr;
+#else
using internal::shared_ptr;
+#endif
namespace python {
struct CMessage;
-struct CFieldDescriptor;
+struct CMessageClass;
// A RepeatedCompositeContainer can be in one of two states: attached
// or released.
//
// When in the attached state all modifications to the container are
// done both on the 'message' and on the 'child_messages'
-// list. In this state all Messages refered to by the children in
+// list. In this state all Messages referred to by the children in
// 'child_messages' are owner by the 'owner'.
//
// When in the released state 'message', 'owner', 'parent', and
-// 'parent_field' are NULL.
+// 'parent_field_descriptor' are NULL.
typedef struct RepeatedCompositeContainer {
PyObject_HEAD;
@@ -82,7 +85,8 @@ typedef struct RepeatedCompositeContainer {
CMessage* parent;
// A descriptor used to modify the underlying 'message'.
- CFieldDescriptor* parent_field;
+ // The pointer is owned by the global DescriptorPool.
+ const FieldDescriptor* parent_field_descriptor;
// Pointer to the C++ Message that contains this container. The
// RepeatedCompositeContainer does not own this pointer.
@@ -91,8 +95,8 @@ typedef struct RepeatedCompositeContainer {
// calling Clear() or ClearField() on the parent.
Message* message;
- // A callable that is used to create new child messages.
- PyObject* subclass_init;
+ // The type used to create new child messages.
+ CMessageClass* child_message_class;
// A list of child messages.
PyObject* child_messages;
@@ -102,8 +106,12 @@ extern PyTypeObject RepeatedCompositeContainer_Type;
namespace repeated_composite_container {
-// Returns the number of items in this repeated composite container.
-static Py_ssize_t Length(RepeatedCompositeContainer* self);
+// Builds a RepeatedCompositeContainer object, from a parent message and a
+// field descriptor.
+PyObject *NewContainer(
+ CMessage* parent,
+ const FieldDescriptor* parent_field_descriptor,
+ CMessageClass *child_message_class);
// Appends a new CMessage to the container and returns it. The
// CMessage is initialized using the content of kwargs.
@@ -143,8 +151,7 @@ int AssignSubscript(RepeatedCompositeContainer* self,
// Releases the messages in the container to the given message.
//
// Returns 0 on success, -1 on failure.
-int ReleaseToMessage(RepeatedCompositeContainer* self,
- google::protobuf::Message* new_message);
+int ReleaseToMessage(RepeatedCompositeContainer* self, Message* new_message);
// Releases the messages in the container to a new message.
//
@@ -156,13 +163,13 @@ int SetOwner(RepeatedCompositeContainer* self,
const shared_ptr<Message>& new_owner);
// Removes the last element of the repeated message field 'field' on
-// the Message 'message', and transfers the ownership of the released
-// Message to 'cmessage'.
+// the Message 'parent', and transfers the ownership of the released
+// Message to 'target'.
//
// Corresponds to reflection api method ReleaseMessage.
-void ReleaseLastTo(const FieldDescriptor* field,
- Message* message,
- CMessage* cmessage);
+void ReleaseLastTo(CMessage* parent,
+ const FieldDescriptor* field,
+ CMessage* target);
} // namespace repeated_composite_container
} // namespace python
diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc
index e627d37dc..95da85f87 100644
--- a/python/google/protobuf/pyext/repeated_scalar_container.cc
+++ b/python/google/protobuf/pyext/repeated_scalar_container.cc
@@ -39,10 +39,12 @@
#endif
#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/message.h>
#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
@@ -52,7 +54,7 @@
#error "Python 3.0 - 3.2 are not supported."
#else
#define PyString_AsString(ob) \
- (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob))
+ (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob))
#endif
#endif
@@ -60,17 +62,15 @@ namespace google {
namespace protobuf {
namespace python {
-extern google::protobuf::DynamicMessageFactory* global_message_factory;
-
namespace repeated_scalar_container {
static int InternalAssignRepeatedField(
RepeatedScalarContainer* self, PyObject* list) {
self->message->GetReflection()->ClearField(self->message,
- self->parent_field->descriptor);
+ self->parent_field_descriptor);
for (Py_ssize_t i = 0; i < PyList_GET_SIZE(list); ++i) {
PyObject* value = PyList_GET_ITEM(list, i);
- if (Append(self, value) == NULL) {
+ if (ScopedPyObjectPtr(Append(self, value)) == NULL) {
return -1;
}
}
@@ -78,25 +78,19 @@ static int InternalAssignRepeatedField(
}
static Py_ssize_t Len(RepeatedScalarContainer* self) {
- google::protobuf::Message* message = self->message;
+ Message* message = self->message;
return message->GetReflection()->FieldSize(*message,
- self->parent_field->descriptor);
+ self->parent_field_descriptor);
}
static int AssignItem(RepeatedScalarContainer* self,
Py_ssize_t index,
PyObject* arg) {
cmessage::AssureWritable(self->parent);
- google::protobuf::Message* message = self->message;
- const google::protobuf::FieldDescriptor* field_descriptor =
- self->parent_field->descriptor;
- if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) {
- PyErr_SetString(
- PyExc_KeyError, "Field does not belong to message!");
- return -1;
- }
+ Message* message = self->message;
+ const FieldDescriptor* field_descriptor = self->parent_field_descriptor;
- const google::protobuf::Reflection* reflection = message->GetReflection();
+ const Reflection* reflection = message->GetReflection();
int field_size = reflection->FieldSize(*message, field_descriptor);
if (index < 0) {
index = field_size + index;
@@ -110,8 +104,8 @@ static int AssignItem(RepeatedScalarContainer* self,
if (arg == NULL) {
ScopedPyObjectPtr py_index(PyLong_FromLong(index));
- return cmessage::InternalDeleteRepeatedField(message, field_descriptor,
- py_index, NULL);
+ return cmessage::InternalDeleteRepeatedField(self->parent, field_descriptor,
+ py_index.get(), NULL);
}
if (PySequence_Check(arg) && !(PyBytes_Check(arg) || PyUnicode_Check(arg))) {
@@ -120,64 +114,68 @@ static int AssignItem(RepeatedScalarContainer* self,
}
switch (field_descriptor->cpp_type()) {
- case google::protobuf::FieldDescriptor::CPPTYPE_INT32: {
+ case FieldDescriptor::CPPTYPE_INT32: {
GOOGLE_CHECK_GET_INT32(arg, value, -1);
reflection->SetRepeatedInt32(message, field_descriptor, index, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_INT64: {
+ case FieldDescriptor::CPPTYPE_INT64: {
GOOGLE_CHECK_GET_INT64(arg, value, -1);
reflection->SetRepeatedInt64(message, field_descriptor, index, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: {
+ case FieldDescriptor::CPPTYPE_UINT32: {
GOOGLE_CHECK_GET_UINT32(arg, value, -1);
reflection->SetRepeatedUInt32(message, field_descriptor, index, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: {
+ case FieldDescriptor::CPPTYPE_UINT64: {
GOOGLE_CHECK_GET_UINT64(arg, value, -1);
reflection->SetRepeatedUInt64(message, field_descriptor, index, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: {
+ case FieldDescriptor::CPPTYPE_FLOAT: {
GOOGLE_CHECK_GET_FLOAT(arg, value, -1);
reflection->SetRepeatedFloat(message, field_descriptor, index, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: {
+ case FieldDescriptor::CPPTYPE_DOUBLE: {
GOOGLE_CHECK_GET_DOUBLE(arg, value, -1);
reflection->SetRepeatedDouble(message, field_descriptor, index, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: {
+ case FieldDescriptor::CPPTYPE_BOOL: {
GOOGLE_CHECK_GET_BOOL(arg, value, -1);
reflection->SetRepeatedBool(message, field_descriptor, index, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
+ case FieldDescriptor::CPPTYPE_STRING: {
if (!CheckAndSetString(
arg, message, field_descriptor, reflection, false, index)) {
return -1;
}
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
+ case FieldDescriptor::CPPTYPE_ENUM: {
GOOGLE_CHECK_GET_INT32(arg, value, -1);
- const google::protobuf::EnumDescriptor* enum_descriptor =
- field_descriptor->enum_type();
- const google::protobuf::EnumValueDescriptor* enum_value =
- enum_descriptor->FindValueByNumber(value);
- if (enum_value != NULL) {
- reflection->SetRepeatedEnum(message, field_descriptor, index,
- enum_value);
+ if (reflection->SupportsUnknownEnumValues()) {
+ reflection->SetRepeatedEnumValue(message, field_descriptor, index,
+ value);
} else {
- ScopedPyObjectPtr s(PyObject_Str(arg));
- if (s != NULL) {
- PyErr_Format(PyExc_ValueError, "Unknown enum value: %s",
- PyString_AsString(s.get()));
+ const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
+ const EnumValueDescriptor* enum_value =
+ enum_descriptor->FindValueByNumber(value);
+ if (enum_value != NULL) {
+ reflection->SetRepeatedEnum(message, field_descriptor, index,
+ enum_value);
+ } else {
+ ScopedPyObjectPtr s(PyObject_Str(arg));
+ if (s != NULL) {
+ PyErr_Format(PyExc_ValueError, "Unknown enum value: %s",
+ PyString_AsString(s.get()));
+ }
+ return -1;
}
- return -1;
}
break;
}
@@ -191,10 +189,9 @@ static int AssignItem(RepeatedScalarContainer* self,
}
static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) {
- google::protobuf::Message* message = self->message;
- const google::protobuf::FieldDescriptor* field_descriptor =
- self->parent_field->descriptor;
- const google::protobuf::Reflection* reflection = message->GetReflection();
+ Message* message = self->message;
+ const FieldDescriptor* field_descriptor = self->parent_field_descriptor;
+ const Reflection* reflection = message->GetReflection();
int field_size = reflection->FieldSize(*message, field_descriptor);
if (index < 0) {
@@ -202,80 +199,80 @@ static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) {
}
if (index < 0 || index >= field_size) {
PyErr_Format(PyExc_IndexError,
- "list assignment index (%d) out of range",
- static_cast<int>(index));
+ "list index (%zd) out of range",
+ index);
return NULL;
}
PyObject* result = NULL;
switch (field_descriptor->cpp_type()) {
- case google::protobuf::FieldDescriptor::CPPTYPE_INT32: {
+ case FieldDescriptor::CPPTYPE_INT32: {
int32 value = reflection->GetRepeatedInt32(
*message, field_descriptor, index);
result = PyInt_FromLong(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_INT64: {
+ case FieldDescriptor::CPPTYPE_INT64: {
int64 value = reflection->GetRepeatedInt64(
*message, field_descriptor, index);
result = PyLong_FromLongLong(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: {
+ case FieldDescriptor::CPPTYPE_UINT32: {
uint32 value = reflection->GetRepeatedUInt32(
*message, field_descriptor, index);
result = PyLong_FromLongLong(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: {
+ case FieldDescriptor::CPPTYPE_UINT64: {
uint64 value = reflection->GetRepeatedUInt64(
*message, field_descriptor, index);
result = PyLong_FromUnsignedLongLong(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: {
+ case FieldDescriptor::CPPTYPE_FLOAT: {
float value = reflection->GetRepeatedFloat(
*message, field_descriptor, index);
result = PyFloat_FromDouble(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: {
+ case FieldDescriptor::CPPTYPE_DOUBLE: {
double value = reflection->GetRepeatedDouble(
*message, field_descriptor, index);
result = PyFloat_FromDouble(value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: {
+ case FieldDescriptor::CPPTYPE_BOOL: {
bool value = reflection->GetRepeatedBool(
*message, field_descriptor, index);
result = PyBool_FromLong(value ? 1 : 0);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
- const google::protobuf::EnumValueDescriptor* enum_value =
+ case FieldDescriptor::CPPTYPE_ENUM: {
+ const EnumValueDescriptor* enum_value =
message->GetReflection()->GetRepeatedEnum(
*message, field_descriptor, index);
result = PyInt_FromLong(enum_value->number());
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
+ case FieldDescriptor::CPPTYPE_STRING: {
string value = reflection->GetRepeatedString(
*message, field_descriptor, index);
result = ToStringObject(field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
+ case FieldDescriptor::CPPTYPE_MESSAGE: {
PyObject* py_cmsg = PyObject_CallObject(reinterpret_cast<PyObject*>(
&CMessage_Type), NULL);
if (py_cmsg == NULL) {
return NULL;
}
CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
- const google::protobuf::Message& msg = reflection->GetRepeatedMessage(
+ const Message& msg = reflection->GetRepeatedMessage(
*message, field_descriptor, index);
cmsg->owner = self->owner;
cmsg->parent = self->parent;
- cmsg->message = const_cast<google::protobuf::Message*>(&msg);
+ cmsg->message = const_cast<Message*>(&msg);
cmsg->read_only = false;
result = reinterpret_cast<PyObject*>(py_cmsg);
break;
@@ -337,7 +334,7 @@ static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) {
break;
}
ScopedPyObjectPtr s(Item(self, index));
- PyList_Append(list, s);
+ PyList_Append(list, s.get());
}
} else {
if (step > 0) {
@@ -348,7 +345,7 @@ static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) {
break;
}
ScopedPyObjectPtr s(Item(self, index));
- PyList_Append(list, s);
+ PyList_Append(list, s.get());
}
}
return list;
@@ -356,75 +353,71 @@ static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) {
PyObject* Append(RepeatedScalarContainer* self, PyObject* item) {
cmessage::AssureWritable(self->parent);
- google::protobuf::Message* message = self->message;
- const google::protobuf::FieldDescriptor* field_descriptor =
- self->parent_field->descriptor;
-
- if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) {
- PyErr_SetString(
- PyExc_KeyError, "Field does not belong to message!");
- return NULL;
- }
+ Message* message = self->message;
+ const FieldDescriptor* field_descriptor = self->parent_field_descriptor;
- const google::protobuf::Reflection* reflection = message->GetReflection();
+ const Reflection* reflection = message->GetReflection();
switch (field_descriptor->cpp_type()) {
- case google::protobuf::FieldDescriptor::CPPTYPE_INT32: {
+ case FieldDescriptor::CPPTYPE_INT32: {
GOOGLE_CHECK_GET_INT32(item, value, NULL);
reflection->AddInt32(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_INT64: {
+ case FieldDescriptor::CPPTYPE_INT64: {
GOOGLE_CHECK_GET_INT64(item, value, NULL);
reflection->AddInt64(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: {
+ case FieldDescriptor::CPPTYPE_UINT32: {
GOOGLE_CHECK_GET_UINT32(item, value, NULL);
reflection->AddUInt32(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: {
+ case FieldDescriptor::CPPTYPE_UINT64: {
GOOGLE_CHECK_GET_UINT64(item, value, NULL);
reflection->AddUInt64(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: {
+ case FieldDescriptor::CPPTYPE_FLOAT: {
GOOGLE_CHECK_GET_FLOAT(item, value, NULL);
reflection->AddFloat(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: {
+ case FieldDescriptor::CPPTYPE_DOUBLE: {
GOOGLE_CHECK_GET_DOUBLE(item, value, NULL);
reflection->AddDouble(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: {
+ case FieldDescriptor::CPPTYPE_BOOL: {
GOOGLE_CHECK_GET_BOOL(item, value, NULL);
reflection->AddBool(message, field_descriptor, value);
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
+ case FieldDescriptor::CPPTYPE_STRING: {
if (!CheckAndSetString(
item, message, field_descriptor, reflection, true, -1)) {
return NULL;
}
break;
}
- case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
+ case FieldDescriptor::CPPTYPE_ENUM: {
GOOGLE_CHECK_GET_INT32(item, value, NULL);
- const google::protobuf::EnumDescriptor* enum_descriptor =
- field_descriptor->enum_type();
- const google::protobuf::EnumValueDescriptor* enum_value =
- enum_descriptor->FindValueByNumber(value);
- if (enum_value != NULL) {
- reflection->AddEnum(message, field_descriptor, enum_value);
+ if (reflection->SupportsUnknownEnumValues()) {
+ reflection->AddEnumValue(message, field_descriptor, value);
} else {
- ScopedPyObjectPtr s(PyObject_Str(item));
- if (s != NULL) {
- PyErr_Format(PyExc_ValueError, "Unknown enum value: %s",
- PyString_AsString(s.get()));
+ const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
+ const EnumValueDescriptor* enum_value =
+ enum_descriptor->FindValueByNumber(value);
+ if (enum_value != NULL) {
+ reflection->AddEnum(message, field_descriptor, enum_value);
+ } else {
+ ScopedPyObjectPtr s(PyObject_Str(item));
+ if (s != NULL) {
+ PyErr_Format(PyExc_ValueError, "Unknown enum value: %s",
+ PyString_AsString(s.get()));
+ }
+ return NULL;
}
- return NULL;
}
break;
}
@@ -449,9 +442,9 @@ static int AssSubscript(RepeatedScalarContainer* self,
bool create_list = false;
cmessage::AssureWritable(self->parent);
- google::protobuf::Message* message = self->message;
- const google::protobuf::FieldDescriptor* field_descriptor =
- self->parent_field->descriptor;
+ Message* message = self->message;
+ const FieldDescriptor* field_descriptor =
+ self->parent_field_descriptor;
#if PY_MAJOR_VERSION < 3
if (PyInt_Check(slice)) {
@@ -461,7 +454,7 @@ static int AssSubscript(RepeatedScalarContainer* self,
if (PyLong_Check(slice)) {
from = to = PyLong_AsLong(slice);
} else if (PySlice_Check(slice)) {
- const google::protobuf::Reflection* reflection = message->GetReflection();
+ const Reflection* reflection = message->GetReflection();
length = reflection->FieldSize(*message, field_descriptor);
#if PY_MAJOR_VERSION >= 3
if (PySlice_GetIndicesEx(slice,
@@ -479,7 +472,7 @@ static int AssSubscript(RepeatedScalarContainer* self,
if (value == NULL) {
return cmessage::InternalDeleteRepeatedField(
- message, field_descriptor, slice, NULL);
+ self->parent, field_descriptor, slice, NULL);
}
if (!create_list) {
@@ -490,30 +483,36 @@ static int AssSubscript(RepeatedScalarContainer* self,
if (full_slice == NULL) {
return -1;
}
- ScopedPyObjectPtr new_list(Subscript(self, full_slice));
+ ScopedPyObjectPtr new_list(Subscript(self, full_slice.get()));
if (new_list == NULL) {
return -1;
}
- if (PySequence_SetSlice(new_list, from, to, value) < 0) {
+ if (PySequence_SetSlice(new_list.get(), from, to, value) < 0) {
return -1;
}
- return InternalAssignRepeatedField(self, new_list);
+ return InternalAssignRepeatedField(self, new_list.get());
}
PyObject* Extend(RepeatedScalarContainer* self, PyObject* value) {
cmessage::AssureWritable(self->parent);
- if (PyObject_Not(value)) {
+
+ // TODO(ptucker): Deprecate this behavior. b/18413862
+ if (value == Py_None) {
+ Py_RETURN_NONE;
+ }
+ if ((Py_TYPE(value)->tp_as_sequence == NULL) && PyObject_Not(value)) {
Py_RETURN_NONE;
}
+
ScopedPyObjectPtr iter(PyObject_GetIter(value));
if (iter == NULL) {
PyErr_SetString(PyExc_TypeError, "Value must be iterable");
return NULL;
}
ScopedPyObjectPtr next;
- while ((next.reset(PyIter_Next(iter))) != NULL) {
- if (Append(self, next) == NULL) {
+ while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
+ if (ScopedPyObjectPtr(Append(self, next.get())) == NULL) {
return NULL;
}
}
@@ -530,11 +529,11 @@ static PyObject* Insert(RepeatedScalarContainer* self, PyObject* args) {
return NULL;
}
ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
- ScopedPyObjectPtr new_list(Subscript(self, full_slice));
- if (PyList_Insert(new_list, index, value) < 0) {
+ ScopedPyObjectPtr new_list(Subscript(self, full_slice.get()));
+ if (PyList_Insert(new_list.get(), index, value) < 0) {
return NULL;
}
- int ret = InternalAssignRepeatedField(self, new_list);
+ int ret = InternalAssignRepeatedField(self, new_list.get());
if (ret < 0) {
return NULL;
}
@@ -545,7 +544,7 @@ static PyObject* Remove(RepeatedScalarContainer* self, PyObject* value) {
Py_ssize_t match_index = -1;
for (Py_ssize_t i = 0; i < Len(self); ++i) {
ScopedPyObjectPtr elem(Item(self, i));
- if (PyObject_RichCompareBool(elem, value, Py_EQ)) {
+ if (PyObject_RichCompareBool(elem.get(), value, Py_EQ)) {
match_index = i;
break;
}
@@ -580,15 +579,15 @@ static PyObject* RichCompare(RepeatedScalarContainer* self,
ScopedPyObjectPtr other_list_deleter;
if (PyObject_TypeCheck(other, &RepeatedScalarContainer_Type)) {
other_list_deleter.reset(Subscript(
- reinterpret_cast<RepeatedScalarContainer*>(other), full_slice));
+ reinterpret_cast<RepeatedScalarContainer*>(other), full_slice.get()));
other = other_list_deleter.get();
}
- ScopedPyObjectPtr list(Subscript(self, full_slice));
+ ScopedPyObjectPtr list(Subscript(self, full_slice.get()));
if (list == NULL) {
return NULL;
}
- return PyObject_RichCompare(list, other, opid);
+ return PyObject_RichCompare(list.get(), other, opid);
}
PyObject* Reduce(RepeatedScalarContainer* unused_self) {
@@ -619,66 +618,63 @@ static PyObject* Sort(RepeatedScalarContainer* self,
if (full_slice == NULL) {
return NULL;
}
- ScopedPyObjectPtr list(Subscript(self, full_slice));
+ ScopedPyObjectPtr list(Subscript(self, full_slice.get()));
if (list == NULL) {
return NULL;
}
- ScopedPyObjectPtr m(PyObject_GetAttrString(list, "sort"));
+ ScopedPyObjectPtr m(PyObject_GetAttrString(list.get(), "sort"));
if (m == NULL) {
return NULL;
}
- ScopedPyObjectPtr res(PyObject_Call(m, args, kwds));
+ ScopedPyObjectPtr res(PyObject_Call(m.get(), args, kwds));
if (res == NULL) {
return NULL;
}
- int ret = InternalAssignRepeatedField(self, list);
+ int ret = InternalAssignRepeatedField(self, list.get());
if (ret < 0) {
return NULL;
}
Py_RETURN_NONE;
}
-static int Init(RepeatedScalarContainer* self,
- PyObject* args,
- PyObject* kwargs) {
- PyObject* py_parent;
- PyObject* py_parent_field;
- if (!PyArg_UnpackTuple(args, "__init__()", 2, 2, &py_parent,
- &py_parent_field)) {
- return -1;
+static PyObject* Pop(RepeatedScalarContainer* self,
+ PyObject* args) {
+ Py_ssize_t index = -1;
+ if (!PyArg_ParseTuple(args, "|n", &index)) {
+ return NULL;
}
-
- if (!PyObject_TypeCheck(py_parent, &CMessage_Type)) {
- PyErr_Format(PyExc_TypeError,
- "expect %s, but got %s",
- CMessage_Type.tp_name,
- Py_TYPE(py_parent)->tp_name);
- return -1;
+ PyObject* item = Item(self, index);
+ if (item == NULL) {
+ PyErr_Format(PyExc_IndexError,
+ "list index (%zd) out of range",
+ index);
+ return NULL;
}
-
- if (!PyObject_TypeCheck(py_parent_field, &CFieldDescriptor_Type)) {
- PyErr_Format(PyExc_TypeError,
- "expect %s, but got %s",
- CFieldDescriptor_Type.tp_name,
- Py_TYPE(py_parent_field)->tp_name);
- return -1;
+ if (AssignItem(self, index, NULL) < 0) {
+ return NULL;
}
+ return item;
+}
- CMessage* cmessage = reinterpret_cast<CMessage*>(py_parent);
- CFieldDescriptor* cdescriptor = reinterpret_cast<CFieldDescriptor*>(
- py_parent_field);
+// The private constructor of RepeatedScalarContainer objects.
+PyObject *NewContainer(
+ CMessage* parent, const FieldDescriptor* parent_field_descriptor) {
+ if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
+ return NULL;
+ }
- if (!FIELD_BELONGS_TO_MESSAGE(cdescriptor->descriptor, cmessage->message)) {
- PyErr_SetString(
- PyExc_KeyError, "Field does not belong to message!");
- return -1;
+ RepeatedScalarContainer* self = reinterpret_cast<RepeatedScalarContainer*>(
+ PyType_GenericAlloc(&RepeatedScalarContainer_Type, 0));
+ if (self == NULL) {
+ return NULL;
}
- self->message = cmessage->message;
- self->parent = cmessage;
- self->parent_field = cdescriptor;
- self->owner = cmessage->owner;
- return 0;
+ self->message = parent->message;
+ self->parent = parent;
+ self->parent_field_descriptor = parent_field_descriptor;
+ self->owner = parent->owner;
+
+ return reinterpret_cast<PyObject*>(self);
}
// Initializes the underlying Message object of "to" so it becomes a new parent
@@ -692,20 +688,16 @@ static int InitializeAndCopyToParentContainer(
if (full_slice == NULL) {
return -1;
}
- ScopedPyObjectPtr values(Subscript(from, full_slice));
+ ScopedPyObjectPtr values(Subscript(from, full_slice.get()));
if (values == NULL) {
return -1;
}
- google::protobuf::Message* new_message = global_message_factory->GetPrototype(
- from->message->GetDescriptor())->New();
+ Message* new_message = from->message->New();
to->parent = NULL;
- // TODO(anuraag): Document why it's OK to hang on to parent_field,
- // even though it's a weak reference. It ought to be enough to
- // hold on to the FieldDescriptor only.
- to->parent_field = from->parent_field;
+ to->parent_field_descriptor = from->parent_field_descriptor;
to->message = new_message;
to->owner.reset(new_message);
- if (InternalAssignRepeatedField(to, values) < 0) {
+ if (InternalAssignRepeatedField(to, values.get()) < 0) {
return -1;
}
return 0;
@@ -716,23 +708,17 @@ int Release(RepeatedScalarContainer* self) {
}
PyObject* DeepCopy(RepeatedScalarContainer* self, PyObject* arg) {
- ScopedPyObjectPtr init_args(
- PyTuple_Pack(2, self->parent, self->parent_field));
- PyObject* clone = PyObject_CallObject(
- reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type), init_args);
+ RepeatedScalarContainer* clone = reinterpret_cast<RepeatedScalarContainer*>(
+ PyType_GenericAlloc(&RepeatedScalarContainer_Type, 0));
if (clone == NULL) {
return NULL;
}
- if (!PyObject_TypeCheck(clone, &RepeatedScalarContainer_Type)) {
- Py_DECREF(clone);
- return NULL;
- }
- if (InitializeAndCopyToParentContainer(
- self, reinterpret_cast<RepeatedScalarContainer*>(clone)) < 0) {
+
+ if (InitializeAndCopyToParentContainer(self, clone) < 0) {
Py_DECREF(clone);
return NULL;
}
- return clone;
+ return reinterpret_cast<PyObject*>(clone);
}
static void Dealloc(RepeatedScalarContainer* self) {
@@ -771,6 +757,8 @@ static PyMethodDef Methods[] = {
"Appends objects to the repeated container." },
{ "insert", (PyCFunction)Insert, METH_VARARGS,
"Appends objects to the repeated container." },
+ { "pop", (PyCFunction)Pop, METH_VARARGS,
+ "Removes an object from the repeated container and returns it." },
{ "remove", (PyCFunction)Remove, METH_O,
"Removes an object from the repeated container." },
{ "sort", (PyCFunction)Sort, METH_VARARGS | METH_KEYWORDS,
@@ -782,8 +770,7 @@ static PyMethodDef Methods[] = {
PyTypeObject RepeatedScalarContainer_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- "google.protobuf.internal."
- "cpp._message.RepeatedScalarContainer", // tp_name
+ FULL_MODULE_NAME ".RepeatedScalarContainer", // tp_name
sizeof(RepeatedScalarContainer), // tp_basicsize
0, // tp_itemsize
(destructor)repeated_scalar_container::Dealloc, // tp_dealloc
@@ -795,7 +782,7 @@ PyTypeObject RepeatedScalarContainer_Type = {
0, // tp_as_number
&repeated_scalar_container::SqMethods, // tp_as_sequence
&repeated_scalar_container::MpMethods, // tp_as_mapping
- 0, // tp_hash
+ PyObject_HashNotImplemented, // tp_hash
0, // tp_call
0, // tp_str
0, // tp_getattro
@@ -817,7 +804,7 @@ PyTypeObject RepeatedScalarContainer_Type = {
0, // tp_descr_get
0, // tp_descr_set
0, // tp_dictoffset
- (initproc)repeated_scalar_container::Init, // tp_init
+ 0, // tp_init
};
} // namespace python
diff --git a/python/google/protobuf/pyext/repeated_scalar_container.h b/python/google/protobuf/pyext/repeated_scalar_container.h
index 69d15d5cd..555e621c9 100644
--- a/python/google/protobuf/pyext/repeated_scalar_container.h
+++ b/python/google/protobuf/pyext/repeated_scalar_container.h
@@ -41,17 +41,21 @@
#include <google/protobuf/stubs/shared_ptr.h>
#endif
+#include <google/protobuf/descriptor.h>
namespace google {
namespace protobuf {
class Message;
+#ifdef _SHARED_PTR_H
+using std::shared_ptr;
+#else
using internal::shared_ptr;
+#endif
namespace python {
-struct CFieldDescriptor;
struct CMessage;
typedef struct RepeatedScalarContainer {
@@ -73,16 +77,22 @@ typedef struct RepeatedScalarContainer {
// modifying the container.
CMessage* parent;
- // Weak reference to the parent's descriptor that describes this
+ // Pointer to the parent's descriptor that describes this
// field. Used together with the parent's message when making a
// default message instance mutable.
- CFieldDescriptor* parent_field;
+ // The pointer is owned by the global DescriptorPool.
+ const FieldDescriptor* parent_field_descriptor;
} RepeatedScalarContainer;
extern PyTypeObject RepeatedScalarContainer_Type;
namespace repeated_scalar_container {
+// Builds a RepeatedScalarContainer object, from a parent message and a
+// field descriptor.
+extern PyObject *NewContainer(
+ CMessage* parent, const FieldDescriptor* parent_field_descriptor);
+
// Appends the scalar 'item' to the end of the container 'self'.
//
// Returns None if successful; returns NULL and sets an exception if
diff --git a/python/google/protobuf/pyext/scoped_pyobject_ptr.h b/python/google/protobuf/pyext/scoped_pyobject_ptr.h
index 9f337c3c8..a128cd4c6 100644
--- a/python/google/protobuf/pyext/scoped_pyobject_ptr.h
+++ b/python/google/protobuf/pyext/scoped_pyobject_ptr.h
@@ -33,12 +33,14 @@
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__
+#include <google/protobuf/stubs/common.h>
+
#include <Python.h>
namespace google {
class ScopedPyObjectPtr {
public:
- // Constructor. Defaults to intializing with NULL.
+ // Constructor. Defaults to initializing with NULL.
// There is no way to create an uninitialized ScopedPyObjectPtr.
explicit ScopedPyObjectPtr(PyObject* p = NULL) : ptr_(p) { }
@@ -49,24 +51,23 @@ class ScopedPyObjectPtr {
// Reset. Deletes the current owned object, if any.
// Then takes ownership of a new object, if given.
- // this->reset(this->get()) works.
+ // This function must be called with a reference that you own.
+ // this->reset(this->get()) is wrong!
+ // this->reset(this->release()) is OK.
PyObject* reset(PyObject* p = NULL) {
- if (p != ptr_) {
- Py_XDECREF(ptr_);
- ptr_ = p;
- }
+ Py_XDECREF(ptr_);
+ ptr_ = p;
return ptr_;
}
// Releases ownership of the object.
+ // The caller now owns the returned reference.
PyObject* release() {
PyObject* p = ptr_;
ptr_ = NULL;
return p;
}
- operator PyObject*() { return ptr_; }
-
PyObject* operator->() const {
assert(ptr_ != NULL);
return ptr_;
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py
index 1fc704a23..0c757264f 100755
--- a/python/google/protobuf/reflection.py
+++ b/python/google/protobuf/reflection.py
@@ -49,108 +49,23 @@ __author__ = 'robinson@google.com (Will Robinson)'
from google.protobuf.internal import api_implementation
-from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message
-_FieldDescriptor = descriptor_mod.FieldDescriptor
-
if api_implementation.Type() == 'cpp':
- if api_implementation.Version() == 2:
- from google.protobuf.pyext import cpp_message
- _NewMessage = cpp_message.NewMessage
- _InitMessage = cpp_message.InitMessage
- else:
- from google.protobuf.internal import cpp_message
- _NewMessage = cpp_message.NewMessage
- _InitMessage = cpp_message.InitMessage
+ from google.protobuf.pyext import cpp_message as message_impl
else:
- from google.protobuf.internal import python_message
- _NewMessage = python_message.NewMessage
- _InitMessage = python_message.InitMessage
-
-
-class GeneratedProtocolMessageType(type):
-
- """Metaclass for protocol message classes created at runtime from Descriptors.
-
- We add implementations for all methods described in the Message class. We
- also create properties to allow getting/setting all fields in the protocol
- message. Finally, we create slots to prevent users from accidentally
- "setting" nonexistent fields in the protocol message, which then wouldn't get
- serialized / deserialized properly.
+ from google.protobuf.internal import python_message as message_impl
- The protocol compiler currently uses this metaclass to create protocol
- message classes at runtime. Clients can also manually create their own
- classes at runtime, as in this example:
-
- mydescriptor = Descriptor(.....)
- class MyProtoClass(Message):
- __metaclass__ = GeneratedProtocolMessageType
- DESCRIPTOR = mydescriptor
- myproto_instance = MyProtoClass()
- myproto.foo_field = 23
- ...
-
- The above example will not work for nested types. If you wish to include them,
- use reflection.MakeClass() instead of manually instantiating the class in
- order to create the appropriate class structure.
- """
-
- # Must be consistent with the protocol-compiler code in
- # proto2/compiler/internal/generator.*.
- _DESCRIPTOR_KEY = 'DESCRIPTOR'
-
- def __new__(cls, name, bases, dictionary):
- """Custom allocation for runtime-generated class types.
-
- We override __new__ because this is apparently the only place
- where we can meaningfully set __slots__ on the class we're creating(?).
- (The interplay between metaclasses and slots is not very well-documented).
-
- Args:
- name: Name of the class (ignored, but required by the
- metaclass protocol).
- bases: Base classes of the class we're constructing.
- (Should be message.Message). We ignore this field, but
- it's required by the metaclass protocol
- dictionary: The class dictionary of the class we're
- constructing. dictionary[_DESCRIPTOR_KEY] must contain
- a Descriptor object describing this protocol message
- type.
-
- Returns:
- Newly-allocated class.
- """
- descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
- bases = _NewMessage(bases, descriptor, dictionary)
- superclass = super(GeneratedProtocolMessageType, cls)
-
- new_class = superclass.__new__(cls, name, bases, dictionary)
- setattr(descriptor, '_concrete_class', new_class)
- return new_class
-
- def __init__(cls, name, bases, dictionary):
- """Here we perform the majority of our work on the class.
- We add enum getters, an __init__ method, implementations
- of all Message methods, and properties for all fields
- in the protocol type.
-
- Args:
- name: Name of the class (ignored, but required by the
- metaclass protocol).
- bases: Base classes of the class we're constructing.
- (Should be message.Message). We ignore this field, but
- it's required by the metaclass protocol
- dictionary: The class dictionary of the class we're
- constructing. dictionary[_DESCRIPTOR_KEY] must contain
- a Descriptor object describing this protocol message
- type.
- """
- descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
- _InitMessage(descriptor, cls)
- superclass = super(GeneratedProtocolMessageType, cls)
- superclass.__init__(name, bases, dictionary)
+# The type of all Message classes.
+# Part of the public interface.
+#
+# Used by generated files, but clients can also use it at runtime:
+# mydescriptor = pool.FindDescriptor(.....)
+# class MyProtoClass(Message):
+# __metaclass__ = GeneratedProtocolMessageType
+# DESCRIPTOR = mydescriptor
+GeneratedProtocolMessageType = message_impl.GeneratedProtocolMessageType
def ParseMessage(descriptor, byte_str):
diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py
index 4c70b3938..87760f263 100644
--- a/python/google/protobuf/symbol_database.py
+++ b/python/google/protobuf/symbol_database.py
@@ -72,12 +72,12 @@ class SymbolDatabase(object):
buffer types used within a program.
"""
- def __init__(self):
+ def __init__(self, pool=None):
"""Constructor."""
self._symbols = {}
self._symbols_by_file = {}
- self.pool = descriptor_pool.DescriptorPool()
+ self.pool = pool or descriptor_pool.Default()
def RegisterMessage(self, message):
"""Registers the given message type in the local database.
@@ -177,7 +177,7 @@ class SymbolDatabase(object):
result.update(self._symbols_by_file[f])
return result
-_DEFAULT = SymbolDatabase()
+_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default())
def Default():
diff --git a/python/google/protobuf/text_encoding.py b/python/google/protobuf/text_encoding.py
index 2d86a67ca..989956382 100644
--- a/python/google/protobuf/text_encoding.py
+++ b/python/google/protobuf/text_encoding.py
@@ -28,15 +28,13 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#PY25 compatible for GAE.
-#
"""Encoding related utilities."""
-
import re
-import sys ##PY25
+
+import six
# Lookup table for utf8
-_cescape_utf8_to_str = [chr(i) for i in xrange(0, 256)]
+_cescape_utf8_to_str = [chr(i) for i in range(0, 256)]
_cescape_utf8_to_str[9] = r'\t' # optional escape
_cescape_utf8_to_str[10] = r'\n' # optional escape
_cescape_utf8_to_str[13] = r'\r' # optional escape
@@ -46,9 +44,9 @@ _cescape_utf8_to_str[34] = r'\"' # necessary escape
_cescape_utf8_to_str[92] = r'\\' # necessary escape
# Lookup table for non-utf8, with necessary escapes at (o >= 127 or o < 32)
-_cescape_byte_to_str = ([r'\%03o' % i for i in xrange(0, 32)] +
- [chr(i) for i in xrange(32, 127)] +
- [r'\%03o' % i for i in xrange(127, 256)])
+_cescape_byte_to_str = ([r'\%03o' % i for i in range(0, 32)] +
+ [chr(i) for i in range(32, 127)] +
+ [r'\%03o' % i for i in range(127, 256)])
_cescape_byte_to_str[9] = r'\t' # optional escape
_cescape_byte_to_str[10] = r'\n' # optional escape
_cescape_byte_to_str[13] = r'\r' # optional escape
@@ -75,7 +73,7 @@ def CEscape(text, as_utf8):
"""
# PY3 hack: make Ord work for str and bytes:
# //platforms/networking/data uses unicode here, hence basestring.
- Ord = ord if isinstance(text, basestring) else lambda x: x
+ Ord = ord if isinstance(text, six.string_types) else lambda x: x
if as_utf8:
return ''.join(_cescape_utf8_to_str[Ord(c)] for c in text)
return ''.join(_cescape_byte_to_str[Ord(c)] for c in text)
@@ -100,8 +98,7 @@ def CUnescape(text):
# allow single-digit hex escapes (like '\xf').
result = _CUNESCAPE_HEX.sub(ReplaceHex, text)
- if sys.version_info[0] < 3: ##PY25
-##!PY25 if str is bytes: # PY2
+ if str is bytes: # PY2
return result.decode('string_escape')
result = ''.join(_cescape_highbit_to_str[ord(c)] for c in result)
return (result.encode('ascii') # Make it bytes to allow decode.
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index 2429fa59f..6f1e3c8b7 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -28,17 +28,28 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#PY25 compatible for GAE.
-#
-# Copyright 2007 Google Inc. All Rights Reserved.
+"""Contains routines for printing protocol messages in text format.
+
+Simple usage example:
+
+ # Create a proto object and serialize it to a text proto string.
+ message = my_proto_pb2.MyMessage(foo='bar')
+ text_proto = text_format.MessageToString(message)
-"""Contains routines for printing protocol messages in text format."""
+ # Parse a text proto string.
+ message = text_format.Parse(text_proto, my_proto_pb2.MyMessage())
+"""
__author__ = 'kenton@google.com (Kenton Varda)'
-import cStringIO
+import io
import re
+import six
+
+if six.PY3:
+ long = int
+
from google.protobuf.internal import type_checkers
from google.protobuf import descriptor
from google.protobuf import text_encoding
@@ -55,6 +66,7 @@ _FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE)
_FLOAT_NAN = re.compile('nanf?', re.IGNORECASE)
_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT,
descriptor.FieldDescriptor.CPPTYPE_DOUBLE])
+_QUOTES = frozenset(("'", '"'))
class Error(Exception):
@@ -62,17 +74,38 @@ class Error(Exception):
class ParseError(Error):
- """Thrown in case of ASCII parsing error."""
+ """Thrown in case of text parsing error."""
+
+
+class TextWriter(object):
+ def __init__(self, as_utf8):
+ if six.PY2:
+ self._writer = io.BytesIO()
+ else:
+ self._writer = io.StringIO()
+
+ def write(self, val):
+ if six.PY2:
+ if isinstance(val, six.text_type):
+ val = val.encode('utf-8')
+ return self._writer.write(val)
+
+ def close(self):
+ return self._writer.close()
+
+ def getvalue(self):
+ return self._writer.getvalue()
def MessageToString(message, as_utf8=False, as_one_line=False,
pointy_brackets=False, use_index_order=False,
- float_format=None):
+ float_format=None, use_field_number=False):
"""Convert protobuf message to text format.
Floating point values can be formatted compactly with 15 digits of
precision (which is the most that IEEE 754 "double" can guarantee)
- using float_format='.15g'.
+ using float_format='.15g'. To ensure that converting to text and back to a
+ proto will result in an identical value, float_format='.17g' should be used.
Args:
message: The protocol buffers message.
@@ -85,15 +118,16 @@ def MessageToString(message, as_utf8=False, as_one_line=False,
field number order.
float_format: If set, use this to specify floating point number formatting
(per the "Format Specification Mini-Language"); otherwise, str() is used.
+ use_field_number: If True, print field numbers instead of names.
Returns:
A string of the text formatted protocol buffer message.
"""
- out = cStringIO.StringIO()
- PrintMessage(message, out, as_utf8=as_utf8, as_one_line=as_one_line,
- pointy_brackets=pointy_brackets,
- use_index_order=use_index_order,
- float_format=float_format)
+ out = TextWriter(as_utf8)
+ printer = _Printer(out, 0, as_utf8, as_one_line,
+ pointy_brackets, use_index_order, float_format,
+ use_field_number)
+ printer.PrintMessage(message)
result = out.getvalue()
out.close()
if as_one_line:
@@ -101,262 +135,446 @@ def MessageToString(message, as_utf8=False, as_one_line=False,
return result
+def _IsMapEntry(field):
+ return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ field.message_type.has_options and
+ field.message_type.GetOptions().map_entry)
+
+
def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False,
pointy_brackets=False, use_index_order=False,
- float_format=None):
- fields = message.ListFields()
- if use_index_order:
- fields.sort(key=lambda x: x[0].index)
- for field, value in fields:
- if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- for element in value:
- PrintField(field, element, out, indent, as_utf8, as_one_line,
- pointy_brackets=pointy_brackets,
- float_format=float_format)
- else:
- PrintField(field, value, out, indent, as_utf8, as_one_line,
- pointy_brackets=pointy_brackets,
- float_format=float_format)
+ float_format=None, use_field_number=False):
+ printer = _Printer(out, indent, as_utf8, as_one_line,
+ pointy_brackets, use_index_order, float_format,
+ use_field_number)
+ printer.PrintMessage(message)
def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False,
- pointy_brackets=False, float_format=None):
- """Print a single field name/value pair. For repeated fields, the value
- should be a single element."""
-
- out.write(' ' * indent)
- if field.is_extension:
- out.write('[')
- if (field.containing_type.GetOptions().message_set_wire_format and
- field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
- field.message_type == field.extension_scope and
- field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL):
- out.write(field.message_type.full_name)
- else:
- out.write(field.full_name)
- out.write(']')
- elif field.type == descriptor.FieldDescriptor.TYPE_GROUP:
- # For groups, use the capitalized name.
- out.write(field.message_type.name)
- else:
- out.write(field.name)
-
- if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- # The colon is optional in this case, but our cross-language golden files
- # don't include it.
- out.write(': ')
-
- PrintFieldValue(field, value, out, indent, as_utf8, as_one_line,
- pointy_brackets=pointy_brackets,
- float_format=float_format)
- if as_one_line:
- out.write(' ')
- else:
- out.write('\n')
+ pointy_brackets=False, use_index_order=False, float_format=None):
+ """Print a single field name/value pair."""
+ printer = _Printer(out, indent, as_utf8, as_one_line,
+ pointy_brackets, use_index_order, float_format)
+ printer.PrintField(field, value)
def PrintFieldValue(field, value, out, indent=0, as_utf8=False,
as_one_line=False, pointy_brackets=False,
+ use_index_order=False,
float_format=None):
- """Print a single field value (not including name). For repeated fields,
- the value should be a single element."""
+ """Print a single field value (not including name)."""
+ printer = _Printer(out, indent, as_utf8, as_one_line,
+ pointy_brackets, use_index_order, float_format)
+ printer.PrintFieldValue(field, value)
- if pointy_brackets:
- openb = '<'
- closeb = '>'
- else:
- openb = '{'
- closeb = '}'
-
- if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- if as_one_line:
- out.write(' %s ' % openb)
- PrintMessage(value, out, indent, as_utf8, as_one_line,
- pointy_brackets=pointy_brackets,
- float_format=float_format)
- out.write(closeb)
- else:
- out.write(' %s\n' % openb)
- PrintMessage(value, out, indent + 2, as_utf8, as_one_line,
- pointy_brackets=pointy_brackets,
- float_format=float_format)
- out.write(' ' * indent + closeb)
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
- enum_value = field.enum_type.values_by_number.get(value, None)
- if enum_value is not None:
- out.write(enum_value.name)
- else:
- out.write(str(value))
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
- out.write('\"')
- if isinstance(value, unicode):
- out_value = value.encode('utf-8')
- else:
- out_value = value
- if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
- # We need to escape non-UTF8 chars in TYPE_BYTES field.
- out_as_utf8 = False
- else:
- out_as_utf8 = as_utf8
- out.write(text_encoding.CEscape(out_value, out_as_utf8))
- out.write('\"')
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
- if value:
- out.write('true')
+
+class _Printer(object):
+ """Text format printer for protocol message."""
+
+ def __init__(self, out, indent=0, as_utf8=False, as_one_line=False,
+ pointy_brackets=False, use_index_order=False, float_format=None,
+ use_field_number=False):
+ """Initialize the Printer.
+
+ Floating point values can be formatted compactly with 15 digits of
+ precision (which is the most that IEEE 754 "double" can guarantee)
+ using float_format='.15g'. To ensure that converting to text and back to a
+ proto will result in an identical value, float_format='.17g' should be used.
+
+ Args:
+ out: To record the text format result.
+ indent: The indent level for pretty print.
+ as_utf8: Produce text output in UTF8 format.
+ as_one_line: Don't introduce newlines between fields.
+ pointy_brackets: If True, use angle brackets instead of curly braces for
+ nesting.
+ use_index_order: If True, print fields of a proto message using the order
+ defined in source code instead of the field number. By default, use the
+ field number order.
+ float_format: If set, use this to specify floating point number formatting
+ (per the "Format Specification Mini-Language"); otherwise, str() is
+ used.
+ use_field_number: If True, print field numbers instead of names.
+ """
+ self.out = out
+ self.indent = indent
+ self.as_utf8 = as_utf8
+ self.as_one_line = as_one_line
+ self.pointy_brackets = pointy_brackets
+ self.use_index_order = use_index_order
+ self.float_format = float_format
+ self.use_field_number = use_field_number
+
+ def PrintMessage(self, message):
+ """Convert protobuf message to text format.
+
+ Args:
+ message: The protocol buffers message.
+ """
+ fields = message.ListFields()
+ if self.use_index_order:
+ fields.sort(key=lambda x: x[0].index)
+ for field, value in fields:
+ if _IsMapEntry(field):
+ for key in sorted(value):
+ # This is slow for maps with submessage entires because it copies the
+ # entire tree. Unfortunately this would take significant refactoring
+ # of this file to work around.
+ #
+ # TODO(haberman): refactor and optimize if this becomes an issue.
+ entry_submsg = field.message_type._concrete_class(
+ key=key, value=value[key])
+ self.PrintField(field, entry_submsg)
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ for element in value:
+ self.PrintField(field, element)
+ else:
+ self.PrintField(field, value)
+
+ def PrintField(self, field, value):
+ """Print a single field name/value pair."""
+ out = self.out
+ out.write(' ' * self.indent)
+ if self.use_field_number:
+ out.write(str(field.number))
else:
- out.write('false')
- elif field.cpp_type in _FLOAT_TYPES and float_format is not None:
- out.write('{1:{0}}'.format(float_format, value))
- else:
- out.write(str(value))
+ if field.is_extension:
+ out.write('[')
+ if (field.containing_type.GetOptions().message_set_wire_format and
+ field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL):
+ out.write(field.message_type.full_name)
+ else:
+ out.write(field.full_name)
+ out.write(']')
+ elif field.type == descriptor.FieldDescriptor.TYPE_GROUP:
+ # For groups, use the capitalized name.
+ out.write(field.message_type.name)
+ else:
+ out.write(field.name)
+ if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ # The colon is optional in this case, but our cross-language golden files
+ # don't include it.
+ out.write(': ')
-def _ParseOrMerge(lines, message, allow_multiple_scalars):
- """Converts an ASCII representation of a protocol message into a message.
+ self.PrintFieldValue(field, value)
+ if self.as_one_line:
+ out.write(' ')
+ else:
+ out.write('\n')
- Args:
- lines: Lines of a message's ASCII representation.
- message: A protocol buffer message to merge into.
- allow_multiple_scalars: Determines if repeated values for a non-repeated
- field are permitted, e.g., the string "foo: 1 foo: 2" for a
- required/optional field named "foo".
+ def PrintFieldValue(self, field, value):
+ """Print a single field value (not including name).
- Raises:
- ParseError: On ASCII parsing problems.
- """
- tokenizer = _Tokenizer(lines)
- while not tokenizer.AtEnd():
- _MergeField(tokenizer, message, allow_multiple_scalars)
+ For repeated fields, the value should be a single element.
+
+ Args:
+ field: The descriptor of the field to be printed.
+ value: The value of the field.
+ """
+ out = self.out
+ if self.pointy_brackets:
+ openb = '<'
+ closeb = '>'
+ else:
+ openb = '{'
+ closeb = '}'
+
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ if self.as_one_line:
+ out.write(' %s ' % openb)
+ self.PrintMessage(value)
+ out.write(closeb)
+ else:
+ out.write(' %s\n' % openb)
+ self.indent += 2
+ self.PrintMessage(value)
+ self.indent -= 2
+ out.write(' ' * self.indent + closeb)
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
+ enum_value = field.enum_type.values_by_number.get(value, None)
+ if enum_value is not None:
+ out.write(enum_value.name)
+ else:
+ out.write(str(value))
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
+ out.write('\"')
+ if isinstance(value, six.text_type):
+ out_value = value.encode('utf-8')
+ else:
+ out_value = value
+ if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ # We need to escape non-UTF8 chars in TYPE_BYTES field.
+ out_as_utf8 = False
+ else:
+ out_as_utf8 = self.as_utf8
+ out.write(text_encoding.CEscape(out_value, out_as_utf8))
+ out.write('\"')
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
+ if value:
+ out.write('true')
+ else:
+ out.write('false')
+ elif field.cpp_type in _FLOAT_TYPES and self.float_format is not None:
+ out.write('{1:{0}}'.format(self.float_format, value))
+ else:
+ out.write(str(value))
-def Parse(text, message):
- """Parses an ASCII representation of a protocol message into a message.
+def Parse(text, message,
+ allow_unknown_extension=False, allow_field_number=False):
+ """Parses an text representation of a protocol message into a message.
Args:
- text: Message ASCII representation.
+ text: Message text representation.
message: A protocol buffer message to merge into.
+ allow_unknown_extension: if True, skip over missing extensions and keep
+ parsing
+ allow_field_number: if True, both field number and field name are allowed.
Returns:
The same message passed as argument.
Raises:
- ParseError: On ASCII parsing problems.
+ ParseError: On text parsing problems.
"""
- if not isinstance(text, str): text = text.decode('utf-8')
- return ParseLines(text.split('\n'), message)
+ if not isinstance(text, str):
+ text = text.decode('utf-8')
+ return ParseLines(text.split('\n'), message, allow_unknown_extension,
+ allow_field_number)
-def Merge(text, message):
- """Parses an ASCII representation of a protocol message into a message.
+def Merge(text, message, allow_unknown_extension=False,
+ allow_field_number=False):
+ """Parses an text representation of a protocol message into a message.
Like Parse(), but allows repeated values for a non-repeated field, and uses
the last one.
Args:
- text: Message ASCII representation.
+ text: Message text representation.
message: A protocol buffer message to merge into.
+ allow_unknown_extension: if True, skip over missing extensions and keep
+ parsing
+ allow_field_number: if True, both field number and field name are allowed.
Returns:
The same message passed as argument.
Raises:
- ParseError: On ASCII parsing problems.
+ ParseError: On text parsing problems.
"""
- return MergeLines(text.split('\n'), message)
+ return MergeLines(text.split('\n'), message, allow_unknown_extension,
+ allow_field_number)
-def ParseLines(lines, message):
- """Parses an ASCII representation of a protocol message into a message.
+def ParseLines(lines, message, allow_unknown_extension=False,
+ allow_field_number=False):
+ """Parses an text representation of a protocol message into a message.
Args:
- lines: An iterable of lines of a message's ASCII representation.
+ lines: An iterable of lines of a message's text representation.
message: A protocol buffer message to merge into.
+ allow_unknown_extension: if True, skip over missing extensions and keep
+ parsing
+ allow_field_number: if True, both field number and field name are allowed.
Returns:
The same message passed as argument.
Raises:
- ParseError: On ASCII parsing problems.
+ ParseError: On text parsing problems.
"""
- _ParseOrMerge(lines, message, False)
- return message
+ parser = _Parser(allow_unknown_extension, allow_field_number)
+ return parser.ParseLines(lines, message)
-def MergeLines(lines, message):
- """Parses an ASCII representation of a protocol message into a message.
+def MergeLines(lines, message, allow_unknown_extension=False,
+ allow_field_number=False):
+ """Parses an text representation of a protocol message into a message.
Args:
- lines: An iterable of lines of a message's ASCII representation.
+ lines: An iterable of lines of a message's text representation.
message: A protocol buffer message to merge into.
+ allow_unknown_extension: if True, skip over missing extensions and keep
+ parsing
+ allow_field_number: if True, both field number and field name are allowed.
Returns:
The same message passed as argument.
Raises:
- ParseError: On ASCII parsing problems.
+ ParseError: On text parsing problems.
"""
- _ParseOrMerge(lines, message, True)
- return message
+ parser = _Parser(allow_unknown_extension, allow_field_number)
+ return parser.MergeLines(lines, message)
-def _MergeField(tokenizer, message, allow_multiple_scalars):
- """Merges a single protocol message field into a message.
+class _Parser(object):
+ """Text format parser for protocol message."""
- Args:
- tokenizer: A tokenizer to parse the field name and values.
- message: A protocol message to record the data.
- allow_multiple_scalars: Determines if repeated values for a non-repeated
- field are permitted, e.g., the string "foo: 1 foo: 2" for a
- required/optional field named "foo".
+ def __init__(self, allow_unknown_extension=False, allow_field_number=False):
+ self.allow_unknown_extension = allow_unknown_extension
+ self.allow_field_number = allow_field_number
- Raises:
- ParseError: In case of ASCII parsing problems.
- """
- message_descriptor = message.DESCRIPTOR
- if tokenizer.TryConsume('['):
- name = [tokenizer.ConsumeIdentifier()]
- while tokenizer.TryConsume('.'):
- name.append(tokenizer.ConsumeIdentifier())
- name = '.'.join(name)
-
- if not message_descriptor.is_extendable:
- raise tokenizer.ParseErrorPreviousToken(
- 'Message type "%s" does not have extensions.' %
- message_descriptor.full_name)
- # pylint: disable=protected-access
- field = message.Extensions._FindExtensionByName(name)
- # pylint: enable=protected-access
- if not field:
- raise tokenizer.ParseErrorPreviousToken(
- 'Extension "%s" not registered.' % name)
- elif message_descriptor != field.containing_type:
- raise tokenizer.ParseErrorPreviousToken(
- 'Extension "%s" does not extend message type "%s".' % (
- name, message_descriptor.full_name))
- tokenizer.Consume(']')
- else:
- name = tokenizer.ConsumeIdentifier()
- field = message_descriptor.fields_by_name.get(name, None)
+ def ParseFromString(self, text, message):
+ """Parses an text representation of a protocol message into a message."""
+ if not isinstance(text, str):
+ text = text.decode('utf-8')
+ return self.ParseLines(text.split('\n'), message)
+
+ def ParseLines(self, lines, message):
+ """Parses an text representation of a protocol message into a message."""
+ self._allow_multiple_scalars = False
+ self._ParseOrMerge(lines, message)
+ return message
+
+ def MergeFromString(self, text, message):
+ """Merges an text representation of a protocol message into a message."""
+ return self._MergeLines(text.split('\n'), message)
- # Group names are expected to be capitalized as they appear in the
- # .proto file, which actually matches their type names, not their field
- # names.
- if not field:
- field = message_descriptor.fields_by_name.get(name.lower(), None)
- if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP:
- field = None
+ def MergeLines(self, lines, message):
+ """Merges an text representation of a protocol message into a message."""
+ self._allow_multiple_scalars = True
+ self._ParseOrMerge(lines, message)
+ return message
- if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and
- field.message_type.name != name):
- field = None
+ def _ParseOrMerge(self, lines, message):
+ """Converts an text representation of a protocol message into a message.
- if not field:
- raise tokenizer.ParseErrorPreviousToken(
- 'Message type "%s" has no field named "%s".' % (
- message_descriptor.full_name, name))
+ Args:
+ lines: Lines of a message's text representation.
+ message: A protocol buffer message to merge into.
- if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- tokenizer.TryConsume(':')
+ Raises:
+ ParseError: On text parsing problems.
+ """
+ tokenizer = _Tokenizer(lines)
+ while not tokenizer.AtEnd():
+ self._MergeField(tokenizer, message)
+
+ def _MergeField(self, tokenizer, message):
+ """Merges a single protocol message field into a message.
+
+ Args:
+ tokenizer: A tokenizer to parse the field name and values.
+ message: A protocol message to record the data.
+
+ Raises:
+ ParseError: In case of text parsing problems.
+ """
+ message_descriptor = message.DESCRIPTOR
+ if (hasattr(message_descriptor, 'syntax') and
+ message_descriptor.syntax == 'proto3'):
+ # Proto3 doesn't represent presence so we can't test if multiple
+ # scalars have occurred. We have to allow them.
+ self._allow_multiple_scalars = True
+ if tokenizer.TryConsume('['):
+ name = [tokenizer.ConsumeIdentifier()]
+ while tokenizer.TryConsume('.'):
+ name.append(tokenizer.ConsumeIdentifier())
+ name = '.'.join(name)
+
+ if not message_descriptor.is_extendable:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" does not have extensions.' %
+ message_descriptor.full_name)
+ # pylint: disable=protected-access
+ field = message.Extensions._FindExtensionByName(name)
+ # pylint: enable=protected-access
+ if not field:
+ if self.allow_unknown_extension:
+ field = None
+ else:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Extension "%s" not registered.' % name)
+ elif message_descriptor != field.containing_type:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Extension "%s" does not extend message type "%s".' % (
+ name, message_descriptor.full_name))
+
+ tokenizer.Consume(']')
+
+ else:
+ name = tokenizer.ConsumeIdentifier()
+ if self.allow_field_number and name.isdigit():
+ number = ParseInteger(name, True, True)
+ field = message_descriptor.fields_by_number.get(number, None)
+ if not field and message_descriptor.is_extendable:
+ field = message.Extensions._FindExtensionByNumber(number)
+ else:
+ field = message_descriptor.fields_by_name.get(name, None)
+
+ # Group names are expected to be capitalized as they appear in the
+ # .proto file, which actually matches their type names, not their field
+ # names.
+ if not field:
+ field = message_descriptor.fields_by_name.get(name.lower(), None)
+ if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP:
+ field = None
+
+ if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and
+ field.message_type.name != name):
+ field = None
+
+ if not field:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" has no field named "%s".' % (
+ message_descriptor.full_name, name))
+
+ if field:
+ if not self._allow_multiple_scalars and field.containing_oneof:
+ # Check if there's a different field set in this oneof.
+ # Note that we ignore the case if the same field was set before, and we
+ # apply _allow_multiple_scalars to non-scalar fields as well.
+ which_oneof = message.WhichOneof(field.containing_oneof.name)
+ if which_oneof is not None and which_oneof != field.name:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Field "%s" is specified along with field "%s", another member '
+ 'of oneof "%s" for message type "%s".' % (
+ field.name, which_oneof, field.containing_oneof.name,
+ message_descriptor.full_name))
+
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ tokenizer.TryConsume(':')
+ merger = self._MergeMessageField
+ else:
+ tokenizer.Consume(':')
+ merger = self._MergeScalarField
+
+ if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED
+ and tokenizer.TryConsume('[')):
+ # Short repeated format, e.g. "foo: [1, 2, 3]"
+ while True:
+ merger(tokenizer, message, field)
+ if tokenizer.TryConsume(']'): break
+ tokenizer.Consume(',')
+
+ else:
+ merger(tokenizer, message, field)
+
+ else: # Proto field is unknown.
+ assert self.allow_unknown_extension
+ _SkipFieldContents(tokenizer)
+
+ # For historical reasons, fields may optionally be separated by commas or
+ # semicolons.
+ if not tokenizer.TryConsume(','):
+ tokenizer.TryConsume(';')
+
+ def _MergeMessageField(self, tokenizer, message, field):
+ """Merges a single scalar field into a message.
+
+ Args:
+ tokenizer: A tokenizer to parse the field value.
+ message: The message of which field is a member.
+ field: The descriptor of the field to be merged.
+
+ Raises:
+ ParseError: In case of text parsing problems.
+ """
+ is_map_entry = _IsMapEntry(field)
if tokenizer.TryConsume('<'):
end_token = '>'
@@ -367,6 +585,9 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
if field.is_extension:
sub_message = message.Extensions[field].add()
+ elif is_map_entry:
+ # pylint: disable=protected-access
+ sub_message = field.message_type._concrete_class()
else:
sub_message = getattr(message, field.name).add()
else:
@@ -378,10 +599,117 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
while not tokenizer.TryConsume(end_token):
if tokenizer.AtEnd():
- raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token))
- _MergeField(tokenizer, sub_message, allow_multiple_scalars)
+ raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token,))
+ self._MergeField(tokenizer, sub_message)
+
+ if is_map_entry:
+ value_cpptype = field.message_type.fields_by_name['value'].cpp_type
+ if value_cpptype == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ value = getattr(message, field.name)[sub_message.key]
+ value.MergeFrom(sub_message.value)
+ else:
+ getattr(message, field.name)[sub_message.key] = sub_message.value
+
+ def _MergeScalarField(self, tokenizer, message, field):
+ """Merges a single scalar field into a message.
+
+ Args:
+ tokenizer: A tokenizer to parse the field value.
+ message: A protocol message to record the data.
+ field: The descriptor of the field to be merged.
+
+ Raises:
+ ParseError: In case of text parsing problems.
+ RuntimeError: On runtime errors.
+ """
+ _ = self.allow_unknown_extension
+ value = None
+
+ if field.type in (descriptor.FieldDescriptor.TYPE_INT32,
+ descriptor.FieldDescriptor.TYPE_SINT32,
+ descriptor.FieldDescriptor.TYPE_SFIXED32):
+ value = tokenizer.ConsumeInt32()
+ elif field.type in (descriptor.FieldDescriptor.TYPE_INT64,
+ descriptor.FieldDescriptor.TYPE_SINT64,
+ descriptor.FieldDescriptor.TYPE_SFIXED64):
+ value = tokenizer.ConsumeInt64()
+ elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32,
+ descriptor.FieldDescriptor.TYPE_FIXED32):
+ value = tokenizer.ConsumeUint32()
+ elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64,
+ descriptor.FieldDescriptor.TYPE_FIXED64):
+ value = tokenizer.ConsumeUint64()
+ elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT,
+ descriptor.FieldDescriptor.TYPE_DOUBLE):
+ value = tokenizer.ConsumeFloat()
+ elif field.type == descriptor.FieldDescriptor.TYPE_BOOL:
+ value = tokenizer.ConsumeBool()
+ elif field.type == descriptor.FieldDescriptor.TYPE_STRING:
+ value = tokenizer.ConsumeString()
+ elif field.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ value = tokenizer.ConsumeByteString()
+ elif field.type == descriptor.FieldDescriptor.TYPE_ENUM:
+ value = tokenizer.ConsumeEnum(field)
+ else:
+ raise RuntimeError('Unknown field type %d' % field.type)
+
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ if field.is_extension:
+ message.Extensions[field].append(value)
+ else:
+ getattr(message, field.name).append(value)
+ else:
+ if field.is_extension:
+ if not self._allow_multiple_scalars and message.HasExtension(field):
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" should not have multiple "%s" extensions.' %
+ (message.DESCRIPTOR.full_name, field.full_name))
+ else:
+ message.Extensions[field] = value
+ else:
+ if not self._allow_multiple_scalars and message.HasField(field.name):
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" should not have multiple "%s" fields.' %
+ (message.DESCRIPTOR.full_name, field.name))
+ else:
+ setattr(message, field.name, value)
+
+
+def _SkipFieldContents(tokenizer):
+ """Skips over contents (value or message) of a field.
+
+ Args:
+ tokenizer: A tokenizer to parse the field name and values.
+ """
+ # Try to guess the type of this field.
+ # If this field is not a message, there should be a ":" between the
+ # field name and the field value and also the field value should not
+ # start with "{" or "<" which indicates the beginning of a message body.
+ # If there is no ":" or there is a "{" or "<" after ":", this field has
+ # to be a message or the input is ill-formed.
+ if tokenizer.TryConsume(':') and not tokenizer.LookingAt(
+ '{') and not tokenizer.LookingAt('<'):
+ _SkipFieldValue(tokenizer)
+ else:
+ _SkipFieldMessage(tokenizer)
+
+
+def _SkipField(tokenizer):
+ """Skips over a complete field (name and value/message).
+
+ Args:
+ tokenizer: A tokenizer to parse the field name and values.
+ """
+ if tokenizer.TryConsume('['):
+ # Consume extension name.
+ tokenizer.ConsumeIdentifier()
+ while tokenizer.TryConsume('.'):
+ tokenizer.ConsumeIdentifier()
+ tokenizer.Consume(']')
else:
- _MergeScalarField(tokenizer, message, field, allow_multiple_scalars)
+ tokenizer.ConsumeIdentifier()
+
+ _SkipFieldContents(tokenizer)
# For historical reasons, fields may optionally be separated by commas or
# semicolons.
@@ -389,76 +717,50 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
tokenizer.TryConsume(';')
-def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars):
- """Merges a single protocol message scalar field into a message.
+def _SkipFieldMessage(tokenizer):
+ """Skips over a field message.
Args:
- tokenizer: A tokenizer to parse the field value.
- message: A protocol message to record the data.
- field: The descriptor of the field to be merged.
- allow_multiple_scalars: Determines if repeated values for a non-repeated
- field are permitted, e.g., the string "foo: 1 foo: 2" for a
- required/optional field named "foo".
+ tokenizer: A tokenizer to parse the field name and values.
+ """
+
+ if tokenizer.TryConsume('<'):
+ delimiter = '>'
+ else:
+ tokenizer.Consume('{')
+ delimiter = '}'
+
+ while not tokenizer.LookingAt('>') and not tokenizer.LookingAt('}'):
+ _SkipField(tokenizer)
+
+ tokenizer.Consume(delimiter)
+
+
+def _SkipFieldValue(tokenizer):
+ """Skips over a field value.
+
+ Args:
+ tokenizer: A tokenizer to parse the field name and values.
Raises:
- ParseError: In case of ASCII parsing problems.
- RuntimeError: On runtime errors.
+ ParseError: In case an invalid field value is found.
"""
- tokenizer.Consume(':')
- value = None
-
- if field.type in (descriptor.FieldDescriptor.TYPE_INT32,
- descriptor.FieldDescriptor.TYPE_SINT32,
- descriptor.FieldDescriptor.TYPE_SFIXED32):
- value = tokenizer.ConsumeInt32()
- elif field.type in (descriptor.FieldDescriptor.TYPE_INT64,
- descriptor.FieldDescriptor.TYPE_SINT64,
- descriptor.FieldDescriptor.TYPE_SFIXED64):
- value = tokenizer.ConsumeInt64()
- elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32,
- descriptor.FieldDescriptor.TYPE_FIXED32):
- value = tokenizer.ConsumeUint32()
- elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64,
- descriptor.FieldDescriptor.TYPE_FIXED64):
- value = tokenizer.ConsumeUint64()
- elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT,
- descriptor.FieldDescriptor.TYPE_DOUBLE):
- value = tokenizer.ConsumeFloat()
- elif field.type == descriptor.FieldDescriptor.TYPE_BOOL:
- value = tokenizer.ConsumeBool()
- elif field.type == descriptor.FieldDescriptor.TYPE_STRING:
- value = tokenizer.ConsumeString()
- elif field.type == descriptor.FieldDescriptor.TYPE_BYTES:
- value = tokenizer.ConsumeByteString()
- elif field.type == descriptor.FieldDescriptor.TYPE_ENUM:
- value = tokenizer.ConsumeEnum(field)
- else:
- raise RuntimeError('Unknown field type %d' % field.type)
+ # String/bytes tokens can come in multiple adjacent string literals.
+ # If we can consume one, consume as many as we can.
+ if tokenizer.TryConsumeByteString():
+ while tokenizer.TryConsumeByteString():
+ pass
+ return
- if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- if field.is_extension:
- message.Extensions[field].append(value)
- else:
- getattr(message, field.name).append(value)
- else:
- if field.is_extension:
- if not allow_multiple_scalars and message.HasExtension(field):
- raise tokenizer.ParseErrorPreviousToken(
- 'Message type "%s" should not have multiple "%s" extensions.' %
- (message.DESCRIPTOR.full_name, field.full_name))
- else:
- message.Extensions[field] = value
- else:
- if not allow_multiple_scalars and message.HasField(field.name):
- raise tokenizer.ParseErrorPreviousToken(
- 'Message type "%s" should not have multiple "%s" fields.' %
- (message.DESCRIPTOR.full_name, field.name))
- else:
- setattr(message, field.name, value)
+ if (not tokenizer.TryConsumeIdentifier() and
+ not tokenizer.TryConsumeInt64() and
+ not tokenizer.TryConsumeUint64() and
+ not tokenizer.TryConsumeFloat()):
+ raise ParseError('Invalid field value: ' + tokenizer.token)
class _Tokenizer(object):
- """Protocol buffer ASCII representation tokenizer.
+ """Protocol buffer text representation tokenizer.
This class handles the lower level string parsing by splitting it into
meaningful tokens.
@@ -467,11 +769,13 @@ class _Tokenizer(object):
"""
_WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE)
- _TOKEN = re.compile(
- '[a-zA-Z_][0-9a-zA-Z_+-]*|' # an identifier
- '[0-9+-][0-9a-zA-Z_.+-]*|' # a number
- '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string
- '\'([^\'\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string
+ _TOKEN = re.compile('|'.join([
+ r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier
+ r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number
+ ] + [ # quoted str for each quote mark
+ r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES
+ ]))
+
_IDENTIFIER = re.compile(r'\w+')
def __init__(self, lines):
@@ -488,6 +792,9 @@ class _Tokenizer(object):
self._SkipWhitespace()
self.NextToken()
+ def LookingAt(self, token):
+ return self.token == token
+
def AtEnd(self):
"""Checks the end of the text was reached.
@@ -499,7 +806,7 @@ class _Tokenizer(object):
def _PopLine(self):
while len(self._current_line) <= self._column:
try:
- self._current_line = self._lines.next()
+ self._current_line = next(self._lines)
except StopIteration:
self._current_line = ''
self._more_lines = False
@@ -543,6 +850,13 @@ class _Tokenizer(object):
if not self.TryConsume(token):
raise self._ParseError('Expected "%s".' % token)
+ def TryConsumeIdentifier(self):
+ try:
+ self.ConsumeIdentifier()
+ return True
+ except ParseError:
+ return False
+
def ConsumeIdentifier(self):
"""Consumes protocol message field identifier.
@@ -569,7 +883,7 @@ class _Tokenizer(object):
"""
try:
result = ParseInteger(self.token, is_signed=True, is_long=False)
- except ValueError, e:
+ except ValueError as e:
raise self._ParseError(str(e))
self.NextToken()
return result
@@ -585,11 +899,18 @@ class _Tokenizer(object):
"""
try:
result = ParseInteger(self.token, is_signed=False, is_long=False)
- except ValueError, e:
+ except ValueError as e:
raise self._ParseError(str(e))
self.NextToken()
return result
+ def TryConsumeInt64(self):
+ try:
+ self.ConsumeInt64()
+ return True
+ except ParseError:
+ return False
+
def ConsumeInt64(self):
"""Consumes a signed 64bit integer number.
@@ -601,11 +922,18 @@ class _Tokenizer(object):
"""
try:
result = ParseInteger(self.token, is_signed=True, is_long=True)
- except ValueError, e:
+ except ValueError as e:
raise self._ParseError(str(e))
self.NextToken()
return result
+ def TryConsumeUint64(self):
+ try:
+ self.ConsumeUint64()
+ return True
+ except ParseError:
+ return False
+
def ConsumeUint64(self):
"""Consumes an unsigned 64bit integer number.
@@ -617,11 +945,18 @@ class _Tokenizer(object):
"""
try:
result = ParseInteger(self.token, is_signed=False, is_long=True)
- except ValueError, e:
+ except ValueError as e:
raise self._ParseError(str(e))
self.NextToken()
return result
+ def TryConsumeFloat(self):
+ try:
+ self.ConsumeFloat()
+ return True
+ except ParseError:
+ return False
+
def ConsumeFloat(self):
"""Consumes an floating point number.
@@ -633,7 +968,7 @@ class _Tokenizer(object):
"""
try:
result = ParseFloat(self.token)
- except ValueError, e:
+ except ValueError as e:
raise self._ParseError(str(e))
self.NextToken()
return result
@@ -649,11 +984,18 @@ class _Tokenizer(object):
"""
try:
result = ParseBool(self.token)
- except ValueError, e:
+ except ValueError as e:
raise self._ParseError(str(e))
self.NextToken()
return result
+ def TryConsumeByteString(self):
+ try:
+ self.ConsumeByteString()
+ return True
+ except ParseError:
+ return False
+
def ConsumeString(self):
"""Consumes a string value.
@@ -665,8 +1007,8 @@ class _Tokenizer(object):
"""
the_bytes = self.ConsumeByteString()
try:
- return unicode(the_bytes, 'utf-8')
- except UnicodeDecodeError, e:
+ return six.text_type(the_bytes, 'utf-8')
+ except UnicodeDecodeError as e:
raise self._StringParseError(e)
def ConsumeByteString(self):
@@ -679,10 +1021,9 @@ class _Tokenizer(object):
ParseError: If a byte array value couldn't be consumed.
"""
the_list = [self._ConsumeSingleByteString()]
- while self.token and self.token[0] in ('\'', '"'):
+ while self.token and self.token[0] in _QUOTES:
the_list.append(self._ConsumeSingleByteString())
- return ''.encode('latin1').join(the_list) ##PY25
-##!PY25 return b''.join(the_list)
+ return b''.join(the_list)
def _ConsumeSingleByteString(self):
"""Consume one token of a string literal.
@@ -690,17 +1031,22 @@ class _Tokenizer(object):
String literals (whether bytes or text) can come in multiple adjacent
tokens which are automatically concatenated, like in C or Python. This
method only consumes one token.
+
+ Returns:
+ The token parsed.
+ Raises:
+ ParseError: When the wrong format data is found.
"""
text = self.token
- if len(text) < 1 or text[0] not in ('\'', '"'):
- raise self._ParseError('Expected string.')
+ if len(text) < 1 or text[0] not in _QUOTES:
+ raise self._ParseError('Expected string but found: %r' % (text,))
if len(text) < 2 or text[-1] != text[0]:
- raise self._ParseError('String missing ending quote.')
+ raise self._ParseError('String missing ending quote: %r' % (text,))
try:
result = text_encoding.CUnescape(text[1:-1])
- except ValueError, e:
+ except ValueError as e:
raise self._ParseError(str(e))
self.NextToken()
return result
@@ -708,7 +1054,7 @@ class _Tokenizer(object):
def ConsumeEnum(self, field):
try:
result = ParseEnum(field, self.token)
- except ValueError, e:
+ except ValueError as e:
raise self._ParseError(str(e))
self.NextToken()
return result
diff --git a/python/mox.py b/python/mox.py
index ce80ba505..257468e52 100755
--- a/python/mox.py
+++ b/python/mox.py
@@ -31,7 +31,7 @@ If an unexpected method (or an expected method with unexpected
parameters) is called, then an exception will be raised.
Once you are done interacting with the mock, you need to verify that
-all the expected interactions occured. (Maybe your code exited
+all the expected interactions occurred. (Maybe your code exited
prematurely without calling some cleanup method!) The verify phase
ensures that every expected method was called; otherwise, an exception
will be raised.
diff --git a/python/setup.py b/python/setup.py
index 2450a7743..0f4b53c4a 100755
--- a/python/setup.py
+++ b/python/setup.py
@@ -1,29 +1,24 @@
-#! /usr/bin/python
+#! /usr/bin/env python
#
# See README for usage instructions.
-import sys
+import glob
import os
import subprocess
+import sys
# We must use setuptools, not distutils, because we need to use the
# namespace_packages option for the "google" package.
-try:
- from setuptools import setup, Extension
-except ImportError:
- try:
- from ez_setup import use_setuptools
- use_setuptools()
- from setuptools import setup, Extension
- except ImportError:
- sys.stderr.write(
- "Could not import setuptools; make sure you have setuptools or "
- "ez_setup installed.\n")
- raise
+from setuptools import setup, Extension, find_packages
+
from distutils.command.clean import clean as _clean
-from distutils.command.build_py import build_py as _build_py
-from distutils.spawn import find_executable
-maintainer_email = "protobuf@googlegroups.com"
+if sys.version_info[0] == 3:
+ # Python 3
+ from distutils.command.build_py import build_py_2to3 as _build_py
+else:
+ # Python 2
+ from distutils.command.build_py import build_py as _build_py
+from distutils.spawn import find_executable
# Find the Protocol Compiler.
if 'PROTOC' in os.environ and os.path.exists(os.environ['PROTOC']):
@@ -39,23 +34,38 @@ elif os.path.exists("../vsprojects/Release/protoc.exe"):
else:
protoc = find_executable("protoc")
-def generate_proto(source):
+
+def GetVersion():
+ """Gets the version from google/protobuf/__init__.py
+
+ Do not import google.protobuf.__init__ directly, because an installed
+ protobuf library may be loaded instead."""
+
+ with open(os.path.join('google', 'protobuf', '__init__.py')) as version_file:
+ exec(version_file.read(), globals())
+ return __version__
+
+
+def generate_proto(source, require = True):
"""Invokes the Protocol Compiler to generate a _pb2.py from the given
.proto file. Does nothing if the output already exists and is newer than
the input."""
+ if not require and not os.path.exists(source):
+ return
+
output = source.replace(".proto", "_pb2.py").replace("../src/", "")
if (not os.path.exists(output) or
(os.path.exists(source) and
os.path.getmtime(source) > os.path.getmtime(output))):
- print ("Generating %s..." % output)
+ print("Generating %s..." % output)
if not os.path.exists(source):
sys.stderr.write("Can't find required file: %s\n" % source)
sys.exit(-1)
- if protoc == None:
+ if protoc is None:
sys.stderr.write(
"protoc is not installed nor found in ../src. Please compile it "
"or install the binary package.\n")
@@ -66,39 +76,35 @@ def generate_proto(source):
sys.exit(-1)
def GenerateUnittestProtos():
- generate_proto("../src/google/protobuf/unittest.proto")
- generate_proto("../src/google/protobuf/unittest_custom_options.proto")
- generate_proto("../src/google/protobuf/unittest_import.proto")
- generate_proto("../src/google/protobuf/unittest_import_public.proto")
- generate_proto("../src/google/protobuf/unittest_mset.proto")
- generate_proto("../src/google/protobuf/unittest_no_generic_services.proto")
- generate_proto("google/protobuf/internal/descriptor_pool_test1.proto")
- generate_proto("google/protobuf/internal/descriptor_pool_test2.proto")
- generate_proto("google/protobuf/internal/test_bad_identifiers.proto")
- generate_proto("google/protobuf/internal/missing_enum_values.proto")
- generate_proto("google/protobuf/internal/more_extensions.proto")
- generate_proto("google/protobuf/internal/more_extensions_dynamic.proto")
- generate_proto("google/protobuf/internal/more_messages.proto")
- generate_proto("google/protobuf/internal/factory_test1.proto")
- generate_proto("google/protobuf/internal/factory_test2.proto")
- generate_proto("google/protobuf/pyext/python.proto")
-
-def MakeTestSuite():
- # Test C++ implementation
- import unittest
- import google.protobuf.pyext.descriptor_cpp2_test as descriptor_cpp2_test
- import google.protobuf.pyext.message_factory_cpp2_test \
- as message_factory_cpp2_test
- import google.protobuf.pyext.reflection_cpp2_generated_test \
- as reflection_cpp2_generated_test
-
- loader = unittest.defaultTestLoader
- suite = unittest.TestSuite()
- for test in [ descriptor_cpp2_test,
- message_factory_cpp2_test,
- reflection_cpp2_generated_test]:
- suite.addTest(loader.loadTestsFromModule(test))
- return suite
+ generate_proto("../src/google/protobuf/map_unittest.proto", False)
+ generate_proto("../src/google/protobuf/unittest_arena.proto", False)
+ generate_proto("../src/google/protobuf/unittest_no_arena.proto", False)
+ generate_proto("../src/google/protobuf/unittest_no_arena_import.proto", False)
+ generate_proto("../src/google/protobuf/unittest.proto", False)
+ generate_proto("../src/google/protobuf/unittest_custom_options.proto", False)
+ generate_proto("../src/google/protobuf/unittest_import.proto", False)
+ generate_proto("../src/google/protobuf/unittest_import_public.proto", False)
+ generate_proto("../src/google/protobuf/unittest_mset.proto", False)
+ generate_proto("../src/google/protobuf/unittest_mset_wire_format.proto", False)
+ generate_proto("../src/google/protobuf/unittest_no_generic_services.proto", False)
+ generate_proto("../src/google/protobuf/unittest_proto3_arena.proto", False)
+ generate_proto("../src/google/protobuf/util/json_format_proto3.proto", False)
+ generate_proto("google/protobuf/internal/any_test.proto", False)
+ generate_proto("google/protobuf/internal/descriptor_pool_test1.proto", False)
+ generate_proto("google/protobuf/internal/descriptor_pool_test2.proto", False)
+ generate_proto("google/protobuf/internal/factory_test1.proto", False)
+ generate_proto("google/protobuf/internal/factory_test2.proto", False)
+ generate_proto("google/protobuf/internal/import_test_package/inner.proto", False)
+ generate_proto("google/protobuf/internal/import_test_package/outer.proto", False)
+ generate_proto("google/protobuf/internal/missing_enum_values.proto", False)
+ generate_proto("google/protobuf/internal/message_set_extensions.proto", False)
+ generate_proto("google/protobuf/internal/more_extensions.proto", False)
+ generate_proto("google/protobuf/internal/more_extensions_dynamic.proto", False)
+ generate_proto("google/protobuf/internal/more_messages.proto", False)
+ generate_proto("google/protobuf/internal/packed_field_test.proto", False)
+ generate_proto("google/protobuf/internal/test_bad_identifiers.proto", False)
+ generate_proto("google/protobuf/pyext/python.proto", False)
+
class clean(_clean):
def run(self):
@@ -108,7 +114,8 @@ class clean(_clean):
filepath = os.path.join(dirpath, filename)
if filepath.endswith("_pb2.py") or filepath.endswith(".pyc") or \
filepath.endswith(".so") or filepath.endswith(".o") or \
- filepath.endswith('google/protobuf/compiler/__init__.py'):
+ filepath.endswith('google/protobuf/compiler/__init__.py') or \
+ filepath.endswith('google/protobuf/util/__init__.py'):
os.remove(filepath)
# _clean is an old-style class, so super() doesn't work.
_clean.run(self)
@@ -118,84 +125,126 @@ class build_py(_build_py):
# Generate necessary .proto file if it doesn't exist.
generate_proto("../src/google/protobuf/descriptor.proto")
generate_proto("../src/google/protobuf/compiler/plugin.proto")
+ generate_proto("../src/google/protobuf/any.proto")
+ generate_proto("../src/google/protobuf/api.proto")
+ generate_proto("../src/google/protobuf/duration.proto")
+ generate_proto("../src/google/protobuf/empty.proto")
+ generate_proto("../src/google/protobuf/field_mask.proto")
+ generate_proto("../src/google/protobuf/source_context.proto")
+ generate_proto("../src/google/protobuf/struct.proto")
+ generate_proto("../src/google/protobuf/timestamp.proto")
+ generate_proto("../src/google/protobuf/type.proto")
+ generate_proto("../src/google/protobuf/wrappers.proto")
GenerateUnittestProtos()
# Make sure google.protobuf/** are valid packages.
- for path in ['', 'internal/', 'compiler/', 'pyext/']:
+ for path in ['', 'internal/', 'compiler/', 'pyext/', 'util/']:
try:
open('google/protobuf/%s__init__.py' % path, 'a').close()
except EnvironmentError:
pass
# _build_py is an old-style class, so super() doesn't work.
_build_py.run(self)
- # TODO(mrovner): Subclass to run 2to3 on some files only.
- # Tracing what https://wiki.python.org/moin/PortingPythonToPy3k's "Approach 2"
- # section on how to get 2to3 to run on source files during install under
- # Python 3. This class seems like a good place to put logic that calls
- # python3's distutils.util.run_2to3 on the subset of the files we have in our
- # release that are subject to conversion.
- # See code reference in previous code review.
+
+class test_conformance(_build_py):
+ target = 'test_python'
+ def run(self):
+ if sys.version_info >= (2, 7):
+ # Python 2.6 dodges these extra failures.
+ os.environ["CONFORMANCE_PYTHON_EXTRA_FAILURES"] = (
+ "--failure_list failure_list_python-post26.txt")
+ cmd = 'cd ../conformance && make %s' % (test_conformance.target)
+ status = subprocess.check_call(cmd, shell=True)
+
+
+def get_option_from_sys_argv(option_str):
+ if option_str in sys.argv:
+ sys.argv.remove(option_str)
+ return True
+ return False
+
if __name__ == '__main__':
ext_module_list = []
- cpp_impl = '--cpp_implementation'
- if cpp_impl in sys.argv:
- sys.argv.remove(cpp_impl)
+ warnings_as_errors = '--warnings_as_errors'
+ if get_option_from_sys_argv('--cpp_implementation'):
+ # Link libprotobuf.a and libprotobuf-lite.a statically with the
+ # extension. Note that those libraries have to be compiled with
+ # -fPIC for this to work.
+ compile_static_ext = get_option_from_sys_argv('--compile_static_extension')
+ extra_compile_args = ['-Wno-write-strings',
+ '-Wno-invalid-offsetof',
+ '-Wno-sign-compare']
+ libraries = ['protobuf']
+ extra_objects = None
+ if compile_static_ext:
+ libraries = None
+ extra_objects = ['../src/.libs/libprotobuf.a',
+ '../src/.libs/libprotobuf-lite.a']
+ test_conformance.target = 'test_python_cpp'
+
+ if "clang" in os.popen('$CC --version 2> /dev/null').read():
+ extra_compile_args.append('-Wno-shorten-64-to-32')
+
+ if warnings_as_errors in sys.argv:
+ extra_compile_args.append('-Werror')
+ sys.argv.remove(warnings_as_errors)
+
# C++ implementation extension
- ext_module_list.append(Extension(
- "google.protobuf.pyext._message",
- [ "google/protobuf/pyext/descriptor.cc",
- "google/protobuf/pyext/message.cc",
- "google/protobuf/pyext/extension_dict.cc",
- "google/protobuf/pyext/repeated_scalar_container.cc",
- "google/protobuf/pyext/repeated_composite_container.cc" ],
- define_macros=[('GOOGLE_PROTOBUF_HAS_ONEOF', '1')],
- include_dirs = [ ".", "../src"],
- libraries = [ "protobuf" ],
- library_dirs = [ '../src/.libs' ],
- ))
-
- setup(name = 'protobuf',
- version = '2.6.1',
- packages = [ 'google' ],
- namespace_packages = [ 'google' ],
- test_suite = 'setup.MakeTestSuite',
- google_test_dir = "google/protobuf/internal",
- # Must list modules explicitly so that we don't install tests.
- py_modules = [
- 'google.protobuf.internal.api_implementation',
- 'google.protobuf.internal.containers',
- 'google.protobuf.internal.cpp_message',
- 'google.protobuf.internal.decoder',
- 'google.protobuf.internal.encoder',
- 'google.protobuf.internal.enum_type_wrapper',
- 'google.protobuf.internal.message_listener',
- 'google.protobuf.internal.python_message',
- 'google.protobuf.internal.type_checkers',
- 'google.protobuf.internal.wire_format',
- 'google.protobuf.descriptor',
- 'google.protobuf.descriptor_pb2',
- 'google.protobuf.compiler.plugin_pb2',
- 'google.protobuf.message',
- 'google.protobuf.descriptor_database',
- 'google.protobuf.descriptor_pool',
- 'google.protobuf.message_factory',
- 'google.protobuf.pyext.cpp_message',
- 'google.protobuf.reflection',
- 'google.protobuf.service',
- 'google.protobuf.service_reflection',
- 'google.protobuf.symbol_database',
- 'google.protobuf.text_encoding',
- 'google.protobuf.text_format'],
- cmdclass = { 'clean': clean, 'build_py': build_py },
- install_requires = ['setuptools'],
- setup_requires = ['google-apputils'],
- ext_modules = ext_module_list,
- url = 'https://developers.google.com/protocol-buffers/',
- maintainer = maintainer_email,
- maintainer_email = 'protobuf@googlegroups.com',
- license = 'New BSD License',
- description = 'Protocol Buffers',
- long_description =
- "Protocol Buffers are Google's data interchange format.",
- )
+ ext_module_list.extend([
+ Extension(
+ "google.protobuf.pyext._message",
+ glob.glob('google/protobuf/pyext/*.cc'),
+ include_dirs=[".", "../src"],
+ libraries=libraries,
+ extra_objects=extra_objects,
+ library_dirs=['../src/.libs'],
+ extra_compile_args=extra_compile_args,
+ ),
+ Extension(
+ "google.protobuf.internal._api_implementation",
+ glob.glob('google/protobuf/internal/api_implementation.cc'),
+ extra_compile_args=['-DPYTHON_PROTO2_CPP_IMPL_V2'],
+ ),
+ ])
+ os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp'
+
+ # Keep this list of dependencies in sync with tox.ini.
+ install_requires = ['six>=1.9', 'setuptools']
+ if sys.version_info <= (2,7):
+ install_requires.append('ordereddict')
+ install_requires.append('unittest2')
+
+ setup(
+ name='protobuf',
+ version=GetVersion(),
+ description='Protocol Buffers',
+ long_description="Protocol Buffers are Google's data interchange format",
+ url='https://developers.google.com/protocol-buffers/',
+ maintainer='protobuf@googlegroups.com',
+ maintainer_email='protobuf@googlegroups.com',
+ license='New BSD License',
+ classifiers=[
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 2",
+ "Programming Language :: Python :: 2.6",
+ "Programming Language :: Python :: 2.7",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.3",
+ "Programming Language :: Python :: 3.4",
+ ],
+ namespace_packages=['google'],
+ packages=find_packages(
+ exclude=[
+ 'import_test_package',
+ ],
+ ),
+ test_suite='google.protobuf.internal',
+ cmdclass={
+ 'clean': clean,
+ 'build_py': build_py,
+ 'test_conformance': test_conformance,
+ },
+ install_requires=install_requires,
+ ext_modules=ext_module_list,
+ )
diff --git a/python/tox.ini b/python/tox.ini
new file mode 100644
index 000000000..cf8d54016
--- /dev/null
+++ b/python/tox.ini
@@ -0,0 +1,24 @@
+[tox]
+envlist =
+ py{26,27,33,34}-{cpp,python}
+
+[testenv]
+usedevelop=true
+passenv = CC
+setenv =
+ cpp: LD_LIBRARY_PATH={toxinidir}/../src/.libs
+ cpp: DYLD_LIBRARY_PATH={toxinidir}/../src/.libs
+ cpp: PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp
+commands =
+ python setup.py -q build_py
+ python: python setup.py -q build
+ cpp: python setup.py -q build --cpp_implementation --warnings_as_errors
+ python: python setup.py -q test -q
+ cpp: python setup.py -q test -q --cpp_implementation
+ python: python setup.py -q test_conformance
+ cpp: python setup.py -q test_conformance --cpp_implementation
+deps =
+ # Keep this list of dependencies in sync with setup.py.
+ six>=1.9
+ py26: ordereddict
+ py26: unittest2