aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBruce Chen <chenbruce@google.com>2021-09-03 02:38:04 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2021-09-03 02:38:04 +0000
commitb121cbf0355ae9d6dc1aba11ac2dc107cebc8ab4 (patch)
tree1b85f87ab0a6e6e20c013d08565e8f5095802ef7
parent5d7251433327be0304a6075c25ddb9a77c41ebb6 (diff)
parent094d9ab72cc2dd9c4d11e31a27ffc6bab0cf9733 (diff)
downloadplatform_packages_modules_DnsResolver-b121cbf0355ae9d6dc1aba11ac2dc107cebc8ab4.tar.gz
platform_packages_modules_DnsResolver-b121cbf0355ae9d6dc1aba11ac2dc107cebc8ab4.tar.bz2
platform_packages_modules_DnsResolver-b121cbf0355ae9d6dc1aba11ac2dc107cebc8ab4.zip
Merge "Replace manual buffer handling with std::span" am: 094d9ab72c
Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1790127 Change-Id: I35c2c70c5af1eb311f50b0afa530d09bb4d110bd
-rw-r--r--DnsProxyListener.cpp45
-rw-r--r--DnsTlsDispatcher.cpp3
-rw-r--r--getaddrinfo.cpp21
-rw-r--r--gethnamaddr.cpp4
-rw-r--r--res_cache.cpp47
-rw-r--r--res_mkquery.cpp58
-rw-r--r--res_query.cpp33
-rw-r--r--res_send.cpp185
-rw-r--r--res_send.h6
-rw-r--r--resolv_cache.h13
-rw-r--r--resolv_private.h17
-rw-r--r--tests/resolv_cache_unit_test.cpp56
12 files changed, 237 insertions, 251 deletions
diff --git a/DnsProxyListener.cpp b/DnsProxyListener.cpp
index db23cfc2..d6f90ccd 100644
--- a/DnsProxyListener.cpp
+++ b/DnsProxyListener.cpp
@@ -39,7 +39,6 @@
#include <cutils/multiuser.h>
#include <netdutils/InternetAddresses.h>
#include <netdutils/ResponseCode.h>
-#include <netdutils/Slice.h>
#include <netdutils/Stopwatch.h>
#include <netdutils/ThreadUtil.h>
#include <private/android_filesystem_config.h> // AID_SYSTEM
@@ -65,6 +64,7 @@ using aidl::android::net::metrics::INetdEventListener;
using aidl::android::net::resolv::aidl::DnsHealthEventParcel;
using aidl::android::net::resolv::aidl::IDnsResolverUnsolicitedEventListener;
using android::net::NetworkDnsEventReported;
+using std::span;
namespace android {
@@ -147,11 +147,11 @@ void maybeFixupNetContext(android_net_context* ctx, pid_t pid) {
void addIpAddrWithinLimit(std::vector<std::string>* ip_addrs, const sockaddr* addr,
socklen_t addrlen);
-int extractResNsendAnswers(const uint8_t* answer, size_t anslen, int ipType,
+int extractResNsendAnswers(std::span<const uint8_t> answer, int ipType,
std::vector<std::string>* ip_addrs) {
int total_ip_addr_count = 0;
ns_msg handle;
- if (ns_initparse((const uint8_t*)answer, anslen, &handle) < 0) {
+ if (ns_initparse(answer.data(), answer.size(), &handle) < 0) {
return 0;
}
int ancount = ns_msg_count(handle, ns_s_an);
@@ -250,21 +250,20 @@ bool simpleStrtoul(const char* input, IntegralType* output, int base = 10) {
return true;
}
-bool setQueryId(uint8_t* msg, size_t msgLen, uint16_t query_id) {
- if (msgLen < sizeof(HEADER)) {
+bool setQueryId(span<uint8_t> msg, uint16_t query_id) {
+ if ((size_t)msg.size() < sizeof(HEADER)) {
errno = EINVAL;
return false;
}
- auto hp = reinterpret_cast<HEADER*>(msg);
+ auto hp = reinterpret_cast<HEADER*>(msg.data());
hp->id = htons(query_id);
return true;
}
-bool parseQuery(const uint8_t* msg, size_t msgLen, uint16_t* query_id, int* rr_type,
- std::string* rr_name) {
+bool parseQuery(span<const uint8_t> msg, uint16_t* query_id, int* rr_type, std::string* rr_name) {
ns_msg handle;
ns_rr rr;
- if (ns_initparse((const uint8_t*)msg, msgLen, &handle) < 0 ||
+ if (ns_initparse(msg.data(), msg.size(), &handle) < 0 ||
ns_parserr(&handle, ns_s_qd, 0, &rr) < 0) {
return false;
}
@@ -927,8 +926,8 @@ void DnsProxyListener::ResNSendHandler::run() {
uint16_t original_query_id = 0;
// TODO: Handle the case which is msg contains more than one query
- if (!parseQuery(msg.data(), msgLen, &original_query_id, &rr_type, &rr_name) ||
- !setQueryId(msg.data(), msgLen, arc4random_uniform(65536))) {
+ if (!parseQuery({msg.data(), msgLen}, &original_query_id, &rr_type, &rr_name) ||
+ !setQueryId({msg.data(), msgLen}, arc4random_uniform(65536))) {
// If the query couldn't be parsed, block the request.
LOG(WARNING) << "ResNSendHandler::run: resnsend: from UID " << uid << ", invalid query";
sendBE32(mClient, -EINVAL);
@@ -938,21 +937,21 @@ void DnsProxyListener::ResNSendHandler::run() {
// Send DNS query
std::vector<uint8_t> ansBuf(MAXPACKET, 0);
int rcode = ns_r_noerror;
- int nsendAns = -1;
+ int ansLen = -1;
NetworkDnsEventReported event;
initDnsEvent(&event, mNetContext);
if (queryLimiter.start(uid)) {
if (evaluate_domain_name(mNetContext, rr_name.c_str())) {
- nsendAns = resolv_res_nsend(&mNetContext, msg.data(), msgLen, ansBuf.data(), MAXPACKET,
- &rcode, static_cast<ResNsendFlags>(mFlags), &event);
+ ansLen = resolv_res_nsend(&mNetContext, {msg.data(), msgLen}, ansBuf, &rcode,
+ static_cast<ResNsendFlags>(mFlags), &event);
} else {
- nsendAns = -EAI_SYSTEM;
+ ansLen = -EAI_SYSTEM;
}
queryLimiter.finish(uid);
} else {
LOG(WARNING) << "ResNSendHandler::run: resnsend: from UID " << uid
<< ", max concurrent queries reached";
- nsendAns = -EBUSY;
+ ansLen = -EBUSY;
}
const int32_t latencyUs = saturate_cast<int32_t>(s.timeTakenUs());
@@ -961,14 +960,14 @@ void DnsProxyListener::ResNSendHandler::run() {
event.set_res_nsend_flags(static_cast<ResNsendFlags>(mFlags));
// Fail, send -errno
- if (nsendAns < 0) {
- if (!sendBE32(mClient, nsendAns)) {
+ if (ansLen < 0) {
+ if (!sendBE32(mClient, ansLen)) {
PLOG(WARNING) << "ResNSendHandler::run: resnsend: failed to send errno to uid " << uid
<< " pid " << mClient->getPid();
}
if (rr_type == ns_t_a || rr_type == ns_t_aaaa) {
reportDnsEvent(INetdEventListener::EVENT_RES_NSEND, mNetContext, latencyUs,
- resNSendToAiError(nsendAns, rcode), event, rr_name);
+ resNSendToAiError(ansLen, rcode), event, rr_name);
}
return;
}
@@ -981,8 +980,8 @@ void DnsProxyListener::ResNSendHandler::run() {
}
// Restore query id and send answer
- if (!setQueryId(ansBuf.data(), nsendAns, original_query_id) ||
- !sendLenAndData(mClient, nsendAns, ansBuf.data())) {
+ if (!setQueryId({ansBuf.data(), ansLen}, original_query_id) ||
+ !sendLenAndData(mClient, ansLen, ansBuf.data())) {
PLOG(WARNING) << "ResNSendHandler::run: resnsend: failed to send answer to uid " << uid
<< " pid " << mClient->getPid();
return;
@@ -991,9 +990,9 @@ void DnsProxyListener::ResNSendHandler::run() {
if (rr_type == ns_t_a || rr_type == ns_t_aaaa) {
std::vector<std::string> ip_addrs;
const int total_ip_addr_count =
- extractResNsendAnswers((uint8_t*)ansBuf.data(), nsendAns, rr_type, &ip_addrs);
+ extractResNsendAnswers({ansBuf.data(), ansLen}, rr_type, &ip_addrs);
reportDnsEvent(INetdEventListener::EVENT_RES_NSEND, mNetContext, latencyUs,
- resNSendToAiError(nsendAns, rcode), event, rr_name, ip_addrs,
+ resNSendToAiError(ansLen, rcode), event, rr_name, ip_addrs,
total_ip_addr_count);
}
}
diff --git a/DnsTlsDispatcher.cpp b/DnsTlsDispatcher.cpp
index 67f98576..b6f766bb 100644
--- a/DnsTlsDispatcher.cpp
+++ b/DnsTlsDispatcher.cpp
@@ -128,7 +128,8 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const std::list<DnsTlsServer>&
dnsQueryEvent->set_dns_server_index(serverCount++);
dnsQueryEvent->set_ip_version(ipFamilyToIPVersion(server.ss.ss_family));
dnsQueryEvent->set_protocol(PROTO_DOT);
- dnsQueryEvent->set_type(getQueryType(query.base(), query.size()));
+ std::span<const uint8_t> msg(query.base(), query.size());
+ dnsQueryEvent->set_type(getQueryType(msg));
dnsQueryEvent->set_connected(connectTriggered);
switch (code) {
diff --git a/getaddrinfo.cpp b/getaddrinfo.cpp
index 2e407f99..cb95311c 100644
--- a/getaddrinfo.cpp
+++ b/getaddrinfo.cpp
@@ -1627,13 +1627,11 @@ QueryResult doQuery(const char* name, res_target* t, ResState* res,
LOG(DEBUG) << __func__ << ": (" << cl << ", " << type << ")";
uint8_t buf[MAXPACKET];
-
- int n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf, sizeof(buf),
- res->netcontext_flags);
+ int n = res_nmkquery(QUERY, name, cl, type, {}, buf, res->netcontext_flags);
if (n > 0 &&
(res->netcontext_flags & (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS))) {
- n = res_nopt(res, n, buf, sizeof(buf), anslen);
+ n = res_nopt(res, n, buf, anslen);
}
NetworkDnsEventReported event;
@@ -1651,7 +1649,7 @@ QueryResult doQuery(const char* name, res_target* t, ResState* res,
ResState res_temp = res->clone(&event);
int rcode = NOERROR;
- n = res_nsend(&res_temp, buf, n, t->answer.data(), anslen, &rcode, 0, sleepTimeMs);
+ n = res_nsend(&res_temp, {buf, n}, {t->answer.data(), anslen}, &rcode, 0, sleepTimeMs);
if (n < 0 || hp->rcode != NOERROR || ntohs(hp->ancount) == 0) {
// To ensure that the rcode handling is identical to res_queryN().
if (rcode != RCODE_TIMEOUT) rcode = hp->rcode;
@@ -1660,9 +1658,8 @@ QueryResult doQuery(const char* name, res_target* t, ResState* res,
(NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) &&
(res_temp.flags & RES_F_EDNS0ERR)) {
LOG(DEBUG) << __func__ << ": retry without EDNS0";
- n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf,
- sizeof(buf), res_temp.netcontext_flags);
- n = res_nsend(&res_temp, buf, n, t->answer.data(), anslen, &rcode, 0);
+ n = res_nmkquery(QUERY, name, cl, type, {}, buf, res_temp.netcontext_flags);
+ n = res_nsend(&res_temp, {buf, n}, {t->answer.data(), anslen}, &rcode, 0);
}
}
@@ -1761,21 +1758,19 @@ static int res_queryN(const char* name, res_target* target, ResState* res, int*
const int anslen = t->answer.size();
LOG(DEBUG) << __func__ << ": (" << cl << ", " << type << ")";
-
- n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf, sizeof(buf),
- res->netcontext_flags);
+ n = res_nmkquery(QUERY, name, cl, type, {}, buf, res->netcontext_flags);
if (n > 0 &&
(res->netcontext_flags &
(NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) &&
!retried) // TODO: remove the retry flag and provide a sufficient test coverage.
- n = res_nopt(res, n, buf, sizeof(buf), anslen);
+ n = res_nopt(res, n, buf, anslen);
if (n <= 0) {
LOG(ERROR) << __func__ << ": res_nmkquery failed";
*herrno = NO_RECOVERY;
return n;
}
- n = res_nsend(res, buf, n, t->answer.data(), anslen, &rcode, 0);
+ n = res_nsend(res, {buf, n}, {t->answer.data(), anslen}, &rcode, 0);
if (n < 0 || hp->rcode != NOERROR || ntohs(hp->ancount) == 0) {
// Record rcode from DNS response header only if no timeout.
// Keep rcode timeout for reporting later if any.
diff --git a/gethnamaddr.cpp b/gethnamaddr.cpp
index bbc12a2e..5e35b46a 100644
--- a/gethnamaddr.cpp
+++ b/gethnamaddr.cpp
@@ -632,7 +632,7 @@ static int dns_gethtbyname(ResState* res, const char* name, int addr_type, getna
int he;
const unsigned qclass = isMdnsResolution(res->flags) ? C_IN | C_UNICAST : C_IN;
- n = res_nsearch(res, name, qclass, type, buf->buf, (int)sizeof(buf->buf), &he);
+ n = res_nsearch(res, name, qclass, type, {buf->buf, (int)sizeof(buf->buf)}, &he);
if (n < 0) {
LOG(DEBUG) << __func__ << ": res_nsearch failed (" << n << ")";
// Return h_errno (he) to catch more detailed errors rather than EAI_NODATA.
@@ -694,7 +694,7 @@ static int dns_gethtbyaddr(const unsigned char* uaddr, int len, int af,
ResState res(netcontext, event);
int he;
- n = res_nquery(&res, qbuf, C_IN, T_PTR, buf->buf, (int)sizeof(buf->buf), &he);
+ n = res_nquery(&res, qbuf, C_IN, T_PTR, {buf->buf, (int)sizeof(buf->buf)}, &he);
if (n < 0) {
LOG(DEBUG) << __func__ << ": res_nquery failed (" << n << ")";
// Note that res_nquery() doesn't set the pair NETDB_INTERNAL and errno.
diff --git a/res_cache.cpp b/res_cache.cpp
index 0a1bd5c9..fe0885a5 100644
--- a/res_cache.cpp
+++ b/res_cache.cpp
@@ -78,6 +78,7 @@ using android::net::PROTO_UDP;
using android::net::Protocol;
using android::netdutils::DumpWriter;
using android::netdutils::IPSockAddr;
+using std::span;
/* This code implements a small and *simple* DNS resolver cache.
*
@@ -773,14 +774,14 @@ static uint32_t answer_getNegativeTTL(ns_msg handle) {
* In case of parse error zero (0) is returned which
* indicates that the answer shall not be cached.
*/
-static uint32_t answer_getTTL(const void* answer, int answerlen) {
+static uint32_t answer_getTTL(span<const uint8_t> answer) {
ns_msg handle;
int ancount, n;
uint32_t result, ttl;
ns_rr rr;
result = 0;
- if (ns_initparse((const uint8_t*) answer, answerlen, &handle) >= 0) {
+ if (ns_initparse(answer.data(), answer.size(), &handle) >= 0) {
// get number of answer records
ancount = ns_msg_count(handle, ns_s_an);
@@ -840,13 +841,13 @@ static unsigned entry_hash(const Entry* e) {
/* initialize an Entry as a search key, this also checks the input query packet
* returns 1 on success, or 0 in case of unsupported/malformed data */
-static int entry_init_key(Entry* e, const void* query, int querylen) {
+static int entry_init_key(Entry* e, span<const uint8_t> query) {
DnsPacket pack[1];
memset(e, 0, sizeof(*e));
- e->query = (const uint8_t*) query;
- e->querylen = querylen;
+ e->query = query.data();
+ e->querylen = query.size();
e->hash = entry_hash(e);
_dnsPacket_init(pack, e->query, e->querylen);
@@ -855,11 +856,11 @@ static int entry_init_key(Entry* e, const void* query, int querylen) {
}
/* allocate a new entry as a cache node */
-static Entry* entry_alloc(const Entry* init, const void* answer, int answerlen) {
+static Entry* entry_alloc(const Entry* init, span<const uint8_t> answer) {
Entry* e;
int size;
- size = sizeof(*e) + init->querylen + answerlen;
+ size = sizeof(*e) + init->querylen + answer.size();
e = (Entry*) calloc(size, 1);
if (e == NULL) return e;
@@ -870,9 +871,9 @@ static Entry* entry_alloc(const Entry* init, const void* answer, int answerlen)
memcpy((char*) e->query, init->query, e->querylen);
e->answer = e->query + e->querylen;
- e->answerlen = answerlen;
+ e->answerlen = answer.size();
- memcpy((char*) e->answer, answer, e->answerlen);
+ memcpy((char*)e->answer, answer.data(), e->answerlen);
return e;
}
@@ -1101,14 +1102,14 @@ static void cache_notify_waiting_tid_locked(struct Cache* cache, const Entry* ke
}
}
-void _resolv_cache_query_failed(unsigned netid, const void* query, int querylen, uint32_t flags) {
+void _resolv_cache_query_failed(unsigned netid, span<const uint8_t> query, uint32_t flags) {
// We should not notify with these flags.
if (flags & (ANDROID_RESOLV_NO_CACHE_STORE | ANDROID_RESOLV_NO_CACHE_LOOKUP)) {
return;
}
Entry key[1];
- if (!entry_init_key(key, query, querylen)) return;
+ if (!entry_init_key(key, query)) return;
std::lock_guard guard(cache_mutex);
@@ -1228,8 +1229,8 @@ static void _cache_remove_expired(Cache* cache) {
// Get a NetConfig associated with a network, or nullptr if not found.
static NetConfig* find_netconfig_locked(unsigned netid) REQUIRES(cache_mutex);
-ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int querylen, void* answer,
- int answersize, int* answerlen, uint32_t flags) {
+ResolvCacheStatus resolv_cache_lookup(unsigned netid, span<const uint8_t> query,
+ span<uint8_t> answer, int* answerlen, uint32_t flags) {
// Skip cache lookup, return RESOLV_CACHE_NOTFOUND directly so that it is
// possible to cache the answer of this query.
// If ANDROID_RESOLV_NO_CACHE_STORE is set, return RESOLV_CACHE_SKIP to skip possible cache
@@ -1247,7 +1248,7 @@ ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int que
LOG(INFO) << __func__ << ": lookup";
/* we don't cache malformed queries */
- if (!entry_init_key(&key, query, querylen)) {
+ if (!entry_init_key(&key, query)) {
LOG(INFO) << __func__ << ": unsupported query";
return RESOLV_CACHE_UNSUPPORTED;
}
@@ -1310,13 +1311,13 @@ ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int que
}
*answerlen = e->answerlen;
- if (e->answerlen > answersize) {
+ if (e->answerlen > answer.size()) {
/* NOTE: we return UNSUPPORTED if the answer buffer is too short */
LOG(INFO) << __func__ << ": ANSWER TOO LONG";
return RESOLV_CACHE_UNSUPPORTED;
}
- memcpy(answer, e->answer, e->answerlen);
+ memcpy(answer.data(), e->answer, e->answerlen);
/* bump up this entry to the top of the MRU list */
if (e != cache->mru_list.mru_next) {
@@ -1328,8 +1329,7 @@ ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int que
return RESOLV_CACHE_FOUND;
}
-int resolv_cache_add(unsigned netid, const void* query, int querylen, const void* answer,
- int answerlen) {
+int resolv_cache_add(unsigned netid, span<const uint8_t> query, span<const uint8_t> answer) {
Entry key[1];
Entry* e;
Entry** lookup;
@@ -1338,7 +1338,7 @@ int resolv_cache_add(unsigned netid, const void* query, int querylen, const void
/* don't assume that the query has already been cached
*/
- if (!entry_init_key(key, query, querylen)) {
+ if (!entry_init_key(key, query)) {
LOG(INFO) << __func__ << ": passed invalid query?";
return -EINVAL;
}
@@ -1375,9 +1375,9 @@ int resolv_cache_add(unsigned netid, const void* query, int querylen, const void
}
}
- ttl = answer_getTTL(answer, answerlen);
+ ttl = answer_getTTL(answer);
if (ttl > 0) {
- e = entry_alloc(key, answer, answerlen);
+ e = entry_alloc(key, answer);
if (e != NULL) {
e->expires = ttl + _time_now();
_cache_add_p(cache, lookup, e);
@@ -1886,13 +1886,12 @@ bool has_named_cache(unsigned netid) {
return find_named_cache_locked(netid) != nullptr;
}
-int resolv_cache_get_expiration(unsigned netid, const std::vector<char>& query,
- time_t* expiration) {
+int resolv_cache_get_expiration(unsigned netid, span<const uint8_t> query, time_t* expiration) {
Entry key;
*expiration = -1;
// A malformed query is not allowed.
- if (!entry_init_key(&key, query.data(), query.size())) {
+ if (!entry_init_key(&key, query)) {
LOG(WARNING) << __func__ << ": unsupported query";
return -EINVAL;
}
diff --git a/res_mkquery.cpp b/res_mkquery.cpp
index 939c4c29..d02cdb35 100644
--- a/res_mkquery.cpp
+++ b/res_mkquery.cpp
@@ -97,13 +97,11 @@ extern const char* const _res_opcodes[] = {
};
// Form all types of queries. Returns the size of the result or -1.
-int res_nmkquery(int op, // opcode of query
- const char* dname, // domain name
- int cl, int type, // class and type of query
- const uint8_t* data, // resource record data
- int datalen, // length of data
- uint8_t* buf, // buffer to put query
- int buflen, // size of buffer
+int res_nmkquery(int op, // opcode of query
+ const char* dname, // domain name
+ int cl, int type, // class and type of query
+ std::span<const uint8_t> data, // resource record data
+ std::span<uint8_t> buf, // buffer to put query
int netcontext_flags) {
HEADER* hp;
uint8_t *cp, *ep;
@@ -116,18 +114,18 @@ int res_nmkquery(int op, // opcode of query
/*
* Initialize header fields.
*/
- if ((buf == NULL) || (buflen < HFIXEDSZ)) return (-1);
- memset(buf, 0, HFIXEDSZ);
- hp = (HEADER*) (void*) buf;
+ if (buf.empty() || (buf.size() < HFIXEDSZ)) return (-1);
+ memset(buf.data(), 0, HFIXEDSZ);
+ hp = (HEADER*)(void*)buf.data();
hp->id = htons(arc4random_uniform(65536));
hp->opcode = op;
hp->rd = true;
hp->ad = (netcontext_flags & NET_CONTEXT_FLAG_USE_DNS_OVER_TLS) != 0U;
hp->rcode = NOERROR;
- cp = buf + HFIXEDSZ;
- ep = buf + buflen;
+ cp = buf.data() + HFIXEDSZ;
+ ep = buf.data() + buf.size();
dpp = dnptrs;
- *dpp++ = buf;
+ *dpp++ = buf.data();
*dpp++ = NULL;
lastdnptr = dnptrs + sizeof dnptrs / sizeof dnptrs[0];
/*
@@ -145,12 +143,12 @@ int res_nmkquery(int op, // opcode of query
*reinterpret_cast<uint16_t*>(cp) = htons(cl);
cp += INT16SZ;
hp->qdcount = htons(1);
- if (op == QUERY || data == NULL) break;
+ if (op == QUERY || data.empty()) break;
/*
* Make an additional record for completion domain.
*/
if ((ep - cp) < RRFIXEDSZ) return (-1);
- n = dn_comp((const char*) data, cp, ep - cp - RRFIXEDSZ, dnptrs, lastdnptr);
+ n = dn_comp((const char*)data.data(), cp, ep - cp - RRFIXEDSZ, dnptrs, lastdnptr);
if (n < 0) return (-1);
cp += n;
*reinterpret_cast<uint16_t*>(cp) = htons(ns_t_null);
@@ -168,7 +166,7 @@ int res_nmkquery(int op, // opcode of query
/*
* Initialize answer section
*/
- if (ep - cp < 1 + RRFIXEDSZ + datalen) return (-1);
+ if (ep - cp < 1 + RRFIXEDSZ + data.size()) return (-1);
*cp++ = '\0'; /* no domain name */
*reinterpret_cast<uint16_t*>(cp) = htons(type);
cp += INT16SZ;
@@ -176,11 +174,11 @@ int res_nmkquery(int op, // opcode of query
cp += INT16SZ;
*reinterpret_cast<uint32_t*>(cp) = htonl(0);
cp += INT32SZ;
- *reinterpret_cast<uint16_t*>(cp) = htons(datalen);
+ *reinterpret_cast<uint16_t*>(cp) = htons(data.size());
cp += INT16SZ;
- if (datalen) {
- memcpy(cp, data, (size_t) datalen);
- cp += datalen;
+ if (data.size()) {
+ memcpy(cp, data.data(), data.size());
+ cp += data.size();
}
hp->ancount = htons(1);
break;
@@ -188,23 +186,21 @@ int res_nmkquery(int op, // opcode of query
default:
return (-1);
}
- return (cp - buf);
+ return (cp - buf.data());
}
int res_nopt(ResState* statp, int n0, /* current offset in buffer */
- uint8_t* buf, /* buffer to put query */
- int buflen, /* size of buffer */
+ std::span<uint8_t> buf, /* buffer to put query */
int anslen) /* UDP answer buffer size */
{
- HEADER* hp;
+ HEADER* hp = reinterpret_cast<HEADER*>(buf.data());
uint8_t *cp, *ep;
uint16_t flags = 0;
LOG(DEBUG) << __func__;
- hp = (HEADER*) (void*) buf;
- cp = buf + n0;
- ep = buf + buflen;
+ cp = buf.data() + n0;
+ ep = buf.data() + buf.size();
if ((ep - cp) < 1 + RRFIXEDSZ) return (-1);
@@ -226,13 +222,13 @@ int res_nopt(ResState* statp, int n0, /* current offset in buffer */
cp += INT16SZ;
// EDNS0 padding
- const uint16_t minlen = static_cast<uint16_t>(cp - buf) + 3 * INT16SZ;
+ const uint16_t minlen = static_cast<uint16_t>(cp - buf.data()) + 3 * INT16SZ;
const uint16_t extra = minlen % kEdns0Padding;
uint16_t padlen = (kEdns0Padding - extra) % kEdns0Padding;
- if (minlen > buflen) {
+ if (minlen > buf.size()) {
return -1;
}
- padlen = std::min(padlen, static_cast<uint16_t>(buflen - minlen));
+ padlen = std::min(padlen, static_cast<uint16_t>(buf.size() - minlen));
*reinterpret_cast<uint16_t*>(cp) = htons(padlen + 2 * INT16SZ); /* RDLEN */
cp += INT16SZ;
*reinterpret_cast<uint16_t*>(cp) = htons(NS_OPT_PADDING); /* OPTION-CODE */
@@ -243,5 +239,5 @@ int res_nopt(ResState* statp, int n0, /* current offset in buffer */
cp += padlen;
hp->arcount = htons(ntohs(hp->arcount) + 1);
- return (cp - buf);
+ return (cp - buf.data());
}
diff --git a/res_query.cpp b/res_query.cpp
index 5019ac30..2359b32b 100644
--- a/res_query.cpp
+++ b/res_query.cpp
@@ -101,13 +101,12 @@
*/
int res_nquery(ResState* statp, const char* name, // domain name
int cl, int type, // class and type of query
- uint8_t* answer, // buffer to put answer
- int anslen, // size of answer buffer
+ std::span<uint8_t> answer, // buffer to put answer
int* herrno) // legacy and extended h_errno
// NETD_RESOLV_H_ERRNO_EXT_*
{
uint8_t buf[MAXPACKET];
- HEADER* hp = (HEADER*) (void*) answer;
+ HEADER* hp = reinterpret_cast<HEADER*>(answer.data());
int n;
int rcode = NOERROR;
bool retried = false;
@@ -116,20 +115,18 @@ again:
hp->rcode = NOERROR; // default
LOG(DEBUG) << __func__ << ": (" << cl << ", " << type << ")";
-
- n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, 0, buf, sizeof(buf),
- statp->netcontext_flags);
+ n = res_nmkquery(QUERY, name, cl, type, {}, buf, statp->netcontext_flags);
if (n > 0 &&
(statp->netcontext_flags &
(NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) &&
!retried)
- n = res_nopt(statp, n, buf, sizeof(buf), anslen);
+ n = res_nopt(statp, n, buf, answer.size());
if (n <= 0) {
LOG(DEBUG) << __func__ << ": mkquery failed";
*herrno = NO_RECOVERY;
return n;
}
- n = res_nsend(statp, buf, n, answer, anslen, &rcode, 0);
+ n = res_nsend(statp, {buf, n}, answer, &rcode, 0);
if (n < 0) {
// If the query choked with EDNS0, retry without EDNS0 that when the server
// has no response, resovler won't retry and do nothing. Even fallback to UDP,
@@ -203,13 +200,12 @@ again:
*/
int res_nsearch(ResState* statp, const char* name, /* domain name */
int cl, int type, /* class and type of query */
- uint8_t* answer, /* buffer to put answer */
- int anslen, /* size of answer */
+ std::span<uint8_t> answer, /* buffer to put answer */
int* herrno) /* legacy and extended
h_errno NETD_RESOLV_H_ERRNO_EXT_* */
{
const char* cp;
- HEADER* hp = (HEADER*) (void*) answer;
+ HEADER* hp = reinterpret_cast<HEADER*>(answer.data());
uint32_t dots;
int ret, saved_herrno;
int got_nodata = 0, got_servfail = 0, root_on_list = 0;
@@ -229,7 +225,7 @@ int res_nsearch(ResState* statp, const char* name, /* domain name */
*/
saved_herrno = -1;
if (dots >= statp->ndots || trailing_dot) {
- ret = res_nquerydomain(statp, name, NULL, cl, type, answer, anslen, herrno);
+ ret = res_nquerydomain(statp, name, NULL, cl, type, answer, herrno);
if (ret > 0 || trailing_dot) return ret;
saved_herrno = *herrno;
tried_as_is++;
@@ -255,7 +251,7 @@ int res_nsearch(ResState* statp, const char* name, /* domain name */
for (const auto& domain : statp->search_domains) {
if (domain == "." || domain == "") ++root_on_list;
- ret = res_nquerydomain(statp, name, domain.c_str(), cl, type, answer, anslen, herrno);
+ ret = res_nquerydomain(statp, name, domain.c_str(), cl, type, answer, herrno);
if (ret > 0) return ret;
/*
@@ -301,7 +297,7 @@ int res_nsearch(ResState* statp, const char* name, /* domain name */
// note that we do this regardless of how many dots were in the
// name or whether it ends with a dot.
if (!tried_as_is && !root_on_list) {
- ret = res_nquerydomain(statp, name, NULL, cl, type, answer, anslen, herrno);
+ ret = res_nquerydomain(statp, name, NULL, cl, type, answer, herrno);
if (ret > 0) return ret;
}
@@ -326,10 +322,9 @@ int res_nsearch(ResState* statp, const char* name, /* domain name */
* removing a trailing dot from name if domain is NULL.
*/
int res_nquerydomain(ResState* statp, const char* name, const char* domain, int cl,
- int type, /* class and type of query */
- uint8_t* answer, /* buffer to put answer */
- int anslen, /* size of answer */
- int* herrno) /* legacy and extended h_errno NETD_RESOLV_H_ERRNO_EXT_* */
+ int type, /* class and type of query */
+ std::span<uint8_t> answer, /* buffer to put answer */
+ int* herrno) /* legacy and extended h_errno NETD_RESOLV_H_ERRNO_EXT_* */
{
char nbuf[MAXDNAME];
const char* longname = nbuf;
@@ -362,5 +357,5 @@ int res_nquerydomain(ResState* statp, const char* name, const char* domain, int
}
snprintf(nbuf, sizeof(nbuf), "%s.%s", name, domain);
}
- return res_nquery(statp, longname, cl, type, answer, anslen, herrno);
+ return res_nquery(statp, longname, cl, type, answer, herrno);
}
diff --git a/res_send.cpp b/res_send.cpp
index 82677959..c5eb6270 100644
--- a/res_send.cpp
+++ b/res_send.cpp
@@ -150,19 +150,19 @@ using android::net::PROTO_UDP;
using android::netdutils::IPSockAddr;
using android::netdutils::Slice;
using android::netdutils::Stopwatch;
+using std::span;
const std::vector<IPSockAddr> mdns_addrs = {IPSockAddr::toIPSockAddr("ff02::fb", 5353),
IPSockAddr::toIPSockAddr("224.0.0.251", 5353)};
static int setupUdpSocket(ResState* statp, const sockaddr* sockap, unique_fd* fd_out, int* terrno);
-static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int buflen,
- uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit,
- int* gotsomewhere, time_t* at, int* rcode, int* delay);
-static int send_vc(ResState* statp, res_params* params, const uint8_t* buf, int buflen,
- uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode,
- int* delay);
-static int send_mdns(ResState* statp, std::span<const uint8_t> buf, uint8_t* ans, int anssiz,
- int* terrno, int* rcode);
+static int send_dg(ResState* statp, res_params* params, span<const uint8_t> msg, span<uint8_t> ans,
+ int* terrno, size_t* ns, int* v_circuit, int* gotsomewhere, time_t* at,
+ int* rcode, int* delay);
+static int send_vc(ResState* statp, res_params* params, span<const uint8_t> msg, span<uint8_t> ans,
+ int* terrno, size_t ns, time_t* at, int* rcode, int* delay);
+static int send_mdns(ResState* statp, span<const uint8_t> msg, span<uint8_t> ans, int* terrno,
+ int* rcode);
static void dump_error(const char*, const struct sockaddr*);
static int sock_eq(struct sockaddr*, struct sockaddr*);
@@ -175,10 +175,10 @@ static int res_tls_send(const std::list<DnsTlsServer>& tlsServers, ResState*, co
const Slice answer, int* rcode, PrivateDnsMode mode);
static ssize_t res_doh_send(ResState*, const Slice query, const Slice answer, int* rcode);
-NsType getQueryType(const uint8_t* msg, size_t msgLen) {
+NsType getQueryType(span<const uint8_t> msg) {
ns_msg handle;
ns_rr rr;
- if (ns_initparse((const uint8_t*)msg, msgLen, &handle) < 0 ||
+ if (ns_initparse(msg.data(), msg.size(), &handle) < 0 ||
ns_parserr(&handle, ns_s_qd, 0, &rr) < 0) {
return NS_T_INVALID;
}
@@ -348,10 +348,10 @@ static int res_ourserver_p(ResState* statp, const sockaddr* sa) {
}
/* int
- * res_nameinquery(name, type, cl, buf, eom)
- * look for (name, type, cl) in the query section of packet (buf, eom)
+ * res_nameinquery(name, type, cl, msg, eom)
+ * look for (name, type, cl) in the query section of packet (msg, eom)
* requires:
- * buf + HFIXEDSZ <= eom
+ * msg + HFIXEDSZ <= eom
* returns:
* -1 : format error
* 0 : not found
@@ -359,13 +359,13 @@ static int res_ourserver_p(ResState* statp, const sockaddr* sa) {
* author:
* paul vixie, 29may94
*/
-int res_nameinquery(const char* name, int type, int cl, const uint8_t* buf, const uint8_t* eom) {
- const uint8_t* cp = buf + HFIXEDSZ;
- int qdcount = ntohs(((const HEADER*) (const void*) buf)->qdcount);
+int res_nameinquery(const char* name, int type, int cl, const uint8_t* msg, const uint8_t* eom) {
+ const uint8_t* cp = msg + HFIXEDSZ;
+ int qdcount = ntohs(((const HEADER*)(const void*)msg)->qdcount);
while (qdcount-- > 0) {
char tname[MAXDNAME + 1];
- int n = dn_expand(buf, eom, cp, tname, sizeof tname);
+ int n = dn_expand(msg, eom, cp, tname, sizeof tname);
if (n < 0) return (-1);
cp += n;
if (cp + 2 * INT16SZ > eom) return (-1);
@@ -433,30 +433,29 @@ static bool isNetworkRestricted(int terrno) {
return (terrno == EPERM);
}
-int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int anssiz, int* rcode,
+int res_nsend(ResState* statp, span<const uint8_t> msg, span<uint8_t> ans, int* rcode,
uint32_t flags, std::chrono::milliseconds sleepTimeMs) {
LOG(DEBUG) << __func__;
// Should not happen
- if (anssiz < HFIXEDSZ) {
+ if (ans.size() < HFIXEDSZ) {
// TODO: Remove errno once callers stop using it
errno = EINVAL;
return -EINVAL;
}
- res_pquery({buf, buflen});
+ res_pquery(msg);
int anslen = 0;
Stopwatch cacheStopwatch;
- ResolvCacheStatus cache_status =
- resolv_cache_lookup(statp->netid, buf, buflen, ans, anssiz, &anslen, flags);
+ ResolvCacheStatus cache_status = resolv_cache_lookup(statp->netid, msg, ans, &anslen, flags);
const int32_t cacheLatencyUs = saturate_cast<int32_t>(cacheStopwatch.timeTakenUs());
if (cache_status == RESOLV_CACHE_FOUND) {
- HEADER* hp = (HEADER*)(void*)ans;
+ HEADER* hp = (HEADER*)(void*)ans.data();
*rcode = hp->rcode;
DnsQueryEvent* dnsQueryEvent = addDnsQueryEvent(statp->event);
dnsQueryEvent->set_latency_micros(cacheLatencyUs);
dnsQueryEvent->set_cache_hit(static_cast<CacheStatus>(cache_status));
- dnsQueryEvent->set_type(getQueryType(buf, buflen));
+ dnsQueryEvent->set_type(getQueryType(msg));
return anslen;
} else if (cache_status != RESOLV_CACHE_UNSUPPORTED) {
// had a cache miss for a known network, so populate the thread private
@@ -470,30 +469,29 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int
int terrno = ETIME;
int resplen = 0;
*rcode = RCODE_INTERNAL_ERROR;
- std::span<const uint8_t> buffer(buf, buflen);
Stopwatch queryStopwatch;
- resplen = send_mdns(statp, buffer, ans, anssiz, &terrno, rcode);
+ resplen = send_mdns(statp, msg, ans, &terrno, rcode);
const IPSockAddr& receivedMdnsAddr =
- (getQueryType(buf, buflen) == NS_T_AAAA) ? mdns_addrs[0] : mdns_addrs[1];
+ (getQueryType(msg) == NS_T_AAAA) ? mdns_addrs[0] : mdns_addrs[1];
DnsQueryEvent* mDnsQueryEvent = addDnsQueryEvent(statp->event);
mDnsQueryEvent->set_cache_hit(static_cast<CacheStatus>(cache_status));
mDnsQueryEvent->set_latency_micros(saturate_cast<int32_t>(queryStopwatch.timeTakenUs()));
mDnsQueryEvent->set_ip_version(ipFamilyToIPVersion(receivedMdnsAddr.family()));
mDnsQueryEvent->set_rcode(static_cast<NsRcode>(*rcode));
mDnsQueryEvent->set_protocol(PROTO_MDNS);
- mDnsQueryEvent->set_type(getQueryType(buf, buflen));
+ mDnsQueryEvent->set_type(getQueryType(msg));
mDnsQueryEvent->set_linux_errno(static_cast<LinuxErrno>(terrno));
resolv_stats_add(statp->netid, receivedMdnsAddr, mDnsQueryEvent);
if (resplen <= 0) {
- _resolv_cache_query_failed(statp->netid, buf, buflen, flags);
+ _resolv_cache_query_failed(statp->netid, msg, flags);
return -terrno;
}
LOG(DEBUG) << __func__ << ": got answer:";
- res_pquery({ans, (resplen > anssiz) ? anssiz : resplen});
+ res_pquery(ans.first(resplen));
if (cache_status == RESOLV_CACHE_NOTFOUND) {
- resolv_cache_add(statp->netid, buf, buflen, ans, resplen);
+ resolv_cache_add(statp->netid, msg, {ans.data(), resplen});
}
return resplen;
}
@@ -503,7 +501,7 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int
// point trying. Tell the cache the query failed, or any retries and anyone else
// asking the same question will block for PENDING_REQUEST_TIMEOUT seconds instead
// of failing fast.
- _resolv_cache_query_failed(statp->netid, buf, buflen, flags);
+ _resolv_cache_query_failed(statp->netid, msg, flags);
// TODO: Remove errno once callers stop using it
errno = ESRCH;
@@ -513,18 +511,19 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int
// Private DNS
if (!(statp->netcontext_flags & NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS)) {
bool fallback = false;
- int resplen = res_private_dns_send(statp, Slice(const_cast<uint8_t*>(buf), buflen),
- Slice(ans, anssiz), rcode, &fallback);
+ int resplen =
+ res_private_dns_send(statp, Slice(const_cast<uint8_t*>(msg.data()), msg.size()),
+ Slice(ans.data(), ans.size()), rcode, &fallback);
if (resplen > 0) {
LOG(DEBUG) << __func__ << ": got answer from Private DNS";
- res_pquery({ans, resplen});
+ res_pquery(ans.first(resplen));
if (cache_status == RESOLV_CACHE_NOTFOUND) {
- resolv_cache_add(statp->netid, buf, buflen, ans, resplen);
+ resolv_cache_add(statp->netid, msg, ans.first(resplen));
}
return resplen;
}
if (!fallback) {
- _resolv_cache_query_failed(statp->netid, buf, buflen, flags);
+ _resolv_cache_query_failed(statp->netid, msg, flags);
return -ETIMEDOUT;
}
}
@@ -558,7 +557,7 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int
// TODO: Let it always choose the first nameserver when sort_nameservers is enabled.
if ((flags & ANDROID_RESOLV_NO_RETRY) && usableServersCount > 1) {
- auto hp = reinterpret_cast<const HEADER*>(buf);
+ auto hp = reinterpret_cast<const HEADER*>(msg.data());
// Select a random server based on the query id
int selectedServer = (hp->id % usableServersCount) + 1;
@@ -567,7 +566,7 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int
// Send request, RETRY times, or until successful.
int retryTimes = (flags & ANDROID_RESOLV_NO_RETRY) ? 1 : params.retry_count;
- int useTcp = buflen > PACKETSZ;
+ int useTcp = msg.size() > PACKETSZ;
int gotsomewhere = 0;
// Use an impossible error code as default value
@@ -598,10 +597,10 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int
if (useTcp) {
// TCP; at most one attempt per server.
attempt = retryTimes;
- resplen = send_vc(statp, &params, buf, buflen, ans, anssiz, &terrno, ns,
- &query_time, rcode, &delay);
+ resplen =
+ send_vc(statp, &params, msg, ans, &terrno, ns, &query_time, rcode, &delay);
- if (buflen <= PACKETSZ && resplen <= 0 &&
+ if (msg.size() <= PACKETSZ && resplen <= 0 &&
statp->tc_mode == aidl::android::net::IDnsResolver::TC_MODE_UDP_TCP) {
// reset to UDP for next query on next DNS server if resolver is currently doing
// TCP fallback retry and current server does not support TCP connectin
@@ -610,8 +609,8 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int
LOG(INFO) << __func__ << ": used send_vc " << resplen << " terrno: " << terrno;
} else {
// UDP
- resplen = send_dg(statp, &params, buf, buflen, ans, anssiz, &terrno, &actualNs,
- &useTcp, &gotsomewhere, &query_time, rcode, &delay);
+ resplen = send_dg(statp, &params, msg, ans, &terrno, &actualNs, &useTcp,
+ &gotsomewhere, &query_time, rcode, &delay);
fallbackTCP = useTcp ? true : false;
retry_count_for_event = attempt;
LOG(INFO) << __func__ << ": used send_dg " << resplen << " terrno: " << terrno;
@@ -631,7 +630,7 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int
dnsQueryEvent->set_retry_times(retry_count_for_event);
dnsQueryEvent->set_rcode(static_cast<NsRcode>(*rcode));
dnsQueryEvent->set_protocol(query_proto);
- dnsQueryEvent->set_type(getQueryType(buf, buflen));
+ dnsQueryEvent->set_type(getQueryType(msg));
dnsQueryEvent->set_linux_errno(static_cast<LinuxErrno>(terrno));
// Only record stats the first time we try a query. This ensures that
@@ -660,16 +659,16 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int
continue;
}
if (resplen < 0) {
- _resolv_cache_query_failed(statp->netid, buf, buflen, flags);
+ _resolv_cache_query_failed(statp->netid, msg, flags);
statp->closeSockets();
return -terrno;
}
LOG(DEBUG) << __func__ << ": got answer:";
- res_pquery({ans, (resplen > anssiz) ? anssiz : resplen});
+ res_pquery(ans.first(resplen));
if (cache_status == RESOLV_CACHE_NOTFOUND) {
- resolv_cache_add(statp->netid, buf, buflen, ans, resplen);
+ resolv_cache_add(statp->netid, msg, {ans.data(), resplen});
}
statp->closeSockets();
return (resplen);
@@ -682,7 +681,7 @@ int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int
: gotsomewhere ? ETIMEDOUT /* no answer obtained */
: ECONNREFUSED /* no nameservers found */;
- _resolv_cache_query_failed(statp->netid, buf, buflen, flags);
+ _resolv_cache_query_failed(statp->netid, msg, flags);
return -terrno;
}
@@ -707,13 +706,12 @@ static struct timespec get_timeout(ResState* statp, const res_params* params, co
return result;
}
-static int send_vc(ResState* statp, res_params* params, const uint8_t* buf, int buflen,
- uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode,
- int* delay) {
+static int send_vc(ResState* statp, res_params* params, span<const uint8_t> msg, span<uint8_t> ans,
+ int* terrno, size_t ns, time_t* at, int* rcode, int* delay) {
*at = time(NULL);
*delay = 0;
- const HEADER* hp = (const HEADER*) (const void*) buf;
- HEADER* anhp = (HEADER*) (void*) ans;
+ const HEADER* hp = (const HEADER*)(const void*)msg.data();
+ HEADER* anhp = (HEADER*)(void*)ans.data();
struct sockaddr* nsap;
int nsaplen;
int truncating, connreset, n;
@@ -807,12 +805,13 @@ same_ns:
/*
* Send length & message
*/
- uint16_t len = htons(static_cast<uint16_t>(buflen));
+ uint16_t len = htons(static_cast<uint16_t>(msg.size()));
const iovec iov[] = {
{.iov_base = &len, .iov_len = INT16SZ},
- {.iov_base = const_cast<uint8_t*>(buf), .iov_len = static_cast<size_t>(buflen)},
+ {.iov_base = const_cast<uint8_t*>(msg.data()),
+ .iov_len = static_cast<size_t>(msg.size())},
};
- if (writev(statp->tcp_nssock, iov, 2) != (INT16SZ + buflen)) {
+ if (writev(statp->tcp_nssock, iov, 2) != (INT16SZ + msg.size())) {
*terrno = errno;
PLOG(DEBUG) << __func__ << ": write failed: ";
statp->closeSockets();
@@ -822,7 +821,7 @@ same_ns:
* Receive length & response
*/
read_len:
- cp = ans;
+ cp = ans.data();
len = INT16SZ;
while ((n = read(statp->tcp_nssock, (char*)cp, (size_t)len)) > 0) {
cp += n;
@@ -847,11 +846,11 @@ read_len:
}
return (0);
}
- uint16_t resplen = ntohs(*reinterpret_cast<const uint16_t*>(ans));
- if (resplen > anssiz) {
+ uint16_t resplen = ntohs(*reinterpret_cast<const uint16_t*>(ans.data()));
+ if (resplen > ans.size()) {
LOG(DEBUG) << __func__ << ": response truncated";
truncating = 1;
- len = anssiz;
+ len = ans.size();
} else
len = resplen;
if (len < HFIXEDSZ) {
@@ -863,7 +862,7 @@ read_len:
statp->closeSockets();
return (0);
}
- cp = ans;
+ cp = ans.data();
while (len != 0 && (n = read(statp->tcp_nssock, (char*)cp, (size_t)len)) > 0) {
cp += n;
len -= n;
@@ -880,7 +879,7 @@ read_len:
* Flush rest of answer so connection stays in synch.
*/
anhp->tc = 1;
- len = resplen - anssiz;
+ len = resplen - ans.size();
while (len != 0) {
char junk[PACKETSZ];
@@ -890,9 +889,9 @@ read_len:
else
break;
}
- LOG(WARNING) << __func__ << ": resplen " << resplen << " exceeds buf size " << anssiz;
+ LOG(WARNING) << __func__ << ": resplen " << resplen << " exceeds buf size " << ans.size();
// return size should never exceed container size
- resplen = anssiz;
+ resplen = ans.size();
}
/*
* If the calling application has bailed out of
@@ -903,7 +902,7 @@ read_len:
*/
if (hp->id != anhp->id) {
LOG(DEBUG) << __func__ << ": ld answer (unexpected):";
- res_pquery({ans, resplen});
+ res_pquery({ans.data(), resplen});
goto read_len;
}
@@ -1030,10 +1029,10 @@ static Result<std::vector<int>> udpRetryingPollWrapper(ResState* statp, int addr
return std::vector<int>{statp->udpsocks[addrInfo]};
}
-bool ignoreInvalidAnswer(ResState* statp, const sockaddr_storage& from, const uint8_t* buf,
- int buflen, uint8_t* ans, int anssiz, int* receivedFromNs) {
- const HEADER* hp = (const HEADER*)(const void*)buf;
- HEADER* anhp = (HEADER*)(void*)ans;
+bool ignoreInvalidAnswer(ResState* statp, const sockaddr_storage& from, span<const uint8_t> msg,
+ span<uint8_t> ans, int* receivedFromNs) {
+ const HEADER* hp = (const HEADER*)(const void*)msg.data();
+ HEADER* anhp = (HEADER*)(void*)ans.data();
if (hp->id != anhp->id) {
// response from old query, ignore it.
LOG(DEBUG) << __func__ << ": old answer:";
@@ -1044,7 +1043,8 @@ bool ignoreInvalidAnswer(ResState* statp, const sockaddr_storage& from, const ui
LOG(DEBUG) << __func__ << ": not our server:";
return true;
}
- if (!res_queriesmatch(buf, buf + buflen, ans, ans + anssiz)) {
+ if (!res_queriesmatch(msg.data(), msg.data() + msg.size(), ans.data(),
+ ans.data() + ans.size())) {
// response contains wrong query? ignore it.
LOG(DEBUG) << __func__ << ": wrong query name:";
return true;
@@ -1088,9 +1088,9 @@ static int setupUdpSocket(ResState* statp, const sockaddr* sockap, unique_fd* fd
return 1;
}
-static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int buflen,
- uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit,
- int* gotsomewhere, time_t* at, int* rcode, int* delay) {
+static int send_dg(ResState* statp, res_params* params, span<const uint8_t> msg, span<uint8_t> ans,
+ int* terrno, size_t* ns, int* v_circuit, int* gotsomewhere, time_t* at,
+ int* rcode, int* delay) {
// It should never happen, but just in case.
if (*ns >= statp->nsaddrs.size()) {
LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns;
@@ -1119,7 +1119,7 @@ static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int
}
LOG(DEBUG) << __func__ << ": new DG socket";
}
- if (send(statp->udpsocks[*ns], (const char*)buf, (size_t)buflen, 0) != buflen) {
+ if (send(statp->udpsocks[*ns], msg.data(), msg.size(), 0) != msg.size()) {
*terrno = errno;
PLOG(DEBUG) << __func__ << ": send: ";
statp->closeSockets();
@@ -1150,7 +1150,7 @@ static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int
sockaddr_storage from;
socklen_t fromlen = sizeof(from);
int resplen =
- recvfrom(fd, (char*)ans, (size_t)anssiz, 0, (sockaddr*)(void*)&from, &fromlen);
+ recvfrom(fd, ans.data(), ans.size(), 0, (sockaddr*)(void*)&from, &fromlen);
if (resplen <= 0) {
*terrno = errno;
PLOG(DEBUG) << __func__ << ": recvfrom: ";
@@ -1165,20 +1165,19 @@ static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int
}
int receivedFromNs = *ns;
- if (needRetry =
- ignoreInvalidAnswer(statp, from, buf, buflen, ans, anssiz, &receivedFromNs);
+ if (needRetry = ignoreInvalidAnswer(statp, from, msg, ans, &receivedFromNs);
needRetry) {
- res_pquery({ans, (resplen > anssiz) ? anssiz : resplen});
+ res_pquery({ans.data(), (resplen > ans.size()) ? ans.size() : resplen});
continue;
}
- HEADER* anhp = (HEADER*)(void*)ans;
+ HEADER* anhp = (HEADER*)(void*)ans.data();
if (anhp->rcode == FORMERR && (statp->netcontext_flags & NET_CONTEXT_FLAG_USE_EDNS)) {
// Do not retry if the server do not understand EDNS0.
// The case has to be captured here, as FORMERR packet do not
// carry query section, hence res_queriesmatch() returns 0.
LOG(DEBUG) << __func__ << ": server rejected query with EDNS0:";
- res_pquery({ans, (resplen > anssiz) ? anssiz : resplen});
+ res_pquery({ans.data(), (resplen > ans.size()) ? ans.size() : resplen});
// record the error
statp->flags |= RES_F_EDNS0ERR;
*terrno = EREMOTEIO;
@@ -1189,7 +1188,7 @@ static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int
*delay = res_stats_calculate_rtt(&done, &start_time);
if (anhp->rcode == SERVFAIL || anhp->rcode == NOTIMP || anhp->rcode == REFUSED) {
LOG(DEBUG) << __func__ << ": server rejected query:";
- res_pquery({ans, (resplen > anssiz) ? anssiz : resplen});
+ res_pquery({ans.data(), (resplen > ans.size()) ? ans.size() : resplen});
*rcode = anhp->rcode;
continue;
}
@@ -1215,16 +1214,15 @@ static int send_dg(ResState* statp, res_params* params, const uint8_t* buf, int
// return length - when receiving valid packets.
// return 0 - when mdns packets transfer error.
-static int send_mdns(ResState* statp, std::span<const uint8_t> buf, uint8_t* ans, int anssiz,
- int* terrno, int* rcode) {
- const sockaddr_storage ss =
- (getQueryType(buf.data(), buf.size()) == NS_T_AAAA) ? mdns_addrs[0] : mdns_addrs[1];
+static int send_mdns(ResState* statp, span<const uint8_t> msg, span<uint8_t> ans, int* terrno,
+ int* rcode) {
+ const sockaddr_storage ss = (getQueryType(msg) == NS_T_AAAA) ? mdns_addrs[0] : mdns_addrs[1];
const sockaddr* mdnsap = reinterpret_cast<const sockaddr*>(&ss);
unique_fd fd;
if (setupUdpSocket(statp, mdnsap, &fd, terrno) <= 0) return 0;
- if (sendto(fd, buf.data(), buf.size(), 0, mdnsap, sockaddrSize(mdnsap)) != buf.size()) {
+ if (sendto(fd, msg.data(), msg.size(), 0, mdnsap, sockaddrSize(mdnsap)) != msg.size()) {
*terrno = errno;
return 0;
}
@@ -1241,7 +1239,7 @@ static int send_mdns(ResState* statp, std::span<const uint8_t> buf, uint8_t* ans
sockaddr_storage from;
socklen_t fromlen = sizeof(from);
- int resplen = recvfrom(fd, (char*)ans, (size_t)anssiz, 0, (sockaddr*)(void*)&from, &fromlen);
+ int resplen = recvfrom(fd, ans.data(), ans.size(), 0, (sockaddr*)(void*)&from, &fromlen);
if (resplen <= 0) {
*terrno = errno;
@@ -1255,7 +1253,7 @@ static int send_mdns(ResState* statp, std::span<const uint8_t> buf, uint8_t* ans
return 0;
}
- HEADER* anhp = (HEADER*)(void*)ans;
+ HEADER* anhp = (HEADER*)(void*)ans.data();
if (anhp->tc) {
LOG(DEBUG) << __func__ << ": truncated answer";
*terrno = E2BIG;
@@ -1415,7 +1413,8 @@ ssize_t res_doh_send(ResState* statp, const Slice query, const Slice answer, int
}
dnsQueryEvent->set_rcode(static_cast<NsRcode>(*rcode));
dnsQueryEvent->set_protocol(PROTO_DOH);
- dnsQueryEvent->set_type(getQueryType(query.base(), query.size()));
+ span<const uint8_t> msg(query.base(), query.size());
+ dnsQueryEvent->set_type(getQueryType(msg));
return result;
}
@@ -1463,12 +1462,12 @@ int res_tls_send(const std::list<DnsTlsServer>& tlsServers, ResState* statp, con
}
}
-int resolv_res_nsend(const android_net_context* netContext, const uint8_t* msg, int msgLen,
- uint8_t* ans, int ansLen, int* rcode, uint32_t flags,
+int resolv_res_nsend(const android_net_context* netContext, span<const uint8_t> msg,
+ span<uint8_t> ans, int* rcode, uint32_t flags,
NetworkDnsEventReported* event) {
assert(event != nullptr);
ResState res(netContext, event);
resolv_populate_res_for_net(&res);
*rcode = NOERROR;
- return res_nsend(&res, msg, msgLen, ans, ansLen, rcode, flags);
+ return res_nsend(&res, msg, ans, rcode, flags);
}
diff --git a/res_send.h b/res_send.h
index fb80160f..f3c0dfdf 100644
--- a/res_send.h
+++ b/res_send.h
@@ -16,10 +16,12 @@
#pragma once
+#include <span>
+
#include "netd_resolv/resolv.h" // struct android_net_context
#include "stats.pb.h"
// Query dns with raw msg
-int resolv_res_nsend(const android_net_context* netContext, const uint8_t* msg, int msgLen,
- uint8_t* ans, int ansLen, int* rcode, uint32_t flags,
+int resolv_res_nsend(const android_net_context* netContext, std::span<const uint8_t> msg,
+ std::span<uint8_t> ans, int* rcode, uint32_t flags,
android::net::NetworkDnsEventReported* event);
diff --git a/resolv_cache.h b/resolv_cache.h
index 7fba618b..9b267915 100644
--- a/resolv_cache.h
+++ b/resolv_cache.h
@@ -28,6 +28,7 @@
#pragma once
+#include <span>
#include <unordered_map>
#include <vector>
@@ -61,16 +62,16 @@ typedef enum {
RESOLV_CACHE_SKIP /* Don't do anything on cache */
} ResolvCacheStatus;
-ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int querylen, void* answer,
- int answersize, int* answerlen, uint32_t flags);
+ResolvCacheStatus resolv_cache_lookup(unsigned netid, std::span<const uint8_t> query,
+ std::span<uint8_t> answer, int* answerlen, uint32_t flags);
// add a (query,answer) to the cache. If the pair has been in the cache, no new entry will be added
// in the cache.
-int resolv_cache_add(unsigned netid, const void* query, int querylen, const void* answer,
- int answerlen);
+int resolv_cache_add(unsigned netid, std::span<const uint8_t> query,
+ std::span<const uint8_t> answer);
/* Notify the cache a request failed */
-void _resolv_cache_query_failed(unsigned netid, const void* query, int querylen, uint32_t flags);
+void _resolv_cache_query_failed(unsigned netid, std::span<const uint8_t> query, uint32_t flags);
// Get a customized table for a given network.
std::vector<std::string> getCustomizedTableByName(const size_t netid, const char* hostname);
@@ -103,7 +104,7 @@ bool has_named_cache(unsigned netid);
// For test only.
// Get the expiration time of a cache entry. Return 0 on success; otherwise, an negative error is
// returned if the expiration time can't be acquired.
-int resolv_cache_get_expiration(unsigned netid, const std::vector<char>& query, time_t* expiration);
+int resolv_cache_get_expiration(unsigned netid, std::span<const uint8_t> query, time_t* expiration);
// Set addresses to DnsStats for a given network.
int resolv_stats_set_addrs(unsigned netid, android::net::Protocol proto,
diff --git a/resolv_private.h b/resolv_private.h
index a35dd5da..31c541f4 100644
--- a/resolv_private.h
+++ b/resolv_private.h
@@ -54,6 +54,7 @@
#include <net/if.h>
#include <time.h>
+#include <span>
#include <string>
#include <vector>
@@ -168,14 +169,14 @@ extern const char* const _res_opcodes[];
int res_nameinquery(const char*, int, int, const uint8_t*, const uint8_t*);
int res_queriesmatch(const uint8_t*, const uint8_t*, const uint8_t*, const uint8_t*);
-int res_nquery(ResState*, const char*, int, int, uint8_t*, int, int*);
-int res_nsearch(ResState*, const char*, int, int, uint8_t*, int, int*);
-int res_nquerydomain(ResState*, const char*, const char*, int, int, uint8_t*, int, int*);
-int res_nmkquery(int op, const char* qname, int cl, int type, const uint8_t* data, int datalen,
- uint8_t* buf, int buflen, int netcontext_flags);
-int res_nsend(ResState* statp, const uint8_t* buf, int buflen, uint8_t* ans, int anssiz, int* rcode,
+int res_nquery(ResState*, const char*, int, int, std::span<uint8_t>, int*);
+int res_nsearch(ResState*, const char*, int, int, std::span<uint8_t>, int*);
+int res_nquerydomain(ResState*, const char*, const char*, int, int, std::span<uint8_t>, int*);
+int res_nmkquery(int op, const char* qname, int cl, int type, std::span<const uint8_t> data,
+ std::span<uint8_t> msg, int netcontext_flags);
+int res_nsend(ResState* statp, std::span<const uint8_t> msg, std::span<uint8_t> ans, int* rcode,
uint32_t flags, std::chrono::milliseconds sleepTimeMs = {});
-int res_nopt(ResState*, int, uint8_t*, int, int);
+int res_nopt(ResState*, int, std::span<uint8_t>, int);
int getaddrinfo_numeric(const char* hostname, const char* servname, addrinfo hints,
addrinfo** result);
@@ -227,7 +228,7 @@ constexpr T* align_ptr(T* const p) {
// static_assert(align_ptr<sizeof(uint32_t)>((char*)1004) == (char*)1004);
// static_assert(align_ptr<sizeof(uint64_t)>((char*)1004) == (char*)1008);
-android::net::NsType getQueryType(const uint8_t* msg, size_t msgLen);
+android::net::NsType getQueryType(std::span<const uint8_t> msg);
android::net::IpVersion ipFamilyToIPVersion(int ipFamily);
diff --git a/tests/resolv_cache_unit_test.cpp b/tests/resolv_cache_unit_test.cpp
index a9a0f721..050964e0 100644
--- a/tests/resolv_cache_unit_test.cpp
+++ b/tests/resolv_cache_unit_test.cpp
@@ -20,6 +20,7 @@
#include <atomic>
#include <chrono>
#include <ctime>
+#include <span>
#include <thread>
#include <android-base/logging.h>
@@ -49,8 +50,8 @@ constexpr int MAX_ENTRIES = 64 * 2 * 5;
namespace {
struct CacheEntry {
- std::vector<char> query;
- std::vector<char> answer;
+ std::vector<uint8_t> query;
+ std::vector<uint8_t> answer;
};
struct SetupParams {
@@ -67,18 +68,17 @@ struct CacheStats {
int pendingReqTimeoutCount;
};
-std::vector<char> makeQuery(int op, const char* qname, int qclass, int qtype) {
+std::vector<uint8_t> makeQuery(int op, const char* qname, int qclass, int qtype) {
uint8_t buf[MAXPACKET] = {};
- const int len = res_nmkquery(op, qname, qclass, qtype, /*data=*/nullptr, /*datalen=*/0, buf,
- sizeof(buf),
- /*netcontext_flags=*/0);
- return std::vector<char>(buf, buf + len);
+ const int len = res_nmkquery(op, qname, qclass, qtype, {}, buf, /*netcontext_flags=*/0);
+ return std::vector<uint8_t>(buf, buf + len);
}
-std::vector<char> makeAnswer(const std::vector<char>& query, const char* rdata_str,
- const unsigned ttl) {
+std::vector<uint8_t> makeAnswer(const std::vector<uint8_t>& query, const char* rdata_str,
+ const unsigned ttl) {
test::DNSHeader header;
- header.read(query.data(), query.data() + query.size());
+ header.read(reinterpret_cast<const char*>(query.data()),
+ reinterpret_cast<const char*>(query.data()) + query.size());
for (const test::DNSQuestion& question : header.questions) {
std::string rname(question.qname.name);
@@ -94,7 +94,7 @@ std::vector<char> makeAnswer(const std::vector<char>& query, const char* rdata_s
char answer[MAXPACKET] = {};
char* answer_end = header.write(answer, answer + sizeof(answer));
- return std::vector<char>(answer, answer_end);
+ return std::vector<uint8_t>(answer, answer_end);
}
// Get the current time in unix timestamp since the Epoch.
@@ -155,9 +155,8 @@ class ResolvCacheTest : public ::testing::Test {
[[nodiscard]] bool cacheLookup(ResolvCacheStatus expectedCacheStatus, uint32_t netId,
const CacheEntry& ce, uint32_t flags = 0) {
int anslen = 0;
- std::vector<char> answer(MAXPACKET);
- const auto cacheStatus = resolv_cache_lookup(netId, ce.query.data(), ce.query.size(),
- answer.data(), answer.size(), &anslen, flags);
+ std::vector<uint8_t> answer(MAXPACKET);
+ const auto cacheStatus = resolv_cache_lookup(netId, ce.query, answer, &anslen, flags);
if (cacheStatus != expectedCacheStatus) {
ADD_FAILURE() << "cacheStatus: expected = " << expectedCacheStatus
<< ", actual =" << cacheStatus;
@@ -183,20 +182,20 @@ class ResolvCacheTest : public ::testing::Test {
}
int cacheAdd(uint32_t netId, const CacheEntry& ce) {
- return resolv_cache_add(netId, ce.query.data(), ce.query.size(), ce.answer.data(),
- ce.answer.size());
+ return resolv_cache_add(netId, ce.query, ce.answer);
}
- int cacheAdd(uint32_t netId, const std::vector<char>& query, const std::vector<char>& answer) {
- return resolv_cache_add(netId, query.data(), query.size(), answer.data(), answer.size());
+ int cacheAdd(uint32_t netId, const std::vector<uint8_t>& query,
+ const std::vector<uint8_t>& answer) {
+ return resolv_cache_add(netId, query, answer);
}
- int cacheGetExpiration(uint32_t netId, const std::vector<char>& query, time_t* expiration) {
+ int cacheGetExpiration(uint32_t netId, const std::vector<uint8_t>& query, time_t* expiration) {
return resolv_cache_get_expiration(netId, query, expiration);
}
void cacheQueryFailed(uint32_t netId, const CacheEntry& ce, uint32_t flags) {
- _resolv_cache_query_failed(netId, ce.query.data(), ce.query.size(), flags);
+ _resolv_cache_query_failed(netId, ce.query, flags);
}
int cacheSetupResolver(uint32_t netId, const SetupParams& setup) {
@@ -284,8 +283,8 @@ TEST_F(ResolvCacheTest, CreateAndDeleteCache) {
TEST_F(ResolvCacheTest, CacheAdd_InvalidArgs) {
EXPECT_EQ(0, cacheCreate(TEST_NETID));
- const std::vector<char> queryEmpty(MAXPACKET, 0);
- const std::vector<char> queryTooSmall(DNS_HEADER_SIZE - 1, 0);
+ const std::vector<uint8_t> queryEmpty(MAXPACKET, 0);
+ const std::vector<uint8_t> queryTooSmall(DNS_HEADER_SIZE - 1, 0);
CacheEntry ce = makeCacheEntry(QUERY, "valid.cache", ns_c_in, ns_t_a, "1.2.3.4");
EXPECT_EQ(-EINVAL, cacheAdd(TEST_NETID, queryEmpty, ce.answer));
@@ -395,15 +394,14 @@ TEST_F(ResolvCacheTest, CacheLookup_Types) {
TEST_F(ResolvCacheTest, CacheLookup_InvalidArgs) {
EXPECT_EQ(0, cacheCreate(TEST_NETID));
- const std::vector<char> queryEmpty(MAXPACKET, 0);
- const std::vector<char> queryTooSmall(DNS_HEADER_SIZE - 1, 0);
- std::vector<char> answerTooSmall(DNS_HEADER_SIZE - 1, 0);
+ const std::vector<uint8_t> queryEmpty(MAXPACKET, 0);
+ const std::vector<uint8_t> queryTooSmall(DNS_HEADER_SIZE - 1, 0);
+ std::vector<uint8_t> answerTooSmall(DNS_HEADER_SIZE - 1, 0);
const CacheEntry ce = makeCacheEntry(QUERY, "valid.cache", ns_c_in, ns_t_a, "1.2.3.4");
- auto cacheLookupFn = [](const std::vector<char>& query,
- std::vector<char> answer) -> ResolvCacheStatus {
+ auto cacheLookupFn = [](const std::vector<uint8_t>& query,
+ std::vector<uint8_t> answer) -> ResolvCacheStatus {
int anslen = 0;
- return resolv_cache_lookup(TEST_NETID, query.data(), query.size(), answer.data(),
- answer.size(), &anslen, 0);
+ return resolv_cache_lookup(TEST_NETID, query, answer, &anslen, 0);
};
EXPECT_EQ(0, cacheAdd(TEST_NETID, ce));