summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKeisuke Kuroyanagi <ksk@google.com>2014-11-25 10:34:15 +0000
committerAndroid (Google) Code Review <android-gerrit@google.com>2014-11-25 10:34:15 +0000
commit20da4f07be9cdf58835a79e619785b4cafd428ff (patch)
tree8e19b1d6029a7a8029d3cc472400415f20740371
parentb479a56f281d7b20abbb05369a5d6d0a6ce9dc67 (diff)
parent78212a6d3de2c1fdaa394c58a16cbdee3ad5d046 (diff)
downloadandroid_packages_inputmethods_LatinIME-20da4f07be9cdf58835a79e619785b4cafd428ff.tar.gz
android_packages_inputmethods_LatinIME-20da4f07be9cdf58835a79e619785b4cafd428ff.tar.bz2
android_packages_inputmethods_LatinIME-20da4f07be9cdf58835a79e619785b4cafd428ff.zip
Merge "Use enum to specify ngram type."
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp58
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h81
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp17
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp14
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp12
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h53
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp30
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp22
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp9
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h84
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp21
-rw-r--r--native/jni/src/utils/ngram_utils.h62
14 files changed, 218 insertions, 251 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
index 300e96c4e..a2a0f11b4 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
@@ -18,6 +18,8 @@
#include <algorithm>
+#include "utils/ngram_utils.h"
+
namespace latinime {
// Note that these are corresponding definitions in Java side in DictionaryHeader.
@@ -28,9 +30,11 @@ const char *const HeaderPolicy::REQUIRES_GERMAN_UMLAUT_PROCESSING_KEY =
const char *const HeaderPolicy::IS_DECAYING_DICT_KEY = "USES_FORGETTING_CURVE";
const char *const HeaderPolicy::DATE_KEY = "date";
const char *const HeaderPolicy::LAST_DECAYED_TIME_KEY = "LAST_DECAYED_TIME";
-const char *const HeaderPolicy::UNIGRAM_COUNT_KEY = "UNIGRAM_COUNT";
-const char *const HeaderPolicy::BIGRAM_COUNT_KEY = "BIGRAM_COUNT";
-const char *const HeaderPolicy::TRIGRAM_COUNT_KEY = "TRIGRAM_COUNT";
+const char *const HeaderPolicy::NGRAM_COUNT_KEYS[] =
+ {"UNIGRAM_COUNT", "BIGRAM_COUNT", "TRIGRAM_COUNT"};
+const char *const HeaderPolicy::MAX_NGRAM_COUNT_KEYS[] =
+ {"MAX_UNIGRAM_ENTRY_COUNT", "MAX_BIGRAM_ENTRY_COUNT", "MAX_TRIGRAM_ENTRY_COUNT"};
+const int HeaderPolicy::DEFAULT_MAX_NGRAM_COUNTS[] = {10000, 30000, 30000};
const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE";
// Historical info is information that is needed to support decaying such as timestamp, level and
// count.
@@ -39,18 +43,10 @@ const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration
const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY =
"FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID";
-const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_ENTRY_COUNT";
-const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_ENTRY_COUNT";
-const char *const HeaderPolicy::MAX_TRIGRAM_COUNT_KEY = "MAX_TRIGRAM_ENTRY_COUNT";
-
const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100;
const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3;
-const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000;
-const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 30000;
-const int HeaderPolicy::DEFAULT_MAX_TRIGRAM_COUNT = 30000;
-
// Used for logging. Question mark is used to indicate that the key is not found.
void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue,
int outValueSize) const {
@@ -126,15 +122,22 @@ bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTim
return true;
}
+namespace {
+
+int getIndexFromNgramType(const NgramType ngramType) {
+ return static_cast<int>(ngramType);
+}
+
+} // namespace
+
void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime,
const EntryCounts &entryCounts, const int extendedRegionSize,
DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const {
- HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY,
- entryCounts.getUnigramCount());
- HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY,
- entryCounts.getBigramCount());
- HeaderReadWriteUtils::setIntAttribute(outAttributeMap, TRIGRAM_COUNT_KEY,
- entryCounts.getTrigramCount());
+ for (const auto ngramType : AllNgramTypes::ASCENDING) {
+ HeaderReadWriteUtils::setIntAttribute(outAttributeMap,
+ NGRAM_COUNT_KEYS[getIndexFromNgramType(ngramType)],
+ entryCounts.getNgramCount(ngramType));
+ }
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, EXTENDED_REGION_SIZE_KEY,
extendedRegionSize);
// Set the current time as the generation time.
@@ -155,4 +158,25 @@ void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime,
return attributeMap;
}
+/* static */ const EntryCounts HeaderPolicy::readNgramCounts() const {
+ MutableEntryCounters entryCounters;
+ for (const auto ngramType : AllNgramTypes::ASCENDING) {
+ const int entryCount = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
+ NGRAM_COUNT_KEYS[getIndexFromNgramType(ngramType)], 0 /* defaultValue */);
+ entryCounters.setNgramCount(ngramType, entryCount);
+ }
+ return entryCounters.getEntryCounts();
+}
+
+/* static */ const EntryCounts HeaderPolicy::readMaxNgramCounts() const {
+ MutableEntryCounters entryCounters;
+ for (const auto ngramType : AllNgramTypes::ASCENDING) {
+ const int index = getIndexFromNgramType(ngramType);
+ const int maxEntryCount = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
+ MAX_NGRAM_COUNT_KEYS[index], DEFAULT_MAX_NGRAM_COUNTS[index]);
+ entryCounters.setNgramCount(ngramType, maxEntryCount);
+ }
+ return entryCounters.getEntryCounts();
+}
+
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
index 7a5acd7d5..f76931baa 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
@@ -46,12 +46,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
LAST_DECAYED_TIME_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
- mUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
- UNIGRAM_COUNT_KEY, 0 /* defaultValue */)),
- mBigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
- BIGRAM_COUNT_KEY, 0 /* defaultValue */)),
- mTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
- TRIGRAM_COUNT_KEY, 0 /* defaultValue */)),
+ mNgramCounts(readNgramCounts()), mMaxNgramCounts(readMaxNgramCounts()),
mExtendedRegionSize(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
@@ -59,12 +54,6 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
- mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
- &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
- mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
- &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)),
- mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(
- &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)),
mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {}
// Constructs header information using an attribute map.
@@ -82,18 +71,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
- mUnigramCount(0), mBigramCount(0), mTrigramCount(0), mExtendedRegionSize(0),
+ mNgramCounts(readNgramCounts()), mMaxNgramCounts(readMaxNgramCounts()),
+ mExtendedRegionSize(0),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
&mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
- mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
- &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
- mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
- &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)),
- mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(
- &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)),
mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {}
// Copy header information
@@ -105,15 +89,12 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mRequiresGermanUmlautProcessing(headerPolicy->mRequiresGermanUmlautProcessing),
mIsDecayingDict(headerPolicy->mIsDecayingDict),
mDate(headerPolicy->mDate), mLastDecayedTime(headerPolicy->mLastDecayedTime),
- mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount),
- mTrigramCount(headerPolicy->mTrigramCount),
+ mNgramCounts(headerPolicy->mNgramCounts),
+ mMaxNgramCounts(headerPolicy->mMaxNgramCounts),
mExtendedRegionSize(headerPolicy->mExtendedRegionSize),
mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords),
mForgettingCurveProbabilityValuesTableId(
headerPolicy->mForgettingCurveProbabilityValuesTableId),
- mMaxUnigramCount(headerPolicy->mMaxUnigramCount),
- mMaxBigramCount(headerPolicy->mMaxBigramCount),
- mMaxTrigramCount(headerPolicy->mMaxTrigramCount),
mCodePointTable(headerPolicy->mCodePointTable) {}
// Temporary dummy header.
@@ -121,10 +102,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
: mDictFormatVersion(FormatUtils::UNKNOWN_VERSION), mDictionaryFlags(0), mSize(0),
mAttributeMap(), mLocale(CharUtils::EMPTY_STRING), mMultiWordCostMultiplier(0.0f),
mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false),
- mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), mTrigramCount(0),
+ mDate(0), mLastDecayedTime(0), mNgramCounts(), mMaxNgramCounts(),
mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false),
- mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0),
- mMaxTrigramCount(0), mCodePointTable(nullptr) {}
+ mForgettingCurveProbabilityValuesTableId(0), mCodePointTable(nullptr) {}
~HeaderPolicy() {}
@@ -186,16 +166,12 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mLastDecayedTime;
}
- AK_FORCE_INLINE int getUnigramCount() const {
- return mUnigramCount;
+ AK_FORCE_INLINE const EntryCounts &getNgramCounts() const {
+ return mNgramCounts;
}
- AK_FORCE_INLINE int getBigramCount() const {
- return mBigramCount;
- }
-
- AK_FORCE_INLINE int getTrigramCount() const {
- return mTrigramCount;
+ AK_FORCE_INLINE const EntryCounts getMaxNgramCounts() const {
+ return mMaxNgramCounts;
}
AK_FORCE_INLINE int getExtendedRegionSize() const {
@@ -219,18 +195,6 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mForgettingCurveProbabilityValuesTableId;
}
- AK_FORCE_INLINE int getMaxUnigramCount() const {
- return mMaxUnigramCount;
- }
-
- AK_FORCE_INLINE int getMaxBigramCount() const {
- return mMaxBigramCount;
- }
-
- AK_FORCE_INLINE int getMaxTrigramCount() const {
- return mMaxTrigramCount;
- }
-
void readHeaderValueOrQuestionMark(const char *const key,
int *outValue, int outValueSize) const;
@@ -262,24 +226,18 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
static const char *const IS_DECAYING_DICT_KEY;
static const char *const DATE_KEY;
static const char *const LAST_DECAYED_TIME_KEY;
- static const char *const UNIGRAM_COUNT_KEY;
- static const char *const BIGRAM_COUNT_KEY;
- static const char *const TRIGRAM_COUNT_KEY;
+ static const char *const NGRAM_COUNT_KEYS[];
+ static const char *const MAX_NGRAM_COUNT_KEYS[];
+ static const int DEFAULT_MAX_NGRAM_COUNTS[];
static const char *const EXTENDED_REGION_SIZE_KEY;
static const char *const HAS_HISTORICAL_INFO_KEY;
static const char *const LOCALE_KEY;
static const char *const FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY;
static const char *const FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY;
static const char *const FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY;
- static const char *const MAX_UNIGRAM_COUNT_KEY;
- static const char *const MAX_BIGRAM_COUNT_KEY;
- static const char *const MAX_TRIGRAM_COUNT_KEY;
static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE;
static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE;
static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID;
- static const int DEFAULT_MAX_UNIGRAM_COUNT;
- static const int DEFAULT_MAX_BIGRAM_COUNT;
- static const int DEFAULT_MAX_TRIGRAM_COUNT;
const FormatUtils::FORMAT_VERSION mDictFormatVersion;
const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags;
@@ -291,21 +249,18 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
const bool mIsDecayingDict;
const int mDate;
const int mLastDecayedTime;
- const int mUnigramCount;
- const int mBigramCount;
- const int mTrigramCount;
+ const EntryCounts mNgramCounts;
+ const EntryCounts mMaxNgramCounts;
const int mExtendedRegionSize;
const bool mHasHistoricalInfoOfWords;
const int mForgettingCurveProbabilityValuesTableId;
- const int mMaxUnigramCount;
- const int mMaxBigramCount;
- const int mMaxTrigramCount;
const int *const mCodePointTable;
const std::vector<int> readLocale() const;
float readMultipleWordCostMultiplier() const;
bool readRequiresGermanUmlautProcessing() const;
-
+ const EntryCounts readNgramCounts() const;
+ const EntryCounts readMaxNgramCounts() const;
static DictionaryHeaderStructurePolicy::AttributeMap createAttributeMapAndReadAllAttributes(
const uint8_t *const dictBuf);
};
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index ca7d93b0e..051aed45a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -303,7 +303,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty,
&addedNewUnigram)) {
if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
- mEntryCounters.incrementUnigramCount();
+ mEntryCounters.incrementNgramCount(NgramType::Unigram);
}
if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target.
@@ -397,7 +397,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramPrope
if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
wordPos, ngramProperty, &addedNewBigram)) {
if (addedNewBigram) {
- mEntryCounters.incrementBigramCount();
+ mEntryCounters.incrementNgramCount(NgramType::Bigram);
}
return true;
} else {
@@ -438,7 +438,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.removeNgramEntry(
PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
- mEntryCounters.decrementBigramCount();
+ mEntryCounters.decrementNgramCount(NgramType::Bigram);
return true;
} else {
return false;
@@ -525,20 +525,23 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
char *const outResult, const int maxResultLength) {
const int compareLength = queryLength + 1 /* terminator */;
if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount());
+ snprintf(outResult, maxResultLength, "%d",
+ mEntryCounters.getNgramCount(NgramType::Unigram));
} else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount());
+ snprintf(outResult, maxResultLength, "%d", mEntryCounters.getNgramCount(NgramType::Bigram));
} else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit(
- mHeaderPolicy->getMaxUnigramCount()) :
+ mHeaderPolicy->getMaxNgramCounts().getNgramCount(
+ NgramType::Unigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit(
- mHeaderPolicy->getMaxBigramCount()) :
+ mHeaderPolicy->getMaxNgramCounts().getNgramCount(
+ NgramType::Bigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
}
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
index 0480876ed..80b1111b4 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
@@ -76,8 +76,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
&mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
- mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(),
- mHeaderPolicy->getTrigramCount()),
+ mEntryCounters(mHeaderPolicy->getNgramCounts().getCountArray()),
mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {};
virtual int getRootPosition() const {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp
index a033d396b..985c16803 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp
@@ -53,8 +53,8 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPat
entryCounts, extendedRegionSize, &headerBuffer)) {
AKLOGE("Cannot write header structure to buffer. "
"updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, "
- "extendedRegionSize: %d", false, entryCounts.getUnigramCount(),
- entryCounts.getBigramCount(), extendedRegionSize);
+ "extendedRegionSize: %d", false, entryCounts.getNgramCount(NgramType::Unigram),
+ entryCounts.getNgramCount(NgramType::Bigram), extendedRegionSize);
return false;
}
return mBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@@ -73,9 +73,11 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
}
BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
+ MutableEntryCounters entryCounters;
+ entryCounters.setNgramCount(NgramType::Unigram, unigramCount);
+ entryCounters.setNgramCount(NgramType::Bigram, bigramCount);
if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
- EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */),
- 0 /* extendedRegionSize */, &headerBuffer)) {
+ entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) {
return false;
}
return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@@ -107,7 +109,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
}
const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted
.getValidUnigramCount();
- const int maxUnigramCount = headerPolicy->getMaxUnigramCount();
+ const int maxUnigramCount = headerPolicy->getMaxNgramCounts().getNgramCount(NgramType::Unigram);
if (headerPolicy->isDecayingDict() && unigramCount > maxUnigramCount) {
if (!truncateUnigrams(&ptNodeReader, &ptNodeWriter, maxUnigramCount)) {
AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount,
@@ -124,7 +126,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
return false;
}
const int bigramCount = traversePolicyToUpdateBigramProbability.getValidBigramEntryCount();
- const int maxBigramCount = headerPolicy->getMaxBigramCount();
+ const int maxBigramCount = headerPolicy->getMaxNgramCounts().getNgramCount(NgramType::Bigram);
if (headerPolicy->isDecayingDict() && bigramCount > maxBigramCount) {
if (!truncateBigrams(maxBigramCount)) {
AKLOGE("Cannot remove bigrams. current: %d, max: %d", bigramCount, maxBigramCount);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp
index b0fbb3e72..29bc7f719 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp
@@ -18,17 +18,13 @@
namespace latinime {
-// These counts are used to provide stable probabilities even if the user's input count is small.
-const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_UNIGRAMS = 8192;
-const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_BIGRAMS = 2;
-const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_TRIGRAMS = 2;
+// Used to provide stable probabilities even if the user's input count is small.
+const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNTS[] = {8192, 2, 2};
-// These are encoded backoff weights.
+// Encoded backoff weights.
// Note that we give positive value for trigrams that means the weight is more than 1.
// TODO: Apply backoff for main dictionaries and quit giving a positive backoff weight.
-const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS = -32;
-const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS = 0;
-const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS = 8;
+const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHTS[] = {-32, 0, 8};
// This value is used to remove too old entries from the dictionary.
const int DynamicLanguageModelProbabilityUtils::DURATION_TO_DISCARD_ENTRY_IN_SECONDS =
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h
index 88bc58fe8..b38047f49 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h
@@ -21,6 +21,7 @@
#include "defines.h"
#include "suggest/core/dictionary/property/historical_info.h"
+#include "utils/ngram_utils.h"
#include "utils/time_keeper.h"
namespace latinime {
@@ -28,46 +29,14 @@ namespace latinime {
class DynamicLanguageModelProbabilityUtils {
public:
static float computeRawProbabilityFromCounts(const int count, const int contextCount,
- const int matchedWordCountInContext) {
- int minCount = 0;
- switch (matchedWordCountInContext) {
- case 1:
- minCount = ASSUMED_MIN_COUNT_FOR_UNIGRAMS;
- break;
- case 2:
- minCount = ASSUMED_MIN_COUNT_FOR_BIGRAMS;
- break;
- case 3:
- minCount = ASSUMED_MIN_COUNT_FOR_TRIGRAMS;
- break;
- default:
- AKLOGE("computeRawProbabilityFromCounts is called with invalid "
- "matchedWordCountInContext (%d).", matchedWordCountInContext);
- ASSERT(false);
- return 0.0f;
- }
+ const NgramType ngramType) {
+ const int minCount = ASSUMED_MIN_COUNTS[static_cast<int>(ngramType)];
return static_cast<float>(count) / static_cast<float>(std::max(contextCount, minCount));
}
- static float backoff(const int ngramProbability, const int matchedWordCountInContext) {
- int probability = NOT_A_PROBABILITY;
-
- switch (matchedWordCountInContext) {
- case 1:
- probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS;
- break;
- case 2:
- probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS;
- break;
- case 3:
- probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS;
- break;
- default:
- AKLOGE("backoff is called with invalid matchedWordCountInContext (%d).",
- matchedWordCountInContext);
- ASSERT(false);
- return NOT_A_PROBABILITY;
- }
+ static float backoff(const int ngramProbability, const NgramType ngramType) {
+ const int probability =
+ ngramProbability + ENCODED_BACKOFF_WEIGHTS[static_cast<int>(ngramType)];
return std::min(std::max(probability, NOT_A_PROBABILITY), MAX_PROBABILITY);
}
@@ -99,14 +68,8 @@ private:
static_assert(MAX_PREV_WORD_COUNT_FOR_N_GRAM <= 2, "Max supported Ngram is Trigram.");
- static const int ASSUMED_MIN_COUNT_FOR_UNIGRAMS;
- static const int ASSUMED_MIN_COUNT_FOR_BIGRAMS;
- static const int ASSUMED_MIN_COUNT_FOR_TRIGRAMS;
-
- static const int ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS;
- static const int ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS;
- static const int ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS;
-
+ static const int ASSUMED_MIN_COUNTS[];
+ static const int ENCODED_BACKOFF_WEIGHTS[];
static const int DURATION_TO_DISCARD_ENTRY_IN_SECONDS;
};
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index 31b1ea696..6db7ea444 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -21,6 +21,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
+#include "utils/ngram_utils.h"
namespace latinime {
@@ -89,16 +90,17 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
}
contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount();
}
+ const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(i + 1);
const float rawProbability =
DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts(
- historicalInfo->getCount(), contextCount, i + 1);
+ historicalInfo->getCount(), contextCount, ngramType);
const int encodedRawProbability =
ProbabilityUtils::encodeRawProbability(rawProbability);
const int decayedProbability =
DynamicLanguageModelProbabilityUtils::getDecayedProbability(
encodedRawProbability, *historicalInfo);
probability = DynamicLanguageModelProbabilityUtils::backoff(
- decayedProbability, i + 1 /* n */);
+ decayedProbability, ngramType);
} else {
probability = probabilityEntry.getProbability();
}
@@ -198,18 +200,19 @@ bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCo
MutableEntryCounters *const outEntryCounters) {
for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) {
const int totalWordCount = prevWordCount + 1;
- if (currentEntryCounts.getNgramCount(totalWordCount)
- <= maxEntryCounts.getNgramCount(totalWordCount)) {
- outEntryCounters->setNgramCount(totalWordCount,
- currentEntryCounts.getNgramCount(totalWordCount));
+ const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(totalWordCount);
+ if (currentEntryCounts.getNgramCount(ngramType)
+ <= maxEntryCounts.getNgramCount(ngramType)) {
+ outEntryCounters->setNgramCount(ngramType,
+ currentEntryCounts.getNgramCount(ngramType));
continue;
}
int entryCount = 0;
if (!turncateEntriesInSpecifiedLevel(headerPolicy,
- maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) {
+ maxEntryCounts.getNgramCount(ngramType), prevWordCount, &entryCount)) {
return false;
}
- outEntryCounters->setNgramCount(totalWordCount, entryCount);
+ outEntryCounters->setNgramCount(ngramType, entryCount);
}
return true;
}
@@ -246,7 +249,10 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
mGlobalCounters.updateMaxValueOfCounters(
updatedNgramProbabilityEntry.getHistoricalInfo()->getCount());
if (!originalNgramProbabilityEntry.isValid()) {
- entryCountersToUpdate->incrementNgramCount(i + 2);
+ // (i + 2) words are used in total because the prevWords consists of (i + 1) words when
+ // looking at its i-th element.
+ entryCountersToUpdate->incrementNgramCount(
+ NgramUtils::getNgramTypeFromWordCount(i + 2));
}
}
return true;
@@ -369,7 +375,8 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
}
}
}
- outEntryCounters->incrementNgramCount(prevWordCount + 1);
+ outEntryCounters->incrementNgramCount(
+ NgramUtils::getNgramTypeFromWordCount(prevWordCount + 1));
if (!entry.hasNextLevelMap()) {
continue;
}
@@ -402,7 +409,8 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
for (int i = 0; i < entryCountToRemove; ++i) {
const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
if (!removeNgramProbabilityEntry(
- WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) {
+ WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount),
+ entryInfo.mKey)) {
return false;
}
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 7449cd02b..a96719533 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -31,6 +31,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h"
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
+#include "utils/ngram_utils.h"
namespace latinime {
@@ -215,7 +216,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty,
&addedNewUnigram)) {
if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
- mEntryCounters.incrementUnigramCount();
+ mEntryCounters.incrementNgramCount(NgramType::Unigram);
}
if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target.
@@ -263,7 +264,7 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
return false;
}
if (!ptNodeParams.representsNonWordInfo()) {
- mEntryCounters.decrementUnigramCount();
+ mEntryCounters.decrementNgramCount(NgramType::Unigram);
}
return true;
}
@@ -321,7 +322,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramPrope
bool addedNewEntry = false;
if (mNodeWriter.addNgramEntry(prevWordIds, wordId, ngramProperty, &addedNewEntry)) {
if (addedNewEntry) {
- mEntryCounters.incrementNgramCount(prevWordIds.size() + 1);
+ mEntryCounters.incrementNgramCount(
+ NgramUtils::getNgramTypeFromWordCount(prevWordIds.size() + 1));
}
return true;
} else {
@@ -359,7 +361,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon
return false;
}
if (mNodeWriter.removeNgramEntry(prevWordIds, wordId)) {
- mEntryCounters.decrementNgramCount(prevWordIds.size());
+ mEntryCounters.decrementNgramCount(
+ NgramUtils::getNgramTypeFromWordCount(prevWordIds.size() + 1));
return true;
} else {
return false;
@@ -477,20 +480,23 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
char *const outResult, const int maxResultLength) {
const int compareLength = queryLength + 1 /* terminator */;
if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount());
+ snprintf(outResult, maxResultLength, "%d",
+ mEntryCounters.getNgramCount(NgramType::Unigram));
} else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount());
+ snprintf(outResult, maxResultLength, "%d", mEntryCounters.getNgramCount(NgramType::Bigram));
} else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit(
- mHeaderPolicy->getMaxUnigramCount()) :
+ mHeaderPolicy->getMaxNgramCounts().getNgramCount(
+ NgramType::Unigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit(
- mHeaderPolicy->getMaxBigramCount()) :
+ mHeaderPolicy->getMaxNgramCounts().getNgramCount(
+ NgramType::Bigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
}
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index 13700b390..93faa83a0 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -51,8 +51,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
&mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
- mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(),
- mHeaderPolicy->getTrigramCount()),
+ mEntryCounters(mHeaderPolicy->getNgramCounts().getCountArray()),
mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {};
AK_FORCE_INLINE int getRootPosition() const {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index 7f0604ce8..34af76c5d 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -29,6 +29,7 @@
#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
#include "suggest/policyimpl/dictionary/utils/file_utils.h"
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
+#include "utils/ngram_utils.h"
namespace latinime {
@@ -43,8 +44,9 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPat
entryCounts, extendedRegionSize, &headerBuffer)) {
AKLOGE("Cannot write header structure to buffer. "
"updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, trigramCount: %d,"
- "extendedRegionSize: %d", false, entryCounts.getUnigramCount(),
- entryCounts.getBigramCount(), entryCounts.getTrigramCount(),
+ "extendedRegionSize: %d", false, entryCounts.getNgramCount(NgramType::Unigram),
+ entryCounts.getNgramCount(NgramType::Bigram),
+ entryCounts.getNgramCount(NgramType::Trigram),
extendedRegionSize);
return false;
}
@@ -86,8 +88,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
return false;
}
if (headerPolicy->isDecayingDict()) {
- const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(),
- headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount());
+ const EntryCounts &maxEntryCounts = headerPolicy->getMaxNgramCounts();
if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(
outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy,
outEntryCounters)) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
index 73dc42a18..7269913e8 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
@@ -20,6 +20,7 @@
#include <array>
#include "defines.h"
+#include "utils/ngram_utils.h"
namespace latinime {
@@ -28,34 +29,22 @@ class EntryCounts final {
public:
EntryCounts() : mEntryCounts({{0, 0, 0}}) {}
- EntryCounts(const int unigramCount, const int bigramCount, const int trigramCount)
- : mEntryCounts({{unigramCount, bigramCount, trigramCount}}) {}
-
explicit EntryCounts(const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters)
: mEntryCounts(counters) {}
- int getUnigramCount() const {
- return mEntryCounts[0];
- }
-
- int getBigramCount() const {
- return mEntryCounts[1];
- }
-
- int getTrigramCount() const {
- return mEntryCounts[2];
+ int getNgramCount(const NgramType ngramType) const {
+ return mEntryCounts[static_cast<int>(ngramType)];
}
- int getNgramCount(const size_t n) const {
- if (n < 1 || n > mEntryCounts.size()) {
- return 0;
- }
- return mEntryCounts[n - 1];
+ const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &getCountArray() const {
+ return mEntryCounts;
}
private:
DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts);
+ // Counts from Unigram (0-th element) to (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram
+ // (MAX_PREV_WORD_COUNT_FOR_N_GRAM-th element)
const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounts;
};
@@ -65,68 +54,35 @@ class MutableEntryCounters final {
mEntryCounters.fill(0);
}
- MutableEntryCounters(const int unigramCount, const int bigramCount, const int trigramCount)
- : mEntryCounters({{unigramCount, bigramCount, trigramCount}}) {}
+ explicit MutableEntryCounters(
+ const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters)
+ : mEntryCounters(counters) {}
const EntryCounts getEntryCounts() const {
return EntryCounts(mEntryCounters);
}
- int getUnigramCount() const {
- return mEntryCounters[0];
- }
-
- int getBigramCount() const {
- return mEntryCounters[1];
- }
-
- int getTrigramCount() const {
- return mEntryCounters[2];
- }
-
- void incrementUnigramCount() {
- ++mEntryCounters[0];
- }
-
- void decrementUnigramCount() {
- ASSERT(mEntryCounters[0] != 0);
- --mEntryCounters[0];
- }
-
- void incrementBigramCount() {
- ++mEntryCounters[1];
- }
-
- void decrementBigramCount() {
- ASSERT(mEntryCounters[1] != 0);
- --mEntryCounters[1];
+ void incrementNgramCount(const NgramType ngramType) {
+ ++mEntryCounters[static_cast<int>(ngramType)];
}
- void incrementNgramCount(const size_t n) {
- if (n < 1 || n > mEntryCounters.size()) {
- return;
- }
- ++mEntryCounters[n - 1];
+ void decrementNgramCount(const NgramType ngramType) {
+ --mEntryCounters[static_cast<int>(ngramType)];
}
- void decrementNgramCount(const size_t n) {
- if (n < 1 || n > mEntryCounters.size()) {
- return;
- }
- ASSERT(mEntryCounters[n - 1] != 0);
- --mEntryCounters[n - 1];
+ int getNgramCount(const NgramType ngramType) const {
+ return mEntryCounters[static_cast<int>(ngramType)];
}
- void setNgramCount(const size_t n, const int count) {
- if (n < 1 || n > mEntryCounters.size()) {
- return;
- }
- mEntryCounters[n - 1] = count;
+ void setNgramCount(const NgramType ngramType, const int count) {
+ mEntryCounters[static_cast<int>(ngramType)] = count;
}
private:
DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters);
+ // Counters from Unigram (0-th element) to (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram
+ // (MAX_PREV_WORD_COUNT_FOR_N_GRAM-th element)
std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounters;
};
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp
index 9055f7bfc..f05c6149e 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp
@@ -126,20 +126,13 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
/* static */ bool ForgettingCurveUtils::needsToDecay(const bool mindsBlockByDecay,
const EntryCounts &entryCounts, const HeaderPolicy *const headerPolicy) {
- if (entryCounts.getUnigramCount()
- >= getEntryCountHardLimit(headerPolicy->getMaxUnigramCount())) {
- // Unigram count exceeds the limit.
- return true;
- }
- if (entryCounts.getBigramCount()
- >= getEntryCountHardLimit(headerPolicy->getMaxBigramCount())) {
- // Bigram count exceeds the limit.
- return true;
- }
- if (entryCounts.getTrigramCount()
- >= getEntryCountHardLimit(headerPolicy->getMaxTrigramCount())) {
- // Trigram count exceeds the limit.
- return true;
+ const EntryCounts &maxNgramCounts = headerPolicy->getMaxNgramCounts();
+ for (const auto ngramType : AllNgramTypes::ASCENDING) {
+ if (entryCounts.getNgramCount(ngramType)
+ >= getEntryCountHardLimit(maxNgramCounts.getNgramCount(ngramType))) {
+ // Unigram count exceeds the limit.
+ return true;
+ }
}
if (mindsBlockByDecay) {
return false;
diff --git a/native/jni/src/utils/ngram_utils.h b/native/jni/src/utils/ngram_utils.h
new file mode 100644
index 000000000..6227812d4
--- /dev/null
+++ b/native/jni/src/utils/ngram_utils.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LATINIME_NGRAM_UTILS_H
+#define LATINIME_NGRAM_UTILS_H
+
+#include "defines.h"
+
+namespace latinime {
+
+enum class NgramType : int {
+ Unigram = 0,
+ Bigram = 1,
+ Trigram = 2,
+ NotANgramType = -1,
+};
+
+namespace AllNgramTypes {
+// Use anonymous namespace to avoid ODR (One Definition Rule) violation.
+namespace {
+
+const NgramType ASCENDING[] = {
+ NgramType::Unigram, NgramType::Bigram, NgramType::Trigram
+};
+
+const NgramType DESCENDING[] = {
+ NgramType::Trigram, NgramType::Bigram, NgramType::Unigram
+};
+
+} // namespace
+} // namespace AllNgramTypes
+
+class NgramUtils final {
+ public:
+ static AK_FORCE_INLINE NgramType getNgramTypeFromWordCount(const int wordCount) {
+ // Max supported ngram is (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram.
+ if (wordCount <= 0 || wordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1) {
+ return NgramType::NotANgramType;
+ }
+ // Convert word count to 0-origin enum value.
+ return static_cast<NgramType>(wordCount - 1);
+ }
+
+ private:
+ DISALLOW_IMPLICIT_CONSTRUCTORS(NgramUtils);
+
+};
+}
+#endif /* LATINIME_NGRAM_UTILS_H */