diff options
author | Tony Mak <tonymak@google.com> | 2019-03-29 14:05:34 +0000 |
---|---|---|
committer | Tony Mak <tonymak@google.com> | 2019-03-29 14:16:00 +0000 |
commit | 2753d687954df8a5e4a03776daa3d18c73894358 (patch) | |
tree | 145b66684d05996067b5d1e2e5f5fb08284d463b | |
parent | df54e74fd621225ad3ac7c77c89d0e1c965ba01e (diff) | |
download | android_external_libtextclassifier-2753d687954df8a5e4a03776daa3d18c73894358.tar.gz android_external_libtextclassifier-2753d687954df8a5e4a03776daa3d18c73894358.tar.bz2 android_external_libtextclassifier-2753d687954df8a5e4a03776daa3d18c73894358.zip |
Import libtextclassifier
Not exporting model file, as the content description thing needs
some integration works which take time to review.
BUG: 129481059
Test: atest frameworks/base/core/tests/coretests/src/android/view/textclassifier/
Change-Id: I24c7dcaffe79c523b23b591767e2aeb3f581b3f7
-rw-r--r-- | Android.bp | 1 | ||||
-rw-r--r-- | actions/actions-suggestions.cc | 4 | ||||
-rw-r--r-- | actions/actions-suggestions.h | 2 | ||||
-rw-r--r-- | actions/actions-suggestions_test.cc | 18 | ||||
-rw-r--r-- | actions/actions_jni.cc | 3 | ||||
-rwxr-xr-x | actions/actions_model.fbs | 4 | ||||
-rw-r--r-- | actions/zlib-utils.cc | 18 | ||||
-rw-r--r-- | annotator/number/number.cc | 2 | ||||
-rw-r--r-- | annotator/number/number_test.cc | 16 | ||||
-rw-r--r-- | utils/sentencepiece/double_array_trie.cc | 2 | ||||
-rw-r--r-- | utils/sentencepiece/double_array_trie.h | 13 | ||||
-rw-r--r-- | utils/sentencepiece/encoder_test.cc | 7 | ||||
-rw-r--r-- | utils/sentencepiece/sorted_strings_table.cc | 16 | ||||
-rw-r--r-- | utils/sentencepiece/sorted_strings_table.h | 5 | ||||
-rw-r--r-- | utils/sentencepiece/sorted_strings_table_test.cc | 3 | ||||
-rw-r--r-- | utils/tflite/text_encoder_config.fbs | 2 |
16 files changed, 68 insertions, 48 deletions
@@ -74,6 +74,7 @@ cc_defaults { "-Wno-unused-parameter", "-Wno-extern-c-compat", + "-funsigned-char", "-fvisibility=hidden", "-DLIBTEXTCLASSIFIER_UNILIB_ICU", "-DZLIB_CONST", diff --git a/actions/actions-suggestions.cc b/actions/actions-suggestions.cc index d8be1b0..1d4a70f 100644 --- a/actions/actions-suggestions.cc +++ b/actions/actions-suggestions.cc @@ -1035,7 +1035,7 @@ std::vector<int> ActionsSuggestions::DeduplicateAnnotations( bool ActionsSuggestions::FillAnnotationFromMatchGroup( const UniLib::RegexMatcher* matcher, - const RulesModel_::Rule_::RuleActionSpec_::CapturingGroup* group, + const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group, const int message_index, ActionSuggestionAnnotation* annotation) const { if (group->annotation_name() != nullptr || group->annotation_type() != nullptr) { @@ -1100,7 +1100,7 @@ bool ActionsSuggestions::SuggestActionsFromRules( // Add entity data from rule capturing groups. if (rule_action->capturing_group() != nullptr) { - for (const RulesModel_::Rule_::RuleActionSpec_::CapturingGroup* + for (const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group : *rule_action->capturing_group()) { if (group->entity_field() != nullptr) { TC3_CHECK(entity_data != nullptr); diff --git a/actions/actions-suggestions.h b/actions/actions-suggestions.h index 3c2ff09..61e052b 100644 --- a/actions/actions-suggestions.h +++ b/actions/actions-suggestions.h @@ -207,7 +207,7 @@ class ActionsSuggestions { bool FillAnnotationFromMatchGroup( const UniLib::RegexMatcher* matcher, - const RulesModel_::Rule_::RuleActionSpec_::CapturingGroup* group, + const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group, const int message_index, ActionSuggestionAnnotation* annotation) const; const ActionsModel* model_; diff --git a/actions/actions-suggestions_test.cc b/actions/actions-suggestions_test.cc index b510284..b0aeec6 100644 --- a/actions/actions-suggestions_test.cc +++ b/actions/actions-suggestions_test.cc @@ -554,6 +554,8 @@ TEST_F(ActionsSuggestionsTest, ReadFile(GetModelPath() + kModelFileName); std::unique_ptr<ActionsModelT> actions_model = UnPackActionsModel(actions_model_string.c_str()); + actions_model->low_confidence_rules.reset(); + // Add custom triggering rule. actions_model->rules.reset(new RulesModelT()); actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); @@ -731,16 +733,16 @@ TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) { // Set capturing groups for entity data. rule->actions.back()->capturing_group.emplace_back( - new RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT); - RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT* greeting_group = + new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); + RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* greeting_group = rule->actions.back()->capturing_group.back().get(); greeting_group->group_id = 0; greeting_group->entity_field.reset(new FlatbufferFieldPathT); greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT); greeting_group->entity_field->field.back()->field_name = "greeting"; rule->actions.back()->capturing_group.emplace_back( - new RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT); - RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT* location_group = + new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); + RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* location_group = rule->actions.back()->capturing_group.back().get(); location_group->group_id = 1; location_group->entity_field.reset(new FlatbufferFieldPathT); @@ -802,8 +804,8 @@ TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) { // Set capturing groups for entity data. rule->actions.back()->capturing_group.emplace_back( - new RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT); - RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT* code_group = + new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); + RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group = rule->actions.back()->capturing_group.back().get(); code_group->group_id = 1; code_group->text_reply.reset(new ActionSuggestionSpecT); @@ -915,8 +917,8 @@ TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) { action->priority_score = 2.0f; action->type = "test_code"; rule->actions.back()->capturing_group.emplace_back( - new RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT); - RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT* code_group = + new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); + RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group = rule->actions.back()->capturing_group.back().get(); code_group->group_id = 1; code_group->annotation_name = "code"; diff --git a/actions/actions_jni.cc b/actions/actions_jni.cc index 1ee595f..20891fa 100644 --- a/actions/actions_jni.cc +++ b/actions/actions_jni.cc @@ -376,8 +376,7 @@ TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions) annotator ? annotator->entity_data_schema() : nullptr; return ActionSuggestionsToJObjectArray( env, context, app_context, anntotations_entity_data_schema, - response.actions, conversation, - /*device_locales=*/nullptr, generate_intents); + response.actions, conversation, device_locales, generate_intents); } TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel) diff --git a/actions/actions_model.fbs b/actions/actions_model.fbs index 885a717..dc4520a 100755 --- a/actions/actions_model.fbs +++ b/actions/actions_model.fbs @@ -332,7 +332,7 @@ table RankingOptions { // Entity data to set from capturing groups. namespace libtextclassifier3.RulesModel_.Rule_.RuleActionSpec_; -table CapturingGroup { +table RuleCapturingGroup { // The id of group. group_id:int; @@ -357,7 +357,7 @@ table RuleActionSpec { // The action. action:ActionSuggestionSpec; - capturing_group:[RuleActionSpec_.CapturingGroup]; + capturing_group:[RuleActionSpec_.RuleCapturingGroup]; } // List of regular expression matchers. diff --git a/actions/zlib-utils.cc b/actions/zlib-utils.cc index 317623b..b1d997d 100644 --- a/actions/zlib-utils.cc +++ b/actions/zlib-utils.cc @@ -42,12 +42,9 @@ bool CompressActionsModel(ActionsModelT* model) { } } - if (model->preconditions != nullptr && - model->preconditions->low_confidence_rules != nullptr) { - for (int i = 0; i < model->preconditions->low_confidence_rules->rule.size(); - i++) { - RulesModel_::RuleT* rule = - model->preconditions->low_confidence_rules->rule[i].get(); + if (model->low_confidence_rules != nullptr) { + for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) { + RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get(); if (!rule->pattern.empty()) { rule->compressed_pattern.reset(new CompressedBufferT); zlib_compressor->Compress(rule->pattern, @@ -113,12 +110,9 @@ bool DecompressActionsModel(ActionsModelT* model) { } // Decompress low confidence rules. - if (model->preconditions != nullptr && - model->preconditions->low_confidence_rules != nullptr) { - for (int i = 0; i < model->preconditions->low_confidence_rules->rule.size(); - i++) { - RulesModel_::RuleT* rule = - model->preconditions->low_confidence_rules->rule[i].get(); + if (model->low_confidence_rules != nullptr) { + for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) { + RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get(); if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(), &rule->pattern)) { TC3_LOG(ERROR) << "Cannot decompress pattern: " << i; diff --git a/annotator/number/number.cc b/annotator/number/number.cc index e58e682..bc3a2fe 100644 --- a/annotator/number/number.cc +++ b/annotator/number/number.cc @@ -69,7 +69,7 @@ bool NumberAnnotator::FindAll(const UnicodeText& context, classification.priority_score = options_->priority_score(); AnnotatedSpan annotated_span; - annotated_span.span = {token.start - num_prefix_codepoints, + annotated_span.span = {token.start + num_prefix_codepoints, token.end - num_suffix_codepoints}; annotated_span.classification.push_back(classification); diff --git a/annotator/number/number_test.cc b/annotator/number/number_test.cc index afa9444..d3b2e8c 100644 --- a/annotator/number/number_test.cc +++ b/annotator/number/number_test.cc @@ -145,6 +145,22 @@ TEST_F(NumberAnnotatorTest, FindsNumberWithPunctuation) { Field(&ClassificationResult::numeric_value, 9))))))); } +TEST_F(NumberAnnotatorTest, HandlesNumbersAtBeginning) { + std::vector<AnnotatedSpan> result; + EXPECT_TRUE(number_annotator_.FindAll( + UTF8ToUnicodeText("-5"), AnnotationUsecase_ANNOTATION_USECASE_RAW, + &result)); + + EXPECT_THAT( + result, + ElementsAre( + AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 2)), + Field(&AnnotatedSpan::classification, + ElementsAre(AllOf( + Field(&ClassificationResult::collection, "number"), + Field(&ClassificationResult::numeric_value, -5))))))); +} + TEST_F(NumberAnnotatorTest, WhenLowestSupportedNumberParsesIt) { ClassificationResult classification_result; EXPECT_TRUE(number_annotator_.ClassifyText( diff --git a/utils/sentencepiece/double_array_trie.cc b/utils/sentencepiece/double_array_trie.cc index b6ed1e1..a2b66ea 100644 --- a/utils/sentencepiece/double_array_trie.cc +++ b/utils/sentencepiece/double_array_trie.cc @@ -21,7 +21,7 @@ namespace libtextclassifier3 { bool DoubleArrayTrie::GatherPrefixMatches( StringPiece input, const std::function<void(TrieMatch)>& update_fn) const { - unsigned int pos = 0; + uint32 pos = 0; if (nodes_length_ == 0) { TC3_LOG(WARNING) << "Trie is empty. Skipping."; return true; diff --git a/utils/sentencepiece/double_array_trie.h b/utils/sentencepiece/double_array_trie.h index e88819a..0614fb4 100644 --- a/utils/sentencepiece/double_array_trie.h +++ b/utils/sentencepiece/double_array_trie.h @@ -21,6 +21,7 @@ #include <vector> #include "utils/base/endian.h" +#include "utils/base/integral_types.h" #include "utils/sentencepiece/matcher.h" #include "utils/strings/stringpiece.h" @@ -35,7 +36,7 @@ namespace libtextclassifier3 { // character during matching. // We account for endianness when using the node values, as they are serialized // (in little endian) as bytes in the flatbuffer model. -typedef unsigned int TrieNode; +typedef uint32 TrieNode; // A memory mappable trie, compatible with Darts::DoubleArray. class DoubleArrayTrie : public SentencePieceMatcher { @@ -53,22 +54,22 @@ class DoubleArrayTrie : public SentencePieceMatcher { private: // Returns whether a node as a leaf as a child. - bool has_leaf(unsigned int i) const { return nodes_[i] & 0x100; } + bool has_leaf(uint32 i) const { return nodes_[i] & 0x100; } // Available when a node is a leaf. - int value(unsigned int i) const { + int value(uint32 i) const { return static_cast<int>(LittleEndian::ToHost32(nodes_[i]) & 0x7fffffff); } // Label associated with a node. // A leaf node will have the MSB set and thus return an invalid label. - unsigned int label(unsigned int i) const { + uint32 label(uint32 i) const { return LittleEndian::ToHost32(nodes_[i]) & 0x800000ff; } // Returns offset to children. - unsigned int offset(unsigned int i) const { - const unsigned int node = LittleEndian::ToHost32(nodes_[i]); + uint32 offset(uint32 i) const { + const uint32 node = LittleEndian::ToHost32(nodes_[i]); return (node >> 10) << ((node & 0x200) >> 6); } diff --git a/utils/sentencepiece/encoder_test.cc b/utils/sentencepiece/encoder_test.cc index 3cdb3e3..9082cca 100644 --- a/utils/sentencepiece/encoder_test.cc +++ b/utils/sentencepiece/encoder_test.cc @@ -20,6 +20,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "utils/base/integral_types.h" #include "utils/sentencepiece/encoder.h" #include "utils/sentencepiece/sorted_strings_table.h" @@ -30,7 +31,7 @@ using testing::ElementsAre; TEST(EncoderTest, SimpleTokenization) { const char pieces[] = "hell\0hello\0o\0there\0"; - const int offsets[] = {0, 5, 11, 13}; + const uint32 offsets[] = {0, 5, 11, 13}; float scores[] = {-0.5, -1.0, -10.0, -1.0}; std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable( /*num_pieces=*/4, offsets, StringPiece(pieces, 18))); @@ -55,7 +56,7 @@ TEST(EncoderTest, SimpleTokenization) { TEST(EncoderTest, HandlesEdgeCases) { const char pieces[] = "hell\0hello\0o\0there\0"; - const int offsets[] = {0, 5, 11, 13}; + const uint32 offsets[] = {0, 5, 11, 13}; float scores[] = {-0.5, -1.0, -10.0, -1.0}; std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable( /*num_pieces=*/4, offsets, StringPiece(pieces, 18))); @@ -85,7 +86,7 @@ TEST(EncoderTest, HandlesEdgeCases) { TEST(EncoderTest, HandlesOutOfDictionary) { const char pieces[] = "hell\0hello\0o\0there\0"; - const int offsets[] = {0, 5, 11, 13}; + const uint32 offsets[] = {0, 5, 11, 13}; float scores[] = {-0.5, -1.0, -10.0, -1.0}; std::unique_ptr<SentencePieceMatcher> matcher(new SortedStringsTable( /*num_pieces=*/4, offsets, StringPiece(pieces, 18))); diff --git a/utils/sentencepiece/sorted_strings_table.cc b/utils/sentencepiece/sorted_strings_table.cc index 96637d8..8e7e9ba 100644 --- a/utils/sentencepiece/sorted_strings_table.cc +++ b/utils/sentencepiece/sorted_strings_table.cc @@ -41,15 +41,19 @@ void SortedStringsTable::GatherPrefixMatches( // `lower_bound` to find the start of the range of matching pieces. // `upper_bound` to find the non-inclusive end of the range. left = (std::lower_bound( - offsets_ + left, offsets_ + right, input[match_length], - [this, match_length](int piece_offset, int c) -> bool { - return pieces_[piece_offset + match_length] < c; + offsets_ + left, offsets_ + right, + static_cast<unsigned char>(input[match_length]), + [this, match_length](uint32 piece_offset, uint32 c) -> bool { + return static_cast<unsigned char>( + pieces_[piece_offset + match_length]) < c; }) - offsets_); right = (std::upper_bound( - offsets_ + left, offsets_ + right, input[match_length], - [this, match_length](int c, int piece_offset) -> bool { - return c < pieces_[piece_offset + match_length]; + offsets_ + left, offsets_ + right, + static_cast<unsigned char>(input[match_length]), + [this, match_length](uint32 c, uint32 piece_offset) -> bool { + return c < static_cast<unsigned char>( + pieces_[piece_offset + match_length]); }) - offsets_); span_size = right - left; diff --git a/utils/sentencepiece/sorted_strings_table.h b/utils/sentencepiece/sorted_strings_table.h index 61a2239..69f638a 100644 --- a/utils/sentencepiece/sorted_strings_table.h +++ b/utils/sentencepiece/sorted_strings_table.h @@ -20,6 +20,7 @@ #include <functional> #include <vector> +#include "utils/base/integral_types.h" #include "utils/sentencepiece/matcher.h" #include "utils/strings/stringpiece.h" @@ -36,7 +37,7 @@ namespace libtextclassifier3 { // switching to a linear sweep for prefix match testing. class SortedStringsTable : public SentencePieceMatcher { public: - SortedStringsTable(const int num_pieces, const int* offsets, + SortedStringsTable(const int num_pieces, const uint32* offsets, StringPiece pieces, const int use_linear_scan_threshold = 10) : num_pieces_(num_pieces), @@ -56,7 +57,7 @@ class SortedStringsTable : public SentencePieceMatcher { StringPiece input, const std::function<void(TrieMatch)>& update_fn) const; const int num_pieces_; - const int* offsets_; + const uint32* offsets_; const StringPiece pieces_; const int use_linear_scan_threshold_; }; diff --git a/utils/sentencepiece/sorted_strings_table_test.cc b/utils/sentencepiece/sorted_strings_table_test.cc index 10824d9..4dff29d 100644 --- a/utils/sentencepiece/sorted_strings_table_test.cc +++ b/utils/sentencepiece/sorted_strings_table_test.cc @@ -19,6 +19,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "utils/base/integral_types.h" #include "utils/sentencepiece/sorted_strings_table.h" namespace libtextclassifier3 { @@ -26,7 +27,7 @@ namespace { TEST(SortedStringsTest, Lookup) { const char pieces[] = "hell\0hello\0o\0there\0"; - const int offsets[] = {0, 5, 11, 13}; + const uint32 offsets[] = {0, 5, 11, 13}; SortedStringsTable table(/*num_pieces=*/4, offsets, StringPiece(pieces, 18), /*use_linear_scan_threshold=*/1); diff --git a/utils/tflite/text_encoder_config.fbs b/utils/tflite/text_encoder_config.fbs index 8ae8fc5..4ffade4 100644 --- a/utils/tflite/text_encoder_config.fbs +++ b/utils/tflite/text_encoder_config.fbs @@ -60,6 +60,6 @@ table TextEncoderConfig { // Serialized sentence pieces. pieces:string; - pieces_offsets:[int32]; + pieces_offsets:[uint32]; matcher_type: SentencePieceMatcherType = MAPPED_TRIE; } |