diff options
author | Qijiang Fan <fqj@google.com> | 2020-07-01 03:06:58 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2020-07-01 03:06:58 +0000 |
commit | d4ab54f37549cae4f62548313fdea4f1a3041613 (patch) | |
tree | b03c4daf3bbbcdc96829053f727f4d7ea25a36e0 | |
parent | 8a8fc45fe2c26ec3d9582f748bf9aa12c6fc0e7a (diff) | |
parent | e2e0f82d2ce199f17606c69b51b16b2d6793a433 (diff) | |
download | platform_external_libbrillo-d4ab54f37549cae4f62548313fdea4f1a3041613.tar.gz platform_external_libbrillo-d4ab54f37549cae4f62548313fdea4f1a3041613.tar.bz2 platform_external_libbrillo-d4ab54f37549cae4f62548313fdea4f1a3041613.zip |
Merge commit 'cf6c031ec38f6932125148cdaceb984b07ec5052' into HEAD am: e2e0f82d2c
Original change: https://android-review.googlesource.com/c/platform/external/libbrillo/+/1353562
Change-Id: I9fc110e7621f37f1cafb644b26d0a0f4c5fcbd4e
224 files changed, 9119 insertions, 1975 deletions
@@ -22,7 +22,6 @@ libbrillo_core_sources = [ "brillo/errors/error.cc", "brillo/errors/error_codes.cc", "brillo/flag_helper.cc", - "brillo/imageloader/manifest.cc", "brillo/key_value_store.cc", "brillo/message_loops/base_message_loop.cc", "brillo/message_loops/message_loop.cc", @@ -85,46 +84,48 @@ libbrillo_test_helpers_sources = [ ] libbrillo_test_sources = [ - "brillo/asynchronous_signal_handler_unittest.cc", - "brillo/backoff_entry_unittest.cc", - "brillo/data_encoding_unittest.cc", - "brillo/enum_flags_unittest.cc", - "brillo/errors/error_codes_unittest.cc", - "brillo/errors/error_unittest.cc", - "brillo/file_utils_unittest.cc", - "brillo/flag_helper_unittest.cc", - "brillo/http/http_connection_curl_unittest.cc", - "brillo/http/http_form_data_unittest.cc", - "brillo/http/http_request_unittest.cc", - "brillo/http/http_transport_curl_unittest.cc", - "brillo/http/http_utils_unittest.cc", - "brillo/imageloader/manifest_unittest.cc", - "brillo/key_value_store_unittest.cc", - "brillo/map_utils_unittest.cc", - "brillo/message_loops/base_message_loop_unittest.cc", - "brillo/message_loops/fake_message_loop_unittest.cc", - "brillo/mime_utils_unittest.cc", - "brillo/osrelease_reader_unittest.cc", - "brillo/process_reaper_unittest.cc", - "brillo/process_unittest.cc", - "brillo/secure_blob_unittest.cc", - "brillo/streams/fake_stream_unittest.cc", - "brillo/streams/file_stream_unittest.cc", - "brillo/streams/input_stream_set_unittest.cc", - "brillo/streams/memory_containers_unittest.cc", - "brillo/streams/memory_stream_unittest.cc", - "brillo/streams/openssl_stream_bio_unittests.cc", - "brillo/streams/stream_unittest.cc", - "brillo/streams/stream_utils_unittest.cc", - "brillo/strings/string_utils_unittest.cc", + "brillo/asynchronous_signal_handler_test.cc", + "brillo/backoff_entry_test.cc", + "brillo/data_encoding_test.cc", + "brillo/enum_flags_test.cc", + "brillo/errors/error_codes_test.cc", + "brillo/errors/error_test.cc", + "brillo/file_utils_test.cc", + "brillo/flag_helper_test.cc", + "brillo/http/http_connection_curl_test.cc", + "brillo/http/http_form_data_test.cc", + "brillo/http/http_request_test.cc", + "brillo/http/http_transport_curl_test.cc", + "brillo/http/http_utils_test.cc", + "brillo/key_value_store_test.cc", + "brillo/map_utils_test.cc", + "brillo/message_loops/base_message_loop_test.cc", + "brillo/message_loops/fake_message_loop_test.cc", + "brillo/mime_utils_test.cc", + "brillo/osrelease_reader_test.cc", + "brillo/process_reaper_test.cc", + "brillo/process_test.cc", + "brillo/secure_blob_test.cc", + "brillo/streams/fake_stream_test.cc", + "brillo/streams/file_stream_test.cc", + "brillo/streams/input_stream_set_test.cc", + "brillo/streams/memory_containers_test.cc", + "brillo/streams/memory_stream_test.cc", + "brillo/streams/openssl_stream_bio_test.cc", + "brillo/streams/stream_test.cc", + "brillo/streams/stream_utils_test.cc", + "brillo/strings/string_utils_test.cc", "brillo/unittest_utils.cc", - "brillo/url_utils_unittest.cc", - "brillo/value_conversion_unittest.cc", + "brillo/url_utils_test.cc", + "brillo/value_conversion_test.cc", ] libbrillo_CFLAGS = [ "-Wall", "-Werror", + "-Wno-non-virtual-dtor", + "-Wno-unused-parameter", + "-Wno-unused-variable", ] libbrillo_shared_libraries = ["libchrome"] diff --git a/BUILD.gn b/BUILD.gn new file mode 100644 index 0000000..f30f945 --- /dev/null +++ b/BUILD.gn @@ -0,0 +1,632 @@ +# Copyright 2019 The Chromium OS Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import("//common-mk/deps.gni") +import("//common-mk/pkg_config.gni") +import("//common-mk/proto_library.gni") + +group("all") { + deps = [ + ":libbrillo-${libbase_ver}", + ":libbrillo-glib-${libbase_ver}", + ":libbrillo-test-${libbase_ver}", + ":libinstallattributes-${libbase_ver}", + ":libpolicy-${libbase_ver}", + ] + if (use.test) { + deps += [ + ":libbrillo-${libbase_ver}_tests", + ":libinstallattributes-${libbase_ver}_tests", + ":libpolicy-${libbase_ver}_tests", + ] + } + if (use.fuzzer) { + deps += [ + ":libbrillo_data_encoding_fuzzer", + ":libbrillo_dbus_data_serialization_fuzzer", + ":libbrillo_http_form_data_fuzzer", + ] + } +} + +default_pkg_deps = [ "libchrome-${libbase_ver}" ] +pkg_config("target_defaults_pkg_deps") { + pkg_deps = default_pkg_deps +} + +config("target_defaults") { + configs = [ ":target_defaults_pkg_deps" ] + include_dirs = [ "../libbrillo" ] + defines = [ + "USE_DBUS=${use.dbus}", + "USE_RTTI_FOR_TYPE_TAGS", + ] +} + +config("libbrillo-${libbase_ver}_configs") { + include_dirs = [ "../libbrillo" ] +} + +# Properties of shared libraries which libbrillo consists of. +# Stored to variables once before actually declaring the targets, so that +# another target can collect information for making the .pc and .so files. +libbrillo_sublibs = [ + { + # |library_name| is library file name without "lib" prefix. This is needed + # for composing -l*** flags in libbrillo-${libbasever}.so. + # (Current version of GN deployed to ChromeOS doesn't have string_replace.) + library_name = "brillo-core-${libbase_ver}" + if (use.dbus) { + all_dependent_pkg_deps = [ "dbus-1" ] + } + libs = [ "modp_b64" ] + sources = [ + "brillo/asynchronous_signal_handler.cc", + "brillo/backoff_entry.cc", + "brillo/daemons/daemon.cc", + "brillo/data_encoding.cc", + "brillo/errors/error.cc", + "brillo/errors/error_codes.cc", + "brillo/file_utils.cc", + "brillo/files/file_util.cc", + "brillo/files/safe_fd.cc", + "brillo/flag_helper.cc", + "brillo/key_value_store.cc", + "brillo/message_loops/base_message_loop.cc", + "brillo/message_loops/message_loop.cc", + "brillo/message_loops/message_loop_utils.cc", + "brillo/mime_utils.cc", + "brillo/osrelease_reader.cc", + "brillo/process.cc", + "brillo/process_information.cc", + "brillo/process_reaper.cc", + "brillo/scoped_mount_namespace.cc", + "brillo/scoped_umask.cc", + "brillo/secure_blob.cc", + "brillo/strings/string_utils.cc", + "brillo/syslog_logging.cc", + "brillo/type_name_undecorate.cc", + "brillo/url_utils.cc", + "brillo/userdb_utils.cc", + "brillo/value_conversion.cc", + ] + if (use.dbus) { + sources += [ + "brillo/any.cc", + "brillo/daemons/dbus_daemon.cc", + "brillo/dbus/async_event_sequencer.cc", + "brillo/dbus/data_serialization.cc", + "brillo/dbus/dbus_connection.cc", + "brillo/dbus/dbus_method_invoker.cc", + "brillo/dbus/dbus_method_response.cc", + "brillo/dbus/dbus_object.cc", + "brillo/dbus/dbus_service_watcher.cc", + "brillo/dbus/dbus_signal.cc", + "brillo/dbus/exported_object_manager.cc", + "brillo/dbus/exported_property_set.cc", + "brillo/dbus/introspectable_helper.cc", + "brillo/dbus/utils.cc", + ] + } + }, + + { + library_name = "brillo-blockdeviceutils-${libbase_ver}" + deps = [ + ":libbrillo-core-${libbase_ver}", + ] + sources = [ + "brillo/blkdev_utils/loop_device.cc", + ] + if (use.device_mapper) { + pkg_deps = [ "devmapper" ] + sources += [ + "brillo/blkdev_utils/device_mapper.cc", + "brillo/blkdev_utils/device_mapper_task.cc", + ] + } + }, + + { + library_name = "brillo-http-${libbase_ver}" + deps = [ + ":libbrillo-core-${libbase_ver}", + ":libbrillo-streams-${libbase_ver}", + ] + all_dependent_pkg_deps = [ "libcurl" ] + sources = [ + "brillo/http/curl_api.cc", + "brillo/http/http_connection_curl.cc", + "brillo/http/http_form_data.cc", + "brillo/http/http_request.cc", + "brillo/http/http_transport.cc", + "brillo/http/http_transport_curl.cc", + "brillo/http/http_utils.cc", + ] + if (use.dbus) { + sources += [ "brillo/http/http_proxy.cc" ] + } + }, + + { + library_name = "brillo-streams-${libbase_ver}" + deps = [ + ":libbrillo-core-${libbase_ver}", + ] + all_dependent_pkg_deps = [ "openssl" ] + sources = [ + "brillo/streams/file_stream.cc", + "brillo/streams/input_stream_set.cc", + "brillo/streams/memory_containers.cc", + "brillo/streams/memory_stream.cc", + "brillo/streams/openssl_stream_bio.cc", + "brillo/streams/stream.cc", + "brillo/streams/stream_errors.cc", + "brillo/streams/stream_utils.cc", + "brillo/streams/tls_stream.cc", + ] + }, + + { + library_name = "brillo-cryptohome-${libbase_ver}" + all_dependent_pkg_deps = [ "openssl" ] + sources = [ + "brillo/cryptohome.cc", + ] + }, + + { + library_name = "brillo-minijail-${libbase_ver}" + all_dependent_pkg_deps = [ "libminijail" ] + sources = [ + "brillo/minijail/minijail.cc", + ] + }, + + { + library_name = "brillo-protobuf-${libbase_ver}" + all_dependent_pkg_deps = [ "protobuf" ] + sources = [ + "brillo/proto_file_io.cc", + ] + }, +] + +if (use.udev) { + libbrillo_sublibs += [ + { + library_name = "brillo-udev-${libbase_ver}" + all_dependent_pkg_deps = [ "libudev" ] + sources = [ + "brillo/udev/udev.cc", + "brillo/udev/udev_device.cc", + "brillo/udev/udev_enumerate.cc", + "brillo/udev/udev_list_entry.cc", + "brillo/udev/udev_monitor.cc", + ] + }, + ] +} + +# Generate shared libraries. +foreach(attr, libbrillo_sublibs) { + shared_library("lib" + attr.library_name) { + sources = attr.sources + if (defined(attr.deps)) { + deps = attr.deps + } + if (defined(attr.libs)) { + libs = attr.libs + } + if (defined(attr.pkg_deps)) { + pkg_deps = attr.pkg_deps + } + if (defined(attr.public_pkg_deps)) { + public_pkg_deps = attr.public_pkg_deps + } + if (defined(attr.all_dependent_pkg_deps)) { + all_dependent_pkg_deps = attr.all_dependent_pkg_deps + } + if (defined(attr.cflags)) { + cflags = attr.cflags + } + if (defined(attr.configs)) { + configs += attr.configs + } + configs += [ ":target_defaults" ] + } +} + +generate_pkg_config("libbrillo_pc") { + name = "libbrillo" + output_name = "libbrillo-${libbase_ver}" + description = "brillo base library" + version = libbase_ver + requires_private = default_pkg_deps + foreach(sublib, libbrillo_sublibs) { + if (defined(sublib.pkg_deps)) { + requires_private += sublib.pkg_deps + } + if (defined(sublib.public_pkg_deps)) { + requires_private += sublib.public_pkg_deps + } + if (defined(sublib.all_dependent_pkg_deps)) { + requires_private += sublib.all_dependent_pkg_deps + } + } + defines = [ "USE_RTTI_FOR_TYPE_TAGS" ] + libs = [ "-lbrillo-${libbase_ver}" ] +} + +action("libbrillo-${libbase_ver}") { + deps = [ + ":libbrillo_pc", + ] + foreach(sublib, libbrillo_sublibs) { + deps += [ ":lib" + sublib.library_name ] + } + script = "//common-mk/write_args.py" + outputs = [ + "${root_out_dir}/lib/libbrillo-${libbase_ver}.so", + ] + args = [ "--output" ] + outputs + [ "--" ] + [ + "GROUP", + "(", + "AS_NEEDED", + "(", + ] + foreach(sublib, libbrillo_sublibs) { + args += [ "-l" + sublib.library_name ] + } + args += [ + ")", + ")", + ] +} + +libbrillo_test_deps = [ "libbrillo-http-${libbase_ver}" ] + +generate_pkg_config("libbrillo-test_pc") { + name = "libbrillo-test" + output_name = "libbrillo-test-${libbase_ver}" + description = "brillo test library" + version = libbase_ver + + # Because libbrillo-test is static, we have to depend directly on everything. + requires = [ "libbrillo-${libbase_ver}" ] + default_pkg_deps + foreach(name, libbrillo_test_deps) { + foreach(t, libbrillo_sublibs) { + if ("lib" + t.library_name == name) { + if (defined(t.pkg_deps)) { + requires += t.pkg_deps + } + if (defined(t.public_pkg_deps)) { + requires += t.public_pkg_deps + } + if (defined(t.all_dependent_pkg_deps)) { + requires += t.all_dependent_pkg_deps + } + } + } + } + libs = [ "-lbrillo-test-${libbase_ver}" ] +} + +static_library("libbrillo-test-${libbase_ver}") { + configs -= [ "//common-mk:use_thin_archive" ] + configs += [ + "//common-mk:nouse_thin_archive", + ":target_defaults", + ] + deps = [ + ":libbrillo-http-${libbase_ver}", + ":libbrillo-test_pc", + ] + foreach(name, libbrillo_test_deps) { + deps += [ ":" + name ] + } + sources = [ + "brillo/blkdev_utils/loop_device_fake.cc", + "brillo/http/http_connection_fake.cc", + "brillo/http/http_transport_fake.cc", + "brillo/message_loops/fake_message_loop.cc", + "brillo/streams/fake_stream.cc", + "brillo/unittest_utils.cc", + ] + if (use.device_mapper) { + sources += [ "brillo/blkdev_utils/device_mapper_fake.cc" ] + } +} + +shared_library("libinstallattributes-${libbase_ver}") { + configs += [ ":target_defaults" ] + deps = [ + ":libinstallattributes-includes", + "../common-mk/external_dependencies:install_attributes-proto", + ] + all_dependent_pkg_deps = [ "protobuf-lite" ] + sources = [ + "install_attributes/libinstallattributes.cc", + ] +} + +shared_library("libpolicy-${libbase_ver}") { + configs += [ ":target_defaults" ] + deps = [ + ":libinstallattributes-${libbase_ver}", + ":libpolicy-includes", + "../common-mk/external_dependencies:policy-protos", + ] + all_dependent_pkg_deps = [ + "openssl", + "protobuf-lite", + ] + ldflags = [ "-Wl,--version-script,${platform2_root}/libbrillo/libpolicy.ver" ] + sources = [ + "policy/device_policy.cc", + "policy/device_policy_impl.cc", + "policy/libpolicy.cc", + "policy/policy_util.cc", + "policy/resilient_policy_util.cc", + ] +} + +libbrillo_glib_pkg_deps = [ + "glib-2.0", + "gobject-2.0", +] +if (use.dbus) { + libbrillo_glib_pkg_deps += [ + "dbus-1", + "dbus-glib-1", + ] +} + +generate_pkg_config("libbrillo-glib_pc") { + name = "libbrillo-glib" + output_name = "libbrillo-glib-${libbase_ver}" + description = "brillo glib wrapper library" + version = libbase_ver + requires_private = libbrillo_glib_pkg_deps + libs = [ "-lbrillo-glib-${libbase_ver}" ] +} + +shared_library("libbrillo-glib-${libbase_ver}") { + configs += [ ":target_defaults" ] + deps = [ + ":libbrillo-${libbase_ver}", + ":libbrillo-glib_pc", + ] + all_dependent_pkg_deps = libbrillo_glib_pkg_deps + if (use.dbus) { + sources = [ + "brillo/glib/abstract_dbus_service.cc", + "brillo/glib/dbus.cc", + ] + } + cflags = [ + # glib uses the deprecated "register" attribute in some header files. + "-Wno-deprecated-register", + ] +} + +if (use.test) { + static_library("libbrillo-${libbase_ver}_static") { + configs += [ ":target_defaults" ] + deps = [ + ":libbrillo_pc", + ":libinstallattributes-${libbase_ver}", + ":libpolicy-${libbase_ver}", + ] + foreach(sublib, libbrillo_sublibs) { + deps += [ ":lib" + sublib.library_name ] + } + public_configs = [ ":libbrillo-${libbase_ver}_configs" ] + } + proto_library("libbrillo-${libbase_ver}_tests_proto") { + proto_in_dir = "brillo/dbus" + proto_out_dir = "include/brillo/dbus" + sources = [ + "${proto_in_dir}/test.proto", + ] + } + executable("libbrillo-${libbase_ver}_tests") { + configs += [ + "//common-mk:test", + ":target_defaults", + ] + deps = [ + ":libbrillo-${libbase_ver}_static", + ":libbrillo-${libbase_ver}_tests_proto", + ":libbrillo-glib-${libbase_ver}", + ":libbrillo-test-${libbase_ver}", + ] + pkg_deps = [ "libchrome-test-${libbase_ver}" ] + cflags = [ "-Wno-format-zero-length" ] + + if (is_debug) { + cflags += [ + "-fprofile-arcs", + "-ftest-coverage", + "-fno-inline", + ] + libs = [ "gcov" ] + } + sources = [ + "brillo/asynchronous_signal_handler_test.cc", + "brillo/backoff_entry_test.cc", + "brillo/blkdev_utils/loop_device_test.cc", + "brillo/data_encoding_test.cc", + "brillo/enum_flags_test.cc", + "brillo/errors/error_codes_test.cc", + "brillo/errors/error_test.cc", + "brillo/file_utils_test.cc", + "brillo/files/file_util_test.cc", + "brillo/files/safe_fd_test.cc", + "brillo/flag_helper_test.cc", + "brillo/glib/object_test.cc", + "brillo/http/http_connection_curl_test.cc", + "brillo/http/http_form_data_test.cc", + "brillo/http/http_request_test.cc", + "brillo/http/http_transport_curl_test.cc", + "brillo/http/http_utils_test.cc", + "brillo/key_value_store_test.cc", + "brillo/map_utils_test.cc", + "brillo/message_loops/base_message_loop_test.cc", + "brillo/message_loops/fake_message_loop_test.cc", + "brillo/message_loops/message_loop_test.cc", + "brillo/mime_utils_test.cc", + "brillo/osrelease_reader_test.cc", + "brillo/process_reaper_test.cc", + "brillo/process_test.cc", + "brillo/scoped_umask_test.cc", + "brillo/secure_blob_test.cc", + "brillo/streams/fake_stream_test.cc", + "brillo/streams/file_stream_test.cc", + "brillo/streams/input_stream_set_test.cc", + "brillo/streams/memory_containers_test.cc", + "brillo/streams/memory_stream_test.cc", + "brillo/streams/openssl_stream_bio_test.cc", + "brillo/streams/stream_test.cc", + "brillo/streams/stream_utils_test.cc", + "brillo/strings/string_utils_test.cc", + "brillo/unittest_utils.cc", + "brillo/url_utils_test.cc", + "brillo/value_conversion_test.cc", + "testrunner.cc", + ] + if (use.dbus) { + sources += [ + "brillo/any_internal_impl_test.cc", + "brillo/any_test.cc", + "brillo/dbus/async_event_sequencer_test.cc", + "brillo/dbus/data_serialization_test.cc", + "brillo/dbus/dbus_method_invoker_test.cc", + "brillo/dbus/dbus_object_test.cc", + "brillo/dbus/dbus_param_reader_test.cc", + "brillo/dbus/dbus_param_writer_test.cc", + "brillo/dbus/dbus_signal_handler_test.cc", + "brillo/dbus/exported_object_manager_test.cc", + "brillo/dbus/exported_property_set_test.cc", + "brillo/http/http_proxy_test.cc", + "brillo/type_name_undecorate_test.cc", + "brillo/variant_dictionary_test.cc", + ] + } + if (use.device_mapper) { + sources += [ "brillo/blkdev_utils/device_mapper_test.cc" ] + } + } + + executable("libinstallattributes-${libbase_ver}_tests") { + configs += [ + "//common-mk:test", + ":target_defaults", + ] + deps = [ + ":libinstallattributes-${libbase_ver}", + "../common-mk/external_dependencies:install_attributes-proto", + "../common-mk/testrunner:testrunner", + ] + sources = [ + "install_attributes/tests/libinstallattributes_test.cc", + ] + } + + executable("libpolicy-${libbase_ver}_tests") { + configs += [ + "//common-mk:test", + ":target_defaults", + ] + deps = [ + ":libinstallattributes-${libbase_ver}", + ":libpolicy-${libbase_ver}", + "../common-mk/external_dependencies:install_attributes-proto", + "../common-mk/external_dependencies:policy-protos", + "../common-mk/testrunner:testrunner", + ] + sources = [ + "install_attributes/mock_install_attributes_reader.cc", + "policy/tests/device_policy_impl_test.cc", + "policy/tests/libpolicy_test.cc", + "policy/tests/policy_util_test.cc", + "policy/tests/resilient_policy_util_test.cc", + ] + } +} + +if (use.fuzzer) { + executable("libbrillo_data_encoding_fuzzer") { + sources = [ + "brillo/data_encoding_fuzzer.cc", + ] + + configs += [ "//common-mk/common_fuzzer:common_fuzzer" ] + + pkg_deps = [ "libchrome-${libbase_ver}" ] + + include_dirs = [ "../libbrillo" ] + + deps = [ + ":libbrillo-core-${libbase_ver}", + ] + } + + executable("libbrillo_dbus_data_serialization_fuzzer") { + sources = [ + "brillo/dbus/data_serialization_fuzzer.cc", + ] + + configs += [ "//common-mk/common_fuzzer:common_fuzzer" ] + + pkg_deps = [ "libchrome-${libbase_ver}" ] + + include_dirs = [ "../libbrillo" ] + + deps = [ + ":libbrillo-core-${libbase_ver}", + ] + } + + executable("libbrillo_http_form_data_fuzzer") { + sources = [ + "brillo/http/http_form_data_fuzzer.cc", + ] + + configs += [ "//common-mk/common_fuzzer:common_fuzzer" ] + + pkg_deps = [ "libchrome-${libbase_ver}" ] + + include_dirs = [ "../libbrillo" ] + + deps = [ + ":libbrillo-http-${libbase_ver}", + ":libbrillo-streams-${libbase_ver}", + ] + } +} + +copy("libinstallattributes-includes") { + sources = [ + "install_attributes/libinstallattributes.h", + ] + outputs = [ + "${root_gen_dir}/include/install_attributes/{{source_file_part}}", + ] +} + +copy("libpolicy-includes") { + sources = [ + "policy/device_policy.h", + "policy/device_policy_impl.h", + "policy/libpolicy.h", + "policy/mock_device_policy.h", + "policy/mock_libpolicy.h", + "policy/policy_util.h", + "policy/resilient_policy_util.h", + ] + outputs = [ + "${root_gen_dir}/include/policy/{{source_file_part}}", + ] +} diff --git a/README.md b/README.md new file mode 100644 index 0000000..118e3f1 --- /dev/null +++ b/README.md @@ -0,0 +1,20 @@ +# libbrillo: platform utility library + +libbrillo is a shared library meant to hold common utility code that we deem +useful for platform projects. +It supplements the functionality provided by libbase/libchrome since that +project, by design, only holds functionality that Chromium (the browser) needs. +As a result, this tends to be more OS-centric code. + +## AOSP Usage + +This project is also used by [Update Engine] which is maintained in AOSP. +However, AOSP doesn't use this codebase directly, it maintains its own +[libbrillo fork](https://android.googlesource.com/platform/external/libbrillo/). + +To help keep the projects in sync, we have a gsubtree set up on our GoB: +https://chromium.googlesource.com/chromiumos/platform2/libbrillo/ + +This allows AOSP to cherry pick or merge changes directly back into their fork. + +[Update Engine]: https://android.googlesource.com/platform/system/update_engine/ diff --git a/brillo/any.cc b/brillo/any.cc index f84badf..b5ac84f 100644 --- a/brillo/any.cc +++ b/brillo/any.cc @@ -5,6 +5,7 @@ #include <brillo/any.h> #include <algorithm> +#include <utility> namespace brillo { diff --git a/brillo/any.h b/brillo/any.h index 51016b5..d41dd4a 100644 --- a/brillo/any.h +++ b/brillo/any.h @@ -18,7 +18,7 @@ // use helper functions std::ref() and std::cref() to create non-const and // const references respectively. In such a case, the type of contained data // will be std::reference_wrapper<T>. See 'References' unit tests in -// any_unittest.cc for examples. +// any_test.cc for examples. #ifndef LIBBRILLO_BRILLO_ANY_H_ #define LIBBRILLO_BRILLO_ANY_H_ @@ -26,6 +26,8 @@ #include <brillo/any_internal_impl.h> #include <algorithm> +#include <string> +#include <utility> #include <brillo/brillo_export.h> #include <brillo/type_name_undecorate.h> @@ -189,7 +191,7 @@ class BRILLO_EXPORT Any final { // (an appropriate specialization of AppendValueToWriter<T>() is available). // Returns false if the Any is empty or if there is no serialization method // defined for the contained data. - void AppendToDBusMessageWriter(dbus::MessageWriter* writer) const; + void AppendToDBusMessageWriter(::dbus::MessageWriter* writer) const; private: // Returns a pointer to a static buffer containing type tag (sort of a type diff --git a/brillo/any_internal_impl.h b/brillo/any_internal_impl.h index 9309f5d..f4114e6 100644 --- a/brillo/any_internal_impl.h +++ b/brillo/any_internal_impl.h @@ -154,7 +154,7 @@ struct Data { // Gets the contained integral value as an integer. virtual intmax_t GetAsInteger() const = 0; // Writes the contained value to the D-Bus message buffer. - virtual void AppendToDBusMessage(dbus::MessageWriter* writer) const = 0; + virtual void AppendToDBusMessage(::dbus::MessageWriter* writer) const = 0; // Compares if the two data containers have objects of the same value. virtual bool CompareEqual(const Data* other_data) const = 0; }; @@ -180,19 +180,19 @@ struct TypedData : public Data { return int_val; } - template<typename U> + template <typename U> static typename std::enable_if<dbus_utils::IsTypeSupported<U>::value>::type - AppendValueHelper(dbus::MessageWriter* writer, const U& value) { + AppendValueHelper(::dbus::MessageWriter* writer, const U& value) { brillo::dbus_utils::AppendValueToWriterAsVariant(writer, value); } - template<typename U> + template <typename U> static typename std::enable_if<!dbus_utils::IsTypeSupported<U>::value>::type - AppendValueHelper(dbus::MessageWriter* /* writer */, const U& /* value */) { + AppendValueHelper(::dbus::MessageWriter* /* writer */, const U& /* value */) { LOG(FATAL) << "Type '" << GetUndecoratedTypeName<U>() << "' is not supported by D-Bus"; } - void AppendToDBusMessage(dbus::MessageWriter* writer) const override { + void AppendToDBusMessage(::dbus::MessageWriter* writer) const override { return AppendValueHelper(writer, value_); } diff --git a/brillo/any_internal_impl_unittest.cc b/brillo/any_internal_impl_test.cc index 6f7f512..6f7f512 100644 --- a/brillo/any_internal_impl_unittest.cc +++ b/brillo/any_internal_impl_test.cc diff --git a/brillo/any_unittest.cc b/brillo/any_test.cc index db89884..936235e 100644 --- a/brillo/any_unittest.cc +++ b/brillo/any_test.cc @@ -5,6 +5,7 @@ #include <algorithm> #include <functional> #include <string> +#include <utility> #include <vector> #include <brillo/any.h> diff --git a/brillo/array_utils.h b/brillo/array_utils.h new file mode 100644 index 0000000..d180d35 --- /dev/null +++ b/brillo/array_utils.h @@ -0,0 +1,26 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_ARRAY_UTILS_H_ +#define LIBBRILLO_BRILLO_ARRAY_UTILS_H_ + +#include <array> +#include <utility> + +namespace brillo { + +// Create a std::array from a set of values without manually specifying the +// size of the array. Note that unlike the make_array likely to make its way +// into C++20, this function always requires the user to specify ElementType. +// This is done so that users are not surprised by the element type of resulting +// arrays when std::common_type is used. +template <typename ElementType, typename... T> +constexpr auto make_array(T&&... values) { + return std::array<ElementType, sizeof...(T)>{ + static_cast<ElementType>(std::forward<T>(values))...}; +} + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_ARRAY_UTILS_H_ diff --git a/brillo/asan.h b/brillo/asan.h index 9a73202..d29932a 100644 --- a/brillo/asan.h +++ b/brillo/asan.h @@ -17,5 +17,4 @@ #define BRILLO_DISABLE_ASAN #endif -#endif - +#endif // LIBBRILLO_BRILLO_ASAN_H_ diff --git a/brillo/asynchronous_signal_handler_unittest.cc b/brillo/asynchronous_signal_handler_test.cc index ec3b061..2211b9c 100644 --- a/brillo/asynchronous_signal_handler_unittest.cc +++ b/brillo/asynchronous_signal_handler_test.cc @@ -113,6 +113,8 @@ TEST_F(AsynchronousSignalHandlerTest, CheckMultipleSignal) { } } +// TODO(crbug/1011829): This test is flaky. +#if 0 TEST_F(AsynchronousSignalHandlerTest, CheckChld) { handler_.RegisterHandler( SIGCHLD, @@ -134,5 +136,6 @@ TEST_F(AsynchronousSignalHandlerTest, CheckChld) { EXPECT_EQ(static_cast<int>(CLD_EXITED), infos_[0].ssi_code); EXPECT_EQ(EXIT_SUCCESS, infos_[0].ssi_status); } +#endif } // namespace brillo diff --git a/brillo/backoff_entry_test.cc b/brillo/backoff_entry_test.cc new file mode 100644 index 0000000..6a95bc0 --- /dev/null +++ b/brillo/backoff_entry_test.cc @@ -0,0 +1,311 @@ +// Copyright 2015 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/backoff_entry.h> +#include <gtest/gtest.h> + +using base::TimeDelta; +using base::TimeTicks; + +namespace brillo { + +BackoffEntry::Policy base_policy = { 0, 1000, 2.0, 0.0, 20000, 2000, false }; + +class TestBackoffEntry : public BackoffEntry { + public: + explicit TestBackoffEntry(const Policy* const policy) + : BackoffEntry(policy), + now_(TimeTicks()) { + // Work around initialization in constructor not picking up + // fake time. + SetCustomReleaseTime(TimeTicks()); + } + + ~TestBackoffEntry() override {} + + TimeTicks ImplGetTimeNow() const override { return now_; } + + void set_now(const TimeTicks& now) { + now_ = now; + } + + private: + TimeTicks now_; + + DISALLOW_COPY_AND_ASSIGN(TestBackoffEntry); +}; + +TEST(BackoffEntryTest, BaseTest) { + TestBackoffEntry entry(&base_policy); + EXPECT_FALSE(entry.ShouldRejectRequest()); + EXPECT_EQ(TimeDelta(), entry.GetTimeUntilRelease()); + + entry.InformOfRequest(false); + EXPECT_TRUE(entry.ShouldRejectRequest()); + EXPECT_EQ(TimeDelta::FromMilliseconds(1000), entry.GetTimeUntilRelease()); +} + +TEST(BackoffEntryTest, CanDiscardNeverExpires) { + BackoffEntry::Policy never_expires_policy = base_policy; + never_expires_policy.entry_lifetime_ms = -1; + TestBackoffEntry never_expires(&never_expires_policy); + EXPECT_FALSE(never_expires.CanDiscard()); + never_expires.set_now(TimeTicks() + TimeDelta::FromDays(100)); + EXPECT_FALSE(never_expires.CanDiscard()); +} + +TEST(BackoffEntryTest, CanDiscard) { + TestBackoffEntry entry(&base_policy); + // Because lifetime is non-zero, we shouldn't be able to discard yet. + EXPECT_FALSE(entry.CanDiscard()); + + // Test the "being used" case. + entry.InformOfRequest(false); + EXPECT_FALSE(entry.CanDiscard()); + + // Test the case where there are errors but we can time out. + entry.set_now( + entry.GetReleaseTime() + TimeDelta::FromMilliseconds(1)); + EXPECT_FALSE(entry.CanDiscard()); + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds( + base_policy.maximum_backoff_ms + 1)); + EXPECT_TRUE(entry.CanDiscard()); + + // Test the final case (no errors, dependent only on specified lifetime). + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds( + base_policy.entry_lifetime_ms - 1)); + entry.InformOfRequest(true); + EXPECT_FALSE(entry.CanDiscard()); + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds( + base_policy.entry_lifetime_ms)); + EXPECT_TRUE(entry.CanDiscard()); +} + +TEST(BackoffEntryTest, CanDiscardAlwaysDelay) { + BackoffEntry::Policy always_delay_policy = base_policy; + always_delay_policy.always_use_initial_delay = true; + always_delay_policy.entry_lifetime_ms = 0; + + TestBackoffEntry entry(&always_delay_policy); + + // Because lifetime is non-zero, we shouldn't be able to discard yet. + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds(2000)); + EXPECT_TRUE(entry.CanDiscard()); + + // Even with no failures, we wait until the delay before we allow discard. + entry.InformOfRequest(true); + EXPECT_FALSE(entry.CanDiscard()); + + // Wait until the delay expires, and we can discard the entry again. + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds(1000)); + EXPECT_TRUE(entry.CanDiscard()); +} + +TEST(BackoffEntryTest, CanDiscardNotStored) { + BackoffEntry::Policy no_store_policy = base_policy; + no_store_policy.entry_lifetime_ms = 0; + TestBackoffEntry not_stored(&no_store_policy); + EXPECT_TRUE(not_stored.CanDiscard()); +} + +TEST(BackoffEntryTest, ShouldIgnoreFirstTwo) { + BackoffEntry::Policy lenient_policy = base_policy; + lenient_policy.num_errors_to_ignore = 2; + + BackoffEntry entry(&lenient_policy); + + entry.InformOfRequest(false); + EXPECT_FALSE(entry.ShouldRejectRequest()); + + entry.InformOfRequest(false); + EXPECT_FALSE(entry.ShouldRejectRequest()); + + entry.InformOfRequest(false); + EXPECT_TRUE(entry.ShouldRejectRequest()); +} + +TEST(BackoffEntryTest, ReleaseTimeCalculation) { + TestBackoffEntry entry(&base_policy); + + // With zero errors, should return "now". + TimeTicks result = entry.GetReleaseTime(); + EXPECT_EQ(entry.ImplGetTimeNow(), result); + + // 1 error. + entry.InformOfRequest(false); + result = entry.GetReleaseTime(); + EXPECT_EQ(entry.ImplGetTimeNow() + TimeDelta::FromMilliseconds(1000), result); + EXPECT_EQ(TimeDelta::FromMilliseconds(1000), entry.GetTimeUntilRelease()); + + // 2 errors. + entry.InformOfRequest(false); + result = entry.GetReleaseTime(); + EXPECT_EQ(entry.ImplGetTimeNow() + TimeDelta::FromMilliseconds(2000), result); + EXPECT_EQ(TimeDelta::FromMilliseconds(2000), entry.GetTimeUntilRelease()); + + // 3 errors. + entry.InformOfRequest(false); + result = entry.GetReleaseTime(); + EXPECT_EQ(entry.ImplGetTimeNow() + TimeDelta::FromMilliseconds(4000), result); + EXPECT_EQ(TimeDelta::FromMilliseconds(4000), entry.GetTimeUntilRelease()); + + // 6 errors (to check it doesn't pass maximum). + entry.InformOfRequest(false); + entry.InformOfRequest(false); + entry.InformOfRequest(false); + result = entry.GetReleaseTime(); + EXPECT_EQ( + entry.ImplGetTimeNow() + TimeDelta::FromMilliseconds(20000), result); +} + +TEST(BackoffEntryTest, ReleaseTimeCalculationAlwaysDelay) { + BackoffEntry::Policy always_delay_policy = base_policy; + always_delay_policy.always_use_initial_delay = true; + always_delay_policy.num_errors_to_ignore = 2; + + TestBackoffEntry entry(&always_delay_policy); + + // With previous requests, should return "now". + TimeTicks result = entry.GetReleaseTime(); + EXPECT_EQ(TimeDelta(), entry.GetTimeUntilRelease()); + + // 1 error. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(1000), entry.GetTimeUntilRelease()); + + // 2 errors. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(1000), entry.GetTimeUntilRelease()); + + // 3 errors, exponential backoff starts. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(2000), entry.GetTimeUntilRelease()); + + // 4 errors. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(4000), entry.GetTimeUntilRelease()); + + // 8 errors (to check it doesn't pass maximum). + entry.InformOfRequest(false); + entry.InformOfRequest(false); + entry.InformOfRequest(false); + entry.InformOfRequest(false); + result = entry.GetReleaseTime(); + EXPECT_EQ(TimeDelta::FromMilliseconds(20000), entry.GetTimeUntilRelease()); +} + +TEST(BackoffEntryTest, ReleaseTimeCalculationWithJitter) { + for (int i = 0; i < 10; ++i) { + BackoffEntry::Policy jittery_policy = base_policy; + jittery_policy.jitter_factor = 0.2; + + TestBackoffEntry entry(&jittery_policy); + + entry.InformOfRequest(false); + entry.InformOfRequest(false); + entry.InformOfRequest(false); + TimeTicks result = entry.GetReleaseTime(); + EXPECT_LE( + entry.ImplGetTimeNow() + TimeDelta::FromMilliseconds(3200), result); + EXPECT_GE( + entry.ImplGetTimeNow() + TimeDelta::FromMilliseconds(4000), result); + } +} + +TEST(BackoffEntryTest, FailureThenSuccess) { + TestBackoffEntry entry(&base_policy); + + // Failure count 1, establishes horizon. + entry.InformOfRequest(false); + TimeTicks release_time = entry.GetReleaseTime(); + EXPECT_EQ(TimeTicks() + TimeDelta::FromMilliseconds(1000), release_time); + + // Success, failure count 0, should not advance past + // the horizon that was already set. + entry.set_now(release_time - TimeDelta::FromMilliseconds(200)); + entry.InformOfRequest(true); + EXPECT_EQ(release_time, entry.GetReleaseTime()); + + // Failure, failure count 1. + entry.InformOfRequest(false); + EXPECT_EQ(release_time + TimeDelta::FromMilliseconds(800), + entry.GetReleaseTime()); +} + +TEST(BackoffEntryTest, FailureThenSuccessAlwaysDelay) { + BackoffEntry::Policy always_delay_policy = base_policy; + always_delay_policy.always_use_initial_delay = true; + always_delay_policy.num_errors_to_ignore = 1; + + TestBackoffEntry entry(&always_delay_policy); + + // Failure count 1. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(1000), entry.GetTimeUntilRelease()); + + // Failure count 2. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(2000), entry.GetTimeUntilRelease()); + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds(2000)); + + // Success. We should go back to the original delay. + entry.InformOfRequest(true); + EXPECT_EQ(TimeDelta::FromMilliseconds(1000), entry.GetTimeUntilRelease()); + + // Failure count reaches 2 again. We should increase the delay once more. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(2000), entry.GetTimeUntilRelease()); + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds(2000)); +} + +TEST(BackoffEntryTest, RetainCustomHorizon) { + TestBackoffEntry custom(&base_policy); + TimeTicks custom_horizon = TimeTicks() + TimeDelta::FromDays(3); + custom.SetCustomReleaseTime(custom_horizon); + custom.InformOfRequest(false); + custom.InformOfRequest(true); + custom.set_now(TimeTicks() + TimeDelta::FromDays(2)); + custom.InformOfRequest(false); + custom.InformOfRequest(true); + EXPECT_EQ(custom_horizon, custom.GetReleaseTime()); + + // Now check that once we are at or past the custom horizon, + // we get normal behavior. + custom.set_now(TimeTicks() + TimeDelta::FromDays(3)); + custom.InformOfRequest(false); + EXPECT_EQ( + TimeTicks() + TimeDelta::FromDays(3) + TimeDelta::FromMilliseconds(1000), + custom.GetReleaseTime()); +} + +TEST(BackoffEntryTest, RetainCustomHorizonWhenInitialErrorsIgnored) { + // Regression test for a bug discovered during code review. + BackoffEntry::Policy lenient_policy = base_policy; + lenient_policy.num_errors_to_ignore = 1; + TestBackoffEntry custom(&lenient_policy); + TimeTicks custom_horizon = TimeTicks() + TimeDelta::FromDays(3); + custom.SetCustomReleaseTime(custom_horizon); + custom.InformOfRequest(false); // This must not reset the horizon. + EXPECT_EQ(custom_horizon, custom.GetReleaseTime()); +} + +TEST(BackoffEntryTest, OverflowProtection) { + BackoffEntry::Policy large_multiply_policy = base_policy; + large_multiply_policy.multiply_factor = 256; + TestBackoffEntry custom(&large_multiply_policy); + + // Trigger enough failures such that more than 11 bits of exponent are used + // to represent the exponential backoff intermediate values. Given a multiply + // factor of 256 (2^8), 129 iterations is enough: 2^(8*(129-1)) = 2^1024. + for (int i = 0; i < 129; ++i) { + custom.set_now(custom.ImplGetTimeNow() + custom.GetTimeUntilRelease()); + custom.InformOfRequest(false); + ASSERT_TRUE(custom.ShouldRejectRequest()); + } + + // Max delay should still be respected. + EXPECT_EQ(20000, custom.GetTimeUntilRelease().InMilliseconds()); +} + +} // namespace brillo diff --git a/brillo/bind_lambda.h b/brillo/bind_lambda.h deleted file mode 100644 index 50ac095..0000000 --- a/brillo/bind_lambda.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2014 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef LIBBRILLO_BRILLO_BIND_LAMBDA_H_ -#define LIBBRILLO_BRILLO_BIND_LAMBDA_H_ - -#include <base/bind.h> - -//////////////////////////////////////////////////////////////////////////////// -// This file is an extension to base/bind_internal.h and adds a RunnableAdapter -// class specialization that wraps a functor (including lambda objects), so -// they can be used in base::Callback/base::Bind constructs. -// By including this file you will gain the ability to write expressions like: -// base::Callback<int(int)> callback = base::Bind([](int value) { -// return value * value; -// }); -//////////////////////////////////////////////////////////////////////////////// -namespace base { -namespace internal { - -// LambdaAdapter is a helper class that specializes on different function call -// signatures and provides the RunType and Run() method required by -// RunnableAdapter<> class. -template <typename Lambda, typename Sig> -class LambdaAdapter; - -// R(...) -template <typename Lambda, typename R, typename... Args> -class LambdaAdapter<Lambda, R(Lambda::*)(Args... args)> { - public: - typedef R(RunType)(Args...); - explicit LambdaAdapter(Lambda lambda) : lambda_(lambda) {} - R Run(Args... args) { return lambda_(std::forward<Args>(args)...); } - - private: - Lambda lambda_; -}; - -// R(...) const -template <typename Lambda, typename R, typename... Args> -class LambdaAdapter<Lambda, R(Lambda::*)(Args... args) const> { - public: - typedef R(RunType)(Args...); - explicit LambdaAdapter(Lambda lambda) : lambda_(lambda) {} - R Run(Args... args) { return lambda_(std::forward<Args>(args)...); } - - private: - Lambda lambda_; -}; - -template <typename Lambda> -class RunnableAdapter - : public LambdaAdapter<Lambda, decltype(&Lambda::operator())> { - public: - explicit RunnableAdapter(Lambda lambda) - : LambdaAdapter<Lambda, decltype(&Lambda::operator())>(lambda) {} -}; - -} // namespace internal -} // namespace base - -#endif // LIBBRILLO_BRILLO_BIND_LAMBDA_H_ diff --git a/brillo/blkdev_utils/device_mapper.cc b/brillo/blkdev_utils/device_mapper.cc new file mode 100644 index 0000000..726cd94 --- /dev/null +++ b/brillo/blkdev_utils/device_mapper.cc @@ -0,0 +1,240 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/blkdev_utils/device_mapper.h> + +#include <libdevmapper.h> +#include <algorithm> +#include <utility> + +#include <base/files/file_util.h> +#include <base/strings/string_number_conversions.h> +#include <base/strings/string_tokenizer.h> +#include <base/strings/stringprintf.h> +#include <brillo/blkdev_utils/device_mapper_task.h> +#include <brillo/secure_blob.h> + +namespace brillo { + +// Use a tokenizer to parse string data stored in SecureBlob. +// The tokenizer does not store internal state so it should be +// okay to use with SecureBlobs. +// DO NOT USE .toker() as that leaks contents of the SecureBlob. +using SecureBlobTokenizer = + base::StringTokenizerT<std::string, SecureBlob::const_iterator>; + +DevmapperTable::DevmapperTable(uint64_t start, + uint64_t size, + const std::string& type, + const SecureBlob& parameters) + : start_(start), size_(size), type_(type), parameters_(parameters) {} + +SecureBlob DevmapperTable::ToSecureBlob() { + SecureBlob table_blob(base::StringPrintf("%" PRIu64 " %" PRIu64 " %s ", + start_, size_, type_.c_str())); + + return SecureBlob::Combine(table_blob, parameters_); +} + +DevmapperTable DevmapperTable::CreateTableFromSecureBlob( + const SecureBlob& table) { + uint64_t start, size; + std::string type; + DevmapperTable invalid_table(0, 0, "", SecureBlob()); + + SecureBlobTokenizer tokenizer(table.begin(), table.end(), " "); + + // First parameter is start. + if (!tokenizer.GetNext() || + !base::StringToUint64( + std::string(tokenizer.token_begin(), tokenizer.token_end()), &start)) + return invalid_table; + + // Second parameter is size of the dm device. + if (!tokenizer.GetNext() || + !base::StringToUint64( + std::string(tokenizer.token_begin(), tokenizer.token_end()), &size)) + return invalid_table; + + // Third parameter is type of dm device. + if (!tokenizer.GetNext()) + return invalid_table; + + type = std::string(tokenizer.token_begin(), tokenizer.token_end()); + + // The remaining string is the parameters. + if (!tokenizer.GetNext()) + return invalid_table; + + // The remaining part is the parameters passed to the device. + SecureBlob target = SecureBlob(tokenizer.token_begin(), table.end()); + + return DevmapperTable(start, size, type, target); +} + +SecureBlob DevmapperTable::CryptGetKey() { + SecureBlobTokenizer tokenizer(parameters_.begin(), parameters_.end(), " "); + + // First field is the cipher. + if (!tokenizer.GetNext()) + return SecureBlob(); + + // The key is stored in the second field. + if (!tokenizer.GetNext()) + return SecureBlob(); + + SecureBlob hex_key(tokenizer.token_begin(), tokenizer.token_end()); + + SecureBlob key = SecureHexToSecureBlob(hex_key); + + if (key.empty()) { + LOG(ERROR) << "CryptExtractKey: HexStringToBytes failed"; + return SecureBlob(); + } + + return key; +} + +// In order to not leak the encryption key to non-SecureBlob managed memory, +// create the parameter blobs in three parts and combine. +SecureBlob DevmapperTable::CryptCreateParameters( + const std::string& cipher, + const SecureBlob& encryption_key, + const int iv_offset, + const base::FilePath& device, + int device_offset, + bool allow_discard) { + // First field is the cipher. + SecureBlob parameter_parts[3]; + + parameter_parts[0] = SecureBlob(cipher + " "); + parameter_parts[1] = SecureBlobToSecureHex(encryption_key); + parameter_parts[2] = SecureBlob(base::StringPrintf( + " %d %s %d%s", iv_offset, device.value().c_str(), device_offset, + (allow_discard ? " 1 allow_discards" : ""))); + + SecureBlob parameters; + for (auto param_part : parameter_parts) + parameters = SecureBlob::Combine(parameters, param_part); + + return parameters; +} + +std::unique_ptr<DevmapperTask> CreateDevmapperTask(int type) { + return std::make_unique<DevmapperTaskImpl>(type); +} + +DeviceMapper::DeviceMapper() { + dm_task_factory_ = base::Bind(&CreateDevmapperTask); +} + +DeviceMapper::DeviceMapper(const DevmapperTaskFactory& factory) + : dm_task_factory_(factory) {} + +bool DeviceMapper::Setup(const std::string& name, const DevmapperTable& table) { + auto task = dm_task_factory_.Run(DM_DEVICE_CREATE); + + if (!task->SetName(name)) { + LOG(ERROR) << "Setup: SetName failed."; + return false; + } + + if (!task->AddTarget(table.GetStart(), table.GetSize(), table.GetType(), + table.GetParameters())) { + LOG(ERROR) << "Setup: AddTarget failed"; + return false; + } + + if (!task->Run(true /* udev sync */)) { + LOG(ERROR) << "Setup: Run failed."; + return false; + } + + return true; +} + +bool DeviceMapper::Remove(const std::string& name) { + auto task = dm_task_factory_.Run(DM_DEVICE_REMOVE); + + if (!task->SetName(name)) { + LOG(ERROR) << "Remove: SetName failed."; + return false; + } + + if (!task->Run(true /* udev_sync */)) { + LOG(ERROR) << "Remove: Teardown failed."; + return false; + } + + return true; +} + +DevmapperTable DeviceMapper::GetTable(const std::string& name) { + auto task = dm_task_factory_.Run(DM_DEVICE_TABLE); + uint64_t start, size; + std::string type; + SecureBlob parameters; + + if (!task->SetName(name)) { + LOG(ERROR) << "GetTable: SetName failed."; + return DevmapperTable(0, 0, "", SecureBlob()); + } + + if (!task->Run()) { + LOG(ERROR) << "GetTable: Run failed."; + return DevmapperTable(0, 0, "", SecureBlob()); + } + + task->GetNextTarget(&start, &size, &type, ¶meters); + + return DevmapperTable(start, size, type, parameters); +} + +bool DeviceMapper::WipeTable(const std::string& name) { + auto size_task = dm_task_factory_.Run(DM_DEVICE_TABLE); + + if (!size_task->SetName(name)) { + LOG(ERROR) << "WipeTable: SetName failed."; + return false; + } + + if (!size_task->Run()) { + LOG(ERROR) << "WipeTable: RunTask failed."; + return false; + } + + // Arguments for fetching dm target. + bool ret = false; + uint64_t start = 0, size = 0, total_size = 0; + std::string type; + SecureBlob parameters; + + // Get maximum size of the device to be wiped. + do { + ret = size_task->GetNextTarget(&start, &size, &type, ¶meters); + total_size = std::max(start + size, total_size); + } while (ret); + + // Setup wipe task. + auto wipe_task = dm_task_factory_.Run(DM_DEVICE_RELOAD); + + if (!wipe_task->SetName(name)) { + LOG(ERROR) << "WipeTable: SetName failed."; + return false; + } + + if (!wipe_task->AddTarget(0, total_size, "error", SecureBlob())) { + LOG(ERROR) << "WipeTable: AddTarget failed."; + return false; + } + + if (!wipe_task->Run()) { + LOG(ERROR) << "WipeTable: RunTask failed."; + return false; + } + + return true; +} + +} // namespace brillo diff --git a/brillo/blkdev_utils/device_mapper.h b/brillo/blkdev_utils/device_mapper.h new file mode 100644 index 0000000..478b30a --- /dev/null +++ b/brillo/blkdev_utils/device_mapper.h @@ -0,0 +1,116 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_H_ +#define LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_H_ + +#include <functional> +#include <memory> +#include <string> + +#include <base/bind.h> +#include <base/callback.h> +#include <base/files/file_path.h> +#include <brillo/blkdev_utils/device_mapper_task.h> + +namespace brillo { + +// DevmapperTable manages device parameters. Contains helper +// functions to parse results from dmsetup. Since the table parameters +// may contain sensitive data eg. dm-crypt keys, we use SecureBlobs for +// the table parameters and as the table output format. + +class BRILLO_EXPORT DevmapperTable { + public: + // Create table from table parameters. + // Useful for setting up devices. + DevmapperTable(uint64_t start, + uint64_t size, + const std::string& type, + const SecureBlob& parameters); + + ~DevmapperTable() = default; + + // Returns the table as a SecureBlob. + SecureBlob ToSecureBlob(); + + // Getters for table components. + uint64_t GetStart() const { return start_; } + uint64_t GetSize() const { return size_; } + std::string GetType() const { return type_; } + SecureBlob GetParameters() const { return parameters_; } + + // Create table from table blob. + // Useful for parsing output from dmsetup. + // Using a static function to surface errors in parsing the blob. + static DevmapperTable CreateTableFromSecureBlob(const SecureBlob& table); + + // dm-crypt specific functions: + // ---------------------------- + // Extract key from (crypt) table. + SecureBlob CryptGetKey(); + + // Create crypt parameters . + // Useful for parsing output from dmsetup. + // Using a static function to surface errors in parsing the blob. + static SecureBlob CryptCreateParameters(const std::string& cipher, + const SecureBlob& encryption_key, + const int iv_offset, + const base::FilePath& device, + int device_offset, + bool allow_discard); + + private: + const uint64_t start_; + const uint64_t size_; + const std::string type_; + const SecureBlob parameters_; +}; + +// DevmapperTask is an abstract class so we wrap it in a unique_ptr. +using DevmapperTaskFactory = + base::Callback<std::unique_ptr<DevmapperTask>(int)>; + +// DeviceMapper handles the creation and removal of dm devices. +class BRILLO_EXPORT DeviceMapper { + public: + // Default constructor: sets up real devmapper devices. + DeviceMapper(); + + // Set a non-default dm task factory. + explicit DeviceMapper(const DevmapperTaskFactory& factory); + + // Default destructor. + ~DeviceMapper() = default; + + // Sets up device with table on /dev/mapper/<name>. + // Parameters + // name - Name of the devmapper device. + // table - Table for the devmapper device. + bool Setup(const std::string& name, const DevmapperTable& table); + + // Removes device. + // Parameters + // name - Name of the devmapper device. + bool Remove(const std::string& device); + + // Returns table for device. + // Parameters + // name - Name of the devmapper device. + DevmapperTable GetTable(const std::string& name); + + // Clears table for device. + // Parameters + // name - Name of the devmapper device. + bool WipeTable(const std::string& name); + + private: + // Devmapper task factory. + DevmapperTaskFactory dm_task_factory_; + DISALLOW_COPY_AND_ASSIGN(DeviceMapper); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_H_ diff --git a/brillo/blkdev_utils/device_mapper_fake.cc b/brillo/blkdev_utils/device_mapper_fake.cc new file mode 100644 index 0000000..8126960 --- /dev/null +++ b/brillo/blkdev_utils/device_mapper_fake.cc @@ -0,0 +1,112 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/blkdev_utils/device_mapper_fake.h> + +#include <memory> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +namespace brillo { +namespace fake { + +namespace { + +// Stub DmTask runs into a map for easy reference. +bool StubDmRunTask(DmTask* task, bool udev_sync) { + std::string dev_name = task->name; + std::string params; + int type = task->type; + static auto& dm_target_map_ = + *new std::unordered_map<std::string, std::vector<DmTarget>>(); + + switch (type) { + case DM_DEVICE_CREATE: + CHECK_EQ(udev_sync, true); + if (dm_target_map_.find(dev_name) != dm_target_map_.end()) + return false; + dm_target_map_.insert(std::make_pair(dev_name, task->targets)); + break; + case DM_DEVICE_REMOVE: + CHECK_EQ(udev_sync, true); + if (dm_target_map_.find(dev_name) == dm_target_map_.end()) + return false; + dm_target_map_.erase(dev_name); + break; + case DM_DEVICE_TABLE: + CHECK_EQ(udev_sync, false); + if (dm_target_map_.find(dev_name) == dm_target_map_.end()) + return false; + task->targets = dm_target_map_[dev_name]; + break; + case DM_DEVICE_RELOAD: + CHECK_EQ(udev_sync, false); + if (dm_target_map_.find(dev_name) == dm_target_map_.end()) + return false; + dm_target_map_.erase(dev_name); + dm_target_map_.insert(std::make_pair(dev_name, task->targets)); + break; + default: + return false; + } + return true; +} + +std::unique_ptr<DmTask> DmTaskCreate(int type) { + auto t = std::make_unique<DmTask>(); + t->type = type; + return t; +} + +} // namespace + +FakeDevmapperTask::FakeDevmapperTask(int type) : task_(DmTaskCreate(type)) {} + +bool FakeDevmapperTask::SetName(const std::string& name) { + task_->name = std::string(name); + return true; +} + +bool FakeDevmapperTask::AddTarget(uint64_t start, + uint64_t sectors, + const std::string& type, + const SecureBlob& parameters) { + DmTarget dmt; + dmt.start = start; + dmt.size = sectors; + dmt.type = type; + dmt.parameters = parameters; + task_->targets.push_back(dmt); + return true; +} + +bool FakeDevmapperTask::GetNextTarget(uint64_t* start, + uint64_t* sectors, + std::string* type, + SecureBlob* parameters) { + if (task_->targets.empty()) + return false; + + DmTarget dmt = task_->targets[0]; + *start = dmt.start; + *sectors = dmt.size; + *type = dmt.type; + *parameters = dmt.parameters; + task_->targets.erase(task_->targets.begin()); + + return !task_->targets.empty(); +} + +bool FakeDevmapperTask::Run(bool udev_sync) { + return StubDmRunTask(task_.get(), udev_sync); +} + +std::unique_ptr<DevmapperTask> CreateDevmapperTask(int type) { + return std::make_unique<FakeDevmapperTask>(type); +} + +} // namespace fake +} // namespace brillo diff --git a/brillo/blkdev_utils/device_mapper_fake.h b/brillo/blkdev_utils/device_mapper_fake.h new file mode 100644 index 0000000..bc4f28c --- /dev/null +++ b/brillo/blkdev_utils/device_mapper_fake.h @@ -0,0 +1,65 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_FAKE_H_ +#define LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_FAKE_H_ + +#include <memory> +#include <string> +#include <vector> + +#include <base/files/file_path.h> +#include <brillo/blkdev_utils/device_mapper.h> +#include <brillo/blkdev_utils/device_mapper_fake.h> +#include <brillo/blkdev_utils/device_mapper_task.h> +#include <brillo/secure_blob.h> + +namespace brillo { +namespace fake { + +// Fake implementation of dm_task primitives. +// ------------------------------------------ +// dm_task is an opaque type in libdevmapper so we +// define a minimal struct for DmTask and DmTarget +// to avoid linking in libdevmapper. +struct DmTarget { + uint64_t start; + uint64_t size; + std::string type; + SecureBlob parameters; +}; + +struct DmTask { + int type; + std::string name; + std::vector<DmTarget> targets; +}; + +// Fake task factory: creates fake tasks that +// stub task info into a map. +std::unique_ptr<DevmapperTask> CreateDevmapperTask(int type); + +class FakeDevmapperTask : public brillo::DevmapperTask { + public: + explicit FakeDevmapperTask(int type); + ~FakeDevmapperTask() override = default; + bool SetName(const std::string& name) override; + bool AddTarget(uint64_t start, + uint64_t sectors, + const std::string& target, + const SecureBlob& parameters) override; + bool GetNextTarget(uint64_t* start, + uint64_t* sectors, + std::string* target, + SecureBlob* parameters) override; + bool Run(bool udev_sync = true) override; + + private: + std::unique_ptr<DmTask> task_; +}; + +} // namespace fake +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_FAKE_H_ diff --git a/brillo/blkdev_utils/device_mapper_task.cc b/brillo/blkdev_utils/device_mapper_task.cc new file mode 100644 index 0000000..f2cbadd --- /dev/null +++ b/brillo/blkdev_utils/device_mapper_task.cc @@ -0,0 +1,95 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/blkdev_utils/device_mapper_task.h> + +#include <libdevmapper.h> +#include <string> +#include <utility> + +#include <brillo/blkdev_utils/device_mapper.h> + +namespace brillo { + +DevmapperTaskImpl::DevmapperTaskImpl(int type) + : task_(DmTaskPtr(dm_task_create(type), &dm_task_destroy)) {} + +bool DevmapperTaskImpl::SetName(const std::string& name) { + if (!task_ || !dm_task_set_name(task_.get(), name.c_str())) { + LOG(ERROR) << "SetName failed"; + return false; + } + return true; +} + +bool DevmapperTaskImpl::AddTarget(uint64_t start, + uint64_t length, + const std::string& type, + const SecureBlob& parameters) { + // Strings stored in SecureBlob don't end with '\0'. Unfortunately, + // this causes accesses beyond the allocated storage space if any + // of the functions expecting a c-string get passed a SecureBlob.data(). + // Temporarily, assign to a string. + // TODO(sarthakkukreti): Evaluate creation of a SecureCString to keep + // string data safe. + std::string parameters_str = parameters.to_string(); + if (!task_ || + !dm_task_add_target(task_.get(), start, length, type.c_str(), + parameters_str.c_str())) { + LOG(ERROR) << "AddTarget failed"; + return false; + } + // Clear the string. + parameters_str.clear(); + return true; +} + +bool DevmapperTaskImpl::GetNextTarget(uint64_t* start, + uint64_t* length, + std::string* type, + SecureBlob* parameters) { + if (!task_) { + LOG(ERROR) << "GetNextTarget: invalid task."; + return false; + } + + char *type_cstr, *parameters_cstr; + next_target_ = dm_get_next_target(task_.get(), next_target_, start, length, + &type_cstr, ¶meters_cstr); + + if (type_cstr) + *type = std::string(type_cstr); + if (parameters_cstr) { + SecureBlob parameters_blob(parameters_cstr); + memset(parameters_cstr, 0, parameters_blob.size()); + *parameters = std::move(parameters_blob); + } + + return (next_target_ != nullptr); +} + +bool DevmapperTaskImpl::Run(bool udev_sync) { + uint32_t cookie = 0; + + if (!task_) { + LOG(ERROR) << "Invalid task."; + return false; + } + + if (udev_sync && !dm_task_set_cookie(task_.get(), &cookie, 0)) { + LOG(ERROR) << "dm_task_set_cookie failed"; + return false; + } + + if (!dm_task_run(task_.get())) { + LOG(ERROR) << "dm_task_run failed"; + return false; + } + + // Make sure the node exists before continuing. + // TODO(sarthakkukreti): move to dm_udev_wait_immediate() on uprevving lvm2. + return udev_sync ? (dm_udev_wait(cookie) != 0) : true; +} + +} // namespace brillo diff --git a/brillo/blkdev_utils/device_mapper_task.h b/brillo/blkdev_utils/device_mapper_task.h new file mode 100644 index 0000000..f8e45d4 --- /dev/null +++ b/brillo/blkdev_utils/device_mapper_task.h @@ -0,0 +1,101 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_TASK_H_ +#define LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_TASK_H_ + +#include <libdevmapper.h> +#include <memory> +#include <string> + +#include <brillo/secure_blob.h> + +namespace brillo { + +using DmTaskPtr = std::unique_ptr<dm_task, void (*)(dm_task*)>; + +// Abstract class to manage DM devices. +// This class implements the bare minimum set of functions +// required to create/remove DM devices. DevmapperTask is the equivalent +// of a command to the device mapper to set/get targets associated with a +// logical DM device, but omits, for now, finer-grained commands. +// A target represents a segment of a DM device. +// +// The abstract class is strictly based on the dm_task_* functions +// from libdevmapper, but the interface provides sufficient flexibility +// for other implementations (eg. invoking dmsetup) or testing facades. +// +// The task type enum is defined in libdevmapper.h: for simplicity, the same +// enum types are reused in fake implementations of DevmapperTask. +// The following task types have been tested with DeviceMapper functions: +// - DM_DEVICE_CREATE: used in DeviceMapper::Setup. +// - DM_DEVICE_REMOVE: used in DeviceMapper::Remove. +// - DM_DEVICE_TABLE: used in DeviceMapper::GetTable and +// DeviceMapper::WipeTable. +// - DM_DEVICE_RELOAD: used in DeviceMapper::WipeTable. +class DevmapperTask { + public: + virtual ~DevmapperTask() = default; + // Sets device name for the command. + virtual bool SetName(const std::string& name) = 0; + + // Adds a target to the command. Should be followed by a Run(); + // Parameters: + // start: start of target in device. + // sectors: number of sectors in the target. + // type: type of the target. + // parameters: target parameters. + virtual bool AddTarget(uint64_t start, + uint64_t sectors, + const std::string& type, + const SecureBlob& parameters) = 0; + // Gets the next target from the command. + // Returns true while another target exists. + // If no target exist for the device, GetNextTarget sets all + // parameters to 0 and returns false. + // + // Parameters: + // start: start of target in device. + // sectors: number of sectors in the target. + // type: type of the target. + // parameters: target parameters. + virtual bool GetNextTarget(uint64_t* start, + uint64_t* sectors, + std::string* type, + SecureBlob* parameters) = 0; + // Run the task. + // Returns true if the task succeeded. + // + // Parameters: + // udev_sync: Enable/Disable udev_synchronization. Defaults to false. + // Enable only for tasks that create/remove/rename files to + // prevent both udevd and libdevmapper from attempting to + // add or remove files. + virtual bool Run(bool udev_sync = false) = 0; +}; + +// Libdevmapper implementation for DevmapperTask. +class DevmapperTaskImpl : public DevmapperTask { + public: + explicit DevmapperTaskImpl(int type); + ~DevmapperTaskImpl() override = default; + bool SetName(const std::string& name) override; + bool AddTarget(uint64_t start, + uint64_t sectors, + const std::string& target, + const SecureBlob& parameters) override; + bool GetNextTarget(uint64_t* start, + uint64_t* sectors, + std::string* target, + SecureBlob* parameters) override; + bool Run(bool udev_sync = true) override; + + private: + DmTaskPtr task_; + void* next_target_ = nullptr; +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_TASK_H_ diff --git a/brillo/blkdev_utils/device_mapper_test.cc b/brillo/blkdev_utils/device_mapper_test.cc new file mode 100644 index 0000000..ab19092 --- /dev/null +++ b/brillo/blkdev_utils/device_mapper_test.cc @@ -0,0 +1,143 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <libdevmapper.h> + +#include <base/files/file_util.h> +#include <base/memory/ptr_util.h> +#include <base/strings/string_number_conversions.h> +#include <base/strings/string_split.h> +#include <brillo/blkdev_utils/device_mapper_fake.h> +#include <gtest/gtest.h> + +namespace brillo { + +TEST(DevmapperTableTest, CreateTableFromSecureBlobTest) { + SecureBlob crypt_table_str("0 100 crypt"); + + DevmapperTable dm_table = + DevmapperTable::CreateTableFromSecureBlob(crypt_table_str); + EXPECT_EQ(DevmapperTable(0, 0, "", SecureBlob()).ToSecureBlob(), + dm_table.ToSecureBlob()); +} + +TEST(DevmapperTableTest, CryptCreateParametersTest) { + base::FilePath device("/some/random/filepath"); + + SecureBlob secret; + SecureBlob::HexStringToSecureBlob("0123456789ABCDEF", &secret); + + SecureBlob crypt_parameters = DevmapperTable::CryptCreateParameters( + "aes-cbc-essiv:sha256", secret, 0, device, 0, true); + + DevmapperTable crypt_table(0, 100, "crypt", crypt_parameters); + + SecureBlob crypt_table_str( + "0 100 crypt aes-cbc-essiv:sha256 " + "0123456789ABCDEF 0 /some/random/filepath 0 1 " + "allow_discards"); + + EXPECT_EQ(crypt_table.ToSecureBlob().to_string(), + crypt_table_str.to_string()); +} + +TEST(DevmapperTableTest, CryptCreateTableFromSecureBlobTest) { + base::FilePath device("/some/random/filepath"); + + SecureBlob secret; + SecureBlob::HexStringToSecureBlob("0123456789ABCDEF", &secret); + + SecureBlob crypt_parameters = DevmapperTable::CryptCreateParameters( + "aes-cbc-essiv:sha256", secret, 0, device, 0, true); + + DevmapperTable crypt_table(0, 100, "crypt", crypt_parameters); + + SecureBlob crypt_table_str( + "0 100 crypt aes-cbc-essiv:sha256 " + "0123456789ABCDEF 0 /some/random/filepath 0 1 " + "allow_discards"); + + DevmapperTable parsed_blob_table = + DevmapperTable::CreateTableFromSecureBlob(crypt_table_str); + + EXPECT_EQ(crypt_table.ToSecureBlob(), parsed_blob_table.ToSecureBlob()); +} + +TEST(DevmapperTableTest, CryptGetKeyTest) { + SecureBlob secret; + SecureBlob::HexStringToSecureBlob("0123456789ABCDEF", &secret); + SecureBlob crypt_table_str( + "0 100 crypt aes-cbc-essiv:sha256 " + "0123456789ABCDEF 0 /some/random/filepath 0 1 " + "allow_discards"); + + DevmapperTable dm_table = + DevmapperTable::CreateTableFromSecureBlob(crypt_table_str); + + EXPECT_EQ(secret, dm_table.CryptGetKey()); +} + +TEST(DevmapperTableTest, MalformedCryptTableTest) { + SecureBlob secret; + SecureBlob::HexStringToSecureBlob("0123456789ABCDEF", &secret); + // Pass malformed crypt table string. + SecureBlob crypt_table_str( + "0 100 crypt ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZ" + "ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZ" + "ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZ" + "ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZ"); + + DevmapperTable dm_table = + DevmapperTable::CreateTableFromSecureBlob(crypt_table_str); + + EXPECT_EQ(SecureBlob(), dm_table.CryptGetKey()); +} + +TEST(DevmapperTableTest, GetterTest) { + SecureBlob verity_table( + "0 40 verity payload=/dev/loop6 hashtree=/dev/loop6 " + "hashstart=40 alg=sha256 root_hexdigest=" + "01234567 " + "salt=89ABCDEF " + "error_behavior=eio"); + + DevmapperTable dm_table = + DevmapperTable::CreateTableFromSecureBlob(verity_table); + + EXPECT_EQ(dm_table.GetStart(), 0); + EXPECT_EQ(dm_table.GetSize(), 40); + EXPECT_EQ(dm_table.GetType(), "verity"); + EXPECT_EQ(dm_table.GetParameters(), + SecureBlob("payload=/dev/loop6 hashtree=/dev/loop6 " + "hashstart=40 alg=sha256 root_hexdigest=01234567 " + "salt=89ABCDEF error_behavior=eio")); +} + +TEST(DevmapperTest, FakeTaskConformance) { + SecureBlob secret; + SecureBlob::HexStringToSecureBlob("0123456789ABCDEF", &secret); + SecureBlob crypt_table_str( + "0 100 crypt aes-cbc-essiv:sha256 " + "0123456789ABCDEF 0 /some/random/filepath 0 1 " + "allow_discards"); + + DevmapperTable dm_table = + DevmapperTable::CreateTableFromSecureBlob(crypt_table_str); + + EXPECT_EQ(secret, dm_table.CryptGetKey()); + DeviceMapper dm(base::Bind(&fake::CreateDevmapperTask)); + + // Add device. + EXPECT_TRUE(dm.Setup("abcd", dm_table)); + EXPECT_FALSE(dm.Setup("abcd", dm_table)); + DevmapperTable table = dm.GetTable("abcd"); + // Expect tables to be the same. + EXPECT_EQ(table.ToSecureBlob(), dm_table.ToSecureBlob()); + // Expect key to match. + EXPECT_EQ(table.CryptGetKey(), secret); + EXPECT_TRUE(dm.Remove("abcd")); + EXPECT_FALSE(dm.Remove("abcd")); +} + +} // namespace brillo diff --git a/brillo/blkdev_utils/loop_device.cc b/brillo/blkdev_utils/loop_device.cc new file mode 100644 index 0000000..bd1b67c --- /dev/null +++ b/brillo/blkdev_utils/loop_device.cc @@ -0,0 +1,270 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/blkdev_utils/loop_device.h> + +#include <fcntl.h> +#include <linux/major.h> +#include <sys/ioctl.h> +#include <sys/types.h> +#include <unistd.h> + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include <base/files/file_enumerator.h> +#include <base/files/file_util.h> +#include <base/files/scoped_file.h> +#include <base/posix/eintr_wrapper.h> +#include <base/strings/string_number_conversions.h> +#include <base/strings/string_split.h> +#include <base/strings/string_util.h> +#include <base/strings/stringprintf.h> + +namespace brillo { + +namespace { + +constexpr char kLoopControl[] = "/dev/loop-control"; +constexpr char kSysBlockPath[] = "/sys/block"; +// File containing device id in /sys/block/loopX/. +constexpr char kDeviceIdPath[] = "dev"; +constexpr char kLoopBackingFile[] = "loop/backing_file"; +constexpr int kLoopDeviceIoctlFlags = O_RDWR | O_NOFOLLOW | O_CLOEXEC; +constexpr int kLoopControlIoctlFlags = O_RDONLY | O_NOFOLLOW | O_CLOEXEC; + +// ioctl runner for LoopDevice and LoopDeviceManager +int LoopDeviceIoctl(const base::FilePath& device, + int type, + uint64_t arg, + int open_flag) { + base::ScopedFD device_fd( + HANDLE_EINTR(open(device.value().c_str(), open_flag))); + + if (!device_fd.is_valid()) { + PLOG(ERROR) << "Unable to open loop device"; + return -EINVAL; + } + + int rc = ioctl(device_fd.get(), type, arg); + + if (rc < 0) + PLOG(ERROR) << "ioctl failed."; + + return rc; +} + +// Parse the device number for a valid /sys/block/loopX path +// or symlink to such a path. +// Returns -1 if invalid. +int GetDeviceNumber(const base::FilePath& sys_block_loopdev_path) { + std::string device_string; + int device_number = -1; + + base::FilePath device_file = sys_block_loopdev_path.Append(kDeviceIdPath); + + if (!base::ReadFileToString(device_file, &device_string)) + return -1; + + std::vector<std::string> device_ids = base::SplitString( + device_string, ":", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY); + + if (device_ids.size() != 2 || device_ids[0] != base::IntToString(LOOP_MAJOR)) + return -1; + + base::StringToInt(device_ids[1], &device_number); + return device_number; +} + +// For a validated loop device path, return the backing file path. +// Note that a pre-populated loop device path would return an empty +// backing file. +base::FilePath GetBackingFile(const base::FilePath& loopdev_path) { + // Backing file contains path to associated source for loop devices. + base::FilePath backing_file = loopdev_path.Append(kLoopBackingFile); + std::string backing_file_content; + // If the backing file doesn't exist, it's not an attached loop device. + if (!base::ReadFileToString(backing_file, &backing_file_content)) + return base::FilePath(); + base::FilePath backing_file_path( + base::TrimWhitespaceASCII(backing_file_content, base::TRIM_ALL)); + + return backing_file_path; +} + +base::FilePath CreateDevicePath(int device_number) { + return base::FilePath(base::StringPrintf("/dev/loop%d", device_number)); +} + +} // namespace + +LoopDevice::LoopDevice(int device_number, + const base::FilePath& backing_file, + const LoopIoctl& ioctl_runner) + : device_number_(device_number), + backing_file_(backing_file), + loop_ioctl_(ioctl_runner) {} + +bool LoopDevice::SetStatus(struct loop_info64 info) { + if (loop_ioctl_.Run(GetDevicePath(), LOOP_SET_STATUS64, + reinterpret_cast<uint64_t>(&info), + kLoopDeviceIoctlFlags) < 0) { + LOG(ERROR) << "ioctl(LOOP_SET_STATUS64) failed"; + return false; + } + return true; +} + +bool LoopDevice::GetStatus(struct loop_info64* info) { + if (loop_ioctl_.Run(GetDevicePath(), LOOP_GET_STATUS64, + reinterpret_cast<uint64_t>(info), + kLoopDeviceIoctlFlags) < 0) { + LOG(ERROR) << "ioctl(LOOP_GET_STATUS64) failed"; + return false; + } + return true; +} + +bool LoopDevice::SetName(const std::string& name) { + struct loop_info64 info; + + memset(&info, 0, sizeof(info)); + strncpy(reinterpret_cast<char*>(info.lo_file_name), name.c_str(), + LO_NAME_SIZE); + return SetStatus(info); +} + +bool LoopDevice::Detach() { + if (loop_ioctl_.Run(GetDevicePath(), LOOP_CLR_FD, 0, kLoopDeviceIoctlFlags) != + 0) { + LOG(ERROR) << "ioctl(LOOP_CLR_FD) failed"; + return false; + } + + return true; +} + +base::FilePath LoopDevice::GetDevicePath() { + return CreateDevicePath(device_number_); +} + +bool LoopDevice::IsValid() { + return device_number_ >= 0; +} + +LoopDeviceManager::LoopDeviceManager() + : loop_ioctl_(base::Bind(&LoopDeviceIoctl)) {} + +LoopDeviceManager::LoopDeviceManager(LoopIoctl ioctl_runner) + : loop_ioctl_(ioctl_runner) {} + +std::unique_ptr<LoopDevice> LoopDeviceManager::AttachDeviceToFile( + const base::FilePath& backing_file) { + int device_number = -1; + while (true) { + device_number = + loop_ioctl_.Run(base::FilePath(kLoopControl), LOOP_CTL_GET_FREE, 0, + kLoopControlIoctlFlags); + + if (device_number < 0) { + LOG(ERROR) << "ioctl(LOOP_CTL_GET_FREE) failed"; + return CreateLoopDevice(-1, base::FilePath()); + } + + base::ScopedFD backing_file_fd( + HANDLE_EINTR(open(backing_file.value().c_str(), O_RDWR))); + + if (!backing_file_fd.is_valid()) { + LOG(ERROR) << "Failed to open backing file."; + return CreateLoopDevice(-1, base::FilePath()); + } + + base::FilePath device_path = CreateDevicePath(device_number); + + if (loop_ioctl_.Run(device_path, LOOP_SET_FD, backing_file_fd.get(), + kLoopDeviceIoctlFlags) == 0) + break; + + if (errno != EBUSY) { + LOG(ERROR) << "ioctl(LOOP_SET_FD) failed"; + return CreateLoopDevice(-1, base::FilePath()); + } + } + // All steps of setting up the loop device succeeded. + return CreateLoopDevice(device_number, backing_file); +} + +std::vector<std::unique_ptr<LoopDevice>> +LoopDeviceManager::GetAttachedDevices() { + return SearchLoopDevicePaths(); +} + +std::unique_ptr<LoopDevice> LoopDeviceManager::GetAttachedDeviceByNumber( + int device_number) { + auto devices = SearchLoopDevicePaths(device_number); + + if (devices.empty()) + return CreateLoopDevice(-1, base::FilePath()); + + return std::move(devices[0]); +} + +std::unique_ptr<LoopDevice> LoopDeviceManager::GetAttachedDeviceByName( + const std::string& name) { + std::vector<std::unique_ptr<LoopDevice>> devices = GetAttachedDevices(); + + for (auto& attached_device : devices) { + struct loop_info64 device_info; + + if (!attached_device->GetStatus(&device_info)) { + LOG(ERROR) << "GetStatus failed"; + continue; + } + + if (strcmp(reinterpret_cast<char*>(device_info.lo_file_name), + name.c_str()) == 0) + return std::move(attached_device); + } + + return CreateLoopDevice(-1, base::FilePath()); +} + +// virtual +std::vector<std::unique_ptr<LoopDevice>> +LoopDeviceManager::SearchLoopDevicePaths(int device_number) { + std::vector<std::unique_ptr<LoopDevice>> devices; + base::FilePath rootdir(kSysBlockPath); + + if (device_number != -1) { + auto loopdev_path = + rootdir.Append(base::StringPrintf("loop%d", device_number)); + if (base::PathExists(loopdev_path)) + devices.push_back( + CreateLoopDevice(device_number, GetBackingFile(loopdev_path))); + } else { + // Read /sys/block to discover all loop devices. + base::FileEnumerator loopdev_enum( + rootdir, false /*recursive*/, + base::FileEnumerator::FILES | base::FileEnumerator::SHOW_SYM_LINKS, + "loop*"); + + for (auto loopdev = loopdev_enum.Next(); !loopdev.empty(); + loopdev = loopdev_enum.Next()) { + int dev_number = GetDeviceNumber(loopdev); + if (dev_number != -1) + devices.push_back( + CreateLoopDevice(dev_number, GetBackingFile(loopdev))); + } + } + return devices; +} + +std::unique_ptr<LoopDevice> LoopDeviceManager::CreateLoopDevice( + int device_number, const base::FilePath& backing_file) { + return std::make_unique<LoopDevice>(device_number, backing_file, loop_ioctl_); +} + +} // namespace brillo diff --git a/brillo/blkdev_utils/loop_device.h b/brillo/blkdev_utils/loop_device.h new file mode 100644 index 0000000..aba19cc --- /dev/null +++ b/brillo/blkdev_utils/loop_device.h @@ -0,0 +1,117 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_BLKDEV_UTILS_LOOP_DEVICE_H_ +#define LIBBRILLO_BRILLO_BLKDEV_UTILS_LOOP_DEVICE_H_ + +#include <linux/loop.h> +#include <memory> +#include <string> +#include <vector> + +#include <base/bind.h> +#include <base/callback.h> +#include <base/files/file_path.h> +#include <brillo/secure_blob.h> + +namespace brillo { + +// Forward declaration used by LoopDevice. +class LoopDeviceManager; + +using LoopIoctl = + base::Callback<int(const base::FilePath&, int, uint64_t, int)>; + +// LoopDevice provides an interface to attached loop devices. +// In order to simplify handling of loop devices, there +// is no inherent modifiable state associated within objects: +// the device number and backing file are consts. +// The intent here is for no class to create a LoopDevice +// directly; instead use LoopDeviceManager to get devices. +class BRILLO_EXPORT LoopDevice { + public: + // Create a loop device with a ioctl runner. + // Parameters + // device_number - loop device number. + // backing_file - backing file for the device. + // ioctl_runner - function to run loop ioctls. + LoopDevice(int device_number, + const base::FilePath& backing_file, + const LoopIoctl& ioctl_runner); + ~LoopDevice() = default; + + // Set device status. + // Parameters + // info - struct containing status. + bool SetStatus(struct loop_info64 info); + // Get device status. + // Parameters + // info - struct to populate. + bool GetStatus(struct loop_info64* info); + // Set device name. + // Parameters + // name - device name + bool SetName(const std::string& name); + // Detach device. + bool Detach(); + // Check if device is valid; + bool IsValid(); + + // Getters for device parameters. + base::FilePath GetBackingFilePath() { return backing_file_; } + base::FilePath GetDevicePath(); + + private: + const int device_number_; + const base::FilePath backing_file_; + // Ioctl runner. + LoopIoctl loop_ioctl_; +}; + +// Loop Device Manager handles requests for creating or fetching +// existing loop devices. If creation/fetch fails, the loop device +// manager returns nullptr. +class BRILLO_EXPORT LoopDeviceManager { + public: + LoopDeviceManager(); + // Create a loop device manager with a non-default ioctl runner. + // Parameters + // ioctl_runner - base::Callback to run ioctls. + explicit LoopDeviceManager(LoopIoctl ioctl_runner); + virtual ~LoopDeviceManager() = default; + + // Allocates a loop device and attaches it to a backing file. + // Parameters + // backing_file - file to attach device to. + virtual std::unique_ptr<LoopDevice> AttachDeviceToFile( + const base::FilePath& backing_file); + + // Fetches all attached loop devices. + std::vector<std::unique_ptr<LoopDevice>> GetAttachedDevices(); + + // Fetches a loop device by device number. + std::unique_ptr<LoopDevice> GetAttachedDeviceByNumber(int device_number); + + // Fetches a device number by name. + std::unique_ptr<LoopDevice> GetAttachedDeviceByName(const std::string& name); + + private: + // Search for loop devices by device number; if no device number is given, + // default to searaching and returning all loop devices. + virtual std::vector<std::unique_ptr<LoopDevice>> SearchLoopDevicePaths( + int device_number = -1); + // Create loop device with current ioctl runner. + // Parameters + // device_number - device number. + // backing_file - path to backing file. + std::unique_ptr<LoopDevice> CreateLoopDevice( + int device_number, const base::FilePath& backing_file); + + LoopIoctl loop_ioctl_; + DISALLOW_COPY_AND_ASSIGN(LoopDeviceManager); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_BLKDEV_UTILS_LOOP_DEVICE_H_ diff --git a/brillo/blkdev_utils/loop_device_fake.cc b/brillo/blkdev_utils/loop_device_fake.cc new file mode 100644 index 0000000..a181aad --- /dev/null +++ b/brillo/blkdev_utils/loop_device_fake.cc @@ -0,0 +1,148 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/blkdev_utils/loop_device_fake.h> + +#include <linux/loop.h> +#include <memory> +#include <string> +#include <vector> + +#include <base/strings/string_number_conversions.h> +#include <base/strings/string_split.h> +#include <base/strings/string_util.h> +#include <base/strings/stringprintf.h> +#include <brillo/blkdev_utils/loop_device.h> + +// Not a loop ioctl: we only use this to get the backing file from +// the stubbed function. All loop device ioctls start with 0x4c. +#define LOOP_GET_DEV 0x4cff + +namespace brillo { +namespace fake { + +namespace { + +int ParseLoopDeviceNumber(const base::FilePath& device_path) { + int device_number; + std::string path_string = device_path.value(); + return base::StartsWith(path_string, "/dev/loop", + base::CompareCase::SENSITIVE) && + base::StringToInt(path_string.substr(9), &device_number) + ? device_number + : -1; +} + +base::FilePath GetLoopDevicePath(int device_number) { + return base::FilePath(base::StringPrintf("/dev/loop%d", device_number)); +} + +int StubIoctlRunner(const base::FilePath& path, + int type, + uint64_t arg, + int flag) { + int device_number = ParseLoopDeviceNumber(path); + struct loop_info64* info; + struct LoopDev* device; + static std::vector<struct LoopDev>& loop_device_vector = + *new std::vector<struct LoopDev>(); + + switch (type) { + case LOOP_GET_STATUS64: + if (loop_device_vector.size() <= device_number || + loop_device_vector[device_number].valid == false) + return -1; + info = reinterpret_cast<struct loop_info64*>(arg); + memcpy(info, &loop_device_vector[device_number].info, + sizeof(struct loop_info64)); + return 0; + case LOOP_SET_STATUS64: + if (loop_device_vector.size() <= device_number || + loop_device_vector[device_number].valid == false) + return -1; + info = reinterpret_cast<struct loop_info64*>(arg); + memcpy(&loop_device_vector[device_number].info, info, + sizeof(struct loop_info64)); + return 0; + case LOOP_CLR_FD: + if (loop_device_vector.size() <= device_number || + loop_device_vector[device_number].valid == false) + return -1; + loop_device_vector[device_number].valid = false; + return 0; + case LOOP_CTL_GET_FREE: + device_number = loop_device_vector.size(); + loop_device_vector.push_back({true, base::FilePath(), {0}}); + return device_number; + // Instead of passing the fd here, we pass the FilePath of the backing + // file. + case LOOP_SET_FD: + if (loop_device_vector.size() <= device_number) + return -1; + loop_device_vector[device_number].backing_file = + *reinterpret_cast<const base::FilePath*>(arg); + return 0; + // Not a loop ioctl; Only used for conveniently checking the + // validity of the loop devices. + case LOOP_GET_DEV: + if (device_number >= loop_device_vector.size()) + return -1; + device = reinterpret_cast<struct LoopDev*>(arg); + device->valid = loop_device_vector[device_number].valid; + device->backing_file = loop_device_vector[device_number].backing_file; + memset(&(device->info), 0, sizeof(struct loop_info64)); + return 0; + default: + return -1; + } +} + +} // namespace + +FakeLoopDeviceManager::FakeLoopDeviceManager() + : LoopDeviceManager(base::Bind(&StubIoctlRunner)) {} + +std::unique_ptr<LoopDevice> FakeLoopDeviceManager::AttachDeviceToFile( + const base::FilePath& backing_file) { + int device_number = StubIoctlRunner(base::FilePath("/dev/loop-control"), + LOOP_CTL_GET_FREE, 0, 0); + + if (StubIoctlRunner(GetLoopDevicePath(device_number), LOOP_SET_FD, + reinterpret_cast<uint64_t>(&backing_file), 0) < 0) + return std::make_unique<LoopDevice>(-1, base::FilePath(), + base::Bind(&StubIoctlRunner)); + + return std::make_unique<LoopDevice>(device_number, backing_file, + base::Bind(&StubIoctlRunner)); +} + +std::vector<std::unique_ptr<LoopDevice>> +FakeLoopDeviceManager::SearchLoopDevicePaths(int device_number) { + std::vector<std::unique_ptr<LoopDevice>> devices; + struct LoopDev device; + + if (device_number != -1) { + if (StubIoctlRunner(GetLoopDevicePath(device_number), LOOP_GET_DEV, + reinterpret_cast<uint64_t>(&device), 0) < 0) + return devices; + + if (device.valid) + devices.push_back(std::make_unique<LoopDevice>( + device_number, device.backing_file, base::Bind(&StubIoctlRunner))); + return devices; + } + + int i = 0; + while (StubIoctlRunner(GetLoopDevicePath(i), LOOP_GET_DEV, + reinterpret_cast<uint64_t>(&device), 0) == 0) { + if (device.valid) + devices.push_back(std::make_unique<LoopDevice>( + i, device.backing_file, base::Bind(&StubIoctlRunner))); + i++; + } + return devices; +} + +} // namespace fake +} // namespace brillo diff --git a/brillo/blkdev_utils/loop_device_fake.h b/brillo/blkdev_utils/loop_device_fake.h new file mode 100644 index 0000000..751aa96 --- /dev/null +++ b/brillo/blkdev_utils/loop_device_fake.h @@ -0,0 +1,37 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_BLKDEV_UTILS_LOOP_DEVICE_FAKE_H_ +#define LIBBRILLO_BRILLO_BLKDEV_UTILS_LOOP_DEVICE_FAKE_H_ + +#include <memory> +#include <vector> + +#include <brillo/blkdev_utils/loop_device.h> + +namespace brillo { +namespace fake { + +struct LoopDev { + bool valid; + base::FilePath backing_file; + struct loop_info64 info; +}; + +class BRILLO_EXPORT FakeLoopDeviceManager : public brillo::LoopDeviceManager { + public: + FakeLoopDeviceManager(); + ~FakeLoopDeviceManager() override = default; + std::unique_ptr<LoopDevice> AttachDeviceToFile( + const base::FilePath& backing_file) override; + + private: + std::vector<std::unique_ptr<LoopDevice>> SearchLoopDevicePaths( + int device_number = -1) override; +}; + +} // namespace fake +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_BLKDEV_UTILS_LOOP_DEVICE_FAKE_H_ diff --git a/brillo/blkdev_utils/loop_device_test.cc b/brillo/blkdev_utils/loop_device_test.cc new file mode 100644 index 0000000..920ad68 --- /dev/null +++ b/brillo/blkdev_utils/loop_device_test.cc @@ -0,0 +1,57 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/blkdev_utils/loop_device_fake.h> + +#include <base/files/file_util.h> +#include <gtest/gtest.h> + +namespace brillo { + +TEST(LoopDeviceTest, GeneralTest) { + base::FilePath loop_backing_file; + base::CreateTemporaryFile(&loop_backing_file); + fake::FakeLoopDeviceManager loop_manager; + + // Create a new device + std::unique_ptr<LoopDevice> device = + loop_manager.AttachDeviceToFile(loop_backing_file); + std::unique_ptr<LoopDevice> device1 = + loop_manager.AttachDeviceToFile(loop_backing_file); + std::unique_ptr<LoopDevice> device2 = + loop_manager.AttachDeviceToFile(loop_backing_file); + + EXPECT_TRUE(device->IsValid()); + EXPECT_TRUE(device1->IsValid()); + EXPECT_TRUE(device2->IsValid()); + + std::vector<std::unique_ptr<LoopDevice>> attached_devices = + loop_manager.GetAttachedDevices(); + + // Expect 3 devices + EXPECT_EQ(attached_devices.size(), 3); + + device2->SetName("Loopy"); + + std::unique_ptr<LoopDevice> device1_copy = + loop_manager.GetAttachedDeviceByNumber(1); + EXPECT_TRUE(device1_copy->IsValid()); + EXPECT_EQ(device1->GetDevicePath(), device1_copy->GetDevicePath()); + EXPECT_EQ(device1->GetBackingFilePath(), device1_copy->GetBackingFilePath()); + + std::unique_ptr<LoopDevice> device2_copy = + loop_manager.GetAttachedDeviceByName("Loopy"); + EXPECT_TRUE(device2_copy->IsValid()); + EXPECT_EQ(device2->GetDevicePath(), device2_copy->GetDevicePath()); + EXPECT_EQ(device2->GetBackingFilePath(), device2_copy->GetBackingFilePath()); + + // Check double detach + EXPECT_TRUE(device->Detach()); + EXPECT_TRUE(device1->Detach()); + EXPECT_FALSE(device1_copy->Detach()); + EXPECT_TRUE(device2->Detach()); + EXPECT_FALSE(device2_copy->Detach()); +} + +} // namespace brillo diff --git a/brillo/daemons/daemon.cc b/brillo/daemons/daemon.cc index 1b3d6d2..b706017 100644 --- a/brillo/daemons/daemon.cc +++ b/brillo/daemons/daemon.cc @@ -14,7 +14,7 @@ namespace brillo { -Daemon::Daemon() : exit_code_{EX_OK} { +Daemon::Daemon() : exit_code_{EX_OK}, exiting_(false) { message_loop_.SetAsCurrent(); } @@ -85,15 +85,27 @@ bool Daemon::OnRestart() { } bool Daemon::Shutdown(const signalfd_siginfo& /* info */) { - Quit(); - return true; // Unregister the signal handler. + // Only respond to the first call. + if (!exiting_) { + exiting_ = true; + Quit(); + } + // Always return false, to avoid unregistering the signal handler. We might + // receive multiple successive signals, and we don't want to take the default + // response (termination) while we're still tearing down. + return false; } bool Daemon::Restart(const signalfd_siginfo& /* info */) { - if (OnRestart()) - return false; // Keep listening to the signal. - Quit(); - return true; // Unregister the signal handler. + if (!exiting_ && !OnRestart()) { + // Only Quit() once. + exiting_ = true; + Quit(); + } + // Always return false, to avoid unregistering the signal handler. We might + // receive multiple successive signals, and we don't want to take the default + // response (termination) while we're still tearing down. + return false; } void Daemon::OnEventLoopStartedTask() { diff --git a/brillo/daemons/daemon.h b/brillo/daemons/daemon.h index a16e04a..499b609 100644 --- a/brillo/daemons/daemon.h +++ b/brillo/daemons/daemon.h @@ -114,6 +114,8 @@ class BRILLO_EXPORT Daemon : public AsynchronousSignalHandlerInterface { AsynchronousSignalHandler async_signal_handler_; // Process exit code specified in QuitWithExitCode() method call. int exit_code_; + // Daemon is in the process of exiting. + bool exiting_; DISALLOW_COPY_AND_ASSIGN(Daemon); }; diff --git a/brillo/daemons/dbus_daemon.h b/brillo/daemons/dbus_daemon.h index 25ce306..2017e7f 100644 --- a/brillo/daemons/dbus_daemon.h +++ b/brillo/daemons/dbus_daemon.h @@ -37,7 +37,7 @@ class BRILLO_EXPORT DBusDaemon : public Daemon { // A reference to the |dbus_connection_| bus object often used by derived // classes. - scoped_refptr<dbus::Bus> bus_; + scoped_refptr<::dbus::Bus> bus_; private: DBusConnection dbus_connection_; @@ -59,7 +59,7 @@ class BRILLO_EXPORT DBusServiceDaemon : public DBusDaemon { // not created and is not available as part of the D-Bus service. explicit DBusServiceDaemon(const std::string& service_name); DBusServiceDaemon(const std::string& service_name, - const dbus::ObjectPath& object_manager_path); + const ::dbus::ObjectPath& object_manager_path); DBusServiceDaemon(const std::string& service_name, base::StringPiece object_manager_path); @@ -76,7 +76,7 @@ class BRILLO_EXPORT DBusServiceDaemon : public DBusDaemon { dbus_utils::AsyncEventSequencer* sequencer); std::string service_name_; - dbus::ObjectPath object_manager_path_; + ::dbus::ObjectPath object_manager_path_; std::unique_ptr<dbus_utils::ExportedObjectManager> object_manager_; private: diff --git a/brillo/data_encoding_fuzzer.cc b/brillo/data_encoding_fuzzer.cc new file mode 100644 index 0000000..8d5d41e --- /dev/null +++ b/brillo/data_encoding_fuzzer.cc @@ -0,0 +1,72 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <cstddef> +#include <cstdint> +#include <cstdio> + +#include <brillo/data_encoding.h> + +#include <base/logging.h> +#include <fuzzer/FuzzedDataProvider.h> + +namespace { +constexpr int kMaxStringLength = 256; +constexpr int kMaxParamsSize = 8; + +void FuzzUrlEncodeDecode(FuzzedDataProvider* provider) { + brillo::data_encoding::UrlEncode( + provider->ConsumeRandomLengthString(kMaxStringLength).c_str(), + provider->ConsumeBool()); + + brillo::data_encoding::UrlDecode( + provider->ConsumeRandomLengthString(kMaxStringLength).c_str()); +} + +void FuzzWebParamsEncodeDecode(FuzzedDataProvider* provider) { + brillo::data_encoding::WebParamList param_list; + const auto num_params = provider->ConsumeIntegralInRange(0, kMaxParamsSize); + for (auto i = 0; i < num_params; i++) { + param_list.push_back(std::pair<std::string, std::string>( + provider->ConsumeRandomLengthString(kMaxStringLength), + provider->ConsumeRandomLengthString(kMaxStringLength))); + } + brillo::data_encoding::WebParamsEncode(param_list, provider->ConsumeBool()); + + brillo::data_encoding::WebParamsDecode( + provider->ConsumeRandomLengthString(kMaxStringLength)); +} + +void FuzzBase64EncodeDecode(FuzzedDataProvider* provider) { + brillo::data_encoding::Base64Encode( + provider->ConsumeRandomLengthString(kMaxStringLength)); + brillo::Blob output; + brillo::data_encoding::Base64Decode( + provider->ConsumeRandomLengthString(kMaxStringLength), &output); +} + +bool IgnoreLogging(int, const char*, int, size_t, const std::string&) { + return true; +} + +} // namespace + +class Environment { + public: + Environment() { + // Disable logging. Normally this would be done with logging::SetMinLogLevel + // but that doesn't work for brillo::Error because it's not using the + // LOG(ERROR) macro which is where the actual log level check occurs. + logging::SetLogMessageHandler(&IgnoreLogging); + } +}; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + static Environment env; + FuzzedDataProvider data_provider(data, size); + FuzzUrlEncodeDecode(&data_provider); + FuzzWebParamsEncodeDecode(&data_provider); + FuzzBase64EncodeDecode(&data_provider); + return 0; +} diff --git a/brillo/data_encoding_unittest.cc b/brillo/data_encoding_test.cc index cb73da6..cb73da6 100644 --- a/brillo/data_encoding_unittest.cc +++ b/brillo/data_encoding_test.cc diff --git a/brillo/dbus/async_event_sequencer.cc b/brillo/dbus/async_event_sequencer.cc index 8861e21..5cdf36d 100644 --- a/brillo/dbus/async_event_sequencer.cc +++ b/brillo/dbus/async_event_sequencer.cc @@ -4,6 +4,9 @@ #include <brillo/dbus/async_event_sequencer.h> +#include <base/bind.h> +#include <base/callback.h> + namespace brillo { namespace dbus_utils { diff --git a/brillo/dbus/async_event_sequencer.h b/brillo/dbus/async_event_sequencer.h index c817b55..cc532e6 100644 --- a/brillo/dbus/async_event_sequencer.h +++ b/brillo/dbus/async_event_sequencer.h @@ -9,7 +9,7 @@ #include <string> #include <vector> -#include <base/bind.h> +#include <base/callback_forward.h> #include <base/macros.h> #include <base/memory/ref_counted.h> #include <brillo/brillo_export.h> diff --git a/brillo/dbus/async_event_sequencer_unittest.cc b/brillo/dbus/async_event_sequencer_test.cc index 5f4c0e2..1026afe 100644 --- a/brillo/dbus/async_event_sequencer_unittest.cc +++ b/brillo/dbus/async_event_sequencer_test.cc @@ -4,6 +4,7 @@ #include <brillo/dbus/async_event_sequencer.h> +#include <base/bind.h> #include <base/bind_helpers.h> #include <gmock/gmock.h> #include <gtest/gtest.h> @@ -22,7 +23,7 @@ const char kTestMethod2[] = "TestMethod2"; class AsyncEventSequencerTest : public ::testing::Test { public: - MOCK_METHOD1(HandleCompletion, void(bool all_succeeded)); + MOCK_METHOD(void, HandleCompletion, (bool)); void SetUp() { aec_ = new AsyncEventSequencer(); diff --git a/brillo/dbus/data_serialization.cc b/brillo/dbus/data_serialization.cc index 4cae471..fa348a0 100644 --- a/brillo/dbus/data_serialization.cc +++ b/brillo/dbus/data_serialization.cc @@ -232,6 +232,9 @@ bool PopArrayValueFromReader(dbus::MessageReader* reader, else if (signature == "a(uu)") return PopTypedArrayFromReader< std::tuple<uint32_t, uint32_t>>(reader, value); + else if (signature == "a(ubay)") + return PopTypedArrayFromReader< + std::tuple<uint32_t, bool, std::vector<uint8_t>>>(reader, value); // When a use case for particular array signature is found, feel free // to add handing for it here. @@ -256,6 +259,9 @@ bool PopStructValueFromReader(dbus::MessageReader* reader, else if (signature == "(uu)") return PopTypedValueFromReader<std::tuple<uint32_t, uint32_t>>(reader, value); + else if (signature == "(ua{sv})") + return PopTypedValueFromReader< + std::tuple<uint32_t, brillo::VariantDictionary>>(reader, value); // When a use case for particular struct signature is found, feel free // to add handing for it here. diff --git a/brillo/dbus/data_serialization.h b/brillo/dbus/data_serialization.h index 1600919..a4f49c1 100644 --- a/brillo/dbus/data_serialization.h +++ b/brillo/dbus/data_serialization.h @@ -49,7 +49,7 @@ // - static void Write(dbus::MessageWriter* writer, const CustomType& value); // - static bool Read(dbus::MessageReader* reader, CustomType* value); // See an example in DBusUtils.CustomStruct unit test in -// brillo/dbus/data_serialization_unittest.cc. +// brillo/dbus/data_serialization_test.cc. #include <map> #include <memory> @@ -125,16 +125,16 @@ struct IsTypeSupported<> : public std::false_type {}; // Write the |value| of type T to D-Bus message. // Explicitly delete the overloads for scalar types that are not supported by // D-Bus. -void AppendValueToWriter(dbus::MessageWriter* writer, char value) = delete; -void AppendValueToWriter(dbus::MessageWriter* writer, float value) = delete; +void AppendValueToWriter(::dbus::MessageWriter* writer, char value) = delete; +void AppendValueToWriter(::dbus::MessageWriter* writer, float value) = delete; //---------------------------------------------------------------------------- // PopValueFromReader<T>(dbus::MessageWriter* writer, T* value) // Reads the |value| of type T from D-Bus message. // Explicitly delete the overloads for scalar types that are not supported by // D-Bus. -void PopValueFromReader(dbus::MessageReader* reader, char* value) = delete; -void PopValueFromReader(dbus::MessageReader* reader, float* value) = delete; +void PopValueFromReader(::dbus::MessageReader* reader, char* value) = delete; +void PopValueFromReader(::dbus::MessageReader* reader, float* value) = delete; //---------------------------------------------------------------------------- // Get D-Bus data signature from C++ data types. @@ -153,9 +153,9 @@ namespace details { // into the Variant and updates the |*reader_ref| with the transient // |variant_reader| MessageReader instance passed in. // Returns false if it fails to descend into the Variant. -inline bool DescendIntoVariantIfPresent(dbus::MessageReader** reader_ref, - dbus::MessageReader* variant_reader) { - if ((*reader_ref)->GetDataType() != dbus::Message::VARIANT) +inline bool DescendIntoVariantIfPresent(::dbus::MessageReader** reader_ref, + ::dbus::MessageReader* variant_reader) { + if ((*reader_ref)->GetDataType() != ::dbus::Message::VARIANT) return true; if (!(*reader_ref)->PopVariant(variant_reader)) return false; @@ -187,198 +187,198 @@ inline std::string GetDBusDictEntryType() { // DBusType<T> for various C++ types that can be serialized over D-Bus. // bool ----------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - bool value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - bool* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + bool value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + bool* value); template<> struct DBusType<bool> { inline static std::string GetSignature() { return DBUS_TYPE_BOOLEAN_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, bool value) { + inline static void Write(::dbus::MessageWriter* writer, bool value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, bool* value) { + inline static bool Read(::dbus::MessageReader* reader, bool* value) { return PopValueFromReader(reader, value); } }; // uint8_t -------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - uint8_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - uint8_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + uint8_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + uint8_t* value); template<> struct DBusType<uint8_t> { inline static std::string GetSignature() { return DBUS_TYPE_BYTE_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, uint8_t value) { + inline static void Write(::dbus::MessageWriter* writer, uint8_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, uint8_t* value) { + inline static bool Read(::dbus::MessageReader* reader, uint8_t* value) { return PopValueFromReader(reader, value); } }; // int16_t -------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - int16_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - int16_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + int16_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + int16_t* value); template<> struct DBusType<int16_t> { inline static std::string GetSignature() { return DBUS_TYPE_INT16_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, int16_t value) { + inline static void Write(::dbus::MessageWriter* writer, int16_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, int16_t* value) { + inline static bool Read(::dbus::MessageReader* reader, int16_t* value) { return PopValueFromReader(reader, value); } }; // uint16_t ------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - uint16_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - uint16_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + uint16_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + uint16_t* value); template<> struct DBusType<uint16_t> { inline static std::string GetSignature() { return DBUS_TYPE_UINT16_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, uint16_t value) { + inline static void Write(::dbus::MessageWriter* writer, uint16_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, uint16_t* value) { + inline static bool Read(::dbus::MessageReader* reader, uint16_t* value) { return PopValueFromReader(reader, value); } }; // int32_t -------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - int32_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - int32_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + int32_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + int32_t* value); template<> struct DBusType<int32_t> { inline static std::string GetSignature() { return DBUS_TYPE_INT32_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, int32_t value) { + inline static void Write(::dbus::MessageWriter* writer, int32_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, int32_t* value) { + inline static bool Read(::dbus::MessageReader* reader, int32_t* value) { return PopValueFromReader(reader, value); } }; // uint32_t ------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - uint32_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - uint32_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + uint32_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + uint32_t* value); template<> struct DBusType<uint32_t> { inline static std::string GetSignature() { return DBUS_TYPE_UINT32_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, uint32_t value) { + inline static void Write(::dbus::MessageWriter* writer, uint32_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, uint32_t* value) { + inline static bool Read(::dbus::MessageReader* reader, uint32_t* value) { return PopValueFromReader(reader, value); } }; // int64_t -------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - int64_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - int64_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + int64_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + int64_t* value); template<> struct DBusType<int64_t> { inline static std::string GetSignature() { return DBUS_TYPE_INT64_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, int64_t value) { + inline static void Write(::dbus::MessageWriter* writer, int64_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, int64_t* value) { + inline static bool Read(::dbus::MessageReader* reader, int64_t* value) { return PopValueFromReader(reader, value); } }; // uint64_t ------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - uint64_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - uint64_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + uint64_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + uint64_t* value); template<> struct DBusType<uint64_t> { inline static std::string GetSignature() { return DBUS_TYPE_UINT64_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, uint64_t value) { + inline static void Write(::dbus::MessageWriter* writer, uint64_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, uint64_t* value) { + inline static bool Read(::dbus::MessageReader* reader, uint64_t* value) { return PopValueFromReader(reader, value); } }; // double --------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - double value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - double* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + double value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + double* value); template<> struct DBusType<double> { inline static std::string GetSignature() { return DBUS_TYPE_DOUBLE_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, double value) { + inline static void Write(::dbus::MessageWriter* writer, double value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, double* value) { + inline static bool Read(::dbus::MessageReader* reader, double* value) { return PopValueFromReader(reader, value); } }; // std::string ---------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - const std::string& value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - std::string* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + const std::string& value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + std::string* value); template<> struct DBusType<std::string> { inline static std::string GetSignature() { return DBUS_TYPE_STRING_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const std::string& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, std::string* value) { + inline static bool Read(::dbus::MessageReader* reader, std::string* value) { return PopValueFromReader(reader, value); } }; // const char* -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - const char* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + const char* value); template<> struct DBusType<const char*> { inline static std::string GetSignature() { return DBUS_TYPE_STRING_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, const char* value) { + inline static void Write(::dbus::MessageWriter* writer, const char* value) { AppendValueToWriter(writer, value); } }; @@ -389,44 +389,44 @@ struct DBusType<const char[]> { inline static std::string GetSignature() { return DBUS_TYPE_STRING_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, const char* value) { + inline static void Write(::dbus::MessageWriter* writer, const char* value) { AppendValueToWriter(writer, value); } }; // dbus::ObjectPath ----------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - const dbus::ObjectPath& value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - dbus::ObjectPath* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + const ::dbus::ObjectPath& value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + ::dbus::ObjectPath* value); -template<> -struct DBusType<dbus::ObjectPath> { +template <> +struct DBusType<::dbus::ObjectPath> { inline static std::string GetSignature() { return DBUS_TYPE_OBJECT_PATH_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, - const dbus::ObjectPath& value) { + inline static void Write(::dbus::MessageWriter* writer, + const ::dbus::ObjectPath& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, - dbus::ObjectPath* value) { + inline static bool Read(::dbus::MessageReader* reader, + ::dbus::ObjectPath* value) { return PopValueFromReader(reader, value); } }; // brillo::dbus_utils::FileDescriptor/base::ScopedFD -------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - const FileDescriptor& value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - base::ScopedFD* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + const FileDescriptor& value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + base::ScopedFD* value); template<> struct DBusType<FileDescriptor> { inline static std::string GetSignature() { return DBUS_TYPE_UNIX_FD_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const FileDescriptor& value) { AppendValueToWriter(writer, value); } @@ -437,38 +437,37 @@ struct DBusType<base::ScopedFD> { inline static std::string GetSignature() { return DBUS_TYPE_UNIX_FD_AS_STRING; } - inline static bool Read(dbus::MessageReader* reader, + inline static bool Read(::dbus::MessageReader* reader, base::ScopedFD* value) { return PopValueFromReader(reader, value); } }; // brillo::Any -------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - const brillo::Any& value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - brillo::Any* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + const brillo::Any& value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + brillo::Any* value); template<> struct DBusType<brillo::Any> { inline static std::string GetSignature() { return DBUS_TYPE_VARIANT_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const brillo::Any& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, brillo::Any* value) { + inline static bool Read(::dbus::MessageReader* reader, brillo::Any* value) { return PopValueFromReader(reader, value); } }; // std::vector = D-Bus ARRAY. ------------------------------------------------- -template<typename T, typename ALLOC> +template <typename T, typename ALLOC> typename std::enable_if<IsTypeSupported<T>::value>::type AppendValueToWriter( - dbus::MessageWriter* writer, - const std::vector<T, ALLOC>& value) { - dbus::MessageWriter array_writer(nullptr); + ::dbus::MessageWriter* writer, const std::vector<T, ALLOC>& value) { + ::dbus::MessageWriter array_writer(nullptr); writer->OpenArray(GetDBusSignature<T>(), &array_writer); for (const auto& element : value) { // Use DBusType<T>::Write() instead of AppendValueToWriter() to delay @@ -479,11 +478,12 @@ typename std::enable_if<IsTypeSupported<T>::value>::type AppendValueToWriter( writer->CloseContainer(&array_writer); } -template<typename T, typename ALLOC> +template <typename T, typename ALLOC> typename std::enable_if<IsTypeSupported<T>::value, bool>::type -PopValueFromReader(dbus::MessageReader* reader, std::vector<T, ALLOC>* value) { - dbus::MessageReader variant_reader(nullptr); - dbus::MessageReader array_reader(nullptr); +PopValueFromReader(::dbus::MessageReader* reader, + std::vector<T, ALLOC>* value) { + ::dbus::MessageReader variant_reader(nullptr); + ::dbus::MessageReader array_reader(nullptr); if (!details::DescendIntoVariantIfPresent(&reader, &variant_reader) || !reader->PopArray(&array_reader)) return false; @@ -510,11 +510,11 @@ struct DBusArrayType { inline static std::string GetSignature() { return GetArrayDBusSignature(GetDBusSignature<T>()); } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const std::vector<T, ALLOC>& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, + inline static bool Read(::dbus::MessageReader* reader, std::vector<T, ALLOC>* value) { return PopValueFromReader(reader, value); } @@ -562,11 +562,10 @@ inline std::string GetStructDBusSignature() { DBUS_STRUCT_END_CHAR_AS_STRING; } -template<typename U, typename V> +template <typename U, typename V> typename std::enable_if<IsTypeSupported<U, V>::value>::type AppendValueToWriter( - dbus::MessageWriter* writer, - const std::pair<U, V>& value) { - dbus::MessageWriter struct_writer(nullptr); + ::dbus::MessageWriter* writer, const std::pair<U, V>& value) { + ::dbus::MessageWriter struct_writer(nullptr); writer->OpenStruct(&struct_writer); // Use DBusType<T>::Write() instead of AppendValueToWriter() to delay // binding to AppendValueToWriter() to the point of instantiation of this @@ -576,11 +575,11 @@ typename std::enable_if<IsTypeSupported<U, V>::value>::type AppendValueToWriter( writer->CloseContainer(&struct_writer); } -template<typename U, typename V> +template <typename U, typename V> typename std::enable_if<IsTypeSupported<U, V>::value, bool>::type -PopValueFromReader(dbus::MessageReader* reader, std::pair<U, V>* value) { - dbus::MessageReader variant_reader(nullptr); - dbus::MessageReader struct_reader(nullptr); +PopValueFromReader(::dbus::MessageReader* reader, std::pair<U, V>* value) { + ::dbus::MessageReader variant_reader(nullptr); + ::dbus::MessageReader struct_reader(nullptr); if (!details::DescendIntoVariantIfPresent(&reader, &variant_reader) || !reader->PopStruct(&struct_reader)) return false; @@ -602,11 +601,12 @@ struct DBusPairType { inline static std::string GetSignature() { return GetStructDBusSignature<U, V>(); } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const std::pair<U, V>& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, std::pair<U, V>* value) { + inline static bool Read(::dbus::MessageReader* reader, + std::pair<U, V>* value) { return PopValueFromReader(reader, value); } }; @@ -636,7 +636,7 @@ struct TupleIterator { using ValueType = typename std::tuple_element<I, Tuple>::type; // Write the tuple element at index I to D-Bus message. - static void Write(dbus::MessageWriter* writer, const Tuple& value) { + static void Write(::dbus::MessageWriter* writer, const Tuple& value) { // Use DBusType<T>::Write() instead of AppendValueToWriter() to delay // binding to AppendValueToWriter() to the point of instantiation of this // template. @@ -645,7 +645,7 @@ struct TupleIterator { } // Read the tuple element at index I from D-Bus message. - static bool Read(dbus::MessageReader* reader, Tuple* value) { + static bool Read(::dbus::MessageReader* reader, Tuple* value) { // Use DBusType<T>::Read() instead of PopValueFromReader() to delay // binding to PopValueFromReader() to the point of instantiation of this // template. @@ -658,29 +658,29 @@ struct TupleIterator { template<size_t N, typename... T> struct TupleIterator<N, N, T...> { using Tuple = std::tuple<T...>; - static void Write(dbus::MessageWriter* /* writer */, + static void Write(::dbus::MessageWriter* /* writer */, const Tuple& /* value */) {} - static bool Read(dbus::MessageReader* /* reader */, - Tuple* /* value */) { return true; } + static bool Read(::dbus::MessageReader* /* reader */, Tuple* /* value */) { + return true; + } }; } // namespace details -template<typename... T> +template <typename... T> typename std::enable_if<IsTypeSupported<T...>::value>::type AppendValueToWriter( - dbus::MessageWriter* writer, - const std::tuple<T...>& value) { - dbus::MessageWriter struct_writer(nullptr); + ::dbus::MessageWriter* writer, const std::tuple<T...>& value) { + ::dbus::MessageWriter struct_writer(nullptr); writer->OpenStruct(&struct_writer); details::TupleIterator<0, sizeof...(T), T...>::Write(&struct_writer, value); writer->CloseContainer(&struct_writer); } -template<typename... T> +template <typename... T> typename std::enable_if<IsTypeSupported<T...>::value, bool>::type -PopValueFromReader(dbus::MessageReader* reader, std::tuple<T...>* value) { - dbus::MessageReader variant_reader(nullptr); - dbus::MessageReader struct_reader(nullptr); +PopValueFromReader(::dbus::MessageReader* reader, std::tuple<T...>* value) { + ::dbus::MessageReader variant_reader(nullptr); + ::dbus::MessageReader struct_reader(nullptr); if (!details::DescendIntoVariantIfPresent(&reader, &variant_reader) || !reader->PopStruct(&struct_reader)) return false; @@ -699,11 +699,11 @@ struct DBusTupleType { inline static std::string GetSignature() { return GetStructDBusSignature<T...>(); } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const std::tuple<T...>& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, + inline static bool Read(::dbus::MessageReader* reader, std::tuple<T...>* value) { return PopValueFromReader(reader, value); } @@ -720,14 +720,14 @@ struct DBusType<std::tuple<T...>> : public details::DBusTupleType<IsTypeSupported<T...>::value, T...> {}; // std::map = D-Bus ARRAY of DICT_ENTRY. -------------------------------------- -template<typename KEY, typename VALUE, typename PRED, typename ALLOC> +template <typename KEY, typename VALUE, typename PRED, typename ALLOC> typename std::enable_if<IsTypeSupported<KEY, VALUE>::value>::type -AppendValueToWriter(dbus::MessageWriter* writer, +AppendValueToWriter(::dbus::MessageWriter* writer, const std::map<KEY, VALUE, PRED, ALLOC>& value) { - dbus::MessageWriter dict_writer(nullptr); + ::dbus::MessageWriter dict_writer(nullptr); writer->OpenArray(details::GetDBusDictEntryType<KEY, VALUE>(), &dict_writer); for (const auto& pair : value) { - dbus::MessageWriter entry_writer(nullptr); + ::dbus::MessageWriter entry_writer(nullptr); dict_writer.OpenDictEntry(&entry_writer); // Use DBusType<T>::Write() instead of AppendValueToWriter() to delay // binding to AppendValueToWriter() to the point of instantiation of this @@ -739,18 +739,18 @@ AppendValueToWriter(dbus::MessageWriter* writer, writer->CloseContainer(&dict_writer); } -template<typename KEY, typename VALUE, typename PRED, typename ALLOC> +template <typename KEY, typename VALUE, typename PRED, typename ALLOC> typename std::enable_if<IsTypeSupported<KEY, VALUE>::value, bool>::type -PopValueFromReader(dbus::MessageReader* reader, +PopValueFromReader(::dbus::MessageReader* reader, std::map<KEY, VALUE, PRED, ALLOC>* value) { - dbus::MessageReader variant_reader(nullptr); - dbus::MessageReader array_reader(nullptr); + ::dbus::MessageReader variant_reader(nullptr); + ::dbus::MessageReader array_reader(nullptr); if (!details::DescendIntoVariantIfPresent(&reader, &variant_reader) || !reader->PopArray(&array_reader)) return false; value->clear(); while (array_reader.HasMoreData()) { - dbus::MessageReader dict_entry_reader(nullptr); + ::dbus::MessageReader dict_entry_reader(nullptr); if (!array_reader.PopDictEntry(&dict_entry_reader)) return false; KEY key; @@ -782,11 +782,11 @@ struct DBusMapType { inline static std::string GetSignature() { return GetArrayDBusSignature(GetDBusDictEntryType<KEY, VALUE>()); } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const std::map<KEY, VALUE, PRED, ALLOC>& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, + inline static bool Read(::dbus::MessageReader* reader, std::map<KEY, VALUE, PRED, ALLOC>* value) { return PopValueFromReader(reader, value); } @@ -807,12 +807,12 @@ struct DBusType<std::map<KEY, VALUE, PRED, ALLOC>> ALLOC> {}; // google::protobuf::MessageLite = D-Bus ARRAY of BYTE ------------------------ -inline void AppendValueToWriter(dbus::MessageWriter* writer, +inline void AppendValueToWriter(::dbus::MessageWriter* writer, const google::protobuf::MessageLite& value) { writer->AppendProtoAsArrayOfBytes(value); } -inline bool PopValueFromReader(dbus::MessageReader* reader, +inline bool PopValueFromReader(::dbus::MessageReader* reader, google::protobuf::MessageLite* value) { return reader->PopArrayOfBytesAsProto(value); } @@ -835,23 +835,23 @@ struct DBusType<T, typename std::enable_if<is_protobuf<T>::value>::type> { inline static std::string GetSignature() { return GetDBusSignature<std::vector<uint8_t>>(); } - inline static void Write(dbus::MessageWriter* writer, const T& value) { + inline static void Write(::dbus::MessageWriter* writer, const T& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, T* value) { + inline static bool Read(::dbus::MessageReader* reader, T* value) { return PopValueFromReader(reader, value); } }; //---------------------------------------------------------------------------- -// AppendValueToWriterAsVariant<T>(dbus::MessageWriter* writer, const T& value) -// Write the |value| of type T to D-Bus message as a VARIANT. -// This overload is provided only if T is supported by D-Bus. -template<typename T> +// AppendValueToWriterAsVariant<T>(::dbus::MessageWriter* writer, const T& +// value) Write the |value| of type T to D-Bus message as a VARIANT. This +// overload is provided only if T is supported by D-Bus. +template <typename T> typename std::enable_if<IsTypeSupported<T>::value>::type -AppendValueToWriterAsVariant(dbus::MessageWriter* writer, const T& value) { +AppendValueToWriterAsVariant(::dbus::MessageWriter* writer, const T& value) { std::string data_type = GetDBusSignature<T>(); - dbus::MessageWriter variant_writer(nullptr); + ::dbus::MessageWriter variant_writer(nullptr); writer->OpenVariant(data_type, &variant_writer); // Use DBusType<T>::Write() instead of AppendValueToWriter() to delay // binding to AppendValueToWriter() to the point of instantiation of this @@ -862,13 +862,13 @@ AppendValueToWriterAsVariant(dbus::MessageWriter* writer, const T& value) { // Special case: do not allow to write a Variant containing a Variant. // Just redirect to normal AppendValueToWriter(). -inline void AppendValueToWriterAsVariant(dbus::MessageWriter* writer, +inline void AppendValueToWriterAsVariant(::dbus::MessageWriter* writer, const brillo::Any& value) { return AppendValueToWriter(writer, value); } //---------------------------------------------------------------------------- -// PopVariantValueFromReader<T>(dbus::MessageWriter* writer, T* value) +// PopVariantValueFromReader<T>(::dbus::MessageWriter* writer, T* value) // Reads a Variant containing the |value| of type T from D-Bus message. // Note that the generic PopValueFromReader<T>(...) can do this too. // This method is provided for two reasons: @@ -876,10 +876,10 @@ inline void AppendValueToWriterAsVariant(dbus::MessageWriter* writer, // 2. To be used when it is important to assert that the data was sent // specifically as a Variant. // This overload is provided only if T is supported by D-Bus. -template<typename T> +template <typename T> typename std::enable_if<IsTypeSupported<T>::value, bool>::type -PopVariantValueFromReader(dbus::MessageReader* reader, T* value) { - dbus::MessageReader variant_reader(nullptr); +PopVariantValueFromReader(::dbus::MessageReader* reader, T* value) { + ::dbus::MessageReader variant_reader(nullptr); if (!reader->PopVariant(&variant_reader)) return false; // Use DBusType<T>::Read() instead of PopValueFromReader() to delay @@ -889,7 +889,8 @@ PopVariantValueFromReader(dbus::MessageReader* reader, T* value) { } // Special handling of request to read a Variant of Variant. -inline bool PopVariantValueFromReader(dbus::MessageReader* reader, Any* value) { +inline bool PopVariantValueFromReader(::dbus::MessageReader* reader, + Any* value) { return PopValueFromReader(reader, value); } diff --git a/brillo/dbus/data_serialization_fuzzer.cc b/brillo/dbus/data_serialization_fuzzer.cc new file mode 100644 index 0000000..dd576a1 --- /dev/null +++ b/brillo/dbus/data_serialization_fuzzer.cc @@ -0,0 +1,334 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <cmath> +#include <cstddef> +#include <cstdint> +#include <map> +#include <string> +#include <utility> +#include <vector> + +#include <base/logging.h> +#include <base/strings/string_util.h> +#include <brillo/dbus/data_serialization.h> +#include <dbus/string_util.h> +#include <fuzzer/FuzzedDataProvider.h> + +namespace { +constexpr int kRandomMaxContainerSize = 8; +constexpr int kRandomMaxDataLength = 128; + +typedef enum DataType { + kUint8 = 0, + kUint16, + kUint32, + kUint64, + kInt16, + kInt32, + kInt64, + kBool, + kDouble, + kString, + kObjectPath, + // A couple vector types. + kVectorInt16, + kVectorString, + // A couple pair types. + kPairBoolInt64, + kPairUint32String, + // A couple tuple types. + kTupleUint16StringBool, + kTupleDoubleInt32ObjectPath, + // A couple map types. + kMapInt32String, + kMapDoubleBool, + kMaxValue = kMapDoubleBool, +} DataType; + +template <typename T> +void AppendValue(dbus::MessageWriter* writer, bool variant, const T& value) { + if (variant) + brillo::dbus_utils::AppendValueToWriterAsVariant(writer, value); + else + brillo::dbus_utils::AppendValueToWriter(writer, value); +} + +template <typename T> +void GenerateIntAndAppendValue(FuzzedDataProvider* data_provider, + dbus::MessageWriter* writer, + bool variant) { + AppendValue(writer, variant, data_provider->ConsumeIntegral<T>()); +} + +template <typename T> +void PopValue(dbus::MessageReader* reader, bool variant, T* value) { + if (variant) + brillo::dbus_utils::PopVariantValueFromReader(reader, value); + else + brillo::dbus_utils::PopValueFromReader(reader, value); +} + +std::string GenerateValidUTF8(FuzzedDataProvider* data_provider) { + // >= 0x80 + // Generates a random string and returns it if it is valid UTF8, if it is not + // then it will strip it down to all the 7-bit ASCII chars and just return + // that string. + std::string str = + data_provider->ConsumeRandomLengthString(kRandomMaxDataLength); + if (base::IsStringUTF8(str)) + return str; + for (auto it = str.begin(); it != str.end(); it++) { + if (static_cast<uint8_t>(*it) >= 0x80) { + // Might be invalid, remove it. + it = str.erase(it); + it--; + } + } + return str; +} + +} // namespace + +class Environment { + public: + Environment() { + // Disable logging. + logging::SetMinLogLevel(logging::LOG_FATAL); + } +}; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + static Environment env; + FuzzedDataProvider data_provider(data, size); + // Consume a random fraction of our data writing random things to a D-Bus + // message, and then consume the remaining data reading randomly from that + // same D-Bus message. Given the templated nature of these functions and that + // they support essentially an infinite amount of types, we are constraining + // this to a fixed set of types defined above. + std::unique_ptr<dbus::Response> message = dbus::Response::CreateEmpty(); + dbus::MessageWriter writer(message.get()); + + int bytes_left_for_read = + static_cast<int>(data_provider.ConsumeProbability<float>() * size); + while (data_provider.remaining_bytes() > bytes_left_for_read) { + DataType curr_type = data_provider.ConsumeEnum<DataType>(); + bool variant = data_provider.ConsumeBool(); + switch (curr_type) { + case kUint8: + GenerateIntAndAppendValue<uint8_t>(&data_provider, &writer, variant); + break; + case kUint16: + GenerateIntAndAppendValue<uint16_t>(&data_provider, &writer, variant); + break; + case kUint32: + GenerateIntAndAppendValue<uint32_t>(&data_provider, &writer, variant); + break; + case kUint64: + GenerateIntAndAppendValue<uint64_t>(&data_provider, &writer, variant); + break; + case kInt16: + GenerateIntAndAppendValue<int16_t>(&data_provider, &writer, variant); + break; + case kInt32: + GenerateIntAndAppendValue<int32_t>(&data_provider, &writer, variant); + break; + case kInt64: + GenerateIntAndAppendValue<int64_t>(&data_provider, &writer, variant); + break; + case kBool: + AppendValue(&writer, variant, data_provider.ConsumeBool()); + break; + case kDouble: + AppendValue(&writer, variant, + data_provider.ConsumeProbability<double>()); + break; + case kString: + AppendValue(&writer, variant, GenerateValidUTF8(&data_provider)); + break; + case kObjectPath: { + std::string object_path = + data_provider.ConsumeRandomLengthString(kRandomMaxDataLength); + // If this isn't valid we'll hit a CHECK failure. + if (dbus::IsValidObjectPath(object_path)) + AppendValue(&writer, variant, dbus::ObjectPath(object_path)); + break; + } + case kVectorInt16: { + int vec_size = data_provider.ConsumeIntegralInRange<int>( + 0, kRandomMaxContainerSize); + std::vector<int16_t> vec(vec_size); + for (int i = 0; i < vec_size; i++) + vec[i] = data_provider.ConsumeIntegral<int16_t>(); + AppendValue(&writer, variant, vec); + break; + } + case kVectorString: { + int vec_size = data_provider.ConsumeIntegralInRange<int>( + 0, kRandomMaxContainerSize); + std::vector<std::string> vec(vec_size); + for (int i = 0; i < vec_size; i++) + vec[i] = GenerateValidUTF8(&data_provider); + AppendValue(&writer, variant, vec); + break; + } + case kPairBoolInt64: + AppendValue( + &writer, variant, + std::pair<bool, int64_t>{data_provider.ConsumeBool(), + data_provider.ConsumeIntegral<int64_t>()}); + break; + case kPairUint32String: + AppendValue(&writer, variant, + std::pair<uint32_t, std::string>{ + data_provider.ConsumeIntegral<uint32_t>(), + GenerateValidUTF8(&data_provider)}); + break; + case kTupleUint16StringBool: + AppendValue(&writer, variant, + std::tuple<uint32_t, std::string, bool>{ + data_provider.ConsumeIntegral<uint32_t>(), + GenerateValidUTF8(&data_provider), + data_provider.ConsumeBool()}); + break; + case kTupleDoubleInt32ObjectPath: { + std::string object_path = + data_provider.ConsumeRandomLengthString(kRandomMaxDataLength); + // If this isn't valid we'll hit a CHECK failure. + if (dbus::IsValidObjectPath(object_path)) { + AppendValue(&writer, variant, + std::tuple<double, int32_t, dbus::ObjectPath>{ + data_provider.ConsumeProbability<double>(), + data_provider.ConsumeIntegral<int32_t>(), + dbus::ObjectPath(object_path)}); + } + break; + } + case kMapInt32String: { + int map_size = data_provider.ConsumeIntegralInRange<int>( + 0, kRandomMaxContainerSize); + std::map<int32_t, std::string> map; + for (int i = 0; i < map_size; i++) + map[data_provider.ConsumeIntegral<int32_t>()] = + GenerateValidUTF8(&data_provider); + AppendValue(&writer, variant, map); + break; + } + case kMapDoubleBool: { + int map_size = data_provider.ConsumeIntegralInRange<int>( + 0, kRandomMaxContainerSize); + std::map<double, bool> map; + for (int i = 0; i < map_size; i++) + map[data_provider.ConsumeProbability<double>()] = + data_provider.ConsumeBool(); + AppendValue(&writer, variant, map); + break; + } + } + } + + dbus::MessageReader reader(message.get()); + while (data_provider.remaining_bytes()) { + DataType curr_type = data_provider.ConsumeEnum<DataType>(); + bool variant = data_provider.ConsumeBool(); + switch (curr_type) { + case kUint8: { + uint8_t value; + PopValue(&reader, variant, &value); + break; + } + case kUint16: { + uint16_t value; + PopValue(&reader, variant, &value); + break; + } + case kUint32: { + uint32_t value; + PopValue(&reader, variant, &value); + break; + } + case kUint64: { + uint64_t value; + PopValue(&reader, variant, &value); + break; + } + case kInt16: { + int16_t value; + PopValue(&reader, variant, &value); + break; + } + case kInt32: { + int32_t value; + PopValue(&reader, variant, &value); + break; + } + case kInt64: { + int64_t value; + PopValue(&reader, variant, &value); + break; + } + case kBool: { + bool value; + PopValue(&reader, variant, &value); + break; + } + case kDouble: { + double value; + PopValue(&reader, variant, &value); + break; + } + case kString: { + std::string value; + PopValue(&reader, variant, &value); + break; + } + case kObjectPath: { + dbus::ObjectPath value; + PopValue(&reader, variant, &value); + break; + } + case kVectorInt16: { + std::vector<int16_t> value; + PopValue(&reader, variant, &value); + break; + } + case kVectorString: { + std::vector<std::string> value; + PopValue(&reader, variant, &value); + break; + } + case kPairBoolInt64: { + std::pair<bool, int64_t> value; + PopValue(&reader, variant, &value); + break; + } + case kPairUint32String: { + std::pair<uint32_t, std::string> value; + PopValue(&reader, variant, &value); + break; + } + case kTupleUint16StringBool: { + std::tuple<uint16_t, std::string, bool> value; + break; + } + case kTupleDoubleInt32ObjectPath: { + std::tuple<double, int32_t, dbus::ObjectPath> value; + PopValue(&reader, variant, &value); + break; + } + case kMapInt32String: { + std::map<int32_t, std::string> value; + PopValue(&reader, variant, &value); + break; + } + case kMapDoubleBool: { + std::map<double, bool> value; + PopValue(&reader, variant, &value); + break; + } + } + } + + return 0; +} diff --git a/brillo/dbus/data_serialization_unittest.cc b/brillo/dbus/data_serialization_test.cc index c7d5e0f..7e68af5 100644 --- a/brillo/dbus/data_serialization_unittest.cc +++ b/brillo/dbus/data_serialization_test.cc @@ -5,6 +5,7 @@ #include <brillo/dbus/data_serialization.h> #include <limits> +#include <tuple> #include <base/files/scoped_file.h> #include <brillo/variant_dictionary.h> @@ -473,19 +474,28 @@ TEST(DBusUtils, ArraysAsVariant) { std::vector<double> dbl_array_empty{}; std::map<std::string, std::string> dict_ss{{"k1", "v1"}, {"k2", "v2"}}; VariantDictionary dict_sv{{"k1", 1}, {"k2", "v2"}}; + using ComplexStructArray = + std::vector<std::tuple<uint32_t, bool, std::vector<uint8_t>>>; + ComplexStructArray complex_struct_array{ + {123, true, {0xaa, 0xbb, 0xcc}}, + {456, false, {0xdd}}, + {789, false, {}}, + }; AppendValueToWriterAsVariant(&writer, int_array); AppendValueToWriterAsVariant(&writer, str_array); AppendValueToWriterAsVariant(&writer, dbl_array_empty); AppendValueToWriterAsVariant(&writer, dict_ss); AppendValueToWriterAsVariant(&writer, dict_sv); + AppendValueToWriterAsVariant(&writer, complex_struct_array); - EXPECT_EQ("vvvvv", message->GetSignature()); + EXPECT_EQ("vvvvvv", message->GetSignature()); Any int_array_out; Any str_array_out; Any dbl_array_out; Any dict_ss_out; Any dict_sv_out; + Any complex_struct_array_out; MessageReader reader(message.get()); EXPECT_TRUE(PopValueFromReader(&reader, &int_array_out)); @@ -493,6 +503,7 @@ TEST(DBusUtils, ArraysAsVariant) { EXPECT_TRUE(PopValueFromReader(&reader, &dbl_array_out)); EXPECT_TRUE(PopValueFromReader(&reader, &dict_ss_out)); EXPECT_TRUE(PopValueFromReader(&reader, &dict_sv_out)); + EXPECT_TRUE(PopValueFromReader(&reader, &complex_struct_array_out)); EXPECT_FALSE(reader.HasMoreData()); EXPECT_EQ(int_array, int_array_out.Get<std::vector<int>>()); @@ -503,6 +514,35 @@ TEST(DBusUtils, ArraysAsVariant) { dict_sv_out.Get<VariantDictionary>().at("k1").Get<int>()); EXPECT_EQ(dict_sv["k2"].Get<const char*>(), dict_sv_out.Get<VariantDictionary>().at("k2").Get<std::string>()); + EXPECT_EQ(complex_struct_array, + complex_struct_array_out.Get<ComplexStructArray>()); +} + +TEST(DBusUtils, StructsAsVariant) { + std::unique_ptr<Response> message = Response::CreateEmpty(); + MessageWriter writer(message.get()); + VariantDictionary dict_sv{{"k1", 1}, {"k2", "v2"}}; + std::tuple<uint32_t, VariantDictionary> u32_dict_sv_struct = + std::make_tuple(1, dict_sv); + AppendValueToWriterAsVariant(&writer, u32_dict_sv_struct); + + EXPECT_EQ("v", message->GetSignature()); + + Any u32_dict_sv_struct_out_any; + + MessageReader reader(message.get()); + EXPECT_TRUE(PopValueFromReader(&reader, &u32_dict_sv_struct_out_any)); + EXPECT_FALSE(reader.HasMoreData()); + + auto u32_dict_sv_struct_out = + u32_dict_sv_struct_out_any.Get<std::tuple<uint32_t, VariantDictionary>>(); + EXPECT_EQ(std::get<0>(u32_dict_sv_struct), + std::get<0>(u32_dict_sv_struct_out)); + VariantDictionary dict_sv_out = std::get<1>(u32_dict_sv_struct_out); + EXPECT_EQ(dict_sv.size(), dict_sv_out.size()); + EXPECT_EQ(dict_sv["k1"].Get<int>(), dict_sv_out["k1"].Get<int>()); + EXPECT_EQ(dict_sv["k2"].Get<const char*>(), + dict_sv_out["k2"].Get<std::string>()); } TEST(DBusUtils, VariantDictionary) { diff --git a/brillo/dbus/dbus_connection.cc b/brillo/dbus/dbus_connection.cc index b60cf44..2773316 100644 --- a/brillo/dbus/dbus_connection.cc +++ b/brillo/dbus/dbus_connection.cc @@ -4,15 +4,6 @@ #include <brillo/dbus/dbus_connection.h> -#include <sysexits.h> - -#include <base/bind.h> -#include <brillo/dbus/async_event_sequencer.h> -#include <brillo/dbus/exported_object_manager.h> - -using brillo::dbus_utils::AsyncEventSequencer; -using brillo::dbus_utils::ExportedObjectManager; - namespace brillo { DBusConnection::DBusConnection() { diff --git a/brillo/dbus/dbus_connection.h b/brillo/dbus/dbus_connection.h index aecf434..5f08ef7 100644 --- a/brillo/dbus/dbus_connection.h +++ b/brillo/dbus/dbus_connection.h @@ -21,15 +21,15 @@ class BRILLO_EXPORT DBusConnection final { // Instantiates dbus::Bus and establishes a D-Bus connection. Returns a // reference to the connected bus, or an empty pointer in case of error. - scoped_refptr<dbus::Bus> Connect(); + scoped_refptr<::dbus::Bus> Connect(); // Instantiates dbus::Bus and tries to establish a D-Bus connection for up to // |timeout|. If the connection can't be established after the timeout, fails // returning an empty pointer. - scoped_refptr<dbus::Bus> ConnectWithTimeout(base::TimeDelta timeout); + scoped_refptr<::dbus::Bus> ConnectWithTimeout(base::TimeDelta timeout); private: - scoped_refptr<dbus::Bus> bus_; + scoped_refptr<::dbus::Bus> bus_; private: DISALLOW_COPY_AND_ASSIGN(DBusConnection); @@ -37,4 +37,4 @@ class BRILLO_EXPORT DBusConnection final { } // namespace brillo -#endif // LIBBRILLO_BRILLO_DAEMONS_DBUS_DAEMON_H_ +#endif // LIBBRILLO_BRILLO_DBUS_DBUS_CONNECTION_H_ diff --git a/brillo/dbus/dbus_method_invoker.h b/brillo/dbus/dbus_method_invoker.h index f8b6990..08f5781 100644 --- a/brillo/dbus/dbus_method_invoker.h +++ b/brillo/dbus/dbus_method_invoker.h @@ -65,6 +65,7 @@ #include <memory> #include <string> #include <tuple> +#include <utility> #include <base/bind.h> #include <base/files/scoped_file.h> @@ -91,19 +92,19 @@ namespace dbus_utils { // [dbus/dbus.h]). // Returns a dbus::Response object on success. On failure, returns nullptr and // fills in additional error details into the |error| object. -template<typename... Args> -inline std::unique_ptr<dbus::Response> CallMethodAndBlockWithTimeout( +template <typename... Args> +inline std::unique_ptr<::dbus::Response> CallMethodAndBlockWithTimeout( int timeout_ms, - dbus::ObjectProxy* object, + ::dbus::ObjectProxy* object, const std::string& interface_name, const std::string& method_name, ErrorPtr* error, const Args&... args) { - dbus::MethodCall method_call(interface_name, method_name); + ::dbus::MethodCall method_call(interface_name, method_name); // Add method arguments to the message buffer. - dbus::MessageWriter writer(&method_call); + ::dbus::MessageWriter writer(&method_call); DBusParamWriter::Append(&writer, args...); - dbus::ScopedDBusError dbus_error; + ::dbus::ScopedDBusError dbus_error; auto response = object->CallMethodAndBlockWithErrorDetails( &method_call, timeout_ms, &dbus_error); if (!response) { @@ -127,19 +128,16 @@ inline std::unique_ptr<dbus::Response> CallMethodAndBlockWithTimeout( } // Same as CallMethodAndBlockWithTimeout() but uses a default timeout value. -template<typename... Args> -inline std::unique_ptr<dbus::Response> CallMethodAndBlock( - dbus::ObjectProxy* object, +template <typename... Args> +inline std::unique_ptr<::dbus::Response> CallMethodAndBlock( + ::dbus::ObjectProxy* object, const std::string& interface_name, const std::string& method_name, ErrorPtr* error, const Args&... args) { - return CallMethodAndBlockWithTimeout(dbus::ObjectProxy::TIMEOUT_USE_DEFAULT, - object, - interface_name, - method_name, - error, - args...); + return CallMethodAndBlockWithTimeout(::dbus::ObjectProxy::TIMEOUT_USE_DEFAULT, + object, interface_name, method_name, + error, args...); } namespace internal { @@ -169,9 +167,9 @@ inline FileDescriptor HackMove(const FileDescriptor& val) { // Extracts the parameters of |ResultTypes...| types from the message reader // and puts the values in the resulting |tuple|. Returns false on error and // provides additional error details in |error| object. -template<typename... ResultTypes> +template <typename... ResultTypes> inline bool ExtractMessageParametersAsTuple( - dbus::MessageReader* reader, + ::dbus::MessageReader* reader, ErrorPtr* error, std::tuple<ResultTypes...>* val_tuple) { auto callback = [val_tuple](const ResultTypes&... params) { @@ -182,9 +180,9 @@ inline bool ExtractMessageParametersAsTuple( } // Overload of ExtractMessageParametersAsTuple to handle reference types in // tuples created with std::tie(). -template<typename... ResultTypes> +template <typename... ResultTypes> inline bool ExtractMessageParametersAsTuple( - dbus::MessageReader* reader, + ::dbus::MessageReader* reader, ErrorPtr* error, std::tuple<ResultTypes&...>* ref_tuple) { auto callback = [ref_tuple](const ResultTypes&... params) { @@ -207,8 +205,8 @@ inline bool ExtractMessageParametersAsTuple( // if (ExtractMessageParameters(reader, &error, &data1, &data2)) { ... } // // The above example extracts an Int32 and a String from D-Bus message buffer. -template<typename... ResultTypes> -inline bool ExtractMessageParameters(dbus::MessageReader* reader, +template <typename... ResultTypes> +inline bool ExtractMessageParameters(::dbus::MessageReader* reader, ErrorPtr* error, ResultTypes*... results) { auto ref_tuple = std::tie(*results...); @@ -225,14 +223,14 @@ inline bool ExtractMessageParameters(dbus::MessageReader* reader, // any return values. Just do not specify any output |results|. In this case, // ExtractMethodCallResults() will verify that the method didn't return any // data in the |message|. -template<typename... ResultTypes> -inline bool ExtractMethodCallResults(dbus::Message* message, +template <typename... ResultTypes> +inline bool ExtractMethodCallResults(::dbus::Message* message, ErrorPtr* error, ResultTypes*... results) { CHECK(message) << "Unable to extract parameters from a NULL message."; - dbus::MessageReader reader(message); - if (message->GetMessageType() == dbus::Message::MESSAGE_ERROR) { + ::dbus::MessageReader reader(message); + if (message->GetMessageType() == ::dbus::Message::MESSAGE_ERROR) { std::string error_message; if (ExtractMessageParameters(&reader, error, &error_message)) AddDBusError(error, message->GetErrorName(), error_message); @@ -249,24 +247,24 @@ using AsyncErrorCallback = base::Callback<void(Error* error)>; // A helper function that translates dbus::ErrorResponse response // from D-Bus into brillo::Error* and invokes the |callback|. void BRILLO_EXPORT TranslateErrorResponse(const AsyncErrorCallback& callback, - dbus::ErrorResponse* resp); + ::dbus::ErrorResponse* resp); // A helper function that translates dbus::Response from D-Bus into // a list of C++ values passed as parameters to |success_callback|. If the // response message doesn't have the correct number of parameters, or they // are of wrong types, an error is sent to |error_callback|. -template<typename... OutArgs> +template <typename... OutArgs> void TranslateSuccessResponse( const base::Callback<void(OutArgs...)>& success_callback, const AsyncErrorCallback& error_callback, - dbus::Response* resp) { + ::dbus::Response* resp) { auto callback = [&success_callback](const OutArgs&... params) { if (!success_callback.is_null()) { success_callback.Run(params...); } }; ErrorPtr error; - dbus::MessageReader reader(resp); + ::dbus::MessageReader reader(resp); if (!DBusParamReader<false, OutArgs...>::Invoke(callback, &reader, &error) && !error_callback.is_null()) { error_callback.Run(error.get()); @@ -283,43 +281,40 @@ void TranslateSuccessResponse( // a problem invoking a method (e.g. object or method doesn't exist). // If the response is not received within |timeout_ms|, an error callback is // called with DBUS_ERROR_NO_REPLY error code. -template<typename... InArgs, typename... OutArgs> +template <typename... InArgs, typename... OutArgs> inline void CallMethodWithTimeout( int timeout_ms, - dbus::ObjectProxy* object, + ::dbus::ObjectProxy* object, const std::string& interface_name, const std::string& method_name, const base::Callback<void(OutArgs...)>& success_callback, const AsyncErrorCallback& error_callback, const InArgs&... params) { - dbus::MethodCall method_call(interface_name, method_name); - dbus::MessageWriter writer(&method_call); + ::dbus::MethodCall method_call(interface_name, method_name); + ::dbus::MessageWriter writer(&method_call); DBusParamWriter::Append(&writer, params...); - dbus::ObjectProxy::ErrorCallback dbus_error_callback = + ::dbus::ObjectProxy::ErrorCallback dbus_error_callback = base::Bind(&TranslateErrorResponse, error_callback); - dbus::ObjectProxy::ResponseCallback dbus_success_callback = base::Bind( + ::dbus::ObjectProxy::ResponseCallback dbus_success_callback = base::Bind( &TranslateSuccessResponse<OutArgs...>, success_callback, error_callback); - object->CallMethodWithErrorCallback( - &method_call, timeout_ms, dbus_success_callback, dbus_error_callback); + object->CallMethodWithErrorCallback(&method_call, timeout_ms, + std::move(dbus_success_callback), + std::move(dbus_error_callback)); } // Same as CallMethodWithTimeout() but uses a default timeout value. -template<typename... InArgs, typename... OutArgs> -inline void CallMethod(dbus::ObjectProxy* object, +template <typename... InArgs, typename... OutArgs> +inline void CallMethod(::dbus::ObjectProxy* object, const std::string& interface_name, const std::string& method_name, const base::Callback<void(OutArgs...)>& success_callback, const AsyncErrorCallback& error_callback, const InArgs&... params) { - return CallMethodWithTimeout(dbus::ObjectProxy::TIMEOUT_USE_DEFAULT, - object, - interface_name, - method_name, - success_callback, - error_callback, - params...); + return CallMethodWithTimeout(::dbus::ObjectProxy::TIMEOUT_USE_DEFAULT, object, + interface_name, method_name, success_callback, + error_callback, params...); } } // namespace dbus_utils diff --git a/brillo/dbus/dbus_method_invoker_unittest.cc b/brillo/dbus/dbus_method_invoker_test.cc index 34f4c6f..9e6600a 100644 --- a/brillo/dbus/dbus_method_invoker_unittest.cc +++ b/brillo/dbus/dbus_method_invoker_test.cc @@ -6,8 +6,7 @@ #include <string> -#include <base/files/scoped_file.h> -#include <brillo/bind_lambda.h> +#include <base/bind.h> #include <dbus/mock_bus.h> #include <dbus/mock_object_proxy.h> #include <dbus/scoped_dbus_error.h> @@ -84,16 +83,18 @@ class DBusMethodInvokerTest : public testing::Test { GetObjectProxy(kTestServiceName, dbus::ObjectPath(kTestPath))) .WillRepeatedly(Return(mock_object_proxy_.get())); int def_timeout_ms = dbus::ObjectProxy::TIMEOUT_USE_DEFAULT; - EXPECT_CALL(*mock_object_proxy_, - MockCallMethodAndBlockWithErrorDetails(_, def_timeout_ms, _)) + EXPECT_CALL( + *mock_object_proxy_, + MIGRATE_MockCallMethodAndBlockWithErrorDetails(_, def_timeout_ms, _)) .WillRepeatedly(Invoke(this, &DBusMethodInvokerTest::CreateResponse)); } void TearDown() override { bus_ = nullptr; } - Response* CreateResponse(dbus::MethodCall* method_call, - int /* timeout_ms */, - dbus::ScopedDBusError* dbus_error) { + MIGRATE_WrapObjectProxyResponseType(Response) + CreateResponse(dbus::MethodCall* method_call, + int /* timeout_ms */, + dbus::ScopedDBusError* dbus_error) { if (method_call->GetInterface() == kTestInterface) { if (method_call->GetMember() == kTestMethod1) { MessageReader reader(method_call); @@ -104,12 +105,12 @@ class DBusMethodInvokerTest : public testing::Test { auto response = Response::CreateEmpty(); MessageWriter writer(response.get()); writer.AppendString(std::to_string(v1 + v2)); - return response.release(); + return MIGRATE_WrapObjectProxyResponseConversion(response); } } else if (method_call->GetMember() == kTestMethod2) { method_call->SetSerial(123); dbus_set_error(dbus_error->get(), "org.MyError", "My error message"); - return nullptr; + return MIGRATE_WrapObjectProxyResponseEmpty; } else if (method_call->GetMember() == kTestMethod3) { MessageReader reader(method_call); dbus_utils_test::TestMessage msg; @@ -117,7 +118,7 @@ class DBusMethodInvokerTest : public testing::Test { auto response = Response::CreateEmpty(); MessageWriter writer(response.get()); AppendValueToWriter(&writer, msg); - return response.release(); + return MIGRATE_WrapObjectProxyResponseConversion(response); } } else if (method_call->GetMember() == kTestMethod4) { method_call->SetSerial(123); @@ -127,13 +128,13 @@ class DBusMethodInvokerTest : public testing::Test { auto response = Response::CreateEmpty(); MessageWriter writer(response.get()); writer.AppendFileDescriptor(fd.get()); - return response.release(); + return MIGRATE_WrapObjectProxyResponseConversion(response); } } } LOG(ERROR) << "Unexpected method call: " << method_call->ToString(); - return nullptr; + return MIGRATE_WrapObjectProxyResponseEmpty; } std::string CallTestMethod(int v1, int v2) { @@ -244,7 +245,7 @@ class AsyncDBusMethodInvokerTest : public testing::Test { .WillRepeatedly(Return(mock_object_proxy_.get())); int def_timeout_ms = dbus::ObjectProxy::TIMEOUT_USE_DEFAULT; EXPECT_CALL(*mock_object_proxy_, - CallMethodWithErrorCallback(_, def_timeout_ms, _, _)) + MIGRATE_CallMethodWithErrorCallback(_, def_timeout_ms, _, _)) .WillRepeatedly(Invoke(this, &AsyncDBusMethodInvokerTest::HandleCall)); } @@ -252,8 +253,10 @@ class AsyncDBusMethodInvokerTest : public testing::Test { void HandleCall(dbus::MethodCall* method_call, int /* timeout_ms */, - dbus::ObjectProxy::ResponseCallback success_callback, - dbus::ObjectProxy::ErrorCallback error_callback) { + dbus::ObjectProxy::ResponseCallback + MIGRATE_WrapObjectProxyCallback(success_callback), + dbus::ObjectProxy::ErrorCallback + MIGRATE_WrapObjectProxyCallback(error_callback)) { if (method_call->GetInterface() == kTestInterface) { if (method_call->GetMember() == kTestMethod1) { MessageReader reader(method_call); @@ -264,14 +267,16 @@ class AsyncDBusMethodInvokerTest : public testing::Test { auto response = Response::CreateEmpty(); MessageWriter writer(response.get()); writer.AppendString(std::to_string(v1 + v2)); - success_callback.Run(response.get()); + std::move(MIGRATE_WrapObjectProxyCallback(success_callback)) + .Run(response.get()); } return; } else if (method_call->GetMember() == kTestMethod2) { method_call->SetSerial(123); auto error_response = dbus::ErrorResponse::FromMethodCall( method_call, "org.MyError", "My error message"); - error_callback.Run(error_response.get()); + std::move(MIGRATE_WrapObjectProxyCallback(error_callback)) + .Run(error_response.get()); return; } } diff --git a/brillo/dbus/dbus_method_response.h b/brillo/dbus/dbus_method_response.h index 289f11e..15df602 100644 --- a/brillo/dbus/dbus_method_response.h +++ b/brillo/dbus/dbus_method_response.h @@ -5,8 +5,12 @@ #ifndef LIBBRILLO_BRILLO_DBUS_DBUS_METHOD_RESPONSE_H_ #define LIBBRILLO_BRILLO_DBUS_DBUS_METHOD_RESPONSE_H_ +#include <memory> #include <string> +#include <utility> +#include <base/bind.h> +#include <base/location.h> #include <base/macros.h> #include <brillo/brillo_export.h> #include <brillo/dbus/dbus_param_writer.h> @@ -20,14 +24,25 @@ class Error; namespace dbus_utils { -using ResponseSender = dbus::ExportedObject::ResponseSender; +using ResponseSender = ::dbus::ExportedObject::ResponseSender; // DBusMethodResponseBase is a helper class used with asynchronous D-Bus method // handlers to encapsulate the information needed to send the method call // response when it is available. class BRILLO_EXPORT DBusMethodResponseBase { public: - DBusMethodResponseBase(dbus::MethodCall* method_call, ResponseSender sender); + DBusMethodResponseBase(::dbus::MethodCall* method_call, + ResponseSender sender); + DBusMethodResponseBase(DBusMethodResponseBase&& other) + : sender_(std::exchange( + other.sender_, + base::Bind([](std::unique_ptr<dbus::Response> response) { + LOG(DFATAL) + << "Empty DBusMethodResponseBase attempts to send a response"; + }))), + method_call_(std::exchange(other.method_call_, nullptr)) {} + DBusMethodResponseBase& operator=(DBusMethodResponseBase&& other) = delete; + virtual ~DBusMethodResponseBase(); // Sends an error response. Marshals the |error| object over D-Bus. @@ -36,20 +51,20 @@ class BRILLO_EXPORT DBusMethodResponseBase { // For error is from other domains, the full error information (domain, error // code, error message) is encoded into the D-Bus error message and returned // to the caller as "org.freedesktop.DBus.Failed". - void ReplyWithError(const brillo::Error* error); + virtual void ReplyWithError(const brillo::Error* error); // Constructs brillo::Error object from the parameters specified and send // the error information over D-Bus using the method above. - void ReplyWithError(const base::Location& location, - const std::string& error_domain, - const std::string& error_code, - const std::string& error_message); + virtual void ReplyWithError(const base::Location& location, + const std::string& error_domain, + const std::string& error_code, + const std::string& error_message); // Sends a raw D-Bus response message. - void SendRawResponse(std::unique_ptr<dbus::Response> response); + void SendRawResponse(std::unique_ptr<::dbus::Response> response); // Creates a custom response object for the current method call. - std::unique_ptr<dbus::Response> CreateCustomResponse() const; + std::unique_ptr<::dbus::Response> CreateCustomResponse() const; // Checks if the response has been sent already. bool IsResponseSent() const; @@ -67,9 +82,7 @@ class BRILLO_EXPORT DBusMethodResponseBase { // in the bound parameter list in the Callback). We set it to nullptr after // the method call response has been sent to ensure we can't possibly try // to send a response again somehow. - dbus::MethodCall* method_call_; - - DISALLOW_COPY_AND_ASSIGN(DBusMethodResponseBase); + ::dbus::MethodCall* method_call_; }; // DBusMethodResponse is an explicitly-typed version of DBusMethodResponse. @@ -83,10 +96,10 @@ class DBusMethodResponse : public DBusMethodResponseBase { // Sends the a successful response. |return_values| can contain a list // of return values to be sent to the caller. - inline void Return(const Types&... return_values) { + virtual void Return(const Types&... return_values) { CheckCanSendResponse(); auto response = CreateCustomResponse(); - dbus::MessageWriter writer(response.get()); + ::dbus::MessageWriter writer(response.get()); DBusParamWriter::Append(&writer, return_values...); SendRawResponse(std::move(response)); } diff --git a/brillo/dbus/dbus_object.cc b/brillo/dbus/dbus_object.cc index 512cd6f..12eb353 100644 --- a/brillo/dbus/dbus_object.cc +++ b/brillo/dbus/dbus_object.cc @@ -4,6 +4,8 @@ #include <brillo/dbus/dbus_object.h> +#include <memory> +#include <utility> #include <vector> #include <base/bind.h> @@ -37,8 +39,10 @@ void SetupDefaultPropertyHandlers(DBusInterface* prop_interface, DBusInterface::DBusInterface(DBusObject* dbus_object, const std::string& interface_name) - : dbus_object_(dbus_object), interface_name_(interface_name) { -} + : dbus_object_(dbus_object), + interface_name_(interface_name), + // TODO(crbug.com/909719): Use base::DoNothing() + release_interface_cb_(base::Bind([]() {})) {} void DBusInterface::AddProperty(const std::string& property_name, ExportedPropertyBase* prop_base) { @@ -115,6 +119,50 @@ void DBusInterface::ExportAndBlock( } } +void DBusInterface::UnexportAsync( + ExportedObjectManager* object_manager, + dbus::ExportedObject* exported_object, + const dbus::ObjectPath& object_path, + const AsyncEventSequencer::CompletionAction& completion_callback) { + VLOG(1) << "Unexporting D-Bus interface " << interface_name_ << " for " + << object_path.value(); + + // Release the interface. + release_interface_cb_.RunAndReset(); + + // Unexport all method handlers. + scoped_refptr<AsyncEventSequencer> sequencer(new AsyncEventSequencer()); + for (const auto& pair : handlers_) { + std::string method_name = pair.first; + VLOG(1) << "Unexporting method: " << interface_name_ << "." << method_name; + std::string export_error = "Failed unexporting " + method_name + " method"; + auto export_handler = sequencer->GetExportHandler( + interface_name_, method_name, export_error, true); + exported_object->UnexportMethod(interface_name_, method_name, + export_handler); + } + + sequencer->OnAllTasksCompletedCall({completion_callback}); +} + +void DBusInterface::UnexportAndBlock(ExportedObjectManager* object_manager, + dbus::ExportedObject* exported_object, + const dbus::ObjectPath& object_path) { + VLOG(1) << "Unexporting D-Bus interface " << interface_name_ << " for " + << object_path.value(); + + // Release the interface. + release_interface_cb_.RunAndReset(); + + // Unexport all method handlers. + for (const auto& pair : handlers_) { + std::string method_name = pair.first; + VLOG(1) << "Unexporting method: " << interface_name_ << "." << method_name; + if (!exported_object->UnexportMethodAndBlock(interface_name_, method_name)) + LOG(FATAL) << "Failed unexporting " << method_name << " method"; + } +} + void DBusInterface::ClaimInterface( base::WeakPtr<ExportedObjectManager> object_manager, const dbus::ObjectPath& object_path, @@ -125,6 +173,7 @@ void DBusInterface::ClaimInterface( return; } object_manager->ClaimInterface(object_path, interface_name_, writer); + release_interface_cb_.RunAndReset(); release_interface_cb_.ReplaceClosure( base::Bind(&ExportedObjectManager::ReleaseInterface, object_manager, object_path, interface_name_)); @@ -234,6 +283,25 @@ void DBusObject::ExportInterfaceAsync( object_path_, completion_callback); } +void DBusObject::ExportInterfaceAndBlock(const std::string& interface_name) { + AddOrGetInterface(interface_name) + ->ExportAndBlock(object_manager_.get(), bus_.get(), exported_object_, + object_path_); +} + +void DBusObject::UnexportInterfaceAsync( + const std::string& interface_name, + const AsyncEventSequencer::CompletionAction& completion_callback) { + AddOrGetInterface(interface_name) + ->UnexportAsync(object_manager_.get(), exported_object_, object_path_, + completion_callback); +} + +void DBusObject::UnexportInterfaceAndBlock(const std::string& interface_name) { + AddOrGetInterface(interface_name) + ->UnexportAndBlock(object_manager_.get(), exported_object_, object_path_); +} + void DBusObject::RegisterAsync( const AsyncEventSequencer::CompletionAction& completion_callback) { VLOG(1) << "Registering D-Bus object '" << object_path_.value() << "'."; diff --git a/brillo/dbus/dbus_object.h b/brillo/dbus/dbus_object.h index 61c954f..6ab0b23 100644 --- a/brillo/dbus/dbus_object.h +++ b/brillo/dbus/dbus_object.h @@ -45,7 +45,8 @@ class MyDbusObject { void Method3(std::unique_ptr<DBusMethodResponse<int_32>> response, const std::string& message) { if (message.empty()) { - response->ReplyWithError(brillo::errors::dbus::kDomain, + response->ReplyWithError(FROM_HERE, + brillo::errors::dbus::kDomain, DBUS_ERROR_INVALID_ARGS, "Message string cannot be empty"); return; @@ -62,7 +63,9 @@ class MyDbusObject { #define LIBBRILLO_BRILLO_DBUS_DBUS_OBJECT_H_ #include <map> +#include <memory> #include <string> +#include <utility> #include <base/bind.h> #include <base/callback_helpers.h> @@ -197,10 +200,10 @@ class BRILLO_EXPORT DBusInterface final { // Register sync DBus method handler for |method_name| as base::Callback. // Passing the method sender as a first parameter to the callback. - template<typename... Args> + template <typename... Args> inline void AddSimpleMethodHandlerWithErrorAndMessage( const std::string& method_name, - const base::Callback<bool(ErrorPtr*, dbus::Message*, Args...)>& + const base::Callback<bool(ErrorPtr*, ::dbus::Message*, Args...)>& handler) { Handler<SimpleDBusInterfaceMethodHandlerWithErrorAndMessage<Args...>>::Add( this, method_name, handler); @@ -209,10 +212,10 @@ class BRILLO_EXPORT DBusInterface final { // Register sync D-Bus method handler for |method_name| as a static // function. Passing the method D-Bus message as the second parameter to the // callback. - template<typename... Args> + template <typename... Args> inline void AddSimpleMethodHandlerWithErrorAndMessage( const std::string& method_name, - bool(*handler)(ErrorPtr*, dbus::Message*, Args...)) { + bool (*handler)(ErrorPtr*, ::dbus::Message*, Args...)) { Handler<SimpleDBusInterfaceMethodHandlerWithErrorAndMessage<Args...>>::Add( this, method_name, base::Bind(handler)); } @@ -220,21 +223,21 @@ class BRILLO_EXPORT DBusInterface final { // Register sync D-Bus method handler for |method_name| as a class member // function. Passing the method D-Bus message as the second parameter to the // callback. - template<typename Instance, typename Class, typename... Args> + template <typename Instance, typename Class, typename... Args> inline void AddSimpleMethodHandlerWithErrorAndMessage( const std::string& method_name, Instance instance, - bool(Class::*handler)(ErrorPtr*, dbus::Message*, Args...)) { + bool (Class::*handler)(ErrorPtr*, ::dbus::Message*, Args...)) { Handler<SimpleDBusInterfaceMethodHandlerWithErrorAndMessage<Args...>>::Add( this, method_name, base::Bind(handler, instance)); } // Same as above but for const-method of a class. - template<typename Instance, typename Class, typename... Args> + template <typename Instance, typename Class, typename... Args> inline void AddSimpleMethodHandlerWithErrorAndMessage( const std::string& method_name, Instance instance, - bool(Class::*handler)(ErrorPtr*, dbus::Message*, Args...) const) { + bool (Class::*handler)(ErrorPtr*, ::dbus::Message*, Args...) const) { Handler<SimpleDBusInterfaceMethodHandlerWithErrorAndMessage<Args...>>::Add( this, method_name, base::Bind(handler, instance)); } @@ -294,11 +297,11 @@ class BRILLO_EXPORT DBusInterface final { } // Register an async DBus method handler for |method_name| as base::Callback. - template<typename Response, typename... Args> + template <typename Response, typename... Args> inline void AddMethodHandlerWithMessage( const std::string& method_name, - const base::Callback<void(std::unique_ptr<Response>, dbus::Message*, - Args...)>& handler) { + const base::Callback<void( + std::unique_ptr<Response>, ::dbus::Message*, Args...)>& handler) { static_assert(std::is_base_of<DBusMethodResponseBase, Response>::value, "Response must be DBusMethodResponse<T...>"); Handler<DBusInterfaceMethodHandlerWithMessage<Response, Args...>>::Add( @@ -307,10 +310,10 @@ class BRILLO_EXPORT DBusInterface final { // Register an async D-Bus method handler for |method_name| as a static // function. - template<typename Response, typename... Args> + template <typename Response, typename... Args> inline void AddMethodHandlerWithMessage( const std::string& method_name, - void (*handler)(std::unique_ptr<Response>, dbus::Message*, Args...)) { + void (*handler)(std::unique_ptr<Response>, ::dbus::Message*, Args...)) { static_assert(std::is_base_of<DBusMethodResponseBase, Response>::value, "Response must be DBusMethodResponse<T...>"); Handler<DBusInterfaceMethodHandlerWithMessage<Response, Args...>>::Add( @@ -319,15 +322,16 @@ class BRILLO_EXPORT DBusInterface final { // Register an async D-Bus method handler for |method_name| as a class member // function. - template<typename Response, - typename Instance, - typename Class, - typename... Args> + template <typename Response, + typename Instance, + typename Class, + typename... Args> inline void AddMethodHandlerWithMessage( const std::string& method_name, Instance instance, - void(Class::*handler)(std::unique_ptr<Response>, - dbus::Message*, Args...)) { + void (Class::*handler)(std::unique_ptr<Response>, + ::dbus::Message*, + Args...)) { static_assert(std::is_base_of<DBusMethodResponseBase, Response>::value, "Response must be DBusMethodResponse<T...>"); Handler<DBusInterfaceMethodHandlerWithMessage<Response, Args...>>::Add( @@ -335,15 +339,16 @@ class BRILLO_EXPORT DBusInterface final { } // Same as above but for const-method of a class. - template<typename Response, - typename Instance, - typename Class, - typename... Args> + template <typename Response, + typename Instance, + typename Class, + typename... Args> inline void AddMethodHandlerWithMessage( const std::string& method_name, Instance instance, - void(Class::*handler)(std::unique_ptr<Response>, dbus::Message*, - Args...) const) { + void (Class::*handler)(std::unique_ptr<Response>, + ::dbus::Message*, + Args...) const) { static_assert(std::is_base_of<DBusMethodResponseBase, Response>::value, "Response must be DBusMethodResponse<T...>"); Handler<DBusInterfaceMethodHandlerWithMessage<Response, Args...>>::Add( @@ -353,17 +358,18 @@ class BRILLO_EXPORT DBusInterface final { // Register a raw D-Bus method handler for |method_name| as base::Callback. inline void AddRawMethodHandler( const std::string& method_name, - const base::Callback<void(dbus::MethodCall*, ResponseSender)>& handler) { + const base::Callback<void(::dbus::MethodCall*, ResponseSender)>& + handler) { Handler<RawDBusInterfaceMethodHandler>::Add(this, method_name, handler); } // Register a raw D-Bus method handler for |method_name| as a class member // function. - template<typename Instance, typename Class> - inline void AddRawMethodHandler( - const std::string& method_name, - Instance instance, - void(Class::*handler)(dbus::MethodCall*, ResponseSender)) { + template <typename Instance, typename Class> + inline void AddRawMethodHandler(const std::string& method_name, + Instance instance, + void (Class::*handler)(::dbus::MethodCall*, + ResponseSender)) { Handler<RawDBusInterfaceMethodHandler>::Add( this, method_name, base::Bind(handler, instance)); } @@ -444,7 +450,7 @@ class BRILLO_EXPORT DBusInterface final { // A generic D-Bus method handler for the interface. It extracts the method // name from |method_call|, looks up a registered handler from |handlers_| // map and dispatched the call to that handler. - void HandleMethodCall(dbus::MethodCall* method_call, ResponseSender sender); + void HandleMethodCall(::dbus::MethodCall* method_call, ResponseSender sender); // Helper to add a handler for method |method_name| to the |handlers_| map. // Not marked BRILLO_PRIVATE because it needs to be called by the inline // template functions AddMethodHandler(...) @@ -467,9 +473,9 @@ class BRILLO_EXPORT DBusInterface final { // registration operation is completed. BRILLO_PRIVATE void ExportAsync( ExportedObjectManager* object_manager, - dbus::Bus* bus, - dbus::ExportedObject* exported_object, - const dbus::ObjectPath& object_path, + ::dbus::Bus* bus, + ::dbus::ExportedObject* exported_object, + const ::dbus::ObjectPath& object_path, const AsyncEventSequencer::CompletionAction& completion_callback); // Exports all the methods and properties of this interface and claims the // D-Bus interface synchronously. @@ -478,15 +484,24 @@ class BRILLO_EXPORT DBusInterface final { // exported_object - instance of D-Bus object the interface is being added to. // object_path - D-Bus object path for the object instance. // interface_name - name of interface being registered. - BRILLO_PRIVATE void ExportAndBlock( + BRILLO_PRIVATE void ExportAndBlock(ExportedObjectManager* object_manager, + ::dbus::Bus* bus, + ::dbus::ExportedObject* exported_object, + const ::dbus::ObjectPath& object_path); + // Releases the D-Bus interface and unexports all the methods asynchronously. + BRILLO_PRIVATE void UnexportAsync( ExportedObjectManager* object_manager, - dbus::Bus* bus, - dbus::ExportedObject* exported_object, - const dbus::ObjectPath& object_path); + ::dbus::ExportedObject* exported_object, + const ::dbus::ObjectPath& object_path, + const AsyncEventSequencer::CompletionAction& completion_callback); + // Releases the D-Bus interface and unexports all the methods synchronously. + BRILLO_PRIVATE void UnexportAndBlock(ExportedObjectManager* object_manager, + ::dbus::ExportedObject* exported_object, + const ::dbus::ObjectPath& object_path); BRILLO_PRIVATE void ClaimInterface( base::WeakPtr<ExportedObjectManager> object_manager, - const dbus::ObjectPath& object_path, + const ::dbus::ObjectPath& object_path, const ExportedPropertySet::PropertyWriter& writer, bool all_succeeded); @@ -518,8 +533,8 @@ class BRILLO_EXPORT DBusObject { // changes on those interfaces. // object_path - D-Bus object path for the object instance. DBusObject(ExportedObjectManager* object_manager, - const scoped_refptr<dbus::Bus>& bus, - const dbus::ObjectPath& object_path); + const scoped_refptr<::dbus::Bus>& bus, + const ::dbus::ObjectPath& object_path); // property_handler_setup_callback - To be called when setting up property // method handlers. Clients can register @@ -527,8 +542,8 @@ class BRILLO_EXPORT DBusObject { // (GetAll/Get/Set) by passing in this // callback. DBusObject(ExportedObjectManager* object_manager, - const scoped_refptr<dbus::Bus>& bus, - const dbus::ObjectPath& object_path, + const scoped_refptr<::dbus::Bus>& bus, + const ::dbus::ObjectPath& object_path, PropertyHandlerSetupCallback property_handler_setup_callback); virtual ~DBusObject(); @@ -551,6 +566,28 @@ class BRILLO_EXPORT DBusObject { const std::string& interface_name, const AsyncEventSequencer::CompletionAction& completion_callback); + // Exports a proxy handler for the interface |interface_name|. If the + // interface proxy does not exist yet, it will be automatically created. This + // call is synchronous and will block until all methods of the interface are + // registered and the interface is claimed. + void ExportInterfaceAndBlock(const std::string& interface_name); + + // Unexports the interface |interface_name| and unexports all method handlers. + // In some cases, one may want to export an interface even after it's removed. + // In that case, they should call this method before removing the interface + // to make sure it will start with a clean state of method handlers. + void UnexportInterfaceAsync( + const std::string& interface_name, + const AsyncEventSequencer::CompletionAction& completion_callback); + + // Unexports the interface |interface_name| and unexports all method handlers. + // In some cases, one may want to export an interface even after it's removed. + // In that case, they should call this method before removing the interface + // to make sure it will start with a clean state of method handlers. + // This call is synchronous and will block until the interface is released and + // all of its methods of are unregistered. + void UnexportInterfaceAndBlock(const std::string& interface_name); + // Registers the object instance with D-Bus. This is an asynchronous call // that will call |completion_callback| when the object and all of its // interfaces are registered. @@ -576,10 +613,10 @@ class BRILLO_EXPORT DBusObject { } // Sends a signal from the exported D-Bus object. - bool SendSignal(dbus::Signal* signal); + bool SendSignal(::dbus::Signal* signal); // Returns the reference to dbus::Bus this object is associated with. - scoped_refptr<dbus::Bus> GetBus() { return bus_; } + scoped_refptr<::dbus::Bus> GetBus() { return bus_; } private: // Add the org.freedesktop.DBus.Properties interface to the object. @@ -593,11 +630,11 @@ class BRILLO_EXPORT DBusObject { // Delegate object implementing org.freedesktop.DBus.ObjectManager interface. base::WeakPtr<ExportedObjectManager> object_manager_; // D-Bus bus object. - scoped_refptr<dbus::Bus> bus_; + scoped_refptr<::dbus::Bus> bus_; // D-Bus object path for this object. - dbus::ObjectPath object_path_; + ::dbus::ObjectPath object_path_; // D-Bus object instance once this object is successfully exported. - dbus::ExportedObject* exported_object_ = nullptr; // weak; owned by |bus_|. + ::dbus::ExportedObject* exported_object_ = nullptr; // weak; owned by |bus_|. // Sets up property method handlers. PropertyHandlerSetupCallback property_handler_setup_callback_; diff --git a/brillo/dbus/dbus_object_internal_impl.h b/brillo/dbus/dbus_object_internal_impl.h index 3c5e8d7..a521776 100644 --- a/brillo/dbus/dbus_object_internal_impl.h +++ b/brillo/dbus/dbus_object_internal_impl.h @@ -32,6 +32,7 @@ #include <memory> #include <string> #include <type_traits> +#include <utility> #include <brillo/dbus/data_serialization.h> #include <brillo/dbus/dbus_method_response.h> @@ -52,7 +53,7 @@ class DBusInterfaceMethodHandlerInterface { // Returns true if the method has been handled synchronously (whether or not // a success or error response message had been sent). - virtual void HandleMethod(dbus::MethodCall* method_call, + virtual void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) = 0; }; @@ -76,7 +77,7 @@ class SimpleDBusInterfaceMethodHandler explicit SimpleDBusInterfaceMethodHandler( const base::Callback<R(Args...)>& handler) : handler_(handler) {} - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { DBusMethodResponse<R> method_response(method_call, sender); auto invoke_callback = [this, &method_response](const Args&... args) { @@ -84,7 +85,7 @@ class SimpleDBusInterfaceMethodHandler }; ErrorPtr param_reader_error; - dbus::MessageReader reader(method_call); + ::dbus::MessageReader reader(method_call); // The handler is expected a return value, don't allow output parameters. if (!DBusParamReader<false, Args...>::Invoke( invoke_callback, &reader, ¶m_reader_error)) { @@ -110,19 +111,19 @@ class SimpleDBusInterfaceMethodHandler<void, Args...> explicit SimpleDBusInterfaceMethodHandler( const base::Callback<void(Args...)>& handler) : handler_(handler) {} - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { DBusMethodResponseBase method_response(method_call, sender); auto invoke_callback = [this, &method_response](const Args&... args) { handler_.Run(args...); auto response = method_response.CreateCustomResponse(); - dbus::MessageWriter writer(response.get()); + ::dbus::MessageWriter writer(response.get()); DBusParamWriter::AppendDBusOutParams(&writer, args...); method_response.SendRawResponse(std::move(response)); }; ErrorPtr param_reader_error; - dbus::MessageReader reader(method_call); + ::dbus::MessageReader reader(method_call); if (!DBusParamReader<true, Args...>::Invoke( invoke_callback, &reader, ¶m_reader_error)) { // Error parsing method arguments. @@ -156,7 +157,7 @@ class SimpleDBusInterfaceMethodHandlerWithError const base::Callback<bool(ErrorPtr*, Args...)>& handler) : handler_(handler) {} - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { DBusMethodResponseBase method_response(method_call, sender); auto invoke_callback = [this, &method_response](const Args&... args) { @@ -165,14 +166,14 @@ class SimpleDBusInterfaceMethodHandlerWithError method_response.ReplyWithError(error.get()); } else { auto response = method_response.CreateCustomResponse(); - dbus::MessageWriter writer(response.get()); + ::dbus::MessageWriter writer(response.get()); DBusParamWriter::AppendDBusOutParams(&writer, args...); method_response.SendRawResponse(std::move(response)); } }; ErrorPtr param_reader_error; - dbus::MessageReader reader(method_call); + ::dbus::MessageReader reader(method_call); if (!DBusParamReader<true, Args...>::Invoke( invoke_callback, &reader, ¶m_reader_error)) { // Error parsing method arguments. @@ -204,10 +205,10 @@ class SimpleDBusInterfaceMethodHandlerWithErrorAndMessage // A constructor that takes a |handler| to be called when HandleMethod() // virtual function is invoked. explicit SimpleDBusInterfaceMethodHandlerWithErrorAndMessage( - const base::Callback<bool(ErrorPtr*, dbus::Message*, Args...)>& handler) + const base::Callback<bool(ErrorPtr*, ::dbus::Message*, Args...)>& handler) : handler_(handler) {} - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { DBusMethodResponseBase method_response(method_call, sender); auto invoke_callback = @@ -217,14 +218,14 @@ class SimpleDBusInterfaceMethodHandlerWithErrorAndMessage method_response.ReplyWithError(error.get()); } else { auto response = method_response.CreateCustomResponse(); - dbus::MessageWriter writer(response.get()); + ::dbus::MessageWriter writer(response.get()); DBusParamWriter::AppendDBusOutParams(&writer, args...); method_response.SendRawResponse(std::move(response)); } }; ErrorPtr param_reader_error; - dbus::MessageReader reader(method_call); + ::dbus::MessageReader reader(method_call); if (!DBusParamReader<true, Args...>::Invoke( invoke_callback, &reader, ¶m_reader_error)) { // Error parsing method arguments. @@ -234,7 +235,7 @@ class SimpleDBusInterfaceMethodHandlerWithErrorAndMessage private: // C++ callback to be called when a DBus method is dispatched. - base::Callback<bool(ErrorPtr*, dbus::Message*, Args...)> handler_; + base::Callback<bool(ErrorPtr*, ::dbus::Message*, Args...)> handler_; DISALLOW_COPY_AND_ASSIGN(SimpleDBusInterfaceMethodHandlerWithErrorAndMessage); }; @@ -257,7 +258,7 @@ class DBusInterfaceMethodHandler : public DBusInterfaceMethodHandlerInterface { // This method forwards the call to |handler_| after extracting the required // arguments from the DBus message buffer specified in |method_call|. // The output parameters of |handler_| (if any) are sent back to the called. - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { auto invoke_callback = [this, method_call, &sender](const Args&... args) { std::unique_ptr<Response> response(new Response(method_call, sender)); @@ -265,7 +266,7 @@ class DBusInterfaceMethodHandler : public DBusInterfaceMethodHandlerInterface { }; ErrorPtr param_reader_error; - dbus::MessageReader reader(method_call); + ::dbus::MessageReader reader(method_call); if (!DBusParamReader<false, Args...>::Invoke( invoke_callback, &reader, ¶m_reader_error)) { // Error parsing method arguments. @@ -297,14 +298,14 @@ class DBusInterfaceMethodHandlerWithMessage // A constructor that takes a |handler| to be called when HandleMethod() // virtual function is invoked. explicit DBusInterfaceMethodHandlerWithMessage( - const base::Callback<void(std::unique_ptr<Response>, dbus::Message*, - Args...)>& handler) + const base::Callback< + void(std::unique_ptr<Response>, ::dbus::Message*, Args...)>& handler) : handler_(handler) {} // This method forwards the call to |handler_| after extracting the required // arguments from the DBus message buffer specified in |method_call|. // The output parameters of |handler_| (if any) are sent back to the called. - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { auto invoke_callback = [this, method_call, &sender](const Args&... args) { std::unique_ptr<Response> response(new Response(method_call, sender)); @@ -312,7 +313,7 @@ class DBusInterfaceMethodHandlerWithMessage }; ErrorPtr param_reader_error; - dbus::MessageReader reader(method_call); + ::dbus::MessageReader reader(method_call); if (!DBusParamReader<false, Args...>::Invoke( invoke_callback, &reader, ¶m_reader_error)) { // Error parsing method arguments. @@ -323,8 +324,8 @@ class DBusInterfaceMethodHandlerWithMessage private: // C++ callback to be called when a D-Bus method is dispatched. - base::Callback<void(std::unique_ptr<Response>, - dbus::Message*, Args...)> handler_; + base::Callback<void(std::unique_ptr<Response>, ::dbus::Message*, Args...)> + handler_; DISALLOW_COPY_AND_ASSIGN(DBusInterfaceMethodHandlerWithMessage); }; @@ -341,18 +342,18 @@ class RawDBusInterfaceMethodHandler public: // A constructor that takes a |handler| to be called when HandleMethod() // virtual function is invoked. - explicit RawDBusInterfaceMethodHandler( - const base::Callback<void(dbus::MethodCall*, ResponseSender)>& handler) + RawDBusInterfaceMethodHandler( + const base::Callback<void(::dbus::MethodCall*, ResponseSender)>& handler) : handler_(handler) {} - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { handler_.Run(method_call, sender); } private: // C++ callback to be called when a D-Bus method is dispatched. - base::Callback<void(dbus::MethodCall*, ResponseSender)> handler_; + base::Callback<void(::dbus::MethodCall*, ResponseSender)> handler_; DISALLOW_COPY_AND_ASSIGN(RawDBusInterfaceMethodHandler); }; diff --git a/brillo/dbus/dbus_object_unittest.cc b/brillo/dbus/dbus_object_test.cc index 932a5c8..09615c8 100644 --- a/brillo/dbus/dbus_object_unittest.cc +++ b/brillo/dbus/dbus_object_test.cc @@ -17,8 +17,6 @@ using ::testing::AnyNumber; using ::testing::Return; -using ::testing::Invoke; -using ::testing::Mock; using ::testing::_; namespace brillo { @@ -335,7 +333,39 @@ TEST_F(DBusObjectTest, TestRemovedInterface) { EXPECT_EQ(DBUS_ERROR_UNKNOWN_INTERFACE, response->GetErrorName()); } -TEST_F(DBusObjectTest, TestInterfaceExportedLate) { +TEST_F(DBusObjectTest, TestUnexportInterfaceAsync) { + // Unexport the interface to be tested. It should unexport the methods on that + // interface. + EXPECT_CALL(*mock_exported_object_, + UnexportMethod(kTestInterface3, kTestMethod_NoOp, _)) + .Times(1); + EXPECT_CALL(*mock_exported_object_, + UnexportMethod(kTestInterface3, kTestMethod_WithMessage, _)) + .Times(1); + EXPECT_CALL(*mock_exported_object_, + UnexportMethod(kTestInterface3, kTestMethod_WithMessageAsync, _)) + .Times(1); + dbus_object_->UnexportInterfaceAsync(kTestInterface3, + base::Bind(&OnInterfaceExported)); +} + +TEST_F(DBusObjectTest, TestUnexportInterfaceBlocking) { + // Unexport the interface to be tested. It should unexport the methods on that + // interface. + EXPECT_CALL(*mock_exported_object_, + UnexportMethodAndBlock(kTestInterface3, kTestMethod_NoOp)) + .WillOnce(Return(true)); + EXPECT_CALL(*mock_exported_object_, + UnexportMethodAndBlock(kTestInterface3, kTestMethod_WithMessage)) + .WillOnce(Return(true)); + EXPECT_CALL( + *mock_exported_object_, + UnexportMethodAndBlock(kTestInterface3, kTestMethod_WithMessageAsync)) + .WillOnce(Return(true)); + dbus_object_->UnexportInterfaceAndBlock(kTestInterface3); +} + +TEST_F(DBusObjectTest, TestInterfaceExportedLateAsync) { // Registers a new interface late. dbus_object_->ExportInterfaceAsync(kTestInterface4, base::Bind(&OnInterfaceExported)); @@ -350,6 +380,20 @@ TEST_F(DBusObjectTest, TestInterfaceExportedLate) { EXPECT_EQ(DBUS_ERROR_UNKNOWN_METHOD, response->GetErrorName()); } +TEST_F(DBusObjectTest, TestInterfaceExportedLateBlocking) { + // Registers a new interface late. + dbus_object_->ExportInterfaceAndBlock(kTestInterface4); + + const std::string sender{":1.2345"}; + dbus::MethodCall method_call(kTestInterface4, kTestMethod_WithMessage); + method_call.SetSerial(123); + method_call.SetSender(sender); + auto response = testing::CallMethod(*dbus_object_, &method_call); + // The response should contain error UnknownMethod rather than + // UnknownInterface since the interface has been registered late. + EXPECT_EQ(DBUS_ERROR_UNKNOWN_METHOD, response->GetErrorName()); +} + TEST_F(DBusObjectTest, TooFewParams) { dbus::MethodCall method_call(kTestInterface1, kTestMethod_Add); method_call.SetSerial(123); diff --git a/brillo/dbus/dbus_object_test_helpers.h b/brillo/dbus/dbus_object_test_helpers.h index 59c4a06..4a9287f 100644 --- a/brillo/dbus/dbus_object_test_helpers.h +++ b/brillo/dbus/dbus_object_test_helpers.h @@ -12,6 +12,9 @@ #ifndef LIBBRILLO_BRILLO_DBUS_DBUS_OBJECT_TEST_HELPERS_H_ #define LIBBRILLO_BRILLO_DBUS_DBUS_OBJECT_TEST_HELPERS_H_ +#include <memory> +#include <utility> + #include <base/bind.h> #include <base/memory/weak_ptr.h> #include <brillo/dbus/dbus_method_invoker.h> @@ -25,7 +28,7 @@ namespace dbus_utils { class DBusInterfaceTestHelper final { public: static void HandleMethodCall(DBusInterface* itf, - dbus::MethodCall* method_call, + ::dbus::MethodCall* method_call, ResponseSender sender) { itf->HandleMethodCall(method_call, sender); } @@ -40,11 +43,11 @@ namespace testing { // ResponseHolder::ReceiveResponse() will not be called since we bind the // callback to the object instance via a weak pointer. struct ResponseHolder final : public base::SupportsWeakPtr<ResponseHolder> { - void ReceiveResponse(std::unique_ptr<dbus::Response> response) { + void ReceiveResponse(std::unique_ptr<::dbus::Response> response) { response_ = std::move(response); } - std::unique_ptr<dbus::Response> response_; + std::unique_ptr<::dbus::Response> response_; }; // Dispatches a D-Bus method call to the corresponding handler. @@ -53,10 +56,10 @@ struct ResponseHolder final : public base::SupportsWeakPtr<ResponseHolder> { // call sites. Returns a response from the method handler or nullptr if the // method hasn't provided the response message immediately // (i.e. it is asynchronous). -inline std::unique_ptr<dbus::Response> CallMethod( - const DBusObject& object, dbus::MethodCall* method_call) { +inline std::unique_ptr<::dbus::Response> CallMethod( + const DBusObject& object, ::dbus::MethodCall* method_call) { DBusInterface* itf = object.FindInterface(method_call->GetInterface()); - std::unique_ptr<dbus::Response> response; + std::unique_ptr<::dbus::Response> response; if (!itf) { response = CreateDBusErrorResponse( method_call, @@ -95,7 +98,7 @@ struct MethodHandlerInvoker { Params...), Args... args) { ResponseHolder response_holder; - dbus::MethodCall method_call("test.interface", "TestMethod"); + ::dbus::MethodCall method_call("test.interface", "TestMethod"); method_call.SetSerial(123); std::unique_ptr<DBusMethodResponse<RetType>> method_response{ new DBusMethodResponse<RetType>( @@ -122,7 +125,7 @@ struct MethodHandlerInvoker<void> { void(Class::*method)(std::unique_ptr<DBusMethodResponse<>>, Params...), Args... args) { ResponseHolder response_holder; - dbus::MethodCall method_call("test.interface", "TestMethod"); + ::dbus::MethodCall method_call("test.interface", "TestMethod"); method_call.SetSerial(123); std::unique_ptr<DBusMethodResponse<>> method_response{ new DBusMethodResponse<>(&method_call, diff --git a/brillo/dbus/dbus_param_reader.h b/brillo/dbus/dbus_param_reader.h index 228cfb6..f5c4541 100644 --- a/brillo/dbus/dbus_param_reader.h +++ b/brillo/dbus/dbus_param_reader.h @@ -51,9 +51,9 @@ struct DBusParamReader<allow_out_params, CurrentParam, RestOfParams...> { // method_call - D-Bus method call object we are processing. // reader - D-Bus message reader to pop the current argument value from. // args... - the callback parameters processed so far. - template<typename CallbackType, typename... Args> + template <typename CallbackType, typename... Args> static bool Invoke(const CallbackType& handler, - dbus::MessageReader* reader, + ::dbus::MessageReader* reader, ErrorPtr* error, const Args&... args) { return InvokeHelper<CurrentParam, CallbackType, Args...>( @@ -70,10 +70,10 @@ struct DBusParamReader<allow_out_params, CurrentParam, RestOfParams...> { // parameters should be sent back in the method call response message. // Overload 1: ParamType is not a pointer. - template<typename ParamType, typename CallbackType, typename... Args> + template <typename ParamType, typename CallbackType, typename... Args> static typename std::enable_if<!std::is_pointer<ParamType>::value, bool>::type InvokeHelper(const CallbackType& handler, - dbus::MessageReader* reader, + ::dbus::MessageReader* reader, ErrorPtr* error, const Args&... args) { if (!reader->HasMoreData()) { @@ -112,13 +112,14 @@ struct DBusParamReader<allow_out_params, CurrentParam, RestOfParams...> { } // Overload 2: ParamType is a pointer. - template<typename ParamType, typename CallbackType, typename... Args> + template <typename ParamType, typename CallbackType, typename... Args> static typename std::enable_if<allow_out_params && - std::is_pointer<ParamType>::value, bool>::type - InvokeHelper(const CallbackType& handler, - dbus::MessageReader* reader, - ErrorPtr* error, - const Args&... args) { + std::is_pointer<ParamType>::value, + bool>::type + InvokeHelper(const CallbackType& handler, + ::dbus::MessageReader* reader, + ErrorPtr* error, + const Args&... args) { // ParamType is a pointer. This is expected to be an output parameter. // Create storage for it and the handler will provide a value for it. using ParamValueType = typename std::remove_pointer<ParamType>::type; @@ -143,9 +144,9 @@ struct DBusParamReader<allow_out_params, CurrentParam, RestOfParams...> { // handler with all the accumulated arguments. template<bool allow_out_params> struct DBusParamReader<allow_out_params> { - template<typename CallbackType, typename... Args> + template <typename CallbackType, typename... Args> static bool Invoke(const CallbackType& handler, - dbus::MessageReader* reader, + ::dbus::MessageReader* reader, ErrorPtr* error, const Args&... args) { if (reader->HasMoreData()) { diff --git a/brillo/dbus/dbus_param_reader_unittest.cc b/brillo/dbus/dbus_param_reader_test.cc index fd9f243..abf1da3 100644 --- a/brillo/dbus/dbus_param_reader_unittest.cc +++ b/brillo/dbus/dbus_param_reader_test.cc @@ -4,6 +4,7 @@ #include <brillo/dbus/dbus_param_reader.h> +#include <memory> #include <string> #include <brillo/variant_dictionary.h> diff --git a/brillo/dbus/dbus_param_writer.h b/brillo/dbus/dbus_param_writer.h index 7c7f45e..779ea61 100644 --- a/brillo/dbus/dbus_param_writer.h +++ b/brillo/dbus/dbus_param_writer.h @@ -24,8 +24,8 @@ class DBusParamWriter final { public: // Generic writer method that takes 1 or more arguments. It recursively calls // itself (each time with one fewer arguments) until no more is left. - template<typename ParamType, typename... RestOfParams> - static void Append(dbus::MessageWriter* writer, + template <typename ParamType, typename... RestOfParams> + static void Append(::dbus::MessageWriter* writer, const ParamType& param, const RestOfParams&... rest) { // Append the current |param| to D-Bus, then call Append() with one @@ -38,13 +38,13 @@ class DBusParamWriter final { // The final overload of DBusParamWriter::Append() used when no more // parameters are remaining to be written. // Does nothing and finishes meta-recursion. - static void Append(dbus::MessageWriter* /*writer*/) {} + static void Append(::dbus::MessageWriter* /*writer*/) {} // Generic writer method that takes 1 or more arguments. It recursively calls // itself (each time with one fewer arguments) until no more is left. // Handles non-pointer parameter by just skipping over it. - template<typename ParamType, typename... RestOfParams> - static void AppendDBusOutParams(dbus::MessageWriter* writer, + template <typename ParamType, typename... RestOfParams> + static void AppendDBusOutParams(::dbus::MessageWriter* writer, const ParamType& /* param */, const RestOfParams&... rest) { // Skip the current |param| and call Append() with one fewer arguments, @@ -57,8 +57,8 @@ class DBusParamWriter final { // itself (each time with one fewer arguments) until no more is left. // Handles only a parameter of pointer type and writes the data pointed to // to the output message buffer. - template<typename ParamType, typename... RestOfParams> - static void AppendDBusOutParams(dbus::MessageWriter* writer, + template <typename ParamType, typename... RestOfParams> + static void AppendDBusOutParams(::dbus::MessageWriter* writer, ParamType* param, const RestOfParams&... rest) { // Append the current |param| to D-Bus, then call Append() with one @@ -71,7 +71,7 @@ class DBusParamWriter final { // The final overload of DBusParamWriter::AppendDBusOutParams() used when no // more parameters are remaining to be written. // Does nothing and finishes meta-recursion. - static void AppendDBusOutParams(dbus::MessageWriter* /*writer*/) {} + static void AppendDBusOutParams(::dbus::MessageWriter* /*writer*/) {} }; } // namespace dbus_utils diff --git a/brillo/dbus/dbus_param_writer_unittest.cc b/brillo/dbus/dbus_param_writer_test.cc index 6ab863a..2611ada 100644 --- a/brillo/dbus/dbus_param_writer_unittest.cc +++ b/brillo/dbus/dbus_param_writer_test.cc @@ -4,6 +4,7 @@ #include <brillo/dbus/dbus_param_writer.h> +#include <memory> #include <string> #include <brillo/any.h> diff --git a/brillo/dbus/dbus_property.h b/brillo/dbus/dbus_property.h index 01b850d..77b7328 100644 --- a/brillo/dbus/dbus_property.h +++ b/brillo/dbus/dbus_property.h @@ -16,8 +16,8 @@ namespace dbus_utils { // This class is pretty much a copy of dbus::Property<T> from dbus/property.h // except that it provides the implementations for PopValueFromReader and // AppendSetValueToWriter. -template<class T> -class Property : public dbus::PropertyBase { +template <class T> +class Property : public ::dbus::PropertyBase { public: Property() = default; @@ -27,7 +27,7 @@ class Property : public dbus::PropertyBase { // Requests an updated value from the remote object incurring a // round-trip. |callback| will be called when the new value is available. // This may not be implemented by some interfaces. - void Get(dbus::PropertySet::GetCallback callback) { + void Get(::dbus::PropertySet::GetCallback callback) { property_set()->Get(this, callback); } @@ -40,7 +40,7 @@ class Property : public dbus::PropertyBase { // |callback| will be called to indicate the success or failure of the // request, however the new value may not be available depending on the // remote object. - void Set(const T& value, dbus::PropertySet::SetCallback callback) { + void Set(const T& value, ::dbus::PropertySet::SetCallback callback) { set_value_ = value; property_set()->Set(this, callback); } @@ -54,14 +54,14 @@ class Property : public dbus::PropertyBase { // Method used by PropertySet to retrieve the value from a MessageReader, // no knowledge of the contained type is required, this method returns // true if its expected type was found, false if not. - bool PopValueFromReader(dbus::MessageReader* reader) override { + bool PopValueFromReader(::dbus::MessageReader* reader) override { return PopVariantValueFromReader(reader, &value_); } // Method used by PropertySet to append the set value to a MessageWriter, // no knowledge of the contained type is required. // Implementation provided by specialization. - void AppendSetValueToWriter(dbus::MessageWriter* writer) override { + void AppendSetValueToWriter(::dbus::MessageWriter* writer) override { AppendValueToWriterAsVariant(writer, set_value_); } diff --git a/brillo/dbus/dbus_service_watcher.h b/brillo/dbus/dbus_service_watcher.h index 0031771..b747161 100644 --- a/brillo/dbus/dbus_service_watcher.h +++ b/brillo/dbus/dbus_service_watcher.h @@ -29,7 +29,7 @@ namespace dbus_utils { // cause the Bus to crash the process on destruction. class BRILLO_EXPORT DBusServiceWatcher { public: - DBusServiceWatcher(scoped_refptr<dbus::Bus> bus, + DBusServiceWatcher(scoped_refptr<::dbus::Bus> bus, const std::string& connection_name, const base::Closure& on_connection_vanish); virtual ~DBusServiceWatcher(); @@ -38,9 +38,9 @@ class BRILLO_EXPORT DBusServiceWatcher { private: void OnServiceOwnerChange(const std::string& service_owner); - scoped_refptr<dbus::Bus> bus_; + scoped_refptr<::dbus::Bus> bus_; const std::string connection_name_; - dbus::Bus::GetServiceOwnerCallback monitoring_callback_; + ::dbus::Bus::GetServiceOwnerCallback monitoring_callback_; base::Closure on_connection_vanish_; base::WeakPtrFactory<DBusServiceWatcher> weak_factory_{this}; diff --git a/brillo/dbus/dbus_signal.h b/brillo/dbus/dbus_signal.h index bda322a..d1fcced 100644 --- a/brillo/dbus/dbus_signal.h +++ b/brillo/dbus/dbus_signal.h @@ -30,7 +30,7 @@ class BRILLO_EXPORT DBusSignalBase { virtual ~DBusSignalBase() = default; protected: - bool SendSignal(dbus::Signal* signal) const; + bool SendSignal(::dbus::Signal* signal) const; std::string interface_name_; std::string signal_name_; @@ -51,9 +51,11 @@ class DBusSignal : public DBusSignalBase { ~DBusSignal() override = default; // DBusSignal<...>::Send(...) dispatches the signal with the given arguments. + // Note: This function can be called from any thread/task runner, as it'll + // eventually post the actual signal sending to the DBus thread. bool Send(const Args&... args) const { - dbus::Signal signal(interface_name_, signal_name_); - dbus::MessageWriter signal_writer(&signal); + ::dbus::Signal signal(interface_name_, signal_name_); + ::dbus::MessageWriter signal_writer(&signal); DBusParamWriter::Append(&signal_writer, args...); return SendSignal(&signal); } diff --git a/brillo/dbus/dbus_signal_handler.h b/brillo/dbus/dbus_signal_handler.h index 15cdae1..e89f867 100644 --- a/brillo/dbus/dbus_signal_handler.h +++ b/brillo/dbus/dbus_signal_handler.h @@ -7,8 +7,9 @@ #include <functional> #include <string> +#include <utility> -#include <brillo/bind_lambda.h> +#include <base/bind.h> #include <brillo/dbus/dbus_param_reader.h> #include <dbus/message.h> #include <dbus/object_proxy.h> @@ -31,39 +32,36 @@ namespace dbus_utils { // If the signal message doesn't contain correct number or types of arguments, // an error message is logged to the system log and the signal is ignored // (|signal_callback| is not invoked). -template<typename... Args> +template <typename... Args> void ConnectToSignal( - dbus::ObjectProxy* object_proxy, + ::dbus::ObjectProxy* object_proxy, const std::string& interface_name, const std::string& signal_name, base::Callback<void(Args...)> signal_callback, - dbus::ObjectProxy::OnConnectedCallback on_connected_callback) { + ::dbus::ObjectProxy::OnConnectedCallback on_connected_callback) { + // DBusParamReader::Invoke() needs a functor object, not a base::Callback. + // Wrap the callback with lambda so we can redirect the call. + auto signal_callback_wrapper = [signal_callback](const Args&... args) { + if (!signal_callback.is_null()) { + signal_callback.Run(args...); + } + }; + // Raw signal handler stub method. When called, unpacks the signal arguments // from |signal| message buffer and redirects the call to // |signal_callback_wrapper| which, in turn, would call the user-provided // |signal_callback|. - auto dbus_signal_callback = []( - const base::Callback<void(Args...)>& signal_callback, - dbus::Signal* signal) { - // DBusParamReader::Invoke() needs a functor object, not a base::Callback. - // Wrap the callback with lambda so we can redirect the call. - auto signal_callback_wrapper = [signal_callback](const Args&... args) { - if (!signal_callback.is_null()) { - signal_callback.Run(args...); - } - }; - - dbus::MessageReader reader(signal); - DBusParamReader<false, Args...>::Invoke( - signal_callback_wrapper, &reader, nullptr); + auto dbus_signal_callback = [](std::function<void(const Args&...)> callback, + ::dbus::Signal* signal) { + ::dbus::MessageReader reader(signal); + DBusParamReader<false, Args...>::Invoke(callback, &reader, nullptr); }; // Register our stub handler with D-Bus ObjectProxy. object_proxy->ConnectToSignal( - interface_name, - signal_name, - base::Bind(dbus_signal_callback, signal_callback), - on_connected_callback); + interface_name, signal_name, + base::Bind(dbus_signal_callback, signal_callback_wrapper), + std::move(on_connected_callback)); } } // namespace dbus_utils diff --git a/brillo/dbus/dbus_signal_handler_unittest.cc b/brillo/dbus/dbus_signal_handler_test.cc index e0bea10..edd0eca 100644 --- a/brillo/dbus/dbus_signal_handler_unittest.cc +++ b/brillo/dbus/dbus_signal_handler_test.cc @@ -6,7 +6,7 @@ #include <string> -#include <brillo/bind_lambda.h> +#include <base/bind.h> #include <brillo/dbus/dbus_param_writer.h> #include <dbus/mock_bus.h> #include <dbus/mock_object_proxy.h> @@ -49,7 +49,8 @@ class DBusSignalHandlerTest : public testing::Test { template<typename SignalHandlerSink, typename... Args> void CallSignal(SignalHandlerSink* sink, Args... args) { dbus::ObjectProxy::SignalCallback signal_callback; - EXPECT_CALL(*mock_object_proxy_, ConnectToSignal(kInterface, kSignal, _, _)) + EXPECT_CALL(*mock_object_proxy_, + MIGRATE_ConnectToSignal(kInterface, kSignal, _, _)) .WillOnce(SaveArg<2>(&signal_callback)); brillo::dbus_utils::ConnectToSignal( @@ -70,7 +71,8 @@ class DBusSignalHandlerTest : public testing::Test { }; TEST_F(DBusSignalHandlerTest, ConnectToSignal) { - EXPECT_CALL(*mock_object_proxy_, ConnectToSignal(kInterface, kSignal, _, _)) + EXPECT_CALL(*mock_object_proxy_, + MIGRATE_ConnectToSignal(kInterface, kSignal, _, _)) .Times(1); brillo::dbus_utils::ConnectToSignal( @@ -80,7 +82,7 @@ TEST_F(DBusSignalHandlerTest, ConnectToSignal) { TEST_F(DBusSignalHandlerTest, CallSignal_3Args) { class SignalHandlerSink { public: - MOCK_METHOD3(Handler, void(int, int, double)); + MOCK_METHOD(void, Handler, (int, int, double)); } sink; EXPECT_CALL(sink, Handler(10, 20, 30.5)).Times(1); @@ -91,7 +93,7 @@ TEST_F(DBusSignalHandlerTest, CallSignal_2Args) { class SignalHandlerSink { public: // Take string both by reference and by value to make sure this works too. - MOCK_METHOD2(Handler, void(const std::string&, std::string)); + MOCK_METHOD(void, Handler, (const std::string&, std::string)); } sink; EXPECT_CALL(sink, Handler(std::string{"foo"}, std::string{"bar"})).Times(1); @@ -101,7 +103,7 @@ TEST_F(DBusSignalHandlerTest, CallSignal_2Args) { TEST_F(DBusSignalHandlerTest, CallSignal_NoArgs) { class SignalHandlerSink { public: - MOCK_METHOD0(Handler, void()); + MOCK_METHOD(void, Handler, ()); } sink; EXPECT_CALL(sink, Handler()).Times(1); @@ -111,7 +113,7 @@ TEST_F(DBusSignalHandlerTest, CallSignal_NoArgs) { TEST_F(DBusSignalHandlerTest, CallSignal_Error_TooManyArgs) { class SignalHandlerSink { public: - MOCK_METHOD0(Handler, void()); + MOCK_METHOD(void, Handler, ()); } sink; // Handler() expects no args, but we send an int. @@ -122,7 +124,7 @@ TEST_F(DBusSignalHandlerTest, CallSignal_Error_TooManyArgs) { TEST_F(DBusSignalHandlerTest, CallSignal_Error_TooFewArgs) { class SignalHandlerSink { public: - MOCK_METHOD2(Handler, void(std::string, bool)); + MOCK_METHOD(void, Handler, (std::string, bool)); } sink; // Handler() expects 2 args while we send it just one. @@ -133,7 +135,7 @@ TEST_F(DBusSignalHandlerTest, CallSignal_Error_TooFewArgs) { TEST_F(DBusSignalHandlerTest, CallSignal_Error_TypeMismatchArgs) { class SignalHandlerSink { public: - MOCK_METHOD2(Handler, void(std::string, bool)); + MOCK_METHOD(void, Handler, (std::string, bool)); } sink; // Handler() expects "sb" while we send it "ii". diff --git a/brillo/dbus/exported_object_manager.h b/brillo/dbus/exported_object_manager.h index ea68f33..9534009 100644 --- a/brillo/dbus/exported_object_manager.h +++ b/brillo/dbus/exported_object_manager.h @@ -6,6 +6,7 @@ #define LIBBRILLO_BRILLO_DBUS_EXPORTED_OBJECT_MANAGER_H_ #include <map> +#include <memory> #include <string> #include <vector> @@ -80,12 +81,12 @@ class BRILLO_EXPORT ExportedObjectManager : public base::SupportsWeakPtr<ExportedObjectManager> { public: using ObjectMap = - std::map<dbus::ObjectPath, std::map<std::string, VariantDictionary>>; + std::map<::dbus::ObjectPath, std::map<std::string, VariantDictionary>>; using InterfaceProperties = std::map<std::string, ExportedPropertySet::PropertyWriter>; - ExportedObjectManager(scoped_refptr<dbus::Bus> bus, - const dbus::ObjectPath& path); + ExportedObjectManager(scoped_refptr<::dbus::Bus> bus, + const ::dbus::ObjectPath& path); virtual ~ExportedObjectManager() = default; // Registers methods implementing the ObjectManager interface on the object @@ -98,35 +99,28 @@ class BRILLO_EXPORT ExportedObjectManager // Trigger a signal that |path| has added an interface |interface_name| // with properties as given by |writer|. virtual void ClaimInterface( - const dbus::ObjectPath& path, + const ::dbus::ObjectPath& path, const std::string& interface_name, const ExportedPropertySet::PropertyWriter& writer); // Trigger a signal that |path| has removed an interface |interface_name|. - virtual void ReleaseInterface(const dbus::ObjectPath& path, + virtual void ReleaseInterface(const ::dbus::ObjectPath& path, const std::string& interface_name); - const scoped_refptr<dbus::Bus>& GetBus() const { return bus_; } - - // Due to D-Bus forwarding, clients may need to access the underlying - // DBusObject to handle signals/methods. - // TODO(sonnysasaka): Refactor this accessor into a stricter API once we know - // what D-Bus forwarding needs when it's completed, without exposing - // DBusObject directly. - brillo::dbus_utils::DBusObject* dbus_object() { return &dbus_object_; }; + const scoped_refptr<::dbus::Bus>& GetBus() const { return bus_; } private: BRILLO_PRIVATE ObjectMap HandleGetManagedObjects(); - scoped_refptr<dbus::Bus> bus_; + scoped_refptr<::dbus::Bus> bus_; brillo::dbus_utils::DBusObject dbus_object_; // Tracks all objects currently known to the ExportedObjectManager. - std::map<dbus::ObjectPath, InterfaceProperties> registered_objects_; + std::map<::dbus::ObjectPath, InterfaceProperties> registered_objects_; using SignalInterfacesAdded = - DBusSignal<dbus::ObjectPath, std::map<std::string, VariantDictionary>>; + DBusSignal<::dbus::ObjectPath, std::map<std::string, VariantDictionary>>; using SignalInterfacesRemoved = - DBusSignal<dbus::ObjectPath, std::vector<std::string>>; + DBusSignal<::dbus::ObjectPath, std::vector<std::string>>; std::weak_ptr<SignalInterfacesAdded> signal_itf_added_; std::weak_ptr<SignalInterfacesRemoved> signal_itf_removed_; diff --git a/brillo/dbus/exported_object_manager_unittest.cc b/brillo/dbus/exported_object_manager_test.cc index 00fe108..6837399 100644 --- a/brillo/dbus/exported_object_manager_unittest.cc +++ b/brillo/dbus/exported_object_manager_test.cc @@ -4,6 +4,8 @@ #include <brillo/dbus/exported_object_manager.h> +#include <utility> + #include <base/bind.h> #include <brillo/dbus/dbus_object_test_helpers.h> #include <brillo/dbus/utils.h> diff --git a/brillo/dbus/exported_property_set.cc b/brillo/dbus/exported_property_set.cc index 018843e..c71aab6 100644 --- a/brillo/dbus/exported_property_set.cc +++ b/brillo/dbus/exported_property_set.cc @@ -4,16 +4,15 @@ #include <brillo/dbus/exported_property_set.h> +#include <utility> + #include <base/bind.h> #include <dbus/bus.h> #include <dbus/property.h> // For kPropertyInterface -#include <brillo/dbus/async_event_sequencer.h> #include <brillo/dbus/dbus_object.h> #include <brillo/errors/error_codes.h> -using brillo::dbus_utils::AsyncEventSequencer; - namespace brillo { namespace dbus_utils { diff --git a/brillo/dbus/exported_property_set.h b/brillo/dbus/exported_property_set.h index 971e932..08d0ae4 100644 --- a/brillo/dbus/exported_property_set.h +++ b/brillo/dbus/exported_property_set.h @@ -8,6 +8,7 @@ #include <stdint.h> #include <map> +#include <memory> #include <string> #include <vector> @@ -97,7 +98,7 @@ class BRILLO_EXPORT ExportedPropertySet { public: using PropertyWriter = base::Callback<void(VariantDictionary* dict)>; - explicit ExportedPropertySet(dbus::Bus* bus); + explicit ExportedPropertySet(::dbus::Bus* bus); virtual ~ExportedPropertySet() = default; // Called to notify ExportedPropertySet that the Properties interface of the @@ -148,7 +149,7 @@ class BRILLO_EXPORT ExportedPropertySet { const std::string& property_name, const ExportedPropertyBase* exported_property); - dbus::Bus* bus_; // weak; owned by outer DBusObject containing this object. + ::dbus::Bus* bus_; // weak; owned by outer DBusObject containing this object. // This is a map from interface name -> property name -> pointer to property. std::map<std::string, std::map<std::string, ExportedPropertyBase*>> properties_; diff --git a/brillo/dbus/exported_property_set_unittest.cc b/brillo/dbus/exported_property_set_test.cc index 93aceb4..6f9dbd7 100644 --- a/brillo/dbus/exported_property_set_unittest.cc +++ b/brillo/dbus/exported_property_set_test.cc @@ -177,8 +177,7 @@ class PropertyValidatorObserver { base::Unretained(this))) {} virtual ~PropertyValidatorObserver() {} - MOCK_METHOD2_T(ValidateProperty, - bool(brillo::ErrorPtr* error, const T& value)); + MOCK_METHOD(bool, ValidateProperty, (brillo::ErrorPtr*, const T&)); const base::Callback<bool(brillo::ErrorPtr*, const T&)>& validate_property_callback() const { diff --git a/brillo/dbus/file_descriptor.h b/brillo/dbus/file_descriptor.h index f7be44f..2cf1b02 100644 --- a/brillo/dbus/file_descriptor.h +++ b/brillo/dbus/file_descriptor.h @@ -5,6 +5,8 @@ #ifndef LIBBRILLO_BRILLO_DBUS_FILE_DESCRIPTOR_H_ #define LIBBRILLO_BRILLO_DBUS_FILE_DESCRIPTOR_H_ +#include <utility> + #include <base/files/scoped_file.h> #include <base/macros.h> diff --git a/brillo/dbus/introspectable_helper.cc b/brillo/dbus/introspectable_helper.cc new file mode 100644 index 0000000..68ec78c --- /dev/null +++ b/brillo/dbus/introspectable_helper.cc @@ -0,0 +1,81 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/dbus/introspectable_helper.h> + +#include <memory> + +#include <base/bind.h> +#include <dbus/dbus-shared.h> + +namespace brillo { +namespace dbus_utils { + +using base::Bind; +using std::string; +using std::unique_ptr; + +void IntrospectableInterfaceHelper::AddInterfaceXml(string xml) { + interface_xmls.push_back(xml); +} + +void IntrospectableInterfaceHelper::RegisterWithDBusObject(DBusObject* object) { + DBusInterface* itf = object->AddOrGetInterface(DBUS_INTERFACE_INTROSPECTABLE); + + itf->AddMethodHandler("Introspect", GetHandler()); +} + +IntrospectableInterfaceHelper::IntrospectCallback +IntrospectableInterfaceHelper::GetHandler() { + return Bind( + [](const string& xml, StringResponse response) { response->Return(xml); }, + GetXmlString()); +} + +string IntrospectableInterfaceHelper::GetXmlString() { + constexpr const char header[] = + "<!DOCTYPE node PUBLIC " + "\"-//freedesktop//DTD D-BUS Object Introspection 1.0//EN\"\n" + "\"http://www.freedesktop.org/standards/dbus/1.0/introspect.dtd\">\n" + "\n" + "<node>\n" + " <interface name=\"org.freedesktop.DBus.Introspectable\">\n" + " <method name=\"Introspect\">\n" + " <arg name=\"data\" direction=\"out\" type=\"s\"/>\n" + " </method>\n" + " </interface>\n" + " <interface name=\"org.freedesktop.DBus.Properties\">\n" + " <method name=\"Get\">\n" + " <arg name=\"interface\" direction=\"in\" type=\"s\"/>\n" + " <arg name=\"propname\" direction=\"in\" type=\"s\"/>\n" + " <arg name=\"value\" direction=\"out\" type=\"v\"/>\n" + " </method>\n" + " <method name=\"Set\">\n" + " <arg name=\"interface\" direction=\"in\" type=\"s\"/>\n" + " <arg name=\"propname\" direction=\"in\" type=\"s\"/>\n" + " <arg name=\"value\" direction=\"in\" type=\"v\"/>\n" + " </method>\n" + " <method name=\"GetAll\">\n" + " <arg name=\"interface\" direction=\"in\" type=\"s\"/>\n" + " <arg name=\"props\" direction=\"out\" type=\"a{sv}\"/>\n" + " </method>\n" + " </interface>\n"; + constexpr const char footer[] = "</node>\n"; + + size_t result_len = strlen(header) + strlen(footer); + for (const string& xml : interface_xmls) { + result_len += xml.size(); + } + + string result = header; + result.reserve(result_len + 1); // +1 for null terminator + for (const string& xml : interface_xmls) { + result.append(xml); + } + result.append(footer); + return result; +} + +} // namespace dbus_utils +} // namespace brillo diff --git a/brillo/dbus/introspectable_helper.h b/brillo/dbus/introspectable_helper.h new file mode 100644 index 0000000..e1a398f --- /dev/null +++ b/brillo/dbus/introspectable_helper.h @@ -0,0 +1,68 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_DBUS_INTROSPECTABLE_HELPER_H_ +#define LIBBRILLO_BRILLO_DBUS_INTROSPECTABLE_HELPER_H_ + +#include <memory> +#include <string> +#include <vector> + +#include <brillo/brillo_export.h> +#include <brillo/dbus/dbus_method_response.h> +#include <brillo/dbus/dbus_object.h> + +namespace brillo { +namespace dbus_utils { + +// Note that brillo/dbus/dbus_object.h include files that include this file, so +// we'll need this forward declaration. +// class DBusObject; + +// This is a helper class that is used for creating the DBus Introspectable +// Interface. Each of the interfaces that is exported under a DBus Object will +// add its dbus interface introspection XML to this class, and then the user of +// this class will call RegisterWithDBusObject on the DBus object. Then this +// class can be freed. Note that this class is usually used in conjunction with +// the chromeos-dbus-bindings tool. Simply pass the string returned by +// GetIntrospectionXML() of the generated adaptor. Usage example: +// { +// IntrospectableInterfaceHelper helper; +// helper.AddInterfaceXML("<interface...> ...</interface>"); +// helper.AddInterfaceXML("<interface...> ...</interface>"); +// helper.AddInterfaceXML(XXXAdaptor::GetIntrospect()); +// helper.RegisterWithDBusObject(object); +// } +class BRILLO_EXPORT IntrospectableInterfaceHelper { + public: + IntrospectableInterfaceHelper() = default; + + // Add the Introspection XML for an interface to this class. The |xml| string + // should contain an interface XML tag and its content. + void AddInterfaceXml(std::string xml); + + // Register the Introspectable Interface with a DBus object. Note that this + // class can be freed after registering with DBus object. + void RegisterWithDBusObject(DBusObject* object); + + private: + // Internal alias for convenience. + using StringResponse = std::unique_ptr<DBusMethodResponse<std::string>>; + using IntrospectCallback = base::Callback<void(StringResponse)>; + + // Create the method handler for Introspect method call. + IntrospectCallback GetHandler(); + + // Get the complete introspection XML. + std::string GetXmlString(); + + // Stores the list of introspection XMLs for each of the interfaces that was + // added to this class. + std::vector<std::string> interface_xmls; +}; + +} // namespace dbus_utils +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_DBUS_INTROSPECTABLE_HELPER_H_ diff --git a/brillo/dbus/mock_dbus_object.h b/brillo/dbus/mock_dbus_object.h index 82e2fc7..d65f9ab 100644 --- a/brillo/dbus/mock_dbus_object.h +++ b/brillo/dbus/mock_dbus_object.h @@ -17,13 +17,15 @@ namespace dbus_utils { class MockDBusObject : public DBusObject { public: MockDBusObject(ExportedObjectManager* object_manager, - const scoped_refptr<dbus::Bus>& bus, - const dbus::ObjectPath& object_path) + const scoped_refptr<::dbus::Bus>& bus, + const ::dbus::ObjectPath& object_path) : DBusObject(object_manager, bus, object_path) {} ~MockDBusObject() override = default; - MOCK_METHOD1(RegisterAsync, - void(const AsyncEventSequencer::CompletionAction&)); + MOCK_METHOD(void, + RegisterAsync, + (const AsyncEventSequencer::CompletionAction&), + (override)); }; // class MockDBusObject } // namespace dbus_utils diff --git a/brillo/dbus/mock_exported_object_manager.h b/brillo/dbus/mock_exported_object_manager.h index d8abc0a..02bb073 100644 --- a/brillo/dbus/mock_exported_object_manager.h +++ b/brillo/dbus/mock_exported_object_manager.h @@ -24,15 +24,17 @@ class MockExportedObjectManager : public ExportedObjectManager { using ExportedObjectManager::ExportedObjectManager; ~MockExportedObjectManager() override = default; - MOCK_METHOD1(RegisterAsync, - void(const CompletionAction& completion_callback)); - MOCK_METHOD3(ClaimInterface, - void(const dbus::ObjectPath& path, - const std::string& interface_name, - const ExportedPropertySet::PropertyWriter& writer)); - MOCK_METHOD2(ReleaseInterface, - void(const dbus::ObjectPath& path, - const std::string& interface_name)); + MOCK_METHOD(void, RegisterAsync, (const CompletionAction&), (override)); + MOCK_METHOD(void, + ClaimInterface, + (const ::dbus::ObjectPath&, + const std::string&, + const ExportedPropertySet::PropertyWriter&), + (override)); + MOCK_METHOD(void, + ReleaseInterface, + (const ::dbus::ObjectPath&, const std::string&), + (override)); }; } // namespace dbus_utils diff --git a/brillo/dbus/test.proto b/brillo/dbus/test.proto index 84607a3..709bf71 100644 --- a/brillo/dbus/test.proto +++ b/brillo/dbus/test.proto @@ -1,3 +1,9 @@ +// Copyright 2015 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +syntax = "proto2"; + option optimize_for = LITE_RUNTIME; package dbus_utils_test; diff --git a/brillo/dbus/utils.h b/brillo/dbus/utils.h index a548756..163849e 100644 --- a/brillo/dbus/utils.h +++ b/brillo/dbus/utils.h @@ -18,8 +18,8 @@ namespace brillo { namespace dbus_utils { // A helper function to create a D-Bus error response object as unique_ptr<>. -BRILLO_EXPORT std::unique_ptr<dbus::Response> CreateDBusErrorResponse( - dbus::MethodCall* method_call, +BRILLO_EXPORT std::unique_ptr<::dbus::Response> CreateDBusErrorResponse( + ::dbus::MethodCall* method_call, const std::string& error_name, const std::string& error_message); @@ -28,9 +28,8 @@ BRILLO_EXPORT std::unique_ptr<dbus::Response> CreateDBusErrorResponse( // and message are directly translated to D-Bus error code and message. // Any inner errors are formatted as "domain/code:message" string and appended // to the D-Bus error message, delimited by semi-colons. -BRILLO_EXPORT std::unique_ptr<dbus::Response> GetDBusError( - dbus::MethodCall* method_call, - const brillo::Error* error); +BRILLO_EXPORT std::unique_ptr<::dbus::Response> GetDBusError( + ::dbus::MethodCall* method_call, const brillo::Error* error); // AddDBusError() is the opposite of GetDBusError(). It de-serializes the Error // object received over D-Bus. diff --git a/brillo/enum_flags.h b/brillo/enum_flags.h index 9630dd0..227cafd 100644 --- a/brillo/enum_flags.h +++ b/brillo/enum_flags.h @@ -57,7 +57,8 @@ template <typename T, typename = void> struct IsFlagEnum : std::false_type {}; template <typename T> -struct IsFlagEnum<T, Void<typename FlagEnumTraits<T>::EnumFlagType>> : std::true_type {}; +struct IsFlagEnum<T, Void<typename FlagEnumTraits<T>::EnumFlagType>> + : std::true_type {}; } // namespace enum_details @@ -68,7 +69,8 @@ struct IsFlagEnum<T, Void<typename FlagEnumTraits<T>::EnumFlagType>> : std::true template <typename T> constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type operator~(const T& l) { - return static_cast<T>( ~static_cast<typename std::underlying_type<T>::type>(l)); + return static_cast<T>( + ~static_cast<typename std::underlying_type<T>::type>(l)); } // T operator|(T&, T&) @@ -91,37 +93,37 @@ operator&(const T& l, const T& r) { // T operator^(T&, T&) template <typename T> -constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type operator^( - const T& l, const T& r) { +constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type +operator^(const T& l, const T& r) { return static_cast<T>(static_cast<typename std::underlying_type<T>::type>(l) ^ static_cast<typename std::underlying_type<T>::type>(r)); -}; +} // T operator|=(T&, T&) template <typename T> -constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type operator|=( - T& l, const T& r) { +constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type +operator|=(T& l, const T& r) { return l = static_cast<T>( static_cast<typename std::underlying_type<T>::type>(l) | static_cast<typename std::underlying_type<T>::type>(r)); -}; +} // T operator&=(T&, T&) template <typename T> -constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type operator&=( - T& l, const T& r) { +constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type +operator&=(T& l, const T& r) { return l = static_cast<T>( static_cast<typename std::underlying_type<T>::type>(l) & static_cast<typename std::underlying_type<T>::type>(r)); -}; +} // T operator^=(T&, T&) template <typename T> -constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type operator^=( - T& l, const T& r) { +constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type +operator^=(T& l, const T& r) { return l = static_cast<T>( static_cast<typename std::underlying_type<T>::type>(l) ^ static_cast<typename std::underlying_type<T>::type>(r)); -}; +} #endif // LIBBRILLO_BRILLO_ENUM_FLAGS_H_ diff --git a/brillo/enum_flags_unittest.cc b/brillo/enum_flags_test.cc index e57b4ad..e57b4ad 100644 --- a/brillo/enum_flags_unittest.cc +++ b/brillo/enum_flags_test.cc diff --git a/brillo/errors/error.cc b/brillo/errors/error.cc index f229bd7..ccae1fa 100644 --- a/brillo/errors/error.cc +++ b/brillo/errors/error.cc @@ -4,6 +4,8 @@ #include <brillo/errors/error.h> +#include <utility> + #include <base/logging.h> #include <base/strings/stringprintf.h> @@ -19,16 +21,11 @@ inline void LogError(const base::Location& location, // the current error location with the location passed in to the Error object. // This way the log will contain the actual location of the error, and not // as if it always comes from brillo/errors/error.cc(22). - if (location.function_name() == nullptr) { - logging::LogMessage(location.file_name(), location.line_number(), - logging::LOG_ERROR) - .stream() - << "Domain=" << domain << ", Code=" << code << ", Message=" << message; - return; - } - logging::LogMessage( - location.file_name(), location.line_number(), logging::LOG_ERROR).stream() - << location.function_name() << "(...): " + logging::LogMessage(location.file_name(), location.line_number(), + logging::LOG_ERROR) + .stream() + << (location.function_name() ? location.function_name() : "unknown") + << "(...): " << "Domain=" << domain << ", Code=" << code << ", Message=" << message; } } // anonymous namespace diff --git a/brillo/errors/error.h b/brillo/errors/error.h index d08f0e7..1a6a91e 100644 --- a/brillo/errors/error.h +++ b/brillo/errors/error.h @@ -8,8 +8,8 @@ #include <memory> #include <string> -#include <base/macros.h> #include <base/location.h> +#include <base/macros.h> #include <brillo/brillo_export.h> namespace brillo { @@ -110,6 +110,7 @@ class BRILLO_EXPORT Error { // Human-readable error message. std::string message_; // Error origin in the source code. + // TODO(crbug.com/980935): Consider dropping this. base::Location location_; // Pointer to inner error, if any. This forms a chain of errors. ErrorPtr inner_error_; diff --git a/brillo/errors/error_codes.h b/brillo/errors/error_codes.h index 4f1bc09..664fb03 100644 --- a/brillo/errors/error_codes.h +++ b/brillo/errors/error_codes.h @@ -7,6 +7,7 @@ #include <string> +#include <base/location.h> #include <brillo/brillo_export.h> #include <brillo/errors/error.h> diff --git a/brillo/errors/error_codes_unittest.cc b/brillo/errors/error_codes_test.cc index 2baa28f..2baa28f 100644 --- a/brillo/errors/error_codes_unittest.cc +++ b/brillo/errors/error_codes_test.cc diff --git a/brillo/errors/error_unittest.cc b/brillo/errors/error_test.cc index 93f4372..7dd011e 100644 --- a/brillo/errors/error_unittest.cc +++ b/brillo/errors/error_test.cc @@ -4,6 +4,9 @@ #include <brillo/errors/error.h> +#include <utility> + +#include <base/location.h> #include <gtest/gtest.h> using brillo::Error; @@ -12,9 +15,9 @@ namespace { brillo::ErrorPtr GenerateNetworkError() { base::Location loc("GenerateNetworkError", - "error_unittest.cc", - 15, - ::base::GetProgramCounter()); + "error_test.cc", + 15, + ::base::GetProgramCounter()); return Error::Create(loc, "network", "not_found", "Resource not found"); } @@ -31,7 +34,7 @@ TEST(Error, Single) { EXPECT_EQ("not_found", err->GetCode()); EXPECT_EQ("Resource not found", err->GetMessage()); EXPECT_EQ("GenerateNetworkError", err->GetLocation().function_name()); - EXPECT_EQ("error_unittest.cc", err->GetLocation().file_name()); + EXPECT_EQ("error_test.cc", err->GetLocation().file_name()); EXPECT_EQ(15, err->GetLocation().line_number()); EXPECT_EQ(nullptr, err->GetInnerError()); EXPECT_TRUE(err->HasDomain("network")); @@ -73,7 +76,8 @@ TEST(Error, Clone) { EXPECT_EQ(error1->GetMessage(), error2->GetMessage()); EXPECT_EQ(error1->GetLocation().function_name(), error2->GetLocation().function_name()); - EXPECT_EQ(error1->GetLocation().file_name(), error2->GetLocation().file_name()); + EXPECT_EQ(error1->GetLocation().file_name(), + error2->GetLocation().file_name()); EXPECT_EQ(error1->GetLocation().line_number(), error2->GetLocation().line_number()); error1 = error1->GetInnerError(); diff --git a/brillo/file_utils.cc b/brillo/file_utils.cc index 8faa1b7..b4370d1 100644 --- a/brillo/file_utils.cc +++ b/brillo/file_utils.cc @@ -7,13 +7,17 @@ #include <fcntl.h> #include <unistd.h> +#include <limits> +#include <utility> +#include <vector> + #include <base/files/file_path.h> #include <base/files/file_util.h> -#include <base/files/scoped_file.h> #include <base/logging.h> #include <base/posix/eintr_wrapper.h> #include <base/rand_util.h> #include <base/strings/string_number_conversions.h> +#include <base/strings/stringprintf.h> #include <base/time/time.h> namespace brillo { @@ -25,7 +29,8 @@ constexpr const base::TimeDelta kLongSync = base::TimeDelta::FromSeconds(10); enum { kPermissions600 = S_IRUSR | S_IWUSR, - kPermissions777 = S_IRWXU | S_IRWXG | S_IRWXO + kPermissions777 = S_IRWXU | S_IRWXG | S_IRWXO, + kPermissions755 = S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH }; // Verify that base file permission enums are compatible with S_Ixxx. If these @@ -150,6 +155,80 @@ std::string GetRandomSuffix() { return suffix; } +base::ScopedFD OpenPathComponentInternal(int parent_fd, + const std::string& file, + int flags, + mode_t mode) { + DCHECK(file == "/" || file.find("/") == std::string::npos); + base::ScopedFD fd; + + // O_NONBLOCK is used to avoid hanging on edge cases (e.g. a serial port with + // flow control, or a FIFO without a writer). + if (parent_fd >= 0 || parent_fd == AT_FDCWD) { + fd.reset(HANDLE_EINTR(openat(parent_fd, file.c_str(), + flags | O_NONBLOCK | O_NOFOLLOW | O_CLOEXEC, + mode))); + } else if (file == "/") { + fd.reset(HANDLE_EINTR(open( + file.c_str(), + flags | O_RDONLY | O_DIRECTORY | O_NONBLOCK | O_NOFOLLOW | O_CLOEXEC, + mode))); + } + + if (!fd.is_valid()) { + // open(2) fails with ELOOP when the last component of the |path| is a + // symlink. It fails with ENXIO when |path| is a FIFO and |flags| is for + // writing because of the O_NONBLOCK flag added above. + if (errno == ELOOP || errno == ENXIO) { + PLOG(WARNING) << "Failed to open " << file << " safely."; + } else { + PLOG(WARNING) << "Failed to open " << file << "."; + } + return base::ScopedFD(); + } + + // Remove the O_NONBLOCK flag unless the original |flags| have it. + if ((flags & O_NONBLOCK) == 0) { + flags = fcntl(fd.get(), F_GETFL); + if (flags == -1) { + PLOG(ERROR) << "Failed to get fd flags for " << file; + return base::ScopedFD(); + } + if (fcntl(fd.get(), F_SETFL, flags & ~O_NONBLOCK)) { + PLOG(ERROR) << "Failed to set fd flags for " << file; + return base::ScopedFD(); + } + } + + return fd; +} + +base::ScopedFD OpenSafelyInternal(int parent_fd, + const base::FilePath& path, + int flags, + mode_t mode) { + std::vector<std::string> components; + path.GetComponents(&components); + + auto itr = components.begin(); + if (itr == components.end()) { + LOG(ERROR) << "A path is required."; + return base::ScopedFD(); // This is an invalid fd. + } + + base::ScopedFD child_fd; + int parent_flags = flags | O_NONBLOCK | O_RDONLY | O_DIRECTORY | O_PATH; + for (; itr + 1 != components.end(); ++itr) { + child_fd = OpenPathComponentInternal(parent_fd, *itr, parent_flags, 0); + if (!child_fd.is_valid()) { + return base::ScopedFD(); + } + parent_fd = child_fd.get(); + } + + return OpenPathComponentInternal(parent_fd, *itr, flags, mode); +} + } // namespace bool TouchFile(const base::FilePath& path, @@ -184,9 +263,129 @@ bool TouchFile(const base::FilePath& path) { return TouchFile(path, kPermissions600, geteuid(), getegid()); } -bool WriteBlobToFile(const base::FilePath& path, const Blob& blob) { - return WriteToFile(path, reinterpret_cast<const char*>(blob.data()), - blob.size()); +base::ScopedFD OpenSafely(const base::FilePath& path, int flags, mode_t mode) { + if (!path.IsAbsolute()) { + LOG(ERROR) << "An absolute path is required."; + return base::ScopedFD(); // This is an invalid fd. + } + + base::ScopedFD fd(OpenSafelyInternal(-1, path, flags, mode)); + if (!fd.is_valid()) + return base::ScopedFD(); + + // Ensure the opened file is a regular file or directory. + struct stat st; + if (fstat(fd.get(), &st) < 0) { + PLOG(ERROR) << "Failed to fstat " << path.value(); + return base::ScopedFD(); + } + + // This detects a FIFO opened for reading, for example. + if (flags & O_DIRECTORY) { + if (!S_ISDIR(st.st_mode)) { + LOG(ERROR) << path.value() << " is not a directory: " << st.st_mode; + return base::ScopedFD(); + } + } else if (!S_ISREG(st.st_mode) && !S_ISDIR(st.st_mode)) { + LOG(ERROR) << path.value() + << " is not a regular file or directory: " << st.st_mode; + return base::ScopedFD(); + } + + return fd; +} + +base::ScopedFD OpenAtSafely(int parent_fd, + const base::FilePath& path, + int flags, + mode_t mode) { + base::ScopedFD fd(OpenSafelyInternal(parent_fd, path, flags, mode)); + if (!fd.is_valid()) + return base::ScopedFD(); + + // Ensure the opened file is a regular file or directory. + struct stat st; + if (fstat(fd.get(), &st) < 0) { + PLOG(ERROR) << "Failed to fstat " << path.value(); + return base::ScopedFD(); + } + + // This detects a FIFO opened for reading, for example. + if (flags & O_DIRECTORY) { + if (!S_ISDIR(st.st_mode)) { + LOG(ERROR) << path.value() << " is not a directory: " << st.st_mode; + return base::ScopedFD(); + } + } else if (!S_ISREG(st.st_mode)) { + LOG(ERROR) << path.value() << " is not a regular file: " << st.st_mode; + return base::ScopedFD(); + } + + return fd; +} + +base::ScopedFD OpenFifoSafely(const base::FilePath& path, + int flags, + mode_t mode) { + if (!path.IsAbsolute()) { + LOG(ERROR) << "An absolute path is required."; + return base::ScopedFD(); // This is an invalid fd. + } + + base::ScopedFD fd(OpenSafelyInternal(-1, path, flags, mode)); + if (!fd.is_valid()) + return base::ScopedFD(); + + // Ensure the opened file is a FIFO. + struct stat st; + if (fstat(fd.get(), &st) < 0) { + PLOG(ERROR) << "Failed to fstat " << path.value(); + return base::ScopedFD(); + } + + if (!S_ISFIFO(st.st_mode)) { + LOG(ERROR) << path.value() << " is not a FIFO: " << st.st_mode; + return base::ScopedFD(); + } + + return fd; +} + +base::ScopedFD MkdirRecursively(const base::FilePath& full_path, mode_t mode) { + std::vector<std::string> components; + full_path.GetComponents(&components); + + auto itr = components.begin(); + if (!full_path.IsAbsolute() || itr == components.end()) { + LOG(ERROR) << "An absolute path is required."; + return base::ScopedFD(); // This is an invalid fd. + } + + base::ScopedFD parent_fd; + int parent_flags = O_NONBLOCK | O_RDONLY | O_DIRECTORY | O_PATH; + while (itr + 1 != components.end()) { + base::ScopedFD child( + OpenPathComponentInternal(parent_fd.get(), *itr, parent_flags, 0)); + if (!child.is_valid()) { + return base::ScopedFD(); + } + parent_fd = std::move(child); + + ++itr; + + // Try to create the directory. Note that Chromium's MkdirRecursively() uses + // 0700, but we use 0755. + if (mkdirat(parent_fd.get(), itr->c_str(), mode) != 0) { + if (errno != EEXIST) { + PLOG(ERROR) << "Failed to mkdirat " << *itr + << ": full_path=" << full_path.value(); + return base::ScopedFD(); + } + } + } + + return OpenPathComponentInternal(parent_fd.get(), *itr, + O_RDONLY | O_DIRECTORY, 0); } bool WriteStringToFile(const base::FilePath& path, const std::string& data) { @@ -306,11 +505,4 @@ bool WriteToFileAtomic(const base::FilePath& path, return true; } -bool WriteBlobToFileAtomic(const base::FilePath& path, - const Blob& blob, - mode_t mode) { - return WriteToFileAtomic(path, reinterpret_cast<const char*>(blob.data()), - blob.size(), mode); -} - } // namespace brillo diff --git a/brillo/file_utils.h b/brillo/file_utils.h index 663d640..3862a43 100644 --- a/brillo/file_utils.h +++ b/brillo/file_utils.h @@ -7,7 +7,10 @@ #include <sys/types.h> +#include <string> + #include <base/files/file_path.h> +#include <base/files/scoped_file.h> #include <brillo/brillo_export.h> #include <brillo/secure_blob.h> @@ -31,6 +34,62 @@ BRILLO_EXPORT bool TouchFile(const base::FilePath& path, // bit set. BRILLO_EXPORT bool TouchFile(const base::FilePath& path); +// Opens the absolute |path| to a regular file or directory ensuring that none +// of the path components are symbolic links and returns a FD. If |path| is +// relative, or contains any symbolic links, or points to a non-regular file or +// directory, an invalid FD is returned instead. |mode| is ignored unless +// |flags| has either O_CREAT or O_TMPFILE. Note that O_CLOEXEC is set so the +// file descriptor will not be inherited across exec calls. +// +// Parameters +// path - An absolute path of the file to open +// flags - Flags to pass to open. +// mode - Mode to pass to open. +BRILLO_EXPORT base::ScopedFD OpenSafely(const base::FilePath& path, + int flags, + mode_t mode); + +// Opens the |path| relative to the |parent_fd| to a regular file or directory +// ensuring that none of the path components are symbolic links and returns a +// FD. If |path| contains any symbolic links, or points to a non-regular file or +// directory, an invalid FD is returned instead. |mode| is ignored unless +// |flags| has either O_CREAT or O_TMPFILE. Note that O_CLOEXEC is set so the +// file descriptor will not be inherited across exec calls. +// +// Parameters +// parent_fd - The file descriptor of the parent directory +// path - An absolute path of the file to open +// flags - Flags to pass to open. +// mode - Mode to pass to open. +BRILLO_EXPORT base::ScopedFD OpenAtSafely(int parent_fd, + const base::FilePath& path, + int flags, + mode_t mode); + +// Opens the absolute |path| to a FIFO ensuring that none of the path components +// are symbolic links and returns a FD. If |path| is relative, or contains any +// symbolic links, or points to a non-regular file or directory, an invalid FD +// is returned instead. |mode| is ignored unless |flags| has either O_CREAT or +// O_TMPFILE. +// +// Parameters +// path - An absolute path of the file to open +// flags - Flags to pass to open. +// mode - Mode to pass to open. +BRILLO_EXPORT base::ScopedFD OpenFifoSafely(const base::FilePath& path, + int flags, + mode_t mode); + +// Iterates through the path components and creates any missing ones. Guarantees +// the ancestor paths are not symlinks. This function returns an invalid FD on +// failure. Newly created directories will have |mode| permissions. The returned +// file descriptor was opened with both O_RDONLY and O_CLOEXEC. +// +// Parameters +// full_path - An absolute path of the directory to create and open. +BRILLO_EXPORT base::ScopedFD MkdirRecursively(const base::FilePath& full_path, + mode_t mode); + // Writes the entirety of the given data to |path| with 0640 permissions // (modulo umask). If missing, parent (and parent of parent etc.) directories // are created with 0700 permissions (modulo umask). Returns true on success. @@ -39,13 +98,16 @@ BRILLO_EXPORT bool TouchFile(const base::FilePath& path); // path - Path of the file to write // blob/data - blob/string/array to populate from // (size - array size) -BRILLO_EXPORT bool WriteBlobToFile(const base::FilePath& path, - const Blob& blob); BRILLO_EXPORT bool WriteStringToFile(const base::FilePath& path, const std::string& data); BRILLO_EXPORT bool WriteToFile(const base::FilePath& path, const char* data, size_t size); +template <class T> +BRILLO_EXPORT bool WriteBlobToFile(const base::FilePath& path, const T& blob) { + return WriteToFile(path, reinterpret_cast<const char*>(blob.data()), + blob.size()); +} // Calls fdatasync() on file if data_sync is true or fsync() on directory or // file when data_sync is false. Returns true on success. @@ -70,13 +132,17 @@ BRILLO_EXPORT bool SyncFileOrDirectory(const base::FilePath& path, // blob/data - blob/array to populate from // (size - array size) // mode - File permission bit-pattern, eg. 0644 for rw-r--r-- -BRILLO_EXPORT bool WriteBlobToFileAtomic(const base::FilePath& path, - const Blob& blob, - mode_t mode); BRILLO_EXPORT bool WriteToFileAtomic(const base::FilePath& path, const char* data, size_t size, mode_t mode); +template <class T> +BRILLO_EXPORT bool WriteBlobToFileAtomic(const base::FilePath& path, + const T& blob, + mode_t mode) { + return WriteToFileAtomic(path, reinterpret_cast<const char*>(blob.data()), + blob.size(), mode); +} } // namespace brillo diff --git a/brillo/file_utils_unittest.cc b/brillo/file_utils_test.cc index 7a730f0..9a1f646 100644 --- a/brillo/file_utils_unittest.cc +++ b/brillo/file_utils_test.cc @@ -4,6 +4,7 @@ #include "brillo/file_utils.h" +#include <fcntl.h> #include <sys/stat.h> #include <unistd.h> @@ -23,6 +24,7 @@ constexpr int kPermissions600 = base::FILE_PERMISSION_READ_BY_USER | base::FILE_PERMISSION_WRITE_BY_USER; constexpr int kPermissions700 = base::FILE_PERMISSION_USER_MASK; constexpr int kPermissions777 = base::FILE_PERMISSION_MASK; +constexpr int kPermissions755 = S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH; std::string GetRandomSuffix() { const int kBufferSize = 6; @@ -31,6 +33,10 @@ std::string GetRandomSuffix() { return base::HexEncode(buffer, arraysize(buffer)); } +bool IsNonBlockingFD(int fd) { + return fcntl(fd, F_GETFL) & O_NONBLOCK; +} + } // namespace class FileUtilsTest : public testing::Test { @@ -144,6 +150,101 @@ TEST_F(FileUtilsTest, TouchFileExistingPermissionsUnchanged) { ExpectFilePermissions(kPermissions777); } +// Other parts of OpenSafely are tested in Arcsetup.TestInstallDirectory*. +TEST_F(FileUtilsTest, TestOpenSafelyWithoutNonblocking) { + ASSERT_TRUE(TouchFile(file_path_, kPermissions700, geteuid(), getegid())); + base::ScopedFD fd(OpenSafely(file_path_, O_RDONLY, 0)); + EXPECT_TRUE(fd.is_valid()); + EXPECT_FALSE(IsNonBlockingFD(fd.get())); +} + +TEST_F(FileUtilsTest, TestOpenSafelyWithNonblocking) { + ASSERT_TRUE(TouchFile(file_path_, kPermissions700, geteuid(), getegid())); + base::ScopedFD fd = OpenSafely(file_path_, O_RDONLY | O_NONBLOCK, 0); + EXPECT_TRUE(fd.is_valid()); + EXPECT_TRUE(IsNonBlockingFD(fd.get())); +} + +TEST_F(FileUtilsTest, TestOpenFifoSafelySuccess) { + ASSERT_EQ(0, mkfifo(file_path_.value().c_str(), kPermissions700)); + base::ScopedFD fd(OpenFifoSafely(file_path_, O_RDONLY, 0)); + EXPECT_TRUE(fd.is_valid()); + EXPECT_FALSE(IsNonBlockingFD(fd.get())); +} + +TEST_F(FileUtilsTest, TestOpenFifoSafelyRegularFile) { + ASSERT_TRUE(TouchFile(file_path_, kPermissions700, geteuid(), getegid())); + base::ScopedFD fd = OpenFifoSafely(file_path_, O_RDONLY, 0); + EXPECT_FALSE(fd.is_valid()); +} + +TEST_F(FileUtilsTest, TestMkdirRecursivelyRoot) { + // Try to create an existing directory ("/") should still succeed. + EXPECT_TRUE( + MkdirRecursively(base::FilePath("/"), kPermissions755).is_valid()); +} + +TEST_F(FileUtilsTest, TestMkdirRecursivelySuccess) { + // Set |temp_directory| to 0707. + EXPECT_TRUE(base::SetPosixFilePermissions(temp_dir_.GetPath(), 0707)); + + EXPECT_TRUE( + MkdirRecursively(temp_dir_.GetPath().Append("a/b/c"), kPermissions755) + .is_valid()); + // Confirm the 3 directories are there. + EXPECT_TRUE(base::DirectoryExists(temp_dir_.GetPath().Append("a"))); + EXPECT_TRUE(base::DirectoryExists(temp_dir_.GetPath().Append("a/b"))); + EXPECT_TRUE(base::DirectoryExists(temp_dir_.GetPath().Append("a/b/c"))); + + // Confirm that the newly created directories have 0755 mode. + int mode = 0; + EXPECT_TRUE( + base::GetPosixFilePermissions(temp_dir_.GetPath().Append("a"), &mode)); + EXPECT_EQ(kPermissions755, mode); + mode = 0; + EXPECT_TRUE( + base::GetPosixFilePermissions(temp_dir_.GetPath().Append("a/b"), &mode)); + EXPECT_EQ(kPermissions755, mode); + mode = 0; + EXPECT_TRUE(base::GetPosixFilePermissions(temp_dir_.GetPath().Append("a/b/c"), + &mode)); + EXPECT_EQ(kPermissions755, mode); + + // Confirm that the existing directory |temp_directory| still has 0707 mode. + mode = 0; + EXPECT_TRUE(base::GetPosixFilePermissions(temp_dir_.GetPath(), &mode)); + EXPECT_EQ(0707, mode); + + // Call the API again which should still succeed. + EXPECT_TRUE( + MkdirRecursively(temp_dir_.GetPath().Append("a/b/c"), kPermissions755) + .is_valid()); + EXPECT_TRUE( + MkdirRecursively(temp_dir_.GetPath().Append("a/b/c/d"), kPermissions755) + .is_valid()); + EXPECT_TRUE(base::DirectoryExists(temp_dir_.GetPath().Append("a/b/c/d"))); + mode = 0; + EXPECT_TRUE(base::GetPosixFilePermissions( + temp_dir_.GetPath().Append("a/b/c/d"), &mode)); + EXPECT_EQ(kPermissions755, mode); + + // Call the API again which should still succeed. + EXPECT_TRUE( + MkdirRecursively(temp_dir_.GetPath().Append("a/b"), kPermissions755) + .is_valid()); + EXPECT_TRUE(MkdirRecursively(temp_dir_.GetPath().Append("a"), kPermissions755) + .is_valid()); +} + +TEST_F(FileUtilsTest, TestMkdirRecursivelyRelativePath) { + // Try to pass a relative or empty directory. They should all fail. + EXPECT_FALSE( + MkdirRecursively(base::FilePath("foo"), kPermissions755).is_valid()); + EXPECT_FALSE( + MkdirRecursively(base::FilePath("bar/"), kPermissions755).is_valid()); + EXPECT_FALSE(MkdirRecursively(base::FilePath(), kPermissions755).is_valid()); +} + TEST_F(FileUtilsTest, WriteFileCanBeReadBack) { const base::FilePath filename(GetTempName()); const std::string content("blablabla"); diff --git a/brillo/files/OWNERS b/brillo/files/OWNERS new file mode 100644 index 0000000..da09356 --- /dev/null +++ b/brillo/files/OWNERS @@ -0,0 +1,3 @@ +allenwebb@chromium.org +jorgelo@chromium.org +mnissler@chromium.org diff --git a/brillo/files/file_util.cc b/brillo/files/file_util.cc new file mode 100644 index 0000000..c642d14 --- /dev/null +++ b/brillo/files/file_util.cc @@ -0,0 +1,112 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/files/file_util.h" + +#include <fcntl.h> +#include <sys/stat.h> +#include <unistd.h> + +#include <utility> + +#include <base/files/file_util.h> +#include <base/logging.h> +#include <base/strings/stringprintf.h> +#include <brillo/syslog_logging.h> + +namespace brillo { + +namespace { + +enum class FSObjectType { + RegularFile = 0, + Directory, +}; + +SafeFD::SafeFDResult OpenOrRemake(SafeFD* parent, + const std::string& name, + FSObjectType type, + int permissions, + uid_t uid, + gid_t gid, + int flags) { + SafeFD::Error err = IsValidFilename(name); + if (SafeFD::IsError(err)) { + return std::make_pair(SafeFD(), err); + } + + SafeFD::SafeFDResult (SafeFD::*maker)(const base::FilePath&, mode_t, uid_t, + gid_t, int); + if (type == FSObjectType::Directory) { + maker = &SafeFD::MakeDir; + } else { + maker = &SafeFD::MakeFile; + } + + SafeFD child; + std::tie(child, err) = + (parent->*maker)(base::FilePath(name), permissions, uid, gid, flags); + if (child.is_valid()) { + return std::make_pair(std::move(child), err); + } + + // Rmdir should be used on directories. However, kWrongType indicates when + // a directory was expected and a non-directory was found or when a + // directory was found but not expected, so XOR was used. + if ((type == FSObjectType::Directory) ^ (err == SafeFD::Error::kWrongType)) { + err = parent->Rmdir(name, true /*recursive*/); + } else { + err = parent->Unlink(name); + } + if (SafeFD::IsError(err)) { + PLOG(ERROR) << "Failed to clean up \"" << name << "\""; + return std::make_pair(SafeFD(), err); + } + + std::tie(child, err) = + (parent->*maker)(base::FilePath(name), permissions, uid, gid, flags); + return std::make_pair(std::move(child), err); +} + +} // namespace + +SafeFD::Error IsValidFilename(const std::string& filename) { + if (filename == "." || filename == ".." || + filename.find("/") != std::string::npos) { + return SafeFD::Error::kBadArgument; + } + return SafeFD::Error::kNoError; +} + +base::FilePath GetFDPath(int fd) { + const base::FilePath proc_fd(base::StringPrintf("/proc/self/fd/%d", fd)); + base::FilePath resolved; + if (!base::ReadSymbolicLink(proc_fd, &resolved)) { + LOG(ERROR) << "Failed to read " << proc_fd.value(); + return base::FilePath(); + } + return resolved; +} + +SafeFD::SafeFDResult OpenOrRemakeDir(SafeFD* parent, + const std::string& name, + int permissions, + uid_t uid, + gid_t gid, + int flags) { + return OpenOrRemake(parent, name, FSObjectType::Directory, permissions, uid, + gid, flags); +} + +SafeFD::SafeFDResult OpenOrRemakeFile(SafeFD* parent, + const std::string& name, + int permissions, + uid_t uid, + gid_t gid, + int flags) { + return OpenOrRemake(parent, name, FSObjectType::RegularFile, permissions, uid, + gid, flags); +} + +} // namespace brillo diff --git a/brillo/files/file_util.h b/brillo/files/file_util.h new file mode 100644 index 0000000..c020667 --- /dev/null +++ b/brillo/files/file_util.h @@ -0,0 +1,62 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Filesystem-related utility functions. + +#ifndef LIBBRILLO_BRILLO_FILES_FILE_UTIL_H_ +#define LIBBRILLO_BRILLO_FILES_FILE_UTIL_H_ + +#include <string> + +#include <brillo/files/safe_fd.h> + +namespace brillo { + +SafeFD::Error IsValidFilename(const std::string& filename); + +// Obtain the canonical path of the file descriptor or base::FilePath() on +// failure. +BRILLO_EXPORT base::FilePath GetFDPath(int fd); + +// Open or create a child directory named |name| as a child of |parent| with +// the specified permissions and ownership. Custom open flags can be set with +// |flags|. The directory will be re-created if: +// * The open operation fails (e.g. if |name| is not a directory). +// * The permissions do not match. +// * The ownership is different. +// +// Parameters +// parent - An open SafeFD to the parent directory. +// name - the name of the directory being created. It cannot have more than one +// path component. +BRILLO_EXPORT SafeFD::SafeFDResult OpenOrRemakeDir( + SafeFD* parent, + const std::string& name, + int permissions = SafeFD::kDefaultDirPermissions, + uid_t uid = getuid(), + gid_t gid = getgid(), + int flags = O_RDONLY | O_CLOEXEC); + +// Open or create a file named |name| under the directory |parent| with +// the specified permissions and ownership. Custom open flags can be set with +// |flags|. The file will be re-created if: +// * The open operation fails (e.g. |name| is a directory). +// * The permissions do not match. +// * The ownership is different. +// +// Parameters +// parent - An open SafeFD to the parent directory. +// name - the name of the file being created. It cannot have more than one +// path component. +BRILLO_EXPORT SafeFD::SafeFDResult OpenOrRemakeFile( + SafeFD* parent, + const std::string& name, + int permissions = SafeFD::kDefaultFilePermissions, + uid_t uid = getuid(), + gid_t gid = getgid(), + int flags = O_RDWR | O_CLOEXEC); + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_FILES_FILE_UTIL_H_ diff --git a/brillo/files/file_util_test.cc b/brillo/files/file_util_test.cc new file mode 100644 index 0000000..98d28fc --- /dev/null +++ b/brillo/files/file_util_test.cc @@ -0,0 +1,283 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/files/file_util_test.h" + +#include <base/files/file_util.h> +#include <base/rand_util.h> +#include <base/strings/string_number_conversions.h> +#include <brillo/files/file_util.h> +#include <brillo/files/safe_fd.h> + +namespace brillo { + +#define TO_STRING_HELPER(x) \ + case brillo::SafeFD::Error::x: \ + return #x; +std::string to_string(brillo::SafeFD::Error err) { + switch (err) { + TO_STRING_HELPER(kNoError) + TO_STRING_HELPER(kBadArgument) + TO_STRING_HELPER(kNotInitialized) + TO_STRING_HELPER(kIOError) + TO_STRING_HELPER(kDoesNotExist) + TO_STRING_HELPER(kSymlinkDetected) + TO_STRING_HELPER(kWrongType) + TO_STRING_HELPER(kWrongUID) + TO_STRING_HELPER(kWrongGID) + TO_STRING_HELPER(kWrongPermissions) + TO_STRING_HELPER(kExceededMaximum) + default: + return std::string("unknown (") + std::to_string(static_cast<int>(err)) + + ")"; + } +} +#undef TO_STRING_HELPER + +std::ostream& operator<<(std::ostream& os, const brillo::SafeFD::Error err) { + return os << to_string(err); // whatever needed to print bar to os +} + +std::string GetRandomSuffix() { + const int kBufferSize = 6; + unsigned char buffer[kBufferSize]; + base::RandBytes(buffer, arraysize(buffer)); + return base::HexEncode(buffer, arraysize(buffer)); +} + +void FileTest::SetUpTestCase() { + umask(0); +} + +FileTest::FileTest() { + CHECK(temp_dir_.CreateUniqueTempDir()) << strerror(errno); + sub_dir_path_ = temp_dir_.GetPath().Append(kSubdirName); + file_path_ = sub_dir_path_.Append(kFileName); + + std::string path = temp_dir_.GetPath().value(); + temp_dir_path_.reserve(path.size() + 1); + temp_dir_path_.assign(temp_dir_.GetPath().value().begin(), + temp_dir_.GetPath().value().end()); + temp_dir_path_.push_back('\0'); + + CHECK_EQ(chmod(temp_dir_path_.data(), SafeFD::kDefaultDirPermissions), 0); + SafeFD::SetRootPathForTesting(temp_dir_path_.data()); + root_ = SafeFD::Root().first; + CHECK(root_.is_valid()); +} + +bool FileTest::SetupSubdir() { + if (!base::CreateDirectory(sub_dir_path_)) { + PLOG(ERROR) << "Failed to create '" << sub_dir_path_.value() << "'"; + return false; + } + if (chmod(sub_dir_path_.value().c_str(), SafeFD::kDefaultDirPermissions) != + 0) { + PLOG(ERROR) << "Failed to set permissions of '" << sub_dir_path_.value() + << "'"; + return false; + } + return true; +} + +bool FileTest::SetupSymlinks() { + symlink_file_path_ = temp_dir_.GetPath().Append(kSymbolicFileName); + symlink_dir_path_ = temp_dir_.GetPath().Append(kSymbolicDirName); + if (!base::CreateSymbolicLink(file_path_, symlink_file_path_)) { + PLOG(ERROR) << "Failed to create symlink to '" << symlink_file_path_.value() + << "'"; + return false; + } + if (!base::CreateSymbolicLink(temp_dir_.GetPath(), symlink_dir_path_)) { + PLOG(ERROR) << "Failed to create symlink to'" << symlink_dir_path_.value() + << "'"; + return false; + } + return true; +} + +bool FileTest::WriteFile(const std::string& contents) { + if (!SetupSubdir()) { + return false; + } + if (contents.length() != + base::WriteFile(file_path_, contents.c_str(), contents.length())) { + PLOG(ERROR) << "base::WriteFile failed"; + return false; + } + if (chmod(file_path_.value().c_str(), SafeFD::kDefaultFilePermissions) != 0) { + PLOG(ERROR) << "chmod failed"; + return false; + } + return true; +} + +void FileTest::ExpectFileContains(const std::string& contents) { + EXPECT_TRUE(base::PathExists(file_path_)); + std::string new_contents; + EXPECT_TRUE(base::ReadFileToString(file_path_, &new_contents)); + EXPECT_EQ(contents, new_contents); +} + +void FileTest::ExpectPermissions(base::FilePath path, int permissions) { + int actual_permissions = 0; + // This breaks out of the ExpectPermissions() call but not the test case. + ASSERT_TRUE(base::GetPosixFilePermissions(path, &actual_permissions)); + EXPECT_EQ(permissions, actual_permissions); +} + +// Creates a file with a random name in the temporary directory. +base::FilePath FileTest::GetTempName() { + return temp_dir_.GetPath().Append(GetRandomSuffix()); +} + +constexpr char FileTest::kFileName[]; +constexpr char FileTest::kSubdirName[]; +constexpr char FileTest::kSymbolicFileName[]; +constexpr char FileTest::kSymbolicDirName[]; + +class FileUtilTest : public FileTest {}; + +TEST_F(FileUtilTest, GetFDPath_SimpleSuccess) { + EXPECT_EQ(GetFDPath(root_.get()), temp_dir_.GetPath()); +} + +TEST_F(FileUtilTest, GetFDPath_BadFD) { + base::FilePath path = GetFDPath(-1); + EXPECT_TRUE(path.empty()); +} + +TEST_F(FileUtilTest, OpenOrRemakeDir_SimpleSuccess) { + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + + SafeFD subdir; + std::tie(subdir, err) = OpenOrRemakeDir(&dir, kSubdirName); + EXPECT_EQ(err, SafeFD::Error::kNoError); + EXPECT_TRUE(subdir.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeDir_SuccessAfterRetry) { + ASSERT_NE(base::WriteFile(sub_dir_path_, "", 0), -1); + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + + SafeFD subdir; + std::tie(subdir, err) = OpenOrRemakeDir(&dir, kSubdirName); + EXPECT_EQ(err, SafeFD::Error::kNoError); + EXPECT_TRUE(subdir.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeDir_BadArgument) { + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + + SafeFD subdir; + std::tie(subdir, err) = OpenOrRemakeDir(&dir, "."); + EXPECT_EQ(err, SafeFD::Error::kBadArgument); + EXPECT_FALSE(subdir.is_valid()); + std::tie(subdir, err) = OpenOrRemakeDir(&dir, ".."); + EXPECT_EQ(err, SafeFD::Error::kBadArgument); + EXPECT_FALSE(subdir.is_valid()); + std::tie(subdir, err) = OpenOrRemakeDir(&dir, "a/a"); + EXPECT_EQ(err, SafeFD::Error::kBadArgument); + EXPECT_FALSE(subdir.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeDir_NotInitialized) { + SafeFD::Error err; + SafeFD dir; + + SafeFD subdir; + std::tie(subdir, err) = OpenOrRemakeDir(&dir, kSubdirName); + EXPECT_EQ(err, SafeFD::Error::kNotInitialized); + EXPECT_FALSE(subdir.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeDir_IOError) { + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + ASSERT_EQ(chmod(temp_dir_path_.data(), 0000), 0); + + SafeFD subdir; + std::tie(subdir, err) = OpenOrRemakeDir(&dir, kSubdirName); + EXPECT_EQ(err, SafeFD::Error::kIOError); + EXPECT_FALSE(subdir.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeFile_SimpleSuccess) { + ASSERT_TRUE(SetupSubdir()); + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(sub_dir_path_); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + + SafeFD file; + std::tie(file, err) = OpenOrRemakeFile(&dir, kFileName); + EXPECT_EQ(err, SafeFD::Error::kNoError); + EXPECT_TRUE(file.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeFile_SuccessAfterRetry) { + ASSERT_TRUE(SetupSubdir()); + ASSERT_TRUE(base::CreateDirectory(file_path_)); + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(sub_dir_path_); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + + SafeFD file; + std::tie(file, err) = OpenOrRemakeFile(&dir, kFileName); + EXPECT_EQ(err, SafeFD::Error::kNoError); + EXPECT_TRUE(file.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeFile_NotInitialized) { + ASSERT_TRUE(SetupSubdir()); + SafeFD::Error err; + SafeFD dir; + + SafeFD file; + std::tie(file, err) = OpenOrRemakeFile(&dir, kFileName); + EXPECT_EQ(err, SafeFD::Error::kNotInitialized); + EXPECT_FALSE(file.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeFile_IOError) { + ASSERT_TRUE(SetupSubdir()); + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(sub_dir_path_); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + ASSERT_EQ(chmod(sub_dir_path_.value().c_str(), 0000), 0); + + SafeFD file; + std::tie(file, err) = OpenOrRemakeFile(&dir, kFileName); + EXPECT_EQ(err, SafeFD::Error::kIOError); + EXPECT_FALSE(file.is_valid()); +} + +} // namespace brillo diff --git a/brillo/files/file_util_test.h b/brillo/files/file_util_test.h new file mode 100644 index 0000000..182cdf4 --- /dev/null +++ b/brillo/files/file_util_test.h @@ -0,0 +1,70 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Filesystem-related utility functions. + +#ifndef LIBBRILLO_BRILLO_FILES_FILE_UTIL_TEST_H_ +#define LIBBRILLO_BRILLO_FILES_FILE_UTIL_TEST_H_ + +#include <string> +#include <vector> + +#include <base/files/scoped_temp_dir.h> +#include <brillo/files/safe_fd.h> +#include <gtest/gtest.h> + +namespace brillo { + +// Convert the SafeFD::Error enum class to a string for readability of +// test results. +std::string to_string(brillo::SafeFD::Error err); + +// Helper to enable gtest to print SafeFD::Error results in a way that is easier +// to read. +std::ostream& operator<<(std::ostream& os, const brillo::SafeFD::Error err); + +// Gets a short random string that can be used as part of a file name. +std::string GetRandomSuffix(); + +class FileTest : public testing::Test { + public: + static constexpr char kFileName[] = "test.temp"; + static constexpr char kSubdirName[] = "test_dir"; + static constexpr char kSymbolicFileName[] = "sym_test.temp"; + static constexpr char kSymbolicDirName[] = "sym_dir"; + + static void SetUpTestCase(); + + FileTest(); + + protected: + std::vector<char> temp_dir_path_; + base::FilePath file_path_; + base::FilePath sub_dir_path_; + base::FilePath symlink_file_path_; + base::FilePath symlink_dir_path_; + base::ScopedTempDir temp_dir_; + SafeFD root_; + + bool SetupSubdir() WARN_UNUSED_RESULT; + + bool SetupSymlinks() WARN_UNUSED_RESULT; + + // Writes |contents| to |file_path_|. Pulled into a separate function just + // to improve readability of tests. + bool WriteFile(const std::string& contents) WARN_UNUSED_RESULT; + + // Verifies that the file at |file_path_| exists and contains |contents|. + void ExpectFileContains(const std::string& contents); + + // Verifies that the file at |file_path_| has |permissions|. + void ExpectPermissions(base::FilePath path, int permissions); + + // Creates a file with a random name in the temporary directory. + base::FilePath GetTempName() WARN_UNUSED_RESULT; +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_FILES_FILE_UTIL_TEST_H_ diff --git a/brillo/files/safe_fd.cc b/brillo/files/safe_fd.cc new file mode 100644 index 0000000..855207d --- /dev/null +++ b/brillo/files/safe_fd.cc @@ -0,0 +1,538 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/files/safe_fd.h" + +#include <fcntl.h> +#include <sys/stat.h> +#include <unistd.h> + +#include <base/files/file_util.h> +#include <base/logging.h> +#include <base/posix/eintr_wrapper.h> +#include <brillo/files/file_util.h> +#include <brillo/files/scoped_dir.h> +#include <brillo/syslog_logging.h> + +namespace brillo { + +namespace { + +SafeFD::SafeFDResult MakeErrorResult(SafeFD::Error error) { + return std::make_pair(SafeFD(), error); +} + +SafeFD::SafeFDResult MakeSuccessResult(SafeFD&& fd) { + return std::make_pair(std::move(fd), SafeFD::Error::kNoError); +} + +SafeFD::SafeFDResult OpenPathComponentInternal(int parent_fd, + const std::string& file, + int flags, + mode_t mode) { + if (file != "/" && file.find("/") != std::string::npos) { + return MakeErrorResult(SafeFD::Error::kBadArgument); + } + SafeFD fd; + + // O_NONBLOCK is used to avoid hanging on edge cases (e.g. a serial port with + // flow control, or a FIFO without a writer). + if (parent_fd >= 0 || parent_fd == AT_FDCWD) { + fd.UnsafeReset(HANDLE_EINTR(openat(parent_fd, file.c_str(), + flags | O_NONBLOCK | O_NOFOLLOW, mode))); + } else if (file == "/") { + fd.UnsafeReset(HANDLE_EINTR(open( + file.c_str(), flags | O_DIRECTORY | O_NONBLOCK | O_NOFOLLOW, mode))); + } + + if (!fd.is_valid()) { + // open(2) fails with ELOOP when the last component of the |path| is a + // symlink. It fails with ENXIO when |path| is a FIFO and |flags| is for + // writing because of the O_NONBLOCK flag added above. + switch (errno) { + case ENOENT: + // Do not write to the log because opening a non-existent file is a + // frequent occurrence. + return MakeErrorResult(SafeFD::Error::kDoesNotExist); + case ELOOP: + // PLOG prints something along the lines of the symlink depth being too + // great which is is misleading so LOG is used instead. + LOG(ERROR) << "Symlink detected! failed to open \"" << file + << "\" safely."; + return MakeErrorResult(SafeFD::Error::kSymlinkDetected); + case EISDIR: + PLOG(ERROR) << "Directory detected! failed to open \"" << file + << "\" safely"; + return MakeErrorResult(SafeFD::Error::kWrongType); + case ENOTDIR: + PLOG(ERROR) << "Not a directory! failed to open \"" << file + << "\" safely"; + return MakeErrorResult(SafeFD::Error::kWrongType); + case ENXIO: + PLOG(ERROR) << "FIFO detected! failed to open \"" << file + << "\" safely"; + return MakeErrorResult(SafeFD::Error::kWrongType); + default: + PLOG(ERROR) << "Failed to open \"" << file << '"'; + return MakeErrorResult(SafeFD::Error::kIOError); + } + } + + // Remove the O_NONBLOCK flag unless the original |flags| have it. + if ((flags & O_NONBLOCK) == 0) { + flags = fcntl(fd.get(), F_GETFL); + if (flags == -1) { + PLOG(ERROR) << "Failed to get fd flags for " << file; + return MakeErrorResult(SafeFD::Error::kIOError); + } + if (fcntl(fd.get(), F_SETFL, flags & ~O_NONBLOCK)) { + PLOG(ERROR) << "Failed to set fd flags for " << file; + return MakeErrorResult(SafeFD::Error::kIOError); + } + } + + return MakeSuccessResult(std::move(fd)); +} + +SafeFD::SafeFDResult OpenSafelyInternal(int parent_fd, + const base::FilePath& path, + int flags, + mode_t mode) { + std::vector<std::string> components; + path.GetComponents(&components); + + auto itr = components.begin(); + if (itr == components.end()) { + LOG(ERROR) << "A path is required."; + return MakeErrorResult(SafeFD::Error::kBadArgument); + } + + SafeFD::SafeFDResult child_fd; + int parent_flags = flags | O_NONBLOCK | O_RDONLY | O_DIRECTORY | O_PATH; + for (; itr + 1 != components.end(); ++itr) { + child_fd = OpenPathComponentInternal(parent_fd, *itr, parent_flags, 0); + // Operation failed, so directly return the error result. + if (!child_fd.first.is_valid()) { + return child_fd; + } + parent_fd = child_fd.first.get(); + } + + return OpenPathComponentInternal(parent_fd, *itr, flags, mode); +} + +SafeFD::Error CheckAttributes(int fd, + mode_t permissions, + uid_t uid, + gid_t gid) { + struct stat fd_attributes; + if (fstat(fd, &fd_attributes) != 0) { + PLOG(ERROR) << "fstat failed"; + return SafeFD::Error::kIOError; + } + + if (fd_attributes.st_uid != uid) { + LOG(ERROR) << "Owner uid is " << fd_attributes.st_uid << " instead of " + << uid; + return SafeFD::Error::kWrongUID; + } + + if (fd_attributes.st_gid != gid) { + LOG(ERROR) << "Owner gid is " << fd_attributes.st_gid << " instead of " + << gid; + return SafeFD::Error::kWrongGID; + } + + if ((0777 & (fd_attributes.st_mode ^ permissions)) != 0) { + mode_t mask = umask(0); + umask(mask); + LOG(ERROR) << "Permissions are " << std::oct + << (0777 & fd_attributes.st_mode) << " instead of " + << (0777 & permissions) << ". Umask is " << std::oct << mask + << std::dec; + return SafeFD::Error::kWrongPermissions; + } + + return SafeFD::Error::kNoError; +} + +SafeFD::Error GetFileSize(int fd, size_t* file_size) { + struct stat fd_attributes; + if (fstat(fd, &fd_attributes) != 0) { + return SafeFD::Error::kIOError; + } + + *file_size = fd_attributes.st_size; + return SafeFD::Error::kNoError; +} + +} // namespace + +bool SafeFD::IsError(SafeFD::Error err) { + return err != Error::kNoError; +} + +const char* SafeFD::RootPath = "/"; + +SafeFD::SafeFDResult SafeFD::Root() { + SafeFD::SafeFDResult root = + OpenPathComponentInternal(-1, "/", O_DIRECTORY, 0); + if (strcmp(SafeFD::RootPath, "/") == 0) { + return root; + } + + if (!root.first.is_valid()) { + LOG(ERROR) << "Failed to open root directory!"; + return root; + } + return root.first.OpenExistingDir(base::FilePath(SafeFD::RootPath)); +} + +void SafeFD::SetRootPathForTesting(const char* new_root_path) { + SafeFD::RootPath = new_root_path; +} + +int SafeFD::get() const { + return fd_.get(); +} + +bool SafeFD::is_valid() const { + return fd_.is_valid(); +} + +void SafeFD::reset() { + return fd_.reset(); +} + +void SafeFD::UnsafeReset(int fd) { + return fd_.reset(fd); +} + +SafeFD::Error SafeFD::Write(const char* data, size_t size) { + if (!fd_.is_valid()) { + return SafeFD::Error::kNotInitialized; + } + errno = 0; + if (!base::WriteFileDescriptor(fd_.get(), data, size)) { + PLOG(ERROR) << "Failed to write to file"; + return SafeFD::Error::kIOError; + } + + if (HANDLE_EINTR(ftruncate(fd_.get(), size)) != 0) { + PLOG(ERROR) << "Failed to truncate file"; + return SafeFD::Error::kIOError; + } + return SafeFD::Error::kNoError; +} + +std::pair<std::vector<char>, SafeFD::Error> SafeFD::ReadContents( + size_t max_size) { + std::vector<char> buffer; + if (!fd_.is_valid()) { + return std::make_pair(std::move(buffer), SafeFD::Error::kNotInitialized); + } + + size_t file_size = 0; + SafeFD::Error err = GetFileSize(fd_.get(), &file_size); + if (IsError(err)) { + return std::make_pair(std::move(buffer), err); + } + + if (file_size > max_size) { + return std::make_pair(std::move(buffer), SafeFD::Error::kExceededMaximum); + } + + buffer.resize(file_size); + + err = Read(buffer.data(), buffer.size()); + if (IsError(err)) { + buffer.clear(); + } + return std::make_pair(std::move(buffer), err); +} + +SafeFD::Error SafeFD::Read(char* data, size_t size) { + if (!fd_.is_valid()) { + return SafeFD::Error::kNotInitialized; + } + + if (!base::ReadFromFD(fd_.get(), data, size)) { + PLOG(ERROR) << "Failed to read file"; + return SafeFD::Error::kIOError; + } + return SafeFD::Error::kNoError; +} + +SafeFD::SafeFDResult SafeFD::OpenExistingFile(const base::FilePath& path, + int flags) { + if (!fd_.is_valid()) { + return MakeErrorResult(SafeFD::Error::kNotInitialized); + } + + return OpenSafelyInternal(get(), path, flags, 0 /*mode*/); +} + +SafeFD::SafeFDResult SafeFD::OpenExistingDir(const base::FilePath& path, + int flags) { + if (!fd_.is_valid()) { + return MakeErrorResult(SafeFD::Error::kNotInitialized); + } + + return OpenSafelyInternal(get(), path, O_DIRECTORY | flags /*flags*/, + 0 /*mode*/); +} + +SafeFD::SafeFDResult SafeFD::MakeFile(const base::FilePath& path, + mode_t permissions, + uid_t uid, + gid_t gid, + int flags) { + if (!fd_.is_valid()) { + return MakeErrorResult(SafeFD::Error::kNotInitialized); + } + + // Open (and create if necessary) the parent directory. + base::FilePath dir_name = path.DirName(); + SafeFD::SafeFDResult parent_dir; + int parent_dir_fd = get(); + if (!dir_name.empty() && + dir_name.value() != base::FilePath::kCurrentDirectory) { + // Apply execute permission where read permission are present for parent + // directories. + int dir_permissions = permissions | ((permissions & 0444) >> 2); + parent_dir = + MakeDir(dir_name, dir_permissions, uid, gid, O_RDONLY | O_CLOEXEC); + if (!parent_dir.first.is_valid()) { + return parent_dir; + } + parent_dir_fd = parent_dir.first.get(); + } + + // If file already exists, validate permissions. + SafeFDResult file = OpenPathComponentInternal( + parent_dir_fd, path.BaseName().value(), flags, permissions /*mode*/); + if (file.first.is_valid()) { + SafeFD::Error err = + CheckAttributes(file.first.get(), permissions, uid, gid); + if (IsError(err)) { + return MakeErrorResult(err); + } + return file; + } else if (errno != ENOENT) { + return file; + } + + // The file does exist, create it and set the ownership. + file = + OpenPathComponentInternal(parent_dir_fd, path.BaseName().value(), + O_CREAT | O_EXCL | flags, permissions /*mode*/); + if (!file.first.is_valid()) { + return file; + } + if (HANDLE_EINTR(fchown(file.first.get(), uid, gid)) != 0) { + PLOG(ERROR) << "Failed to set ownership in MakeFile() for \"" + << path.value() << '"'; + return MakeErrorResult(SafeFD::Error::kIOError); + } + return file; +} + +SafeFD::SafeFDResult SafeFD::MakeDir(const base::FilePath& path, + mode_t permissions, + uid_t uid, + gid_t gid, + int flags) { + if (!fd_.is_valid()) { + return MakeErrorResult(SafeFD::Error::kNotInitialized); + } + + std::vector<std::string> components; + path.GetComponents(&components); + if (components.empty()) { + LOG(ERROR) << "Called MakeDir() with an empty path"; + return MakeErrorResult(SafeFD::Error::kBadArgument); + } + + // Walk the path creating directories as necessary. + SafeFD dir; + SafeFDResult child_dir; + int parent_dir_fd = get(); + int dir_flags = O_NONBLOCK | O_DIRECTORY | O_PATH; + bool made_dir = false; + for (const auto& component : components) { + if (mkdirat(parent_dir_fd, component.c_str(), permissions) != 0) { + if (errno != EEXIST) { + PLOG(ERROR) << "Failed to mkdirat() " << component << ": full_path=\"" + << path.value() << '"'; + return MakeErrorResult(SafeFD::Error::kIOError); + } + } else { + made_dir = true; + } + + // For the last component in the path, use the flags provided by the caller. + if (&component == &components.back()) { + dir_flags = flags | O_DIRECTORY; + } + child_dir = OpenPathComponentInternal(parent_dir_fd, component, dir_flags, + 0 /*mode*/); + if (!child_dir.first.is_valid()) { + return child_dir; + } + + dir = std::move(child_dir.first); + parent_dir_fd = dir.get(); + } + + if (made_dir) { + // If the directory was created, set the ownership. + if (HANDLE_EINTR(fchown(dir.get(), uid, gid)) != 0) { + PLOG(ERROR) << "Failed to set ownership in MakeDir() for \"" + << path.value() << '"'; + return MakeErrorResult(SafeFD::Error::kIOError); + } + } + // If the directory already existed, validate the permissions. + SafeFD::Error err = CheckAttributes(dir.get(), permissions, uid, gid); + if (IsError(err)) { + return MakeErrorResult(err); + } + + return MakeSuccessResult(std::move(dir)); +} + +SafeFD::Error SafeFD::Link(const SafeFD& source_dir, + const std::string& source_name, + const std::string& destination_name) { + if (!fd_.is_valid() || !source_dir.is_valid()) { + return SafeFD::Error::kNotInitialized; + } + + SafeFD::Error err = IsValidFilename(source_name); + if (IsError(err)) { + return err; + } + + err = IsValidFilename(destination_name); + if (IsError(err)) { + return err; + } + + if (HANDLE_EINTR(linkat(source_dir.get(), source_name.c_str(), fd_.get(), + destination_name.c_str(), 0)) != 0) { + PLOG(ERROR) << "Failed to link \"" << destination_name << "\""; + return SafeFD::Error::kIOError; + } + return SafeFD::Error::kNoError; +} + +SafeFD::Error SafeFD::Unlink(const std::string& name) { + if (!fd_.is_valid()) { + return SafeFD::Error::kNotInitialized; + } + + SafeFD::Error err = IsValidFilename(name); + if (IsError(err)) { + return err; + } + + if (HANDLE_EINTR(unlinkat(fd_.get(), name.c_str(), 0 /*flags*/)) != 0) { + PLOG(ERROR) << "Failed to unlink \"" << name << "\""; + return SafeFD::Error::kIOError; + } + return SafeFD::Error::kNoError; +} + +SafeFD::Error SafeFD::Rmdir(const std::string& name, + bool recursive, + size_t max_depth) { + if (!fd_.is_valid()) { + return SafeFD::Error::kNotInitialized; + } + + if (max_depth == 0) { + return SafeFD::Error::kExceededMaximum; + } + + SafeFD::Error err = IsValidFilename(name); + if (IsError(err)) { + return err; + } + + if (recursive) { + SafeFD dir_fd; + std::tie(dir_fd, err) = + OpenPathComponentInternal(fd_.get(), name, O_DIRECTORY, 0); + if (!dir_fd.is_valid()) { + return err; + } + + // The ScopedDIR takes ownership of this so dup_fd is not scoped on its own. + int dup_fd = dup(dir_fd.get()); + if (dup_fd < 0) { + PLOG(ERROR) << "dup failed"; + return SafeFD::Error::kIOError; + } + + ScopedDIR dir(fdopendir(dup_fd)); + if (!dir.is_valid()) { + PLOG(ERROR) << "fdopendir failed"; + close(dup_fd); + return SafeFD::Error::kIOError; + } + + struct stat dir_info; + if (fstat(dir_fd.get(), &dir_info) != 0) { + return SafeFD::Error::kIOError; + } + + errno = 0; + const dirent* entry = HANDLE_EINTR_IF_EQ(readdir(dir.get()), nullptr); + while (entry != nullptr) { + if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) { + goto continue_; + } + + struct stat child_info; + if (fstatat(dir_fd.get(), entry->d_name, &child_info, + AT_NO_AUTOMOUNT | AT_SYMLINK_NOFOLLOW) != 0) { + return SafeFD::Error::kIOError; + } + + if (child_info.st_dev != dir_info.st_dev) { + return SafeFD::Error::kBoundaryDetected; + } + + SafeFD::Error err; + if (entry->d_type == DT_DIR) { + err = dir_fd.Rmdir(entry->d_name, true, max_depth - 1); + } else { + err = dir_fd.Unlink(entry->d_name); + } + + if (IsError(err)) { + return err; + } + + continue_: + errno = 0; + entry = HANDLE_EINTR_IF_EQ(readdir(dir.get()), nullptr); + } + if (errno != 0) { + PLOG(ERROR) << "readdir failed"; + return SafeFD::Error::kIOError; + } + } + + if (HANDLE_EINTR(unlinkat(fd_.get(), name.c_str(), AT_REMOVEDIR)) != 0) { + PLOG(ERROR) << "unlinkat failed"; + if (errno == ENOTDIR) { + return SafeFD::Error::kWrongType; + } + return SafeFD::Error::kIOError; + } + return SafeFD::Error::kNoError; +} + +} // namespace brillo diff --git a/brillo/files/safe_fd.h b/brillo/files/safe_fd.h new file mode 100644 index 0000000..5e126b4 --- /dev/null +++ b/brillo/files/safe_fd.h @@ -0,0 +1,201 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This provides an API for performing typical filesystem related tasks while +// guaranteeing certain security properties are maintained. Specifically, checks +// are performed to disallow symbolic links, and exotic file objects. The goal +// behind these checks is to thwart attacks that rely on confusing system +// services to perform unintended file operations like ownership changes or +// copy-as-root attack primitives. To accomplish this these operations are +// written to avoid susceptibility to TOCTOU (time-of-check-time-of-use) +// attacks. + +// To use this API start with the root path and work from there. For example: +// SafeFD fd(SafeDirFD::Root().MakeFile(PATH).first); +// if (!fd.is_valid()) { +// LOG(ERROR) << "Failed to open " << PATH; +// return false; +// } +// if (fd.WriteString(CONTENTS) != SafeFD::kNoError) { +// LOG(ERROR) << "Failed to write to " << PATH; +// return false; +// } +// auto read_result = fd.ReadString(); +// if (!read_result.second != SafeFD::kNoError) { +// LOG(ERROR) << "Failed to read from " << PATH; +// return false; +// } + +#ifndef LIBBRILLO_BRILLO_FILES_SAFE_FD_H_ +#define LIBBRILLO_BRILLO_FILES_SAFE_FD_H_ + +#include <fcntl.h> + +#include <string> +#include <utility> +#include <vector> + +#include <base/files/file_path.h> +#include <base/files/scoped_file.h> +#include <base/optional.h> +#include <base/synchronization/lock.h> +#include <brillo/brillo_export.h> + +namespace brillo { + +class SafeFDTest; + +class SafeFD { + public: + enum class Error { + kNoError = 0, + kBadArgument, + kNotInitialized, // Invalid operation on a SafeFD that was not initialized. + kIOError, // Check errno for specific cause. + kDoesNotExist, // The specified path does not exist. + kSymlinkDetected, + kBoundaryDetected, // Detected a file system boundary during recursion. + kWrongType, // (e.g. got a directory and expected a file) + kWrongUID, + kWrongGID, + kWrongPermissions, + kExceededMaximum, // The maximum allowed read size was reached. + }; + + // Returns true if |err| denotes a failed operation. + BRILLO_EXPORT static bool IsError(SafeFD::Error err); + + typedef std::pair<SafeFD, Error> SafeFDResult; + + // 100 MiB + BRILLO_EXPORT static constexpr size_t kDefaultMaxRead = 100 << 20; + BRILLO_EXPORT static constexpr size_t kDefaultMaxPathDepth = 256; + // User read and write only. + BRILLO_EXPORT static constexpr size_t kDefaultFilePermissions = 0640; + // User read, write, and execute. Group read and execute. + BRILLO_EXPORT static constexpr size_t kDefaultDirPermissions = 0750; + + // Get a SafeFD to the root path. + BRILLO_EXPORT static SafeFDResult Root() WARN_UNUSED_RESULT; + BRILLO_EXPORT static void SetRootPathForTesting(const char* new_root_path); + + // Constructs an invalid fd; + BRILLO_EXPORT SafeFD() = default; + + // Move-based constructor and assignment. + BRILLO_EXPORT SafeFD(SafeFD&&) = default; + BRILLO_EXPORT SafeFD& operator=(SafeFD&&) = default; + + // Return the fd number. + BRILLO_EXPORT int get() const WARN_UNUSED_RESULT; + + // Check the validity of the file descriptor. + BRILLO_EXPORT bool is_valid() const WARN_UNUSED_RESULT; + + // Close the scoped file if one was open. + BRILLO_EXPORT void reset(); + + // Wrap |fd| with a SafeFD which will close the fd when this goes out of + // scope. This closes the original fd if one was open. + // This is named "Unsafe" because the recommended way to get a SafeFD + // instance is opening one from SafeFD::Root(). + BRILLO_EXPORT void UnsafeReset(int fd); + + // Writes |size| bytes from |data| into a file and returns kNoError on + // success. Note the file will be truncated to the size of the content. + // + // Parameters + // data - The buffer to write to the file. + // size - The number of bytes to write. + BRILLO_EXPORT Error Write(const char* data, size_t size) WARN_UNUSED_RESULT; + + // Read the contents of the file and return it as a string. + // + // Parameters + // size - The max number of bytes to read. + BRILLO_EXPORT std::pair<std::vector<char>, Error> ReadContents( + size_t max_size = kDefaultMaxRead) WARN_UNUSED_RESULT; + + // Reads exactly |size| bytes into |data|. + // + // Parameters + // data - The buffer to read the file into. + // size - The number of bytes to read. + BRILLO_EXPORT Error Read(char* data, size_t size) WARN_UNUSED_RESULT; + + // Open an existing file relative to this directory. + // + // Parameters + // path - The path to open relative to the current directory. + BRILLO_EXPORT SafeFDResult OpenExistingFile(const base::FilePath& path, + int flags = O_RDWR | O_CLOEXEC) + WARN_UNUSED_RESULT; + + // Open an existing directory relative to this directory. + // + // Parameters + // path - The path to open relative to the current directory. + BRILLO_EXPORT SafeFDResult OpenExistingDir(const base::FilePath& path, + int flags = O_RDONLY | O_CLOEXEC) + WARN_UNUSED_RESULT; + + // Open a file relative to this directory creating the parent directories and + // file if they don't already exist. + BRILLO_EXPORT SafeFDResult + MakeFile(const base::FilePath& path, + mode_t permissions = kDefaultFilePermissions, + uid_t uid = getuid(), + gid_t gid = getgid(), + int flags = O_RDWR | O_CLOEXEC) WARN_UNUSED_RESULT; + + // Create the directories in the relative path with the given ownership and + // permissions and return a file descriptor to the result. + BRILLO_EXPORT SafeFDResult + MakeDir(const base::FilePath& path, + mode_t permissions = kDefaultDirPermissions, + uid_t uid = getuid(), + gid_t gid = getgid(), + int flags = O_RDONLY | O_CLOEXEC) WARN_UNUSED_RESULT; + + // Hard link |fd| in the directory represented by |this| with the specified + // name |filename|. This requires CAP_DAC_READ_SEARCH. + // + // Parameters + // data - The buffer to write to the file. + // size - The number of bytes to write. + BRILLO_EXPORT Error Link(const SafeFD& source_dir, + const std::string& source_name, + const std::string& destination_name) + WARN_UNUSED_RESULT; + + // Deletes the child path named |name|. + // + // Parameters + // name - the name of the filesystem object to delete. + BRILLO_EXPORT Error Unlink(const std::string& name) WARN_UNUSED_RESULT; + + // Deletes a child directory. It will return kBoundaryDetected if a file + // system boundary is reached during recursion. + // + // Parameters + // name - the name of the directory to delete. + // recursive - if true also unlink child paths. + // max_depth - limit on recursion depth to prevent fd exhaustion and stack + // overflows. + BRILLO_EXPORT Error Rmdir(const std::string& name, + bool recursive = false, + size_t max_depth = kDefaultMaxPathDepth) + WARN_UNUSED_RESULT; + + private: + BRILLO_EXPORT static const char* RootPath; + + base::ScopedFD fd_; + + DISALLOW_COPY_AND_ASSIGN(SafeFD); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_FILES_SAFE_FD_H_ diff --git a/brillo/files/safe_fd_test.cc b/brillo/files/safe_fd_test.cc new file mode 100644 index 0000000..fff3a6c --- /dev/null +++ b/brillo/files/safe_fd_test.cc @@ -0,0 +1,627 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/files/safe_fd.h" + +#include <fcntl.h> +#include <sys/stat.h> + +#include <base/files/file_util.h> +#include <brillo/files/file_util_test.h> +#include <brillo/syslog_logging.h> +#include <gtest/gtest.h> + +namespace brillo { + +class SafeFDTest : public FileTest {}; + +TEST_F(SafeFDTest, SafeFD) { + EXPECT_FALSE(SafeFD().is_valid()); +} + +TEST_F(SafeFDTest, SafeFD_Move) { + SafeFD moved_root = std::move(root_); + EXPECT_FALSE(root_.is_valid()); + ASSERT_TRUE(moved_root.is_valid()); + + SafeFD moved_root2(std::move(moved_root)); + EXPECT_FALSE(moved_root.is_valid()); + ASSERT_TRUE(moved_root2.is_valid()); +} + +TEST_F(SafeFDTest, Root) { + SafeFD::SafeFDResult result = SafeFD::Root(); + EXPECT_TRUE(result.first.is_valid()); + EXPECT_EQ(result.second, SafeFD::Error::kNoError); +} + +TEST_F(SafeFDTest, reset) { + root_.reset(); + EXPECT_FALSE(root_.is_valid()); +} + +TEST_F(SafeFDTest, UnsafeReset) { + int fd = + HANDLE_EINTR(open(temp_dir_path_.data(), + O_NONBLOCK | O_DIRECTORY | O_RDONLY | O_CLOEXEC, 0777)); + ASSERT_GE(fd, 0); + + { + SafeFD safe_fd; + safe_fd.UnsafeReset(fd); + EXPECT_EQ(safe_fd.get(), fd); + } + + // Verify the file descriptor is closed. + int result = fcntl(fd, F_GETFD); + int error = errno; + EXPECT_EQ(result, -1); + EXPECT_EQ(error, EBADF); +} + +TEST_F(SafeFDTest, Write_Success) { + std::string random_data = GetRandomSuffix(); + { + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + + EXPECT_EQ(file.first.Write(random_data.data(), random_data.size()), + SafeFD::Error::kNoError); + } + + ExpectFileContains(random_data); + ExpectPermissions(file_path_, SafeFD::kDefaultFilePermissions); +} + +TEST_F(SafeFDTest, Write_NotInitialized) { + SafeFD invalid; + ASSERT_FALSE(invalid.is_valid()); + + std::string random_data = GetRandomSuffix(); + EXPECT_EQ(invalid.Write(random_data.data(), random_data.size()), + SafeFD::Error::kNotInitialized); +} + +TEST_F(SafeFDTest, Write_VerifyTruncate) { + std::string random_data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(random_data)); + + { + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + + EXPECT_EQ(file.first.Write("", 0), SafeFD::Error::kNoError); + } + + ExpectFileContains(""); +} + +TEST_F(SafeFDTest, Write_Failure) { + std::string random_data = GetRandomSuffix(); + EXPECT_EQ(root_.Write("", 1), SafeFD::Error::kIOError); +} + +TEST_F(SafeFDTest, ReadContents_Success) { + std::string random_data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(random_data)); + + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + + auto result = file.first.ReadContents(); + EXPECT_EQ(result.second, SafeFD::Error::kNoError); + ASSERT_EQ(random_data.size(), result.first.size()); + EXPECT_EQ(memcmp(random_data.data(), result.first.data(), random_data.size()), + 0); +} + +TEST_F(SafeFDTest, ReadContents_ExceededMaximum) { + std::string random_data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(random_data)); + + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + + ASSERT_LT(1, random_data.size()); + auto result = file.first.ReadContents(1); + EXPECT_EQ(result.second, SafeFD::Error::kExceededMaximum); +} + +TEST_F(SafeFDTest, ReadContents_NotInitialized) { + SafeFD invalid; + ASSERT_FALSE(invalid.is_valid()); + + auto result = invalid.ReadContents(); + EXPECT_EQ(result.second, SafeFD::Error::kNotInitialized); +} + +TEST_F(SafeFDTest, Read_Success) { + std::string random_data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(random_data)); + + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + + std::vector<char> buffer(random_data.size(), '\0'); + ASSERT_EQ(file.first.Read(buffer.data(), buffer.size()), + SafeFD::Error::kNoError); + EXPECT_EQ(memcmp(random_data.data(), buffer.data(), random_data.size()), 0); +} + +TEST_F(SafeFDTest, Read_NotInitialized) { + SafeFD invalid; + ASSERT_FALSE(invalid.is_valid()); + + char to_read; + EXPECT_EQ(invalid.Read(&to_read, 1), SafeFD::Error::kNotInitialized); +} + +TEST_F(SafeFDTest, Read_IOError) { + std::string random_data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(random_data)); + + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + + std::vector<char> buffer(random_data.size() * 2, '\0'); + ASSERT_EQ(file.first.Read(buffer.data(), buffer.size()), + SafeFD::Error::kIOError); +} + +TEST_F(SafeFDTest, OpenExistingFile_Success) { + std::string data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(data)); + { + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + } + ExpectFileContains(data); +} + +TEST_F(SafeFDTest, OpenExistingFile_NotInitialized) { + SafeFD::SafeFDResult file = SafeFD().OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNotInitialized); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingFile_DoesNotExist) { + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kDoesNotExist); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingFile_IOError) { + ASSERT_TRUE(WriteFile("")); + EXPECT_EQ(chmod(file_path_.value().c_str(), 0000), 0) << strerror(errno); + + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kIOError); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingFile_SymlinkDetected) { + ASSERT_TRUE(SetupSymlinks()); + ASSERT_TRUE(WriteFile("")); + SafeFD::SafeFDResult file = root_.OpenExistingFile(symlink_file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kSymlinkDetected); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingFile_WrongType) { + ASSERT_TRUE(SetupSymlinks()); + ASSERT_TRUE(WriteFile("")); + SafeFD::SafeFDResult file = + root_.OpenExistingFile(symlink_dir_path_.Append(kFileName)); + EXPECT_EQ(file.second, SafeFD::Error::kWrongType); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingDir_Success) { + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(dir.second, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingDir_NotInitialized) { + SafeFD::SafeFDResult dir = SafeFD().OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(dir.second, SafeFD::Error::kNotInitialized); + ASSERT_FALSE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingDir_DoesNotExist) { + SafeFD::SafeFDResult dir = root_.OpenExistingDir(sub_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kDoesNotExist); + ASSERT_FALSE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingDir_IOError) { + ASSERT_TRUE(WriteFile("")); + ASSERT_EQ(chmod(sub_dir_path_.value().c_str(), 0000), 0) << strerror(errno); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(sub_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kIOError); + ASSERT_FALSE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingDir_WrongType) { + ASSERT_TRUE(SetupSymlinks()); + SafeFD::SafeFDResult dir = root_.OpenExistingDir(symlink_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kWrongType); + ASSERT_FALSE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, MakeFile_DoesNotExistSuccess) { + { + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + } + + ExpectPermissions(file_path_, SafeFD::kDefaultFilePermissions); +} + +TEST_F(SafeFDTest, MakeFile_LeadingSelfDirSuccess) { + ASSERT_TRUE(SetupSubdir()); + + SafeFD::Error err; + SafeFD dir; + std::tie(dir, err) = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(err, SafeFD::Error::kNoError); + + { + SafeFD file; + std::tie(file, err) = dir.MakeFile(file_path_.BaseName()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(file.is_valid()); + } + + ExpectPermissions(file_path_, SafeFD::kDefaultFilePermissions); +} + +TEST_F(SafeFDTest, MakeFile_ExistsSuccess) { + std::string data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(data)); + { + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + } + ExpectPermissions(file_path_, SafeFD::kDefaultFilePermissions); + ExpectFileContains(data); +} + +TEST_F(SafeFDTest, MakeFile_IOError) { + ASSERT_TRUE(SetupSubdir()); + ASSERT_EQ(mkfifo(file_path_.value().c_str(), 0), 0); + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kIOError); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, MakeFile_SymlinkDetected) { + ASSERT_TRUE(SetupSymlinks()); + SafeFD::SafeFDResult file = root_.MakeFile(symlink_file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kSymlinkDetected); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, MakeFile_WrongType) { + ASSERT_TRUE(SetupSubdir()); + SafeFD::SafeFDResult file = root_.MakeFile(sub_dir_path_); + EXPECT_EQ(file.second, SafeFD::Error::kWrongType); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, MakeFile_WrongGID) { + ASSERT_TRUE(WriteFile("")); + ASSERT_EQ(chown(file_path_.value().c_str(), getuid(), 0), 0) + << strerror(errno); + { + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kWrongGID); + ASSERT_FALSE(file.first.is_valid()); + } +} + +TEST_F(SafeFDTest, MakeFile_WrongPermissions) { + ASSERT_TRUE(WriteFile("")); + ASSERT_EQ(chmod(file_path_.value().c_str(), 0777), 0) << strerror(errno); + { + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kWrongPermissions); + ASSERT_FALSE(file.first.is_valid()); + } + ASSERT_EQ(chmod(file_path_.value().c_str(), SafeFD::kDefaultFilePermissions), + 0) + << strerror(errno); + + EXPECT_EQ(chmod(sub_dir_path_.value().c_str(), 0777), 0) << strerror(errno); + { + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kWrongPermissions); + ASSERT_FALSE(file.first.is_valid()); + } +} + +TEST_F(SafeFDTest, MakeDir_DoesNotExistSuccess) { + { + SafeFD::SafeFDResult dir = root_.MakeDir(sub_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.first.is_valid()); + } + + ExpectPermissions(sub_dir_path_, SafeFD::kDefaultDirPermissions); +} + +TEST_F(SafeFDTest, MakeFile_SingleComponentSuccess) { + ASSERT_TRUE(SetupSubdir()); + + SafeFD::Error err; + SafeFD dir; + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(err, SafeFD::Error::kNoError); + + { + SafeFD subdir; + std::tie(subdir, err) = dir.MakeDir(base::FilePath(kSubdirName)); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(subdir.is_valid()); + } + + ExpectPermissions(sub_dir_path_, SafeFD::kDefaultDirPermissions); +} + +TEST_F(SafeFDTest, MakeDir_ExistsSuccess) { + ASSERT_TRUE(SetupSubdir()); + { + SafeFD::SafeFDResult dir = root_.MakeDir(sub_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.first.is_valid()); + } + + ExpectPermissions(sub_dir_path_, SafeFD::kDefaultDirPermissions); +} + +TEST_F(SafeFDTest, MakeDir_WrongType) { + ASSERT_TRUE(SetupSymlinks()); + SafeFD::SafeFDResult dir = root_.MakeDir(symlink_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kWrongType); + ASSERT_FALSE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, MakeDir_WrongGID) { + ASSERT_TRUE(SetupSubdir()); + ASSERT_EQ(chown(sub_dir_path_.value().c_str(), getuid(), 0), 0) + << strerror(errno); + { + SafeFD::SafeFDResult dir = root_.MakeDir(sub_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kWrongGID); + ASSERT_FALSE(dir.first.is_valid()); + } +} + +TEST_F(SafeFDTest, MakeDir_WrongPermissions) { + ASSERT_TRUE(SetupSubdir()); + ASSERT_EQ(chmod(sub_dir_path_.value().c_str(), 0777), 0) << strerror(errno); + + SafeFD::SafeFDResult dir = root_.MakeDir(sub_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kWrongPermissions); + ASSERT_FALSE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, Link_Success) { + std::string data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(data)); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Link(subdir.first, kFileName, kFileName), + SafeFD::Error::kNoError); + + SafeFD::SafeFDResult new_file = dir.first.OpenExistingFile( + base::FilePath(kFileName), O_RDONLY | O_CLOEXEC); + EXPECT_EQ(new_file.second, SafeFD::Error::kNoError); + std::pair<std::vector<char>, SafeFD::Error> contents = + new_file.first.ReadContents(); + EXPECT_EQ(contents.second, SafeFD::Error::kNoError); + EXPECT_EQ(data.size(), contents.first.size()); + EXPECT_EQ(memcmp(data.data(), contents.first.data(), data.size()), 0); +} + +TEST_F(SafeFDTest, Link_NotInitialized) { + std::string data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(data)); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(SafeFD().Link(subdir.first, kFileName, kFileName), + SafeFD::Error::kNotInitialized); + + EXPECT_EQ(dir.first.Link(SafeFD(), kFileName, kFileName), + SafeFD::Error::kNotInitialized); +} + +TEST_F(SafeFDTest, Link_BadArgument) { + ASSERT_TRUE(WriteFile("")); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Link(subdir.first, "a/a", kFileName), + SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Link(subdir.first, ".", kFileName), + SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Link(subdir.first, "..", kFileName), + SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Link(subdir.first, kFileName, "a/a"), + SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Link(subdir.first, kFileName, "."), + SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Link(subdir.first, kFileName, ".."), + SafeFD::Error::kBadArgument); +} + +TEST_F(SafeFDTest, Link_IOError) { + ASSERT_TRUE(SetupSubdir()); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Link(subdir.first, kFileName, kFileName), + SafeFD::Error::kIOError); +} + +TEST_F(SafeFDTest, Unlink_Success) { + ASSERT_TRUE(WriteFile("")); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(subdir.first.Unlink(kFileName), SafeFD::Error::kNoError); + EXPECT_FALSE(base::PathExists(file_path_)); +} + +TEST_F(SafeFDTest, Unlink_NotInitialized) { + ASSERT_TRUE(WriteFile("")); + + EXPECT_EQ(SafeFD().Unlink(kFileName), SafeFD::Error::kNotInitialized); +} + +TEST_F(SafeFDTest, Unlink_BadArgument) { + ASSERT_TRUE(WriteFile("")); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(subdir.first.Unlink("a/a"), SafeFD::Error::kBadArgument); + EXPECT_EQ(subdir.first.Unlink("."), SafeFD::Error::kBadArgument); + EXPECT_EQ(subdir.first.Unlink(".."), SafeFD::Error::kBadArgument); +} + +TEST_F(SafeFDTest, Unlink_IOError_Nonexistent) { + ASSERT_TRUE(SetupSubdir()); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(subdir.first.Unlink(kFileName), SafeFD::Error::kIOError); +} + +TEST_F(SafeFDTest, Unlink_IOError_IsADir) { + ASSERT_TRUE(SetupSubdir()); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Unlink(kSubdirName), SafeFD::Error::kIOError); +} + +TEST_F(SafeFDTest, Rmdir_Recursive_Success) { + ASSERT_TRUE(WriteFile("")); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Rmdir(kSubdirName, true /*recursive*/), + SafeFD::Error::kNoError); + EXPECT_FALSE(base::PathExists(file_path_)); + EXPECT_FALSE(base::PathExists(sub_dir_path_)); +} + +TEST_F(SafeFDTest, Rmdir_Recursive_SuccessMaxRecursion) { + SafeFD::Error err; + SafeFD dir; + + // Create directory with the maximum depth. + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + for (size_t x = 0; x < SafeFD::kDefaultMaxPathDepth; ++x) { + std::tie(dir, err) = dir.MakeDir(base::FilePath(kSubdirName)); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + } + + // Check if recursive Rmdir succeeds (i.e. there isn't a stack overflow). + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.Rmdir(kSubdirName, true /*recursive*/), + SafeFD::Error::kNoError); + EXPECT_FALSE(base::PathExists(file_path_)); + EXPECT_FALSE(base::PathExists(sub_dir_path_)); +} + +TEST_F(SafeFDTest, Rmdir_NotInitialized) { + ASSERT_TRUE(WriteFile("")); + + EXPECT_EQ(SafeFD().Rmdir(kSubdirName, true /*recursive*/), + SafeFD::Error::kNotInitialized); +} + +TEST_F(SafeFDTest, Rmdir_BadArgument) { + ASSERT_TRUE(WriteFile("")); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(dir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Rmdir("a/a"), SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Rmdir("."), SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Rmdir(".."), SafeFD::Error::kBadArgument); +} + +TEST_F(SafeFDTest, Rmdir_ExceededMaximum) { + ASSERT_TRUE(SetupSubdir()); + ASSERT_TRUE(base::CreateDirectory(sub_dir_path_.Append(kSubdirName))); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Rmdir(kSubdirName, true /*recursive*/, 1), + SafeFD::Error::kExceededMaximum); +} + +TEST_F(SafeFDTest, Rmdir_IOError) { + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + // Dir doesn't exist. + EXPECT_EQ(dir.first.Rmdir(kSubdirName), SafeFD::Error::kIOError); + + // Dir not empty. + ASSERT_TRUE(WriteFile("")); + EXPECT_EQ(dir.first.Rmdir(kSubdirName), SafeFD::Error::kIOError); +} + +TEST_F(SafeFDTest, Rmdir_WrongType) { + ASSERT_TRUE(WriteFile("")); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(subdir.first.Rmdir(kFileName), SafeFD::Error::kWrongType); +} + +} // namespace brillo diff --git a/brillo/files/scoped_dir.h b/brillo/files/scoped_dir.h new file mode 100644 index 0000000..b4ca6a9 --- /dev/null +++ b/brillo/files/scoped_dir.h @@ -0,0 +1,36 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_FILES_SCOPED_DIR_H_ +#define LIBBRILLO_BRILLO_FILES_SCOPED_DIR_H_ + +#include <dirent.h> + +#include <base/scoped_generic.h> + +#define HANDLE_EINTR_IF_EQ(x, val) \ + ({ \ + decltype(x) eintr_wrapper_result; \ + do { \ + eintr_wrapper_result = (x); \ + } while (eintr_wrapper_result == (val) && errno == EINTR); \ + eintr_wrapper_result; \ + }) + +namespace brillo { + +struct ScopedDIRCloseTraits { + static DIR* InvalidValue() { return nullptr; } + static void Free(DIR* dir) { + if (dir != nullptr) { + closedir(dir); + } + } +}; + +typedef base::ScopedGeneric<DIR*, ScopedDIRCloseTraits> ScopedDIR; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_FILES_SCOPED_DIR_H_ diff --git a/brillo/flag_helper.cc b/brillo/flag_helper.cc index bb51818..065a1c7 100644 --- a/brillo/flag_helper.cc +++ b/brillo/flag_helper.cc @@ -4,12 +4,14 @@ #include "brillo/flag_helper.h" -#include <memory> #include <stdio.h> #include <stdlib.h> -#include <string> #include <sysexits.h> +#include <memory> +#include <string> +#include <utility> + #include <base/base_switches.h> #include <base/command_line.h> #include <base/logging.h> diff --git a/brillo/flag_helper_unittest.cc b/brillo/flag_helper_test.cc index 29c6429..29c6429 100644 --- a/brillo/flag_helper_unittest.cc +++ b/brillo/flag_helper_test.cc diff --git a/brillo/glib/dbus.h b/brillo/glib/dbus.h index 7a28480..0e756bf 100644 --- a/brillo/glib/dbus.h +++ b/brillo/glib/dbus.h @@ -13,6 +13,7 @@ #include <algorithm> #include <string> +#include <utility> #include "base/logging.h" #include <brillo/brillo_export.h> diff --git a/brillo/glib/object.h b/brillo/glib/object.h index 15de52c..56d38a4 100644 --- a/brillo/glib/object.h +++ b/brillo/glib/object.h @@ -8,13 +8,14 @@ #include <glib-object.h> #include <stdint.h> -#include <base/logging.h> -#include <base/macros.h> - #include <algorithm> #include <cstddef> #include <memory> #include <string> +#include <utility> + +#include <base/logging.h> +#include <base/macros.h> namespace brillo { diff --git a/brillo/glib/object_unittest.cc b/brillo/glib/object_test.cc index a1ed408..a1ed408 100644 --- a/brillo/glib/object_unittest.cc +++ b/brillo/glib/object_test.cc diff --git a/brillo/http/http_connection_curl.cc b/brillo/http/http_connection_curl.cc index 3720330..6f1b3ed 100644 --- a/brillo/http/http_connection_curl.cc +++ b/brillo/http/http_connection_curl.cc @@ -4,6 +4,8 @@ #include <brillo/http/http_connection_curl.h> +#include <utility> + #include <base/logging.h> #include <brillo/http/http_request.h> #include <brillo/http/http_transport_curl.h> diff --git a/brillo/http/http_connection_curl.h b/brillo/http/http_connection_curl.h index c34de57..81008e1 100644 --- a/brillo/http/http_connection_curl.h +++ b/brillo/http/http_connection_curl.h @@ -6,6 +6,7 @@ #define LIBBRILLO_BRILLO_HTTP_HTTP_CONNECTION_CURL_H_ #include <map> +#include <memory> #include <string> #include <vector> diff --git a/brillo/http/http_connection_curl_unittest.cc b/brillo/http/http_connection_curl_test.cc index 90a5626..d908ac0 100644 --- a/brillo/http/http_connection_curl_unittest.cc +++ b/brillo/http/http_connection_curl_test.cc @@ -6,6 +6,7 @@ #include <algorithm> #include <set> +#include <utility> #include <base/callback.h> #include <brillo/http/http_request.h> diff --git a/brillo/http/http_connection_fake.cc b/brillo/http/http_connection_fake.cc index 15e5181..dbd9f90 100644 --- a/brillo/http/http_connection_fake.cc +++ b/brillo/http/http_connection_fake.cc @@ -4,8 +4,8 @@ #include <brillo/http/http_connection_fake.h> +#include <base/bind.h> #include <base/logging.h> -#include <brillo/bind_lambda.h> #include <brillo/http/http_request.h> #include <brillo/mime_utils.h> #include <brillo/streams/memory_stream.h> diff --git a/brillo/http/http_connection_fake.h b/brillo/http/http_connection_fake.h index a6ebeee..402d6f9 100644 --- a/brillo/http/http_connection_fake.h +++ b/brillo/http/http_connection_fake.h @@ -6,7 +6,9 @@ #define LIBBRILLO_BRILLO_HTTP_HTTP_CONNECTION_FAKE_H_ #include <map> +#include <memory> #include <string> +#include <utility> #include <vector> #include <base/macros.h> diff --git a/brillo/http/http_form_data.cc b/brillo/http/http_form_data.cc index 4d8f6f0..eb1d028 100644 --- a/brillo/http/http_form_data.cc +++ b/brillo/http/http_form_data.cc @@ -5,10 +5,12 @@ #include <brillo/http/http_form_data.h> #include <limits> +#include <utility> #include <base/format_macros.h> #include <base/rand_util.h> #include <base/strings/stringprintf.h> +#include <base/strings/string_util.h> #include <brillo/errors/error_codes.h> #include <brillo/http/http_transport.h> @@ -141,8 +143,18 @@ bool MultiPartFormField::ExtractDataStreams(std::vector<StreamPtr>* streams) { } std::string MultiPartFormField::GetContentType() const { + // Quote the boundary only if it has non-alphanumeric chars in it. + // https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html + bool use_quotes = false; + for (auto ch : boundary_) { + if (!base::IsAsciiAlpha(ch) && !base::IsAsciiDigit(ch)) { + use_quotes = true; + break; + } + } return base::StringPrintf( - "%s; boundary=\"%s\"", content_type_.c_str(), boundary_.c_str()); + use_quotes ? "%s; boundary=\"%s\"" : "%s; boundary=%s", + content_type_.c_str(), boundary_.c_str()); } void MultiPartFormField::AddCustomField(std::unique_ptr<FormField> field) { @@ -180,7 +192,7 @@ std::string MultiPartFormField::GetBoundaryStart() const { } std::string MultiPartFormField::GetBoundaryEnd() const { - return base::StringPrintf("--%s--", boundary_.c_str()); + return base::StringPrintf("--%s--\r\n", boundary_.c_str()); } FormData::FormData() : FormData{std::string{}} { diff --git a/brillo/http/http_form_data_fuzzer.cc b/brillo/http/http_form_data_fuzzer.cc new file mode 100644 index 0000000..f73a89f --- /dev/null +++ b/brillo/http/http_form_data_fuzzer.cc @@ -0,0 +1,128 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <stddef.h> +#include <stdint.h> + +#include <base/files/file_path.h> +#include <base/files/file_util.h> +#include <base/files/scoped_temp_dir.h> +#include <base/logging.h> +#include <brillo/http/http_form_data.h> +#include <brillo/streams/memory_stream.h> +#include <fuzzer/FuzzedDataProvider.h> + +namespace { +constexpr int kRandomDataMaxLength = 64; +constexpr int kMaxRecursionDepth = 256; + +std::unique_ptr<brillo::http::TextFormField> CreateTextFormField( + FuzzedDataProvider* data_provider) { + return std::make_unique<brillo::http::TextFormField>( + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength)); +} + +std::unique_ptr<brillo::http::FileFormField> CreateFileFormField( + FuzzedDataProvider* data_provider) { + brillo::StreamPtr mem_stream = brillo::MemoryStream::OpenCopyOf( + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), nullptr); + return std::make_unique<brillo::http::FileFormField>( + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + std::move(mem_stream), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength)); +} + +std::unique_ptr<brillo::http::MultiPartFormField> CreateMultipartFormField( + FuzzedDataProvider* data_provider, int depth) { + std::unique_ptr<brillo::http::MultiPartFormField> multipart_field = + std::make_unique<brillo::http::MultiPartFormField>( + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength)); + + // Randomly add fields to this like we do the base FormData, but don't loop + // forever. + while (data_provider->ConsumeBool()) { + if (data_provider->ConsumeBool()) { + // Add a random text field to the form. + multipart_field->AddCustomField(CreateTextFormField(data_provider)); + } + if (data_provider->ConsumeBool()) { + // Add a random file field to the form. + multipart_field->AddCustomField(CreateFileFormField(data_provider)); + } + // Limit our recursion depth. We could make this part of our code iterative, + // but that won't help because in libbrillo we use recursion to generate the + // stream so we would hit a stack depth limit there as well. + if (depth < kMaxRecursionDepth && data_provider->ConsumeBool()) { + // Add a random multipart form field to the form. + multipart_field->AddCustomField( + CreateMultipartFormField(data_provider, depth + 1)); + } + } + + return multipart_field; +} + +} // namespace + +bool IgnoreLogging(int, const char*, int, size_t, const std::string&) { + return true; +} + +class Environment { + public: + Environment() { + // Disable logging. Normally this would be done with logging::SetMinLogLevel + // but that doesn't work for brillo::Error for because it's not using the + // LOG(ERROR) macro which is where the actual log level check occurs. + logging::SetLogMessageHandler(&IgnoreLogging); + } +}; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + static Environment env; + FuzzedDataProvider data_provider(data, size); + // Randomly add a bunch of fields to the FormData and then when done extract + // and consume the data stream. + brillo::http::FormData form_data( + data_provider.ConsumeRandomLengthString(kRandomDataMaxLength)); + while (data_provider.remaining_bytes() > 0) { + if (data_provider.ConsumeBool()) { + // Add a random text field to the form. + form_data.AddCustomField(CreateTextFormField(&data_provider)); + } + if (data_provider.ConsumeBool()) { + // Add a random file field to the form. + form_data.AddCustomField(CreateFileFormField(&data_provider)); + } + if (data_provider.ConsumeBool()) { + // Add a random multipart form field to the form. + form_data.AddCustomField(CreateMultipartFormField(&data_provider, 0)); + } + } + + brillo::StreamPtr form_stream = form_data.ExtractDataStream(); + if (!form_stream) + return 0; + + // We need to use a decent sized buffer and call ReadAllBlocking to avoid + // excess overhead with reading here that can make the fuzzer timeout. + uint8_t buffer[32768]; + while (form_stream->GetRemainingSize() > 0) { + if (!form_stream->ReadAllBlocking(buffer, sizeof(buffer), nullptr)) { + // If there's an error reading from the stream, then bail since we'd + // likely just see repeated errors and never exit. + break; + } + } + + return 0; +} diff --git a/brillo/http/http_form_data_unittest.cc b/brillo/http/http_form_data_test.cc index 34288d0..80bf30a 100644 --- a/brillo/http/http_form_data_unittest.cc +++ b/brillo/http/http_form_data_test.cc @@ -5,6 +5,7 @@ #include <brillo/http/http_form_data.h> #include <set> +#include <utility> #include <base/files/file_util.h> #include <base/files/scoped_temp_dir.h> @@ -94,7 +95,7 @@ TEST(HttpFormData, MultiPartFormField) { nullptr)); const char expected_header[] = "Content-Disposition: form-data; name=\"foo\"\r\n" - "Content-Type: multipart/form-data; boundary=\"Delimiter\"\r\n" + "Content-Type: multipart/form-data; boundary=Delimiter\r\n" "\r\n"; EXPECT_EQ(expected_header, form_field.GetContentHeader()); const char expected_data[] = @@ -116,7 +117,7 @@ TEST(HttpFormData, MultiPartFormField) { "Content-Transfer-Encoding: binary\r\n" "\r\n" "\x01\x02\x03\x04\x05\r\n" - "--Delimiter--"; + "--Delimiter--\r\n"; EXPECT_EQ(expected_data, GetFormFieldData(&form_field)); } @@ -158,7 +159,7 @@ TEST(HttpFormData, FormData) { FormData form_data{"boundary1"}; form_data.AddTextField("name", "John Doe"); std::unique_ptr<MultiPartFormField> files{ - new MultiPartFormField{"files", "", "boundary2"}}; + new MultiPartFormField{"files", "", "boundary 2"}}; EXPECT_TRUE(files->AddFileField( "", filename1, content_disposition::kFile, mime::text::kPlain, nullptr)); EXPECT_TRUE(files->AddFileField("", @@ -167,7 +168,7 @@ TEST(HttpFormData, FormData) { mime::application::kOctet_stream, nullptr)); form_data.AddCustomField(std::move(files)); - EXPECT_EQ("multipart/form-data; boundary=\"boundary1\"", + EXPECT_EQ("multipart/form-data; boundary=boundary1", form_data.GetContentType()); StreamPtr stream = form_data.ExtractDataStream(); @@ -180,22 +181,22 @@ TEST(HttpFormData, FormData) { "John Doe\r\n" "--boundary1\r\n" "Content-Disposition: form-data; name=\"files\"\r\n" - "Content-Type: multipart/mixed; boundary=\"boundary2\"\r\n" + "Content-Type: multipart/mixed; boundary=\"boundary 2\"\r\n" "\r\n" - "--boundary2\r\n" + "--boundary 2\r\n" "Content-Disposition: file; filename=\"sample.txt\"\r\n" "Content-Type: text/plain\r\n" "Content-Transfer-Encoding: binary\r\n" "\r\n" "text line1\ntext line2\n\r\n" - "--boundary2\r\n" + "--boundary 2\r\n" "Content-Disposition: file; filename=\"test.bin\"\r\n" "Content-Type: application/octet-stream\r\n" "Content-Transfer-Encoding: binary\r\n" "\r\n" "\x01\x02\x03\x04\x05\r\n" - "--boundary2--\r\n" - "--boundary1--"; + "--boundary 2--\r\n\r\n" + "--boundary1--\r\n"; EXPECT_EQ(expected_data, (std::string{data.begin(), data.end()})); } } // namespace http diff --git a/brillo/http/http_proxy.cc b/brillo/http/http_proxy.cc index bf6a8af..b697518 100644 --- a/brillo/http/http_proxy.cc +++ b/brillo/http/http_proxy.cc @@ -6,6 +6,7 @@ #include <memory> #include <string> +#include <utility> #include <vector> #include <base/bind.h> diff --git a/brillo/http/http_proxy.h b/brillo/http/http_proxy.h index c142af2..46863b6 100644 --- a/brillo/http/http_proxy.h +++ b/brillo/http/http_proxy.h @@ -32,13 +32,13 @@ using GetChromeProxyServersCallback = // Even if this function returns false, it will still set |proxies_out| to be // just the direct proxy. This function will only return false if there is an // error in the D-Bus communication itself. -BRILLO_EXPORT bool GetChromeProxyServers(scoped_refptr<dbus::Bus> bus, +BRILLO_EXPORT bool GetChromeProxyServers(scoped_refptr<::dbus::Bus> bus, const std::string& url, std::vector<std::string>* proxies_out); // Async version of GetChromeProxyServers. BRILLO_EXPORT void GetChromeProxyServersAsync( - scoped_refptr<dbus::Bus> bus, + scoped_refptr<::dbus::Bus> bus, const std::string& url, const GetChromeProxyServersCallback& callback); diff --git a/brillo/http/http_proxy_unittest.cc b/brillo/http/http_proxy_test.cc index 4893a87..eb44263 100644 --- a/brillo/http/http_proxy_unittest.cc +++ b/brillo/http/http_proxy_test.cc @@ -4,7 +4,9 @@ #include <brillo/http/http_proxy.h> +#include <memory> #include <string> +#include <utility> #include <vector> #include <base/bind.h> @@ -30,25 +32,28 @@ class HttpProxyTest : public testing::Test { public: void ResolveProxyHandlerAsync(dbus::MethodCall* method_call, int timeout_msec, - dbus::ObjectProxy::ResponseCallback callback) { + dbus::ObjectProxy::ResponseCallback + MIGRATE_WrapObjectProxyCallback(callback)) { if (null_dbus_response_) { - callback.Run(nullptr); + std::move(MIGRATE_WrapObjectProxyCallback(callback)).Run(nullptr); return; } - callback.Run(CreateDBusResponse(method_call).get()); + std::move(MIGRATE_WrapObjectProxyCallback(callback)) + .Run(CreateDBusResponse(method_call).get()); } - dbus::Response* ResolveProxyHandler(dbus::MethodCall* method_call, - int timeout_msec) { + MIGRATE_WrapObjectProxyResponseType(dbus::Response) + ResolveProxyHandler(dbus::MethodCall* method_call, int timeout_msec) { if (null_dbus_response_) { - return nullptr; + return MIGRATE_WrapObjectProxyResponseEmpty; } - // The mock wraps this back into a std::unique_ptr in the function calling - // us. - return CreateDBusResponse(method_call).release(); + return MIGRATE_WrapObjectProxyResponseConversion( + CreateDBusResponse(method_call)); } - MOCK_METHOD2(GetProxiesCallback, void(bool, const std::vector<std::string>&)); + MOCK_METHOD(void, + GetProxiesCallback, + (bool, const std::vector<std::string>&)); protected: HttpProxyTest() { @@ -97,7 +102,7 @@ class HttpProxyTest : public testing::Test { TEST_F(HttpProxyTest, DBusNullResponseFails) { std::vector<std::string> proxies; null_dbus_response_ = true; - EXPECT_CALL(*object_proxy_, MockCallMethodAndBlock(_, _)) + EXPECT_CALL(*object_proxy_, MIGRATE_MockCallMethodAndBlock(_, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandler)); EXPECT_FALSE(GetChromeProxyServers(bus_, kTestUrl, &proxies)); } @@ -105,14 +110,14 @@ TEST_F(HttpProxyTest, DBusNullResponseFails) { TEST_F(HttpProxyTest, DBusInvalidResponseFails) { std::vector<std::string> proxies; invalid_dbus_response_ = true; - EXPECT_CALL(*object_proxy_, MockCallMethodAndBlock(_, _)) + EXPECT_CALL(*object_proxy_, MIGRATE_MockCallMethodAndBlock(_, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandler)); EXPECT_FALSE(GetChromeProxyServers(bus_, kTestUrl, &proxies)); } TEST_F(HttpProxyTest, NoProxies) { std::vector<std::string> proxies; - EXPECT_CALL(*object_proxy_, MockCallMethodAndBlock(_, _)) + EXPECT_CALL(*object_proxy_, MIGRATE_MockCallMethodAndBlock(_, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandler)); EXPECT_TRUE(GetChromeProxyServers(bus_, kTestUrl, &proxies)); EXPECT_THAT(proxies, ElementsAre(kDirectProxy)); @@ -121,7 +126,7 @@ TEST_F(HttpProxyTest, NoProxies) { TEST_F(HttpProxyTest, MultipleProxiesWithoutDirect) { proxy_info_ = "proxy example.com; socks5 foo.com;"; std::vector<std::string> proxies; - EXPECT_CALL(*object_proxy_, MockCallMethodAndBlock(_, _)) + EXPECT_CALL(*object_proxy_, MIGRATE_MockCallMethodAndBlock(_, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandler)); EXPECT_TRUE(GetChromeProxyServers(bus_, kTestUrl, &proxies)); EXPECT_THAT(proxies, ElementsAre("http://example.com", "socks5://foo.com", @@ -132,7 +137,7 @@ TEST_F(HttpProxyTest, MultipleProxiesWithDirect) { proxy_info_ = "socks foo.com; Https example.com ; badproxy example2.com ; " "socks5 test.com ; proxy foobar.com; DIRECT "; std::vector<std::string> proxies; - EXPECT_CALL(*object_proxy_, MockCallMethodAndBlock(_, _)) + EXPECT_CALL(*object_proxy_, MIGRATE_MockCallMethodAndBlock(_, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandler)); EXPECT_TRUE(GetChromeProxyServers(bus_, kTestUrl, &proxies)); EXPECT_THAT(proxies, ElementsAre("socks4://foo.com", "https://example.com", @@ -142,7 +147,7 @@ TEST_F(HttpProxyTest, MultipleProxiesWithDirect) { TEST_F(HttpProxyTest, DBusNullResponseFailsAsync) { null_dbus_response_ = true; - EXPECT_CALL(*object_proxy_, CallMethod(_, _, _)) + EXPECT_CALL(*object_proxy_, MIGRATE_CallMethod(_, _, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandlerAsync)); EXPECT_CALL(*this, GetProxiesCallback(false, _)); GetChromeProxyServersAsync( @@ -152,7 +157,7 @@ TEST_F(HttpProxyTest, DBusNullResponseFailsAsync) { TEST_F(HttpProxyTest, DBusInvalidResponseFailsAsync) { invalid_dbus_response_ = true; - EXPECT_CALL(*object_proxy_, CallMethod(_, _, _)) + EXPECT_CALL(*object_proxy_, MIGRATE_CallMethod(_, _, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandlerAsync)); EXPECT_CALL(*this, GetProxiesCallback(false, _)); GetChromeProxyServersAsync( @@ -168,7 +173,7 @@ TEST_F(HttpProxyTest, MultipleProxiesWithDirectAsync) { std::vector<std::string> expected = { "socks4://foo.com", "https://example.com", "socks5://test.com", "http://foobar.com", kDirectProxy}; - EXPECT_CALL(*object_proxy_, CallMethod(_, _, _)) + EXPECT_CALL(*object_proxy_, MIGRATE_CallMethod(_, _, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandlerAsync)); EXPECT_CALL(*this, GetProxiesCallback(true, expected)); GetChromeProxyServersAsync( diff --git a/brillo/http/http_request_unittest.cc b/brillo/http/http_request_test.cc index 39ccc18..e0be38b 100644 --- a/brillo/http/http_request_unittest.cc +++ b/brillo/http/http_request_test.cc @@ -6,8 +6,8 @@ #include <string> +#include <base/bind.h> #include <base/callback.h> -#include <brillo/bind_lambda.h> #include <brillo/http/mock_connection.h> #include <brillo/http/mock_transport.h> #include <brillo/mime_utils.h> diff --git a/brillo/http/http_transport.cc b/brillo/http/http_transport.cc index 0c27489..d713e50 100644 --- a/brillo/http/http_transport.cc +++ b/brillo/http/http_transport.cc @@ -26,5 +26,32 @@ std::shared_ptr<Transport> Transport::CreateDefaultWithProxy( } } +base::FilePath Transport::CertificateToPath(Transport::Certificate cert) { + const char* str; + switch (cert) { + case Certificate::kDefault: + str = +#ifdef __ANDROID__ + "/system/etc/security/cacerts_google"; +#else + "/usr/share/chromeos-ca-certificates"; +#endif + break; + case Certificate::kHermesProd: + str = "/usr/share/hermes-ca-certificates/prod"; + break; + case Certificate::kHermesTest: + str = "/usr/share/hermes-ca-certificates/test"; + break; + case Certificate::kNss: + str = "/etc/ssl/certs"; + break; + default: + CHECK(false) << "Invalid certificate"; + break; + } + return base::FilePath(str); +} + } // namespace http } // namespace brillo diff --git a/brillo/http/http_transport.h b/brillo/http/http_transport.h index e00166c..76ff901 100644 --- a/brillo/http/http_transport.h +++ b/brillo/http/http_transport.h @@ -11,6 +11,7 @@ #include <vector> #include <base/callback_forward.h> +#include <base/files/file_path.h> #include <base/location.h> #include <base/macros.h> #include <base/time/time.h> @@ -38,10 +39,26 @@ using ErrorCallback = base::Callback<void(RequestID, const brillo::Error*)>; /////////////////////////////////////////////////////////////////////////////// // Transport is a base class for specific implementation of HTTP communication. // This class (and its underlying implementation) is used by http::Request and -// http::Response classes to provide HTTP functionality to the clients. +// http::Response classes to provide HTTP functionality to the clients. By +// default, this interface will use CA certificates that only allow secure +// (HTTPS) communication with Google services. /////////////////////////////////////////////////////////////////////////////// class BRILLO_EXPORT Transport : public std::enable_shared_from_this<Transport> { public: + enum class Certificate { + // Default certificate; only allows communication with Google services. + kDefault, + // Certificates for communicating only with production SM-DP+ and SM-DS + // servers. + kHermesProd, + // Certificates for communicating only with test SM-DP+ and SM-DS servers. + kHermesTest, + // The NSS certificate store, which the curl command-line tool and libcurl + // library use by default. This set of certificates does not restrict + // secure communication to only Google services. + kNss, + }; + Transport() = default; virtual ~Transport() = default; @@ -87,6 +104,28 @@ class BRILLO_EXPORT Transport : public std::enable_shared_from_this<Transport> { // Set the local IP address of requests virtual void SetLocalIpAddress(const std::string& ip_address) = 0; + // Use the default CA certificate for certificate verification. This + // means that clients are only allowed to communicate with Google services. + virtual void UseDefaultCertificate() {} + + // Set the CA certificate to use for certificate verification. + // + // This call can allow a client to securly communicate with a different subset + // of services than it can otherwise. However, setting a custom certificate + // should be done only when necessary, and should be done with careful control + // over the certificates that are contained in the relevant path. See + // https://chromium.googlesource.com/chromiumos/docs/+/master/ca_certs.md for + // more information on certificates in Chrome OS. + virtual void UseCustomCertificate(Transport::Certificate cert) {} + + // Appends host entry to DNS cache. curl can only do HTTPS request to a custom + // IP if it resolves an HTTPS hostname to that IP. This is useful in + // forcing a particular mapping for an HTTPS host. See CURLOPT_RESOLVE for + // more details. + virtual void ResolveHostToIp(const std::string& host, + uint16_t port, + const std::string& ip_address) {} + // Creates a default http::Transport (currently, using http::curl::Transport). static std::shared_ptr<Transport> CreateDefault(); @@ -97,6 +136,12 @@ class BRILLO_EXPORT Transport : public std::enable_shared_from_this<Transport> { static std::shared_ptr<Transport> CreateDefaultWithProxy( const std::string& proxy); + protected: + // Clears the forced DNS mappings created by ResolveHostToIp. + virtual void ClearHost() {} + + static base::FilePath CertificateToPath(Certificate cert); + private: DISALLOW_COPY_AND_ASSIGN(Transport); }; diff --git a/brillo/http/http_transport_curl.cc b/brillo/http/http_transport_curl.cc index 9affc2a..45a28a3 100644 --- a/brillo/http/http_transport_curl.cc +++ b/brillo/http/http_transport_curl.cc @@ -7,23 +7,14 @@ #include <limits> #include <base/bind.h> +#include <base/files/file_util.h> #include <base/logging.h> #include <base/message_loop/message_loop.h> +#include <base/strings/stringprintf.h> #include <brillo/http/http_connection_curl.h> #include <brillo/http/http_request.h> #include <brillo/strings/string_utils.h> -namespace { - -const char kCACertificatePath[] = -#ifdef __ANDROID__ - "/system/etc/security/cacerts_google"; -#else - "/usr/share/brillo-ca-certificates"; -#endif - -} // namespace - namespace brillo { namespace http { namespace curl { @@ -101,15 +92,18 @@ struct Transport::AsyncRequestData { Transport::Transport(const std::shared_ptr<CurlInterface>& curl_interface) : curl_interface_{curl_interface} { VLOG(2) << "curl::Transport created"; + UseDefaultCertificate(); } Transport::Transport(const std::shared_ptr<CurlInterface>& curl_interface, const std::string& proxy) : curl_interface_{curl_interface}, proxy_{proxy} { VLOG(2) << "curl::Transport created with proxy " << proxy; + UseDefaultCertificate(); } Transport::~Transport() { + ClearHost(); ShutDownAsyncCurl(); VLOG(2) << "curl::Transport destroyed"; } @@ -134,8 +128,14 @@ std::shared_ptr<http::Connection> Transport::CreateConnection( CURLcode code = curl_interface_->EasySetOptStr(curl_handle, CURLOPT_URL, url); if (code == CURLE_OK) { + // CURLOPT_CAINFO is a string, but CurlApi::EasySetOptStr will never pass + // curl_easy_setopt a null pointer, so we use EasySetOptPtr instead. + code = curl_interface_->EasySetOptPtr(curl_handle, CURLOPT_CAINFO, nullptr); + } + if (code == CURLE_OK) { + CHECK(base::PathExists(certificate_path_)); code = curl_interface_->EasySetOptStr(curl_handle, CURLOPT_CAPATH, - kCACertificatePath); + certificate_path_.value()); } if (code == CURLE_OK) { code = @@ -169,6 +169,10 @@ std::shared_ptr<http::Connection> Transport::CreateConnection( code = curl_interface_->EasySetOptStr( curl_handle, CURLOPT_INTERFACE, ip_address_.c_str()); } + if (code == CURLE_OK && host_list_) { + code = curl_interface_->EasySetOptPtr(curl_handle, CURLOPT_RESOLVE, + host_list_); + } // Setup HTTP request method and optional request body. if (code == CURLE_OK) { @@ -274,6 +278,29 @@ void Transport::SetLocalIpAddress(const std::string& ip_address) { ip_address_ = "host!" + ip_address; } +void Transport::UseDefaultCertificate() { + UseCustomCertificate(Certificate::kDefault); +} + +void Transport::UseCustomCertificate(Transport::Certificate cert) { + certificate_path_ = CertificateToPath(cert); + CHECK(base::PathExists(certificate_path_)); +} + +void Transport::ResolveHostToIp(const std::string& host, + uint16_t port, + const std::string& ip_address) { + host_list_ = curl_slist_append( + host_list_, + base::StringPrintf("%s:%d:%s", host.c_str(), port, ip_address.c_str()) + .c_str()); +} + +void Transport::ClearHost() { + curl_slist_free_all(host_list_); + host_list_ = nullptr; +} + void Transport::AddEasyCurlError(brillo::ErrorPtr* error, const base::Location& location, CURLcode code, diff --git a/brillo/http/http_transport_curl.h b/brillo/http/http_transport_curl.h index 175a675..5af2c61 100644 --- a/brillo/http/http_transport_curl.h +++ b/brillo/http/http_transport_curl.h @@ -6,9 +6,11 @@ #define LIBBRILLO_BRILLO_HTTP_HTTP_TRANSPORT_CURL_H_ #include <map> +#include <memory> #include <string> #include <utility> +#include <base/location.h> #include <base/memory/weak_ptr.h> #include <brillo/brillo_export.h> #include <brillo/http/curl_api.h> @@ -61,6 +63,14 @@ class BRILLO_EXPORT Transport : public http::Transport { void SetLocalIpAddress(const std::string& ip_address) override; + void UseDefaultCertificate() override; + + void UseCustomCertificate(Certificate cert) override; + + void ResolveHostToIp(const std::string& host, + uint16_t port, + const std::string& ip_address) override; + // Helper methods to convert CURL error codes (CURLcode and CURLMcode) // into brillo::Error object. static void AddEasyCurlError(brillo::ErrorPtr* error, @@ -73,6 +83,9 @@ class BRILLO_EXPORT Transport : public http::Transport { CURLMcode code, CurlInterface* curl_interface); + protected: + void ClearHost() override; + private: // Forward-declaration of internal implementation structures. struct AsyncRequestData; @@ -130,6 +143,8 @@ class BRILLO_EXPORT Transport : public http::Transport { // The connection timeout for the requests made. base::TimeDelta connection_timeout_; std::string ip_address_; + base::FilePath certificate_path_; + curl_slist* host_list_{nullptr}; base::WeakPtrFactory<Transport> weak_ptr_factory_for_timer_{this}; base::WeakPtrFactory<Transport> weak_ptr_factory_{this}; diff --git a/brillo/http/http_transport_curl_unittest.cc b/brillo/http/http_transport_curl_test.cc index c05c81a..40ef23e 100644 --- a/brillo/http/http_transport_curl_unittest.cc +++ b/brillo/http/http_transport_curl_test.cc @@ -5,9 +5,9 @@ #include <brillo/http/http_transport_curl.h> #include <base/at_exit.h> +#include <base/bind.h> #include <base/message_loop/message_loop.h> #include <base/run_loop.h> -#include <brillo/bind_lambda.h> #include <brillo/http/http_connection_curl.h> #include <brillo/http/http_request.h> #include <brillo/http/mock_curl_api.h> @@ -33,6 +33,8 @@ class HttpCurlTransportTest : public testing::Test { transport_ = std::make_shared<Transport>(curl_api_); handle_ = reinterpret_cast<CURL*>(100); // Mock handle value. EXPECT_CALL(*curl_api_, EasyInit()).WillOnce(Return(handle_)); + EXPECT_CALL(*curl_api_, EasySetOptPtr(handle_, CURLOPT_CAINFO, _)) + .WillOnce(Return(CURLE_OK)); EXPECT_CALL(*curl_api_, EasySetOptStr(handle_, CURLOPT_CAPATH, _)) .WillOnce(Return(CURLE_OK)); EXPECT_CALL(*curl_api_, EasySetOptInt(handle_, CURLOPT_SSL_VERIFYPEER, 1)) @@ -197,6 +199,8 @@ class HttpCurlTransportAsyncTest : public testing::Test { curl_api_ = std::make_shared<MockCurlInterface>(); transport_ = std::make_shared<Transport>(curl_api_); EXPECT_CALL(*curl_api_, EasyInit()).WillOnce(Return(handle_)); + EXPECT_CALL(*curl_api_, EasySetOptPtr(handle_, CURLOPT_CAINFO, _)) + .WillOnce(Return(CURLE_OK)); EXPECT_CALL(*curl_api_, EasySetOptStr(handle_, CURLOPT_CAPATH, _)) .WillOnce(Return(CURLE_OK)); EXPECT_CALL(*curl_api_, EasySetOptInt(handle_, CURLOPT_SSL_VERIFYPEER, 1)) @@ -333,6 +337,23 @@ TEST_F(HttpCurlTransportTest, RequestGetTimeout) { connection.reset(); } +TEST_F(HttpCurlTransportTest, RequestGetResolveHost) { + transport_->ResolveHostToIp("foo.bar", 80, "127.0.0.1"); + EXPECT_CALL(*curl_api_, + EasySetOptStr(handle_, CURLOPT_URL, "http://foo.bar/get")) + .WillOnce(Return(CURLE_OK)); + EXPECT_CALL(*curl_api_, EasySetOptPtr(handle_, CURLOPT_RESOLVE, _)) + .WillOnce(Return(CURLE_OK)); + EXPECT_CALL(*curl_api_, EasySetOptInt(handle_, CURLOPT_HTTPGET, 1)) + .WillOnce(Return(CURLE_OK)); + auto connection = transport_->CreateConnection( + "http://foo.bar/get", request_type::kGet, {}, "", "", nullptr); + EXPECT_NE(nullptr, connection.get()); + + EXPECT_CALL(*curl_api_, EasyCleanup(handle_)).Times(1); + connection.reset(); +} + } // namespace curl } // namespace http } // namespace brillo diff --git a/brillo/http/http_transport_fake.cc b/brillo/http/http_transport_fake.cc index 224b5de..c4757f9 100644 --- a/brillo/http/http_transport_fake.cc +++ b/brillo/http/http_transport_fake.cc @@ -6,10 +6,10 @@ #include <utility> +#include <base/bind.h> #include <base/json/json_reader.h> #include <base/json/json_writer.h> #include <base/logging.h> -#include <brillo/bind_lambda.h> #include <brillo/http/http_connection_fake.h> #include <brillo/http/http_request.h> #include <brillo/mime_utils.h> diff --git a/brillo/http/http_transport_fake.h b/brillo/http/http_transport_fake.h index 0a2fe90..56351ec 100644 --- a/brillo/http/http_transport_fake.h +++ b/brillo/http/http_transport_fake.h @@ -6,12 +6,15 @@ #define LIBBRILLO_BRILLO_HTTP_HTTP_TRANSPORT_FAKE_H_ #include <map> +#include <memory> #include <queue> #include <string> #include <type_traits> +#include <utility> #include <vector> #include <base/callback.h> +#include <base/location.h> #include <base/values.h> #include <brillo/http/http_transport.h> #include <brillo/http/http_utils.h> @@ -104,6 +107,13 @@ class Transport : public http::Transport { void SetLocalIpAddress(const std::string& /* ip_address */) override {} + void ResolveHostToIp(const std::string& host, + uint16_t port, + const std::string& ip_address) override {} + + protected: + void ClearHost() override {} + private: // A list of user-supplied request handlers. std::map<std::string, HandlerCallback> handlers_; diff --git a/brillo/http/http_utils.h b/brillo/http/http_utils.h index e09bab8..0d4d109 100644 --- a/brillo/http/http_utils.h +++ b/brillo/http/http_utils.h @@ -5,6 +5,7 @@ #ifndef LIBBRILLO_BRILLO_HTTP_HTTP_UTILS_H_ #define LIBBRILLO_BRILLO_HTTP_HTTP_UTILS_H_ +#include <memory> #include <string> #include <utility> #include <vector> diff --git a/brillo/http/http_utils_unittest.cc b/brillo/http/http_utils_test.cc index 376ba53..409282c 100644 --- a/brillo/http/http_utils_unittest.cc +++ b/brillo/http/http_utils_test.cc @@ -6,8 +6,8 @@ #include <string> #include <vector> +#include <base/bind.h> #include <base/values.h> -#include <brillo/bind_lambda.h> #include <brillo/http/http_transport_fake.h> #include <brillo/http/http_utils.h> #include <brillo/mime_utils.h> @@ -366,7 +366,7 @@ TEST(HttpUtils, PostMultipartFormData) { "Content-Disposition: form-data; name=\"key2\"\r\n" "\r\n" "value2\r\n" - "--boundary123--"; + "--boundary123--\r\n"; EXPECT_EQ(expected_value, response->ExtractDataAsString()); } diff --git a/brillo/http/mock_connection.h b/brillo/http/mock_connection.h index 0796a7e..1810824 100644 --- a/brillo/http/mock_connection.h +++ b/brillo/http/mock_connection.h @@ -19,17 +19,22 @@ class MockConnection : public Connection { public: using Connection::Connection; - MOCK_METHOD2(SendHeaders, bool(const HeaderList&, ErrorPtr*)); - MOCK_METHOD2(MockSetRequestData, bool(Stream*, ErrorPtr*)); - MOCK_METHOD1(MockSetResponseData, void(Stream*)); - MOCK_METHOD1(FinishRequest, bool(ErrorPtr*)); - MOCK_METHOD2(FinishRequestAsync, - RequestID(const SuccessCallback&, const ErrorCallback&)); - MOCK_CONST_METHOD0(GetResponseStatusCode, int()); - MOCK_CONST_METHOD0(GetResponseStatusText, std::string()); - MOCK_CONST_METHOD0(GetProtocolVersion, std::string()); - MOCK_CONST_METHOD1(GetResponseHeader, std::string(const std::string&)); - MOCK_CONST_METHOD1(MockExtractDataStream, Stream*(brillo::ErrorPtr*)); + MOCK_METHOD(bool, SendHeaders, (const HeaderList&, ErrorPtr*), (override)); + MOCK_METHOD(bool, MockSetRequestData, (Stream*, ErrorPtr*)); + MOCK_METHOD(void, MockSetResponseData, (Stream*)); + MOCK_METHOD(bool, FinishRequest, (ErrorPtr*), (override)); + MOCK_METHOD(RequestID, + FinishRequestAsync, + (const SuccessCallback&, const ErrorCallback&), + (override)); + MOCK_METHOD(int, GetResponseStatusCode, (), (const, override)); + MOCK_METHOD(std::string, GetResponseStatusText, (), (const, override)); + MOCK_METHOD(std::string, GetProtocolVersion, (), (const, override)); + MOCK_METHOD(std::string, + GetResponseHeader, + (const std::string&), + (const, override)); + MOCK_METHOD(Stream*, MockExtractDataStream, (brillo::ErrorPtr*), (const)); private: bool SetRequestData(StreamPtr stream, brillo::ErrorPtr* error) override { diff --git a/brillo/http/mock_curl_api.h b/brillo/http/mock_curl_api.h index 32b6e0d..daac8c2 100644 --- a/brillo/http/mock_curl_api.h +++ b/brillo/http/mock_curl_api.h @@ -20,34 +20,67 @@ class MockCurlInterface : public CurlInterface { public: MockCurlInterface() = default; - MOCK_METHOD0(EasyInit, CURL*()); - MOCK_METHOD1(EasyCleanup, void(CURL*)); - MOCK_METHOD3(EasySetOptInt, CURLcode(CURL*, CURLoption, int)); - MOCK_METHOD3(EasySetOptStr, CURLcode(CURL*, CURLoption, const std::string&)); - MOCK_METHOD3(EasySetOptPtr, CURLcode(CURL*, CURLoption, void*)); - MOCK_METHOD3(EasySetOptCallback, CURLcode(CURL*, CURLoption, intptr_t)); - MOCK_METHOD3(EasySetOptOffT, CURLcode(CURL*, CURLoption, curl_off_t)); - MOCK_METHOD1(EasyPerform, CURLcode(CURL*)); - MOCK_CONST_METHOD3(EasyGetInfoInt, CURLcode(CURL*, CURLINFO, int*)); - MOCK_CONST_METHOD3(EasyGetInfoDbl, CURLcode(CURL*, CURLINFO, double*)); - MOCK_CONST_METHOD3(EasyGetInfoStr, CURLcode(CURL*, CURLINFO, std::string*)); - MOCK_CONST_METHOD3(EasyGetInfoPtr, CURLcode(CURL*, CURLINFO, void**)); - MOCK_CONST_METHOD1(EasyStrError, std::string(CURLcode)); - MOCK_METHOD0(MultiInit, CURLM*()); - MOCK_METHOD1(MultiCleanup, CURLMcode(CURLM*)); - MOCK_METHOD2(MultiInfoRead, CURLMsg*(CURLM*, int*)); - MOCK_METHOD2(MultiAddHandle, CURLMcode(CURLM*, CURL*)); - MOCK_METHOD2(MultiRemoveHandle, CURLMcode(CURLM*, CURL*)); - MOCK_METHOD3(MultiSetSocketCallback, - CURLMcode(CURLM*, curl_socket_callback, void*)); - MOCK_METHOD3(MultiSetTimerCallback, - CURLMcode(CURLM*, curl_multi_timer_callback, void*)); - MOCK_METHOD3(MultiAssign, CURLMcode(CURLM*, curl_socket_t, void*)); - MOCK_METHOD4(MultiSocketAction, CURLMcode(CURLM*, curl_socket_t, int, int*)); - MOCK_CONST_METHOD1(MultiStrError, std::string(CURLMcode)); - MOCK_METHOD2(MultiPerform, CURLMcode(CURLM*, int*)); - MOCK_METHOD5(MultiWait, - CURLMcode(CURLM*, curl_waitfd[], unsigned int, int, int*)); + MOCK_METHOD(CURL*, EasyInit, (), (override)); + MOCK_METHOD(void, EasyCleanup, (CURL*), (override)); + MOCK_METHOD(CURLcode, EasySetOptInt, (CURL*, CURLoption, int), (override)); + MOCK_METHOD(CURLcode, + EasySetOptStr, + (CURL*, CURLoption, const std::string&), + (override)); + MOCK_METHOD(CURLcode, EasySetOptPtr, (CURL*, CURLoption, void*), (override)); + MOCK_METHOD(CURLcode, + EasySetOptCallback, + (CURL*, CURLoption, intptr_t), + (override)); + MOCK_METHOD(CURLcode, + EasySetOptOffT, + (CURL*, CURLoption, curl_off_t), + (override)); + MOCK_METHOD(CURLcode, EasyPerform, (CURL*), (override)); + MOCK_METHOD(CURLcode, + EasyGetInfoInt, + (CURL*, CURLINFO, int*), + (const, override)); + MOCK_METHOD(CURLcode, + EasyGetInfoDbl, + (CURL*, CURLINFO, double*), + (const, override)); + MOCK_METHOD(CURLcode, + EasyGetInfoStr, + (CURL*, CURLINFO, std::string*), + (const, override)); + MOCK_METHOD(CURLcode, + EasyGetInfoPtr, + (CURL*, CURLINFO, void**), + (const, override)); + MOCK_METHOD(std::string, EasyStrError, (CURLcode), (const, override)); + MOCK_METHOD(CURLM*, MultiInit, (), (override)); + MOCK_METHOD(CURLMcode, MultiCleanup, (CURLM*), (override)); + MOCK_METHOD(CURLMsg*, MultiInfoRead, (CURLM*, int*), (override)); + MOCK_METHOD(CURLMcode, MultiAddHandle, (CURLM*, CURL*), (override)); + MOCK_METHOD(CURLMcode, MultiRemoveHandle, (CURLM*, CURL*), (override)); + MOCK_METHOD(CURLMcode, + MultiSetSocketCallback, + (CURLM*, curl_socket_callback, void*), + (override)); + MOCK_METHOD(CURLMcode, + MultiSetTimerCallback, + (CURLM*, curl_multi_timer_callback, void*), + (override)); + MOCK_METHOD(CURLMcode, + MultiAssign, + (CURLM*, curl_socket_t, void*), + (override)); + MOCK_METHOD(CURLMcode, + MultiSocketAction, + (CURLM*, curl_socket_t, int, int*), + (override)); + MOCK_METHOD(std::string, MultiStrError, (CURLMcode), (const, override)); + MOCK_METHOD(CURLMcode, MultiPerform, (CURLM*, int*), (override)); + MOCK_METHOD(CURLMcode, + MultiWait, + (CURLM*, curl_waitfd[], unsigned int, int, int*), + (override)); private: DISALLOW_COPY_AND_ASSIGN(MockCurlInterface); diff --git a/brillo/http/mock_transport.h b/brillo/http/mock_transport.h index 7504266..a9f5d46 100644 --- a/brillo/http/mock_transport.h +++ b/brillo/http/mock_transport.h @@ -8,6 +8,7 @@ #include <memory> #include <string> +#include <base/location.h> #include <base/macros.h> #include <brillo/http/http_transport.h> #include <gmock/gmock.h> @@ -19,21 +20,35 @@ class MockTransport : public Transport { public: MockTransport() = default; - MOCK_METHOD6(CreateConnection, - std::shared_ptr<Connection>(const std::string&, - const std::string&, - const HeaderList&, - const std::string&, - const std::string&, - brillo::ErrorPtr*)); - MOCK_METHOD2(RunCallbackAsync, - void(const base::Location&, const base::Closure&)); - MOCK_METHOD3(StartAsyncTransfer, RequestID(Connection*, - const SuccessCallback&, - const ErrorCallback&)); - MOCK_METHOD1(CancelRequest, bool(RequestID)); - MOCK_METHOD1(SetDefaultTimeout, void(base::TimeDelta)); - MOCK_METHOD1(SetLocalIpAddress, void(const std::string&)); + MOCK_METHOD(std::shared_ptr<Connection>, + CreateConnection, + (const std::string&, + const std::string&, + const HeaderList&, + const std::string&, + const std::string&, + brillo::ErrorPtr*), + (override)); + MOCK_METHOD(void, + RunCallbackAsync, + (const base::Location&, const base::Closure&), + (override)); + MOCK_METHOD(RequestID, + StartAsyncTransfer, + (Connection*, const SuccessCallback&, const ErrorCallback&), + (override)); + MOCK_METHOD(bool, CancelRequest, (RequestID), (override)); + MOCK_METHOD(void, SetDefaultTimeout, (base::TimeDelta), (override)); + MOCK_METHOD(void, SetLocalIpAddress, (const std::string&), (override)); + MOCK_METHOD(void, UseDefaultCertificate, (), (override)); + MOCK_METHOD(void, UseCustomCertificate, (Certificate), (override)); + MOCK_METHOD(void, + ResolveHostToIp, + (const std::string&, uint16_t, const std::string&), + (override)); + + protected: + MOCK_METHOD(void, ClearHost, (), (override)); private: DISALLOW_COPY_AND_ASSIGN(MockTransport); diff --git a/brillo/imageloader/manifest.cc b/brillo/imageloader/manifest.cc deleted file mode 100644 index 92789df..0000000 --- a/brillo/imageloader/manifest.cc +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2018 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include <brillo/imageloader/manifest.h> - -#include <memory> -#include <utility> - -#include <base/json/json_string_value_serializer.h> -#include <base/strings/string_number_conversions.h> - -namespace brillo { -namespace imageloader { - -namespace { -// The current version of the manifest file. -constexpr int kCurrentManifestVersion = 1; -// The name of the version field in the manifest. -constexpr char kManifestVersionField[] = "manifest-version"; -// The name of the component version field in the manifest. -constexpr char kVersionField[] = "version"; -// The name of the field containing the image hash. -constexpr char kImageHashField[] = "image-sha256-hash"; -// The name of the bool field indicating whether component is removable. -constexpr char kIsRemovableField[] = "is-removable"; -// The name of the metadata field. -constexpr char kMetadataField[] = "metadata"; -// The name of the field containing the table hash. -constexpr char kTableHashField[] = "table-sha256-hash"; -// The name of the optional field containing the file system type. -constexpr char kFSType[] = "fs-type"; - -bool GetSHA256FromString(const std::string& hash_str, - std::vector<uint8_t>* bytes) { - if (!base::HexStringToBytes(hash_str, bytes)) - return false; - return bytes->size() == 32; -} - -// Ensure the metadata entry is a dictionary mapping strings to strings and -// parse it into |out_metadata| and return true if so. -bool ParseMetadata(const base::Value* metadata_element, - std::map<std::string, std::string>* out_metadata) { - DCHECK(out_metadata); - - const base::DictionaryValue* metadata_dict = nullptr; - if (!metadata_element->GetAsDictionary(&metadata_dict)) - return false; - - base::DictionaryValue::Iterator it(*metadata_dict); - for (; !it.IsAtEnd(); it.Advance()) { - std::string parsed_value; - if (!it.value().GetAsString(&parsed_value)) { - LOG(ERROR) << "Key \"" << it.key() << "\" did not map to string value"; - return false; - } - - (*out_metadata)[it.key()] = std::move(parsed_value); - } - - return true; -} - -} // namespace - -Manifest::Manifest() {} - -bool Manifest::ParseManifest(const std::string& manifest_raw) { - // Now deserialize the manifest json and read out the rest of the component. - int error_code; - std::string error_message; - JSONStringValueDeserializer deserializer(manifest_raw); - std::unique_ptr<base::Value> value = - deserializer.Deserialize(&error_code, &error_message); - - if (!value) { - LOG(ERROR) << "Could not deserialize the manifest file. Error " - << error_code << ": " << error_message; - return false; - } - - base::DictionaryValue* manifest_dict = nullptr; - if (!value->GetAsDictionary(&manifest_dict)) { - LOG(ERROR) << "Could not parse manifest file as JSON."; - return false; - } - - // This will have to be changed if the manifest version is bumped. - int version; - if (!manifest_dict->GetInteger(kManifestVersionField, &version)) { - LOG(ERROR) << "Could not parse manifest version field from manifest."; - return false; - } - if (version != kCurrentManifestVersion) { - LOG(ERROR) << "Unsupported version of the manifest."; - return false; - } - manifest_version_ = version; - - std::string image_hash_str; - if (!manifest_dict->GetString(kImageHashField, &image_hash_str)) { - LOG(ERROR) << "Could not parse image hash from manifest."; - return false; - } - - if (!GetSHA256FromString(image_hash_str, &(image_sha256_))) { - LOG(ERROR) << "Could not convert image hash to bytes."; - return false; - } - - std::string table_hash_str; - if (!manifest_dict->GetString(kTableHashField, &table_hash_str)) { - LOG(ERROR) << "Could not parse table hash from manifest."; - return false; - } - - if (!GetSHA256FromString(table_hash_str, &(table_sha256_))) { - LOG(ERROR) << "Could not convert table hash to bytes."; - return false; - } - - if (!manifest_dict->GetString(kVersionField, &(version_))) { - LOG(ERROR) << "Could not parse component version from manifest."; - return false; - } - - // The fs_type field is optional, and squashfs by default. - fs_type_ = FileSystem::kSquashFS; - std::string fs_type; - if (manifest_dict->GetString(kFSType, &fs_type)) { - if (fs_type == "ext4") { - fs_type_ = FileSystem::kExt4; - } else if (fs_type == "squashfs") { - fs_type_ = FileSystem::kSquashFS; - } else { - LOG(ERROR) << "Unsupported file system type: " << fs_type; - return false; - } - } - - if (!manifest_dict->GetBoolean(kIsRemovableField, &(is_removable_))) { - // If is_removable field does not exist, by default it is false. - is_removable_ = false; - } - - // Copy out the metadata, if it's there. - const base::Value* metadata = nullptr; - if (manifest_dict->Get(kMetadataField, &metadata)) { - if (!ParseMetadata(metadata, &(metadata_))) { - LOG(ERROR) << "Manifest metadata was malformed"; - return false; - } - } - - return true; -} - -int Manifest::manifest_version() const { - return manifest_version_; -} - -const std::vector<uint8_t>& Manifest::image_sha256() const { - return image_sha256_; -} - -const std::vector<uint8_t>& Manifest::table_sha256() const { - return table_sha256_; -} - -const std::string& Manifest::version() const { - return version_; -} - -FileSystem Manifest::fs_type() const { - return fs_type_; -} - -bool Manifest::is_removable() const { - return is_removable_; -} - -const std::map<std::string, std::string> Manifest::metadata() const { - return metadata_; -} - -} // namespace imageloader -} // namespace brillo diff --git a/brillo/imageloader/manifest.h b/brillo/imageloader/manifest.h deleted file mode 100644 index cfd7c3a..0000000 --- a/brillo/imageloader/manifest.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2018 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef LIBBRILLO_BRILLO_IMAGELOADER_MANIFEST_H_ -#define LIBBRILLO_BRILLO_IMAGELOADER_MANIFEST_H_ - -#include <map> -#include <string> -#include <vector> - -#include <base/macros.h> -#include <brillo/brillo_export.h> - -namespace brillo { -namespace imageloader { - -// The supported file systems for images. -enum class FileSystem { kExt4, kSquashFS }; - -// A class to parse and store imageloader.json manifest. -class BRILLO_EXPORT Manifest { - public: - Manifest(); - // Parse the manifest raw string. Return true if successful. - bool ParseManifest(const std::string& manifest_raw); - // Getters for manifest fields: - int manifest_version() const; - const std::vector<uint8_t>& image_sha256() const; - const std::vector<uint8_t>& table_sha256() const; - const std::string& version() const; - FileSystem fs_type() const; - bool is_removable() const; - const std::map<std::string, std::string> metadata() const; - - private: - // Manifest fields: - int manifest_version_; - std::vector<uint8_t> image_sha256_; - std::vector<uint8_t> table_sha256_; - std::string version_; - FileSystem fs_type_; - bool is_removable_; - std::map<std::string, std::string> metadata_; - - DISALLOW_COPY_AND_ASSIGN(Manifest); -}; - -} // namespace imageloader -} // namespace brillo - -#endif // LIBBRILLO_BRILLO_IMAGELOADER_MANIFEST_H_ diff --git a/brillo/imageloader/manifest_unittest.cc b/brillo/imageloader/manifest_unittest.cc deleted file mode 100644 index bca7e8b..0000000 --- a/brillo/imageloader/manifest_unittest.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include <gtest/gtest.h> - -#include <brillo/imageloader/manifest.h> - -namespace brillo { -namespace imageloader { - -class ManifestTest : public testing::Test {}; - -TEST_F(ManifestTest, ParseManifest) { - const std::string fs_type = R"("ext4")"; - const std::string is_removable = R"(true)"; - const std::string image_sha256_hash = - R"("4CF41BD11362CCB4707FB93939DBB5AC48745EDFC9DC8D7702852FFAA81B3B3F")"; - const std::string table_sha256_hash = - R"("0E11DA3D7140C6B95496787F50D15152434EBA22B60443BFA7E054FF4C799276")"; - const std::string version = R"("9824.0.4")"; - const std::string manifest_version = R"(1)"; - const std::string manifest_raw = std::string() + R"( - { - "fs-type":)" + fs_type + R"(, - "is-removable":)" + is_removable + - R"(, - "image-sha256-hash":)" + image_sha256_hash + - R"(, - "table-sha256-hash":)" + table_sha256_hash + - R"(, - "version":)" + version + R"(, - "manifest-version":)" + manifest_version + - R"( - } - )"; - brillo::imageloader::Manifest manifest; - // Parse the manifest raw string. - ASSERT_TRUE(manifest.ParseManifest(manifest_raw)); - EXPECT_EQ(manifest.fs_type(), FileSystem::kExt4); - EXPECT_EQ(manifest.is_removable(), true); - EXPECT_NE(manifest.image_sha256().size(), 0); - EXPECT_NE(manifest.table_sha256().size(), 0); - EXPECT_NE(manifest.version().size(), 0); - EXPECT_EQ(manifest.manifest_version(), 1); -} - -} // namespace imageloader -} // namespace brillo diff --git a/brillo/key_value_store.cc b/brillo/key_value_store.cc index 7840427..46c1d5c 100644 --- a/brillo/key_value_store.cc +++ b/brillo/key_value_store.cc @@ -4,7 +4,6 @@ #include "brillo/key_value_store.h" -#include <map> #include <string> #include <vector> @@ -15,7 +14,6 @@ #include <brillo/strings/string_utils.h> #include <brillo/map_utils.h> -using std::map; using std::string; using std::vector; @@ -37,6 +35,11 @@ string TrimKey(const string& key) { } // namespace +KeyValueStore::KeyValueStore() = default; +KeyValueStore::~KeyValueStore() = default; +KeyValueStore::KeyValueStore(KeyValueStore&&) = default; +KeyValueStore& KeyValueStore::operator=(KeyValueStore&&) = default; + bool KeyValueStore::Load(const base::FilePath& path) { string file_data; if (!base::ReadFileToString(path, &file_data)) @@ -89,6 +92,10 @@ string KeyValueStore::SaveToString() const { return data; } +void KeyValueStore::Clear() { + store_.clear(); +} + bool KeyValueStore::GetString(const string& key, string* value) const { const auto key_value = store_.find(TrimKey(key)); if (key_value == store_.end()) diff --git a/brillo/key_value_store.h b/brillo/key_value_store.h index cc5fa40..0c8e614 100644 --- a/brillo/key_value_store.h +++ b/brillo/key_value_store.h @@ -21,15 +21,26 @@ namespace brillo { class BRILLO_EXPORT KeyValueStore { public: // Creates an empty KeyValueStore. - KeyValueStore() = default; - virtual ~KeyValueStore() = default; + KeyValueStore(); + virtual ~KeyValueStore(); + // Copying is expensive; disallow accidental copies. + KeyValueStore(const KeyValueStore&) = delete; + KeyValueStore& operator=(const KeyValueStore&) = delete; + KeyValueStore(KeyValueStore&&); + KeyValueStore& operator=(KeyValueStore&&); // Loads the key=value pairs from the given |path|. Lines starting with '#' // and empty lines are ignored, and whitespace around keys is trimmed. // Trailing backslashes may be used to extend values across multiple lines. // Adds all the read key=values to the store, overriding those already defined - // but persisting the ones that aren't present on the passed file. Returns - // whether reading the file succeeded. + // but persisting the ones that aren't present on the passed file. + // + // Returns true, if the entire file is loaded successfully. If an error occurs + // while loading, keeps the pairs that were loaded before the error, and + // returns false. + // + // This function does not clear its internal state before loading. To clear + // the internal state, call Clear(). bool Load(const base::FilePath& path); // Loads the key=value pairs parsing the text passed in |data|. See Load() for @@ -48,6 +59,9 @@ class BRILLO_EXPORT KeyValueStore { // these values will be rewritten on single lines), comments or empty lines. std::string SaveToString() const; + // Clears all the key-value pairs currently stored. + void Clear(); + // Getter for the given key. Returns whether the key was found on the store. bool GetString(const std::string& key, std::string* value) const; @@ -67,8 +81,6 @@ class BRILLO_EXPORT KeyValueStore { private: // The map storing all the key-value pairs. std::map<std::string, std::string> store_; - - DISALLOW_COPY_AND_ASSIGN(KeyValueStore); }; } // namespace brillo diff --git a/brillo/key_value_store_unittest.cc b/brillo/key_value_store_test.cc index 68875ef..ceb8df6 100644 --- a/brillo/key_value_store_unittest.cc +++ b/brillo/key_value_store_test.cc @@ -6,6 +6,7 @@ #include <map> #include <string> +#include <utility> #include <vector> #include <base/files/file_util.h> @@ -37,6 +38,35 @@ class KeyValueStoreTest : public ::testing::Test { KeyValueStore store_; // KeyValueStore under test. }; +TEST_F(KeyValueStoreTest, MoveConstructor) { + store_.SetBoolean("a_boolean", true); + store_.SetString("a_string", "a_value"); + + KeyValueStore moved_to(std::move(store_)); + bool b_value = false; + EXPECT_TRUE(moved_to.GetBoolean("a_boolean", &b_value)); + EXPECT_TRUE(b_value); + + std::string s_value; + EXPECT_TRUE(moved_to.GetString("a_string", &s_value)); + EXPECT_EQ(s_value, "a_value"); +} + +TEST_F(KeyValueStoreTest, MoveAssignmentOperator) { + store_.SetBoolean("a_boolean", true); + store_.SetString("a_string", "a_value"); + + KeyValueStore moved_to; + moved_to = std::move(store_); + bool b_value = false; + EXPECT_TRUE(moved_to.GetBoolean("a_boolean", &b_value)); + EXPECT_TRUE(b_value); + + std::string s_value; + EXPECT_TRUE(moved_to.GetString("a_string", &s_value)); + EXPECT_EQ(s_value, "a_value"); +} + TEST_F(KeyValueStoreTest, LoadAndSaveFromFile) { base::ScopedTempDir temp_dir_; CHECK(temp_dir_.CreateUniqueTempDir()); @@ -96,6 +126,26 @@ TEST_F(KeyValueStoreTest, LoadAndReloadTest) { } } +TEST_F(KeyValueStoreTest, MultipleLoads) { + // The internal state is not cleared before loading. + EXPECT_TRUE(store_.LoadFromString("A=B\n")); + EXPECT_TRUE(store_.LoadFromString("B=C\n")); + EXPECT_EQ(2, store_.GetKeys().size()); +} + +TEST_F(KeyValueStoreTest, PartialLoad) { + // The 2nd line is broken, but the pair from the first line should be kept. + EXPECT_FALSE(store_.LoadFromString("A=B\n=\n")); + EXPECT_EQ(1, store_.GetKeys().size()); +} + +TEST_F(KeyValueStoreTest, Clear) { + EXPECT_TRUE(store_.LoadFromString("A=B\n")); + EXPECT_EQ(1, store_.GetKeys().size()); + store_.Clear(); + EXPECT_EQ(0, store_.GetKeys().size()); +} + TEST_F(KeyValueStoreTest, SimpleBooleanTest) { bool result; EXPECT_FALSE(store_.GetBoolean("A", &result)); diff --git a/brillo/map_utils_unittest.cc b/brillo/map_utils_test.cc index 19bda1d..19bda1d 100644 --- a/brillo/map_utils_unittest.cc +++ b/brillo/map_utils_test.cc diff --git a/brillo/message_loops/base_message_loop.cc b/brillo/message_loops/base_message_loop.cc index 08465d7..9a9e43f 100644 --- a/brillo/message_loops/base_message_loop.cc +++ b/brillo/message_loops/base_message_loop.cc @@ -6,6 +6,7 @@ #include <fcntl.h> #include <sys/stat.h> +#include <sys/sysmacros.h> #include <sys/types.h> #include <unistd.h> @@ -19,6 +20,7 @@ #include <linux/major.h> #endif +#include <utility> #include <vector> #include <base/bind.h> @@ -50,12 +52,14 @@ BaseMessageLoop::BaseMessageLoop() { CHECK(!base::MessageLoop::current()) << "You can't create a base::MessageLoopForIO when another " "base::MessageLoop is already created for this thread."; - owned_base_loop_.reset(new base::MessageLoopForIO); + owned_base_loop_.reset(new base::MessageLoopForIO()); base_loop_ = owned_base_loop_.get(); + watcher_ = std::make_unique<base::FileDescriptorWatcher>(base_loop_); } BaseMessageLoop::BaseMessageLoop(base::MessageLoopForIO* base_loop) - : base_loop_(base_loop) {} + : base_loop_(base_loop), + watcher_(std::make_unique<base::FileDescriptorWatcher>(base_loop_)) {} BaseMessageLoop::~BaseMessageLoop() { for (auto& io_task : io_tasks_) { @@ -97,8 +101,7 @@ MessageLoop::TaskId BaseMessageLoop::PostDelayedTask( if (!base_scheduled) return MessageLoop::kTaskIdNull; - delayed_tasks_.emplace(task_id, - DelayedTask{from_here, task_id, std::move(task)}); + delayed_tasks_.emplace(task_id, DelayedTask{from_here, task_id, task}); return task_id; } diff --git a/brillo/message_loops/base_message_loop.h b/brillo/message_loops/base_message_loop.h index 163ea4f..c038ac7 100644 --- a/brillo/message_loops/base_message_loop.h +++ b/brillo/message_loops/base_message_loop.h @@ -16,6 +16,7 @@ #include <memory> #include <string> +#include <base/files/file_descriptor_watcher_posix.h> #include <base/location.h> #include <base/memory/weak_ptr.h> #include <base/message_loop/message_loop.h> @@ -120,7 +121,7 @@ class BRILLO_EXPORT BaseMessageLoop : public MessageLoop { // Sets the closure to be run immediately whenever the file descriptor // becomes ready. - void RunImmediately() { immediate_run_= true; } + void RunImmediately() { immediate_run_ = true; } private: base::Location location_; @@ -178,6 +179,9 @@ class BRILLO_EXPORT BaseMessageLoop : public MessageLoop { // point to that instance. base::MessageLoopForIO* base_loop_; + // FileDescriptorWatcher for |base_loop_|. This is used in AlarmTimer. + std::unique_ptr<base::FileDescriptorWatcher> watcher_; + // The RunLoop instance used to run the main loop from Run(). base::RunLoop* base_run_loop_{nullptr}; diff --git a/brillo/message_loops/base_message_loop_unittest.cc b/brillo/message_loops/base_message_loop_test.cc index 9e052a8..9e052a8 100644 --- a/brillo/message_loops/base_message_loop_unittest.cc +++ b/brillo/message_loops/base_message_loop_test.cc diff --git a/brillo/message_loops/fake_message_loop_unittest.cc b/brillo/message_loops/fake_message_loop_test.cc index 18f0b4b..b4b839c 100644 --- a/brillo/message_loops/fake_message_loop_unittest.cc +++ b/brillo/message_loops/fake_message_loop_test.cc @@ -13,7 +13,6 @@ #include <base/test/simple_test_clock.h> #include <gtest/gtest.h> -#include <brillo/bind_lambda.h> #include <brillo/message_loops/message_loop.h> using base::Bind; diff --git a/brillo/message_loops/message_loop.h b/brillo/message_loops/message_loop.h index 1f65d96..e9f804e 100644 --- a/brillo/message_loops/message_loop.h +++ b/brillo/message_loops/message_loop.h @@ -91,8 +91,7 @@ class BRILLO_EXPORT MessageLoop { WatchMode mode, bool persistent, const base::Closure& task) { - return WatchFileDescriptor( - base::Location(), fd, mode, persistent, task); + return WatchFileDescriptor(base::Location(), fd, mode, persistent, task); } // Cancel a scheduled task. Returns whether the task was canceled. For diff --git a/brillo/message_loops/message_loop_unittest.cc b/brillo/message_loops/message_loop_test.cc index bda3336..7b57015 100644 --- a/brillo/message_loops/message_loop_unittest.cc +++ b/brillo/message_loops/message_loop_test.cc @@ -7,7 +7,7 @@ // These are the common tests for all the brillo::MessageLoop implementations // that should conform to this interface's contracts. For extra // implementation-specific tests see the particular implementation unittests in -// the *_unittest.cc files. +// the *_test.cc files. #include <memory> #include <vector> @@ -18,10 +18,9 @@ #include <base/posix/eintr_wrapper.h> #include <gtest/gtest.h> -#include <brillo/bind_lambda.h> -#include <brillo/unittest_utils.h> #include <brillo/message_loops/base_message_loop.h> #include <brillo/message_loops/message_loop_utils.h> +#include <brillo/unittest_utils.h> using base::Bind; using base::TimeDelta; @@ -68,7 +67,8 @@ class MessageLoopTest : public ::testing::Test { template <> void MessageLoopTest<BaseMessageLoop>::MessageLoopSetUp() { base_loop_.reset(new base::MessageLoopForIO()); - loop_.reset(new BaseMessageLoop(base::MessageLoopForIO::current())); + loop_.reset(new BaseMessageLoop(base_loop_.get())); + loop_->SetAsCurrent(); } // This setups gtest to run each one of the following TYPED_TEST test cases on diff --git a/brillo/message_loops/message_loop_utils.cc b/brillo/message_loops/message_loop_utils.cc index 9ebe865..c16f268 100644 --- a/brillo/message_loops/message_loop_utils.cc +++ b/brillo/message_loops/message_loop_utils.cc @@ -4,8 +4,8 @@ #include <brillo/message_loops/message_loop_utils.h> +#include <base/bind.h> #include <base/location.h> -#include <brillo/bind_lambda.h> namespace brillo { diff --git a/brillo/message_loops/mock_message_loop.h b/brillo/message_loops/mock_message_loop.h index 9f9a1e4..c84e585 100644 --- a/brillo/message_loops/mock_message_loop.h +++ b/brillo/message_loops/mock_message_loop.h @@ -57,20 +57,19 @@ class BRILLO_EXPORT MockMessageLoop : public MessageLoop { } ~MockMessageLoop() override = default; - MOCK_METHOD3(PostDelayedTask, - TaskId(const base::Location& from_here, - const base::Closure& task, - base::TimeDelta delay)); + MOCK_METHOD(TaskId, + PostDelayedTask, + (const base::Location&, const base::Closure&, base::TimeDelta), + (override)); using MessageLoop::PostDelayedTask; - MOCK_METHOD5(WatchFileDescriptor, - TaskId(const base::Location& from_here, - int fd, - WatchMode mode, - bool persistent, - const base::Closure& task)); + MOCK_METHOD( + TaskId, + WatchFileDescriptor, + (const base::Location&, int, WatchMode, bool, const base::Closure&), + (override)); using MessageLoop::WatchFileDescriptor; - MOCK_METHOD1(CancelTask, bool(TaskId task_id)); - MOCK_METHOD1(RunOnce, bool(bool may_block)); + MOCK_METHOD(bool, CancelTask, (TaskId), (override)); + MOCK_METHOD(bool, RunOnce, (bool), (override)); // Returns the actual FakeMessageLoop instance so default actions can be // override with other actions or call diff --git a/brillo/mime_utils_unittest.cc b/brillo/mime_utils_test.cc index a7595dc..a7595dc 100644 --- a/brillo/mime_utils_unittest.cc +++ b/brillo/mime_utils_test.cc diff --git a/brillo/minijail/minijail.cc b/brillo/minijail/minijail.cc index 305f073..a08233d 100644 --- a/brillo/minijail/minijail.cc +++ b/brillo/minijail/minijail.cc @@ -11,6 +11,9 @@ using std::vector; namespace brillo { +static base::LazyInstance<Minijail>::DestructorAtExit g_minijail + = LAZY_INSTANCE_INITIALIZER; + Minijail::Minijail() {} Minijail::~Minijail() {} @@ -65,6 +68,14 @@ void Minijail::ResetSignalMask(struct minijail* jail) { minijail_reset_signal_mask(jail); } +void Minijail::CloseOpenFds(struct minijail* jail) { + minijail_close_open_fds(jail); +} + +void Minijail::PreserveFd(struct minijail* jail, int parent_fd, int child_fd) { + minijail_preserve_fd(jail, parent_fd, child_fd); +} + void Minijail::Enter(struct minijail* jail) { minijail_enter(jail); } diff --git a/brillo/minijail/minijail.h b/brillo/minijail/minijail.h index 15167cf..c71211d 100644 --- a/brillo/minijail/minijail.h +++ b/brillo/minijail/minijail.h @@ -12,6 +12,9 @@ extern "C" { #include <sys/types.h> } +#include <base/lazy_instance.h> +#include <brillo/brillo_export.h> + #include <libminijail.h> #include "base/macros.h" @@ -19,7 +22,7 @@ extern "C" { namespace brillo { // A Minijail abstraction allowing Minijail mocking in tests. -class Minijail { +class BRILLO_EXPORT Minijail { public: virtual ~Minijail(); @@ -55,6 +58,12 @@ class Minijail { // minijail_reset_signal_mask virtual void ResetSignalMask(struct minijail* jail); + // minijail_close_open_fds + virtual void CloseOpenFds(struct minijail* jail); + + // minijail_preserve_fd + virtual void PreserveFd(struct minijail* jail, int parent_fd, int child_fd); + // minijail_enter virtual void Enter(struct minijail* jail); @@ -108,6 +117,8 @@ class Minijail { Minijail(); private: + friend base::LazyInstanceTraitsBase<Minijail>; + DISALLOW_COPY_AND_ASSIGN(Minijail); }; diff --git a/brillo/minijail/mock_minijail.h b/brillo/minijail/mock_minijail.h index a855632..6c95405 100644 --- a/brillo/minijail/mock_minijail.h +++ b/brillo/minijail/mock_minijail.h @@ -19,45 +19,46 @@ class MockMinijail : public brillo::Minijail { MockMinijail() {} virtual ~MockMinijail() {} - MOCK_METHOD0(New, struct minijail*()); - MOCK_METHOD1(Destroy, void(struct minijail*)); + MOCK_METHOD(struct minijail*, New, (), (override)); + MOCK_METHOD(void, Destroy, (struct minijail*), (override)); - MOCK_METHOD3(DropRoot, - bool(struct minijail* jail, - const char* user, - const char* group)); - MOCK_METHOD2(UseSeccompFilter, void(struct minijail* jail, const char* path)); - MOCK_METHOD2(UseCapabilities, void(struct minijail* jail, uint64_t capmask)); - MOCK_METHOD1(ResetSignalMask, void(struct minijail* jail)); - MOCK_METHOD1(Enter, void(struct minijail* jail)); - MOCK_METHOD3(Run, - bool(struct minijail* jail, - std::vector<char*> args, - pid_t* pid)); - MOCK_METHOD3(RunSync, - bool(struct minijail* jail, - std::vector<char*> args, - int* status)); - MOCK_METHOD3(RunAndDestroy, - bool(struct minijail* jail, - std::vector<char*> args, - pid_t* pid)); - MOCK_METHOD3(RunSyncAndDestroy, - bool(struct minijail* jail, - std::vector<char*> args, - int* status)); - MOCK_METHOD4(RunPipeAndDestroy, - bool(struct minijail* jail, - std::vector<char*> args, - pid_t* pid, - int* stdin)); - MOCK_METHOD6(RunPipesAndDestroy, - bool(struct minijail* jail, - std::vector<char*> args, - pid_t* pid, - int* stdin, - int* stdout, - int* stderr)); + MOCK_METHOD(bool, + DropRoot, + (struct minijail*, const char*, const char*), + (override)); + MOCK_METHOD(void, + UseSeccompFilter, + (struct minijail*, const char*), + (override)); + MOCK_METHOD(void, UseCapabilities, (struct minijail*, uint64_t), (override)); + MOCK_METHOD(void, ResetSignalMask, (struct minijail*), (override)); + MOCK_METHOD(void, CloseOpenFds, (struct minijail*), (override)); + MOCK_METHOD(void, PreserveFd, (struct minijail*, int, int), (override)); + MOCK_METHOD(void, Enter, (struct minijail*), (override)); + MOCK_METHOD(bool, + Run, + (struct minijail*, std::vector<char*>, pid_t*), + (override)); + MOCK_METHOD(bool, + RunSync, + (struct minijail*, std::vector<char*>, int*), + (override)); + MOCK_METHOD(bool, + RunAndDestroy, + (struct minijail*, std::vector<char*>, pid_t*), + (override)); + MOCK_METHOD(bool, + RunSyncAndDestroy, + (struct minijail*, std::vector<char*>, int*), + (override)); + MOCK_METHOD(bool, + RunPipeAndDestroy, + (struct minijail*, std::vector<char*>, pid_t*, int*), + (override)); + MOCK_METHOD(bool, + RunPipesAndDestroy, + (struct minijail*, std::vector<char*>, pid_t*, int*, int*, int*), + (override)); private: DISALLOW_COPY_AND_ASSIGN(MockMinijail); diff --git a/brillo/osrelease_reader.cc b/brillo/osrelease_reader.cc index c8f660e..7e533c0 100644 --- a/brillo/osrelease_reader.cc +++ b/brillo/osrelease_reader.cc @@ -52,4 +52,9 @@ void OsReleaseReader::Load(const base::FilePath& root_dir) { initialized_ = true; } +std::vector<std::string> OsReleaseReader::GetKeys() const { + CHECK(initialized_) << "OsReleaseReader.Load() must be called first."; + return store_.GetKeys(); +} + } // namespace brillo diff --git a/brillo/osrelease_reader.h b/brillo/osrelease_reader.h index f29c14d..372d3a1 100644 --- a/brillo/osrelease_reader.h +++ b/brillo/osrelease_reader.h @@ -10,6 +10,7 @@ #define LIBBRILLO_BRILLO_OSRELEASE_READER_H_ #include <string> +#include <vector> #include <brillo/brillo_export.h> #include <brillo/key_value_store.h> @@ -36,6 +37,9 @@ class BRILLO_EXPORT OsReleaseReader final { // Getter for the given key. Returns whether the key was found on the store. bool GetString(const std::string& key, std::string* value) const; + // Getter for all the keys in /etc/os-release. + std::vector<std::string> GetKeys() const; + private: // The map storing all the key-value pairs. KeyValueStore store_; diff --git a/brillo/osrelease_reader_unittest.cc b/brillo/osrelease_reader_test.cc index 9381367..9381367 100644 --- a/brillo/osrelease_reader_unittest.cc +++ b/brillo/osrelease_reader_test.cc diff --git a/brillo/process.cc b/brillo/process.cc index ead6f20..54e91f0 100644 --- a/brillo/process.cc +++ b/brillo/process.cc @@ -195,7 +195,7 @@ void ProcessImpl::CloseUnusedFileDescriptors() { // Since we're just trying to close anything we can find, // ignore any error return values of close(). IGNORE_EINTR(close(fd)); - } + } } bool ProcessImpl::Start() { @@ -309,7 +309,7 @@ bool ProcessImpl::Start() { } else { execv(argv[0], &argv[0]); } - PLOG(ERROR) << "Exec of " << argv[0] << " failed:"; + PLOG(ERROR) << "Exec of " << argv[0] << " failed"; _exit(kErrorExitStatus); } else { // Still executing inside the parent process with known child pid. diff --git a/brillo/process_information.h b/brillo/process_information.h index 3f0a2c9..13134bd 100644 --- a/brillo/process_information.h +++ b/brillo/process_information.h @@ -31,8 +31,8 @@ class BRILLO_EXPORT ProcessInformation { const std::vector<std::string>& get_cmd_line() { return cmd_line_; } - // Set the command line array. This method DOES swap out the contents of - // |value|. The caller should expect an empty set on return. + // Set the collection of open files. This method DOES swap out the contents + // of |value|. The caller should expect an empty set on return. void set_open_files(std::set<std::string>* value) { open_files_.clear(); open_files_.swap(*value); @@ -40,8 +40,8 @@ class BRILLO_EXPORT ProcessInformation { const std::set<std::string>& get_open_files() { return open_files_; } - // Set the command line array. This method DOES swap out the contents of - // |value|. The caller should expect an empty string on return. + // Set the current working directory. This method DOES swap out the contents + // of |value|. The caller should expect an empty string on return. void set_cwd(std::string* value) { cwd_.clear(); cwd_.swap(*value); diff --git a/brillo/process_mock.h b/brillo/process_mock.h index 92ffa0a..cc33681 100644 --- a/brillo/process_mock.h +++ b/brillo/process_mock.h @@ -19,29 +19,29 @@ class ProcessMock : public Process { ProcessMock() {} virtual ~ProcessMock() {} - MOCK_METHOD1(AddArg, void(const std::string& arg)); - MOCK_METHOD1(RedirectInput, void(const std::string& input_file)); - MOCK_METHOD1(RedirectOutput, void(const std::string& output_file)); - MOCK_METHOD2(RedirectUsingPipe, void(int child_fd, bool is_input)); - MOCK_METHOD2(BindFd, void(int parent_fd, int child_fd)); - MOCK_METHOD1(SetUid, void(uid_t)); - MOCK_METHOD1(SetGid, void(gid_t)); - MOCK_METHOD1(SetCapabilities, void(uint64_t capmask)); - MOCK_METHOD1(ApplySyscallFilter, void(const std::string& path)); - MOCK_METHOD0(EnterNewPidNamespace, void()); - MOCK_METHOD1(SetInheritParentSignalMask, void(bool)); - MOCK_METHOD1(SetPreExecCallback, void(const PreExecCallback&)); - MOCK_METHOD1(SetSearchPath, void(bool)); - MOCK_METHOD1(GetPipe, int(int child_fd)); - MOCK_METHOD0(Start, bool()); - MOCK_METHOD0(Wait, int()); - MOCK_METHOD0(Run, int()); - MOCK_METHOD0(pid, pid_t()); - MOCK_METHOD2(Kill, bool(int signal, int timeout)); - MOCK_METHOD1(Reset, void(pid_t)); - MOCK_METHOD1(ResetPidByFile, bool(const std::string& pid_file)); - MOCK_METHOD0(Release, pid_t()); - MOCK_METHOD1(SetCloseUnusedFileDescriptors, void(bool close_unused_fds)); + MOCK_METHOD(void, AddArg, (const std::string&), (override)); + MOCK_METHOD(void, RedirectInput, (const std::string&), (override)); + MOCK_METHOD(void, RedirectOutput, (const std::string&), (override)); + MOCK_METHOD(void, RedirectUsingPipe, (int, bool), (override)); + MOCK_METHOD(void, BindFd, (int, int), (override)); + MOCK_METHOD(void, SetUid, (uid_t), (override)); + MOCK_METHOD(void, SetGid, (gid_t), (override)); + MOCK_METHOD(void, SetCapabilities, (uint64_t), (override)); + MOCK_METHOD(void, ApplySyscallFilter, (const std::string&), (override)); + MOCK_METHOD(void, EnterNewPidNamespace, (), (override)); + MOCK_METHOD(void, SetInheritParentSignalMask, (bool), (override)); + MOCK_METHOD(void, SetPreExecCallback, (const PreExecCallback&), (override)); + MOCK_METHOD(void, SetSearchPath, (bool), (override)); + MOCK_METHOD(int, GetPipe, (int), (override)); + MOCK_METHOD(bool, Start, (), (override)); + MOCK_METHOD(int, Wait, (), (override)); + MOCK_METHOD(int, Run, (), (override)); + MOCK_METHOD(pid_t, pid, (), (override)); + MOCK_METHOD(bool, Kill, (int, int), (override)); + MOCK_METHOD(void, Reset, (pid_t), (override)); + MOCK_METHOD(bool, ResetPidByFile, (const std::string&), (override)); + MOCK_METHOD(pid_t, Release, (), (override)); + MOCK_METHOD(void, SetCloseUnusedFileDescriptors, (bool), (override)); }; } // namespace brillo diff --git a/brillo/process_reaper.cc b/brillo/process_reaper.cc index 0da3b5d..82e3f56 100644 --- a/brillo/process_reaper.cc +++ b/brillo/process_reaper.cc @@ -8,6 +8,8 @@ #include <sys/types.h> #include <sys/wait.h> +#include <utility> + #include <base/bind.h> #include <base/posix/eintr_wrapper.h> #include <brillo/asynchronous_signal_handler.h> @@ -37,10 +39,11 @@ void ProcessReaper::Unregister() { bool ProcessReaper::WatchForChild(const base::Location& from_here, pid_t pid, - const ChildCallback& callback) { + ChildCallback callback) { if (watched_processes_.find(pid) != watched_processes_.end()) return false; - watched_processes_.emplace(pid, WatchedProcess{from_here, callback}); + watched_processes_.emplace( + pid, WatchedProcess{from_here, std::move(callback)}); return true; } @@ -79,7 +82,7 @@ bool ProcessReaper::HandleSIGCHLD( << info.si_status << " (code = " << info.si_code << ")"; ChildCallback callback = std::move(proc->second.callback); watched_processes_.erase(proc); - callback.Run(info); + std::move(callback).Run(info); } } diff --git a/brillo/process_reaper.h b/brillo/process_reaper.h index 7b70a8d..4e348a3 100644 --- a/brillo/process_reaper.h +++ b/brillo/process_reaper.h @@ -19,7 +19,7 @@ namespace brillo { class BRILLO_EXPORT ProcessReaper final { public: // The callback called when a child exits. - using ChildCallback = base::Callback<void(const siginfo_t&)>; + using ChildCallback = base::OnceCallback<void(const siginfo_t&)>; ProcessReaper() = default; ~ProcessReaper(); @@ -41,7 +41,7 @@ class BRILLO_EXPORT ProcessReaper final { // as a siginfo_t. See wait(2) for details about siginfo_t. bool WatchForChild(const base::Location& from_here, pid_t pid, - const ChildCallback& callback); + ChildCallback callback); // Stop watching child process |pid|. This is useful in situations // where the child process may have been reaped outside of the signal diff --git a/brillo/process_reaper_unittest.cc b/brillo/process_reaper_test.cc index 98498f7..7b68236 100644 --- a/brillo/process_reaper_unittest.cc +++ b/brillo/process_reaper_test.cc @@ -12,7 +12,6 @@ #include <base/location.h> #include <base/message_loop/message_loop.h> #include <brillo/asynchronous_signal_handler.h> -#include <brillo/bind_lambda.h> #include <brillo/message_loops/base_message_loop.h> #include <gtest/gtest.h> @@ -74,7 +73,7 @@ TEST_F(ProcessReaperTest, UnregisterAndReregister) { TEST_F(ProcessReaperTest, ReapExitedChild) { pid_t pid = ForkChildAndExit(123); - EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::Bind( + EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::BindOnce( [](MessageLoop* loop, const siginfo_t& info) { EXPECT_EQ(CLD_EXITED, info.si_code); EXPECT_EQ(123, info.si_status); @@ -91,7 +90,7 @@ TEST_F(ProcessReaperTest, ReapedChildrenMatchCallbacks) { // Different processes will have different exit values. int exit_value = 1 + i; pid_t pid = ForkChildAndExit(exit_value); - EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::Bind( + EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::BindOnce( [](MessageLoop* loop, int exit_value, int* running_children, const siginfo_t& info) { EXPECT_EQ(CLD_EXITED, info.si_code); @@ -110,7 +109,7 @@ TEST_F(ProcessReaperTest, ReapedChildrenMatchCallbacks) { TEST_F(ProcessReaperTest, ReapKilledChild) { pid_t pid = ForkChildAndKill(SIGKILL); - EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::Bind( + EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::BindOnce( [](MessageLoop* loop, const siginfo_t& info) { EXPECT_EQ(CLD_KILLED, info.si_code); EXPECT_EQ(SIGKILL, info.si_status); @@ -121,7 +120,7 @@ TEST_F(ProcessReaperTest, ReapKilledChild) { TEST_F(ProcessReaperTest, ReapKilledAndForgottenChild) { pid_t pid = ForkChildAndExit(0); - EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::Bind( + EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::BindOnce( [](MessageLoop* loop, const siginfo_t& /* info */) { ADD_FAILURE() << "Child process was still tracked."; loop->BreakLoop(); diff --git a/brillo/process_unittest.cc b/brillo/process_test.cc index f65cf34..533a8f0 100644 --- a/brillo/process_unittest.cc +++ b/brillo/process_test.cc @@ -12,8 +12,8 @@ #include <gtest/gtest.h> #include "brillo/process_mock.h" -#include "brillo/unittest_utils.h" #include "brillo/test_helpers.h" +#include "brillo/unittest_utils.h" using base::FilePath; diff --git a/brillo/proto_file_io.cc b/brillo/proto_file_io.cc new file mode 100644 index 0000000..47f3413 --- /dev/null +++ b/brillo/proto_file_io.cc @@ -0,0 +1,40 @@ +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/proto_file_io.h" + +#include <utility> + +#include <base/files/file.h> +#include <base/macros.h> +#include <google/protobuf/io/zero_copy_stream_impl.h> +#include <google/protobuf/text_format.h> + +namespace brillo { + +bool ReadTextProtobuf(const base::FilePath& proto_file, + google::protobuf::Message* out_proto) { + DCHECK(out_proto); + + base::File file(proto_file, base::File::FLAG_OPEN | base::File::FLAG_READ); + if (!file.IsValid()) { + DLOG(ERROR) << "Could not open \"" << proto_file.value() + << "\": " << base::File::ErrorToString(file.error_details()); + return false; + } + + return ReadTextProtobuf(file.GetPlatformFile(), out_proto); +} + +bool ReadTextProtobuf(int fd, google::protobuf::Message* out_proto) { + google::protobuf::io::FileInputStream input_stream(fd); + return google::protobuf::TextFormat::Parse(&input_stream, out_proto); +} + +bool WriteTextProtobuf(int fd, const google::protobuf::Message& proto) { + google::protobuf::io::FileOutputStream output_stream(fd); + return google::protobuf::TextFormat::Print(proto, &output_stream); +} + +} // namespace brillo diff --git a/brillo/proto_file_io.h b/brillo/proto_file_io.h new file mode 100644 index 0000000..77051cc --- /dev/null +++ b/brillo/proto_file_io.h @@ -0,0 +1,29 @@ +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_PROTO_FILE_IO_H_ +#define LIBBRILLO_BRILLO_PROTO_FILE_IO_H_ + +#include <base/files/file_path.h> +#include <brillo/brillo_export.h> +#include <google/protobuf/message.h> + +namespace brillo { + +// Simple utilities for serializing and deserializing protobufs in +// text format. For an example of the format, see the docs at +// https://developers.google.com/protocol-buffers/docs/overview#whynotxml + +BRILLO_EXPORT bool ReadTextProtobuf(const base::FilePath& proto_file, + google::protobuf::Message* out_proto); + +BRILLO_EXPORT bool ReadTextProtobuf(int fd, + google::protobuf::Message* out_proto); + +BRILLO_EXPORT bool WriteTextProtobuf(int fd, + const google::protobuf::Message& proto); + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_PROTO_FILE_IO_H_ diff --git a/brillo/scoped_mount_namespace.cc b/brillo/scoped_mount_namespace.cc new file mode 100644 index 0000000..0f35e82 --- /dev/null +++ b/brillo/scoped_mount_namespace.cc @@ -0,0 +1,66 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/scoped_mount_namespace.h" + +#include <fcntl.h> +#include <sched.h> +#include <sys/stat.h> +#include <sys/types.h> + +#include <string> +#include <utility> + +#include <base/posix/eintr_wrapper.h> +#include <base/strings/stringprintf.h> + +namespace { +constexpr char kCurrentMountNamespacePath[] = "/proc/self/ns/mnt"; +} // anonymous namespace + +namespace brillo { + +ScopedMountNamespace::ScopedMountNamespace(base::ScopedFD mount_namespace_fd) + : mount_namespace_fd_(std::move(mount_namespace_fd)) {} + +ScopedMountNamespace::~ScopedMountNamespace() { + PLOG_IF(ERROR, setns(mount_namespace_fd_.get(), CLONE_NEWNS) != 0) + << "Ignoring failure to restore original mount namespace"; +} + +// static +std::unique_ptr<ScopedMountNamespace> ScopedMountNamespace::CreateForPid( + pid_t pid) { + std::string ns_path = base::StringPrintf("/proc/%d/ns/mnt", pid); + return CreateFromPath(base::FilePath(ns_path)); +} + +// static +std::unique_ptr<ScopedMountNamespace> ScopedMountNamespace::CreateFromPath( + base::FilePath ns_path) { + base::ScopedFD original_mount_namespace_fd( + HANDLE_EINTR(open(kCurrentMountNamespacePath, O_RDONLY))); + if (!original_mount_namespace_fd.is_valid()) { + PLOG(ERROR) << "Failed to open original mount namespace FD at " + << kCurrentMountNamespacePath; + return nullptr; + } + + base::ScopedFD mount_namespace_fd( + HANDLE_EINTR(open(ns_path.value().c_str(), O_RDONLY))); + if (!mount_namespace_fd.is_valid()) { + PLOG(ERROR) << "Failed to open mount namespace FD at " << ns_path.value(); + return nullptr; + } + + if (setns(mount_namespace_fd.get(), CLONE_NEWNS) != 0) { + PLOG(ERROR) << "Failed to enter mount namespace at " << ns_path.value(); + return nullptr; + } + + return std::make_unique<ScopedMountNamespace>( + std::move(original_mount_namespace_fd)); +} + +} // namespace brillo diff --git a/brillo/scoped_mount_namespace.h b/brillo/scoped_mount_namespace.h new file mode 100644 index 0000000..f360221 --- /dev/null +++ b/brillo/scoped_mount_namespace.h @@ -0,0 +1,44 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_SCOPED_MOUNT_NAMESPACE_H_ +#define LIBBRILLO_BRILLO_SCOPED_MOUNT_NAMESPACE_H_ + +#include <memory> + +#include <base/macros.h> +#include <base/files/file_path.h> +#include <base/files/scoped_file.h> + +#include <brillo/brillo_export.h> + +namespace brillo { + +// A class that restores a mount namespace when it goes out of scope. This can +// be done by entering another process' mount namespace by using +// CreateForPid(), or by supplying a mount namespace FD directly. +class BRILLO_EXPORT ScopedMountNamespace { + public: + // Enters the process identified by |pid|'s mount namespace and returns a + // unique_ptr that restores the original mount namespace when it goes out of + // scope. + static std::unique_ptr<ScopedMountNamespace> CreateForPid(pid_t pid); + + // Enters the mount namespace identified by |path| and returns a unique_ptr + // that restores the original mount namespace when it goes out of scope. + static std::unique_ptr<ScopedMountNamespace> CreateFromPath( + base::FilePath ns_path); + + explicit ScopedMountNamespace(base::ScopedFD mount_namespace_fd); + ~ScopedMountNamespace(); + + private: + base::ScopedFD mount_namespace_fd_; + + DISALLOW_COPY_AND_ASSIGN(ScopedMountNamespace); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_SCOPED_MOUNT_NAMESPACE_H_ diff --git a/brillo/scoped_umask.cc b/brillo/scoped_umask.cc new file mode 100644 index 0000000..ac6b208 --- /dev/null +++ b/brillo/scoped_umask.cc @@ -0,0 +1,19 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/scoped_umask.h" + +#include <sys/stat.h> + +namespace brillo { + +ScopedUmask::ScopedUmask(mode_t new_umask) { + saved_umask_ = umask(new_umask); +} + +ScopedUmask::~ScopedUmask() { + umask(saved_umask_); +} + +} // namespace brillo diff --git a/brillo/scoped_umask.h b/brillo/scoped_umask.h new file mode 100644 index 0000000..5369e83 --- /dev/null +++ b/brillo/scoped_umask.h @@ -0,0 +1,52 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_SCOPED_UMASK_H_ +#define LIBBRILLO_BRILLO_SCOPED_UMASK_H_ + +#include <sys/types.h> + +#include <base/macros.h> +#include <brillo/brillo_export.h> + +namespace brillo { + +// ScopedUmask is a helper class for temporarily setting the umask before a +// set of operations. umask(2) is never expected to fail. +class BRILLO_EXPORT ScopedUmask { + public: + explicit ScopedUmask(mode_t new_umask); + ~ScopedUmask(); + + private: + mode_t saved_umask_; + + // Avoid reusing ScopedUmask for multiple masks. DISALLOW_COPY_AND_ASSIGN + // deletes the copy constructor and operator=, but there are other situations + // where reassigning a new ScopedUmask to an existing ScopedUmask object + // is problematic: + // + // /* starting umask: default_value + // auto a = std::make_unique<ScopedUmask>(first_value); + // ... code here ... + // a.reset(ScopedUmask(new_value)); + // + // Here, the order of destruction of the old object and the construction of + // the new object is inverted. The recommended usage would be: + // + // { + // ScopedUmask a(old_value); + // ... code here ... + // } + // + // { + // ScopedUmask a(new_value); + // ... code here ... + // } + DISALLOW_COPY_AND_ASSIGN(ScopedUmask); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_SCOPED_UMASK_H_ diff --git a/brillo/scoped_umask_test.cc b/brillo/scoped_umask_test.cc new file mode 100644 index 0000000..d1caa3c --- /dev/null +++ b/brillo/scoped_umask_test.cc @@ -0,0 +1,57 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/scoped_umask.h" + +#include <fcntl.h> + +#include <base/files/file_path.h> +#include <base/files/file_util.h> +#include <base/files/scoped_file.h> +#include <base/files/scoped_temp_dir.h> +#include <gtest/gtest.h> + +namespace brillo { +namespace { + +constexpr int kPermissions600 = + base::FILE_PERMISSION_READ_BY_USER | base::FILE_PERMISSION_WRITE_BY_USER; +constexpr int kPermissions700 = base::FILE_PERMISSION_USER_MASK; +constexpr mode_t kMask700 = ~(0700); +constexpr mode_t kMask600 = ~(0600); + +void CheckFilePermissions(const base::FilePath& path, + int expected_permissions) { + int mode = 0; + // Try to create a file with broader permissions than the mask may provide. + base::ScopedFD fd( + HANDLE_EINTR(open(path.value().c_str(), O_WRONLY | O_CREAT, 0777))); + EXPECT_TRUE(fd.is_valid()); + EXPECT_TRUE(base::GetPosixFilePermissions(path, &mode)); + EXPECT_EQ(mode, expected_permissions); +} + +} // namespace + +TEST(ScopedUmask, CheckUmaskScope) { + base::ScopedTempDir tmpdir; + CHECK(tmpdir.CreateUniqueTempDir()); + + brillo::ScopedUmask outer_scoped_umask_(kMask700); + CheckFilePermissions(tmpdir.GetPath().AppendASCII("file1.txt"), + kPermissions700); + { + // A new scoped umask should result in different permissions for files + // created in this scope. + brillo::ScopedUmask inner_scoped_umask_(kMask600); + CheckFilePermissions(tmpdir.GetPath().AppendASCII("file2.txt"), + kPermissions600); + } + // Since inner_scoped_umask_ has been deconstructed, permissions on all new + // files should now use outer_scoped_umask_. + CheckFilePermissions(tmpdir.GetPath().AppendASCII("file3.txt"), + kPermissions700); +} + +} // namespace brillo diff --git a/brillo/secure_allocator.h b/brillo/secure_allocator.h new file mode 100644 index 0000000..0cbb8d9 --- /dev/null +++ b/brillo/secure_allocator.h @@ -0,0 +1,61 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_SECURE_ALLOCATOR_H_ +#define LIBBRILLO_BRILLO_SECURE_ALLOCATOR_H_ + +#include <memory> + +#include <brillo/brillo_export.h> + +namespace brillo { +// SecureAllocator is a stateless derivation of std::allocator that clears +// the contents of the object on deallocation. +template <typename T> +class BRILLO_PRIVATE SecureAllocator : public std::allocator<T> { + public: + using typename std::allocator<T>::pointer; + using typename std::allocator<T>::size_type; + using typename std::allocator<T>::value_type; + + // Implicit std::allocator constructors. + + template <typename U> struct rebind { + typedef SecureAllocator<U> other; + }; + + // Allocation/deallocation: use the std::allocation functions but make sure + // that on deallocation, the contents of the element are cleared out. + pointer allocate(size_type n, pointer = {}) { + return std::allocator<T>::allocate(n); + } + + virtual void deallocate(pointer p, size_type n) { + clear_contents(p, n * sizeof(value_type)); + std::allocator<T>::deallocate(p, n); + } + + protected: +// Force memset to not be optimized out. +// Original source commit: 31b02653c2560f8331934e879263beda44c6cc76 +// Repo: https://android.googlesource.com/platform/external/minijail +#if defined(__clang__) +#define __attribute_no_opt __attribute__((optnone)) +#else +#define __attribute_no_opt __attribute__((__optimize__(0))) +#endif + + // Zero-out all bytes in the allocated buffer. + virtual void __attribute_no_opt clear_contents(pointer v, size_type n) { + if (!v) + return; + memset(v, 0, n); + } + +#undef __attribute_no_opt +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_SECURE_ALLOCATOR_H_ diff --git a/brillo/secure_blob.cc b/brillo/secure_blob.cc index f4b797f..0c7026a 100644 --- a/brillo/secure_blob.cc +++ b/brillo/secure_blob.cc @@ -11,6 +11,23 @@ namespace brillo { +namespace { + +bool ConvertHexToBytes(char c, uint8_t* v) { + if (c >= '0' && c <='9') + *v = c - '0'; + else if (c >= 'a' && c <= 'f') + *v = c - 'a' + 10; + else if (c >= 'A' && c <= 'F') + *v = c - 'A' + 10; + else + return false; + + return true; +} + +} // namespace + std::string BlobToString(const Blob& blob) { return std::string(blob.begin(), blob.end()); } @@ -37,26 +54,25 @@ SecureBlob::SecureBlob(const std::string& data) : SecureBlob(data.begin(), data.end()) {} SecureBlob::~SecureBlob() { - clear(); + SecureVector::clear(); } void SecureBlob::resize(size_type count) { if (count < size()) { SecureMemset(data() + count, 0, capacity() - count); } - Blob::resize(count); + SecureVector::resize(count); } void SecureBlob::resize(size_type count, const value_type& value) { if (count < size()) { SecureMemset(data() + count, 0, capacity() - count); } - Blob::resize(count, value); + SecureVector::resize(count, value); } void SecureBlob::clear() { - SecureMemset(data(), 0, capacity()); - Blob::clear(); + SecureVector::clear(); } std::string SecureBlob::to_string() const { @@ -109,4 +125,38 @@ int SecureMemcmp(const void* s1, const void* s2, size_t n) { return result != 0; } +// base::HexEncode and base::HexStringToBytes use strings, which may leak +// contents. These functions are alternatives that keep all contents +// within secured memory. +SecureBlob SecureBlobToSecureHex(const SecureBlob& blob) { + std::string kHexChars("0123456789ABCDEF"); + SecureBlob hex(blob.size() * 2, 0); + const char* blob_char_data = blob.char_data(); + + // Each input byte creates two output hex characters. + for (size_t i = 0; i < blob.size(); ++i) { + hex[(i * 2)] = kHexChars[(blob_char_data[i] >> 4) & 0xf]; + hex[(i * 2) + 1] = kHexChars[blob_char_data[i] & 0xf]; + } + return hex; +} + +SecureBlob SecureHexToSecureBlob(const SecureBlob& hex) { + SecureBlob blob(hex.size()/2, 0); + + if (hex.size() == 0 || hex.size() % 2) + return SecureBlob(); + + for (size_t i = 0; i < hex.size(); i++) { + uint8_t v; + // Check for invalid characters. + if (!ConvertHexToBytes(hex[i], &v)) + return SecureBlob(); + + blob[i/2] = (blob[i/2] << 4) | (v & 0xf); + } + + return blob; +} + } // namespace brillo diff --git a/brillo/secure_blob.h b/brillo/secure_blob.h index 7b6d03c..e06646d 100644 --- a/brillo/secure_blob.h +++ b/brillo/secure_blob.h @@ -11,10 +11,16 @@ #include <brillo/asan.h> #include <brillo/brillo_export.h> +#include <brillo/secure_allocator.h> namespace brillo { using Blob = std::vector<uint8_t>; +// Define SecureVector as a vector using a SecureAllocator. +// Over time, the goal is to remove the differentiating functions +// from SecureBlob (to_string(), char_data()) till it converges with +// SecureVector. +using SecureVector = std::vector<uint8_t, SecureAllocator<uint8_t>>; // Conversion of Blob to/from std::string, where the string holds raw byte // contents. @@ -26,10 +32,11 @@ BRILLO_EXPORT Blob CombineBlobs(const std::initializer_list<Blob>& blobs); // SecureBlob erases the contents on destruction. It does not guarantee erasure // on resize, assign, etc. -class BRILLO_EXPORT SecureBlob : public Blob { +class BRILLO_EXPORT SecureBlob : public SecureVector { public: SecureBlob() = default; - using Blob::vector; // Inherit standard constructors from vector. + // Inherit standard constructors from SecureVector. + using SecureVector::vector; explicit SecureBlob(const Blob& blob); explicit SecureBlob(const std::string& data); ~SecureBlob(); @@ -69,6 +76,14 @@ BRILLO_EXPORT BRILLO_DISABLE_ASAN void* SecureMemset(void* v, int c, size_t n); // [n] and not on the relationship of the match between [s1] and [s2]. BRILLO_EXPORT int SecureMemcmp(const void* s1, const void* s2, size_t n); +// Conversion of SecureBlob data to/from SecureBlob hex. This is useful +// for sensitive data like encryption keys, that should, in the ideal case never +// be exposed as strings in the first place. In case the existing data or hex +// string is already exposed as a std::string, it is preferable to use the +// BlobToString variant. +BRILLO_EXPORT SecureBlob SecureBlobToSecureHex(const SecureBlob& blob); +BRILLO_EXPORT SecureBlob SecureHexToSecureBlob(const SecureBlob& hex); + } // namespace brillo #endif // LIBBRILLO_BRILLO_SECURE_BLOB_H_ diff --git a/brillo/secure_blob_unittest.cc b/brillo/secure_blob_test.cc index ff95d0f..75f3cfb 100644 --- a/brillo/secure_blob_unittest.cc +++ b/brillo/secure_blob_test.cc @@ -5,6 +5,7 @@ // Unit tests for SecureBlob. #include "brillo/asan.h" +#include "brillo/secure_allocator.h" #include "brillo/secure_blob.h" #include <algorithm> @@ -227,4 +228,38 @@ TEST_F(SecureBlobTest, HexStringToSecureBlob) { EXPECT_EQ(blob[15], 0x0f); } +// Override clear_contents() to check whether memory has been cleared. +template <typename T> +class TestSecureAllocator : public SecureAllocator<T> { + public: + using typename SecureAllocator<T>::pointer; + using typename SecureAllocator<T>::size_type; + + int GetErasedCount() { return erased_count; } + + protected: + void clear_contents(pointer p, size_type n) override { + SecureAllocator<T>::clear_contents(p, n); + for (int i = 0; i < n; i++) { + EXPECT_EQ(p[i], 0); + } + erased_count++; + } + + private: + int erased_count = 0; +}; + +TEST(SecureAllocator, ErasureOnDeallocation) { + // Make sure that the contents are cleared on deallocation. + TestSecureAllocator<char> e; + + char *test_string_addr = e.allocate(15); + snprintf(test_string_addr, sizeof(test_string_addr), "Test String"); + + // Deallocate memory; the mock class should check for cleared data. + e.deallocate(test_string_addr, 15); + EXPECT_EQ(e.GetErasedCount(), 1); +} + } // namespace brillo diff --git a/brillo/streams/fake_stream.cc b/brillo/streams/fake_stream.cc index 498b9d4..9d7a044 100644 --- a/brillo/streams/fake_stream.cc +++ b/brillo/streams/fake_stream.cc @@ -5,6 +5,7 @@ #include <brillo/streams/fake_stream.h> #include <algorithm> +#include <utility> #include <base/bind.h> #include <brillo/message_loops/message_loop.h> @@ -185,7 +186,7 @@ bool FakeStream::IsReadBufferEmpty() const { bool FakeStream::PopReadPacket() { if (incoming_queue_.empty()) return false; - const InputDataPacket& packet = incoming_queue_.front(); + InputDataPacket& packet = incoming_queue_.front(); input_ptr_ = 0; input_buffer_ = std::move(packet.data); delay_input_until_ = clock_->Now() + packet.delay_before; @@ -250,7 +251,7 @@ bool FakeStream::IsWriteBufferFull() const { bool FakeStream::PopWritePacket() { if (outgoing_queue_.empty()) return false; - const OutputDataPacket& packet = outgoing_queue_.front(); + OutputDataPacket& packet = outgoing_queue_.front(); expected_output_data_ = std::move(packet.data); delay_output_until_ = clock_->Now() + packet.delay_before; max_output_buffer_size_ = packet.expected_size; diff --git a/brillo/streams/fake_stream_unittest.cc b/brillo/streams/fake_stream_test.cc index 2404514..2e83e3b 100644 --- a/brillo/streams/fake_stream_unittest.cc +++ b/brillo/streams/fake_stream_test.cc @@ -4,11 +4,12 @@ #include <brillo/streams/fake_stream.h> +#include <memory> #include <vector> +#include <base/bind.h> #include <base/callback.h> #include <base/test/simple_test_clock.h> -#include <brillo/bind_lambda.h> #include <brillo/message_loops/mock_message_loop.h> #include <gmock/gmock.h> #include <gtest/gtest.h> diff --git a/brillo/streams/file_stream.cc b/brillo/streams/file_stream.cc index 7b28a5a..db22192 100644 --- a/brillo/streams/file_stream.cc +++ b/brillo/streams/file_stream.cc @@ -4,11 +4,13 @@ #include <brillo/streams/file_stream.h> -#include <algorithm> #include <fcntl.h> #include <sys/stat.h> #include <unistd.h> +#include <algorithm> +#include <utility> + #include <base/bind.h> #include <base/files/file_util.h> #include <base/posix/eintr_wrapper.h> diff --git a/brillo/streams/file_stream.h b/brillo/streams/file_stream.h index 1cf39b5..bf54617 100644 --- a/brillo/streams/file_stream.h +++ b/brillo/streams/file_stream.h @@ -5,6 +5,8 @@ #ifndef LIBBRILLO_BRILLO_STREAMS_FILE_STREAM_H_ #define LIBBRILLO_BRILLO_STREAMS_FILE_STREAM_H_ +#include <memory> + #include <base/files/file_path.h> #include <base/macros.h> #include <brillo/brillo_export.h> diff --git a/brillo/streams/file_stream_unittest.cc b/brillo/streams/file_stream_test.cc index 210725e..23ef64c 100644 --- a/brillo/streams/file_stream_unittest.cc +++ b/brillo/streams/file_stream_test.cc @@ -4,18 +4,20 @@ #include <brillo/streams/file_stream.h> +#include <sys/stat.h> + #include <limits> #include <numeric> #include <string> -#include <sys/stat.h> +#include <utility> #include <vector> +#include <base/bind.h> #include <base/files/file_util.h> #include <base/files/scoped_temp_dir.h> #include <base/message_loop/message_loop.h> #include <base/rand_util.h> #include <base/run_loop.h> -#include <brillo/bind_lambda.h> #include <brillo/errors/error_codes.h> #include <brillo/message_loops/base_message_loop.h> #include <brillo/message_loops/message_loop_utils.h> @@ -130,20 +132,23 @@ void SetToTrue(bool* target, const Error* /* error */) { // A mock file descriptor wrapper to test low-level file API used by FileStream. class MockFileDescriptor : public FileStream::FileDescriptorInterface { public: - MOCK_CONST_METHOD0(IsOpen, bool()); - MOCK_METHOD2(Read, ssize_t(void*, size_t)); - MOCK_METHOD2(Write, ssize_t(const void*, size_t)); - MOCK_METHOD2(Seek, off64_t(off64_t, int)); - MOCK_CONST_METHOD0(GetFileMode, mode_t()); - MOCK_CONST_METHOD0(GetSize, uint64_t()); - MOCK_CONST_METHOD1(Truncate, int(off64_t)); - MOCK_METHOD0(Flush, int()); - MOCK_METHOD0(Close, int()); - MOCK_METHOD3(WaitForData, - bool(Stream::AccessMode, const DataCallback&, ErrorPtr*)); - MOCK_METHOD3(WaitForDataBlocking, - int(Stream::AccessMode, base::TimeDelta, Stream::AccessMode*)); - MOCK_METHOD0(CancelPendingAsyncOperations, void()); + MOCK_METHOD(bool, IsOpen, (), (const, override)); + MOCK_METHOD(ssize_t, Read, (void*, size_t), (override)); + MOCK_METHOD(ssize_t, Write, (const void*, size_t), (override)); + MOCK_METHOD(off64_t, Seek, (off64_t, int), (override)); + MOCK_METHOD(mode_t, GetFileMode, (), (const, override)); + MOCK_METHOD(uint64_t, GetSize, (), (const, override)); + MOCK_METHOD(int, Truncate, (off64_t), (const, override)); + MOCK_METHOD(int, Close, (), (override)); + MOCK_METHOD(bool, + WaitForData, + (Stream::AccessMode, const DataCallback&, ErrorPtr*), + (override)); + MOCK_METHOD(int, + WaitForDataBlocking, + (Stream::AccessMode, base::TimeDelta, Stream::AccessMode*), + (override)); + MOCK_METHOD(void, CancelPendingAsyncOperations, (), (override)); }; class FileStreamTest : public testing::Test { diff --git a/brillo/streams/input_stream_set.cc b/brillo/streams/input_stream_set.cc index 986efac..0ceb5b4 100644 --- a/brillo/streams/input_stream_set.cc +++ b/brillo/streams/input_stream_set.cc @@ -4,6 +4,8 @@ #include <brillo/streams/input_stream_set.h> +#include <utility> + #include <base/bind.h> #include <brillo/message_loops/message_loop.h> #include <brillo/streams/stream_errors.h> diff --git a/brillo/streams/input_stream_set_unittest.cc b/brillo/streams/input_stream_set_test.cc index 3268d96..9a29248 100644 --- a/brillo/streams/input_stream_set_unittest.cc +++ b/brillo/streams/input_stream_set_test.cc @@ -4,13 +4,14 @@ #include <brillo/streams/input_stream_set.h> +#include <memory> + #include <brillo/errors/error_codes.h> #include <brillo/streams/mock_stream.h> #include <brillo/streams/stream_errors.h> #include <gmock/gmock.h> #include <gtest/gtest.h> -using testing::An; using testing::DoAll; using testing::InSequence; using testing::Return; diff --git a/brillo/streams/memory_containers.h b/brillo/streams/memory_containers.h index d3cb205..22488d8 100644 --- a/brillo/streams/memory_containers.h +++ b/brillo/streams/memory_containers.h @@ -6,6 +6,7 @@ #define LIBBRILLO_BRILLO_STREAMS_MEMORY_CONTAINERS_H_ #include <string> +#include <utility> #include <vector> #include <brillo/brillo_export.h> diff --git a/brillo/streams/memory_containers_unittest.cc b/brillo/streams/memory_containers_test.cc index 2f0bf38..8b56ade 100644 --- a/brillo/streams/memory_containers_unittest.cc +++ b/brillo/streams/memory_containers_test.cc @@ -26,14 +26,20 @@ class MockContiguousBuffer : public data_container::ContiguousBufferBase { public: MockContiguousBuffer() = default; - MOCK_METHOD2(Resize, bool(size_t, ErrorPtr*)); - MOCK_CONST_METHOD0(GetSize, size_t()); - MOCK_CONST_METHOD0(IsReadOnly, bool()); - - MOCK_CONST_METHOD2(GetReadOnlyBuffer, const void*(size_t, ErrorPtr*)); - MOCK_METHOD2(GetBuffer, void*(size_t, ErrorPtr*)); - - MOCK_CONST_METHOD3(CopyMemoryBlock, void(void*, const void*, size_t)); + MOCK_METHOD(bool, Resize, (size_t, ErrorPtr*), (override)); + MOCK_METHOD(size_t, GetSize, (), (const, override)); + MOCK_METHOD(bool, IsReadOnly, (), (const, override)); + + MOCK_METHOD(const void*, + GetReadOnlyBuffer, + (size_t, ErrorPtr*), + (const, override)); + MOCK_METHOD(void*, GetBuffer, (size_t, ErrorPtr*), (override)); + + MOCK_METHOD(void, + CopyMemoryBlock, + (void*, const void*, size_t), + (const, override)); private: DISALLOW_COPY_AND_ASSIGN(MockContiguousBuffer); diff --git a/brillo/streams/memory_stream.h b/brillo/streams/memory_stream.h index b4927a8..e748f47 100644 --- a/brillo/streams/memory_stream.h +++ b/brillo/streams/memory_stream.h @@ -5,7 +5,9 @@ #ifndef LIBBRILLO_BRILLO_STREAMS_MEMORY_STREAM_H_ #define LIBBRILLO_BRILLO_STREAMS_MEMORY_STREAM_H_ +#include <memory> #include <string> +#include <utility> #include <vector> #include <base/macros.h> diff --git a/brillo/streams/memory_stream_unittest.cc b/brillo/streams/memory_stream_test.cc index 75278f7..28a88fa 100644 --- a/brillo/streams/memory_stream_unittest.cc +++ b/brillo/streams/memory_stream_test.cc @@ -8,6 +8,7 @@ #include <limits> #include <numeric> #include <string> +#include <utility> #include <vector> #include <brillo/streams/stream_errors.h> @@ -32,11 +33,17 @@ class MockMemoryContainer : public data_container::DataContainerInterface { public: MockMemoryContainer() = default; - MOCK_METHOD5(Read, bool(void*, size_t, size_t, size_t*, ErrorPtr*)); - MOCK_METHOD5(Write, bool(const void*, size_t, size_t, size_t*, ErrorPtr*)); - MOCK_METHOD2(Resize, bool(size_t, ErrorPtr*)); - MOCK_CONST_METHOD0(GetSize, size_t()); - MOCK_CONST_METHOD0(IsReadOnly, bool()); + MOCK_METHOD(bool, + Read, + (void*, size_t, size_t, size_t*, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + Write, + (const void*, size_t, size_t, size_t*, ErrorPtr*), + (override)); + MOCK_METHOD(bool, Resize, (size_t, ErrorPtr*), (override)); + MOCK_METHOD(size_t, GetSize, (), (const, override)); + MOCK_METHOD(bool, IsReadOnly, (), (const, override)); private: DISALLOW_COPY_AND_ASSIGN(MockMemoryContainer); diff --git a/brillo/streams/mock_stream.h b/brillo/streams/mock_stream.h index 934912a..45f83ed 100644 --- a/brillo/streams/mock_stream.h +++ b/brillo/streams/mock_stream.h @@ -16,55 +16,82 @@ class MockStream : public Stream { public: MockStream() = default; - MOCK_CONST_METHOD0(IsOpen, bool()); - MOCK_CONST_METHOD0(CanRead, bool()); - MOCK_CONST_METHOD0(CanWrite, bool()); - MOCK_CONST_METHOD0(CanSeek, bool()); - MOCK_CONST_METHOD0(CanGetSize, bool()); + MOCK_METHOD(bool, IsOpen, (), (const, override)); + MOCK_METHOD(bool, CanRead, (), (const, override)); + MOCK_METHOD(bool, CanWrite, (), (const, override)); + MOCK_METHOD(bool, CanSeek, (), (const, override)); + MOCK_METHOD(bool, CanGetSize, (), (const, override)); - MOCK_CONST_METHOD0(GetSize, uint64_t()); - MOCK_METHOD2(SetSizeBlocking, bool(uint64_t, ErrorPtr*)); - MOCK_CONST_METHOD0(GetRemainingSize, uint64_t()); + MOCK_METHOD(uint64_t, GetSize, (), (const, override)); + MOCK_METHOD(bool, SetSizeBlocking, (uint64_t, ErrorPtr*), (override)); + MOCK_METHOD(uint64_t, GetRemainingSize, (), (const, override)); - MOCK_CONST_METHOD0(GetPosition, uint64_t()); - MOCK_METHOD4(Seek, bool(int64_t, Whence, uint64_t*, ErrorPtr*)); + MOCK_METHOD(uint64_t, GetPosition, (), (const, override)); + MOCK_METHOD(bool, Seek, (int64_t, Whence, uint64_t*, ErrorPtr*), (override)); - MOCK_METHOD5(ReadAsync, bool(void*, - size_t, - const base::Callback<void(size_t)>&, - const ErrorCallback&, - ErrorPtr*)); - MOCK_METHOD5(ReadAllAsync, bool(void*, - size_t, - const base::Closure&, - const ErrorCallback&, - ErrorPtr*)); - MOCK_METHOD5(ReadNonBlocking, bool(void*, size_t, size_t*, bool*, ErrorPtr*)); - MOCK_METHOD4(ReadBlocking, bool(void*, size_t, size_t*, ErrorPtr*)); - MOCK_METHOD3(ReadAllBlocking, bool(void*, size_t, ErrorPtr*)); + MOCK_METHOD(bool, + ReadAsync, + (void*, + size_t, + const base::Callback<void(size_t)>&, + const ErrorCallback&, + ErrorPtr*), + (override)); + MOCK_METHOD( + bool, + ReadAllAsync, + (void*, size_t, const base::Closure&, const ErrorCallback&, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + ReadNonBlocking, + (void*, size_t, size_t*, bool*, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + ReadBlocking, + (void*, size_t, size_t*, ErrorPtr*), + (override)); + MOCK_METHOD(bool, ReadAllBlocking, (void*, size_t, ErrorPtr*), (override)); - MOCK_METHOD5(WriteAsync, bool(const void*, - size_t, - const base::Callback<void(size_t)>&, - const ErrorCallback&, - ErrorPtr*)); - MOCK_METHOD5(WriteAllAsync, bool(const void*, - size_t, - const base::Closure&, - const ErrorCallback&, - ErrorPtr*)); - MOCK_METHOD4(WriteNonBlocking, bool(const void*, size_t, size_t*, ErrorPtr*)); - MOCK_METHOD4(WriteBlocking, bool(const void*, size_t, size_t*, ErrorPtr*)); - MOCK_METHOD3(WriteAllBlocking, bool(const void*, size_t, ErrorPtr*)); + MOCK_METHOD(bool, + WriteAsync, + (const void*, + size_t, + const base::Callback<void(size_t)>&, + const ErrorCallback&, + ErrorPtr*), + (override)); + MOCK_METHOD(bool, + WriteAllAsync, + (const void*, + size_t, + const base::Closure&, + const ErrorCallback&, + ErrorPtr*), + (override)); + MOCK_METHOD(bool, + WriteNonBlocking, + (const void*, size_t, size_t*, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + WriteBlocking, + (const void*, size_t, size_t*, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + WriteAllBlocking, + (const void*, size_t, ErrorPtr*), + (override)); - MOCK_METHOD1(FlushBlocking, bool(ErrorPtr*)); - MOCK_METHOD1(CloseBlocking, bool(ErrorPtr*)); + MOCK_METHOD(bool, FlushBlocking, (ErrorPtr*), (override)); + MOCK_METHOD(bool, CloseBlocking, (ErrorPtr*), (override)); - MOCK_METHOD3(WaitForData, bool(AccessMode, - const base::Callback<void(AccessMode)>&, - ErrorPtr*)); - MOCK_METHOD4(WaitForDataBlocking, - bool(AccessMode, base::TimeDelta, AccessMode*, ErrorPtr*)); + MOCK_METHOD(bool, + WaitForData, + (AccessMode, const base::Callback<void(AccessMode)>&, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + WaitForDataBlocking, + (AccessMode, base::TimeDelta, AccessMode*, ErrorPtr*), + (override)); private: DISALLOW_COPY_AND_ASSIGN(MockStream); diff --git a/brillo/streams/openssl_stream_bio.cc b/brillo/streams/openssl_stream_bio.cc index a63d9c0..478b112 100644 --- a/brillo/streams/openssl_stream_bio.cc +++ b/brillo/streams/openssl_stream_bio.cc @@ -13,9 +13,32 @@ namespace brillo { namespace { +// TODO(crbug.com/984789): Remove once support for OpenSSL <1.1 is dropped. +#if OPENSSL_VERSION_NUMBER < 0x10100000L +static void BIO_set_data(BIO* a, void* ptr) { + a->ptr = ptr; +} + +static void* BIO_get_data(BIO* a) { + return a->ptr; +} + +static void BIO_set_init(BIO* a, int init) { + a->init = init; +} + +static int BIO_get_init(BIO* a) { + return a->init; +} + +static void BIO_set_shutdown(BIO* a, int shut) { + a->shutdown = shut; +} +#endif + // Internal functions for implementing OpenSSL BIO on brillo::Stream. int stream_write(BIO* bio, const char* buf, int size) { - brillo::Stream* stream = static_cast<brillo::Stream*>(bio->ptr); + brillo::Stream* stream = static_cast<brillo::Stream*>(BIO_get_data(bio)); size_t written = 0; BIO_clear_retry_flags(bio); if (!stream->WriteNonBlocking(buf, size, &written, nullptr)) @@ -30,7 +53,7 @@ int stream_write(BIO* bio, const char* buf, int size) { } int stream_read(BIO* bio, char* buf, int size) { - brillo::Stream* stream = static_cast<brillo::Stream*>(bio->ptr); + brillo::Stream* stream = static_cast<brillo::Stream*>(BIO_get_data(bio)); size_t read = 0; BIO_clear_retry_flags(bio); bool eos = false; @@ -49,16 +72,16 @@ int stream_read(BIO* bio, char* buf, int size) { // NOLINTNEXTLINE(runtime/int) long stream_ctrl(BIO* bio, int cmd, long /* num */, void* /* ptr */) { if (cmd == BIO_CTRL_FLUSH) { - brillo::Stream* stream = static_cast<brillo::Stream*>(bio->ptr); + brillo::Stream* stream = static_cast<brillo::Stream*>(BIO_get_data(bio)); return stream->FlushBlocking(nullptr) ? 1 : 0; } return 0; } int stream_new(BIO* bio) { - bio->shutdown = 0; // By default do not close underlying stream on shutdown. - bio->init = 0; - bio->num = -1; // not used. + // By default do not close underlying stream on shutdown. + BIO_set_shutdown(bio, 0); + BIO_set_init(bio, 0); return 1; } @@ -66,13 +89,17 @@ int stream_free(BIO* bio) { if (!bio) return 0; - if (bio->init) { - bio->ptr = nullptr; - bio->init = 0; + if (BIO_get_init(bio)) { + BIO_set_data(bio, nullptr); + BIO_set_init(bio, 0); } return 1; } +#if OPENSSL_VERSION_NUMBER < 0x10100000L +// TODO(crbug.com/984789): Remove #ifdef once support for OpenSSL <1.1 is +// dropped. + // BIO_METHOD structure describing the BIO built on top of brillo::Stream. BIO_METHOD stream_method = { 0x7F | BIO_TYPE_SOURCE_SINK, // type: 0x7F is an arbitrary unused type ID. @@ -87,13 +114,37 @@ BIO_METHOD stream_method = { nullptr, // callback function, not used }; +BIO_METHOD* stream_get_method() { + return &stream_method; +} + +#else + +BIO_METHOD* stream_get_method() { + static BIO_METHOD* stream_method; + + if (!stream_method) { + stream_method = BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, + "stream"); + BIO_meth_set_write(stream_method, stream_write); + BIO_meth_set_read(stream_method, stream_read); + BIO_meth_set_ctrl(stream_method, stream_ctrl); + BIO_meth_set_create(stream_method, stream_new); + BIO_meth_set_destroy(stream_method, stream_free); + } + + return stream_method; +} + +#endif + } // anonymous namespace BIO* BIO_new_stream(brillo::Stream* stream) { - BIO* bio = BIO_new(&stream_method); + BIO* bio = BIO_new(stream_get_method()); if (bio) { - bio->ptr = stream; - bio->init = 1; + BIO_set_data(bio, stream); + BIO_set_init(bio, 1); } return bio; } diff --git a/brillo/streams/openssl_stream_bio_unittests.cc b/brillo/streams/openssl_stream_bio_test.cc index a80710d..a80710d 100644 --- a/brillo/streams/openssl_stream_bio_unittests.cc +++ b/brillo/streams/openssl_stream_bio_test.cc diff --git a/brillo/streams/stream_unittest.cc b/brillo/streams/stream_test.cc index c341cde..8cb99a9 100644 --- a/brillo/streams/stream_unittest.cc +++ b/brillo/streams/stream_test.cc @@ -6,11 +6,11 @@ #include <limits> +#include <base/bind.h> #include <base/callback.h> #include <gmock/gmock.h> #include <gtest/gtest.h> -#include <brillo/bind_lambda.h> #include <brillo/message_loops/fake_message_loop.h> #include <brillo/streams/stream_errors.h> @@ -42,39 +42,48 @@ class MockStreamImpl : public Stream { public: MockStreamImpl() = default; - MOCK_CONST_METHOD0(IsOpen, bool()); - MOCK_CONST_METHOD0(CanRead, bool()); - MOCK_CONST_METHOD0(CanWrite, bool()); - MOCK_CONST_METHOD0(CanSeek, bool()); - MOCK_CONST_METHOD0(CanGetSize, bool()); + MOCK_METHOD(bool, IsOpen, (), (const, override)); + MOCK_METHOD(bool, CanRead, (), (const, override)); + MOCK_METHOD(bool, CanWrite, (), (const, override)); + MOCK_METHOD(bool, CanSeek, (), (const, override)); + MOCK_METHOD(bool, CanGetSize, (), (const, override)); - MOCK_CONST_METHOD0(GetSize, uint64_t()); - MOCK_METHOD2(SetSizeBlocking, bool(uint64_t, ErrorPtr*)); - MOCK_CONST_METHOD0(GetRemainingSize, uint64_t()); + MOCK_METHOD(uint64_t, GetSize, (), (const, override)); + MOCK_METHOD(bool, SetSizeBlocking, (uint64_t, ErrorPtr*), (override)); + MOCK_METHOD(uint64_t, GetRemainingSize, (), (const, override)); - MOCK_CONST_METHOD0(GetPosition, uint64_t()); - MOCK_METHOD4(Seek, bool(int64_t, Whence, uint64_t*, ErrorPtr*)); + MOCK_METHOD(uint64_t, GetPosition, (), (const, override)); + MOCK_METHOD(bool, Seek, (int64_t, Whence, uint64_t*, ErrorPtr*), (override)); // Omitted: ReadAsync // Omitted: ReadAllAsync - MOCK_METHOD5(ReadNonBlocking, bool(void*, size_t, size_t*, bool*, ErrorPtr*)); + MOCK_METHOD(bool, + ReadNonBlocking, + (void*, size_t, size_t*, bool*, ErrorPtr*), + (override)); // Omitted: ReadBlocking // Omitted: ReadAllBlocking // Omitted: WriteAsync // Omitted: WriteAllAsync - MOCK_METHOD4(WriteNonBlocking, bool(const void*, size_t, size_t*, ErrorPtr*)); + MOCK_METHOD(bool, + WriteNonBlocking, + (const void*, size_t, size_t*, ErrorPtr*), + (override)); // Omitted: WriteBlocking // Omitted: WriteAllBlocking - MOCK_METHOD1(FlushBlocking, bool(ErrorPtr*)); - MOCK_METHOD1(CloseBlocking, bool(ErrorPtr*)); + MOCK_METHOD(bool, FlushBlocking, (ErrorPtr*), (override)); + MOCK_METHOD(bool, CloseBlocking, (ErrorPtr*), (override)); - MOCK_METHOD3(WaitForData, bool(AccessMode, - const base::Callback<void(AccessMode)>&, - ErrorPtr*)); - MOCK_METHOD4(WaitForDataBlocking, - bool(AccessMode, base::TimeDelta, AccessMode*, ErrorPtr*)); + MOCK_METHOD(bool, + WaitForData, + (AccessMode, const base::Callback<void(AccessMode)>&, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + WaitForDataBlocking, + (AccessMode, base::TimeDelta, AccessMode*, ErrorPtr*), + (override)); private: DISALLOW_COPY_AND_ASSIGN(MockStreamImpl); @@ -333,7 +342,10 @@ TEST(Stream, ReadBlocking) { TEST(Stream, ReadAllBlocking) { class MockReadBlocking : public MockStreamImpl { public: - MOCK_METHOD4(ReadBlocking, bool(void*, size_t, size_t*, ErrorPtr*)); + MOCK_METHOD(bool, + ReadBlocking, + (void*, size_t, size_t*, ErrorPtr*), + (override)); } stream_mock; char buf[1024]; @@ -471,7 +483,10 @@ TEST(Stream, WriteBlocking) { TEST(Stream, WriteAllBlocking) { class MockWritelocking : public MockStreamImpl { public: - MOCK_METHOD4(WriteBlocking, bool(const void*, size_t, size_t*, ErrorPtr*)); + MOCK_METHOD(bool, + WriteBlocking, + (const void*, size_t, size_t*, ErrorPtr*), + (override)); } stream_mock; char buf[1024]; diff --git a/brillo/streams/stream_utils.cc b/brillo/streams/stream_utils.cc index 3f7a14a..6f8a1d0 100644 --- a/brillo/streams/stream_utils.cc +++ b/brillo/streams/stream_utils.cc @@ -4,7 +4,11 @@ #include <brillo/streams/stream_utils.h> +#include <algorithm> #include <limits> +#include <memory> +#include <utility> +#include <vector> #include <base/bind.h> #include <brillo/message_loops/message_loop.h> diff --git a/brillo/streams/stream_utils_unittest.cc b/brillo/streams/stream_utils_test.cc index f27d233..e0b327d 100644 --- a/brillo/streams/stream_utils_unittest.cc +++ b/brillo/streams/stream_utils_test.cc @@ -5,6 +5,9 @@ #include <brillo/streams/stream_utils.h> #include <limits> +#include <memory> +#include <string> +#include <utility> #include <base/bind.h> #include <brillo/message_loops/fake_message_loop.h> @@ -14,9 +17,7 @@ #include <gmock/gmock.h> #include <gtest/gtest.h> -using testing::DoAll; using testing::InSequence; -using testing::Return; using testing::StrictMock; using testing::_; diff --git a/brillo/streams/tls_stream.cc b/brillo/streams/tls_stream.cc index fde4193..cc63258 100644 --- a/brillo/streams/tls_stream.cc +++ b/brillo/streams/tls_stream.cc @@ -7,6 +7,7 @@ #include <algorithm> #include <limits> #include <string> +#include <utility> #include <vector> #include <openssl/err.h> @@ -67,6 +68,11 @@ const char kCACertificatePath[] = namespace brillo { +// TODO(crbug.com/984789): Remove once support for OpenSSL <1.1 is dropped. +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#define TLS_client_method() TLSv1_2_client_method() +#endif + // Helper implementation of TLS stream used to hide most of OpenSSL inner // workings from the users of brillo::TlsStream. class TlsStream::TlsStreamImpl { @@ -341,7 +347,7 @@ bool TlsStream::TlsStreamImpl::Init(StreamPtr socket, const base::Closure& success_callback, const Stream::ErrorCallback& error_callback, ErrorPtr* error) { - ctx_.reset(SSL_CTX_new(TLSv1_2_client_method())); + ctx_.reset(SSL_CTX_new(TLS_client_method())); if (!ctx_) return ReportError(error, FROM_HERE, "Cannot create SSL_CTX"); diff --git a/brillo/strings/string_utils_unittest.cc b/brillo/strings/string_utils_test.cc index c554e74..c554e74 100644 --- a/brillo/strings/string_utils_unittest.cc +++ b/brillo/strings/string_utils_test.cc diff --git a/brillo/syslog_logging_unittest.cc b/brillo/syslog_logging_test.cc index e852e50..e852e50 100644 --- a/brillo/syslog_logging_unittest.cc +++ b/brillo/syslog_logging_test.cc diff --git a/brillo/type_list.h b/brillo/type_list.h new file mode 100644 index 0000000..c5ccc5e --- /dev/null +++ b/brillo/type_list.h @@ -0,0 +1,53 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_TYPE_LIST_H_ +#define LIBBRILLO_BRILLO_TYPE_LIST_H_ + +#include <type_traits> + +namespace brillo { + +template <typename... Ts> +struct TypeList {}; + +namespace type_list { + +template <typename... Ts> +struct is_one_of { + static constexpr bool value = false; +}; + +template <typename T, typename Head, typename... Tail> +struct is_one_of<T, TypeList<Head, Tail...>> { + static constexpr bool value = + std::is_same<T, Head>::value || is_one_of<T, TypeList<Tail...>>::value; +}; + +} // namespace type_list + +// Enables a template if the type T is in the typelist Types. Since std::same is +// used to determine equivalence of types, cv-qualifiers (const and volatile) +// *are* important. Note that typedefs and type aliases do not define new types. +// +// Example: +// using ValidTypes = TypeList<int32_t, float>; +// +// template <typename T, typename = EnableIfIsOneOf<T, ValidTypes>> +// void f(){} +// +// using integer = int32_t; +// ... +// f<int32_t>(); // Fine. +// f<float>(); // Fine. +// f<integer>(); // Fine. +// f<const int32_t>(); // Error; no matching function for call to 'f'. +// f<uint32_t>(); // Error; no matching function for call to 'f'. +template <typename T, typename Types> +using EnableIfIsOneOf = + std::enable_if_t<type_list::is_one_of<T, Types>::value>; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_TYPE_LIST_H_ diff --git a/brillo/type_name_undecorate.cc b/brillo/type_name_undecorate.cc index b588170..b24a746 100644 --- a/brillo/type_name_undecorate.cc +++ b/brillo/type_name_undecorate.cc @@ -5,6 +5,8 @@ #include <brillo/type_name_undecorate.h> #include <cstring> +#include <map> +#include <string> #ifdef __GNUG__ #include <cstdlib> diff --git a/brillo/type_name_undecorate_unittest.cc b/brillo/type_name_undecorate_test.cc index 604c0fb..a41c6cd 100644 --- a/brillo/type_name_undecorate_unittest.cc +++ b/brillo/type_name_undecorate_test.cc @@ -4,6 +4,8 @@ #include <brillo/type_name_undecorate.h> +#include <map> + #include <brillo/variant_dictionary.h> #include <gtest/gtest.h> diff --git a/brillo/udev/OWNERS b/brillo/udev/OWNERS new file mode 100644 index 0000000..f426deb --- /dev/null +++ b/brillo/udev/OWNERS @@ -0,0 +1,3 @@ +amistry@chromium.org +ejcaruso@chromium.org +wbbradley@chromium.org diff --git a/brillo/udev/mock_udev.h b/brillo/udev/mock_udev.h new file mode 100644 index 0000000..8494bab --- /dev/null +++ b/brillo/udev/mock_udev.h @@ -0,0 +1,48 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_H_ +#define LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_H_ + +#include <memory> + +#include <brillo/brillo_export.h> +#include <brillo/udev/udev.h> +#include <brillo/udev/udev_device.h> +#include <brillo/udev/udev_enumerate.h> +#include <brillo/udev/udev_monitor.h> +#include <gmock/gmock.h> + +namespace brillo { + +class BRILLO_EXPORT MockUdev : public Udev { + public: + MockUdev() : Udev(nullptr) {} + ~MockUdev() override = default; + + MOCK_METHOD(std::unique_ptr<UdevDevice>, + CreateDeviceFromSysPath, + (const char*), + (override)); + MOCK_METHOD(std::unique_ptr<UdevDevice>, + CreateDeviceFromDeviceNumber, + (char, dev_t), + (override)); + MOCK_METHOD(std::unique_ptr<UdevDevice>, + CreateDeviceFromSubsystemSysName, + (const char*, const char*), + (override)); + MOCK_METHOD(std::unique_ptr<UdevEnumerate>, CreateEnumerate, (), (override)); + MOCK_METHOD(std::unique_ptr<UdevMonitor>, + CreateMonitorFromNetlink, + (const char*), + (override)); + + private: + DISALLOW_COPY_AND_ASSIGN(MockUdev); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_H_ diff --git a/brillo/udev/mock_udev_device.h b/brillo/udev/mock_udev_device.h new file mode 100644 index 0000000..6e812d1 --- /dev/null +++ b/brillo/udev/mock_udev_device.h @@ -0,0 +1,68 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_DEVICE_H_ +#define LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_DEVICE_H_ + +#include <memory> + +#include <brillo/brillo_export.h> +#include <brillo/udev/udev_device.h> +#include <gmock/gmock.h> + +namespace brillo { + +class BRILLO_EXPORT MockUdevDevice : public UdevDevice { + public: + MockUdevDevice() = default; + ~MockUdevDevice() override = default; + + MOCK_METHOD(std::unique_ptr<UdevDevice>, GetParent, (), (const, override)); + MOCK_METHOD(std::unique_ptr<UdevDevice>, + GetParentWithSubsystemDeviceType, + (const char*, const char*), + (const, override)); + MOCK_METHOD(bool, IsInitialized, (), (const, override)); + MOCK_METHOD(uint64_t, GetMicrosecondsSinceInitialized, (), (const, override)); + MOCK_METHOD(uint64_t, GetSequenceNumber, (), (const, override)); + MOCK_METHOD(const char*, GetDevicePath, (), (const, override)); + MOCK_METHOD(const char*, GetDeviceNode, (), (const, override)); + MOCK_METHOD(dev_t, GetDeviceNumber, (), (const, override)); + MOCK_METHOD(const char*, GetDeviceType, (), (const, override)); + MOCK_METHOD(const char*, GetDriver, (), (const, override)); + MOCK_METHOD(const char*, GetSubsystem, (), (const, override)); + MOCK_METHOD(const char*, GetSysPath, (), (const, override)); + MOCK_METHOD(const char*, GetSysName, (), (const, override)); + MOCK_METHOD(const char*, GetSysNumber, (), (const, override)); + MOCK_METHOD(const char*, GetAction, (), (const, override)); + MOCK_METHOD(std::unique_ptr<UdevListEntry>, + GetDeviceLinksListEntry, + (), + (const, override)); + MOCK_METHOD(std::unique_ptr<UdevListEntry>, + GetPropertiesListEntry, + (), + (const, override)); + MOCK_METHOD(const char*, GetPropertyValue, (const char*), (const, override)); + MOCK_METHOD(std::unique_ptr<UdevListEntry>, + GetTagsListEntry, + (), + (const, override)); + MOCK_METHOD(std::unique_ptr<UdevListEntry>, + GetSysAttributeListEntry, + (), + (const, override)); + MOCK_METHOD(const char*, + GetSysAttributeValue, + (const char*), + (const, override)); + MOCK_METHOD(std::unique_ptr<UdevDevice>, Clone, (), (override)); + + private: + DISALLOW_COPY_AND_ASSIGN(MockUdevDevice); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_DEVICE_H_ diff --git a/brillo/udev/mock_udev_enumerate.h b/brillo/udev/mock_udev_enumerate.h new file mode 100644 index 0000000..faf94fc --- /dev/null +++ b/brillo/udev/mock_udev_enumerate.h @@ -0,0 +1,49 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_ENUMERATE_H_ +#define LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_ENUMERATE_H_ + +#include <memory> + +#include <brillo/brillo_export.h> +#include <brillo/udev/udev_enumerate.h> +#include <gmock/gmock.h> + +namespace brillo { + +class BRILLO_EXPORT MockUdevEnumerate : public UdevEnumerate { + public: + MockUdevEnumerate() = default; + ~MockUdevEnumerate() override = default; + + MOCK_METHOD(bool, AddMatchSubsystem, (const char*), (override)); + MOCK_METHOD(bool, AddNoMatchSubsystem, (const char*), (override)); + MOCK_METHOD(bool, + AddMatchSysAttribute, + (const char*, const char*), + (override)); + MOCK_METHOD(bool, + AddNoMatchSysAttribute, + (const char*, const char*), + (override)); + MOCK_METHOD(bool, AddMatchProperty, (const char*, const char*), (override)); + MOCK_METHOD(bool, AddMatchSysName, (const char*), (override)); + MOCK_METHOD(bool, AddMatchTag, (const char*), (override)); + MOCK_METHOD(bool, AddMatchIsInitialized, (), (override)); + MOCK_METHOD(bool, AddSysPath, (const char*), (override)); + MOCK_METHOD(bool, ScanDevices, (), (override)); + MOCK_METHOD(bool, ScanSubsystems, (), (override)); + MOCK_METHOD(std::unique_ptr<UdevListEntry>, + GetListEntry, + (), + (const, override)); + + private: + DISALLOW_COPY_AND_ASSIGN(MockUdevEnumerate); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_ENUMERATE_H_ diff --git a/brillo/udev/mock_udev_list_entry.h b/brillo/udev/mock_udev_list_entry.h new file mode 100644 index 0000000..255b6e2 --- /dev/null +++ b/brillo/udev/mock_udev_list_entry.h @@ -0,0 +1,35 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_LIST_ENTRY_H_ +#define LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_LIST_ENTRY_H_ + +#include <memory> + +#include <brillo/brillo_export.h> +#include <brillo/udev/udev_list_entry.h> +#include <gmock/gmock.h> + +namespace brillo { + +class BRILLO_EXPORT MockUdevListEntry : public UdevListEntry { + public: + MockUdevListEntry() = default; + ~MockUdevListEntry() override = default; + + MOCK_METHOD(std::unique_ptr<UdevListEntry>, GetNext, (), (const, override)); + MOCK_METHOD(std::unique_ptr<UdevListEntry>, + GetByName, + (const char*), + (const, override)); + MOCK_METHOD(const char*, GetName, (), (const, override)); + MOCK_METHOD(const char*, GetValue, (), (const, override)); + + private: + DISALLOW_COPY_AND_ASSIGN(MockUdevListEntry); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_LIST_ENTRY_H_ diff --git a/brillo/udev/mock_udev_monitor.h b/brillo/udev/mock_udev_monitor.h new file mode 100644 index 0000000..5854327 --- /dev/null +++ b/brillo/udev/mock_udev_monitor.h @@ -0,0 +1,38 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_MONITOR_H_ +#define LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_MONITOR_H_ + +#include <memory> + +#include <brillo/brillo_export.h> +#include <brillo/udev/udev_monitor.h> +#include <gmock/gmock.h> + +namespace brillo { + +class BRILLO_EXPORT MockUdevMonitor : public UdevMonitor { + public: + MockUdevMonitor() = default; + ~MockUdevMonitor() override = default; + + MOCK_METHOD(bool, EnableReceiving, (), (override)); + MOCK_METHOD(int, GetFileDescriptor, (), (const, override)); + MOCK_METHOD(std::unique_ptr<UdevDevice>, ReceiveDevice, (), (override)); + MOCK_METHOD(bool, + FilterAddMatchSubsystemDeviceType, + (const char*, const char*), + (override)); + MOCK_METHOD(bool, FilterAddMatchTag, (const char*), (override)); + MOCK_METHOD(bool, FilterUpdate, (), (override)); + MOCK_METHOD(bool, FilterRemove, (), (override)); + + private: + DISALLOW_COPY_AND_ASSIGN(MockUdevMonitor); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_MONITOR_H_ diff --git a/brillo/udev/udev.cc b/brillo/udev/udev.cc new file mode 100644 index 0000000..78f9d72 --- /dev/null +++ b/brillo/udev/udev.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/udev/udev.h> + +#include <libudev.h> + +#include <base/format_macros.h> +#include <base/logging.h> +#include <base/strings/stringprintf.h> +#include <brillo/udev/udev_device.h> +#include <brillo/udev/udev_enumerate.h> +#include <brillo/udev/udev_monitor.h> + +using base::StringPrintf; + +namespace brillo { + +Udev::Udev(struct udev* udev) : udev_(udev) {} + +Udev::~Udev() { + if (udev_) { + udev_unref(udev_); + udev_ = nullptr; + } +} + +// static +std::unique_ptr<Udev> Udev::Create() { + struct udev* udev = udev_new(); + if (!udev) + return nullptr; + + return std::unique_ptr<Udev>(new Udev(udev)); +} + +// static +std::unique_ptr<UdevDevice> Udev::CreateDevice(udev_device* device) { + auto device_to_return = std::make_unique<UdevDevice>(device); + + // UdevDevice increases the reference count of the udev_device struct by one. + // Thus, decrease the reference count of the udev_device struct by one before + // returning UdevDevice. + udev_device_unref(device); + + return device_to_return; +} + +std::unique_ptr<UdevDevice> Udev::CreateDeviceFromSysPath( + const char* sys_path) { + udev_device* device = udev_device_new_from_syspath(udev_, sys_path); + if (device) + return CreateDevice(device); + + VLOG(2) << StringPrintf( + "udev_device_new_from_syspath" + "(%p, \"%s\") returned nullptr.", + udev_, sys_path); + return nullptr; +} + +std::unique_ptr<UdevDevice> Udev::CreateDeviceFromDeviceNumber( + char type, dev_t device_number) { + udev_device* device = udev_device_new_from_devnum(udev_, type, device_number); + if (device) + return CreateDevice(device); + + VLOG(2) << StringPrintf( + "udev_device_new_from_devnum" + "(%p, %d, %" PRIu64 ") returned nullptr.", + udev_, type, device_number); + return nullptr; +} + +std::unique_ptr<UdevDevice> Udev::CreateDeviceFromSubsystemSysName( + const char* subsystem, const char* sys_name) { + udev_device* device = + udev_device_new_from_subsystem_sysname(udev_, subsystem, sys_name); + if (device) + return CreateDevice(device); + + VLOG(2) << StringPrintf( + "udev_device_new_from_subsystem_sysname" + "(%p, \"%s\", \"%s\") returned nullptr.", + udev_, subsystem, sys_name); + return nullptr; +} + +std::unique_ptr<UdevEnumerate> Udev::CreateEnumerate() { + udev_enumerate* enumerate = udev_enumerate_new(udev_); + if (enumerate) { + auto enumerate_to_return = std::make_unique<UdevEnumerate>(enumerate); + + // UdevEnumerate increases the reference count of the udev_enumerate struct + // by one. Thus, decrease the reference count of the udev_enumerate struct + // by one before returning UdevEnumerate. + udev_enumerate_unref(enumerate); + + return enumerate_to_return; + } + + VLOG(2) << StringPrintf("udev_enumerate_new(%p) returned nullptr.", udev_); + return nullptr; +} + +std::unique_ptr<UdevMonitor> Udev::CreateMonitorFromNetlink(const char* name) { + udev_monitor* monitor = udev_monitor_new_from_netlink(udev_, name); + if (monitor) { + auto monitor_to_return = std::make_unique<UdevMonitor>(monitor); + + // UdevMonitor increases the reference count of the udev_monitor struct by + // one. Thus, decrease the reference count of the udev_monitor struct by one + // before returning UdevMonitor. + udev_monitor_unref(monitor); + + return monitor_to_return; + } + + VLOG(2) << StringPrintf( + "udev_monitor_new_from_netlink" + "(%p, \"%s\") returned nullptr.", + udev_, name); + return nullptr; +} + +} // namespace brillo diff --git a/brillo/udev/udev.h b/brillo/udev/udev.h new file mode 100644 index 0000000..b2c6c60 --- /dev/null +++ b/brillo/udev/udev.h @@ -0,0 +1,69 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_UDEV_H_ +#define LIBBRILLO_BRILLO_UDEV_UDEV_H_ + +#include <sys/types.h> + +#include <memory> + +#include <base/macros.h> +#include <brillo/brillo_export.h> + +struct udev; +struct udev_device; + +namespace brillo { + +class UdevDevice; +class UdevEnumerate; +class UdevMonitor; + +// A udev library context, which wraps a udev C struct from libudev and related +// library functions into a C++ object. +class BRILLO_EXPORT Udev { + public: + // Creates and initializes a Udev object. Returns nullptr on failure. + static std::unique_ptr<Udev> Create(); + virtual ~Udev(); + + // Wraps udev_device_new_from_syspath(). + virtual std::unique_ptr<UdevDevice> CreateDeviceFromSysPath( + const char* sys_path); + + // Wraps udev_device_new_from_devnum(). + virtual std::unique_ptr<UdevDevice> CreateDeviceFromDeviceNumber( + char type, dev_t device_number); + + // Wraps udev_device_new_from_subsystem_sysname(). + virtual std::unique_ptr<UdevDevice> CreateDeviceFromSubsystemSysName( + const char* subsystem, const char* sys_name); + + // Wraps udev_enumerate_new(). + virtual std::unique_ptr<UdevEnumerate> CreateEnumerate(); + + // Wraps udev_monitor_new_from_netlink(). + virtual std::unique_ptr<UdevMonitor> CreateMonitorFromNetlink( + const char* name); + + private: + friend class MockUdev; + + // Creates a Udev by taking ownership of the |udev|. + explicit Udev(struct udev* udev); + + // Creates a UdevDevice object that wraps a given udev_device struct pointed + // by |device|. The ownership of |device| is transferred to returned + // UdevDevice object. + static std::unique_ptr<UdevDevice> CreateDevice(udev_device* device); + + struct udev* udev_; + + DISALLOW_COPY_AND_ASSIGN(Udev); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_UDEV_H_ diff --git a/brillo/udev/udev_device.cc b/brillo/udev/udev_device.cc new file mode 100644 index 0000000..2251699 --- /dev/null +++ b/brillo/udev/udev_device.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/udev/udev_device.h> + +#include <libudev.h> + +#include <base/logging.h> + +namespace brillo { + +UdevDevice::UdevDevice() : device_(nullptr) {} + +UdevDevice::UdevDevice(udev_device* device) : device_(device) { + CHECK(device_); + + udev_device_ref(device_); +} + +UdevDevice::~UdevDevice() { + if (device_) { + udev_device_unref(device_); + device_ = nullptr; + } +} + +std::unique_ptr<UdevDevice> UdevDevice::GetParent() const { + // udev_device_get_parent does not increase the reference count of the + // returned udev_device struct. + udev_device* parent_device = udev_device_get_parent(device_); + return parent_device ? std::make_unique<UdevDevice>(parent_device) : nullptr; +} + +std::unique_ptr<UdevDevice> UdevDevice::GetParentWithSubsystemDeviceType( + const char* subsystem, const char* device_type) const { + // udev_device_get_parent_with_subsystem_devtype does not increase the + // reference count of the returned udev_device struct. + udev_device* parent_device = udev_device_get_parent_with_subsystem_devtype( + device_, subsystem, device_type); + return parent_device ? std::make_unique<UdevDevice>(parent_device) : nullptr; +} + +bool UdevDevice::IsInitialized() const { + return udev_device_get_is_initialized(device_); +} + +uint64_t UdevDevice::GetMicrosecondsSinceInitialized() const { + return udev_device_get_usec_since_initialized(device_); +} + +uint64_t UdevDevice::GetSequenceNumber() const { + return udev_device_get_seqnum(device_); +} + +const char* UdevDevice::GetDevicePath() const { + return udev_device_get_devpath(device_); +} + +const char* UdevDevice::GetDeviceNode() const { + return udev_device_get_devnode(device_); +} + +dev_t UdevDevice::GetDeviceNumber() const { + return udev_device_get_devnum(device_); +} + +const char* UdevDevice::GetDeviceType() const { + return udev_device_get_devtype(device_); +} + +const char* UdevDevice::GetDriver() const { + return udev_device_get_driver(device_); +} + +const char* UdevDevice::GetSubsystem() const { + return udev_device_get_subsystem(device_); +} + +const char* UdevDevice::GetSysPath() const { + return udev_device_get_syspath(device_); +} + +const char* UdevDevice::GetSysName() const { + return udev_device_get_sysname(device_); +} + +const char* UdevDevice::GetSysNumber() const { + return udev_device_get_sysnum(device_); +} + +const char* UdevDevice::GetAction() const { + return udev_device_get_action(device_); +} + +std::unique_ptr<UdevListEntry> UdevDevice::GetDeviceLinksListEntry() const { + udev_list_entry* list_entry = udev_device_get_devlinks_list_entry(device_); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +std::unique_ptr<UdevListEntry> UdevDevice::GetPropertiesListEntry() const { + udev_list_entry* list_entry = udev_device_get_properties_list_entry(device_); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +const char* UdevDevice::GetPropertyValue(const char* key) const { + return udev_device_get_property_value(device_, key); +} + +std::unique_ptr<UdevListEntry> UdevDevice::GetTagsListEntry() const { + udev_list_entry* list_entry = udev_device_get_tags_list_entry(device_); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +std::unique_ptr<UdevListEntry> UdevDevice::GetSysAttributeListEntry() const { + udev_list_entry* list_entry = udev_device_get_sysattr_list_entry(device_); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +const char* UdevDevice::GetSysAttributeValue(const char* attribute) const { + return udev_device_get_sysattr_value(device_, attribute); +} + +std::unique_ptr<UdevDevice> UdevDevice::Clone() { + return std::make_unique<UdevDevice>(device_); +} + +} // namespace brillo diff --git a/brillo/udev/udev_device.h b/brillo/udev/udev_device.h new file mode 100644 index 0000000..2704a22 --- /dev/null +++ b/brillo/udev/udev_device.h @@ -0,0 +1,117 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_UDEV_DEVICE_H_ +#define LIBBRILLO_BRILLO_UDEV_UDEV_DEVICE_H_ + +#include <stdint.h> +#include <sys/types.h> + +#include <memory> + +#include <base/macros.h> +#include <brillo/brillo_export.h> +#include <brillo/udev/udev_list_entry.h> + +struct udev_device; + +namespace brillo { + +// A udev device, which wraps a udev_device C struct from libudev and related +// library functions into a C++ object. +class BRILLO_EXPORT UdevDevice { + public: + // Constructs a UdevDevice object by taking a raw pointer to a udev_device + // struct as |device|. The ownership of |device| is not transferred, but its + // reference count is increased by one during the lifetime of this object. + explicit UdevDevice(udev_device* device); + + // Destructs this UdevDevice object and decreases the reference count of the + // underlying udev_device struct by one. + virtual ~UdevDevice(); + + // Wraps udev_device_get_parent(). + virtual std::unique_ptr<UdevDevice> GetParent() const; + + // Wraps udev_device_get_parent_with_subsystem_devtype(). + virtual std::unique_ptr<UdevDevice> GetParentWithSubsystemDeviceType( + const char* subsystem, const char* device_type) const; + + // Wraps udev_device_get_is_initialized(). + virtual bool IsInitialized() const; + + // Wraps udev_device_get_usec_since_initialized(). + virtual uint64_t GetMicrosecondsSinceInitialized() const; + + // Wraps udev_device_get_seqnum(). + virtual uint64_t GetSequenceNumber() const; + + // Wraps udev_device_get_devpath(). + virtual const char* GetDevicePath() const; + + // Wraps udev_device_get_devnode(). + virtual const char* GetDeviceNode() const; + + // Wraps udev_device_get_devnum(). + virtual dev_t GetDeviceNumber() const; + + // Wraps udev_device_get_devtype(). + virtual const char* GetDeviceType() const; + + // Wraps udev_device_get_driver(). + virtual const char* GetDriver() const; + + // Wraps udev_device_get_subsystem(). + virtual const char* GetSubsystem() const; + + // Wraps udev_device_get_syspath(). + virtual const char* GetSysPath() const; + + // Wraps udev_device_get_sysname(). + virtual const char* GetSysName() const; + + // Wraps udev_device_get_sysnum(). + virtual const char* GetSysNumber() const; + + // Wraps udev_device_get_action(). + virtual const char* GetAction() const; + + // Wraps udev_device_get_devlinks_list_entry(). + virtual std::unique_ptr<UdevListEntry> GetDeviceLinksListEntry() const; + + // Wraps udev_device_get_properties_list_entry(). + virtual std::unique_ptr<UdevListEntry> GetPropertiesListEntry() const; + + // Wraps udev_device_get_property_value(). + virtual const char* GetPropertyValue(const char* key) const; + + // Wraps udev_device_get_tags_list_entry(). + virtual std::unique_ptr<UdevListEntry> GetTagsListEntry() const; + + // Wraps udev_device_get_sysattr_list_entry(). + virtual std::unique_ptr<UdevListEntry> GetSysAttributeListEntry() const; + + // Wraps udev_device_get_sysattr_value(). + virtual const char* GetSysAttributeValue(const char* attribute) const; + + // Creates a copy of this UdevDevice pointing to the same underlying + // struct udev_device* (increasing its libudev reference count by 1). + virtual std::unique_ptr<UdevDevice> Clone(); + + private: + // Allows MockUdevDevice to invoke the private default constructor below. + friend class MockUdevDevice; + + // Constructs a UdevDevice object without referencing a udev_device struct, + // which is only allowed to be called by MockUdevDevice. + UdevDevice(); + + udev_device* device_; + + DISALLOW_COPY_AND_ASSIGN(UdevDevice); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_UDEV_DEVICE_H_ diff --git a/brillo/udev/udev_enumerate.cc b/brillo/udev/udev_enumerate.cc new file mode 100644 index 0000000..0ac59b9 --- /dev/null +++ b/brillo/udev/udev_enumerate.cc @@ -0,0 +1,158 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/udev/udev_enumerate.h> + +#include <libudev.h> + +#include <base/logging.h> +#include <base/strings/stringprintf.h> +#include <brillo/udev/udev_device.h> + +using base::StringPrintf; + +namespace brillo { + +UdevEnumerate::UdevEnumerate() : enumerate_(nullptr) {} + +UdevEnumerate::UdevEnumerate(udev_enumerate* enumerate) + : enumerate_(enumerate) { + CHECK(enumerate_); + + udev_enumerate_ref(enumerate_); +} + +UdevEnumerate::~UdevEnumerate() { + if (enumerate_) { + udev_enumerate_unref(enumerate_); + enumerate_ = nullptr; + } +} + +bool UdevEnumerate::AddMatchSubsystem(const char* subsystem) { + int result = udev_enumerate_add_match_subsystem(enumerate_, subsystem); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_match_subsystem (%p, \"%s\") returned %d.", + enumerate_, subsystem, result); + return false; +} + +bool UdevEnumerate::AddNoMatchSubsystem(const char* subsystem) { + int result = udev_enumerate_add_nomatch_subsystem(enumerate_, subsystem); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_nomatch_subsystem (%p, \"%s\") returned %d.", + enumerate_, subsystem, result); + return false; +} + +bool UdevEnumerate::AddMatchSysAttribute(const char* attribute, + const char* value) { + int result = udev_enumerate_add_match_sysattr(enumerate_, attribute, value); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_match_sysattr (%p, \"%s\", \"%s\") returned %d.", + enumerate_, attribute, value, result); + return false; +} + +bool UdevEnumerate::AddNoMatchSysAttribute(const char* attribute, + const char* value) { + int result = udev_enumerate_add_nomatch_sysattr(enumerate_, attribute, value); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_nomatch_sysattr (%p, \"%s\", \"%s\") returned %d.", + enumerate_, attribute, value, result); + return false; +} + +bool UdevEnumerate::AddMatchProperty(const char* property, const char* value) { + int result = udev_enumerate_add_match_property(enumerate_, property, value); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_match_property (%p, \"%s\", \"%s\") returned %d.", + enumerate_, property, value, result); + return false; +} + +bool UdevEnumerate::AddMatchSysName(const char* sys_name) { + int result = udev_enumerate_add_match_sysname(enumerate_, sys_name); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_match_sysname (%p, \"%s\") returned %d.", enumerate_, + sys_name, result); + return false; +} + +bool UdevEnumerate::AddMatchTag(const char* tag) { + int result = udev_enumerate_add_match_tag(enumerate_, tag); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_match_tag (%p, \"%s\") returned %d.", enumerate_, tag, + result); + return false; +} + +bool UdevEnumerate::AddMatchIsInitialized() { + int result = udev_enumerate_add_match_is_initialized(enumerate_); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_match_is_initialized (%p) returned %d.", enumerate_, + result); + return false; +} + +bool UdevEnumerate::AddSysPath(const char* sys_path) { + int result = udev_enumerate_add_syspath(enumerate_, sys_path); + if (result == 0) + return true; + + VLOG(2) << StringPrintf("udev_enumerate_add_syspath(%p, \"%s\") returned %d.", + enumerate_, sys_path, result); + return false; +} + +bool UdevEnumerate::ScanDevices() { + int result = udev_enumerate_scan_devices(enumerate_); + if (result == 0) + return true; + + VLOG(2) << StringPrintf("udev_enumerate_scan_devices(%p) returned %d.", + enumerate_, result); + return false; +} + +bool UdevEnumerate::ScanSubsystems() { + int result = udev_enumerate_scan_subsystems(enumerate_); + if (result == 0) + return true; + + VLOG(2) << StringPrintf("udev_enumerate_scan_subsystems(%p) returned %d.", + enumerate_, result); + return false; +} + +std::unique_ptr<UdevListEntry> UdevEnumerate::GetListEntry() const { + udev_list_entry* list_entry = udev_enumerate_get_list_entry(enumerate_); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +} // namespace brillo diff --git a/brillo/udev/udev_enumerate.h b/brillo/udev/udev_enumerate.h new file mode 100644 index 0000000..50a6183 --- /dev/null +++ b/brillo/udev/udev_enumerate.h @@ -0,0 +1,83 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_UDEV_ENUMERATE_H_ +#define LIBBRILLO_BRILLO_UDEV_UDEV_ENUMERATE_H_ + +#include <memory> + +#include <base/macros.h> +#include <brillo/brillo_export.h> +#include <brillo/udev/udev_list_entry.h> + +struct udev_enumerate; + +namespace brillo { + +// A udev enumerate class, which wraps a udev_enumerate C struct from libudev +// and related library functions into a C++ object. +class BRILLO_EXPORT UdevEnumerate { + public: + // Constructs a UdevEnumerate object by taking a raw pointer to a + // udev_enumerate struct as |enumerate|. The ownership of |enumerate| is not + // transferred, but its reference count is increased by one during the + // lifetime of this object. + explicit UdevEnumerate(udev_enumerate* enumerate); + + // Destructs this UdevEnumerate object and decreases the reference count of + // the underlying udev_enumerate struct by one. + virtual ~UdevEnumerate(); + + // Wraps udev_enumerate_add_match_subsystem(). Returns true on success. + virtual bool AddMatchSubsystem(const char* subsystem); + + // Wraps udev_enumerate_add_nomatch_subsystem(). Returns true on success. + virtual bool AddNoMatchSubsystem(const char* subsystem); + + // Wraps udev_enumerate_add_match_sysattr(). Returns true on success. + virtual bool AddMatchSysAttribute(const char* attribute, const char* value); + + // Wraps udev_enumerate_add_nomatch_sysattr(). Returns true on success. + virtual bool AddNoMatchSysAttribute(const char* attribute, const char* value); + + // Wraps udev_enumerate_add_match_property(). Returns true on success. + virtual bool AddMatchProperty(const char* property, const char* value); + + // Wraps udev_enumerate_add_match_sysname(). Returns true on success. + virtual bool AddMatchSysName(const char* sys_name); + + // Wraps udev_enumerate_add_match_tag(). Returns true on success. + virtual bool AddMatchTag(const char* tag); + + // Wraps udev_enumerate_add_match_is_initialized(). Returns true on success. + virtual bool AddMatchIsInitialized(); + + // Wraps udev_enumerate_add_syspath(). Returns true on success. + virtual bool AddSysPath(const char* sys_path); + + // Wraps udev_enumerate_scan_devices(). Returns true on success. + virtual bool ScanDevices(); + + // Wraps udev_enumerate_scan_subsystems(). Returns true on success. + virtual bool ScanSubsystems(); + + // Wraps udev_enumerate_get_list_entry(). + virtual std::unique_ptr<UdevListEntry> GetListEntry() const; + + private: + // Allows MockUdevEnumerate to invoke the private default constructor below. + friend class MockUdevEnumerate; + + // Constructs a UdevEnumerate object without referencing a udev_enumerate + // struct, which is only allowed to be called by MockUdevEnumerate. + UdevEnumerate(); + + udev_enumerate* enumerate_; + + DISALLOW_COPY_AND_ASSIGN(UdevEnumerate); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_UDEV_ENUMERATE_H_ diff --git a/brillo/udev/udev_list_entry.cc b/brillo/udev/udev_list_entry.cc new file mode 100644 index 0000000..739c435 --- /dev/null +++ b/brillo/udev/udev_list_entry.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/udev/udev_list_entry.h> + +#include <libudev.h> + +#include <base/logging.h> + +namespace brillo { + +UdevListEntry::UdevListEntry() : list_entry_(nullptr) {} + +UdevListEntry::UdevListEntry(udev_list_entry* list_entry) + : list_entry_(list_entry) { + CHECK(list_entry_); +} + +std::unique_ptr<UdevListEntry> UdevListEntry::GetNext() const { + udev_list_entry* list_entry = udev_list_entry_get_next(list_entry_); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +std::unique_ptr<UdevListEntry> UdevListEntry::GetByName( + const char* name) const { + udev_list_entry* list_entry = udev_list_entry_get_by_name(list_entry_, name); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +const char* UdevListEntry::GetName() const { + return udev_list_entry_get_name(list_entry_); +} + +const char* UdevListEntry::GetValue() const { + return udev_list_entry_get_value(list_entry_); +} + +} // namespace brillo diff --git a/brillo/udev/udev_list_entry.h b/brillo/udev/udev_list_entry.h new file mode 100644 index 0000000..ee61d18 --- /dev/null +++ b/brillo/udev/udev_list_entry.h @@ -0,0 +1,55 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_UDEV_LIST_ENTRY_H_ +#define LIBBRILLO_BRILLO_UDEV_UDEV_LIST_ENTRY_H_ + +#include <memory> + +#include <base/macros.h> +#include <brillo/brillo_export.h> + +struct udev_list_entry; + +namespace brillo { + +// A udev list entry, which wraps a udev_list_entry C struct from libudev and +// related library functions into a C++ object. +class BRILLO_EXPORT UdevListEntry { + public: + // Constructs a UdevListEntry object by taking a raw pointer to a + // udev_list_entry struct as |list_entry|. The ownership of |list_entry| is + // not transferred, and thus it should outlive this object. + explicit UdevListEntry(udev_list_entry* list_entry); + + virtual ~UdevListEntry() = default; + + // Wraps udev_list_entry_get_next(). + virtual std::unique_ptr<UdevListEntry> GetNext() const; + + // Wraps udev_list_entry_get_by_name(). + virtual std::unique_ptr<UdevListEntry> GetByName(const char* name) const; + + // Wraps udev_list_entry_get_name(). + virtual const char* GetName() const; + + // Wraps udev_list_entry_get_value(). + virtual const char* GetValue() const; + + private: + // Allows MockUdevListEntry to invoke the private default constructor below. + friend class MockUdevListEntry; + + // Constructs a UdevListEntry object without referencing a udev_list_entry + // struct, which is only allowed to be called by MockUdevListEntry. + UdevListEntry(); + + udev_list_entry* const list_entry_; + + DISALLOW_COPY_AND_ASSIGN(UdevListEntry); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_UDEV_LIST_ENTRY_H_ diff --git a/brillo/udev/udev_monitor.cc b/brillo/udev/udev_monitor.cc new file mode 100644 index 0000000..c4b63e5 --- /dev/null +++ b/brillo/udev/udev_monitor.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/udev/udev_monitor.h> + +#include <libudev.h> + +#include <base/logging.h> +#include <base/strings/stringprintf.h> +#include <brillo/udev/udev_device.h> + +using base::StringPrintf; + +namespace brillo { + +UdevMonitor::UdevMonitor() : monitor_(nullptr) {} + +UdevMonitor::UdevMonitor(udev_monitor* monitor) : monitor_(monitor) { + CHECK(monitor_); + + udev_monitor_ref(monitor_); +} + +UdevMonitor::~UdevMonitor() { + if (monitor_) { + udev_monitor_unref(monitor_); + monitor_ = nullptr; + } +} + +bool UdevMonitor::EnableReceiving() { + int result = udev_monitor_enable_receiving(monitor_); + if (result == 0) + return true; + + VLOG(2) << StringPrintf("udev_monitor_enable_receiving(%p) returned %d.", + monitor_, result); + return false; +} + +int UdevMonitor::GetFileDescriptor() const { + int file_descriptor = udev_monitor_get_fd(monitor_); + if (file_descriptor >= 0) + return file_descriptor; + + VLOG(2) << StringPrintf("udev_monitor_get_fd(%p) returned %d.", monitor_, + file_descriptor); + return kInvalidFileDescriptor; +} + +std::unique_ptr<UdevDevice> UdevMonitor::ReceiveDevice() { + udev_device* received_device = udev_monitor_receive_device(monitor_); + if (received_device) { + auto device = std::make_unique<UdevDevice>(received_device); + // udev_monitor_receive_device increases the reference count of the returned + // udev_device struct, while UdevDevice also holds a reference count of the + // udev_device struct. Thus, decrease the reference count of the udev_device + // struct. + udev_device_unref(received_device); + return device; + } + + VLOG(2) << StringPrintf("udev_monitor_receive_device(%p) returned nullptr.", + monitor_); + return nullptr; +} + +bool UdevMonitor::FilterAddMatchSubsystemDeviceType(const char* subsystem, + const char* device_type) { + int result = udev_monitor_filter_add_match_subsystem_devtype( + monitor_, subsystem, device_type); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_monitor_filter_add_match_subsystem_devtype (%p, \"%s\", \"%s\") " + "returned %d.", + monitor_, subsystem, device_type, result); + return false; +} + +bool UdevMonitor::FilterAddMatchTag(const char* tag) { + int result = udev_monitor_filter_add_match_tag(monitor_, tag); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_monitor_filter_add_tag (%p, \"%s\") returned %d.", monitor_, tag, + result); + return false; +} + +bool UdevMonitor::FilterUpdate() { + int result = udev_monitor_filter_update(monitor_); + if (result == 0) + return true; + + VLOG(2) << StringPrintf("udev_monitor_filter_update(%p) returned %d.", + monitor_, result); + return false; +} + +bool UdevMonitor::FilterRemove() { + int result = udev_monitor_filter_remove(monitor_); + if (result == 0) + return true; + + VLOG(2) << StringPrintf("udev_monitor_filter_remove(%p) returned %d.", + monitor_, result); + return false; +} + +} // namespace brillo diff --git a/brillo/udev/udev_monitor.h b/brillo/udev/udev_monitor.h new file mode 100644 index 0000000..b9136f0 --- /dev/null +++ b/brillo/udev/udev_monitor.h @@ -0,0 +1,72 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_UDEV_MONITOR_H_ +#define LIBBRILLO_BRILLO_UDEV_UDEV_MONITOR_H_ + +#include <memory> + +#include <base/macros.h> +#include <brillo/brillo_export.h> + +struct udev_monitor; + +namespace brillo { + +class UdevDevice; + +// A udev monitor, which wraps a udev_monitor C struct from libudev and related +// library functions into a C++ object. +class BRILLO_EXPORT UdevMonitor { + public: + static const int kInvalidFileDescriptor = -1; + + // Constructs a UdevMonitor object by taking a raw pointer to a udev_monitor + // struct as |monitor|. The ownership of |monitor| is not transferred, but its + // reference count is increased by one during the lifetime of this object. + explicit UdevMonitor(udev_monitor* monitor); + + // Destructs this UdevMonitor object and decreases the reference count of the + // underlying udev_monitor struct by one. + virtual ~UdevMonitor(); + + // Wraps udev_monitor_enable_receiving(). Returns true on success. + virtual bool EnableReceiving(); + + // Wraps udev_monitor_get_fd(). + virtual int GetFileDescriptor() const; + + // Wraps udev_monitor_receive_device(). + virtual std::unique_ptr<UdevDevice> ReceiveDevice(); + + // Wraps udev_monitor_filter_add_match_subsystem_devtype(). Returns true on + // success. + virtual bool FilterAddMatchSubsystemDeviceType(const char* subsystem, + const char* device_type); + + // Wraps udev_monitor_filter_add_match_tag(). Returns true on success. + virtual bool FilterAddMatchTag(const char* tag); + + // Wraps udev_monitor_filter_update(). Returns true on success. + virtual bool FilterUpdate(); + + // Wraps udev_monitor_filter_remove(). Returns true on success. + virtual bool FilterRemove(); + + private: + // Allows MockUdevMonitor to invoke the private default constructor below. + friend class MockUdevMonitor; + + // Constructs a UdevMonitor object without referencing a udev_monitor struct, + // which is only allowed to be called by MockUdevMonitor. + UdevMonitor(); + + udev_monitor* monitor_; + + DISALLOW_COPY_AND_ASSIGN(UdevMonitor); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_UDEV_MONITOR_H_ diff --git a/brillo/url_utils_unittest.cc b/brillo/url_utils_test.cc index a2603cb..a2603cb 100644 --- a/brillo/url_utils_unittest.cc +++ b/brillo/url_utils_test.cc diff --git a/brillo/userdb_utils.cc b/brillo/userdb_utils.cc index 55c964c..1308fb7 100644 --- a/brillo/userdb_utils.cc +++ b/brillo/userdb_utils.cc @@ -4,6 +4,7 @@ #include "brillo/userdb_utils.h" +#include <errno.h> #include <grp.h> #include <pwd.h> #include <sys/types.h> @@ -12,6 +13,7 @@ #include <vector> #include <base/logging.h> +#include <base/posix/safe_strerror.h> namespace brillo { namespace userdb { @@ -23,8 +25,16 @@ bool GetUserInfo(const std::string& user, uid_t* uid, gid_t* gid) { passwd pwd_buf; passwd* pwd = nullptr; std::vector<char> buf(buf_len); - if (getpwnam_r(user.c_str(), &pwd_buf, buf.data(), buf_len, &pwd) || !pwd) { - PLOG(ERROR) << "Unable to find user " << user; + + int err_num; + do { + err_num = getpwnam_r(user.c_str(), &pwd_buf, buf.data(), buf_len, &pwd); + } while (err_num == EINTR); + + if (!pwd) { + LOG(ERROR) << "Unable to find user " << user << ": " + << (err_num ? base::safe_strerror(err_num) + : "No matching record"); return false; } @@ -42,8 +52,16 @@ bool GetGroupInfo(const std::string& group, gid_t* gid) { struct group grp_buf; struct group* grp = nullptr; std::vector<char> buf(buf_len); - if (getgrnam_r(group.c_str(), &grp_buf, buf.data(), buf_len, &grp) || !grp) { - PLOG(ERROR) << "Unable to find group " << group; + + int err_num; + do { + err_num = getgrnam_r(group.c_str(), &grp_buf, buf.data(), buf_len, &grp); + } while (err_num == EINTR); + + if (!grp) { + LOG(ERROR) << "Unable to find group " << group << ": " + << (err_num ? base::safe_strerror(err_num) + : "No matching record"); return false; } diff --git a/brillo/value_conversion.h b/brillo/value_conversion.h index b520a77..6cf7323 100644 --- a/brillo/value_conversion.h +++ b/brillo/value_conversion.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef BRILLO_VALUE_CONVERSION_H_ -#define BRILLO_VALUE_CONVERSION_H_ +#ifndef LIBBRILLO_BRILLO_VALUE_CONVERSION_H_ +#define LIBBRILLO_BRILLO_VALUE_CONVERSION_H_ // This file provides a set of helper functions to convert between base::Value // and native types. Apart from handling standard types such as 'int' and @@ -24,6 +24,7 @@ #include <map> #include <memory> #include <string> +#include <utility> #include <vector> #include <base/values.h> @@ -73,7 +74,7 @@ bool FromValue(const base::Value& in_value, std::vector<T, Alloc>* out_value) { return false; out_value->clear(); out_value->reserve(list->GetSize()); - for (const auto& item : *list) { + for (const base::Value& item : base::ValueReferenceAdapter(*list)) { T value{}; if (!FromValue(item, &value)) return false; @@ -134,4 +135,4 @@ std::unique_ptr<base::Value> ToValue( } // namespace brillo -#endif // BRILLO_VALUE_CONVERSION_H_ +#endif // LIBBRILLO_BRILLO_VALUE_CONVERSION_H_ diff --git a/brillo/value_conversion_unittest.cc b/brillo/value_conversion_test.cc index aa1be2a..fec4052 100644 --- a/brillo/value_conversion_unittest.cc +++ b/brillo/value_conversion_test.cc @@ -170,7 +170,7 @@ TEST(ValueConversionTest, FromValueVectorOfString) { TEST(ValueConversionTest, FromValueVectorOfVectors) { std::vector<std::vector<int>> actual; EXPECT_TRUE(FromValue(*ParseValue("[[1,2], [], [3]]"), &actual)); - EXPECT_EQ((std::vector<std::vector<int>>{{1,2}, {}, {3}}), actual); + EXPECT_EQ((std::vector<std::vector<int>>{{1, 2}, {}, {3}}), actual); EXPECT_TRUE(FromValue(*ParseValue("[]"), &actual)); EXPECT_TRUE(actual.empty()); diff --git a/brillo/variant_dictionary_unittest.cc b/brillo/variant_dictionary_test.cc index 73ead2c..73ead2c 100644 --- a/brillo/variant_dictionary_unittest.cc +++ b/brillo/variant_dictionary_test.cc diff --git a/gen_coverage_html.sh b/gen_coverage_html.sh deleted file mode 100755 index 9faf1a9..0000000 --- a/gen_coverage_html.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -# Copyright (c) 2009 The Chromium OS Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -set -ex - -scons debug=1 -c -scons debug=1 -lcov -d . --zerocounters -./unittests -lcov --base-directory . --directory . --capture --output-file app.info - -# some versions of genhtml support the --no-function-coverage argument, -# which we want. The problem w/ function coverage is that every template -# instantiation of a method counts as a different method, so if we -# instantiate a method twice, once for testing and once for prod, the method -# is tested, but it shows only 50% function coverage b/c it thinks we didn't -# test the prod version. - -genhtml --no-function-coverage -o html ./app.info || genhtml -o html ./app.info diff --git a/install_attributes/libinstallattributes.h b/install_attributes/libinstallattributes.h index b947156..2bcbf0f 100644 --- a/install_attributes/libinstallattributes.h +++ b/install_attributes/libinstallattributes.h @@ -53,7 +53,7 @@ class BRILLO_EXPORT InstallAttributesReader { // successful, too. bool initialized_ = false; -private: + private: // Try to load the verified install attributes from disk. This is expected to // fail when install attributes haven't yet been finalized (OOBE) or verified // (early in the boot sequence). @@ -63,4 +63,4 @@ private: std::string empty_string_; }; -#endif // LIBBRILLO_LIBINSTALLATTRIBUTES_H_ +#endif // LIBBRILLO_INSTALL_ATTRIBUTES_LIBINSTALLATTRIBUTES_H_ diff --git a/install_attributes/mock_install_attributes_reader.h b/install_attributes/mock_install_attributes_reader.h index 5ccee02..0d2adcd 100644 --- a/install_attributes/mock_install_attributes_reader.h +++ b/install_attributes/mock_install_attributes_reader.h @@ -5,6 +5,8 @@ #ifndef LIBBRILLO_INSTALL_ATTRIBUTES_MOCK_INSTALL_ATTRIBUTES_READER_H_ #define LIBBRILLO_INSTALL_ATTRIBUTES_MOCK_INSTALL_ATTRIBUTES_READER_H_ +#include <string> + #include "libinstallattributes.h" #include "bindings/install_attributes.pb.h" diff --git a/install_attributes/tests/libinstallattributes_unittest.cc b/install_attributes/tests/libinstallattributes_test.cc index 45ff827..686e565 100644 --- a/install_attributes/tests/libinstallattributes_unittest.cc +++ b/install_attributes/tests/libinstallattributes_test.cc @@ -76,8 +76,3 @@ TEST(InstallAttributesTest, NoProgressionFromEmptyToManaged) { ASSERT_TRUE(reader.IsLocked()); ASSERT_EQ(std::string(), reader.GetAttribute("enterprise.mode")); } - -int main(int argc, char* argv[]) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/libbrillo-395517.gypi b/libbrillo-395517.gypi deleted file mode 100644 index a846c70..0000000 --- a/libbrillo-395517.gypi +++ /dev/null @@ -1,8 +0,0 @@ -{ - 'variables': { - 'libbase_ver': 395517, - }, - 'includes': [ - '../libbrillo/libbrillo.gypi', - ], -} diff --git a/libbrillo-glib.pc.in b/libbrillo-glib.pc.in deleted file mode 100644 index cfd9fc8..0000000 --- a/libbrillo-glib.pc.in +++ /dev/null @@ -1,8 +0,0 @@ -bslot=@BSLOT@ - -Name: libbrillo-glib -Description: brillo glib wrapper library -Version: ${bslot} -Requires.private: @PRIVATE_PC@ -Libs: -lbrillo-glib-${bslot} - diff --git a/libbrillo-test.pc.in b/libbrillo-test.pc.in deleted file mode 100644 index 4fece7c..0000000 --- a/libbrillo-test.pc.in +++ /dev/null @@ -1,8 +0,0 @@ -bslot=@BSLOT@ - -Name: libbrillo-test -Description: brillo test library -Version: ${bslot} -# Because libbrillo-test is static, we have to depend directly on everything. -Requires: @PRIVATE_PC@ -Libs: -lbrillo-test-${bslot} diff --git a/libbrillo.gyp b/libbrillo.gyp deleted file mode 100644 index 5a2bbe4..0000000 --- a/libbrillo.gyp +++ /dev/null @@ -1,7 +0,0 @@ -{ - 'includes': [ - 'libbrillo-395517.gypi', - 'libinstallattributes.gypi', - 'libpolicy.gypi', - ] -} diff --git a/libbrillo.gypi b/libbrillo.gypi deleted file mode 100644 index 05d95e6..0000000 --- a/libbrillo.gypi +++ /dev/null @@ -1,465 +0,0 @@ -{ - 'target_defaults': { - 'variables': { - 'deps': [ - 'libchrome-<(libbase_ver)' - ], - }, - 'include_dirs': [ - '../libbrillo', - ], - 'defines': [ - 'USE_DBUS=<(USE_dbus)', - 'USE_RTTI_FOR_TYPE_TAGS', - ], - }, - 'targets': [ - { - 'target_name': 'libbrillo-<(libbase_ver)', - 'type': 'none', - 'dependencies': [ - 'libbrillo-core-<(libbase_ver)', - 'libbrillo-cryptohome-<(libbase_ver)', - 'libbrillo-http-<(libbase_ver)', - 'libbrillo-minijail-<(libbase_ver)', - 'libbrillo-streams-<(libbase_ver)', - 'libinstallattributes-<(libbase_ver)', - 'libpolicy-<(libbase_ver)', - ], - 'direct_dependent_settings': { - 'include_dirs': [ - '../libbrillo', - ], - }, - 'includes': ['../common-mk/deps.gypi'], - }, - { - 'target_name': 'libbrillo-core-<(libbase_ver)', - 'type': 'shared_library', - 'variables': { - 'exported_deps': [ - ], - 'conditions': [ - ['USE_dbus == 1', { - 'exported_deps': [ - 'dbus-1', - ], - }], - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'libraries': ['-lmodp_b64'], - 'sources': [ - 'brillo/asynchronous_signal_handler.cc', - 'brillo/backoff_entry.cc', - 'brillo/daemons/daemon.cc', - 'brillo/data_encoding.cc', - 'brillo/errors/error.cc', - 'brillo/errors/error_codes.cc', - 'brillo/file_utils.cc', - 'brillo/flag_helper.cc', - 'brillo/imageloader/manifest.cc', - 'brillo/key_value_store.cc', - 'brillo/message_loops/base_message_loop.cc', - 'brillo/message_loops/message_loop.cc', - 'brillo/message_loops/message_loop_utils.cc', - 'brillo/mime_utils.cc', - 'brillo/osrelease_reader.cc', - 'brillo/process.cc', - 'brillo/process_reaper.cc', - 'brillo/process_information.cc', - 'brillo/secure_blob.cc', - 'brillo/strings/string_utils.cc', - 'brillo/syslog_logging.cc', - 'brillo/type_name_undecorate.cc', - 'brillo/url_utils.cc', - 'brillo/userdb_utils.cc', - 'brillo/value_conversion.cc', - ], - 'conditions': [ - ['USE_dbus == 1', { - 'sources': [ - 'brillo/any.cc', - 'brillo/daemons/dbus_daemon.cc', - 'brillo/dbus/async_event_sequencer.cc', - 'brillo/dbus/data_serialization.cc', - 'brillo/dbus/dbus_connection.cc', - 'brillo/dbus/dbus_method_invoker.cc', - 'brillo/dbus/dbus_method_response.cc', - 'brillo/dbus/dbus_object.cc', - 'brillo/dbus/dbus_service_watcher.cc', - 'brillo/dbus/dbus_signal.cc', - 'brillo/dbus/exported_object_manager.cc', - 'brillo/dbus/exported_property_set.cc', - 'brillo/dbus/utils.cc', - ], - }], - ], - }, - { - 'target_name': 'libbrillo-http-<(libbase_ver)', - 'type': 'shared_library', - 'dependencies': [ - 'libbrillo-core-<(libbase_ver)', - 'libbrillo-streams-<(libbase_ver)', - ], - 'variables': { - 'exported_deps': [ - 'libcurl', - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'sources': [ - 'brillo/http/curl_api.cc', - 'brillo/http/http_connection_curl.cc', - 'brillo/http/http_form_data.cc', - 'brillo/http/http_request.cc', - 'brillo/http/http_transport.cc', - 'brillo/http/http_transport_curl.cc', - 'brillo/http/http_utils.cc', - ], - 'conditions': [ - ['USE_dbus == 1', { - 'sources': [ - 'brillo/http/http_proxy.cc', - ], - }], - ], - }, - { - 'target_name': 'libbrillo-streams-<(libbase_ver)', - 'type': 'shared_library', - 'dependencies': [ - 'libbrillo-core-<(libbase_ver)', - ], - 'variables': { - 'exported_deps': [ - 'openssl', - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'sources': [ - 'brillo/streams/file_stream.cc', - 'brillo/streams/input_stream_set.cc', - 'brillo/streams/memory_containers.cc', - 'brillo/streams/memory_stream.cc', - 'brillo/streams/openssl_stream_bio.cc', - 'brillo/streams/stream.cc', - 'brillo/streams/stream_errors.cc', - 'brillo/streams/stream_utils.cc', - 'brillo/streams/tls_stream.cc', - ], - }, - { - 'target_name': 'libbrillo-test-<(libbase_ver)', - 'type': 'static_library', - 'standalone_static_library': 1, - 'dependencies': [ - 'libbrillo-http-<(libbase_ver)', - ], - 'sources': [ - 'brillo/http/http_connection_fake.cc', - 'brillo/http/http_transport_fake.cc', - 'brillo/message_loops/fake_message_loop.cc', - 'brillo/streams/fake_stream.cc', - 'brillo/unittest_utils.cc', - ], - 'includes': ['../common-mk/deps.gypi'], - }, - { - 'target_name': 'libbrillo-cryptohome-<(libbase_ver)', - 'type': 'shared_library', - 'variables': { - 'exported_deps': [ - 'openssl', - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'sources': [ - 'brillo/cryptohome.cc', - ], - }, - { - 'target_name': 'libbrillo-minijail-<(libbase_ver)', - 'type': 'shared_library', - 'variables': { - 'exported_deps': [ - 'libminijail', - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'cflags': [ - '-fvisibility=default', - ], - 'sources': [ - 'brillo/minijail/minijail.cc', - ], - }, - { - 'target_name': 'libinstallattributes-<(libbase_ver)', - 'type': 'shared_library', - 'dependencies': [ - 'libinstallattributes-includes', - '../common-mk/external_dependencies.gyp:install_attributes-proto', - ], - 'variables': { - 'exported_deps': [ - 'protobuf-lite', - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'sources': [ - 'install_attributes/libinstallattributes.cc', - ], - }, - { - 'target_name': 'libpolicy-<(libbase_ver)', - 'type': 'shared_library', - 'dependencies': [ - 'libinstallattributes-<(libbase_ver)', - 'libpolicy-includes', - '../common-mk/external_dependencies.gyp:policy-protos', - ], - 'variables': { - 'exported_deps': [ - 'openssl', - 'protobuf-lite', - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'ldflags': [ - '-Wl,--version-script,<(platform2_root)/libbrillo/libpolicy.ver', - ], - 'sources': [ - 'policy/device_policy.cc', - 'policy/device_policy_impl.cc', - 'policy/policy_util.cc', - 'policy/resilient_policy_util.cc', - 'policy/libpolicy.cc', - ], - }, - { - 'target_name': 'libbrillo-glib-<(libbase_ver)', - 'type': 'shared_library', - 'dependencies': [ - 'libbrillo-<(libbase_ver)', - ], - 'variables': { - 'exported_deps': [ - 'glib-2.0', - 'gobject-2.0', - ], - 'conditions': [ - ['USE_dbus == 1', { - 'exported_deps': [ - 'dbus-1', - 'dbus-glib-1', - ], - }], - ], - 'deps': ['<@(exported_deps)'], - }, - 'cflags': [ - # glib uses the deprecated "register" attribute in some header files. - '-Wno-deprecated-register', - ], - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'includes': ['../common-mk/deps.gypi'], - 'conditions': [ - ['USE_dbus == 1', { - 'sources': [ - 'brillo/glib/abstract_dbus_service.cc', - 'brillo/glib/dbus.cc', - ], - }], - ], - }, - ], - 'conditions': [ - ['USE_test == 1', { - 'targets': [ - { - 'target_name': 'libbrillo-<(libbase_ver)_unittests', - 'type': 'executable', - 'dependencies': [ - 'libbrillo-<(libbase_ver)', - 'libbrillo-glib-<(libbase_ver)', - 'libbrillo-test-<(libbase_ver)', - ], - 'variables': { - 'deps': [ - 'libchrome-test-<(libbase_ver)', - ], - 'proto_in_dir': 'brillo/dbus', - 'proto_out_dir': 'include/brillo/dbus', - }, - 'includes': [ - '../common-mk/common_test.gypi', - '../common-mk/protoc.gypi', - ], - 'cflags': [ - '-Wno-format-zero-length', - ], - 'conditions': [ - ['debug == 1', { - 'cflags': [ - '-fprofile-arcs', - '-ftest-coverage', - '-fno-inline', - ], - 'libraries': [ - '-lgcov', - ], - }], - ], - 'sources': [ - 'brillo/asynchronous_signal_handler_unittest.cc', - 'brillo/backoff_entry_unittest.cc', - 'brillo/data_encoding_unittest.cc', - 'brillo/enum_flags_unittest.cc', - 'brillo/errors/error_codes_unittest.cc', - 'brillo/errors/error_unittest.cc', - 'brillo/file_utils_unittest.cc', - 'brillo/flag_helper_unittest.cc', - 'brillo/glib/object_unittest.cc', - 'brillo/http/http_connection_curl_unittest.cc', - 'brillo/http/http_form_data_unittest.cc', - 'brillo/http/http_request_unittest.cc', - 'brillo/http/http_transport_curl_unittest.cc', - 'brillo/http/http_utils_unittest.cc', - 'brillo/imageloader/manifest_unittest.cc', - 'brillo/key_value_store_unittest.cc', - 'brillo/map_utils_unittest.cc', - 'brillo/message_loops/base_message_loop_unittest.cc', - 'brillo/message_loops/fake_message_loop_unittest.cc', - 'brillo/message_loops/message_loop_unittest.cc', - 'brillo/mime_utils_unittest.cc', - 'brillo/osrelease_reader_unittest.cc', - 'brillo/process_reaper_unittest.cc', - 'brillo/process_unittest.cc', - 'brillo/secure_blob_unittest.cc', - 'brillo/streams/fake_stream_unittest.cc', - 'brillo/streams/file_stream_unittest.cc', - 'brillo/streams/input_stream_set_unittest.cc', - 'brillo/streams/memory_containers_unittest.cc', - 'brillo/streams/memory_stream_unittest.cc', - 'brillo/streams/openssl_stream_bio_unittests.cc', - 'brillo/streams/stream_unittest.cc', - 'brillo/streams/stream_utils_unittest.cc', - 'brillo/strings/string_utils_unittest.cc', - 'brillo/unittest_utils.cc', - 'brillo/url_utils_unittest.cc', - 'brillo/value_conversion_unittest.cc', - 'testrunner.cc', - ], - 'conditions': [ - ['USE_dbus == 1', { - 'sources': [ - 'brillo/any_unittest.cc', - 'brillo/any_internal_impl_unittest.cc', - 'brillo/dbus/async_event_sequencer_unittest.cc', - 'brillo/dbus/data_serialization_unittest.cc', - 'brillo/dbus/dbus_method_invoker_unittest.cc', - 'brillo/dbus/dbus_object_unittest.cc', - 'brillo/dbus/dbus_param_reader_unittest.cc', - 'brillo/dbus/dbus_param_writer_unittest.cc', - 'brillo/dbus/dbus_signal_handler_unittest.cc', - 'brillo/dbus/exported_object_manager_unittest.cc', - 'brillo/dbus/exported_property_set_unittest.cc', - 'brillo/http/http_proxy_unittest.cc', - 'brillo/type_name_undecorate_unittest.cc', - 'brillo/variant_dictionary_unittest.cc', - '<(proto_in_dir)/test.proto', - ], - }], - ], - }, - { - 'target_name': 'libinstallattributes-<(libbase_ver)_unittests', - 'type': 'executable', - 'dependencies': [ - '../common-mk/external_dependencies.gyp:install_attributes-proto', - 'libinstallattributes-<(libbase_ver)', - ], - 'includes': ['../common-mk/common_test.gypi'], - 'sources': [ - 'install_attributes/tests/libinstallattributes_unittest.cc', - ] - }, - { - 'target_name': 'libpolicy-<(libbase_ver)_unittests', - 'type': 'executable', - 'dependencies': [ - '../common-mk/external_dependencies.gyp:install_attributes-proto', - '../common-mk/external_dependencies.gyp:policy-protos', - 'libinstallattributes-<(libbase_ver)', - 'libpolicy-<(libbase_ver)', - ], - 'includes': ['../common-mk/common_test.gypi'], - 'sources': [ - 'install_attributes/mock_install_attributes_reader.cc', - 'policy/tests/device_policy_impl_unittest.cc', - 'policy/tests/libpolicy_unittest.cc', - 'policy/tests/policy_util_unittest.cc', - 'policy/tests/resilient_policy_util_unittest.cc', - ] - }, - ], - }], - ], -} diff --git a/libbrillo.pc.in b/libbrillo.pc.in deleted file mode 100644 index a3a9e07..0000000 --- a/libbrillo.pc.in +++ /dev/null @@ -1,8 +0,0 @@ -bslot=@BSLOT@ - -Name: libbrillo -Description: brillo base library -Version: ${bslot} -Requires.private: @PRIVATE_PC@ -Cflags: -DUSE_RTTI_FOR_TYPE_TAGS -Libs: -lbrillo-${bslot} diff --git a/libinstallattributes.gypi b/libinstallattributes.gypi deleted file mode 100644 index e0c0014..0000000 --- a/libinstallattributes.gypi +++ /dev/null @@ -1,16 +0,0 @@ -{ - 'targets': [ - { - 'target_name': 'libinstallattributes-includes', - 'type': 'none', - 'copies': [ - { - 'destination': '<(SHARED_INTERMEDIATE_DIR)/include/install_attributes', - 'files': [ - 'install_attributes/libinstallattributes.h', - ], - }, - ], - }, - ], -} diff --git a/libpolicy.gypi b/libpolicy.gypi deleted file mode 100644 index b3a3d49..0000000 --- a/libpolicy.gypi +++ /dev/null @@ -1,22 +0,0 @@ -{ - 'targets': [ - { - 'target_name': 'libpolicy-includes', - 'type': 'none', - 'copies': [ - { - 'destination': '<(SHARED_INTERMEDIATE_DIR)/include/policy', - 'files': [ - 'policy/device_policy.h', - 'policy/device_policy_impl.h', - 'policy/libpolicy.h', - 'policy/mock_libpolicy.h', - 'policy/mock_device_policy.h', - 'policy/policy_util.h', - 'policy/resilient_policy_util.h', - ], - }, - ], - }, - ], -} diff --git a/platform2_preinstall.sh b/platform2_preinstall.sh deleted file mode 100755 index 448a31a..0000000 --- a/platform2_preinstall.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash - -# Copyright (c) 2013 The Chromium OS Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -set -e - -OUT=$1 -shift -for v; do - # Extract all the libbrillo sublibs from 'dependencies' section of - # 'libbrillo-<(libbase_ver)' target in libbrillo.gypi and convert them - # into an array of "-lbrillo-<sublib>-<v>" flags. - sublibs=($(sed -n " - /'target_name': 'libbrillo-<(libbase_ver)'/,/target_name/ { - /dependencies/,/],/ { - /libbrillo/ { - s:[',]::g - s:<(libbase_ver):${v}:g - s:libbrillo:-lbrillo: - p - } - } - }" libbrillo.gypi)) - - echo "GROUP ( AS_NEEDED ( ${sublibs[@]} ) )" > "${OUT}"/lib/libbrillo-${v}.so - - deps=$(<"${OUT}"/gen/libbrillo-${v}-deps.txt) - pc="${OUT}"/lib/libbrillo-${v}.pc - - sed \ - -e "s/@BSLOT@/${v}/g" \ - -e "s/@PRIVATE_PC@/${deps}/g" \ - "libbrillo.pc.in" > "${pc}" - - deps_test=$(<"${OUT}"/gen/libbrillo-test-${v}-deps.txt) - deps_test+=" libbrillo-${v}" - sed \ - -e "s/@BSLOT@/${v}/g" \ - -e "s/@PRIVATE_PC@/${deps_test}/g" \ - "libbrillo-test.pc.in" > "${OUT}/lib/libbrillo-test-${v}.pc" - - - deps_glib=$(<"${OUT}"/gen/libbrillo-glib-${v}-deps.txt) - pc_glib="${OUT}"/lib/libbrillo-glib-${v}.pc - - sed \ - -e "s/@BSLOT@/${v}/g" \ - -e "s/@PRIVATE_PC@/${deps_glib}/g" \ - "libbrillo-glib.pc.in" > "${pc_glib}" -done diff --git a/policy/OWNERS b/policy/OWNERS new file mode 100644 index 0000000..0208469 --- /dev/null +++ b/policy/OWNERS @@ -0,0 +1,8 @@ +emaxx@chromium.org +ljusten@chromium.org +pmarko@chromium.org +poromov@chromium.org +rsorokin@chromium.org + +# TEAM: managed-devices@google.com +# COMPONENT: Enterprise>CloudPolicy diff --git a/policy/device_policy.h b/policy/device_policy.h index 5913d8c..29e1ed3 100644 --- a/policy/device_policy.h +++ b/policy/device_policy.h @@ -69,6 +69,10 @@ class DevicePolicy { // Returns true unless there is a policy on disk and loading it fails. virtual bool LoadPolicy() = 0; + // Returns true if OOBE has been completed and if the device has been enrolled + // as an enterprise or enterpriseAD device. + virtual bool IsEnterpriseEnrolled() const = 0; + // Writes the value of the DevicePolicyRefreshRate policy in |rate|. Returns // true on success. virtual bool GetPolicyRefreshRate(int* rate) const = 0; @@ -224,6 +228,18 @@ class DevicePolicy { virtual bool GetDeviceUpdateStagingSchedule( std::vector<DayPercentagePair>* staging_schedule_out) const = 0; + // Writes the value of the DeviceQuickFixBuildToken to + // |device_quick_fix_build_token|. + // Returns true if it has been written, or false if the policy was not set. + virtual bool GetDeviceQuickFixBuildToken( + std::string* device_quick_fix_build_token) const = 0; + + // Writes the value of the Directory API ID to |directory_api_id_out|. + // Returns true on success, false if the ID is not available (eg if the device + // is not enrolled). + virtual bool GetDeviceDirectoryApiId( + std::string* directory_api_id_out) const = 0; + private: // Verifies that the policy signature is correct. virtual bool VerifyPolicySignature() = 0; diff --git a/policy/device_policy_impl.cc b/policy/device_policy_impl.cc index 76b82a1..3f96d12 100644 --- a/policy/device_policy_impl.cc +++ b/policy/device_policy_impl.cc @@ -5,6 +5,7 @@ #include "policy/device_policy_impl.h" #include <algorithm> +#include <map> #include <memory> #include <set> #include <string> @@ -29,6 +30,12 @@ namespace em = enterprise_management; namespace policy { +// TODO(crbug.com/984789): Remove once support for OpenSSL <1.1 is dropped. +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#define EVP_MD_CTX_new EVP_MD_CTX_create +#define EVP_MD_CTX_free EVP_MD_CTX_destroy +#endif + // Maximum value of RollbackAllowedMilestones policy. const int kMaxRollbackAllowedMilestones = 4; @@ -54,36 +61,34 @@ bool ReadPublicKeyFromFile(const base::FilePath& key_file, bool VerifySignature(const std::string& signed_data, const std::string& signature, const std::string& public_key) { - EVP_MD_CTX ctx; - EVP_MD_CTX_init(&ctx); + std::unique_ptr<EVP_MD_CTX, void (*)(EVP_MD_CTX *)> ctx(EVP_MD_CTX_new(), + EVP_MD_CTX_free); + if (!ctx) + return false; const EVP_MD* digest = EVP_sha1(); char* key = const_cast<char*>(public_key.data()); BIO* bio = BIO_new_mem_buf(key, public_key.length()); - if (!bio) { - EVP_MD_CTX_cleanup(&ctx); + if (!bio) return false; - } EVP_PKEY* public_key_ssl = d2i_PUBKEY_bio(bio, nullptr); if (!public_key_ssl) { BIO_free_all(bio); - EVP_MD_CTX_cleanup(&ctx); return false; } const unsigned char* sig = reinterpret_cast<const unsigned char*>(signature.data()); - int rv = EVP_VerifyInit_ex(&ctx, digest, nullptr); + int rv = EVP_VerifyInit_ex(ctx.get(), digest, nullptr); if (rv == 1) { - EVP_VerifyUpdate(&ctx, signed_data.data(), signed_data.length()); - rv = EVP_VerifyFinal(&ctx, sig, signature.length(), public_key_ssl); + EVP_VerifyUpdate(ctx.get(), signed_data.data(), signed_data.length()); + rv = EVP_VerifyFinal(ctx.get(), sig, signature.length(), public_key_ssl); } EVP_PKEY_free(public_key_ssl); BIO_free_all(bio); - EVP_MD_CTX_cleanup(&ctx); return rv == 1; } @@ -196,6 +201,17 @@ bool DevicePolicyImpl::LoadPolicy() { return policy_loaded; } +bool DevicePolicyImpl::IsEnterpriseEnrolled() const { + DCHECK(install_attributes_reader_); + if (!install_attributes_reader_->IsLocked()) + return false; + + const std::string& device_mode = install_attributes_reader_->GetAttribute( + InstallAttributesReader::kAttrMode); + return device_mode == InstallAttributesReader::kDeviceModeEnterprise || + device_mode == InstallAttributesReader::kDeviceModeEnterpriseAD; +} + bool DevicePolicyImpl::GetPolicyRefreshRate(int* rate) const { if (!device_policy_.has_device_policy_refresh_rate()) return false; @@ -331,6 +347,9 @@ bool DevicePolicyImpl::GetReleaseChannelDelegated( } bool DevicePolicyImpl::GetUpdateDisabled(bool* update_disabled) const { + if (!IsEnterpriseEnrolled()) + return false; + if (!device_policy_.has_auto_update_settings()) return false; @@ -345,6 +364,9 @@ bool DevicePolicyImpl::GetUpdateDisabled(bool* update_disabled) const { bool DevicePolicyImpl::GetTargetVersionPrefix( std::string* target_version_prefix) const { + if (!IsEnterpriseEnrolled()) + return false; + if (!device_policy_.has_auto_update_settings()) return false; @@ -374,14 +396,7 @@ bool DevicePolicyImpl::GetRollbackToTargetVersion( bool DevicePolicyImpl::GetRollbackAllowedMilestones( int* rollback_allowed_milestones) const { // This policy can be only set for devices which are enterprise enrolled. - if (!install_attributes_reader_->IsLocked()) - return false; - if (install_attributes_reader_->GetAttribute( - InstallAttributesReader::kAttrMode) != - InstallAttributesReader::kDeviceModeEnterprise && - install_attributes_reader_->GetAttribute( - InstallAttributesReader::kAttrMode) != - InstallAttributesReader::kDeviceModeEnterpriseAD) + if (!IsEnterpriseEnrolled()) return false; if (device_policy_.has_auto_update_settings()) { @@ -398,8 +413,9 @@ bool DevicePolicyImpl::GetRollbackAllowedMilestones( } } // Policy is not present, use default for enterprise devices. - VLOG(1) << "RollbackAllowedMilestones policy is not set, using default 0."; - *rollback_allowed_milestones = 0; + VLOG(1) << "RollbackAllowedMilestones policy is not set, using default " + << kMaxRollbackAllowedMilestones << "."; + *rollback_allowed_milestones = kMaxRollbackAllowedMilestones; return true; } @@ -419,6 +435,9 @@ bool DevicePolicyImpl::GetScatterFactorInSeconds( bool DevicePolicyImpl::GetAllowedConnectionTypesForUpdate( std::set<std::string>* connection_types) const { + if (!IsEnterpriseEnrolled()) + return false; + if (!device_policy_.has_auto_update_settings()) return false; @@ -541,9 +560,9 @@ bool DevicePolicyImpl::GetDeviceUpdateStagingSchedule( if (!list_val) return false; - for (base::Value* const& pair_value : *list_val) { - base::DictionaryValue* day_percentage_pair; - if (!pair_value->GetAsDictionary(&day_percentage_pair)) + for (const auto& pair_value : base::ValueReferenceAdapter(*list_val)) { + const base::DictionaryValue* day_percentage_pair; + if (!pair_value.GetAsDictionary(&day_percentage_pair)) return false; int days, percentage; if (!day_percentage_pair->GetInteger("days", &days) || @@ -616,6 +635,8 @@ bool DevicePolicyImpl::GetSecondFactorAuthenticationMode(int* mode_out) const { bool DevicePolicyImpl::GetDisallowedTimeIntervals( std::vector<WeeklyTimeInterval>* intervals_out) const { intervals_out->clear(); + if (!IsEnterpriseEnrolled()) + return false; if (!device_policy_.has_auto_update_settings()) { return false; @@ -633,14 +654,14 @@ bool DevicePolicyImpl::GetDisallowedTimeIntervals( if (!list_val) return false; - for (base::Value* const& interval_value : *list_val) { - base::DictionaryValue* interval_dict; - if (!interval_value->GetAsDictionary(&interval_dict)) { + for (const auto& interval_value : base::ValueReferenceAdapter(*list_val)) { + const base::DictionaryValue* interval_dict; + if (!interval_value.GetAsDictionary(&interval_dict)) { LOG(ERROR) << "Invalid JSON string given. Interval is not a dict."; return false; } - base::DictionaryValue* start; - base::DictionaryValue* end; + const base::DictionaryValue* start; + const base::DictionaryValue* end; if (!interval_dict->GetDictionary("start", &start) || !interval_dict->GetDictionary("end", &end)) { LOG(ERROR) << "Interval is missing start/end."; @@ -659,6 +680,29 @@ bool DevicePolicyImpl::GetDisallowedTimeIntervals( return true; } +bool DevicePolicyImpl::GetDeviceQuickFixBuildToken( + std::string* device_quick_fix_build_token) const { + if (!IsEnterpriseEnrolled() || !device_policy_.has_auto_update_settings()) + return false; + + const em::AutoUpdateSettingsProto& proto = + device_policy_.auto_update_settings(); + if (!proto.has_device_quick_fix_build_token()) + return false; + + *device_quick_fix_build_token = proto.device_quick_fix_build_token(); + return true; +} + +bool DevicePolicyImpl::GetDeviceDirectoryApiId( + std::string* directory_api_id_out) const { + if (!policy_data_.has_directory_api_id()) + return false; + + *directory_api_id_out = policy_data_.directory_api_id(); + return true; +} + bool DevicePolicyImpl::VerifyPolicyFile(const base::FilePath& policy_path) { if (!verify_root_ownership_) { return true; diff --git a/policy/device_policy_impl.h b/policy/device_policy_impl.h index 6891312..47426df 100644 --- a/policy/device_policy_impl.h +++ b/policy/device_policy_impl.h @@ -40,6 +40,7 @@ class DevicePolicyImpl : public DevicePolicy { // DevicePolicy overrides: bool LoadPolicy() override; + bool IsEnterpriseEnrolled() const override; bool GetPolicyRefreshRate(int* rate) const override; bool GetUserWhitelist( std::vector<std::string>* user_whitelist) const override; @@ -83,6 +84,10 @@ class DevicePolicyImpl : public DevicePolicy { std::vector<WeeklyTimeInterval>* intervals_out) const override; bool GetDeviceUpdateStagingSchedule( std::vector<DayPercentagePair> *staging_schedule_out) const override; + bool GetDeviceQuickFixBuildToken( + std::string* device_quick_fix_build_token) const override; + bool GetDeviceDirectoryApiId( + std::string* device_directory_api_out) const override; // Methods that can be used only for testing. void set_policy_data_for_testing( diff --git a/policy/libpolicy.cc b/policy/libpolicy.cc index a0b7640..e972814 100644 --- a/policy/libpolicy.cc +++ b/policy/libpolicy.cc @@ -5,6 +5,7 @@ #include "policy/libpolicy.h" #include <memory> +#include <utility> #include <base/logging.h> diff --git a/policy/mock_device_policy.h b/policy/mock_device_policy.h index 90470e2..8bf4b07 100644 --- a/policy/mock_device_policy.h +++ b/policy/mock_device_policy.h @@ -52,62 +52,73 @@ class MockDevicePolicy : public DevicePolicy { } ~MockDevicePolicy() override = default; - MOCK_METHOD0(LoadPolicy, bool(void)); + MOCK_METHOD(bool, LoadPolicy, (), (override)); + MOCK_METHOD(bool, IsEnterpriseEnrolled, (), (const, override)); - MOCK_CONST_METHOD1(GetPolicyRefreshRate, - bool(int*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetUserWhitelist, bool(std::vector<std::string>*)); - MOCK_CONST_METHOD1(GetGuestModeEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetCameraEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetShowUserNames, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetDataRoamingEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetAllowNewUsers, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetMetricsEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetReportVersionInfo, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetReportActivityTimes, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetReportBootMode, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetEphemeralUsersEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetReleaseChannel, bool(std::string*)); - MOCK_CONST_METHOD1(GetReleaseChannelDelegated, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetUpdateDisabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetTargetVersionPrefix, bool(std::string*)); - MOCK_CONST_METHOD1(GetRollbackToTargetVersion, bool(int*)); - MOCK_CONST_METHOD1(GetRollbackAllowedMilestones, bool(int*)); - MOCK_CONST_METHOD1(GetScatterFactorInSeconds, - bool(int64_t*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetAllowedConnectionTypesForUpdate, - bool(std::set<std::string>*)); - MOCK_CONST_METHOD1(GetOpenNetworkConfiguration, bool(std::string*)); - MOCK_CONST_METHOD1(GetOwner, bool(std::string*)); - MOCK_CONST_METHOD1(GetHttpDownloadsEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetAuP2PEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetAllowKioskAppControlChromeVersion, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetUsbDetachableWhitelist, - bool(std::vector<DevicePolicy::UsbDeviceId>*)); - MOCK_CONST_METHOD1(GetAutoLaunchedKioskAppId, bool(std::string*)); - MOCK_CONST_METHOD0(IsEnterpriseManaged, bool()); - MOCK_CONST_METHOD1(GetSecondFactorAuthenticationMode, bool(int*)); - MOCK_CONST_METHOD1(GetDisallowedTimeIntervals, - bool(std::vector<WeeklyTimeInterval>*)); - MOCK_CONST_METHOD1(GetDeviceUpdateStagingSchedule, - bool(std::vector<DayPercentagePair>*)); - MOCK_METHOD0(VerifyPolicyFiles, bool(void)); - MOCK_METHOD0(VerifyPolicySignature, bool(void)); + MOCK_METHOD(bool, GetPolicyRefreshRate, (int*), (const, override)); + MOCK_METHOD(bool, + GetUserWhitelist, + (std::vector<std::string>*), + (const, override)); + MOCK_METHOD(bool, GetGuestModeEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetCameraEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetShowUserNames, (bool*), (const, override)); + MOCK_METHOD(bool, GetDataRoamingEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetAllowNewUsers, (bool*), (const, override)); + MOCK_METHOD(bool, GetMetricsEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetReportVersionInfo, (bool*), (const, override)); + MOCK_METHOD(bool, GetReportActivityTimes, (bool*), (const, override)); + MOCK_METHOD(bool, GetReportBootMode, (bool*), (const, override)); + MOCK_METHOD(bool, GetEphemeralUsersEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetReleaseChannel, (std::string*), (const, override)); + MOCK_METHOD(bool, GetReleaseChannelDelegated, (bool*), (const, override)); + MOCK_METHOD(bool, GetUpdateDisabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetTargetVersionPrefix, (std::string*), (const, override)); + MOCK_METHOD(bool, GetRollbackToTargetVersion, (int*), (const, override)); + MOCK_METHOD(bool, GetRollbackAllowedMilestones, (int*), (const, override)); + MOCK_METHOD(bool, GetScatterFactorInSeconds, (int64_t*), (const, override)); + MOCK_METHOD(bool, + GetAllowedConnectionTypesForUpdate, + (std::set<std::string>*), + (const, override)); + MOCK_METHOD(bool, + GetOpenNetworkConfiguration, + (std::string*), + (const, override)); + MOCK_METHOD(bool, GetOwner, (std::string*), (const, override)); + MOCK_METHOD(bool, GetHttpDownloadsEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetAuP2PEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, + GetAllowKioskAppControlChromeVersion, + (bool*), + (const, override)); + MOCK_METHOD(bool, + GetUsbDetachableWhitelist, + (std::vector<DevicePolicy::UsbDeviceId>*), + (const, override)); + MOCK_METHOD(bool, + GetAutoLaunchedKioskAppId, + (std::string*), + (const, override)); + MOCK_METHOD(bool, IsEnterpriseManaged, (), (const, override)); + MOCK_METHOD(bool, + GetSecondFactorAuthenticationMode, + (int*), + (const, override)); + MOCK_METHOD(bool, + GetDisallowedTimeIntervals, + (std::vector<WeeklyTimeInterval>*), + (const, override)); + MOCK_METHOD(bool, + GetDeviceUpdateStagingSchedule, + (std::vector<DayPercentagePair>*), + (const, override)); + MOCK_METHOD(bool, + GetDeviceQuickFixBuildToken, + (std::string*), + (const, override)); + MOCK_METHOD(bool, GetDeviceDirectoryApiId, (std::string*), (const, override)); + MOCK_METHOD(bool, VerifyPolicySignature, (), (override)); }; } // namespace policy diff --git a/policy/mock_libpolicy.h b/policy/mock_libpolicy.h index a0f6920..a04af7b 100644 --- a/policy/mock_libpolicy.h +++ b/policy/mock_libpolicy.h @@ -20,10 +20,10 @@ class MockPolicyProvider : public PolicyProvider { MockPolicyProvider() = default; ~MockPolicyProvider() override = default; - MOCK_METHOD0(Reload, bool(void)); - MOCK_CONST_METHOD0(device_policy_is_loaded, bool(void)); - MOCK_CONST_METHOD0(GetDevicePolicy, const DevicePolicy&(void)); - MOCK_CONST_METHOD0(IsConsumerDevice, bool(void)); + MOCK_METHOD(bool, Reload, (), (override)); + MOCK_METHOD(bool, device_policy_is_loaded, (), (const, override)); + MOCK_METHOD(const DevicePolicy&, GetDevicePolicy, (), (const, override)); + MOCK_METHOD(bool, IsConsumerDevice, (), (const, override)); private: DISALLOW_COPY_AND_ASSIGN(MockPolicyProvider); diff --git a/policy/tests/device_policy_impl_unittest.cc b/policy/tests/device_policy_impl_test.cc index 37c3916..2e68eb7 100644 --- a/policy/tests/device_policy_impl_unittest.cc +++ b/policy/tests/device_policy_impl_test.cc @@ -22,8 +22,8 @@ class DevicePolicyImplTest : public testing::Test, public DevicePolicyImpl { const em::ChromeDeviceSettingsProto& proto) { device_policy_.set_policy_for_testing(proto); device_policy_.set_install_attributes_for_testing( - std::make_unique<MockInstallAttributesReader>( - device_mode, true /* initialized */)); + std::make_unique<MockInstallAttributesReader>(device_mode, + true /* initialized */)); } DevicePolicyImpl device_policy_; @@ -108,7 +108,7 @@ TEST_F(DevicePolicyImplTest, GetRollbackAllowedMilestones_NotSet) { int value = -1; ASSERT_TRUE(device_policy_.GetRollbackAllowedMilestones(&value)); - EXPECT_EQ(0, value); + EXPECT_EQ(4, value); } // RollbackAllowedMilestones is set to a valid value. @@ -183,7 +183,7 @@ TEST_F(DevicePolicyImplTest, GetRollbackAllowedMilestones_SetTooSmall) { // Update staging schedule has no values TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_NoValues) { em::ChromeDeviceSettingsProto device_policy_proto; - em::AutoUpdateSettingsProto *auto_update_settings = + em::AutoUpdateSettingsProto* auto_update_settings = device_policy_proto.mutable_auto_update_settings(); auto_update_settings->set_staging_schedule("[]"); InitializePolicy(InstallAttributesReader::kDeviceModeEnterprise, @@ -197,7 +197,7 @@ TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_NoValues) { // Update staging schedule has valid values TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_Valid) { em::ChromeDeviceSettingsProto device_policy_proto; - em::AutoUpdateSettingsProto *auto_update_settings = + em::AutoUpdateSettingsProto* auto_update_settings = device_policy_proto.mutable_auto_update_settings(); auto_update_settings->set_staging_schedule( "[{\"days\": 4, \"percentage\": 40}, {\"days\": 10, \"percentage\": " @@ -214,7 +214,7 @@ TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_Valid) { // Update staging schedule has valid values, set using AD. TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_Valid_AD) { em::ChromeDeviceSettingsProto device_policy_proto; - em::AutoUpdateSettingsProto *auto_update_settings = + em::AutoUpdateSettingsProto* auto_update_settings = device_policy_proto.mutable_auto_update_settings(); auto_update_settings->set_staging_schedule( "[{\"days\": 4, \"percentage\": 40}, {\"days\": 10, \"percentage\": " @@ -233,7 +233,7 @@ TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_Valid_AD) { TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_SetOutsideAllowable) { em::ChromeDeviceSettingsProto device_policy_proto; - em::AutoUpdateSettingsProto *auto_update_settings = + em::AutoUpdateSettingsProto* auto_update_settings = device_policy_proto.mutable_auto_update_settings(); auto_update_settings->set_staging_schedule( "[{\"days\": -1, \"percentage\": -10}, {\"days\": 30, \"percentage\": " @@ -243,8 +243,118 @@ TEST_F(DevicePolicyImplTest, std::vector<DayPercentagePair> staging_schedule; ASSERT_TRUE(device_policy_.GetDeviceUpdateStagingSchedule(&staging_schedule)); - EXPECT_THAT(staging_schedule, ElementsAre(DayPercentagePair{1, 0}, - DayPercentagePair{28, 100})); + EXPECT_THAT(staging_schedule, + ElementsAre(DayPercentagePair{1, 0}, DayPercentagePair{28, 100})); +} + +// Updates should only be disabled for enterprise managed devices. +TEST_F(DevicePolicyImplTest, GetUpdateDisabled_SetConsumer) { + em::ChromeDeviceSettingsProto device_policy_proto; + em::AutoUpdateSettingsProto* auto_update_settings = + device_policy_proto.mutable_auto_update_settings(); + auto_update_settings->set_update_disabled(true); + InitializePolicy(InstallAttributesReader::kDeviceModeConsumer, + device_policy_proto); + + bool value; + ASSERT_FALSE(device_policy_.GetUpdateDisabled(&value)); +} + +// Updates should only be pinned on enterprise managed devices. +TEST_F(DevicePolicyImplTest, GetTargetVersionPrefix_SetConsumer) { + em::ChromeDeviceSettingsProto device_policy_proto; + em::AutoUpdateSettingsProto* auto_update_settings = + device_policy_proto.mutable_auto_update_settings(); + auto_update_settings->set_target_version_prefix("hello"); + InitializePolicy(InstallAttributesReader::kDeviceModeConsumer, + device_policy_proto); + + std::string value = ""; + ASSERT_FALSE(device_policy_.GetTargetVersionPrefix(&value)); +} + +// The allowed connection types should only be changed in enterprise devices. +TEST_F(DevicePolicyImplTest, GetAllowedConnectionTypesForUpdate_SetConsumer) { + em::ChromeDeviceSettingsProto device_policy_proto; + em::AutoUpdateSettingsProto* auto_update_settings = + device_policy_proto.mutable_auto_update_settings(); + auto_update_settings->add_allowed_connection_types( + em::AutoUpdateSettingsProto::CONNECTION_TYPE_ETHERNET); + InitializePolicy(InstallAttributesReader::kDeviceModeConsumer, + device_policy_proto); + + std::set<std::string> value; + ASSERT_FALSE(device_policy_.GetAllowedConnectionTypesForUpdate(&value)); +} + +// Update time restrictions should only be used in enterprise devices. +TEST_F(DevicePolicyImplTest, GetDisallowedTimeIntervals_SetConsumer) { + em::ChromeDeviceSettingsProto device_policy_proto; + em::AutoUpdateSettingsProto* auto_update_settings = + device_policy_proto.mutable_auto_update_settings(); + auto_update_settings->set_disallowed_time_intervals( + "[{\"start\": {\"day_of_week\": \"Monday\", \"hours\": 10, \"minutes\": " + "0}, \"end\": {\"day_of_week\": \"Monday\", \"hours\": 10, \"minutes\": " + "0}}]"); + InitializePolicy(InstallAttributesReader::kDeviceModeConsumer, + device_policy_proto); + + std::vector<WeeklyTimeInterval> value; + ASSERT_FALSE(device_policy_.GetDisallowedTimeIntervals(&value)); +} + +// |DeviceQuickFixBuildToken| is set when device is enterprise enrolled. +TEST_F(DevicePolicyImplTest, GetDeviceQuickFixBuildToken_Set) { + const char kToken[] = "some_token"; + + em::ChromeDeviceSettingsProto device_policy_proto; + em::AutoUpdateSettingsProto* auto_update_settings = + device_policy_proto.mutable_auto_update_settings(); + auto_update_settings->set_device_quick_fix_build_token(kToken); + InitializePolicy(InstallAttributesReader::kDeviceModeEnterprise, + device_policy_proto); + std::string value; + EXPECT_TRUE(device_policy_.GetDeviceQuickFixBuildToken(&value)); + EXPECT_EQ(value, kToken); +} + +// If the device is not enterprise-enrolled, |GetDeviceQuickFixBuildToken| +// does not provide a token even if it is present in local device settings. +TEST_F(DevicePolicyImplTest, GetDeviceQuickFixBuildToken_NotSet) { + const char kToken[] = "some_token"; + + em::ChromeDeviceSettingsProto device_policy_proto; + em::AutoUpdateSettingsProto* auto_update_settings = + device_policy_proto.mutable_auto_update_settings(); + auto_update_settings->set_device_quick_fix_build_token(kToken); + InitializePolicy(InstallAttributesReader::kDeviceModeConsumer, + device_policy_proto); + std::string value; + EXPECT_FALSE(device_policy_.GetDeviceQuickFixBuildToken(&value)); + EXPECT_TRUE(value.empty()); +} + +// Should only write a value and return true if the ID is present. +TEST_F(DevicePolicyImplTest, GetDeviceDirectoryApiId_Set) { + constexpr char kDummyDeviceId[] = "aa-bb-cc-dd"; + + em::PolicyData policy_data; + policy_data.set_directory_api_id(kDummyDeviceId); + + device_policy_.set_policy_data_for_testing(policy_data); + + std::string id; + EXPECT_TRUE(device_policy_.GetDeviceDirectoryApiId(&id)); + EXPECT_EQ(kDummyDeviceId, id); +} + +TEST_F(DevicePolicyImplTest, GetDeviceDirectoryApiId_NotSet) { + em::PolicyData policy_data; + device_policy_.set_policy_data_for_testing(policy_data); + + std::string id; + EXPECT_FALSE(device_policy_.GetDeviceDirectoryApiId(&id)); + EXPECT_TRUE(id.empty()); } } // namespace policy diff --git a/policy/tests/libpolicy_unittest.cc b/policy/tests/libpolicy_test.cc index aaf497c..b8414bb 100644 --- a/policy/tests/libpolicy_unittest.cc +++ b/policy/tests/libpolicy_test.cc @@ -132,7 +132,7 @@ TEST(PolicyTest, DevicePolicyAllSetTest) { int_value = -1; ASSERT_TRUE(policy.GetRollbackToTargetVersion(&int_value)); EXPECT_EQ(enterprise_management::AutoUpdateSettingsProto:: - ROLLBACK_WITH_FULL_POWERWASH, + ROLLBACK_AND_POWERWASH, int_value); int_value = -1; @@ -243,10 +243,10 @@ TEST(PolicyTest, DevicePolicyNoneSetTest) { EXPECT_FALSE(policy.GetUpdateDisabled(&bool_value)); EXPECT_FALSE(policy.GetTargetVersionPrefix(&string_value)); EXPECT_FALSE(policy.GetRollbackToTargetVersion(&int_value)); - // RollbackAllowedMilestones has the default value of 0 for enterprise + // RollbackAllowedMilestones has the default value of 4 for enterprise // devices. ASSERT_TRUE(policy.GetRollbackAllowedMilestones(&int_value)); - EXPECT_EQ(0, int_value); + EXPECT_EQ(4, int_value); EXPECT_FALSE(policy.GetScatterFactorInSeconds(&int64_value)); EXPECT_FALSE(policy.GetOpenNetworkConfiguration(&string_value)); EXPECT_FALSE(policy.GetHttpDownloadsEnabled(&bool_value)); @@ -358,8 +358,3 @@ TEST(PolicyTest, IsConsumerDeviceEnterpriseAd) { } } // namespace policy - -int main(int argc, char* argv[]) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/policy/tests/policy_util_unittest.cc b/policy/tests/policy_util_test.cc index f26622f..f26622f 100644 --- a/policy/tests/policy_util_unittest.cc +++ b/policy/tests/policy_util_test.cc diff --git a/policy/tests/resilient_policy_util_unittest.cc b/policy/tests/resilient_policy_util_test.cc index 0963b08..0963b08 100644 --- a/policy/tests/resilient_policy_util_unittest.cc +++ b/policy/tests/resilient_policy_util_test.cc |