diff options
author | Josh Gao <jmgao@google.com> | 2018-01-28 20:32:46 -0800 |
---|---|---|
committer | Josh Gao <jmgao@google.com> | 2018-01-30 15:22:41 -0800 |
commit | b800d88b346fb7bd0afd0d36dba7ab5b5f3f6744 (patch) | |
tree | 41fb8004c920841e360f1edc137e394b3c78db04 /adb | |
parent | fb413a230411c6944876fbeffbfbaef507c5014b (diff) | |
download | core-b800d88b346fb7bd0afd0d36dba7ab5b5f3f6744.tar.gz core-b800d88b346fb7bd0afd0d36dba7ab5b5f3f6744.tar.bz2 core-b800d88b346fb7bd0afd0d36dba7ab5b5f3f6744.zip |
adb: extract atransport's connection interface.
As step one of refactoring atransport to separate out protocol handling
from its underlying connection, extract atransport's existing
hand-rolled connection vtable out to its own abstract interface.
This should not change behavior except in one case: emulators are
now treated as TCP devices for the purposes of `adb disconnect`.
Test: python test_device.py, with walleye over USB + TCP
Test: manually connecting and disconnecting devices/emulators
Change-Id: I877b8027e567cc6a7461749432b49f6cb2c2f0d7
Diffstat (limited to 'adb')
-rw-r--r-- | adb/Android.mk | 1 | ||||
-rw-r--r-- | adb/adb.h | 3 | ||||
-rw-r--r-- | adb/services.cpp | 8 | ||||
-rw-r--r-- | adb/transport.cpp | 56 | ||||
-rw-r--r-- | adb/transport.h | 78 | ||||
-rw-r--r-- | adb/transport_local.cpp | 163 | ||||
-rw-r--r-- | adb/transport_test.cpp | 16 | ||||
-rw-r--r-- | adb/transport_usb.cpp | 81 |
8 files changed, 173 insertions, 233 deletions
diff --git a/adb/Android.mk b/adb/Android.mk index 0eeafb63c..e52f0cbef 100644 --- a/adb/Android.mk +++ b/adb/Android.mk @@ -11,6 +11,7 @@ adb_host_sanitize := adb_target_sanitize := ADB_COMMON_CFLAGS := \ + -frtti \ -Wall -Wextra -Werror \ -Wno-unused-parameter \ -Wno-missing-field-initializers \ @@ -136,9 +136,6 @@ int launch_server(const std::string& socket_spec); int adb_server_main(int is_daemon, const std::string& socket_spec, int ack_reply_fd); /* initialize a transport object's func pointers and state */ -#if ADB_HOST -int get_available_local_transport_index(); -#endif int init_socket_transport(atransport* t, int s, int port, int local); void init_usb_transport(atransport* t, usb_handle* usb); diff --git a/adb/services.cpp b/adb/services.cpp index aff7012ee..6dc71cfc4 100644 --- a/adb/services.cpp +++ b/adb/services.cpp @@ -407,14 +407,6 @@ void connect_emulator(const std::string& port_spec, std::string* response) { return; } - // Check if more emulators can be registered. Similar unproblematic - // race condition as above. - int candidate_slot = get_available_local_transport_index(); - if (candidate_slot < 0) { - *response = "Cannot accept more emulators"; - return; - } - // Preconditions met, try to connect to the emulator. std::string error; if (!local_connect_arbitrary_ports(console_port, adb_port, &error)) { diff --git a/adb/transport.cpp b/adb/transport.cpp index 4a9d91a1c..5acaaece6 100644 --- a/adb/transport.cpp +++ b/adb/transport.cpp @@ -41,6 +41,7 @@ #include "adb.h" #include "adb_auth.h" +#include "adb_io.h" #include "adb_trace.h" #include "adb_utils.h" #include "diagnose_usb.h" @@ -65,6 +66,36 @@ TransportId NextTransportId() { return next++; } +bool FdConnection::Read(apacket* packet) { + if (!ReadFdExactly(fd_.get(), &packet->msg, sizeof(amessage))) { + D("remote local: read terminated (message)"); + return false; + } + + if (!ReadFdExactly(fd_.get(), &packet->data, packet->msg.data_length)) { + D("remote local: terminated (data)"); + return false; + } + + return true; +} + +bool FdConnection::Write(apacket* packet) { + uint32_t length = packet->msg.data_length; + + if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(amessage) + length)) { + D("remote local: write terminated"); + return false; + } + + return true; +} + +void FdConnection::Close() { + adb_shutdown(fd_.get()); + fd_.reset(); +} + static std::string dump_packet(const char* name, const char* func, apacket* p) { unsigned command = p->msg.command; int len = p->msg.data_length; @@ -220,11 +251,18 @@ static void read_transport_thread(void* _t) { { ATRACE_NAME("read_transport read_remote"); - if (t->read_from_remote(p, t) != 0) { + if (!t->connection->Read(p)) { D("%s: remote read failed for transport", t->serial); put_apacket(p); break; } + + if (!check_header(p, t)) { + D("%s: remote read: bad header", t->serial); + put_apacket(p); + break; + } + #if ADB_HOST if (p->msg.command == 0) { put_apacket(p); @@ -626,7 +664,7 @@ static void transport_unref(atransport* t) { t->ref_count--; if (t->ref_count == 0) { D("transport: %s unref (kicking and closing)", t->serial); - t->close(t); + t->connection->Close(); remove_transport(t); } else { D("transport: %s unref (count=%zu)", t->serial, t->ref_count); @@ -754,14 +792,14 @@ atransport* acquire_one_transport(TransportType type, const char* serial, Transp } int atransport::Write(apacket* p) { - return write_func_(p, this); + return this->connection->Write(p) ? 0 : -1; } void atransport::Kick() { if (!kicked_) { + D("kicking transport %s", this->serial); kicked_ = true; - CHECK(kick_func_ != nullptr); - kick_func_(this); + this->connection->Close(); } } @@ -1083,8 +1121,12 @@ void register_usb_transport(usb_handle* usb, const char* serial, const char* dev // This should only be used for transports with connection_state == kCsNoPerm. void unregister_usb_transport(usb_handle* usb) { std::lock_guard<std::recursive_mutex> lock(transport_lock); - transport_list.remove_if( - [usb](atransport* t) { return t->usb == usb && t->GetConnectionState() == kCsNoPerm; }); + transport_list.remove_if([usb](atransport* t) { + if (auto connection = dynamic_cast<UsbConnection*>(t->connection.get())) { + return connection->handle_ == usb && t->GetConnectionState() == kCsNoPerm; + } + return false; + }); } bool check_header(apacket* p, atransport* t) { diff --git a/adb/transport.h b/adb/transport.h index 86cd9928c..9700f445b 100644 --- a/adb/transport.h +++ b/adb/transport.h @@ -28,10 +28,11 @@ #include <string> #include <unordered_set> -#include "adb.h" - #include <openssl/rsa.h> +#include "adb.h" +#include "adb_unique_fd.h" + typedef std::unordered_set<std::string> FeatureSet; const FeatureSet& supported_features(); @@ -56,6 +57,50 @@ extern const char* const kFeaturePushSync; TransportId NextTransportId(); +// Abstraction for a blocking packet transport. +struct Connection { + Connection() = default; + Connection(const Connection& copy) = delete; + Connection(Connection&& move) = delete; + + // Destroy a Connection. Formerly known as 'Close' in atransport. + virtual ~Connection() = default; + + // Read/Write a packet. These functions are concurrently called from a transport's reader/writer + // threads. + virtual bool Read(apacket* packet) = 0; + virtual bool Write(apacket* packet) = 0; + + // Terminate a connection. + // This method must be thread-safe, and must cause concurrent Reads/Writes to terminate. + // Formerly known as 'Kick' in atransport. + virtual void Close() = 0; +}; + +struct FdConnection : public Connection { + explicit FdConnection(unique_fd fd) : fd_(std::move(fd)) {} + + bool Read(apacket* packet) override final; + bool Write(apacket* packet) override final; + + void Close() override; + + private: + unique_fd fd_; +}; + +struct UsbConnection : public Connection { + explicit UsbConnection(usb_handle* handle) : handle_(handle) {} + ~UsbConnection(); + + bool Read(apacket* packet) override final; + bool Write(apacket* packet) override final; + + void Close() override final; + + usb_handle* handle_; +}; + class atransport { public: // TODO(danalbert): We expose waaaaaaay too much stuff because this was @@ -73,12 +118,6 @@ class atransport { } virtual ~atransport() {} - int (*read_from_remote)(apacket* p, atransport* t) = nullptr; - void (*close)(atransport* t) = nullptr; - - void SetWriteFunction(int (*write_func)(apacket*, atransport*)) { write_func_ = write_func; } - void SetKickFunction(void (*kick_func)(atransport*)) { kick_func_ = kick_func; } - bool IsKicked() { return kicked_; } int Write(apacket* p); void Kick(); @@ -95,9 +134,7 @@ class atransport { bool online = false; TransportType type = kTransportAny; - // USB handle or socket fd as needed. - usb_handle* usb = nullptr; - int sfd = -1; + std::unique_ptr<Connection> connection; // Used to identify transports for clients. char* serial = nullptr; @@ -105,22 +142,8 @@ class atransport { char* model = nullptr; char* device = nullptr; char* devpath = nullptr; - void SetLocalPortForEmulator(int port) { - CHECK_EQ(local_port_for_emulator_, -1); - local_port_for_emulator_ = port; - } - bool GetLocalPortForEmulator(int* port) const { - if (type == kTransportLocal && local_port_for_emulator_ != -1) { - *port = local_port_for_emulator_; - return true; - } - return false; - } - - bool IsTcpDevice() const { - return type == kTransportLocal && local_port_for_emulator_ == -1; - } + bool IsTcpDevice() const { return type == kTransportLocal; } #if ADB_HOST std::shared_ptr<RSA> NextKey(); @@ -165,10 +188,7 @@ class atransport { bool MatchesTarget(const std::string& target) const; private: - int local_port_for_emulator_ = -1; bool kicked_ = false; - void (*kick_func_)(atransport*) = nullptr; - int (*write_func_)(apacket*, atransport*) = nullptr; // A set of features transmitted in the banner with the initial connection. // This is stored in the banner as 'features=feature0,feature1,etc'. diff --git a/adb/transport_local.cpp b/adb/transport_local.cpp index d6c84dac5..560a0312b 100644 --- a/adb/transport_local.cpp +++ b/adb/transport_local.cpp @@ -28,10 +28,12 @@ #include <condition_variable> #include <mutex> #include <thread> +#include <unordered_map> #include <vector> #include <android-base/parsenetaddress.h> #include <android-base/stringprintf.h> +#include <android-base/thread_annotations.h> #include <cutils/sockets.h> #if !ADB_HOST @@ -40,6 +42,7 @@ #include "adb.h" #include "adb_io.h" +#include "adb_unique_fd.h" #include "adb_utils.h" #include "sysdeps/chrono.h" @@ -53,48 +56,15 @@ static std::mutex& local_transports_lock = *new std::mutex(); -/* we keep a list of opened transports. The atransport struct knows to which - * local transport it is connected. The list is used to detect when we're - * trying to connect twice to a given local transport. - */ -static atransport* local_transports[ ADB_LOCAL_TRANSPORT_MAX ]; +// We keep a map from emulator port to transport. +// TODO: weak_ptr? +static auto& local_transports GUARDED_BY(local_transports_lock) = + *new std::unordered_map<int, atransport*>(); #endif /* ADB_HOST */ -static int remote_read(apacket *p, atransport *t) -{ - if (!ReadFdExactly(t->sfd, &p->msg, sizeof(amessage))) { - D("remote local: read terminated (message)"); - return -1; - } - - if (!check_header(p, t)) { - D("bad header: terminated (data)"); - return -1; - } - - if (!ReadFdExactly(t->sfd, p->data, p->msg.data_length)) { - D("remote local: terminated (data)"); - return -1; - } - - return 0; -} - -static int remote_write(apacket *p, atransport *t) -{ - int length = p->msg.data_length; - - if(!WriteFdExactly(t->sfd, &p->msg, sizeof(amessage) + length)) { - D("remote local: write terminated"); - return -1; - } - - return 0; -} - bool local_connect(int port) { std::string dummy; - return local_connect_arbitrary_ports(port-1, port, &dummy) == 0; + return local_connect_arbitrary_ports(port - 1, port, &dummy) == 0; } void connect_device(const std::string& address, std::string* response) { @@ -423,130 +393,83 @@ void local_init(int port) std::thread(func, port).detach(); } -static void remote_kick(atransport *t) -{ - int fd = t->sfd; - t->sfd = -1; - adb_shutdown(fd); - adb_close(fd); - #if ADB_HOST - int nn; - std::lock_guard<std::mutex> lock(local_transports_lock); - for (nn = 0; nn < ADB_LOCAL_TRANSPORT_MAX; nn++) { - if (local_transports[nn] == t) { - local_transports[nn] = NULL; - break; - } - } -#endif -} +struct EmulatorConnection : public FdConnection { + EmulatorConnection(unique_fd fd, int local_port) + : FdConnection(std::move(fd)), local_port_(local_port) {} -static void remote_close(atransport *t) -{ - int fd = t->sfd; - if (fd != -1) { - t->sfd = -1; - adb_close(fd); - } -#if ADB_HOST - int local_port; - if (t->GetLocalPortForEmulator(&local_port)) { - VLOG(TRANSPORT) << "remote_close, local_port = " << local_port; + ~EmulatorConnection() { + VLOG(TRANSPORT) << "remote_close, local_port = " << local_port_; std::unique_lock<std::mutex> lock(retry_ports_lock); RetryPort port; - port.port = local_port; + port.port = local_port_; port.retry_count = LOCAL_PORT_RETRY_COUNT; retry_ports.push_back(port); retry_ports_cond.notify_one(); } -#endif -} + void Close() override { + std::lock_guard<std::mutex> lock(local_transports_lock); + local_transports.erase(local_port_); + FdConnection::Close(); + } + + int local_port_; +}; -#if ADB_HOST /* Only call this function if you already hold local_transports_lock. */ static atransport* find_emulator_transport_by_adb_port_locked(int adb_port) -{ - int i; - for (i = 0; i < ADB_LOCAL_TRANSPORT_MAX; i++) { - int local_port; - if (local_transports[i] && local_transports[i]->GetLocalPortForEmulator(&local_port)) { - if (local_port == adb_port) { - return local_transports[i]; - } - } + REQUIRES(local_transports_lock) { + auto it = local_transports.find(adb_port); + if (it == local_transports.end()) { + return nullptr; } - return NULL; + return it->second; } -std::string getEmulatorSerialString(int console_port) -{ +std::string getEmulatorSerialString(int console_port) { return android::base::StringPrintf("emulator-%d", console_port); } -atransport* find_emulator_transport_by_adb_port(int adb_port) -{ +atransport* find_emulator_transport_by_adb_port(int adb_port) { std::lock_guard<std::mutex> lock(local_transports_lock); - atransport* result = find_emulator_transport_by_adb_port_locked(adb_port); - return result; + return find_emulator_transport_by_adb_port_locked(adb_port); } -atransport* find_emulator_transport_by_console_port(int console_port) -{ +atransport* find_emulator_transport_by_console_port(int console_port) { return find_transport(getEmulatorSerialString(console_port).c_str()); } - - -/* Only call this function if you already hold local_transports_lock. */ -int get_available_local_transport_index_locked() -{ - int i; - for (i = 0; i < ADB_LOCAL_TRANSPORT_MAX; i++) { - if (local_transports[i] == NULL) { - return i; - } - } - return -1; -} - -int get_available_local_transport_index() -{ - std::lock_guard<std::mutex> lock(local_transports_lock); - int result = get_available_local_transport_index_locked(); - return result; -} #endif -int init_socket_transport(atransport *t, int s, int adb_port, int local) -{ - int fail = 0; +int init_socket_transport(atransport* t, int s, int adb_port, int local) { + int fail = 0; - t->SetKickFunction(remote_kick); - t->SetWriteFunction(remote_write); - t->close = remote_close; - t->read_from_remote = remote_read; - t->sfd = s; + unique_fd fd(s); t->sync_token = 1; t->type = kTransportLocal; #if ADB_HOST + // Emulator connection. if (local) { + t->connection.reset(new EmulatorConnection(std::move(fd), adb_port)); std::lock_guard<std::mutex> lock(local_transports_lock); - t->SetLocalPortForEmulator(adb_port); atransport* existing_transport = find_emulator_transport_by_adb_port_locked(adb_port); - int index = get_available_local_transport_index_locked(); if (existing_transport != NULL) { D("local transport for port %d already registered (%p)?", adb_port, existing_transport); fail = -1; - } else if (index < 0) { + } else if (local_transports.size() >= ADB_LOCAL_TRANSPORT_MAX) { // Too many emulators. D("cannot register more emulators. Maximum is %d", ADB_LOCAL_TRANSPORT_MAX); fail = -1; } else { - local_transports[index] = t; + local_transports[adb_port] = t; } + + return fail; } #endif + + // Regular tcp connection. + t->connection.reset(new FdConnection(std::move(fd))); return fail; } diff --git a/adb/transport_test.cpp b/adb/transport_test.cpp index 68689d4a6..d987d4fa5 100644 --- a/adb/transport_test.cpp +++ b/adb/transport_test.cpp @@ -20,22 +20,6 @@ #include "adb.h" -TEST(transport, kick_transport) { - atransport t; - static size_t kick_count; - kick_count = 0; - // Mutate some member so we can test that the function is run. - t.SetKickFunction([](atransport* trans) { kick_count++; }); - ASSERT_FALSE(t.IsKicked()); - t.Kick(); - ASSERT_TRUE(t.IsKicked()); - ASSERT_EQ(1u, kick_count); - // A transport can only be kicked once. - t.Kick(); - ASSERT_TRUE(t.IsKicked()); - ASSERT_EQ(1u, kick_count); -} - static void DisconnectFunc(void* arg, atransport*) { int* count = reinterpret_cast<int*>(arg); ++*count; diff --git a/adb/transport_usb.cpp b/adb/transport_usb.cpp index 347482016..73e8e15df 100644 --- a/adb/transport_usb.cpp +++ b/adb/transport_usb.cpp @@ -80,25 +80,18 @@ static int UsbReadPayload(usb_handle* h, apacket* p) { #endif } -static int remote_read(apacket* p, atransport* t) { - int n = UsbReadMessage(t->usb, &p->msg); +static int remote_read(apacket* p, usb_handle* usb) { + int n = UsbReadMessage(usb, &p->msg); if (n < 0) { D("remote usb: read terminated (message)"); return -1; } - if (static_cast<size_t>(n) != sizeof(p->msg) || !check_header(p, t)) { - D("remote usb: check_header failed, skip it"); - goto err_msg; - } - if (t->GetConnectionState() == kCsOffline) { - // If we read a wrong msg header declaring a large message payload, don't read its payload. - // Otherwise we may miss true messages from the device. - if (p->msg.command != A_CNXN && p->msg.command != A_AUTH) { - goto err_msg; - } + if (static_cast<size_t>(n) != sizeof(p->msg)) { + D("remote usb: read received unexpected header length %d", n); + return -1; } if (p->msg.data_length) { - n = UsbReadPayload(t->usb, p); + n = UsbReadPayload(usb, p); if (n < 0) { D("remote usb: terminated (data)"); return -1; @@ -106,34 +99,24 @@ static int remote_read(apacket* p, atransport* t) { if (static_cast<uint32_t>(n) != p->msg.data_length) { D("remote usb: read payload failed (need %u bytes, give %d bytes), skip it", p->msg.data_length, n); - goto err_msg; + return -1; } } return 0; - -err_msg: - p->msg.command = 0; - return 0; } #else // On Android devices, we rely on the kernel to provide buffered read. // So we can recover automatically from EOVERFLOW. -static int remote_read(apacket *p, atransport *t) -{ - if (usb_read(t->usb, &p->msg, sizeof(amessage))) { +static int remote_read(apacket* p, usb_handle* usb) { + if (usb_read(usb, &p->msg, sizeof(amessage))) { PLOG(ERROR) << "remote usb: read terminated (message)"; return -1; } - if (!check_header(p, t)) { - LOG(ERROR) << "remote usb: check_header failed"; - return -1; - } - if (p->msg.data_length) { - if (usb_read(t->usb, p->data, p->msg.data_length)) { + if (usb_read(usb, p->data, p->msg.data_length)) { PLOG(ERROR) << "remote usb: terminated (data)"; return -1; } @@ -143,45 +126,43 @@ static int remote_read(apacket *p, atransport *t) } #endif -static int remote_write(apacket *p, atransport *t) -{ - unsigned size = p->msg.data_length; +UsbConnection::~UsbConnection() { + usb_close(handle_); +} - if (usb_write(t->usb, &p->msg, sizeof(amessage))) { +bool UsbConnection::Read(apacket* packet) { + int rc = remote_read(packet, handle_); + return rc == 0; +} + +bool UsbConnection::Write(apacket* packet) { + unsigned size = packet->msg.data_length; + + if (usb_write(handle_, &packet->msg, sizeof(packet->msg)) != 0) { PLOG(ERROR) << "remote usb: 1 - write terminated"; - return -1; + return false; } - if (p->msg.data_length == 0) return 0; - if (usb_write(t->usb, &p->data, size)) { + + if (packet->msg.data_length != 0 && usb_write(handle_, &packet->data, size) != 0) { PLOG(ERROR) << "remote usb: 2 - write terminated"; - return -1; + return false; } - return 0; -} - -static void remote_close(atransport* t) { - usb_close(t->usb); - t->usb = 0; + return true; } -static void remote_kick(atransport* t) { - usb_kick(t->usb); +void UsbConnection::Close() { + usb_kick(handle_); } void init_usb_transport(atransport* t, usb_handle* h) { D("transport: usb"); - t->close = remote_close; - t->SetKickFunction(remote_kick); - t->SetWriteFunction(remote_write); - t->read_from_remote = remote_read; + t->connection.reset(new UsbConnection(h)); t->sync_token = 1; t->type = kTransportUsb; - t->usb = h; } -int is_adb_interface(int usb_class, int usb_subclass, int usb_protocol) -{ +int is_adb_interface(int usb_class, int usb_subclass, int usb_protocol) { return (usb_class == ADB_CLASS && usb_subclass == ADB_SUBCLASS && usb_protocol == ADB_PROTOCOL); } |