diff options
Diffstat (limited to 'brillo/http/http_transport_curl.cc')
-rw-r--r-- | brillo/http/http_transport_curl.cc | 97 |
1 files changed, 43 insertions, 54 deletions
diff --git a/brillo/http/http_transport_curl.cc b/brillo/http/http_transport_curl.cc index 45a28a3..de6899a 100644 --- a/brillo/http/http_transport_curl.cc +++ b/brillo/http/http_transport_curl.cc @@ -7,10 +7,11 @@ #include <limits> #include <base/bind.h> +#include <base/files/file_descriptor_watcher_posix.h> #include <base/files/file_util.h> #include <base/logging.h> -#include <base/message_loop/message_loop.h> #include <base/strings/stringprintf.h> +#include <base/threading/thread_task_runner_handle.h> #include <brillo/http/http_connection_curl.h> #include <brillo/http/http_request.h> #include <brillo/strings/string_utils.h> @@ -22,7 +23,7 @@ namespace curl { // This is a class that stores connection data on particular CURL socket // and provides file descriptor watcher to monitor read and/or write operations // on the socket's file descriptor. -class Transport::SocketPollData : public base::MessagePumpForIO::FdWatcher { +class Transport::SocketPollData { public: SocketPollData(const std::shared_ptr<CurlInterface>& curl_interface, CURLM* curl_multi_handle, @@ -31,27 +32,35 @@ class Transport::SocketPollData : public base::MessagePumpForIO::FdWatcher { : curl_interface_(curl_interface), curl_multi_handle_(curl_multi_handle), transport_(transport), - socket_fd_(socket_fd), - file_descriptor_watcher_(FROM_HERE) {} + socket_fd_(socket_fd) {} - // Returns the pointer for the socket-specific file descriptor watcher. - base::MessagePumpForIO::FdWatchController* GetWatcher() { - return &file_descriptor_watcher_; + void StopWatcher() { + read_watcher_ = nullptr; + write_watcher_ = nullptr; } - private: - // Overrides from base::MessagePumpForIO::Watcher. - void OnFileCanReadWithoutBlocking(int fd) override { - OnSocketReady(fd, CURL_CSELECT_IN); + bool WatchReadable() { + read_watcher_ = base::FileDescriptorWatcher::WatchReadable( + socket_fd_, + base::BindRepeating(&Transport::SocketPollData::OnSocketReady, + base::Unretained(this), + CURL_CSELECT_IN)); + return read_watcher_.get(); } - void OnFileCanWriteWithoutBlocking(int fd) override { - OnSocketReady(fd, CURL_CSELECT_OUT); + + bool WatchWritable() { + write_watcher_ = base::FileDescriptorWatcher::WatchWritable( + socket_fd_, + base::BindRepeating(&Transport::SocketPollData::OnSocketReady, + base::Unretained(this), + CURL_CSELECT_OUT)); + return write_watcher_.get(); } + private: // Data on the socket is available to be read from or written to. // Notify CURL of the action it needs to take on the socket file descriptor. - void OnSocketReady(int fd, int action) { - CHECK_EQ(socket_fd_, fd) << "Unexpected socket file descriptor"; + void OnSocketReady(int action) { int still_running_count = 0; CURLMcode code = curl_interface_->MultiSocketAction( curl_multi_handle_, socket_fd_, action, &still_running_count); @@ -70,8 +79,9 @@ class Transport::SocketPollData : public base::MessagePumpForIO::FdWatcher { Transport* transport_; // The socket file descriptor for the connection. curl_socket_t socket_fd_; - // File descriptor watcher to notify us of asynchronous I/O on the FD. - base::MessagePumpForIO::FdWatchController file_descriptor_watcher_; + + std::unique_ptr<base::FileDescriptorWatcher::Controller> read_watcher_; + std::unique_ptr<base::FileDescriptorWatcher::Controller> write_watcher_; DISALLOW_COPY_AND_ASSIGN(SocketPollData); }; @@ -212,8 +222,7 @@ std::shared_ptr<http::Connection> Transport::CreateConnection( void Transport::RunCallbackAsync(const base::Location& from_here, const base::Closure& callback) { - base::MessageLoopForIO::current()->task_runner()->PostTask( - from_here, callback); + base::ThreadTaskRunnerHandle::Get()->PostTask(from_here, callback); } RequestID Transport::StartAsyncTransfer(http::Connection* connection, @@ -386,42 +395,22 @@ int Transport::MultiSocketCallback(CURL* easy, // Make sure we stop watching the socket file descriptor now, before // we schedule the SocketPollData for deletion. - poll_data->GetWatcher()->StopWatchingFileDescriptor(); + poll_data->StopWatcher(); // This method can be called indirectly from SocketPollData::OnSocketReady, // so delay destruction of SocketPollData object till the next loop cycle. - base::MessageLoopForIO::current()->task_runner()->DeleteSoon(FROM_HERE, - poll_data); + base::ThreadTaskRunnerHandle::Get()->DeleteSoon(FROM_HERE, poll_data); return 0; } - base::MessagePumpForIO::Mode watch_mode = base::MessagePumpForIO::WATCH_READ; - switch (what) { - case CURL_POLL_IN: - watch_mode = base::MessagePumpForIO::WATCH_READ; - break; - case CURL_POLL_OUT: - watch_mode = base::MessagePumpForIO::WATCH_WRITE; - break; - case CURL_POLL_INOUT: - watch_mode = base::MessagePumpForIO::WATCH_READ_WRITE; - break; - default: - LOG(FATAL) << "Unknown CURL socket action: " << what; - break; - } + poll_data->StopWatcher(); + + bool success = true; + if (what == CURL_POLL_IN || what == CURL_POLL_INOUT) + success = poll_data->WatchReadable() && success; + if (what == CURL_POLL_OUT || what == CURL_POLL_INOUT) + success = poll_data->WatchWritable() && success; - // WatchFileDescriptor() can be called with the same controller object - // (watcher) to amend the watch mode, however this has cumulative effect. - // For example, if we were watching a file descriptor for READ operations - // and now call it to watch for WRITE, it will end up watching for both - // READ and WRITE. This is not what we want here, so stop watching the - // file descriptor on previous controller before starting with a different - // mode. - if (!poll_data->GetWatcher()->StopWatchingFileDescriptor()) - LOG(WARNING) << "Failed to stop watching the previous socket descriptor"; - CHECK(base::MessageLoopForIO::current()->WatchFileDescriptor( - s, true, watch_mode, poll_data->GetWatcher(), poll_data)) - << "Failed to watch the CURL socket."; + CHECK(success) << "Failed to watch the CURL socket."; return 0; } @@ -433,11 +422,11 @@ int Transport::MultiTimerCallback(CURLM* /* multi */, // Cancel any previous timer callbacks. transport->weak_ptr_factory_for_timer_.InvalidateWeakPtrs(); if (timeout_ms >= 0) { - base::MessageLoopForIO::current()->task_runner()->PostDelayedTask( - FROM_HERE, - base::Bind(&Transport::OnTimer, - transport->weak_ptr_factory_for_timer_.GetWeakPtr()), - base::TimeDelta::FromMilliseconds(timeout_ms)); + base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( + FROM_HERE, + base::Bind(&Transport::OnTimer, + transport->weak_ptr_factory_for_timer_.GetWeakPtr()), + base::TimeDelta::FromMilliseconds(timeout_ms)); } return 0; } |