summaryrefslogtreecommitdiffstats
path: root/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp')
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp30
1 files changed, 19 insertions, 11 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index 31b1ea696..6db7ea444 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -21,6 +21,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
+#include "utils/ngram_utils.h"
namespace latinime {
@@ -89,16 +90,17 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
}
contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount();
}
+ const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(i + 1);
const float rawProbability =
DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts(
- historicalInfo->getCount(), contextCount, i + 1);
+ historicalInfo->getCount(), contextCount, ngramType);
const int encodedRawProbability =
ProbabilityUtils::encodeRawProbability(rawProbability);
const int decayedProbability =
DynamicLanguageModelProbabilityUtils::getDecayedProbability(
encodedRawProbability, *historicalInfo);
probability = DynamicLanguageModelProbabilityUtils::backoff(
- decayedProbability, i + 1 /* n */);
+ decayedProbability, ngramType);
} else {
probability = probabilityEntry.getProbability();
}
@@ -198,18 +200,19 @@ bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCo
MutableEntryCounters *const outEntryCounters) {
for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) {
const int totalWordCount = prevWordCount + 1;
- if (currentEntryCounts.getNgramCount(totalWordCount)
- <= maxEntryCounts.getNgramCount(totalWordCount)) {
- outEntryCounters->setNgramCount(totalWordCount,
- currentEntryCounts.getNgramCount(totalWordCount));
+ const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(totalWordCount);
+ if (currentEntryCounts.getNgramCount(ngramType)
+ <= maxEntryCounts.getNgramCount(ngramType)) {
+ outEntryCounters->setNgramCount(ngramType,
+ currentEntryCounts.getNgramCount(ngramType));
continue;
}
int entryCount = 0;
if (!turncateEntriesInSpecifiedLevel(headerPolicy,
- maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) {
+ maxEntryCounts.getNgramCount(ngramType), prevWordCount, &entryCount)) {
return false;
}
- outEntryCounters->setNgramCount(totalWordCount, entryCount);
+ outEntryCounters->setNgramCount(ngramType, entryCount);
}
return true;
}
@@ -246,7 +249,10 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
mGlobalCounters.updateMaxValueOfCounters(
updatedNgramProbabilityEntry.getHistoricalInfo()->getCount());
if (!originalNgramProbabilityEntry.isValid()) {
- entryCountersToUpdate->incrementNgramCount(i + 2);
+ // (i + 2) words are used in total because the prevWords consists of (i + 1) words when
+ // looking at its i-th element.
+ entryCountersToUpdate->incrementNgramCount(
+ NgramUtils::getNgramTypeFromWordCount(i + 2));
}
}
return true;
@@ -369,7 +375,8 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
}
}
}
- outEntryCounters->incrementNgramCount(prevWordCount + 1);
+ outEntryCounters->incrementNgramCount(
+ NgramUtils::getNgramTypeFromWordCount(prevWordCount + 1));
if (!entry.hasNextLevelMap()) {
continue;
}
@@ -402,7 +409,8 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
for (int i = 0; i < entryCountToRemove; ++i) {
const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
if (!removeNgramProbabilityEntry(
- WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) {
+ WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount),
+ entryInfo.mKey)) {
return false;
}
}