diff options
Diffstat (limited to 'actions')
-rw-r--r-- | actions/actions-suggestions.cc | 4 | ||||
-rw-r--r-- | actions/actions-suggestions.h | 2 | ||||
-rw-r--r-- | actions/actions-suggestions_test.cc | 18 | ||||
-rw-r--r-- | actions/actions_jni.cc | 3 | ||||
-rwxr-xr-x | actions/actions_model.fbs | 4 | ||||
-rw-r--r-- | actions/zlib-utils.cc | 18 |
6 files changed, 22 insertions, 27 deletions
diff --git a/actions/actions-suggestions.cc b/actions/actions-suggestions.cc index d8be1b0..1d4a70f 100644 --- a/actions/actions-suggestions.cc +++ b/actions/actions-suggestions.cc @@ -1035,7 +1035,7 @@ std::vector<int> ActionsSuggestions::DeduplicateAnnotations( bool ActionsSuggestions::FillAnnotationFromMatchGroup( const UniLib::RegexMatcher* matcher, - const RulesModel_::Rule_::RuleActionSpec_::CapturingGroup* group, + const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group, const int message_index, ActionSuggestionAnnotation* annotation) const { if (group->annotation_name() != nullptr || group->annotation_type() != nullptr) { @@ -1100,7 +1100,7 @@ bool ActionsSuggestions::SuggestActionsFromRules( // Add entity data from rule capturing groups. if (rule_action->capturing_group() != nullptr) { - for (const RulesModel_::Rule_::RuleActionSpec_::CapturingGroup* + for (const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group : *rule_action->capturing_group()) { if (group->entity_field() != nullptr) { TC3_CHECK(entity_data != nullptr); diff --git a/actions/actions-suggestions.h b/actions/actions-suggestions.h index 3c2ff09..61e052b 100644 --- a/actions/actions-suggestions.h +++ b/actions/actions-suggestions.h @@ -207,7 +207,7 @@ class ActionsSuggestions { bool FillAnnotationFromMatchGroup( const UniLib::RegexMatcher* matcher, - const RulesModel_::Rule_::RuleActionSpec_::CapturingGroup* group, + const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group, const int message_index, ActionSuggestionAnnotation* annotation) const; const ActionsModel* model_; diff --git a/actions/actions-suggestions_test.cc b/actions/actions-suggestions_test.cc index b510284..b0aeec6 100644 --- a/actions/actions-suggestions_test.cc +++ b/actions/actions-suggestions_test.cc @@ -554,6 +554,8 @@ TEST_F(ActionsSuggestionsTest, ReadFile(GetModelPath() + kModelFileName); std::unique_ptr<ActionsModelT> actions_model = UnPackActionsModel(actions_model_string.c_str()); + actions_model->low_confidence_rules.reset(); + // Add custom triggering rule. actions_model->rules.reset(new RulesModelT()); actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); @@ -731,16 +733,16 @@ TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) { // Set capturing groups for entity data. rule->actions.back()->capturing_group.emplace_back( - new RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT); - RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT* greeting_group = + new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); + RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* greeting_group = rule->actions.back()->capturing_group.back().get(); greeting_group->group_id = 0; greeting_group->entity_field.reset(new FlatbufferFieldPathT); greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT); greeting_group->entity_field->field.back()->field_name = "greeting"; rule->actions.back()->capturing_group.emplace_back( - new RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT); - RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT* location_group = + new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); + RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* location_group = rule->actions.back()->capturing_group.back().get(); location_group->group_id = 1; location_group->entity_field.reset(new FlatbufferFieldPathT); @@ -802,8 +804,8 @@ TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) { // Set capturing groups for entity data. rule->actions.back()->capturing_group.emplace_back( - new RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT); - RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT* code_group = + new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); + RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group = rule->actions.back()->capturing_group.back().get(); code_group->group_id = 1; code_group->text_reply.reset(new ActionSuggestionSpecT); @@ -915,8 +917,8 @@ TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) { action->priority_score = 2.0f; action->type = "test_code"; rule->actions.back()->capturing_group.emplace_back( - new RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT); - RulesModel_::Rule_::RuleActionSpec_::CapturingGroupT* code_group = + new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); + RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group = rule->actions.back()->capturing_group.back().get(); code_group->group_id = 1; code_group->annotation_name = "code"; diff --git a/actions/actions_jni.cc b/actions/actions_jni.cc index 1ee595f..20891fa 100644 --- a/actions/actions_jni.cc +++ b/actions/actions_jni.cc @@ -376,8 +376,7 @@ TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions) annotator ? annotator->entity_data_schema() : nullptr; return ActionSuggestionsToJObjectArray( env, context, app_context, anntotations_entity_data_schema, - response.actions, conversation, - /*device_locales=*/nullptr, generate_intents); + response.actions, conversation, device_locales, generate_intents); } TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel) diff --git a/actions/actions_model.fbs b/actions/actions_model.fbs index 885a717..dc4520a 100755 --- a/actions/actions_model.fbs +++ b/actions/actions_model.fbs @@ -332,7 +332,7 @@ table RankingOptions { // Entity data to set from capturing groups. namespace libtextclassifier3.RulesModel_.Rule_.RuleActionSpec_; -table CapturingGroup { +table RuleCapturingGroup { // The id of group. group_id:int; @@ -357,7 +357,7 @@ table RuleActionSpec { // The action. action:ActionSuggestionSpec; - capturing_group:[RuleActionSpec_.CapturingGroup]; + capturing_group:[RuleActionSpec_.RuleCapturingGroup]; } // List of regular expression matchers. diff --git a/actions/zlib-utils.cc b/actions/zlib-utils.cc index 317623b..b1d997d 100644 --- a/actions/zlib-utils.cc +++ b/actions/zlib-utils.cc @@ -42,12 +42,9 @@ bool CompressActionsModel(ActionsModelT* model) { } } - if (model->preconditions != nullptr && - model->preconditions->low_confidence_rules != nullptr) { - for (int i = 0; i < model->preconditions->low_confidence_rules->rule.size(); - i++) { - RulesModel_::RuleT* rule = - model->preconditions->low_confidence_rules->rule[i].get(); + if (model->low_confidence_rules != nullptr) { + for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) { + RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get(); if (!rule->pattern.empty()) { rule->compressed_pattern.reset(new CompressedBufferT); zlib_compressor->Compress(rule->pattern, @@ -113,12 +110,9 @@ bool DecompressActionsModel(ActionsModelT* model) { } // Decompress low confidence rules. - if (model->preconditions != nullptr && - model->preconditions->low_confidence_rules != nullptr) { - for (int i = 0; i < model->preconditions->low_confidence_rules->rule.size(); - i++) { - RulesModel_::RuleT* rule = - model->preconditions->low_confidence_rules->rule[i].get(); + if (model->low_confidence_rules != nullptr) { + for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) { + RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get(); if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(), &rule->pattern)) { TC3_LOG(ERROR) << "Cannot decompress pattern: " << i; |