diff options
| -rw-r--r-- | fastboot/Android.mk | 1 | ||||
| -rw-r--r-- | fastboot/socket.cpp | 70 | ||||
| -rw-r--r-- | fastboot/socket.h | 25 | ||||
| -rw-r--r-- | fastboot/socket_mock.cpp | 31 | ||||
| -rw-r--r-- | fastboot/socket_mock.h | 6 | ||||
| -rw-r--r-- | fastboot/socket_test.cpp | 116 | ||||
| -rw-r--r-- | include/cutils/sockets.h | 24 | ||||
| -rw-r--r-- | libcutils/Android.mk | 4 | ||||
| -rw-r--r-- | libcutils/sockets_unix.cpp (renamed from libcutils/sockets_unix.c) | 34 | ||||
| -rw-r--r-- | libcutils/sockets_windows.cpp (renamed from libcutils/sockets_windows.c) | 35 | ||||
| -rw-r--r-- | libcutils/tests/sockets_test.cpp | 8 |
11 files changed, 275 insertions, 79 deletions
diff --git a/fastboot/Android.mk b/fastboot/Android.mk index 65f4e01fa..11d769bad 100644 --- a/fastboot/Android.mk +++ b/fastboot/Android.mk @@ -65,6 +65,7 @@ LOCAL_STATIC_LIBRARIES := \ libdiagnose_usb \ libbase \ libcutils \ + libgtest_host \ # libf2fs_dlutils_host will dlopen("libf2fs_fmt_host_dyn") LOCAL_CFLAGS_linux := -DUSE_F2FS diff --git a/fastboot/socket.cpp b/fastboot/socket.cpp index 0a3ddfa2f..d49f47ff2 100644 --- a/fastboot/socket.cpp +++ b/fastboot/socket.cpp @@ -89,7 +89,8 @@ class UdpSocket : public Socket { UdpSocket(Type type, cutils_socket_t sock); - ssize_t Send(const void* data, size_t length) override; + bool Send(const void* data, size_t length) override; + bool Send(std::vector<cutils_socket_buffer_t> buffers) override; ssize_t Receive(void* data, size_t length, int timeout_ms) override; private: @@ -109,9 +110,20 @@ UdpSocket::UdpSocket(Type type, cutils_socket_t sock) : Socket(sock) { } } -ssize_t UdpSocket::Send(const void* data, size_t length) { +bool UdpSocket::Send(const void* data, size_t length) { return TEMP_FAILURE_RETRY(sendto(sock_, reinterpret_cast<const char*>(data), length, 0, - reinterpret_cast<sockaddr*>(addr_.get()), addr_size_)); + reinterpret_cast<sockaddr*>(addr_.get()), addr_size_)) == + static_cast<ssize_t>(length); +} + +bool UdpSocket::Send(std::vector<cutils_socket_buffer_t> buffers) { + size_t total_length = 0; + for (const auto& buffer : buffers) { + total_length += buffer.length; + } + + return TEMP_FAILURE_RETRY(socket_send_buffers_function_( + sock_, buffers.data(), buffers.size())) == static_cast<ssize_t>(total_length); } ssize_t UdpSocket::Receive(void* data, size_t length, int timeout_ms) { @@ -135,7 +147,8 @@ class TcpSocket : public Socket { public: TcpSocket(cutils_socket_t sock) : Socket(sock) {} - ssize_t Send(const void* data, size_t length) override; + bool Send(const void* data, size_t length) override; + bool Send(std::vector<cutils_socket_buffer_t> buffers) override; ssize_t Receive(void* data, size_t length, int timeout_ms) override; std::unique_ptr<Socket> Accept() override; @@ -144,23 +157,52 @@ class TcpSocket : public Socket { DISALLOW_COPY_AND_ASSIGN(TcpSocket); }; -ssize_t TcpSocket::Send(const void* data, size_t length) { - size_t total = 0; +bool TcpSocket::Send(const void* data, size_t length) { + while (length > 0) { + ssize_t sent = + TEMP_FAILURE_RETRY(send(sock_, reinterpret_cast<const char*>(data), length, 0)); - while (total < length) { - ssize_t bytes = TEMP_FAILURE_RETRY( - send(sock_, reinterpret_cast<const char*>(data) + total, length - total, 0)); + if (sent == -1) { + return false; + } + length -= sent; + } - if (bytes == -1) { - if (total == 0) { - return -1; + return true; +} + +bool TcpSocket::Send(std::vector<cutils_socket_buffer_t> buffers) { + while (!buffers.empty()) { + ssize_t sent = TEMP_FAILURE_RETRY( + socket_send_buffers_function_(sock_, buffers.data(), buffers.size())); + + if (sent == -1) { + return false; + } + + // Adjust the buffers to skip past the bytes we've just sent. + auto iter = buffers.begin(); + while (sent > 0) { + if (iter->length > static_cast<size_t>(sent)) { + // Incomplete buffer write; adjust the buffer to point to the next byte to send. + iter->length -= sent; + iter->data = reinterpret_cast<const char*>(iter->data) + sent; + break; } + + // Complete buffer write; move on to the next buffer. + sent -= iter->length; + ++iter; + } + + // Shortcut the common case: we've written everything remaining. + if (iter == buffers.end()) { break; } - total += bytes; + buffers.erase(buffers.begin(), iter); } - return total; + return true; } ssize_t TcpSocket::Receive(void* data, size_t length, int timeout_ms) { diff --git a/fastboot/socket.h b/fastboot/socket.h index a7481dba4..c0bd7c96c 100644 --- a/fastboot/socket.h +++ b/fastboot/socket.h @@ -33,11 +33,15 @@ #ifndef SOCKET_H_ #define SOCKET_H_ +#include <functional> #include <memory> #include <string> +#include <utility> +#include <vector> #include <android-base/macros.h> #include <cutils/sockets.h> +#include <gtest/gtest_prod.h> // Socket interface to be implemented for each platform. class Socket { @@ -64,8 +68,17 @@ class Socket { virtual ~Socket(); // Sends |length| bytes of |data|. For TCP sockets this will continue trying to send until all - // bytes are transmitted. Returns the number of bytes actually sent or -1 on error. - virtual ssize_t Send(const void* data, size_t length) = 0; + // bytes are transmitted. Returns true on success. + virtual bool Send(const void* data, size_t length) = 0; + + // Sends |buffers| using multi-buffer write, which can be significantly faster than making + // multiple calls. For UDP sockets |buffers| are all combined into a single datagram; for + // TCP sockets this will continue sending until all buffers are fully transmitted. Returns true + // on success. + // + // Note: This is non-functional for UDP server Sockets because it's not currently needed and + // would require an additional sendto() variation of multi-buffer write. + virtual bool Send(std::vector<cutils_socket_buffer_t> buffers) = 0; // Waits up to |timeout_ms| to receive up to |length| bytes of data. |timout_ms| of 0 will // block forever. Returns the number of bytes received or -1 on error/timeout. On timeout @@ -94,9 +107,17 @@ class Socket { cutils_socket_t sock_ = INVALID_SOCKET; + // Non-class functions we want to override during tests to verify functionality. Implementation + // should call this rather than using socket_send_buffers() directly. + std::function<ssize_t(cutils_socket_t, cutils_socket_buffer_t*, size_t)> + socket_send_buffers_function_ = &socket_send_buffers; + private: int receive_timeout_ms_ = 0; + FRIEND_TEST(SocketTest, TestTcpSendBuffers); + FRIEND_TEST(SocketTest, TestUdpSendBuffers); + DISALLOW_COPY_AND_ASSIGN(Socket); }; diff --git a/fastboot/socket_mock.cpp b/fastboot/socket_mock.cpp index 8fea55466..bcb91ecf3 100644 --- a/fastboot/socket_mock.cpp +++ b/fastboot/socket_mock.cpp @@ -38,26 +38,35 @@ SocketMock::~SocketMock() { } } -ssize_t SocketMock::Send(const void* data, size_t length) { +bool SocketMock::Send(const void* data, size_t length) { if (events_.empty()) { ADD_FAILURE() << "Send() was called when no message was expected"; - return -1; + return false; } if (events_.front().type != EventType::kSend) { ADD_FAILURE() << "Send() was called out-of-order"; - return -1; + return false; } std::string message(reinterpret_cast<const char*>(data), length); if (events_.front().message != message) { ADD_FAILURE() << "Send() expected " << events_.front().message << ", but got " << message; - return -1; + return false; } - ssize_t return_value = events_.front().return_value; events_.pop(); - return return_value; + return true; +} + +// Mock out multi-buffer send to be one large send, since that's what it should looks like from +// the user's perspective. +bool SocketMock::Send(std::vector<cutils_socket_buffer_t> buffers) { + std::string data; + for (const auto& buffer : buffers) { + data.append(reinterpret_cast<const char*>(buffer.data), buffer.length); + } + return Send(data.data(), data.size()); } ssize_t SocketMock::Receive(void* data, size_t length, int /*timeout_ms*/) { @@ -106,13 +115,13 @@ std::unique_ptr<Socket> SocketMock::Accept() { } void SocketMock::ExpectSend(std::string message) { - ssize_t return_value = message.length(); - events_.push(Event(EventType::kSend, std::move(message), return_value, nullptr)); + events_.push(Event(EventType::kSend, std::move(message), 0, nullptr)); } -void SocketMock::ExpectSendFailure(std::string message) { - events_.push(Event(EventType::kSend, std::move(message), -1, nullptr)); -} +// TODO: make this properly return false to the caller. +//void SocketMock::ExpectSendFailure(std::string message) { +// events_.push(Event(EventType::kSend, std::move(message), 0, nullptr)); +//} void SocketMock::AddReceive(std::string message) { ssize_t return_value = message.length(); diff --git a/fastboot/socket_mock.h b/fastboot/socket_mock.h index 3e62b330e..c48aa7bd9 100644 --- a/fastboot/socket_mock.h +++ b/fastboot/socket_mock.h @@ -56,7 +56,8 @@ class SocketMock : public Socket { SocketMock(); ~SocketMock() override; - ssize_t Send(const void* data, size_t length) override; + bool Send(const void* data, size_t length) override; + bool Send(std::vector<cutils_socket_buffer_t> buffers) override; ssize_t Receive(void* data, size_t length, int timeout_ms) override; int Close() override; virtual std::unique_ptr<Socket> Accept(); @@ -64,9 +65,6 @@ class SocketMock : public Socket { // Adds an expectation for Send(). void ExpectSend(std::string message); - // Adds an expectation for Send() that returns -1. - void ExpectSendFailure(std::string message); - // Adds data to provide for Receive(). void AddReceive(std::string message); diff --git a/fastboot/socket_test.cpp b/fastboot/socket_test.cpp index 7bfe96714..9365792a1 100644 --- a/fastboot/socket_test.cpp +++ b/fastboot/socket_test.cpp @@ -23,8 +23,10 @@ #include "socket.h" #include "socket_mock.h" -#include <gtest/gtest.h> +#include <list> + #include <gtest/gtest-spi.h> +#include <gtest/gtest.h> enum { kTestTimeoutMs = 3000 }; @@ -59,7 +61,7 @@ bool MakeConnectedSockets(Socket::Protocol protocol, std::unique_ptr<Socket>* se // Sends a string over a Socket. Returns true if the full string (without terminating char) // was sent. static bool SendString(Socket* sock, const std::string& message) { - return sock->Send(message.c_str(), message.length()) == static_cast<ssize_t>(message.length()); + return sock->Send(message.c_str(), message.length()); } // Receives a string from a Socket. Returns true if the full string (without terminating char) @@ -123,6 +125,116 @@ TEST(SocketTest, TestUdpReceiveOverflow) { } } +// Tests UDP multi-buffer send. +TEST(SocketTest, TestUdpSendBuffers) { + std::unique_ptr<Socket> sock = Socket::NewServer(Socket::Protocol::kUdp, 0); + std::vector<std::string> data{"foo", "bar", "12345"}; + std::vector<cutils_socket_buffer_t> buffers{{data[0].data(), data[0].length()}, + {data[1].data(), data[1].length()}, + {data[2].data(), data[2].length()}}; + ssize_t mock_return_value = 0; + + // Mock out socket_send_buffers() to verify we're sending in the correct buffers and + // return |mock_return_value|. + sock->socket_send_buffers_function_ = [&buffers, &mock_return_value]( + cutils_socket_t /*cutils_sock*/, cutils_socket_buffer_t* sent_buffers, + size_t num_sent_buffers) -> ssize_t { + EXPECT_EQ(buffers.size(), num_sent_buffers); + for (size_t i = 0; i < num_sent_buffers; ++i) { + EXPECT_EQ(buffers[i].data, sent_buffers[i].data); + EXPECT_EQ(buffers[i].length, sent_buffers[i].length); + } + return mock_return_value; + }; + + mock_return_value = strlen("foobar12345"); + EXPECT_TRUE(sock->Send(buffers)); + + mock_return_value -= 1; + EXPECT_FALSE(sock->Send(buffers)); + + mock_return_value = 0; + EXPECT_FALSE(sock->Send(buffers)); + + mock_return_value = -1; + EXPECT_FALSE(sock->Send(buffers)); +} + +// Tests TCP re-sending until socket_send_buffers() sends all data. This is a little complicated, +// but the general idea is that we intercept calls to socket_send_buffers() using a lambda mock +// function that simulates partial writes. +TEST(SocketTest, TestTcpSendBuffers) { + std::unique_ptr<Socket> sock = Socket::NewServer(Socket::Protocol::kTcp, 0); + std::vector<std::string> data{"foo", "bar", "12345"}; + std::vector<cutils_socket_buffer_t> buffers{{data[0].data(), data[0].length()}, + {data[1].data(), data[1].length()}, + {data[2].data(), data[2].length()}}; + + // Test breaking up the buffered send at various points. + std::list<std::string> test_sends[] = { + // Successes. + {"foobar12345"}, + {"f", "oob", "ar12345"}, + {"fo", "obar12", "345"}, + {"foo", "bar12345"}, + {"foob", "ar123", "45"}, + {"f", "o", "o", "b", "a", "r", "1", "2", "3", "4", "5"}, + + // Failures. + {}, + {"f"}, + {"foo", "bar"}, + {"fo", "obar12"}, + {"foobar1234"} + }; + + for (auto& test : test_sends) { + ssize_t bytes_sent = 0; + bool expect_success = true; + + // Create a mock function for custom socket_send_buffers() behavior. This function will + // check to make sure the input buffers start at the next unsent byte, then return the + // number of bytes indicated by the next entry in |test|. + sock->socket_send_buffers_function_ = [&bytes_sent, &data, &expect_success, &test]( + cutils_socket_t /*cutils_sock*/, cutils_socket_buffer_t* buffers, + size_t num_buffers) -> ssize_t { + EXPECT_TRUE(num_buffers > 0); + + // Failure case - pretend we errored out before sending all the buffers. + if (test.empty()) { + expect_success = false; + return -1; + } + + // Count the bytes we've sent to find where the next buffer should start and how many + // bytes should be left in it. + size_t byte_count = bytes_sent, data_index = 0; + while (data_index < data.size()) { + if (byte_count >= data[data_index].length()) { + byte_count -= data[data_index].length(); + ++data_index; + } else { + break; + } + } + void* expected_next_byte = &data[data_index][byte_count]; + size_t expected_next_size = data[data_index].length() - byte_count; + + EXPECT_EQ(data.size() - data_index, num_buffers); + EXPECT_EQ(expected_next_byte, buffers[0].data); + EXPECT_EQ(expected_next_size, buffers[0].length); + + std::string to_send = std::move(test.front()); + test.pop_front(); + bytes_sent += to_send.length(); + return to_send.length(); + }; + + EXPECT_EQ(expect_success, sock->Send(buffers)); + EXPECT_TRUE(test.empty()); + } +} + TEST(SocketMockTest, TestSendSuccess) { SocketMock mock; diff --git a/include/cutils/sockets.h b/include/cutils/sockets.h index cb9b3ff8e..783bd0bea 100644 --- a/include/cutils/sockets.h +++ b/include/cutils/sockets.h @@ -30,15 +30,12 @@ typedef int socklen_t; typedef SOCKET cutils_socket_t; -typedef WSABUF cutils_socket_buffer_t; #else #include <sys/socket.h> -#include <sys/uio.h> typedef int cutils_socket_t; -typedef struct iovec cutils_socket_buffer_t; #define INVALID_SOCKET (-1) #endif @@ -144,21 +141,24 @@ int socket_get_local_port(cutils_socket_t sock); * on Windows. This can give significant speedup compared to calling send() * multiple times. * - * Because Unix and Windows use different structs to hold buffers, we also - * need a generic function to set up the buffers. - * * Example usage: - * cutils_socket_buffer_t buffers[2] = { - * make_cutils_socket_buffer(data0, len0), - * make_cutils_socket_buffer(data1, len1) - * }; + * cutils_socket_buffer_t buffers[2] = { {data0, len0}, {data1, len1} }; * socket_send_buffers(sock, buffers, 2); * + * If you try to pass more than SOCKET_SEND_BUFFERS_MAX_BUFFERS buffers into + * this function it will return -1 without sending anything. + * * Returns the number of bytes written or -1 on error. */ -cutils_socket_buffer_t make_cutils_socket_buffer(void* data, size_t length); +typedef struct { + const void* data; + size_t length; +} cutils_socket_buffer_t; + +#define SOCKET_SEND_BUFFERS_MAX_BUFFERS 16 + ssize_t socket_send_buffers(cutils_socket_t sock, - cutils_socket_buffer_t* buffers, + const cutils_socket_buffer_t* buffers, size_t num_buffers); /* diff --git a/libcutils/Android.mk b/libcutils/Android.mk index 482e4dd73..51c6d9d1f 100644 --- a/libcutils/Android.mk +++ b/libcutils/Android.mk @@ -46,7 +46,7 @@ libcutils_nonwindows_sources := \ socket_loopback_client_unix.c \ socket_loopback_server_unix.c \ socket_network_client_unix.c \ - sockets_unix.c \ + sockets_unix.cpp \ str_parms.c \ libcutils_nonwindows_host_sources := \ @@ -56,7 +56,7 @@ libcutils_nonwindows_host_sources := \ libcutils_windows_host_sources := \ socket_inaddr_any_server_windows.c \ socket_network_client_windows.c \ - sockets_windows.c \ + sockets_windows.cpp \ # Shared and static library for host # Note: when linking this library on Windows, you must also link to Winsock2 diff --git a/libcutils/sockets_unix.c b/libcutils/sockets_unix.cpp index 3e7cea087..8747d696f 100644 --- a/libcutils/sockets_unix.c +++ b/libcutils/sockets_unix.cpp @@ -15,6 +15,9 @@ */ #include <cutils/sockets.h> + +#include <sys/uio.h> + #include <log/log.h> #if defined(__ANDROID__) @@ -25,10 +28,9 @@ #define __android_unused __attribute__((__unused__)) #endif -bool socket_peer_is_trusted(int fd __android_unused) -{ +bool socket_peer_is_trusted(int fd __android_unused) { #if defined(__ANDROID__) - struct ucred cr; + ucred cr; socklen_t len = sizeof(cr); int n = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &cr, &len); @@ -51,21 +53,27 @@ int socket_close(int sock) { } int socket_set_receive_timeout(cutils_socket_t sock, int timeout_ms) { - struct timeval tv; + timeval tv; tv.tv_sec = timeout_ms / 1000; tv.tv_usec = (timeout_ms % 1000) * 1000; return setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); } -cutils_socket_buffer_t make_cutils_socket_buffer(void* data, size_t length) { - cutils_socket_buffer_t buffer; - buffer.iov_base = data; - buffer.iov_len = length; - return buffer; -} - ssize_t socket_send_buffers(cutils_socket_t sock, - cutils_socket_buffer_t* buffers, + const cutils_socket_buffer_t* buffers, size_t num_buffers) { - return writev(sock, buffers, num_buffers); + if (num_buffers > SOCKET_SEND_BUFFERS_MAX_BUFFERS) { + return -1; + } + + iovec iovec_buffers[SOCKET_SEND_BUFFERS_MAX_BUFFERS]; + for (size_t i = 0; i < num_buffers; ++i) { + // It's safe to cast away const here; iovec declares non-const + // void* because it's used for both send and receive, but since + // we're only sending, the data won't be modified. + iovec_buffers[i].iov_base = const_cast<void*>(buffers[i].data); + iovec_buffers[i].iov_len = buffers[i].length; + } + + return writev(sock, iovec_buffers, num_buffers); } diff --git a/libcutils/sockets_windows.c b/libcutils/sockets_windows.cpp index 815368822..ed6b1a781 100644 --- a/libcutils/sockets_windows.c +++ b/libcutils/sockets_windows.cpp @@ -37,7 +37,7 @@ // Both adb (1) and Chrome (2) purposefully avoid WSACleanup() with no issues. // (1) https://android.googlesource.com/platform/system/core.git/+/master/adb/sysdeps_win32.cpp // (2) https://code.google.com/p/chromium/codesearch#chromium/src/net/base/winsock_init.cc -bool initialize_windows_sockets() { +extern "C" bool initialize_windows_sockets() { // There's no harm in calling WSAStartup() multiple times but no benefit // either, we may as well skip it after the first. static bool init_success = false; @@ -55,25 +55,32 @@ int socket_close(cutils_socket_t sock) { } int socket_set_receive_timeout(cutils_socket_t sock, int timeout_ms) { - return setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char*)&timeout_ms, - sizeof(timeout_ms)); -} - -cutils_socket_buffer_t make_cutils_socket_buffer(void* data, size_t length) { - cutils_socket_buffer_t buffer; - buffer.buf = data; - buffer.len = length; - return buffer; + return setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast<char*>(&timeout_ms), sizeof(timeout_ms)); } ssize_t socket_send_buffers(cutils_socket_t sock, - cutils_socket_buffer_t* buffers, + const cutils_socket_buffer_t* buffers, size_t num_buffers) { - DWORD bytes_sent = 0; + if (num_buffers > SOCKET_SEND_BUFFERS_MAX_BUFFERS) { + return -1; + } - if (WSASend(sock, buffers, num_buffers, &bytes_sent, 0, NULL, NULL) != - SOCKET_ERROR) { + WSABUF wsa_buffers[SOCKET_SEND_BUFFERS_MAX_BUFFERS]; + for (size_t i = 0; i < num_buffers; ++i) { + // It's safe to cast away const here; WSABUF declares non-const + // void* because it's used for both send and receive, but since + // we're only sending, the data won't be modified. + wsa_buffers[i].buf = + reinterpret_cast<char*>(const_cast<void*>(buffers[i].data)); + wsa_buffers[i].len = buffers[i].length; + } + + DWORD bytes_sent = 0; + if (WSASend(sock, wsa_buffers, num_buffers, &bytes_sent, 0, nullptr, + nullptr) != SOCKET_ERROR) { return bytes_sent; } + return -1; } diff --git a/libcutils/tests/sockets_test.cpp b/libcutils/tests/sockets_test.cpp index 40fa9b110..0f682a2fa 100644 --- a/libcutils/tests/sockets_test.cpp +++ b/libcutils/tests/sockets_test.cpp @@ -60,11 +60,9 @@ static void TestConnectedSockets(cutils_socket_t server, cutils_socket_t client, // Send multiple buffers using socket_send_buffers(). std::string data[] = {"foo", "bar", "12345"}; - cutils_socket_buffer_t socket_buffers[3]; - for (int i = 0; i < 3; ++i) { - socket_buffers[i] = make_cutils_socket_buffer(&data[i][0], - data[i].length()); - } + cutils_socket_buffer_t socket_buffers[] = { {data[0].data(), data[0].length()}, + {data[1].data(), data[1].length()}, + {data[2].data(), data[2].length()} }; EXPECT_EQ(11, socket_send_buffers(client, socket_buffers, 3)); EXPECT_EQ(11, recv(server, buffer, sizeof(buffer), 0)); EXPECT_EQ(0, memcmp(buffer, "foobar12345", 11)); |
