diff options
Diffstat (limited to 'native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp')
-rw-r--r-- | native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index 31b1ea696..6db7ea444 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -21,6 +21,7 @@ #include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h" #include "suggest/policyimpl/dictionary/utils/probability_utils.h" +#include "utils/ngram_utils.h" namespace latinime { @@ -89,16 +90,17 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr } contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount(); } + const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(i + 1); const float rawProbability = DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts( - historicalInfo->getCount(), contextCount, i + 1); + historicalInfo->getCount(), contextCount, ngramType); const int encodedRawProbability = ProbabilityUtils::encodeRawProbability(rawProbability); const int decayedProbability = DynamicLanguageModelProbabilityUtils::getDecayedProbability( encodedRawProbability, *historicalInfo); probability = DynamicLanguageModelProbabilityUtils::backoff( - decayedProbability, i + 1 /* n */); + decayedProbability, ngramType); } else { probability = probabilityEntry.getProbability(); } @@ -198,18 +200,19 @@ bool LanguageModelDictContent::truncateEntries(const EntryCounts ¤tEntryCo MutableEntryCounters *const outEntryCounters) { for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) { const int totalWordCount = prevWordCount + 1; - if (currentEntryCounts.getNgramCount(totalWordCount) - <= maxEntryCounts.getNgramCount(totalWordCount)) { - outEntryCounters->setNgramCount(totalWordCount, - currentEntryCounts.getNgramCount(totalWordCount)); + const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(totalWordCount); + if (currentEntryCounts.getNgramCount(ngramType) + <= maxEntryCounts.getNgramCount(ngramType)) { + outEntryCounters->setNgramCount(ngramType, + currentEntryCounts.getNgramCount(ngramType)); continue; } int entryCount = 0; if (!turncateEntriesInSpecifiedLevel(headerPolicy, - maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) { + maxEntryCounts.getNgramCount(ngramType), prevWordCount, &entryCount)) { return false; } - outEntryCounters->setNgramCount(totalWordCount, entryCount); + outEntryCounters->setNgramCount(ngramType, entryCount); } return true; } @@ -246,7 +249,10 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView mGlobalCounters.updateMaxValueOfCounters( updatedNgramProbabilityEntry.getHistoricalInfo()->getCount()); if (!originalNgramProbabilityEntry.isValid()) { - entryCountersToUpdate->incrementNgramCount(i + 2); + // (i + 2) words are used in total because the prevWords consists of (i + 1) words when + // looking at its i-th element. + entryCountersToUpdate->incrementNgramCount( + NgramUtils::getNgramTypeFromWordCount(i + 2)); } } return true; @@ -369,7 +375,8 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b } } } - outEntryCounters->incrementNgramCount(prevWordCount + 1); + outEntryCounters->incrementNgramCount( + NgramUtils::getNgramTypeFromWordCount(prevWordCount + 1)); if (!entry.hasNextLevelMap()) { continue; } @@ -402,7 +409,8 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( for (int i = 0; i < entryCountToRemove; ++i) { const EntryInfoToTurncate &entryInfo = entryInfoVector[i]; if (!removeNgramProbabilityEntry( - WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) { + WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), + entryInfo.mKey)) { return false; } } |