From 78212a6d3de2c1fdaa394c58a16cbdee3ad5d046 Mon Sep 17 00:00:00 2001 From: Keisuke Kuroyanagi Date: Thu, 20 Nov 2014 15:27:30 +0900 Subject: Use enum to specify ngram type. Change-Id: Ie28768ceadcd7a2d940c57eb30be7d4c364e509f --- .../policyimpl/dictionary/header/header_policy.cpp | 58 ++++++++++----- .../policyimpl/dictionary/header/header_policy.h | 81 +++++---------------- .../backward/v402/ver4_patricia_trie_policy.cpp | 17 +++-- .../backward/v402/ver4_patricia_trie_policy.h | 3 +- .../v402/ver4_patricia_trie_writing_helper.cpp | 14 ++-- .../dynamic_language_model_probability_utils.cpp | 12 ++-- .../dynamic_language_model_probability_utils.h | 53 +++----------- .../v4/content/language_model_dict_content.cpp | 30 +++++--- .../structure/v4/ver4_patricia_trie_policy.cpp | 22 +++--- .../structure/v4/ver4_patricia_trie_policy.h | 3 +- .../v4/ver4_patricia_trie_writing_helper.cpp | 9 +-- .../policyimpl/dictionary/utils/entry_counters.h | 84 ++++++---------------- .../dictionary/utils/forgetting_curve_utils.cpp | 21 ++---- native/jni/src/utils/ngram_utils.h | 62 ++++++++++++++++ 14 files changed, 218 insertions(+), 251 deletions(-) create mode 100644 native/jni/src/utils/ngram_utils.h diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp index 300e96c4e..a2a0f11b4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp @@ -18,6 +18,8 @@ #include +#include "utils/ngram_utils.h" + namespace latinime { // Note that these are corresponding definitions in Java side in DictionaryHeader. @@ -28,9 +30,11 @@ const char *const HeaderPolicy::REQUIRES_GERMAN_UMLAUT_PROCESSING_KEY = const char *const HeaderPolicy::IS_DECAYING_DICT_KEY = "USES_FORGETTING_CURVE"; const char *const HeaderPolicy::DATE_KEY = "date"; const char *const HeaderPolicy::LAST_DECAYED_TIME_KEY = "LAST_DECAYED_TIME"; -const char *const HeaderPolicy::UNIGRAM_COUNT_KEY = "UNIGRAM_COUNT"; -const char *const HeaderPolicy::BIGRAM_COUNT_KEY = "BIGRAM_COUNT"; -const char *const HeaderPolicy::TRIGRAM_COUNT_KEY = "TRIGRAM_COUNT"; +const char *const HeaderPolicy::NGRAM_COUNT_KEYS[] = + {"UNIGRAM_COUNT", "BIGRAM_COUNT", "TRIGRAM_COUNT"}; +const char *const HeaderPolicy::MAX_NGRAM_COUNT_KEYS[] = + {"MAX_UNIGRAM_ENTRY_COUNT", "MAX_BIGRAM_ENTRY_COUNT", "MAX_TRIGRAM_ENTRY_COUNT"}; +const int HeaderPolicy::DEFAULT_MAX_NGRAM_COUNTS[] = {10000, 30000, 30000}; const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE"; // Historical info is information that is needed to support decaying such as timestamp, level and // count. @@ -39,18 +43,10 @@ const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY = "FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID"; -const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_ENTRY_COUNT"; -const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_ENTRY_COUNT"; -const char *const HeaderPolicy::MAX_TRIGRAM_COUNT_KEY = "MAX_TRIGRAM_ENTRY_COUNT"; - const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100; const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f; const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3; -const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000; -const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 30000; -const int HeaderPolicy::DEFAULT_MAX_TRIGRAM_COUNT = 30000; - // Used for logging. Question mark is used to indicate that the key is not found. void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue, int outValueSize) const { @@ -126,15 +122,22 @@ bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTim return true; } +namespace { + +int getIndexFromNgramType(const NgramType ngramType) { + return static_cast(ngramType); +} + +} // namespace + void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime, const EntryCounts &entryCounts, const int extendedRegionSize, DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const { - HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY, - entryCounts.getUnigramCount()); - HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY, - entryCounts.getBigramCount()); - HeaderReadWriteUtils::setIntAttribute(outAttributeMap, TRIGRAM_COUNT_KEY, - entryCounts.getTrigramCount()); + for (const auto ngramType : AllNgramTypes::ASCENDING) { + HeaderReadWriteUtils::setIntAttribute(outAttributeMap, + NGRAM_COUNT_KEYS[getIndexFromNgramType(ngramType)], + entryCounts.getNgramCount(ngramType)); + } HeaderReadWriteUtils::setIntAttribute(outAttributeMap, EXTENDED_REGION_SIZE_KEY, extendedRegionSize); // Set the current time as the generation time. @@ -155,4 +158,25 @@ void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime, return attributeMap; } +/* static */ const EntryCounts HeaderPolicy::readNgramCounts() const { + MutableEntryCounters entryCounters; + for (const auto ngramType : AllNgramTypes::ASCENDING) { + const int entryCount = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, + NGRAM_COUNT_KEYS[getIndexFromNgramType(ngramType)], 0 /* defaultValue */); + entryCounters.setNgramCount(ngramType, entryCount); + } + return entryCounters.getEntryCounts(); +} + +/* static */ const EntryCounts HeaderPolicy::readMaxNgramCounts() const { + MutableEntryCounters entryCounters; + for (const auto ngramType : AllNgramTypes::ASCENDING) { + const int index = getIndexFromNgramType(ngramType); + const int maxEntryCount = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, + MAX_NGRAM_COUNT_KEYS[index], DEFAULT_MAX_NGRAM_COUNTS[index]); + entryCounters.setNgramCount(ngramType, maxEntryCount); + } + return entryCounters.getEntryCounts(); +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index 7a5acd7d5..f76931baa 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -46,12 +46,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, LAST_DECAYED_TIME_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), - mUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, - UNIGRAM_COUNT_KEY, 0 /* defaultValue */)), - mBigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, - BIGRAM_COUNT_KEY, 0 /* defaultValue */)), - mTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, - TRIGRAM_COUNT_KEY, 0 /* defaultValue */)), + mNgramCounts(readNgramCounts()), mMaxNgramCounts(readMaxNgramCounts()), mExtendedRegionSize(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)), mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue( @@ -59,12 +54,6 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue( &mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY, DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)), - mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)), - mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)), - mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)), mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {} // Constructs header information using an attribute map. @@ -82,18 +71,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), - mUnigramCount(0), mBigramCount(0), mTrigramCount(0), mExtendedRegionSize(0), + mNgramCounts(readNgramCounts()), mMaxNgramCounts(readMaxNgramCounts()), + mExtendedRegionSize(0), mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue( &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)), mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue( &mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY, DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)), - mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)), - mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)), - mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)), mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {} // Copy header information @@ -105,15 +89,12 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { mRequiresGermanUmlautProcessing(headerPolicy->mRequiresGermanUmlautProcessing), mIsDecayingDict(headerPolicy->mIsDecayingDict), mDate(headerPolicy->mDate), mLastDecayedTime(headerPolicy->mLastDecayedTime), - mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount), - mTrigramCount(headerPolicy->mTrigramCount), + mNgramCounts(headerPolicy->mNgramCounts), + mMaxNgramCounts(headerPolicy->mMaxNgramCounts), mExtendedRegionSize(headerPolicy->mExtendedRegionSize), mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords), mForgettingCurveProbabilityValuesTableId( headerPolicy->mForgettingCurveProbabilityValuesTableId), - mMaxUnigramCount(headerPolicy->mMaxUnigramCount), - mMaxBigramCount(headerPolicy->mMaxBigramCount), - mMaxTrigramCount(headerPolicy->mMaxTrigramCount), mCodePointTable(headerPolicy->mCodePointTable) {} // Temporary dummy header. @@ -121,10 +102,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { : mDictFormatVersion(FormatUtils::UNKNOWN_VERSION), mDictionaryFlags(0), mSize(0), mAttributeMap(), mLocale(CharUtils::EMPTY_STRING), mMultiWordCostMultiplier(0.0f), mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false), - mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), mTrigramCount(0), + mDate(0), mLastDecayedTime(0), mNgramCounts(), mMaxNgramCounts(), mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false), - mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0), - mMaxTrigramCount(0), mCodePointTable(nullptr) {} + mForgettingCurveProbabilityValuesTableId(0), mCodePointTable(nullptr) {} ~HeaderPolicy() {} @@ -186,16 +166,12 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mLastDecayedTime; } - AK_FORCE_INLINE int getUnigramCount() const { - return mUnigramCount; + AK_FORCE_INLINE const EntryCounts &getNgramCounts() const { + return mNgramCounts; } - AK_FORCE_INLINE int getBigramCount() const { - return mBigramCount; - } - - AK_FORCE_INLINE int getTrigramCount() const { - return mTrigramCount; + AK_FORCE_INLINE const EntryCounts getMaxNgramCounts() const { + return mMaxNgramCounts; } AK_FORCE_INLINE int getExtendedRegionSize() const { @@ -219,18 +195,6 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mForgettingCurveProbabilityValuesTableId; } - AK_FORCE_INLINE int getMaxUnigramCount() const { - return mMaxUnigramCount; - } - - AK_FORCE_INLINE int getMaxBigramCount() const { - return mMaxBigramCount; - } - - AK_FORCE_INLINE int getMaxTrigramCount() const { - return mMaxTrigramCount; - } - void readHeaderValueOrQuestionMark(const char *const key, int *outValue, int outValueSize) const; @@ -262,24 +226,18 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { static const char *const IS_DECAYING_DICT_KEY; static const char *const DATE_KEY; static const char *const LAST_DECAYED_TIME_KEY; - static const char *const UNIGRAM_COUNT_KEY; - static const char *const BIGRAM_COUNT_KEY; - static const char *const TRIGRAM_COUNT_KEY; + static const char *const NGRAM_COUNT_KEYS[]; + static const char *const MAX_NGRAM_COUNT_KEYS[]; + static const int DEFAULT_MAX_NGRAM_COUNTS[]; static const char *const EXTENDED_REGION_SIZE_KEY; static const char *const HAS_HISTORICAL_INFO_KEY; static const char *const LOCALE_KEY; static const char *const FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY; static const char *const FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY; static const char *const FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY; - static const char *const MAX_UNIGRAM_COUNT_KEY; - static const char *const MAX_BIGRAM_COUNT_KEY; - static const char *const MAX_TRIGRAM_COUNT_KEY; static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE; static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE; static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID; - static const int DEFAULT_MAX_UNIGRAM_COUNT; - static const int DEFAULT_MAX_BIGRAM_COUNT; - static const int DEFAULT_MAX_TRIGRAM_COUNT; const FormatUtils::FORMAT_VERSION mDictFormatVersion; const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags; @@ -291,21 +249,18 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { const bool mIsDecayingDict; const int mDate; const int mLastDecayedTime; - const int mUnigramCount; - const int mBigramCount; - const int mTrigramCount; + const EntryCounts mNgramCounts; + const EntryCounts mMaxNgramCounts; const int mExtendedRegionSize; const bool mHasHistoricalInfoOfWords; const int mForgettingCurveProbabilityValuesTableId; - const int mMaxUnigramCount; - const int mMaxBigramCount; - const int mMaxTrigramCount; const int *const mCodePointTable; const std::vector readLocale() const; float readMultipleWordCostMultiplier() const; bool readRequiresGermanUmlautProcessing() const; - + const EntryCounts readNgramCounts() const; + const EntryCounts readMaxNgramCounts() const; static DictionaryHeaderStructurePolicy::AttributeMap createAttributeMapAndReadAllAttributes( const uint8_t *const dictBuf); }; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index ca7d93b0e..051aed45a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -303,7 +303,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty, &addedNewUnigram)) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { - mEntryCounters.incrementUnigramCount(); + mEntryCounters.incrementNgramCount(NgramType::Unigram); } if (unigramProperty->getShortcuts().size() > 0) { // Add shortcut target. @@ -397,7 +397,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramPrope if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos, ngramProperty, &addedNewBigram)) { if (addedNewBigram) { - mEntryCounters.incrementBigramCount(); + mEntryCounters.incrementNgramCount(NgramType::Bigram); } return true; } else { @@ -438,7 +438,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); if (mUpdatingHelper.removeNgramEntry( PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) { - mEntryCounters.decrementBigramCount(); + mEntryCounters.decrementNgramCount(NgramType::Bigram); return true; } else { return false; @@ -525,20 +525,23 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer char *const outResult, const int maxResultLength) { const int compareLength = queryLength + 1 /* terminator */; if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) { - snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount()); + snprintf(outResult, maxResultLength, "%d", + mEntryCounters.getNgramCount(NgramType::Unigram)); } else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) { - snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount()); + snprintf(outResult, maxResultLength, "%d", mEntryCounters.getNgramCount(NgramType::Bigram)); } else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) { snprintf(outResult, maxResultLength, "%d", mHeaderPolicy->isDecayingDict() ? ForgettingCurveUtils::getEntryCountHardLimit( - mHeaderPolicy->getMaxUnigramCount()) : + mHeaderPolicy->getMaxNgramCounts().getNgramCount( + NgramType::Unigram)) : static_cast(Ver4DictConstants::MAX_DICTIONARY_SIZE)); } else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) { snprintf(outResult, maxResultLength, "%d", mHeaderPolicy->isDecayingDict() ? ForgettingCurveUtils::getEntryCountHardLimit( - mHeaderPolicy->getMaxBigramCount()) : + mHeaderPolicy->getMaxNgramCounts().getNgramCount( + NgramType::Bigram)) : static_cast(Ver4DictConstants::MAX_DICTIONARY_SIZE)); } } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index 0480876ed..80b1111b4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -76,8 +76,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { &mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy), mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), mWritingHelper(mBuffers.get()), - mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(), - mHeaderPolicy->getTrigramCount()), + mEntryCounters(mHeaderPolicy->getNgramCounts().getCountArray()), mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {}; virtual int getRootPosition() const { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp index a033d396b..985c16803 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp @@ -53,8 +53,8 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPat entryCounts, extendedRegionSize, &headerBuffer)) { AKLOGE("Cannot write header structure to buffer. " "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, " - "extendedRegionSize: %d", false, entryCounts.getUnigramCount(), - entryCounts.getBigramCount(), extendedRegionSize); + "extendedRegionSize: %d", false, entryCounts.getNgramCount(NgramType::Unigram), + entryCounts.getNgramCount(NgramType::Bigram), extendedRegionSize); return false; } return mBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer); @@ -73,9 +73,11 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr } BufferWithExtendableBuffer headerBuffer( BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); + MutableEntryCounters entryCounters; + entryCounters.setNgramCount(NgramType::Unigram, unigramCount); + entryCounters.setNgramCount(NgramType::Bigram, bigramCount); if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */, - EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */), - 0 /* extendedRegionSize */, &headerBuffer)) { + entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) { return false; } return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer); @@ -107,7 +109,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, } const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted .getValidUnigramCount(); - const int maxUnigramCount = headerPolicy->getMaxUnigramCount(); + const int maxUnigramCount = headerPolicy->getMaxNgramCounts().getNgramCount(NgramType::Unigram); if (headerPolicy->isDecayingDict() && unigramCount > maxUnigramCount) { if (!truncateUnigrams(&ptNodeReader, &ptNodeWriter, maxUnigramCount)) { AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount, @@ -124,7 +126,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, return false; } const int bigramCount = traversePolicyToUpdateBigramProbability.getValidBigramEntryCount(); - const int maxBigramCount = headerPolicy->getMaxBigramCount(); + const int maxBigramCount = headerPolicy->getMaxNgramCounts().getNgramCount(NgramType::Bigram); if (headerPolicy->isDecayingDict() && bigramCount > maxBigramCount) { if (!truncateBigrams(maxBigramCount)) { AKLOGE("Cannot remove bigrams. current: %d, max: %d", bigramCount, maxBigramCount); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp index b0fbb3e72..29bc7f719 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp @@ -18,17 +18,13 @@ namespace latinime { -// These counts are used to provide stable probabilities even if the user's input count is small. -const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_UNIGRAMS = 8192; -const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_BIGRAMS = 2; -const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_TRIGRAMS = 2; +// Used to provide stable probabilities even if the user's input count is small. +const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNTS[] = {8192, 2, 2}; -// These are encoded backoff weights. +// Encoded backoff weights. // Note that we give positive value for trigrams that means the weight is more than 1. // TODO: Apply backoff for main dictionaries and quit giving a positive backoff weight. -const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS = -32; -const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS = 0; -const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS = 8; +const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHTS[] = {-32, 0, 8}; // This value is used to remove too old entries from the dictionary. const int DynamicLanguageModelProbabilityUtils::DURATION_TO_DISCARD_ENTRY_IN_SECONDS = diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h index 88bc58fe8..b38047f49 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h @@ -21,6 +21,7 @@ #include "defines.h" #include "suggest/core/dictionary/property/historical_info.h" +#include "utils/ngram_utils.h" #include "utils/time_keeper.h" namespace latinime { @@ -28,46 +29,14 @@ namespace latinime { class DynamicLanguageModelProbabilityUtils { public: static float computeRawProbabilityFromCounts(const int count, const int contextCount, - const int matchedWordCountInContext) { - int minCount = 0; - switch (matchedWordCountInContext) { - case 1: - minCount = ASSUMED_MIN_COUNT_FOR_UNIGRAMS; - break; - case 2: - minCount = ASSUMED_MIN_COUNT_FOR_BIGRAMS; - break; - case 3: - minCount = ASSUMED_MIN_COUNT_FOR_TRIGRAMS; - break; - default: - AKLOGE("computeRawProbabilityFromCounts is called with invalid " - "matchedWordCountInContext (%d).", matchedWordCountInContext); - ASSERT(false); - return 0.0f; - } + const NgramType ngramType) { + const int minCount = ASSUMED_MIN_COUNTS[static_cast(ngramType)]; return static_cast(count) / static_cast(std::max(contextCount, minCount)); } - static float backoff(const int ngramProbability, const int matchedWordCountInContext) { - int probability = NOT_A_PROBABILITY; - - switch (matchedWordCountInContext) { - case 1: - probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS; - break; - case 2: - probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS; - break; - case 3: - probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS; - break; - default: - AKLOGE("backoff is called with invalid matchedWordCountInContext (%d).", - matchedWordCountInContext); - ASSERT(false); - return NOT_A_PROBABILITY; - } + static float backoff(const int ngramProbability, const NgramType ngramType) { + const int probability = + ngramProbability + ENCODED_BACKOFF_WEIGHTS[static_cast(ngramType)]; return std::min(std::max(probability, NOT_A_PROBABILITY), MAX_PROBABILITY); } @@ -99,14 +68,8 @@ private: static_assert(MAX_PREV_WORD_COUNT_FOR_N_GRAM <= 2, "Max supported Ngram is Trigram."); - static const int ASSUMED_MIN_COUNT_FOR_UNIGRAMS; - static const int ASSUMED_MIN_COUNT_FOR_BIGRAMS; - static const int ASSUMED_MIN_COUNT_FOR_TRIGRAMS; - - static const int ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS; - static const int ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS; - static const int ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS; - + static const int ASSUMED_MIN_COUNTS[]; + static const int ENCODED_BACKOFF_WEIGHTS[]; static const int DURATION_TO_DISCARD_ENTRY_IN_SECONDS; }; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index 31b1ea696..6db7ea444 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -21,6 +21,7 @@ #include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h" #include "suggest/policyimpl/dictionary/utils/probability_utils.h" +#include "utils/ngram_utils.h" namespace latinime { @@ -89,16 +90,17 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr } contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount(); } + const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(i + 1); const float rawProbability = DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts( - historicalInfo->getCount(), contextCount, i + 1); + historicalInfo->getCount(), contextCount, ngramType); const int encodedRawProbability = ProbabilityUtils::encodeRawProbability(rawProbability); const int decayedProbability = DynamicLanguageModelProbabilityUtils::getDecayedProbability( encodedRawProbability, *historicalInfo); probability = DynamicLanguageModelProbabilityUtils::backoff( - decayedProbability, i + 1 /* n */); + decayedProbability, ngramType); } else { probability = probabilityEntry.getProbability(); } @@ -198,18 +200,19 @@ bool LanguageModelDictContent::truncateEntries(const EntryCounts ¤tEntryCo MutableEntryCounters *const outEntryCounters) { for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) { const int totalWordCount = prevWordCount + 1; - if (currentEntryCounts.getNgramCount(totalWordCount) - <= maxEntryCounts.getNgramCount(totalWordCount)) { - outEntryCounters->setNgramCount(totalWordCount, - currentEntryCounts.getNgramCount(totalWordCount)); + const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(totalWordCount); + if (currentEntryCounts.getNgramCount(ngramType) + <= maxEntryCounts.getNgramCount(ngramType)) { + outEntryCounters->setNgramCount(ngramType, + currentEntryCounts.getNgramCount(ngramType)); continue; } int entryCount = 0; if (!turncateEntriesInSpecifiedLevel(headerPolicy, - maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) { + maxEntryCounts.getNgramCount(ngramType), prevWordCount, &entryCount)) { return false; } - outEntryCounters->setNgramCount(totalWordCount, entryCount); + outEntryCounters->setNgramCount(ngramType, entryCount); } return true; } @@ -246,7 +249,10 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView mGlobalCounters.updateMaxValueOfCounters( updatedNgramProbabilityEntry.getHistoricalInfo()->getCount()); if (!originalNgramProbabilityEntry.isValid()) { - entryCountersToUpdate->incrementNgramCount(i + 2); + // (i + 2) words are used in total because the prevWords consists of (i + 1) words when + // looking at its i-th element. + entryCountersToUpdate->incrementNgramCount( + NgramUtils::getNgramTypeFromWordCount(i + 2)); } } return true; @@ -369,7 +375,8 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b } } } - outEntryCounters->incrementNgramCount(prevWordCount + 1); + outEntryCounters->incrementNgramCount( + NgramUtils::getNgramTypeFromWordCount(prevWordCount + 1)); if (!entry.hasNextLevelMap()) { continue; } @@ -402,7 +409,8 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( for (int i = 0; i < entryCountToRemove; ++i) { const EntryInfoToTurncate &entryInfo = entryInfoVector[i]; if (!removeNgramProbabilityEntry( - WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) { + WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), + entryInfo.mKey)) { return false; } } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 7449cd02b..a96719533 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -31,6 +31,7 @@ #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" #include "suggest/policyimpl/dictionary/utils/probability_utils.h" +#include "utils/ngram_utils.h" namespace latinime { @@ -215,7 +216,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty, &addedNewUnigram)) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { - mEntryCounters.incrementUnigramCount(); + mEntryCounters.incrementNgramCount(NgramType::Unigram); } if (unigramProperty->getShortcuts().size() > 0) { // Add shortcut target. @@ -263,7 +264,7 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod return false; } if (!ptNodeParams.representsNonWordInfo()) { - mEntryCounters.decrementUnigramCount(); + mEntryCounters.decrementNgramCount(NgramType::Unigram); } return true; } @@ -321,7 +322,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramPrope bool addedNewEntry = false; if (mNodeWriter.addNgramEntry(prevWordIds, wordId, ngramProperty, &addedNewEntry)) { if (addedNewEntry) { - mEntryCounters.incrementNgramCount(prevWordIds.size() + 1); + mEntryCounters.incrementNgramCount( + NgramUtils::getNgramTypeFromWordCount(prevWordIds.size() + 1)); } return true; } else { @@ -359,7 +361,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon return false; } if (mNodeWriter.removeNgramEntry(prevWordIds, wordId)) { - mEntryCounters.decrementNgramCount(prevWordIds.size()); + mEntryCounters.decrementNgramCount( + NgramUtils::getNgramTypeFromWordCount(prevWordIds.size() + 1)); return true; } else { return false; @@ -477,20 +480,23 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer char *const outResult, const int maxResultLength) { const int compareLength = queryLength + 1 /* terminator */; if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) { - snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount()); + snprintf(outResult, maxResultLength, "%d", + mEntryCounters.getNgramCount(NgramType::Unigram)); } else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) { - snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount()); + snprintf(outResult, maxResultLength, "%d", mEntryCounters.getNgramCount(NgramType::Bigram)); } else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) { snprintf(outResult, maxResultLength, "%d", mHeaderPolicy->isDecayingDict() ? ForgettingCurveUtils::getEntryCountHardLimit( - mHeaderPolicy->getMaxUnigramCount()) : + mHeaderPolicy->getMaxNgramCounts().getNgramCount( + NgramType::Unigram)) : static_cast(Ver4DictConstants::MAX_DICTIONARY_SIZE)); } else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) { snprintf(outResult, maxResultLength, "%d", mHeaderPolicy->isDecayingDict() ? ForgettingCurveUtils::getEntryCountHardLimit( - mHeaderPolicy->getMaxBigramCount()) : + mHeaderPolicy->getMaxNgramCounts().getNgramCount( + NgramType::Bigram)) : static_cast(Ver4DictConstants::MAX_DICTIONARY_SIZE)); } } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index 13700b390..93faa83a0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -51,8 +51,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { &mShortcutPolicy), mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), mWritingHelper(mBuffers.get()), - mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(), - mHeaderPolicy->getTrigramCount()), + mEntryCounters(mHeaderPolicy->getNgramCounts().getCountArray()), mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {}; AK_FORCE_INLINE int getRootPosition() const { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index 7f0604ce8..34af76c5d 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -29,6 +29,7 @@ #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" #include "suggest/policyimpl/dictionary/utils/file_utils.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" +#include "utils/ngram_utils.h" namespace latinime { @@ -43,8 +44,9 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPat entryCounts, extendedRegionSize, &headerBuffer)) { AKLOGE("Cannot write header structure to buffer. " "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, trigramCount: %d," - "extendedRegionSize: %d", false, entryCounts.getUnigramCount(), - entryCounts.getBigramCount(), entryCounts.getTrigramCount(), + "extendedRegionSize: %d", false, entryCounts.getNgramCount(NgramType::Unigram), + entryCounts.getNgramCount(NgramType::Bigram), + entryCounts.getNgramCount(NgramType::Trigram), extendedRegionSize); return false; } @@ -86,8 +88,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, return false; } if (headerPolicy->isDecayingDict()) { - const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(), - headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount()); + const EntryCounts &maxEntryCounts = headerPolicy->getMaxNgramCounts(); if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries( outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy, outEntryCounters)) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h index 73dc42a18..7269913e8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h @@ -20,6 +20,7 @@ #include #include "defines.h" +#include "utils/ngram_utils.h" namespace latinime { @@ -28,34 +29,22 @@ class EntryCounts final { public: EntryCounts() : mEntryCounts({{0, 0, 0}}) {} - EntryCounts(const int unigramCount, const int bigramCount, const int trigramCount) - : mEntryCounts({{unigramCount, bigramCount, trigramCount}}) {} - explicit EntryCounts(const std::array &counters) : mEntryCounts(counters) {} - int getUnigramCount() const { - return mEntryCounts[0]; - } - - int getBigramCount() const { - return mEntryCounts[1]; - } - - int getTrigramCount() const { - return mEntryCounts[2]; + int getNgramCount(const NgramType ngramType) const { + return mEntryCounts[static_cast(ngramType)]; } - int getNgramCount(const size_t n) const { - if (n < 1 || n > mEntryCounts.size()) { - return 0; - } - return mEntryCounts[n - 1]; + const std::array &getCountArray() const { + return mEntryCounts; } private: DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts); + // Counts from Unigram (0-th element) to (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram + // (MAX_PREV_WORD_COUNT_FOR_N_GRAM-th element) const std::array mEntryCounts; }; @@ -65,68 +54,35 @@ class MutableEntryCounters final { mEntryCounters.fill(0); } - MutableEntryCounters(const int unigramCount, const int bigramCount, const int trigramCount) - : mEntryCounters({{unigramCount, bigramCount, trigramCount}}) {} + explicit MutableEntryCounters( + const std::array &counters) + : mEntryCounters(counters) {} const EntryCounts getEntryCounts() const { return EntryCounts(mEntryCounters); } - int getUnigramCount() const { - return mEntryCounters[0]; - } - - int getBigramCount() const { - return mEntryCounters[1]; - } - - int getTrigramCount() const { - return mEntryCounters[2]; - } - - void incrementUnigramCount() { - ++mEntryCounters[0]; - } - - void decrementUnigramCount() { - ASSERT(mEntryCounters[0] != 0); - --mEntryCounters[0]; - } - - void incrementBigramCount() { - ++mEntryCounters[1]; - } - - void decrementBigramCount() { - ASSERT(mEntryCounters[1] != 0); - --mEntryCounters[1]; + void incrementNgramCount(const NgramType ngramType) { + ++mEntryCounters[static_cast(ngramType)]; } - void incrementNgramCount(const size_t n) { - if (n < 1 || n > mEntryCounters.size()) { - return; - } - ++mEntryCounters[n - 1]; + void decrementNgramCount(const NgramType ngramType) { + --mEntryCounters[static_cast(ngramType)]; } - void decrementNgramCount(const size_t n) { - if (n < 1 || n > mEntryCounters.size()) { - return; - } - ASSERT(mEntryCounters[n - 1] != 0); - --mEntryCounters[n - 1]; + int getNgramCount(const NgramType ngramType) const { + return mEntryCounters[static_cast(ngramType)]; } - void setNgramCount(const size_t n, const int count) { - if (n < 1 || n > mEntryCounters.size()) { - return; - } - mEntryCounters[n - 1] = count; + void setNgramCount(const NgramType ngramType, const int count) { + mEntryCounters[static_cast(ngramType)] = count; } private: DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters); + // Counters from Unigram (0-th element) to (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram + // (MAX_PREV_WORD_COUNT_FOR_N_GRAM-th element) std::array mEntryCounters; }; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp index 9055f7bfc..f05c6149e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp @@ -126,20 +126,13 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT /* static */ bool ForgettingCurveUtils::needsToDecay(const bool mindsBlockByDecay, const EntryCounts &entryCounts, const HeaderPolicy *const headerPolicy) { - if (entryCounts.getUnigramCount() - >= getEntryCountHardLimit(headerPolicy->getMaxUnigramCount())) { - // Unigram count exceeds the limit. - return true; - } - if (entryCounts.getBigramCount() - >= getEntryCountHardLimit(headerPolicy->getMaxBigramCount())) { - // Bigram count exceeds the limit. - return true; - } - if (entryCounts.getTrigramCount() - >= getEntryCountHardLimit(headerPolicy->getMaxTrigramCount())) { - // Trigram count exceeds the limit. - return true; + const EntryCounts &maxNgramCounts = headerPolicy->getMaxNgramCounts(); + for (const auto ngramType : AllNgramTypes::ASCENDING) { + if (entryCounts.getNgramCount(ngramType) + >= getEntryCountHardLimit(maxNgramCounts.getNgramCount(ngramType))) { + // Unigram count exceeds the limit. + return true; + } } if (mindsBlockByDecay) { return false; diff --git a/native/jni/src/utils/ngram_utils.h b/native/jni/src/utils/ngram_utils.h new file mode 100644 index 000000000..6227812d4 --- /dev/null +++ b/native/jni/src/utils/ngram_utils.h @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_NGRAM_UTILS_H +#define LATINIME_NGRAM_UTILS_H + +#include "defines.h" + +namespace latinime { + +enum class NgramType : int { + Unigram = 0, + Bigram = 1, + Trigram = 2, + NotANgramType = -1, +}; + +namespace AllNgramTypes { +// Use anonymous namespace to avoid ODR (One Definition Rule) violation. +namespace { + +const NgramType ASCENDING[] = { + NgramType::Unigram, NgramType::Bigram, NgramType::Trigram +}; + +const NgramType DESCENDING[] = { + NgramType::Trigram, NgramType::Bigram, NgramType::Unigram +}; + +} // namespace +} // namespace AllNgramTypes + +class NgramUtils final { + public: + static AK_FORCE_INLINE NgramType getNgramTypeFromWordCount(const int wordCount) { + // Max supported ngram is (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram. + if (wordCount <= 0 || wordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1) { + return NgramType::NotANgramType; + } + // Convert word count to 0-origin enum value. + return static_cast(wordCount - 1); + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(NgramUtils); + +}; +} +#endif /* LATINIME_NGRAM_UTILS_H */ -- cgit v1.2.3