diff options
| author | Bruce Chen <chenbruce@google.com> | 2021-09-03 02:38:04 +0000 |
|---|---|---|
| committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2021-09-03 02:38:04 +0000 |
| commit | b121cbf0355ae9d6dc1aba11ac2dc107cebc8ab4 (patch) | |
| tree | 1b85f87ab0a6e6e20c013d08565e8f5095802ef7 | |
| parent | 5d7251433327be0304a6075c25ddb9a77c41ebb6 (diff) | |
| parent | 094d9ab72cc2dd9c4d11e31a27ffc6bab0cf9733 (diff) | |
| download | platform_packages_modules_DnsResolver-b121cbf0355ae9d6dc1aba11ac2dc107cebc8ab4.tar.gz platform_packages_modules_DnsResolver-b121cbf0355ae9d6dc1aba11ac2dc107cebc8ab4.tar.bz2 platform_packages_modules_DnsResolver-b121cbf0355ae9d6dc1aba11ac2dc107cebc8ab4.zip | |
Merge "Replace manual buffer handling with std::span" am: 094d9ab72c
Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1790127
Change-Id: I35c2c70c5af1eb311f50b0afa530d09bb4d110bd
| -rw-r--r-- | DnsProxyListener.cpp | 45 | ||||
| -rw-r--r-- | DnsTlsDispatcher.cpp | 3 | ||||
| -rw-r--r-- | getaddrinfo.cpp | 21 | ||||
| -rw-r--r-- | gethnamaddr.cpp | 4 | ||||
| -rw-r--r-- | res_cache.cpp | 47 | ||||
| -rw-r--r-- | res_mkquery.cpp | 58 | ||||
| -rw-r--r-- | res_query.cpp | 33 | ||||
| -rw-r--r-- | res_send.cpp | 185 | ||||
| -rw-r--r-- | res_send.h | 6 | ||||
| -rw-r--r-- | resolv_cache.h | 13 | ||||
| -rw-r--r-- | resolv_private.h | 17 | ||||
| -rw-r--r-- | tests/resolv_cache_unit_test.cpp | 56 |
12 files changed, 237 insertions, 251 deletions
diff --git a/DnsProxyListener.cpp b/DnsProxyListener.cpp index db23cfc2..d6f90ccd 100644 --- a/DnsProxyListener.cpp +++ b/DnsProxyListener.cpp @@ -39,7 +39,6 @@ #include <cutils/multiuser.h> #include <netdutils/InternetAddresses.h> #include <netdutils/ResponseCode.h> -#include <netdutils/Slice.h> #include <netdutils/Stopwatch.h> #include <netdutils/ThreadUtil.h> #include <private/android_filesystem_config.h> // AID_SYSTEM @@ -65,6 +64,7 @@ using aidl::android::net::metrics::INetdEventListener; using aidl::android::net::resolv::aidl::DnsHealthEventParcel; using aidl::android::net::resolv::aidl::IDnsResolverUnsolicitedEventListener; using android::net::NetworkDnsEventReported; +using std::span; namespace android { @@ -147,11 +147,11 @@ void maybeFixupNetContext(android_net_context* ctx, pid_t pid) { void addIpAddrWithinLimit(std::vector<std::string>* ip_addrs, const sockaddr* addr, socklen_t addrlen); -int extractResNsendAnswers(const uint8_t* answer, size_t anslen, int ipType, +int extractResNsendAnswers(std::span<const uint8_t> answer, int ipType, std::vector<std::string>* ip_addrs) { int total_ip_addr_count = 0; ns_msg handle; - if (ns_initparse((const uint8_t*)answer, anslen, &handle) < 0) { + if (ns_initparse(answer.data(), answer.size(), &handle) < 0) { return 0; } int ancount = ns_msg_count(handle, ns_s_an); @@ -250,21 +250,20 @@ bool simpleStrtoul(const char* input, IntegralType* output, int base = 10) { return true; } -bool setQueryId(uint8_t* msg, size_t msgLen, uint16_t query_id) { - if (msgLen < sizeof(HEADER)) { +bool setQueryId(span<uint8_t> msg, uint16_t query_id) { + if ((size_t)msg.size() < sizeof(HEADER)) { errno = EINVAL; return false; } - auto hp = reinterpret_cast<HEADER*>(msg); + auto hp = reinterpret_cast<HEADER*>(msg.data()); hp->id = htons(query_id); return true; } -bool parseQuery(const uint8_t* msg, size_t msgLen, uint16_t* query_id, int* rr_type, - std::string* rr_name) { +bool parseQuery(span<const uint8_t> msg, uint16_t* query_id, int* rr_type, std::string* rr_name) { ns_msg handle; ns_rr rr; - if (ns_initparse((const uint8_t*)msg, msgLen, &handle) < 0 || + if (ns_initparse(msg.data(), msg.size(), &handle) < 0 || ns_parserr(&handle, ns_s_qd, 0, &rr) < 0) { return false; } @@ -927,8 +926,8 @@ void DnsProxyListener::ResNSendHandler::run() { uint16_t original_query_id = 0; // TODO: Handle the case which is msg contains more than one query - if (!parseQuery(msg.data(), msgLen, &original_query_id, &rr_type, &rr_name) || - !setQueryId(msg.data(), msgLen, arc4random_uniform(65536))) { + if (!parseQuery({msg.data(), msgLen}, &original_query_id, &rr_type, &rr_name) || + !setQueryId({msg.data(), msgLen}, arc4random_uniform(65536))) { // If the query couldn't be parsed, block the request. LOG(WARNING) << "ResNSendHandler::run: resnsend: from UID " << uid << ", invalid query"; sendBE32(mClient, -EINVAL); @@ -938,21 +937,21 @@ void DnsProxyListener::ResNSendHandler::run() { // Send DNS query std::vector<uint8_t> ansBuf(MAXPACKET, 0); int rcode = ns_r_noerror; - int nsendAns = -1; + int ansLen = -1; NetworkDnsEventReported event; initDnsEvent(&event, mNetContext); if (queryLimiter.start(uid)) { if (evaluate_domain_name(mNetContext, rr_name.c_str())) { - nsendAns = resolv_res_nsend(&mNetContext, msg.data(), msgLen, ansBuf.data(), MAXPACKET, - &rcode, static_cast<ResNsendFlags>(mFlags), &event); + ansLen = resolv_res_nsend(&mNetContext, {msg.data(), msgLen}, ansBuf, &rcode, + static_cast<ResNsendFlags>(mFlags), &event); } else { - nsendAns = -EAI_SYSTEM; + ansLen = -EAI_SYSTEM; } queryLimiter.finish(uid); } else { LOG(WARNING) << "ResNSendHandler::run: resnsend: from UID " << uid << ", max concurrent queries reached"; - nsendAns = -EBUSY; + ansLen = -EBUSY; } const int32_t latencyUs = saturate_cast<int32_t>(s.timeTakenUs()); @@ -961,14 +960,14 @@ void DnsProxyListener::ResNSendHandler::run() { event.set_res_nsend_flags(static_cast<ResNsendFlags>(mFlags)); // Fail, send -errno - if (nsendAns < 0) { - if (!sendBE32(mClient, nsendAns)) { + if (ansLen < 0) { + if (!sendBE32(mClient, ansLen)) { PLOG(WARNING) << "ResNSendHandler::run: resnsend: failed to send errno to uid " << uid << " pid " << mClient->getPid(); } if (rr_type == ns_t_a || rr_type == ns_t_aaaa) { reportDnsEvent(INetdEventListener::EVENT_RES_NSEND, mNetContext, latencyUs, - resNSendToAiError(nsendAns, rcode), event, rr_name); + resNSendToAiError(ansLen, rcode), event, rr_name); } return; } @@ -981,8 +980,8 @@ void DnsProxyListener::ResNSendHandler::run() { } // Restore query id and send answer - if (!setQueryId(ansBuf.data(), nsendAns, original_query_id) || - !sendLenAndData(mClient, nsendAns, ansBuf.data())) { + if (!setQueryId({ansBuf.data(), ansLen}, original_query_id) || + !sendLenAndData(mClient, ansLen, ansBuf.data())) { PLOG(WARNING) << "ResNSendHandler::run: resnsend: failed to send answer to uid " << uid << " pid " << mClient->getPid(); return; @@ -991,9 +990,9 @@ void DnsProxyListener::ResNSendHandler::run() { if (rr_type == ns_t_a || rr_type == ns_t_aaaa) { std::vector<std::string> ip_addrs; const int total_ip_addr_count = - extractResNsendAnswers((uint8_t*)ansBuf.data(), nsendAns, rr_type, &ip_addrs); + extractResNsendAnswers({ansBuf.data(), ansLen}, rr_type, &ip_addrs); reportDnsEvent(INetdEventListener::EVENT_RES_NSEND, mNetContext, latencyUs, - resNSendToAiError(nsendAns, rcode), event, rr_name, ip_addrs, + resNSendToAiError(ansLen, rcode), event, rr_name, ip_addrs, total_ip_addr_count); } } diff --git a/DnsTlsDispatcher.cpp b/DnsTlsDispatcher.cpp index 67f98576..b6f766bb 100644 --- a/DnsTlsDispatcher.cpp +++ b/DnsTlsDispatcher.cpp @@ -128,7 +128,8 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const std::list<DnsTlsServer>& dnsQueryEvent->set_dns_server_index(serverCount++); dnsQueryEvent->set_ip_version(ipFamilyToIPVersion(server.ss.ss_family)); dnsQueryEvent->set_protocol(PROTO_DOT); - dnsQueryEvent->set_type(getQueryType(query.base(), query.size())); + std::span<const uint8_t> msg(query.base(), query.size()); + dnsQueryEvent->set_type(getQueryType(msg)); dnsQueryEvent->set_connected(connectTriggered); switch (code) { diff --git a/getaddrinfo.cpp b/getaddrinfo.cpp index 2e407f99..cb95311c 100644 --- a/getaddrinfo.cpp +++ b/getaddrinfo.cpp @@ -1627,13 +1627,11 @@ QueryResult doQuery(const char* name, res_target* t, ResState* res, LOG(DEBUG) << __func__ << ": (" << cl << ", " << type << ")"; uint8_t buf[MAXPACKET]; - - int n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf, sizeof(buf), - res->netcontext_flags); + int n = res_nmkquery(QUERY, name, cl, type, {}, buf, res->netcontext_flags); if (n > 0 && (res->netcontext_flags & (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS))) { - n = res_nopt(res, n, buf, sizeof(buf), anslen); + n = res_nopt(res, n, buf, anslen); } NetworkDnsEventReported event; @@ -1651,7 +1649,7 @@ QueryResult doQuery(const char* name, res_target* t, ResState* res, ResState res_temp = res->clone(&event); int rcode = NOERROR; - n = res_nsend(&res_temp, buf, n, t->answer.data(), anslen, &rcode, 0, sleepTimeMs); + n = res_nsend(&res_temp, {buf, n}, {t->answer.data(), anslen}, &rcode, 0, sleepTimeMs); if (n < 0 || hp->rcode != NOERROR || ntohs(hp->ancount) == 0) { // To ensure that the rcode handling is identical to res_queryN(). if (rcode != RCODE_TIMEOUT) rcode = hp->rcode; @@ -1660,9 +1658,8 @@ QueryResult doQuery(const char* name, res_target* t, ResState* res, (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) && (res_temp.flags & RES_F_EDNS0ERR)) { LOG(DEBUG) << __func__ << ": retry without EDNS0"; - n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf, - sizeof(buf), res_temp.netcontext_flags); - n = res_nsend(&res_temp, buf, n, t->answer.data(), anslen, &rcode, 0); + n = res_nmkquery(QUERY, name, cl, type, {}, buf, res_temp.netcontext_flags); + n = res_nsend(&res_temp, {buf, n}, {t->answer.data(), anslen}, &rcode, 0); } } @@ -1761,21 +1758,19 @@ static int res_queryN(const char* name, res_target* target, ResState* res, int* const int anslen = t->answer.size(); LOG(DEBUG) << __func__ << ": (" << cl << ", " << type << ")"; - - n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf, sizeof(buf), - res->netcontext_flags); + n = res_nmkquery(QUERY, name, cl, type, {}, buf, res->netcontext_flags); if (n > 0 && (res->netcontext_flags & (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) && !retried) // TODO: remove the retry flag and provide a sufficient test coverage. - n = res_nopt(res, n, buf, sizeof(buf), anslen); + n = res_nopt(res, n, buf, anslen); if (n <= 0) { LOG(ERROR) << __func__ << ": res_nmkquery failed"; *herrno = NO_RECOVERY; return n; } - n = res_nsend(res, buf, n, t->answer.data(), anslen, &rcode, 0); + n = res_nsend(res, {buf, n}, {t->answer.data(), anslen}, &rcode, 0); if (n < 0 || hp->rcode != NOERROR || ntohs(hp->ancount) == 0) { // Record rcode from DNS response header only if no timeout. // Keep rcode timeout for reporting later if any. diff --git a/gethnamaddr.cpp b/gethnamaddr.cpp index bbc12a2e..5e35b46a 100644 --- a/gethnamaddr.cpp +++ b/gethnamaddr.cpp @@ -632,7 +632,7 @@ static int dns_gethtbyname(ResState* res, const char* name, int addr_type, getna int he; const unsigned qclass = isMdnsResolution(res->flags) ? C_IN | C_UNICAST : C_IN; - n = res_nsearch(res, name, qclass, type, buf->buf, (int)sizeof(buf->buf), &he); + n = res_nsearch(res, name, qclass, type, {buf->buf, (int)sizeof(buf->buf)}, &he); if (n < 0) { LOG(DEBUG) << __func__ << ": res_nsearch failed (" << n << ")"; // Return h_errno (he) to catch more detailed errors rather than EAI_NODATA. @@ -694,7 +694,7 @@ static int dns_gethtbyaddr(const unsigned char* uaddr, int len, int af, ResState res(netcontext, event); int he; - n = res_nquery(&res, qbuf, C_IN, T_PTR, buf->buf, (int)sizeof(buf->buf), &he); + n = res_nquery(&res, qbuf, C_IN, T_PTR, {buf->buf, (int)sizeof(buf->buf)}, &he); if (n < 0) { LOG(DEBUG) << __func__ << ": res_nquery failed (" << n << ")"; // Note that res_nquery() doesn't set the pair NETDB_INTERNAL and errno. diff --git a/res_cache.cpp b/res_cache.cpp index 0a1bd5c9..fe0885a5 100644 --- a/res_cache.cpp +++ b/res_cache.cpp @@ -78,6 +78,7 @@ using android::net::PROTO_UDP; using android::net::Protocol; using android::netdutils::DumpWriter; using android::netdutils::IPSockAddr; +using std::span; /* This code implements a small and *simple* DNS resolver cache. * @@ -773,14 +774,14 @@ static uint32_t answer_getNegativeTTL(ns_msg handle) { * In case of parse error zero (0) is returned which * indicates that the answer shall not be cached. */ -static uint32_t answer_getTTL(const void* answer, int answerlen) { +static uint32_t answer_getTTL(span<const uint8_t> answer) { ns_msg handle; int ancount, n; uint32_t result, ttl; ns_rr rr; result = 0; - if (ns_initparse((const uint8_t*) answer, answerlen, &handle) >= 0) { + if (ns_initparse(answer.data(), answer.size(), &handle) >= 0) { // get number of answer records ancount = ns_msg_count(handle, ns_s_an); @@ -840,13 +841,13 @@ static unsigned entry_hash(const Entry* e) { /* initialize an Entry as a search key, this also checks the input query packet * returns 1 on success, or 0 in case of unsupported/malformed data */ -static int entry_init_key(Entry* e, const void* query, int querylen) { +static int entry_init_key(Entry* e, span<const uint8_t> query) { DnsPacket pack[1]; memset(e, 0, sizeof(*e)); - e->query = (const uint8_t*) query; - e->querylen = querylen; + e->query = query.data(); + e->querylen = query.size(); e->hash = entry_hash(e); _dnsPacket_init(pack, e->query, e->querylen); @@ -855,11 +856,11 @@ static int entry_init_key(Entry* e, const void* query, int querylen) { } /* allocate a new entry as a cache node */ -static Entry* entry_alloc(const Entry* init, const void* answer, int answerlen) { +static Entry* entry_alloc(const Entry* init, span<const uint8_t> answer) { Entry* e; int size; - size = sizeof(*e) + init->querylen + answerlen; + size = sizeof(*e) + init->querylen + answer.size(); e = (Entry*) calloc(size, 1); if (e == NULL) return e; @@ -870,9 +871,9 @@ static Entry* entry_alloc(const Entry* init, const void* answer, int answerlen) memcpy((char*) e->query, init->query, e->querylen); e->answer = e->query + e->querylen; - e->answerlen = answerlen; + e->answerlen = answer.size(); - memcpy((char*) e->answer, answer, e->answerlen); + memcpy((char*)e->answer, answer.data(), e->answerlen); return e; } @@ -1101,14 +1102,14 @@ static void cache_notify_waiting_tid_locked(struct Cache* cache, const Entry* ke } } -void _resolv_cache_query_failed(unsigned netid, const void* query, int querylen, uint32_t flags) { +void _resolv_cache_query_failed(unsigned netid, span<const uint8_t> query, uint32_t flags) { // We should not notify with these flags. if (flags & (ANDROID_RESOLV_NO_CACHE_STORE | ANDROID_RESOLV_NO_CACHE_LOOKUP)) { return; } Entry key[1]; - if (!entry_init_key(key, query, querylen)) return; + if (!entry_init_key(key, query)) return; std::lock_guard guard(cache_mutex); @@ -1228,8 +1229,8 @@ static void _cache_remove_expired(Cache* cache) { // Get a NetConfig associated with a network, or nullptr if not found. static NetConfig* find_netconfig_locked(unsigned netid) REQUIRES(cache_mutex); -ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int querylen, void* answer, - int answersize, int* answerlen, uint32_t flags) { +ResolvCacheStatus resolv_cache_lookup(unsigned netid, span<const uint8_t> query, + span<uint8_t> answer, int* answerlen, uint32_t flags) { // Skip cache lookup, return RESOLV_CACHE_NOTFOUND directly so that it is // possible to cache the answer of this query. // If ANDROID_RESOLV_NO_CACHE_STORE is set, return RESOLV_CACHE_SKIP to skip possible cache @@ -1247,7 +1248,7 @@ ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int que LOG(INFO) << __func__ << ": lookup"; /* we don't cache malformed queries */ - if (!entry_init_key(&key, query, querylen)) { + if (!entry_init_key(&key, query)) { LOG(INFO) << __func__ << ": unsupported query"; return RESOLV_CACHE_UNSUPPORTED; } @@ -1310,13 +1311,13 @@ ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int que } *answerlen = e->answerlen; - if (e->answerlen > answersize) { + if (e->answerlen > answer.size()) { /* NOTE: we return UNSUPPORTED if the answer buffer is too short */ LOG(INFO) << __func__ << ": ANSWER TOO LONG"; return RESOLV_CACHE_UNSUPPORTED; } - memcpy(answer, e->answer, e->answerlen); + memcpy(answer.data(), e->answer, e->answerlen); /* bump up this entry to the top of the MRU list */ if (e != cache->mru_list.mru_next) { @@ -1328,8 +1329,7 @@ ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int que return RESOLV_CACHE_FOUND; } -int resolv_cache_add(unsigned netid, const void* query, int querylen, const void* answer, - int answerlen) { +int resolv_cache_add(unsigned netid, span<const uint8_t> query, span<const uint8_t> answer) { Entry key[1]; Entry* e; Entry** lookup; @@ -1338,7 +1338,7 @@ int resolv_cache_add(unsigned netid, const void* query, int querylen, const void /* don't assume that the query has already been cached */ - if (!entry_init_key(key, query, querylen)) { + if (!entry_init_key(key, query)) { LOG(INFO) << __func__ << ": passed invalid query?"; return -EINVAL; } @@ -1375,9 +1375,9 @@ int resolv_cache_add(unsigned netid, const void* query, int querylen, const void } } - ttl = answer_getTTL(answer, answerlen); + ttl = answer_getTTL(answer); if (ttl > 0) { - e = entry_alloc(key, answer, answerlen); + e = entry_alloc(key, answer); if (e != NULL) { e->expires = ttl + _time_now(); _cache_add_p(cache, lookup, e); @@ -1886,13 +1886,12 @@ bool has_named_cache(unsigned netid) { return find_named_cache_locked(netid) != nullptr; } -int resolv_cache_get_expiration(unsigned netid, const std::vector<char>& query, - time_t* expiration) { +int resolv_cache_get_expiration(unsigned netid, span<const uint8_t> query, time_t* expiration) { Entry key; *expiration = -1; // A malformed query is not allowed. - if (!entry_init_key(&key, query.data(), query.size())) { + if (!entry_init_key(&key, query)) { LOG(WARNING) << __func__ << ": unsupported query"; return -EINVAL; } diff --git a/res_mkquery.cpp b/res_mkquery.cpp index 939c4c29..d02cdb35 100644 --- a/res_mkquery.cpp +++ b/res_mkquery.cpp @@ -97,13 +97,11 @@ extern const char* const _res_opcodes[] = { }; // Form all types of queries. Returns the size of the result or -1. -int res_nmkquery(int op, // opcode of query - const char* dname, // domain name - int cl, int type, // class and type of query - const uint8_t* data, // resource record data - int datalen, // length of data - uint8_t* buf, // buffer to put query - int buflen, // size of buffer +int res_nmkquery(int op, // opcode of query + const char* dname, // domain name + int cl, int type, // class and type of query + std::span<const uint8_t> data, // resource record data + std::span<uint8_t> buf, // buffer to put query int netcontext_flags) { HEADER* hp; uint8_t *cp, *ep; @@ -116,18 +114,18 @@ int res_nmkquery(int op, // opcode of query /* * Initialize header fields. */ - if ((buf == NULL) || (buflen < HFIXEDSZ)) return (-1); - memset(buf, 0, HFIXEDSZ); - hp = (HEADER*) (void*) buf; + if (buf.empty() || (buf.size() < HFIXEDSZ)) return (-1); + memset(buf.data(), 0, HFIXEDSZ); + hp = (HEADER*)(void*)buf.data(); hp->id = htons(arc4random_uniform(65536)); hp->opcode = op; hp->rd = true; hp->ad = (netcontext_flags & NET_CONTEXT_FLAG_USE_DNS_OVER_TLS) != 0U; hp->rcode = NOERROR; - cp = buf + HFIXEDSZ; - ep = buf + buflen; + cp = buf.data() + HFIXEDSZ; + ep = buf.data() + buf.size(); dpp = dnptrs; - *dpp++ = buf; + *dpp++ = buf.data(); *dpp++ = NULL; lastdnptr = dnptrs + sizeof dnptrs / sizeof dnptrs[0]; /* @@ -145,12 +143,12 @@ int res_nmkquery(int op, // opcode of query *reinterpret_cast<uint16_t*>(cp) = htons(cl); cp += INT16SZ; hp->qdcount = htons(1); - if (op == QUERY || data == NULL) break; + if (op == QUERY || data.empty()) break; /* * Make an additional record for completion domain. */ if ((ep - cp) < RRFIXEDSZ) return (-1); - n = dn_comp((const char*) data, cp, ep - cp - RRFIXEDSZ, dnptrs, lastdnptr); + n = dn_comp((const char*)data.data(), cp, ep - cp - RRFIXEDSZ, dnptrs, lastdnptr); if (n < 0) return (-1); cp += n; *reinterpret_cast<uint16_t*>(cp) = htons(ns_t_null); @@ -168,7 +166,7 @@ int res_nmkquery(int op, // opcode of query /* * Initialize answer section */ - if (ep - cp < 1 + RRFIXEDSZ + datalen) return (-1); + if (ep - cp < 1 + RRFIXEDSZ + data.size()) return (-1); *cp++ = '\0'; /* no domain name */ *reinterpret_cast<uint16_t*>(cp) = htons(type); cp += INT16SZ; @@ -176,11 +174,11 @@ int res_nmkquery(int op, // opcode of query cp += INT16SZ; *reinterpret_cast<uint32_t*>(cp) = htonl(0); cp += INT32SZ; - *reinterpret_cast<uint16_t*>(cp) = htons(datalen); + *reinterpret_cast<uint16_t*>(cp) = htons(data.size()); cp += INT16SZ; - if (datalen) { - memcpy(cp, data, (size_t) datalen); - cp += datalen; + if (data.size()) { + memcpy(cp, data.data(), data.size()); + cp += data.size(); } hp->ancount = htons(1); break; @@ -188,23 +186,21 @@ int res_nmkquery(int op, // opcode of query default: return (-1); } - return (cp - buf); + return (cp - buf.data()); } int res_nopt(ResState* statp, int n0, /* current offset in buffer */ - uint8_t* buf, /* buffer to put query */ - int buflen, /* size of buffer */ + std::span<uint8_t> buf, /* buffer to put query */ int anslen) /* UDP answer buffer size */ { - HEADER* hp; + HEADER* hp = reinterpret_cast<HEADER*>(buf.data()); uint8_t *cp, *ep; uint16_t flags = 0; LOG(DEBUG) << __func__; - hp = (HEADER*) (void*) buf; - cp = buf + n0; - ep = buf + buflen; + cp = buf.data() + n0; + ep = buf.data() + buf.size(); if ((ep - cp) < 1 + RRFIXEDSZ) return (-1); @@ -226,13 +222,13 @@ int res_nopt(ResState* statp, int n0, /* current offset in buffer */ cp += INT16SZ; // EDNS0 padding - const uint16_t minlen = static_cast<uint16_t>(cp - buf) + 3 * INT16SZ; + const uint16_t minlen = static_cast<uint16_t>(cp - buf.data()) + 3 * INT16SZ; const uint16_t extra = minlen % kEdns0Padding; uint16_t padlen = (kEdns0Padding - extra) % kEdns0Padding; - if (minlen > buflen) { + if (minlen > buf.size()) { return -1; } - padlen = std::min(padlen, static_cast<uint16_t>(buflen - minlen)); + padlen = std::min(padlen, static_cast<uint16_t>(buf.size() - minlen)); *reinterpret_cast<uint16_t*>(cp) = htons(padlen + 2 * INT16SZ); /* RDLEN */ cp += INT16SZ; *reinterpret_cast<uint16_t*>(cp) = htons(NS_OPT_PADDING); /* OPTION-CODE */ @@ -243,5 +239,5 @@ int res_nopt(ResState* statp, int n0, /* current offset in buffer */ cp += padlen; hp->arcount = htons(ntohs(hp->arcount) + 1); - return (cp - buf); + return (cp - buf.data()); } diff --git a/res_query.cpp b/res_query.cpp index 5019ac30..2359b32b 100644 --- a/res_query.cpp +++ b/res_query.cpp @@ -101,13 +101,12 @@ */ int res_nquery(ResState* statp, const char* name, // domain name int cl, int type, // class and type of query - uint8_t* answer, // buffer to put answer - int anslen, // size of answer buffer + std::span<uint8_t> answer, // buffer to put answer int* herrno) // legacy and extended h_errno // NETD_RESOLV_H_ERRNO_EXT_* { uint8_t buf[MAXPACKET]; - HEADER* hp = (HEADER*) (void*) answer; + HEADER* hp = reinterpret_cast<HEADER*>(answer.data()); int n; int rcode = NOERROR; bool retried = false; @@ -116,20 +115,18 @@ again: hp->rcode = NOERROR; // default LOG(DEBUG) << __func__ << ": (" << cl << ", " << type << ")"; - - n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, 0, buf, sizeof(buf), - statp->netcontext_flags); + n = res_nmkquery(QUERY, name, cl, type, {}, buf, statp->netcontext_flags); if (n > 0 && (statp->netcontext_flags & (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) && !retried) - n = res_nopt(statp, n, buf, sizeof(buf), anslen); + n = res_nopt(statp, n, buf, answer.size()); if (n <= 0) { LOG(DEBUG) << __func__ << ": mkquery failed"; *herrno = NO_RECOVERY; return n; } - n = res_nsend(statp, buf, n, answer, anslen, &rcode, 0); + n = res_nsend(statp, {buf, n}, answer, &rcode, 0); if (n < 0) { // If the query choked with EDNS0, retry without EDNS0 that when the server // has no response, resovler won't retry and do nothing. Even fallback to UDP, @@ -203,13 +200,12 @@ again: */ int res_nsearch(ResState* statp, const char* name, /* domain name */ int cl, int type, /* class and type of query */ - uint8_t* answer, /* buffer to put answer */ - int anslen, /* size of answer */ + std::span<uint8_t> answer, /* buffer to put answer */ int* herrno) /* legacy and extended h_errno NETD_RESOLV_H_ERRNO_EXT_* */ { const char* cp; - HEADER* hp = (HEADER*) (void*) answer; + HEADER* hp = reinterpret_cast<HEADER*>(answer.data()); uint32_t dots; int ret, saved_herrno; int got_nodata = 0, got_servfail = 0, root_on_list = 0; @@ -229,7 +225,7 @@ int res_nsearch(ResState* statp, const char* name, /* domain name */ */ saved_herrno = -1; if (dots >= statp->ndots || trailing_dot) { - ret = res_nquerydomain(statp, name, NULL, cl, type, answer, anslen, herrno); + ret = res_nquerydomain(statp, name, NULL, cl, type, answer, herrno); if (ret > 0 || trailing_dot) return ret; saved_herrno = *herrno; tried_as_is++; @@ -255,7 +251,7 @@ int res_nsearch(ResState* statp, const char* name, /* domain name */ for (const auto& domain : statp->search_domains) { if (domain == "." || domain == "") ++root_on_list; - ret = res_nquerydomain(statp, name, domain.c_str(), cl, type, answer, anslen, herrno); + ret = res_nquerydomain(statp, name, domain.c_str(), cl, type, answer, herrno); if (ret > 0) return ret; /* @@ -301,7 +297,7 @@ int res_nsearch(ResState* statp, const char* name, /* domain name */ // note that we do this regardless of how many dots were in the // name or whether it ends with a dot. if (!tried_as_is && !root_on_list) { - ret = res_nquerydomain(statp, name, NULL, cl, type, answer, anslen, herrno); + ret = res_nquerydomain(statp, name, NULL, cl, type, answer, herrno); if (ret > 0) return ret; } @@ -326,10 +322,9 @@ int res_nsearch(ResState* statp, const char* name, /* domain name */ * removing a trailing dot from name if domain is NULL. */ int res_nquerydomain(ResState* statp, const char* name, const char* domain, int cl, - int type, /* class and type of query */ - uint8_t* answer, /* buffer to put answer */ - int anslen, /* size of answer */ - int* herrno) /* legacy and extended h_errno NETD_RESOLV_H_ERRNO_EXT_* */ + int type, /* class and type of query */ + std::span<uint8_t> answer, /* buffer to put answer */ + int* herrno) /* legacy and extended h_errno NETD_RESOLV_H_ERRNO_EXT_* */ { char nbuf[MAXDNAME]; const char* longname = nbuf; @@ -362,5 +357,5 @@ int res_nquerydomain(ResState* statp, const char* name, const char* domain, int } snprintf(nbuf, sizeof(nbuf), "%s.%s", name, domain); } - return res_nquery(statp, longname, cl, type, answer, anslen, herrno); + return res_nquery(statp, longname, cl, type, answer, herrno); } diff --git a/res_send.cpp b/res_send.cpp index 82677959..c5eb6270 100644 --- a/res_send.cpp +++ b/res_send.cpp @@ -150,19 +150,19 @@ using android::net::PROTO_UDP; using android::netdutils::IPSockAddr; using android::netdutils::Slice; using android::netdutils::Stopwatch; +using std::span; const std::vector<IPSockAddr> mdns_addrs = {IPSockAddr::toIPSockAddr("ff02::fb", 5353), IPSockAddr::toIPSockAddr("224.0.0.251", 5353)}; static int setupUdpSocket(ResState* statp, const sockaddr* sockap, unique_fd* fd_out, int* terrno); -static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int buflen, - uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit, - int* gotsomewhere, time_t* at, int* rcode, int* delay); -static int send_vc(ResState* statp, res_params* params, const uint8_t* buf, int buflen, - uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode, - int* delay); -static int send_mdns(ResState* statp, std::span<const uint8_t> buf, uint8_t* ans, int anssiz, - int* terrno, int* rcode); +static int send_dg(ResState* statp, res_params* params, span<const uint8_t> msg, span<uint8_t> ans, + int* terrno, size_t* ns, int* v_circuit, int* gotsomewhere, time_t* at, + int* rcode, int* delay); +static int send_vc(ResState* statp, res_params* params, span<const uint8_t> msg, span<uint8_t> ans, + int* terrno, size_t ns, time_t* at, int* rcode, int* delay); +static int send_mdns(ResState* statp, span<const uint8_t> msg, span<uint8_t> ans, int* terrno, + int* rcode); static void dump_error(const char*, const struct sockaddr*); static int sock_eq(struct sockaddr*, struct sockaddr*); @@ -175,10 +175,10 @@ static int res_tls_send(const std::list<DnsTlsServer>& tlsServers, ResState*, co const Slice answer, int* rcode, PrivateDnsMode mode); static ssize_t res_doh_send(ResState*, const Slice query, const Slice answer, int* rcode); -NsType getQueryType(const uint8_t* msg, size_t msgLen) { +NsType getQueryType(span<const uint8_t> msg) { ns_msg handle; ns_rr rr; - if (ns_initparse((const uint8_t*)msg, msgLen, &handle) < 0 || + if (ns_initparse(msg.data(), msg.size(), &handle) < 0 || ns_parserr(&handle, ns_s_qd, 0, &rr) < 0) { return NS_T_INVALID; } @@ -348,10 +348,10 @@ static int res_ourserver_p(ResState* statp, const sockaddr* sa) { } /* int - * res_nameinquery(name, type, cl, buf, eom) - * look for (name, type, cl) in the query section of packet (buf, eom) + * res_nameinquery(name, type, cl, msg, eom) + * look for (name, type, cl) in the query section of packet (msg, eom) * requires: - * buf + HFIXEDSZ <= eom + * msg + HFIXEDSZ <= eom * returns: * -1 : format error * 0 : not found @@ -359,13 +359,13 @@ static int res_ourserver_p(ResState* statp, const sockaddr* sa) { * author: * paul vixie, 29may94 */ -int res_nameinquery(const char* name, int type, int cl, const uint8_t* buf, const uint8_t* eom) { - const uint8_t* cp = buf + HFIXEDSZ; - int qdcount = ntohs(((const HEADER*) (const void*) buf)->qdcount); +int res_nameinquery(const char* name, int type, int cl, const uint8_t* msg, const uint8_t* eom) { + const uint8_t* cp = msg + HFIXEDSZ; + int qdcount = ntohs(((const HEADER*)(const void*)msg)->qdcount); while (qdcount-- > 0) { char tname[MAXDNAME + 1]; - int n = dn_expand(buf, eom, cp, tname, sizeof tname); + int n = dn_expand(msg, eom, cp, tname, sizeof tname); if (n < 0) return (-1); cp += n; if (cp + 2 * INT16SZ > eom) return (-1); @@ -433,30 +433,29 @@ static bool isNetworkRestricted(int terrno) { return (terrno == EPERM); } -int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int anssiz, int* rcode, +int res_nsend(ResState* statp, span<const uint8_t> msg, span<uint8_t> ans, int* rcode, uint32_t flags, std::chrono::milliseconds sleepTimeMs) { LOG(DEBUG) << __func__; // Should not happen - if (anssiz < HFIXEDSZ) { + if (ans.size() < HFIXEDSZ) { // TODO: Remove errno once callers stop using it errno = EINVAL; return -EINVAL; } - res_pquery({buf, buflen}); + res_pquery(msg); int anslen = 0; Stopwatch cacheStopwatch; - ResolvCacheStatus cache_status = - resolv_cache_lookup(statp->netid, buf, buflen, ans, anssiz, &anslen, flags); + ResolvCacheStatus cache_status = resolv_cache_lookup(statp->netid, msg, ans, &anslen, flags); const int32_t cacheLatencyUs = saturate_cast<int32_t>(cacheStopwatch.timeTakenUs()); if (cache_status == RESOLV_CACHE_FOUND) { - HEADER* hp = (HEADER*)(void*)ans; + HEADER* hp = (HEADER*)(void*)ans.data(); *rcode = hp->rcode; DnsQueryEvent* dnsQueryEvent = addDnsQueryEvent(statp->event); dnsQueryEvent->set_latency_micros(cacheLatencyUs); dnsQueryEvent->set_cache_hit(static_cast<CacheStatus>(cache_status)); - dnsQueryEvent->set_type(getQueryType(buf, buflen)); + dnsQueryEvent->set_type(getQueryType(msg)); return anslen; } else if (cache_status != RESOLV_CACHE_UNSUPPORTED) { // had a cache miss for a known network, so populate the thread private @@ -470,30 +469,29 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int int terrno = ETIME; int resplen = 0; *rcode = RCODE_INTERNAL_ERROR; - std::span<const uint8_t> buffer(buf, buflen); Stopwatch queryStopwatch; - resplen = send_mdns(statp, buffer, ans, anssiz, &terrno, rcode); + resplen = send_mdns(statp, msg, ans, &terrno, rcode); const IPSockAddr& receivedMdnsAddr = - (getQueryType(buf, buflen) == NS_T_AAAA) ? mdns_addrs[0] : mdns_addrs[1]; + (getQueryType(msg) == NS_T_AAAA) ? mdns_addrs[0] : mdns_addrs[1]; DnsQueryEvent* mDnsQueryEvent = addDnsQueryEvent(statp->event); mDnsQueryEvent->set_cache_hit(static_cast<CacheStatus>(cache_status)); mDnsQueryEvent->set_latency_micros(saturate_cast<int32_t>(queryStopwatch.timeTakenUs())); mDnsQueryEvent->set_ip_version(ipFamilyToIPVersion(receivedMdnsAddr.family())); mDnsQueryEvent->set_rcode(static_cast<NsRcode>(*rcode)); mDnsQueryEvent->set_protocol(PROTO_MDNS); - mDnsQueryEvent->set_type(getQueryType(buf, buflen)); + mDnsQueryEvent->set_type(getQueryType(msg)); mDnsQueryEvent->set_linux_errno(static_cast<LinuxErrno>(terrno)); resolv_stats_add(statp->netid, receivedMdnsAddr, mDnsQueryEvent); if (resplen <= 0) { - _resolv_cache_query_failed(statp->netid, buf, buflen, flags); + _resolv_cache_query_failed(statp->netid, msg, flags); return -terrno; } LOG(DEBUG) << __func__ << ": got answer:"; - res_pquery({ans, (resplen > anssiz) ? anssiz : resplen}); + res_pquery(ans.first(resplen)); if (cache_status == RESOLV_CACHE_NOTFOUND) { - resolv_cache_add(statp->netid, buf, buflen, ans, resplen); + resolv_cache_add(statp->netid, msg, {ans.data(), resplen}); } return resplen; } @@ -503,7 +501,7 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int // point trying. Tell the cache the query failed, or any retries and anyone else // asking the same question will block for PENDING_REQUEST_TIMEOUT seconds instead // of failing fast. - _resolv_cache_query_failed(statp->netid, buf, buflen, flags); + _resolv_cache_query_failed(statp->netid, msg, flags); // TODO: Remove errno once callers stop using it errno = ESRCH; @@ -513,18 +511,19 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int // Private DNS if (!(statp->netcontext_flags & NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS)) { bool fallback = false; - int resplen = res_private_dns_send(statp, Slice(const_cast<uint8_t*>(buf), buflen), - Slice(ans, anssiz), rcode, &fallback); + int resplen = + res_private_dns_send(statp, Slice(const_cast<uint8_t*>(msg.data()), msg.size()), + Slice(ans.data(), ans.size()), rcode, &fallback); if (resplen > 0) { LOG(DEBUG) << __func__ << ": got answer from Private DNS"; - res_pquery({ans, resplen}); + res_pquery(ans.first(resplen)); if (cache_status == RESOLV_CACHE_NOTFOUND) { - resolv_cache_add(statp->netid, buf, buflen, ans, resplen); + resolv_cache_add(statp->netid, msg, ans.first(resplen)); } return resplen; } if (!fallback) { - _resolv_cache_query_failed(statp->netid, buf, buflen, flags); + _resolv_cache_query_failed(statp->netid, msg, flags); return -ETIMEDOUT; } } @@ -558,7 +557,7 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int // TODO: Let it always choose the first nameserver when sort_nameservers is enabled. if ((flags & ANDROID_RESOLV_NO_RETRY) && usableServersCount > 1) { - auto hp = reinterpret_cast<const HEADER*>(buf); + auto hp = reinterpret_cast<const HEADER*>(msg.data()); // Select a random server based on the query id int selectedServer = (hp->id % usableServersCount) + 1; @@ -567,7 +566,7 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int // Send request, RETRY times, or until successful. int retryTimes = (flags & ANDROID_RESOLV_NO_RETRY) ? 1 : params.retry_count; - int useTcp = buflen > PACKETSZ; + int useTcp = msg.size() > PACKETSZ; int gotsomewhere = 0; // Use an impossible error code as default value @@ -598,10 +597,10 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int if (useTcp) { // TCP; at most one attempt per server. attempt = retryTimes; - resplen = send_vc(statp, ¶ms, buf, buflen, ans, anssiz, &terrno, ns, - &query_time, rcode, &delay); + resplen = + send_vc(statp, ¶ms, msg, ans, &terrno, ns, &query_time, rcode, &delay); - if (buflen <= PACKETSZ && resplen <= 0 && + if (msg.size() <= PACKETSZ && resplen <= 0 && statp->tc_mode == aidl::android::net::IDnsResolver::TC_MODE_UDP_TCP) { // reset to UDP for next query on next DNS server if resolver is currently doing // TCP fallback retry and current server does not support TCP connectin @@ -610,8 +609,8 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int LOG(INFO) << __func__ << ": used send_vc " << resplen << " terrno: " << terrno; } else { // UDP - resplen = send_dg(statp, ¶ms, buf, buflen, ans, anssiz, &terrno, &actualNs, - &useTcp, &gotsomewhere, &query_time, rcode, &delay); + resplen = send_dg(statp, ¶ms, msg, ans, &terrno, &actualNs, &useTcp, + &gotsomewhere, &query_time, rcode, &delay); fallbackTCP = useTcp ? true : false; retry_count_for_event = attempt; LOG(INFO) << __func__ << ": used send_dg " << resplen << " terrno: " << terrno; @@ -631,7 +630,7 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int dnsQueryEvent->set_retry_times(retry_count_for_event); dnsQueryEvent->set_rcode(static_cast<NsRcode>(*rcode)); dnsQueryEvent->set_protocol(query_proto); - dnsQueryEvent->set_type(getQueryType(buf, buflen)); + dnsQueryEvent->set_type(getQueryType(msg)); dnsQueryEvent->set_linux_errno(static_cast<LinuxErrno>(terrno)); // Only record stats the first time we try a query. This ensures that @@ -660,16 +659,16 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int continue; } if (resplen < 0) { - _resolv_cache_query_failed(statp->netid, buf, buflen, flags); + _resolv_cache_query_failed(statp->netid, msg, flags); statp->closeSockets(); return -terrno; } LOG(DEBUG) << __func__ << ": got answer:"; - res_pquery({ans, (resplen > anssiz) ? anssiz : resplen}); + res_pquery(ans.first(resplen)); if (cache_status == RESOLV_CACHE_NOTFOUND) { - resolv_cache_add(statp->netid, buf, buflen, ans, resplen); + resolv_cache_add(statp->netid, msg, {ans.data(), resplen}); } statp->closeSockets(); return (resplen); @@ -682,7 +681,7 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int : gotsomewhere ? ETIMEDOUT /* no answer obtained */ : ECONNREFUSED /* no nameservers found */; - _resolv_cache_query_failed(statp->netid, buf, buflen, flags); + _resolv_cache_query_failed(statp->netid, msg, flags); return -terrno; } @@ -707,13 +706,12 @@ static struct timespec get_timeout(ResState* statp, const res_params* params, co return result; } -static int send_vc(ResState* statp, res_params* params, const uint8_t* buf, int buflen, - uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode, - int* delay) { +static int send_vc(ResState* statp, res_params* params, span<const uint8_t> msg, span<uint8_t> ans, + int* terrno, size_t ns, time_t* at, int* rcode, int* delay) { *at = time(NULL); *delay = 0; - const HEADER* hp = (const HEADER*) (const void*) buf; - HEADER* anhp = (HEADER*) (void*) ans; + const HEADER* hp = (const HEADER*)(const void*)msg.data(); + HEADER* anhp = (HEADER*)(void*)ans.data(); struct sockaddr* nsap; int nsaplen; int truncating, connreset, n; @@ -807,12 +805,13 @@ same_ns: /* * Send length & message */ - uint16_t len = htons(static_cast<uint16_t>(buflen)); + uint16_t len = htons(static_cast<uint16_t>(msg.size())); const iovec iov[] = { {.iov_base = &len, .iov_len = INT16SZ}, - {.iov_base = const_cast<uint8_t*>(buf), .iov_len = static_cast<size_t>(buflen)}, + {.iov_base = const_cast<uint8_t*>(msg.data()), + .iov_len = static_cast<size_t>(msg.size())}, }; - if (writev(statp->tcp_nssock, iov, 2) != (INT16SZ + buflen)) { + if (writev(statp->tcp_nssock, iov, 2) != (INT16SZ + msg.size())) { *terrno = errno; PLOG(DEBUG) << __func__ << ": write failed: "; statp->closeSockets(); @@ -822,7 +821,7 @@ same_ns: * Receive length & response */ read_len: - cp = ans; + cp = ans.data(); len = INT16SZ; while ((n = read(statp->tcp_nssock, (char*)cp, (size_t)len)) > 0) { cp += n; @@ -847,11 +846,11 @@ read_len: } return (0); } - uint16_t resplen = ntohs(*reinterpret_cast<const uint16_t*>(ans)); - if (resplen > anssiz) { + uint16_t resplen = ntohs(*reinterpret_cast<const uint16_t*>(ans.data())); + if (resplen > ans.size()) { LOG(DEBUG) << __func__ << ": response truncated"; truncating = 1; - len = anssiz; + len = ans.size(); } else len = resplen; if (len < HFIXEDSZ) { @@ -863,7 +862,7 @@ read_len: statp->closeSockets(); return (0); } - cp = ans; + cp = ans.data(); while (len != 0 && (n = read(statp->tcp_nssock, (char*)cp, (size_t)len)) > 0) { cp += n; len -= n; @@ -880,7 +879,7 @@ read_len: * Flush rest of answer so connection stays in synch. */ anhp->tc = 1; - len = resplen - anssiz; + len = resplen - ans.size(); while (len != 0) { char junk[PACKETSZ]; @@ -890,9 +889,9 @@ read_len: else break; } - LOG(WARNING) << __func__ << ": resplen " << resplen << " exceeds buf size " << anssiz; + LOG(WARNING) << __func__ << ": resplen " << resplen << " exceeds buf size " << ans.size(); // return size should never exceed container size - resplen = anssiz; + resplen = ans.size(); } /* * If the calling application has bailed out of @@ -903,7 +902,7 @@ read_len: */ if (hp->id != anhp->id) { LOG(DEBUG) << __func__ << ": ld answer (unexpected):"; - res_pquery({ans, resplen}); + res_pquery({ans.data(), resplen}); goto read_len; } @@ -1030,10 +1029,10 @@ static Result<std::vector<int>> udpRetryingPollWrapper(ResState* statp, int addr return std::vector<int>{statp->udpsocks[addrInfo]}; } -bool ignoreInvalidAnswer(ResState* statp, const sockaddr_storage& from, const uint8_t* buf, - int buflen, uint8_t* ans, int anssiz, int* receivedFromNs) { - const HEADER* hp = (const HEADER*)(const void*)buf; - HEADER* anhp = (HEADER*)(void*)ans; +bool ignoreInvalidAnswer(ResState* statp, const sockaddr_storage& from, span<const uint8_t> msg, + span<uint8_t> ans, int* receivedFromNs) { + const HEADER* hp = (const HEADER*)(const void*)msg.data(); + HEADER* anhp = (HEADER*)(void*)ans.data(); if (hp->id != anhp->id) { // response from old query, ignore it. LOG(DEBUG) << __func__ << ": old answer:"; @@ -1044,7 +1043,8 @@ bool ignoreInvalidAnswer(ResState* statp, const sockaddr_storage& from, const ui LOG(DEBUG) << __func__ << ": not our server:"; return true; } - if (!res_queriesmatch(buf, buf + buflen, ans, ans + anssiz)) { + if (!res_queriesmatch(msg.data(), msg.data() + msg.size(), ans.data(), + ans.data() + ans.size())) { // response contains wrong query? ignore it. LOG(DEBUG) << __func__ << ": wrong query name:"; return true; @@ -1088,9 +1088,9 @@ static int setupUdpSocket(ResState* statp, const sockaddr* sockap, unique_fd* fd return 1; } -static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int buflen, - uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit, - int* gotsomewhere, time_t* at, int* rcode, int* delay) { +static int send_dg(ResState* statp, res_params* params, span<const uint8_t> msg, span<uint8_t> ans, + int* terrno, size_t* ns, int* v_circuit, int* gotsomewhere, time_t* at, + int* rcode, int* delay) { // It should never happen, but just in case. if (*ns >= statp->nsaddrs.size()) { LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns; @@ -1119,7 +1119,7 @@ static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int } LOG(DEBUG) << __func__ << ": new DG socket"; } - if (send(statp->udpsocks[*ns], (const char*)buf, (size_t)buflen, 0) != buflen) { + if (send(statp->udpsocks[*ns], msg.data(), msg.size(), 0) != msg.size()) { *terrno = errno; PLOG(DEBUG) << __func__ << ": send: "; statp->closeSockets(); @@ -1150,7 +1150,7 @@ static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int sockaddr_storage from; socklen_t fromlen = sizeof(from); int resplen = - recvfrom(fd, (char*)ans, (size_t)anssiz, 0, (sockaddr*)(void*)&from, &fromlen); + recvfrom(fd, ans.data(), ans.size(), 0, (sockaddr*)(void*)&from, &fromlen); if (resplen <= 0) { *terrno = errno; PLOG(DEBUG) << __func__ << ": recvfrom: "; @@ -1165,20 +1165,19 @@ static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int } int receivedFromNs = *ns; - if (needRetry = - ignoreInvalidAnswer(statp, from, buf, buflen, ans, anssiz, &receivedFromNs); + if (needRetry = ignoreInvalidAnswer(statp, from, msg, ans, &receivedFromNs); needRetry) { - res_pquery({ans, (resplen > anssiz) ? anssiz : resplen}); + res_pquery({ans.data(), (resplen > ans.size()) ? ans.size() : resplen}); continue; } - HEADER* anhp = (HEADER*)(void*)ans; + HEADER* anhp = (HEADER*)(void*)ans.data(); if (anhp->rcode == FORMERR && (statp->netcontext_flags & NET_CONTEXT_FLAG_USE_EDNS)) { // Do not retry if the server do not understand EDNS0. // The case has to be captured here, as FORMERR packet do not // carry query section, hence res_queriesmatch() returns 0. LOG(DEBUG) << __func__ << ": server rejected query with EDNS0:"; - res_pquery({ans, (resplen > anssiz) ? anssiz : resplen}); + res_pquery({ans.data(), (resplen > ans.size()) ? ans.size() : resplen}); // record the error statp->flags |= RES_F_EDNS0ERR; *terrno = EREMOTEIO; @@ -1189,7 +1188,7 @@ static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int *delay = res_stats_calculate_rtt(&done, &start_time); if (anhp->rcode == SERVFAIL || anhp->rcode == NOTIMP || anhp->rcode == REFUSED) { LOG(DEBUG) << __func__ << ": server rejected query:"; - res_pquery({ans, (resplen > anssiz) ? anssiz : resplen}); + res_pquery({ans.data(), (resplen > ans.size()) ? ans.size() : resplen}); *rcode = anhp->rcode; continue; } @@ -1215,16 +1214,15 @@ static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int // return length - when receiving valid packets. // return 0 - when mdns packets transfer error. -static int send_mdns(ResState* statp, std::span<const uint8_t> buf, uint8_t* ans, int anssiz, - int* terrno, int* rcode) { - const sockaddr_storage ss = - (getQueryType(buf.data(), buf.size()) == NS_T_AAAA) ? mdns_addrs[0] : mdns_addrs[1]; +static int send_mdns(ResState* statp, span<const uint8_t> msg, span<uint8_t> ans, int* terrno, + int* rcode) { + const sockaddr_storage ss = (getQueryType(msg) == NS_T_AAAA) ? mdns_addrs[0] : mdns_addrs[1]; const sockaddr* mdnsap = reinterpret_cast<const sockaddr*>(&ss); unique_fd fd; if (setupUdpSocket(statp, mdnsap, &fd, terrno) <= 0) return 0; - if (sendto(fd, buf.data(), buf.size(), 0, mdnsap, sockaddrSize(mdnsap)) != buf.size()) { + if (sendto(fd, msg.data(), msg.size(), 0, mdnsap, sockaddrSize(mdnsap)) != msg.size()) { *terrno = errno; return 0; } @@ -1241,7 +1239,7 @@ static int send_mdns(ResState* statp, std::span<const uint8_t> buf, uint8_t* ans sockaddr_storage from; socklen_t fromlen = sizeof(from); - int resplen = recvfrom(fd, (char*)ans, (size_t)anssiz, 0, (sockaddr*)(void*)&from, &fromlen); + int resplen = recvfrom(fd, ans.data(), ans.size(), 0, (sockaddr*)(void*)&from, &fromlen); if (resplen <= 0) { *terrno = errno; @@ -1255,7 +1253,7 @@ static int send_mdns(ResState* statp, std::span<const uint8_t> buf, uint8_t* ans return 0; } - HEADER* anhp = (HEADER*)(void*)ans; + HEADER* anhp = (HEADER*)(void*)ans.data(); if (anhp->tc) { LOG(DEBUG) << __func__ << ": truncated answer"; *terrno = E2BIG; @@ -1415,7 +1413,8 @@ ssize_t res_doh_send(ResState* statp, const Slice query, const Slice answer, int } dnsQueryEvent->set_rcode(static_cast<NsRcode>(*rcode)); dnsQueryEvent->set_protocol(PROTO_DOH); - dnsQueryEvent->set_type(getQueryType(query.base(), query.size())); + span<const uint8_t> msg(query.base(), query.size()); + dnsQueryEvent->set_type(getQueryType(msg)); return result; } @@ -1463,12 +1462,12 @@ int res_tls_send(const std::list<DnsTlsServer>& tlsServers, ResState* statp, con } } -int resolv_res_nsend(const android_net_context* netContext, const uint8_t* msg, int msgLen, - uint8_t* ans, int ansLen, int* rcode, uint32_t flags, +int resolv_res_nsend(const android_net_context* netContext, span<const uint8_t> msg, + span<uint8_t> ans, int* rcode, uint32_t flags, NetworkDnsEventReported* event) { assert(event != nullptr); ResState res(netContext, event); resolv_populate_res_for_net(&res); *rcode = NOERROR; - return res_nsend(&res, msg, msgLen, ans, ansLen, rcode, flags); + return res_nsend(&res, msg, ans, rcode, flags); } @@ -16,10 +16,12 @@ #pragma once +#include <span> + #include "netd_resolv/resolv.h" // struct android_net_context #include "stats.pb.h" // Query dns with raw msg -int resolv_res_nsend(const android_net_context* netContext, const uint8_t* msg, int msgLen, - uint8_t* ans, int ansLen, int* rcode, uint32_t flags, +int resolv_res_nsend(const android_net_context* netContext, std::span<const uint8_t> msg, + std::span<uint8_t> ans, int* rcode, uint32_t flags, android::net::NetworkDnsEventReported* event); diff --git a/resolv_cache.h b/resolv_cache.h index 7fba618b..9b267915 100644 --- a/resolv_cache.h +++ b/resolv_cache.h @@ -28,6 +28,7 @@ #pragma once +#include <span> #include <unordered_map> #include <vector> @@ -61,16 +62,16 @@ typedef enum { RESOLV_CACHE_SKIP /* Don't do anything on cache */ } ResolvCacheStatus; -ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int querylen, void* answer, - int answersize, int* answerlen, uint32_t flags); +ResolvCacheStatus resolv_cache_lookup(unsigned netid, std::span<const uint8_t> query, + std::span<uint8_t> answer, int* answerlen, uint32_t flags); // add a (query,answer) to the cache. If the pair has been in the cache, no new entry will be added // in the cache. -int resolv_cache_add(unsigned netid, const void* query, int querylen, const void* answer, - int answerlen); +int resolv_cache_add(unsigned netid, std::span<const uint8_t> query, + std::span<const uint8_t> answer); /* Notify the cache a request failed */ -void _resolv_cache_query_failed(unsigned netid, const void* query, int querylen, uint32_t flags); +void _resolv_cache_query_failed(unsigned netid, std::span<const uint8_t> query, uint32_t flags); // Get a customized table for a given network. std::vector<std::string> getCustomizedTableByName(const size_t netid, const char* hostname); @@ -103,7 +104,7 @@ bool has_named_cache(unsigned netid); // For test only. // Get the expiration time of a cache entry. Return 0 on success; otherwise, an negative error is // returned if the expiration time can't be acquired. -int resolv_cache_get_expiration(unsigned netid, const std::vector<char>& query, time_t* expiration); +int resolv_cache_get_expiration(unsigned netid, std::span<const uint8_t> query, time_t* expiration); // Set addresses to DnsStats for a given network. int resolv_stats_set_addrs(unsigned netid, android::net::Protocol proto, diff --git a/resolv_private.h b/resolv_private.h index a35dd5da..31c541f4 100644 --- a/resolv_private.h +++ b/resolv_private.h @@ -54,6 +54,7 @@ #include <net/if.h> #include <time.h> +#include <span> #include <string> #include <vector> @@ -168,14 +169,14 @@ extern const char* const _res_opcodes[]; int res_nameinquery(const char*, int, int, const uint8_t*, const uint8_t*); int res_queriesmatch(const uint8_t*, const uint8_t*, const uint8_t*, const uint8_t*); -int res_nquery(ResState*, const char*, int, int, uint8_t*, int, int*); -int res_nsearch(ResState*, const char*, int, int, uint8_t*, int, int*); -int res_nquerydomain(ResState*, const char*, const char*, int, int, uint8_t*, int, int*); -int res_nmkquery(int op, const char* qname, int cl, int type, const uint8_t* data, int datalen, - uint8_t* buf, int buflen, int netcontext_flags); -int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int anssiz, int* rcode, +int res_nquery(ResState*, const char*, int, int, std::span<uint8_t>, int*); +int res_nsearch(ResState*, const char*, int, int, std::span<uint8_t>, int*); +int res_nquerydomain(ResState*, const char*, const char*, int, int, std::span<uint8_t>, int*); +int res_nmkquery(int op, const char* qname, int cl, int type, std::span<const uint8_t> data, + std::span<uint8_t> msg, int netcontext_flags); +int res_nsend(ResState* statp, std::span<const uint8_t> msg, std::span<uint8_t> ans, int* rcode, uint32_t flags, std::chrono::milliseconds sleepTimeMs = {}); -int res_nopt(ResState*, int, uint8_t*, int, int); +int res_nopt(ResState*, int, std::span<uint8_t>, int); int getaddrinfo_numeric(const char* hostname, const char* servname, addrinfo hints, addrinfo** result); @@ -227,7 +228,7 @@ constexpr T* align_ptr(T* const p) { // static_assert(align_ptr<sizeof(uint32_t)>((char*)1004) == (char*)1004); // static_assert(align_ptr<sizeof(uint64_t)>((char*)1004) == (char*)1008); -android::net::NsType getQueryType(const uint8_t* msg, size_t msgLen); +android::net::NsType getQueryType(std::span<const uint8_t> msg); android::net::IpVersion ipFamilyToIPVersion(int ipFamily); diff --git a/tests/resolv_cache_unit_test.cpp b/tests/resolv_cache_unit_test.cpp index a9a0f721..050964e0 100644 --- a/tests/resolv_cache_unit_test.cpp +++ b/tests/resolv_cache_unit_test.cpp @@ -20,6 +20,7 @@ #include <atomic> #include <chrono> #include <ctime> +#include <span> #include <thread> #include <android-base/logging.h> @@ -49,8 +50,8 @@ constexpr int MAX_ENTRIES = 64 * 2 * 5; namespace { struct CacheEntry { - std::vector<char> query; - std::vector<char> answer; + std::vector<uint8_t> query; + std::vector<uint8_t> answer; }; struct SetupParams { @@ -67,18 +68,17 @@ struct CacheStats { int pendingReqTimeoutCount; }; -std::vector<char> makeQuery(int op, const char* qname, int qclass, int qtype) { +std::vector<uint8_t> makeQuery(int op, const char* qname, int qclass, int qtype) { uint8_t buf[MAXPACKET] = {}; - const int len = res_nmkquery(op, qname, qclass, qtype, /*data=*/nullptr, /*datalen=*/0, buf, - sizeof(buf), - /*netcontext_flags=*/0); - return std::vector<char>(buf, buf + len); + const int len = res_nmkquery(op, qname, qclass, qtype, {}, buf, /*netcontext_flags=*/0); + return std::vector<uint8_t>(buf, buf + len); } -std::vector<char> makeAnswer(const std::vector<char>& query, const char* rdata_str, - const unsigned ttl) { +std::vector<uint8_t> makeAnswer(const std::vector<uint8_t>& query, const char* rdata_str, + const unsigned ttl) { test::DNSHeader header; - header.read(query.data(), query.data() + query.size()); + header.read(reinterpret_cast<const char*>(query.data()), + reinterpret_cast<const char*>(query.data()) + query.size()); for (const test::DNSQuestion& question : header.questions) { std::string rname(question.qname.name); @@ -94,7 +94,7 @@ std::vector<char> makeAnswer(const std::vector<char>& query, const char* rdata_s char answer[MAXPACKET] = {}; char* answer_end = header.write(answer, answer + sizeof(answer)); - return std::vector<char>(answer, answer_end); + return std::vector<uint8_t>(answer, answer_end); } // Get the current time in unix timestamp since the Epoch. @@ -155,9 +155,8 @@ class ResolvCacheTest : public ::testing::Test { [[nodiscard]] bool cacheLookup(ResolvCacheStatus expectedCacheStatus, uint32_t netId, const CacheEntry& ce, uint32_t flags = 0) { int anslen = 0; - std::vector<char> answer(MAXPACKET); - const auto cacheStatus = resolv_cache_lookup(netId, ce.query.data(), ce.query.size(), - answer.data(), answer.size(), &anslen, flags); + std::vector<uint8_t> answer(MAXPACKET); + const auto cacheStatus = resolv_cache_lookup(netId, ce.query, answer, &anslen, flags); if (cacheStatus != expectedCacheStatus) { ADD_FAILURE() << "cacheStatus: expected = " << expectedCacheStatus << ", actual =" << cacheStatus; @@ -183,20 +182,20 @@ class ResolvCacheTest : public ::testing::Test { } int cacheAdd(uint32_t netId, const CacheEntry& ce) { - return resolv_cache_add(netId, ce.query.data(), ce.query.size(), ce.answer.data(), - ce.answer.size()); + return resolv_cache_add(netId, ce.query, ce.answer); } - int cacheAdd(uint32_t netId, const std::vector<char>& query, const std::vector<char>& answer) { - return resolv_cache_add(netId, query.data(), query.size(), answer.data(), answer.size()); + int cacheAdd(uint32_t netId, const std::vector<uint8_t>& query, + const std::vector<uint8_t>& answer) { + return resolv_cache_add(netId, query, answer); } - int cacheGetExpiration(uint32_t netId, const std::vector<char>& query, time_t* expiration) { + int cacheGetExpiration(uint32_t netId, const std::vector<uint8_t>& query, time_t* expiration) { return resolv_cache_get_expiration(netId, query, expiration); } void cacheQueryFailed(uint32_t netId, const CacheEntry& ce, uint32_t flags) { - _resolv_cache_query_failed(netId, ce.query.data(), ce.query.size(), flags); + _resolv_cache_query_failed(netId, ce.query, flags); } int cacheSetupResolver(uint32_t netId, const SetupParams& setup) { @@ -284,8 +283,8 @@ TEST_F(ResolvCacheTest, CreateAndDeleteCache) { TEST_F(ResolvCacheTest, CacheAdd_InvalidArgs) { EXPECT_EQ(0, cacheCreate(TEST_NETID)); - const std::vector<char> queryEmpty(MAXPACKET, 0); - const std::vector<char> queryTooSmall(DNS_HEADER_SIZE - 1, 0); + const std::vector<uint8_t> queryEmpty(MAXPACKET, 0); + const std::vector<uint8_t> queryTooSmall(DNS_HEADER_SIZE - 1, 0); CacheEntry ce = makeCacheEntry(QUERY, "valid.cache", ns_c_in, ns_t_a, "1.2.3.4"); EXPECT_EQ(-EINVAL, cacheAdd(TEST_NETID, queryEmpty, ce.answer)); @@ -395,15 +394,14 @@ TEST_F(ResolvCacheTest, CacheLookup_Types) { TEST_F(ResolvCacheTest, CacheLookup_InvalidArgs) { EXPECT_EQ(0, cacheCreate(TEST_NETID)); - const std::vector<char> queryEmpty(MAXPACKET, 0); - const std::vector<char> queryTooSmall(DNS_HEADER_SIZE - 1, 0); - std::vector<char> answerTooSmall(DNS_HEADER_SIZE - 1, 0); + const std::vector<uint8_t> queryEmpty(MAXPACKET, 0); + const std::vector<uint8_t> queryTooSmall(DNS_HEADER_SIZE - 1, 0); + std::vector<uint8_t> answerTooSmall(DNS_HEADER_SIZE - 1, 0); const CacheEntry ce = makeCacheEntry(QUERY, "valid.cache", ns_c_in, ns_t_a, "1.2.3.4"); - auto cacheLookupFn = [](const std::vector<char>& query, - std::vector<char> answer) -> ResolvCacheStatus { + auto cacheLookupFn = [](const std::vector<uint8_t>& query, + std::vector<uint8_t> answer) -> ResolvCacheStatus { int anslen = 0; - return resolv_cache_lookup(TEST_NETID, query.data(), query.size(), answer.data(), - answer.size(), &anslen, 0); + return resolv_cache_lookup(TEST_NETID, query, answer, &anslen, 0); }; EXPECT_EQ(0, cacheAdd(TEST_NETID, ce)); |
