summaryrefslogtreecommitdiffstats
path: root/native
diff options
context:
space:
mode:
authorKeisuke Kuroyanagi <ksk@google.com>2014-10-30 23:38:19 +0900
committerKeisuke Kuroyanagi <ksk@google.com>2014-10-30 23:38:19 +0900
commitbcb52d73e206cee86a2ea126a5c3f948103057c6 (patch)
tree5149ec9918ec829b898b1285ece4eba300da0d3f /native
parent660b00477c980d74be48529b9de70d9725ffc72b (diff)
downloadandroid_packages_inputmethods_LatinIME-bcb52d73e206cee86a2ea126a5c3f948103057c6.tar.gz
android_packages_inputmethods_LatinIME-bcb52d73e206cee86a2ea126a5c3f948103057c6.tar.bz2
android_packages_inputmethods_LatinIME-bcb52d73e206cee86a2ea126a5c3f948103057c6.zip
Enable count based dynamic ngram language model for v403.
Bug: 14425059 Change-Id: Icc15e14cfd77d37cd75f75318fd0fa36f9ca7a5b
Diffstat (limited to 'native')
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.cpp3
-rw-r--r--native/jni/src/suggest/core/dictionary/ngram_listener.h2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp127
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h26
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h4
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp55
6 files changed, 128 insertions, 89 deletions
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp
index bfe17cc4c..6a5df9d95 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp
@@ -81,6 +81,9 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
}
const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext(
mPrevWordIds, targetWordId, nullptr /* multiBigramMap */);
+ if (wordAttributes.getProbability() == NOT_A_PROBABILITY) {
+ return;
+ }
mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount,
wordAttributes.getProbability());
}
diff --git a/native/jni/src/suggest/core/dictionary/ngram_listener.h b/native/jni/src/suggest/core/dictionary/ngram_listener.h
index e9b3c1aaf..2eb5e9fd1 100644
--- a/native/jni/src/suggest/core/dictionary/ngram_listener.h
+++ b/native/jni/src/suggest/core/dictionary/ngram_listener.h
@@ -26,6 +26,8 @@ namespace latinime {
*/
class NgramListener {
public:
+ // ngramProbability is always 0 for v403 decaying dictionary.
+ // TODO: Remove ngramProbability.
virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0;
virtual ~NgramListener() {};
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index 05a3a6356..31b1ea696 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -19,11 +19,11 @@
#include <algorithm>
#include <cstring>
-#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
+#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h"
+#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
namespace latinime {
-const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;
const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0;
const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1;
@@ -39,7 +39,8 @@ bool LanguageModelDictContent::runGC(
}
const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
- const int wordId, const HeaderPolicy *const headerPolicy) const {
+ const int wordId, const bool mustMatchAllPrevWords,
+ const HeaderPolicy *const headerPolicy) const {
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
int maxPrevWordCount = 0;
@@ -53,7 +54,15 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
}
+ const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
+ if (mHasHistoricalInfo && unigramProbabilityEntry.getHistoricalInfo()->getCount() == 0) {
+ // The word should be treated as a invalid word.
+ return WordAttributes();
+ }
for (int i = maxPrevWordCount; i >= 0; --i) {
+ if (mustMatchAllPrevWords && prevWordIds.size() > static_cast<size_t>(i)) {
+ break;
+ }
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
if (!result.mIsValid) {
continue;
@@ -62,36 +71,39 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
int probability = NOT_A_PROBABILITY;
if (mHasHistoricalInfo) {
- const int rawProbability = ForgettingCurveUtils::decodeProbability(
- probabilityEntry.getHistoricalInfo(), headerPolicy);
- if (rawProbability == NOT_A_PROBABILITY) {
- // The entry should not be treated as a valid entry.
- continue;
- }
+ const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
+ int contextCount = 0;
if (i == 0) {
// unigram
- probability = rawProbability;
+ contextCount = mGlobalCounters.getTotalCount();
} else {
const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
if (!prevWordProbabilityEntry.isValid()) {
continue;
}
- if (prevWordProbabilityEntry.representsBeginningOfSentence()) {
- probability = rawProbability;
- } else {
- const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability(
- prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy);
- probability = std::min(MAX_PROBABILITY - prevWordRawProbability
- + rawProbability, MAX_PROBABILITY);
+ if (prevWordProbabilityEntry.representsBeginningOfSentence()
+ && historicalInfo->getCount() == 1) {
+ // BoS ngram requires multiple contextCount.
+ continue;
}
+ contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount();
}
+ const float rawProbability =
+ DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts(
+ historicalInfo->getCount(), contextCount, i + 1);
+ const int encodedRawProbability =
+ ProbabilityUtils::encodeRawProbability(rawProbability);
+ const int decayedProbability =
+ DynamicLanguageModelProbabilityUtils::getDecayedProbability(
+ encodedRawProbability, *historicalInfo);
+ probability = DynamicLanguageModelProbabilityUtils::backoff(
+ decayedProbability, i + 1 /* n */);
} else {
probability = probabilityEntry.getProbability();
}
// TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in
// probabilityEntry.
- const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(),
unigramProbabilityEntry.isNotAWord(),
unigramProbabilityEntry.isPossiblyOffensive());
@@ -167,7 +179,8 @@ void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
if (probabilityEntry.isValid()) {
const WordAttributes wordAttributes = getWordAttributes(
- WordIdArrayView(*prevWordIds), wordId, headerPolicy);
+ WordIdArrayView(*prevWordIds), wordId, true /* mustMatchAllPrevWords */,
+ headerPolicy);
outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
wordAttributes, probabilityEntry);
}
@@ -231,7 +244,7 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
return false;
}
mGlobalCounters.updateMaxValueOfCounters(
- updatedUnigramProbabilityEntry.getHistoricalInfo()->getCount());
+ updatedNgramProbabilityEntry.getHistoricalInfo()->getCount());
if (!originalNgramProbabilityEntry.isValid()) {
entryCountersToUpdate->incrementNgramCount(i + 2);
}
@@ -242,10 +255,9 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
const ProbabilityEntry &originalProbabilityEntry, const bool isValid,
const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const {
- const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo(
- originalProbabilityEntry.getHistoricalInfo(), isValid ?
- DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY,
- &historicalInfo, headerPolicy);
+ const HistoricalInfo updatedHistoricalInfo = HistoricalInfo(historicalInfo.getTimestamp(),
+ 0 /* level */, originalProbabilityEntry.getHistoricalInfo()->getCount()
+ + historicalInfo.getCount());
if (originalProbabilityEntry.isValid()) {
return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo);
} else {
@@ -311,7 +323,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
const int prevWordCount, const HeaderPolicy *const headerPolicy,
- MutableEntryCounters *const outEntryCounters) {
+ const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters) {
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
@@ -328,33 +340,41 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
}
continue;
}
- if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()
- && probabilityEntry.isValid()) {
- const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
- probabilityEntry.getHistoricalInfo(), headerPolicy);
- if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
- // Update the entry.
- const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
- if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
- bitmapEntryIndex)) {
- return false;
- }
- } else {
+ if (mHasHistoricalInfo && probabilityEntry.isValid()) {
+ const HistoricalInfo *originalHistoricalInfo = probabilityEntry.getHistoricalInfo();
+ if (DynamicLanguageModelProbabilityUtils::shouldRemoveEntryDuringGC(
+ *originalHistoricalInfo)) {
// Remove the entry.
if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
return false;
}
continue;
}
+ if (needsToHalveCounters) {
+ const int updatedCount = originalHistoricalInfo->getCount() / 2;
+ if (updatedCount == 0) {
+ // Remove the entry.
+ if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
+ return false;
+ }
+ continue;
+ }
+ const HistoricalInfo historicalInfoToSave(originalHistoricalInfo->getTimestamp(),
+ originalHistoricalInfo->getLevel(), updatedCount);
+ const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(),
+ &historicalInfoToSave);
+ if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
+ bitmapEntryIndex)) {
+ return false;
+ }
+ }
}
- if (!probabilityEntry.representsBeginningOfSentence()) {
- outEntryCounters->incrementNgramCount(prevWordCount + 1);
- }
+ outEntryCounters->incrementNgramCount(prevWordCount + 1);
if (!entry.hasNextLevelMap()) {
continue;
}
if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
- prevWordCount + 1, headerPolicy, outEntryCounters)) {
+ prevWordCount + 1, headerPolicy, needsToHalveCounters, outEntryCounters)) {
return false;
}
}
@@ -408,11 +428,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli
}
const ProbabilityEntry probabilityEntry =
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
- const int probability = (mHasHistoricalInfo) ?
- ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
- headerPolicy) : probabilityEntry.getProbability();
- outEntryInfo->emplace_back(probability,
- probabilityEntry.getHistoricalInfo()->getTimestamp(),
+ const int priority = mHasHistoricalInfo
+ ? DynamicLanguageModelProbabilityUtils::getPriorityToPreventFromEviction(
+ *probabilityEntry.getHistoricalInfo())
+ : probabilityEntry.getProbability();
+ outEntryInfo->emplace_back(priority, probabilityEntry.getHistoricalInfo()->getCount(),
entry.key(), targetLevel, prevWordIds->data());
}
return true;
@@ -420,11 +440,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli
bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
- if (left.mProbability != right.mProbability) {
- return left.mProbability < right.mProbability;
+ if (left.mPriority != right.mPriority) {
+ return left.mPriority < right.mPriority;
}
- if (left.mTimestamp != right.mTimestamp) {
- return left.mTimestamp > right.mTimestamp;
+ if (left.mCount != right.mCount) {
+ return left.mCount < right.mCount;
}
if (left.mKey != right.mKey) {
return left.mKey < right.mKey;
@@ -441,10 +461,9 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
return false;
}
-LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
- const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds)
- : mProbability(probability), mTimestamp(timestamp), mKey(key),
- mPrevWordCount(prevWordCount) {
+LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int priority,
+ const int count, const int key, const int prevWordCount, const int *const prevWordIds)
+ : mPriority(priority), mCount(count), mKey(key), mPrevWordCount(prevWordCount) {
memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0]));
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 5b92b96e3..9678c35f9 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -151,13 +151,14 @@ class LanguageModelDictContent {
const LanguageModelDictContent *const originalContent);
const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
- const HeaderPolicy *const headerPolicy) const;
+ const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const;
ProbabilityEntry getProbabilityEntry(const int wordId) const {
return getNgramProbabilityEntry(WordIdArrayView(), wordId);
}
bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) {
+ mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount());
return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry);
}
@@ -180,8 +181,15 @@ class LanguageModelDictContent {
bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
MutableEntryCounters *const outEntryCounters) {
- return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
- 0 /* prevWordCount */, headerPolicy, outEntryCounters);
+ if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
+ 0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(),
+ outEntryCounters)) {
+ return false;
+ }
+ if (mGlobalCounters.needsToHalveCounters()) {
+ mGlobalCounters.halveCounters();
+ }
+ return true;
}
// entryCounts should be created by updateAllProbabilityEntries.
@@ -206,11 +214,12 @@ class LanguageModelDictContent {
DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
};
- EntryInfoToTurncate(const int probability, const int timestamp, const int key,
+ EntryInfoToTurncate(const int priority, const int count, const int key,
const int prevWordCount, const int *const prevWordIds);
- int mProbability;
- int mTimestamp;
+ int mPriority;
+ // TODO: Remove.
+ int mCount;
int mKey;
int mPrevWordCount;
int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
@@ -219,8 +228,6 @@ class LanguageModelDictContent {
DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
};
- // TODO: Remove
- static const int DUMMY_PROBABILITY_FOR_VALID_WORDS;
static const int TRIE_MAP_BUFFER_INDEX;
static const int GLOBAL_COUNTERS_BUFFER_INDEX;
@@ -233,7 +240,8 @@ class LanguageModelDictContent {
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
- const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
+ const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters,
+ MutableEntryCounters *const outEntryCounters);
bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
const int maxEntryCount, const int targetLevel, int *const outEntryCount);
bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h
index 9953aa425..283c2691a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h
@@ -63,6 +63,10 @@ class LanguageModelDictContentGlobalCounters {
mTotalCount += 1;
}
+ void addToTotalCount(const int count) {
+ mTotalCount += count;
+ }
+
void updateMaxValueOfCounters(const int count) {
mMaxValueOfCounters = std::max(count, mMaxValueOfCounters);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index d3de322f9..96d789f58 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -110,7 +110,7 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
return WordAttributes();
}
return mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId,
- mHeaderPolicy);
+ false /* mustMatchAllPrevWords */, mHeaderPolicy);
}
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
@@ -118,18 +118,13 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI
if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) {
return NOT_A_PROBABILITY;
}
- const ProbabilityEntry probabilityEntry =
- mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(prevWordIds, wordId);
- if (!probabilityEntry.isValid() || probabilityEntry.isBlacklisted()
- || probabilityEntry.isNotAWord()) {
+ const WordAttributes wordAttributes =
+ mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId,
+ true /* mustMatchAllPrevWords */, mHeaderPolicy);
+ if (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord()) {
return NOT_A_PROBABILITY;
}
- if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
- return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
- mHeaderPolicy);
- } else {
- return probabilityEntry.getProbability();
- }
+ return wordAttributes.getProbability();
}
BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
@@ -152,9 +147,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI
continue;
}
const int probability = probabilityEntry.hasHistoricalInfo() ?
- ForgettingCurveUtils::decodeProbability(
- probabilityEntry.getHistoricalInfo(), mHeaderPolicy) :
- probabilityEntry.getProbability();
+ 0 : probabilityEntry.getProbability();
listener->onVisitEntry(probability, entry.getWordId());
}
}
@@ -386,25 +379,35 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext(
AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext().");
return false;
}
+ if (!isValidWord) {
+ return true;
+ }
wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */);
}
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray,
false /* tryLowerCaseSearch */);
- if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID
- && ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) {
- const UnigramProperty beginningOfSentenceUnigramProperty(
- true /* representsBeginningOfSentence */,
- true /* isNotAWord */, false /* isPossiblyOffensive */, NOT_A_PROBABILITY,
- HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
- if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
- &beginningOfSentenceUnigramProperty)) {
- AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext().");
+ if (ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) {
+ if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID) {
+ const UnigramProperty beginningOfSentenceUnigramProperty(
+ true /* representsBeginningOfSentence */,
+ true /* isNotAWord */, false /* isPossiblyOffensive */, NOT_A_PROBABILITY,
+ HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
+ if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
+ &beginningOfSentenceUnigramProperty)) {
+ AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext().");
+ return false;
+ }
+ // Refresh word ids.
+ ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
+ }
+ // Update entries for beginning of sentence.
+ if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(
+ prevWordIds.skip(1 /* n */), prevWordIds[0], true /* isVaild */, historicalInfo,
+ mHeaderPolicy, &mEntryCounters)) {
return false;
}
- // Refresh word ids.
- ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
}
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds,
wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &mEntryCounters)) {
@@ -542,7 +545,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
}
}
const WordAttributes wordAttributes = languageModelDictContent->getWordAttributes(
- WordIdArrayView(), wordId, mHeaderPolicy);
+ WordIdArrayView(), wordId, true /* mustMatchAllPrevWords */, mHeaderPolicy);
const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(wordId);
const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),