summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTony Mak <tonymak@google.com>2019-03-29 14:05:34 +0000
committerTony Mak <tonymak@google.com>2019-03-29 14:16:00 +0000
commit2753d687954df8a5e4a03776daa3d18c73894358 (patch)
tree145b66684d05996067b5d1e2e5f5fb08284d463b
parentdf54e74fd621225ad3ac7c77c89d0e1c965ba01e (diff)
downloadandroid_external_libtextclassifier-2753d687954df8a5e4a03776daa3d18c73894358.tar.gz
android_external_libtextclassifier-2753d687954df8a5e4a03776daa3d18c73894358.tar.bz2
android_external_libtextclassifier-2753d687954df8a5e4a03776daa3d18c73894358.zip
Import libtextclassifier
Not exporting model file, as the content description thing needs some integration works which take time to review. BUG: 129481059 Test: atest frameworks/base/core/tests/coretests/src/android/view/textclassifier/ Change-Id: I24c7dcaffe79c523b23b591767e2aeb3f581b3f7
-rw-r--r--Android.bp1
-rw-r--r--actions/actions-suggestions.cc4
-rw-r--r--actions/actions-suggestions.h2
-rw-r--r--actions/actions-suggestions_test.cc18
-rw-r--r--actions/actions_jni.cc3
-rwxr-xr-xactions/actions_model.fbs4
-rw-r--r--actions/zlib-utils.cc18
-rw-r--r--annotator/number/number.cc2
-rw-r--r--annotator/number/number_test.cc16
-rw-r--r--utils/sentencepiece/double_array_trie.cc2
-rw-r--r--utils/sentencepiece/double_array_trie.h13
-rw-r--r--utils/sentencepiece/encoder_test.cc7
-rw-r--r--utils/sentencepiece/sorted_strings_table.cc16
-rw-r--r--utils/sentencepiece/sorted_strings_table.h5
-rw-r--r--utils/sentencepiece/sorted_strings_table_test.cc3
-rw-r--r--utils/tflite/text_encoder_config.fbs2
16 files changed, 68 insertions, 48 deletions
diff --git a/Android.bp b/Android.bp
index c1d9cb5..66b8365 100644
--- a/Android.bp
+++ b/Android.bp
@@ -74,6 +74,7 @@ cc_defaults {
"-Wno-unused-parameter",
"-Wno-extern-c-compat",
+ "-funsigned-char",
"-fvisibility=hidden",
"-DLIBTEXTCLASSIFIER_UNILIB_ICU",
"-DZLIB_CONST",
diff --git a/actions/actions-suggestions.cc b/actions/actions-suggestions.cc
index d8be1b0..1d4a70f 100644
--- a/actions/actions-suggestions.cc
+++ b/actions/actions-suggestions.cc
@@ -1035,7 +1035,7 @@ std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
bool ActionsSuggestions::FillAnnotationFromMatchGroup(
const UniLib::RegexMatcher* matcher,
- const RulesModel_::Rule_::RuleActionSpec_::CapturingGroup* group,
+ const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
const int message_index, ActionSuggestionAnnotation* annotation) const {
if (group->annotation_name() != nullptr ||
group->annotation_type() != nullptr) {
@@ -1100,7 +1100,7 @@ bool ActionsSuggestions::SuggestActionsFromRules(
// Add entity data from rule capturing groups.
if (rule_action->capturing_group() != nullptr) {
- for (const RulesModel_::Rule_::RuleActionSpec_::CapturingGroup*
+ for (const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup*
group : *rule_action->capturing_group()) {
if (group->entity_field() != nullptr) {
TC3_CHECK(entity_data != nullptr);
diff --git a/actions/actions-suggestions.h b/actions/actions-suggestions.h
index 3c2ff09..61e052b 100644
--- a/actions/actions-suggestions.h
+++ b/actions/actions-suggestions.h
@@ -207,7 +207,7 @@ class ActionsSuggestions {
bool FillAnnotationFromMatchGroup(
const UniLib::RegexMatcher* matcher,
- const RulesModel_::Rule_::RuleActionSpec_::CapturingGroup* group,
+ const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
const int message_index, ActionSuggestionAnnotation* annotation) const;
const ActionsModel* model_;
diff --git a/actions/actions-suggestions_test.cc b/actions/actions-suggestions_test.cc
index b510284..b0aeec6 100644
--- a/actions/actions-suggestions_test.cc
+++ b/actions/actions-suggestions_test.cc
@@ -554,6 +554,8 @@ TEST_F(ActionsSuggestionsTest,
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
+ actions_model->low_confidence_rules.reset();
+
// Add custom triggering rule.
actions_model->rules.reset(new RulesModelT());
actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
@@ -731,16 +733,16 @@ TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
// Set capturing groups for entity data.
rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT);
- RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT* greeting_group =
+ new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
rule->actions.back()->capturing_group.back().get();
greeting_group->group_id = 0;
greeting_group->entity_field.reset(new FlatbufferFieldPathT);
greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
greeting_group->entity_field->field.back()->field_name = "greeting";
rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT);
- RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT* location_group =
+ new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* location_group =
rule->actions.back()->capturing_group.back().get();
location_group->group_id = 1;
location_group->entity_field.reset(new FlatbufferFieldPathT);
@@ -802,8 +804,8 @@ TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
// Set capturing groups for entity data.
rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT);
- RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT* code_group =
+ new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group =
rule->actions.back()->capturing_group.back().get();
code_group->group_id = 1;
code_group->text_reply.reset(new ActionSuggestionSpecT);
@@ -915,8 +917,8 @@ TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
action->priority_score = 2.0f;
action->type = "test_code";
rule->actions.back()->capturing_group.emplace_back(
- new RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT);
- RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT* code_group =
+ new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group =
rule->actions.back()->capturing_group.back().get();
code_group->group_id = 1;
code_group->annotation_name = "code";
diff --git a/actions/actions_jni.cc b/actions/actions_jni.cc
index 1ee595f..20891fa 100644
--- a/actions/actions_jni.cc
+++ b/actions/actions_jni.cc
@@ -376,8 +376,7 @@ TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
annotator ? annotator->entity_data_schema() : nullptr;
return ActionSuggestionsToJObjectArray(
env, context, app_context, anntotations_entity_data_schema,
- response.actions, conversation,
- /*device_locales=*/nullptr, generate_intents);
+ response.actions, conversation, device_locales, generate_intents);
}
TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
diff --git a/actions/actions_model.fbs b/actions/actions_model.fbs
index 885a717..dc4520a 100755
--- a/actions/actions_model.fbs
+++ b/actions/actions_model.fbs
@@ -332,7 +332,7 @@ table RankingOptions {
// Entity data to set from capturing groups.
namespace libtextclassifier3.RulesModel_.Rule_.RuleActionSpec_;
-table CapturingGroup {
+table RuleCapturingGroup {
// The id of group.
group_id:int;
@@ -357,7 +357,7 @@ table RuleActionSpec {
// The action.
action:ActionSuggestionSpec;
- capturing_group:[RuleActionSpec_.CapturingGroup];
+ capturing_group:[RuleActionSpec_.RuleCapturingGroup];
}
// List of regular expression matchers.
diff --git a/actions/zlib-utils.cc b/actions/zlib-utils.cc
index 317623b..b1d997d 100644
--- a/actions/zlib-utils.cc
+++ b/actions/zlib-utils.cc
@@ -42,12 +42,9 @@ bool CompressActionsModel(ActionsModelT* model) {
}
}
- if (model->preconditions != nullptr &&
- model->preconditions->low_confidence_rules != nullptr) {
- for (int i = 0; i < model->preconditions->low_confidence_rules->rule.size();
- i++) {
- RulesModel_::RuleT* rule =
- model->preconditions->low_confidence_rules->rule[i].get();
+ if (model->low_confidence_rules != nullptr) {
+ for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) {
+ RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get();
if (!rule->pattern.empty()) {
rule->compressed_pattern.reset(new CompressedBufferT);
zlib_compressor->Compress(rule->pattern,
@@ -113,12 +110,9 @@ bool DecompressActionsModel(ActionsModelT* model) {
}
// Decompress low confidence rules.
- if (model->preconditions != nullptr &&
- model->preconditions->low_confidence_rules != nullptr) {
- for (int i = 0; i < model->preconditions->low_confidence_rules->rule.size();
- i++) {
- RulesModel_::RuleT* rule =
- model->preconditions->low_confidence_rules->rule[i].get();
+ if (model->low_confidence_rules != nullptr) {
+ for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) {
+ RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get();
if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(),
&rule->pattern)) {
TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
diff --git a/annotator/number/number.cc b/annotator/number/number.cc
index e58e682..bc3a2fe 100644
--- a/annotator/number/number.cc
+++ b/annotator/number/number.cc
@@ -69,7 +69,7 @@ bool NumberAnnotator::FindAll(const UnicodeText& context,
classification.priority_score = options_->priority_score();
AnnotatedSpan annotated_span;
- annotated_span.span = {token.start - num_prefix_codepoints,
+ annotated_span.span = {token.start + num_prefix_codepoints,
token.end - num_suffix_codepoints};
annotated_span.classification.push_back(classification);
diff --git a/annotator/number/number_test.cc b/annotator/number/number_test.cc
index afa9444..d3b2e8c 100644
--- a/annotator/number/number_test.cc
+++ b/annotator/number/number_test.cc
@@ -145,6 +145,22 @@ TEST_F(NumberAnnotatorTest, FindsNumberWithPunctuation) {
Field(&ClassificationResult::numeric_value, 9)))))));
}
+TEST_F(NumberAnnotatorTest, HandlesNumbersAtBeginning) {
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(number_annotator_.FindAll(
+ UTF8ToUnicodeText("-5"), AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 2)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "number"),
+ Field(&ClassificationResult::numeric_value, -5)))))));
+}
+
TEST_F(NumberAnnotatorTest, WhenLowestSupportedNumberParsesIt) {
ClassificationResult classification_result;
EXPECT_TRUE(number_annotator_.ClassifyText(
diff --git a/utils/sentencepiece/double_array_trie.cc b/utils/sentencepiece/double_array_trie.cc
index b6ed1e1..a2b66ea 100644
--- a/utils/sentencepiece/double_array_trie.cc
+++ b/utils/sentencepiece/double_array_trie.cc
@@ -21,7 +21,7 @@ namespace libtextclassifier3 {
bool DoubleArrayTrie::GatherPrefixMatches(
StringPiece input, const std::function<void(TrieMatch)>& update_fn) const {
- unsigned int pos = 0;
+ uint32 pos = 0;
if (nodes_length_ == 0) {
TC3_LOG(WARNING) << "Trie is empty. Skipping.";
return true;
diff --git a/utils/sentencepiece/double_array_trie.h b/utils/sentencepiece/double_array_trie.h
index e88819a..0614fb4 100644
--- a/utils/sentencepiece/double_array_trie.h
+++ b/utils/sentencepiece/double_array_trie.h
@@ -21,6 +21,7 @@
#include <vector>
#include "utils/base/endian.h"
+#include "utils/base/integral_types.h"
#include "utils/sentencepiece/matcher.h"
#include "utils/strings/stringpiece.h"
@@ -35,7 +36,7 @@ namespace libtextclassifier3 {
// character during matching.
// We account for endianness when using the node values, as they are serialized
// (in little endian) as bytes in the flatbuffer model.
-typedef unsigned int TrieNode;
+typedef uint32 TrieNode;
// A memory mappable trie, compatible with Darts::DoubleArray.
class DoubleArrayTrie : public SentencePieceMatcher {
@@ -53,22 +54,22 @@ class DoubleArrayTrie : public SentencePieceMatcher {
private:
// Returns whether a node as a leaf as a child.
- bool has_leaf(unsigned int i) const { return nodes_[i] & 0x100; }
+ bool has_leaf(uint32 i) const { return nodes_[i] & 0x100; }
// Available when a node is a leaf.
- int value(unsigned int i) const {
+ int value(uint32 i) const {
return static_cast<int>(LittleEndian::ToHost32(nodes_[i]) & 0x7fffffff);
}
// Label associated with a node.
// A leaf node will have the MSB set and thus return an invalid label.
- unsigned int label(unsigned int i) const {
+ uint32 label(uint32 i) const {
return LittleEndian::ToHost32(nodes_[i]) & 0x800000ff;
}
// Returns offset to children.
- unsigned int offset(unsigned int i) const {
- const unsigned int node = LittleEndian::ToHost32(nodes_[i]);
+ uint32 offset(uint32 i) const {
+ const uint32 node = LittleEndian::ToHost32(nodes_[i]);
return (node >> 10) << ((node & 0x200) >> 6);
}
diff --git a/utils/sentencepiece/encoder_test.cc b/utils/sentencepiece/encoder_test.cc
index 3cdb3e3..9082cca 100644
--- a/utils/sentencepiece/encoder_test.cc
+++ b/utils/sentencepiece/encoder_test.cc
@@ -20,6 +20,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "utils/base/integral_types.h"
#include "utils/sentencepiece/encoder.h"
#include "utils/sentencepiece/sorted_strings_table.h"
@@ -30,7 +31,7 @@ using testing::ElementsAre;
TEST(EncoderTest, SimpleTokenization) {
const char pieces[] = "hell\0hello\0o\0there\0";
- const int offsets[] = {0, 5, 11, 13};
+ const uint32 offsets[] = {0, 5, 11, 13};
float scores[] = {-0.5, -1.0, -10.0, -1.0};
std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
/*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
@@ -55,7 +56,7 @@ TEST(EncoderTest, SimpleTokenization) {
TEST(EncoderTest, HandlesEdgeCases) {
const char pieces[] = "hell\0hello\0o\0there\0";
- const int offsets[] = {0, 5, 11, 13};
+ const uint32 offsets[] = {0, 5, 11, 13};
float scores[] = {-0.5, -1.0, -10.0, -1.0};
std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
/*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
@@ -85,7 +86,7 @@ TEST(EncoderTest, HandlesEdgeCases) {
TEST(EncoderTest, HandlesOutOfDictionary) {
const char pieces[] = "hell\0hello\0o\0there\0";
- const int offsets[] = {0, 5, 11, 13};
+ const uint32 offsets[] = {0, 5, 11, 13};
float scores[] = {-0.5, -1.0, -10.0, -1.0};
std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable(
/*num_pieces=*/4, offsets, StringPiece(pieces, 18)));
diff --git a/utils/sentencepiece/sorted_strings_table.cc b/utils/sentencepiece/sorted_strings_table.cc
index 96637d8..8e7e9ba 100644
--- a/utils/sentencepiece/sorted_strings_table.cc
+++ b/utils/sentencepiece/sorted_strings_table.cc
@@ -41,15 +41,19 @@ void SortedStringsTable::GatherPrefixMatches(
// `lower_bound` to find the start of the range of matching pieces.
// `upper_bound` to find the non-inclusive end of the range.
left = (std::lower_bound(
- offsets_ + left, offsets_ + right, input[match_length],
- [this, match_length](int piece_offset, int c) -> bool {
- return pieces_[piece_offset + match_length] < c;
+ offsets_ + left, offsets_ + right,
+ static_cast<unsigned char>(input[match_length]),
+ [this, match_length](uint32 piece_offset, uint32 c) -> bool {
+ return static_cast<unsigned char>(
+ pieces_[piece_offset + match_length]) < c;
}) -
offsets_);
right = (std::upper_bound(
- offsets_ + left, offsets_ + right, input[match_length],
- [this, match_length](int c, int piece_offset) -> bool {
- return c < pieces_[piece_offset + match_length];
+ offsets_ + left, offsets_ + right,
+ static_cast<unsigned char>(input[match_length]),
+ [this, match_length](uint32 c, uint32 piece_offset) -> bool {
+ return c < static_cast<unsigned char>(
+ pieces_[piece_offset + match_length]);
}) -
offsets_);
span_size = right - left;
diff --git a/utils/sentencepiece/sorted_strings_table.h b/utils/sentencepiece/sorted_strings_table.h
index 61a2239..69f638a 100644
--- a/utils/sentencepiece/sorted_strings_table.h
+++ b/utils/sentencepiece/sorted_strings_table.h
@@ -20,6 +20,7 @@
#include <functional>
#include <vector>
+#include "utils/base/integral_types.h"
#include "utils/sentencepiece/matcher.h"
#include "utils/strings/stringpiece.h"
@@ -36,7 +37,7 @@ namespace libtextclassifier3 {
// switching to a linear sweep for prefix match testing.
class SortedStringsTable : public SentencePieceMatcher {
public:
- SortedStringsTable(const int num_pieces, const int* offsets,
+ SortedStringsTable(const int num_pieces, const uint32* offsets,
StringPiece pieces,
const int use_linear_scan_threshold = 10)
: num_pieces_(num_pieces),
@@ -56,7 +57,7 @@ class SortedStringsTable : public SentencePieceMatcher {
StringPiece input, const std::function<void(TrieMatch)>& update_fn) const;
const int num_pieces_;
- const int* offsets_;
+ const uint32* offsets_;
const StringPiece pieces_;
const int use_linear_scan_threshold_;
};
diff --git a/utils/sentencepiece/sorted_strings_table_test.cc b/utils/sentencepiece/sorted_strings_table_test.cc
index 10824d9..4dff29d 100644
--- a/utils/sentencepiece/sorted_strings_table_test.cc
+++ b/utils/sentencepiece/sorted_strings_table_test.cc
@@ -19,6 +19,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "utils/base/integral_types.h"
#include "utils/sentencepiece/sorted_strings_table.h"
namespace libtextclassifier3 {
@@ -26,7 +27,7 @@ namespace {
TEST(SortedStringsTest, Lookup) {
const char pieces[] = "hell\0hello\0o\0there\0";
- const int offsets[] = {0, 5, 11, 13};
+ const uint32 offsets[] = {0, 5, 11, 13};
SortedStringsTable table(/*num_pieces=*/4, offsets, StringPiece(pieces, 18),
/*use_linear_scan_threshold=*/1);
diff --git a/utils/tflite/text_encoder_config.fbs b/utils/tflite/text_encoder_config.fbs
index 8ae8fc5..4ffade4 100644
--- a/utils/tflite/text_encoder_config.fbs
+++ b/utils/tflite/text_encoder_config.fbs
@@ -60,6 +60,6 @@ table TextEncoderConfig {
// Serialized sentence pieces.
pieces:string;
- pieces_offsets:[int32];
+ pieces_offsets:[uint32];
matcher_type: SentencePieceMatcherType = MAPPED_TRIE;
}