From 4d56be920b6212032f5024008a5c2fa5a2944d8a Mon Sep 17 00:00:00 2001 From: yanyiwu Date: Thu, 8 Oct 2015 20:05:27 +0800 Subject: [PATCH] support optional user word freq weight --- ChangeLog.md | 3 +- src/Application.hpp | 2 +- src/DictTrie.hpp | 113 ++++++++++++++--------------- src/FullSegment.hpp | 2 +- src/Jieba.hpp | 2 +- src/MPSegment.hpp | 8 +- src/MixSegment.hpp | 4 +- src/PosTagger.hpp | 2 +- src/SegmentBase.hpp | 16 ---- src/Trie.hpp | 4 +- test/demo.cpp | 2 +- test/unittest/application_test.cpp | 4 +- test/unittest/trie_test.cpp | 28 +++---- 13 files changed, 85 insertions(+), 105 deletions(-) diff --git a/ChangeLog.md b/ChangeLog.md index 11b482c..65fd53c 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -2,7 +2,8 @@ ## next version -1. 支持多个userdict载入,多词典路径用英文冒号(:)作为分隔符,就当坐是向环境变量PATH致敬,哈哈。 +1. 支持多个userdict载入,多词典路径用英文冒号(:)作为分隔符,就当是向环境变量PATH致敬,哈哈。 +2. userdict是不带权重的,之前对于新的userword默认设置词频权重为最大值,现已支持可配置,默认使用中位值。 ## v3.2.1 diff --git a/src/Application.hpp b/src/Application.hpp index 1903b1d..d7d8789 100644 --- a/src/Application.hpp +++ b/src/Application.hpp @@ -63,7 +63,7 @@ class Application { vector& words, size_t max_word_len) const { jieba_.CutSmall(sentence, words, max_word_len); } - bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { + bool InsertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { return jieba_.InsertUserWord(word, tag); } void tag(const string& str, vector >& res) const { diff --git a/src/DictTrie.hpp b/src/DictTrie.hpp index 9c24069..39b6b8a 100644 --- a/src/DictTrie.hpp +++ b/src/DictTrie.hpp @@ -30,25 +30,48 @@ class DictTrie { Max, }; // enum UserWordWeightOption - DictTrie() { - trie_ = NULL; - min_weight_ = MAX_DOUBLE; - } - DictTrie(const string& dict_path, const string& user_dict_paths = "") { - new (this) DictTrie(); - init(dict_path, user_dict_paths); + DictTrie(const string& dict_path, const string& user_dict_paths = "", UserWordWeightOption user_word_weight_opt = Median) { + Init(dict_path, user_dict_paths, user_word_weight_opt); } + ~DictTrie() { delete trie_; } - void init(const string& dict_path, const string& user_dict_paths = "") { - if (trie_ != NULL) { - LogFatal("trie already initted"); + bool InsertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { + DictUnit node_info; + if (!MakeNodeInfo(node_info, word, max_weight_, tag)) { + return false; } + active_node_infos_.push_back(node_info); + trie_->insertNode(node_info.word, &active_node_infos_.back()); + return true; + } + + const DictUnit* Find(Unicode::const_iterator begin, Unicode::const_iterator end) const { + return trie_->Find(begin, end); + } + + void Find(Unicode::const_iterator begin, + Unicode::const_iterator end, + vector&res, + size_t max_word_len = MAX_WORD_LENGTH) const { + trie_->Find(begin, end, res, max_word_len); + } + + bool IsUserDictSingleChineseWord(const Rune& word) const { + return isIn(user_dict_single_chinese_word_, word); + } + + double GetMinWeight() const { + return min_weight_; + } + + private: + void Init(const string& dict_path, const string& user_dict_paths, UserWordWeightOption user_word_weight_opt) { LoadDict(dict_path); CalculateWeight(static_node_infos_); - SetStaticWordWeights(); + SetStaticWordWeights(user_word_weight_opt); if (user_dict_paths.size()) { LoadUserDict(user_dict_paths); @@ -57,36 +80,6 @@ class DictTrie { CreateTrie(static_node_infos_); } - bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { - DictUnit node_info; - if (!MakeUserNodeInfo(node_info, word, tag)) { - return false; - } - active_node_infos_.push_back(node_info); - trie_->insertNode(node_info.word, &active_node_infos_.back()); - return true; - } - - const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const { - return trie_->find(begin, end); - } - - void find(Unicode::const_iterator begin, - Unicode::const_iterator end, - vector&res, - size_t max_word_len = MAX_WORD_LENGTH) const { - trie_->find(begin, end, res, max_word_len); - } - - bool isUserDictSingleChineseWord(const Rune& word) const { - return isIn(user_dict_single_chinese_word_, word); - } - - double getMinWeight() const { - return min_weight_; - } - - private: void CreateTrie(const vector& dictUnits) { assert(dictUnits.size()); vector words; @@ -98,6 +91,7 @@ class DictTrie { trie_ = new Trie(words, valuePointers); } + void LoadUserDict(const string& filePaths) { vector files = limonp::split(filePaths, ":"); size_t lineno = 0; @@ -116,13 +110,19 @@ class DictTrie { LogFatal("split [%s] result illegal", line.c_str()); } DictUnit node_info; - MakeUserNodeInfo(node_info, buf[0], + MakeNodeInfo(node_info, + buf[0], + max_weight_, (buf.size() == 2 ? buf[1] : UNKNOWN_TAG)); static_node_infos_.push_back(node_info); + if (node_info.word.size() == 1) { + user_dict_single_chinese_word_.insert(node_info.word[0]); + } } } LogInfo("load userdicts[%s] ok. lines[%u]", filePaths.c_str(), lineno); } + bool MakeNodeInfo(DictUnit& node_info, const string& word, double weight, @@ -135,20 +135,7 @@ class DictTrie { node_info.tag = tag; return true; } - bool MakeUserNodeInfo(DictUnit& node_info, - const string& word, - const string& tag = UNKNOWN_TAG) { - if (!TransCode::decode(word, node_info.word)) { - LogError("decode %s failed.", word.c_str()); - return false; - } - if (node_info.word.size() == 1) { - user_dict_single_chinese_word_.insert(node_info.word[0]); - } - node_info.weight = max_weight_; - node_info.tag = tag; - return true; - } + void LoadDict(const string& filePath) { ifstream ifs(filePath.c_str()); if (!ifs.is_open()) { @@ -175,7 +162,7 @@ class DictTrie { return lhs.weight < rhs.weight; } - void SetStaticWordWeights() { + void SetStaticWordWeights(UserWordWeightOption option) { if (static_node_infos_.empty()) { LogFatal("something must be wrong"); } @@ -184,6 +171,17 @@ class DictTrie { min_weight_ = x[0].weight; max_weight_ = x[x.size() - 1].weight; median_weight_ = x[x.size() / 2].weight; + switch (option) { + case Min: + user_word_default_weight_ = min_weight_; + break; + case Median: + user_word_default_weight_ = median_weight_; + break; + default: + user_word_default_weight_ = max_weight_; + break; + } } void CalculateWeight(vector& node_infos) const { @@ -210,6 +208,7 @@ class DictTrie { double min_weight_; double max_weight_; double median_weight_; + double user_word_default_weight_; unordered_set user_dict_single_chinese_word_; }; } diff --git a/src/FullSegment.hpp b/src/FullSegment.hpp index 245efeb..e3c2fc1 100644 --- a/src/FullSegment.hpp +++ b/src/FullSegment.hpp @@ -54,7 +54,7 @@ class FullSegment: public SegmentBase { int wordLen = 0; assert(dictTrie_); vector dags; - dictTrie_->find(begin, end, dags); + dictTrie_->Find(begin, end, dags); for (size_t i = 0; i < dags.size(); i++) { for (size_t j = 0; j < dags[i].nexts.size(); j++) { const DictUnit* du = dags[i].nexts[j].second; diff --git a/src/Jieba.hpp b/src/Jieba.hpp index 57aed80..879134b 100644 --- a/src/Jieba.hpp +++ b/src/Jieba.hpp @@ -44,7 +44,7 @@ class Jieba { mp_seg_.cut(sentence, words, max_word_len); } bool InsertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { - return dict_trie_.insertUserWord(word, tag); + return dict_trie_.InsertUserWord(word, tag); } const DictTrie* GetDictTrie() const { diff --git a/src/MPSegment.hpp b/src/MPSegment.hpp index e683095..bcf5c99 100644 --- a/src/MPSegment.hpp +++ b/src/MPSegment.hpp @@ -45,7 +45,7 @@ class MPSegment: public SegmentBase { vector& words, size_t max_word_len = MAX_WORD_LENGTH) const { vector dags; - dictTrie_->find(begin, + dictTrie_->Find(begin, end, dags, max_word_len); @@ -57,8 +57,8 @@ class MPSegment: public SegmentBase { return dictTrie_; } - bool isUserDictSingleChineseWord(const Rune & value) const { - return dictTrie_->isUserDictSingleChineseWord(value); + bool IsUserDictSingleChineseWord(const Rune& value) const { + return dictTrie_->IsUserDictSingleChineseWord(value); } private: void CalcDP(vector& dags) const { @@ -81,7 +81,7 @@ class MPSegment: public SegmentBase { if (p) { val += p->weight; } else { - val += dictTrie_->getMinWeight(); + val += dictTrie_->GetMinWeight(); } if (val > rit->weight) { rit->pInfo = p; diff --git a/src/MixSegment.hpp b/src/MixSegment.hpp index 7b72289..8d8a8b6 100644 --- a/src/MixSegment.hpp +++ b/src/MixSegment.hpp @@ -48,14 +48,14 @@ class MixSegment: public SegmentBase { piece.reserve(end - begin); for (size_t i = 0, j = 0; i < words.size(); i++) { //if mp get a word, it's ok, put it into result - if (1 != words[i].size() || (words[i].size() == 1 && mpSeg_.isUserDictSingleChineseWord(words[i][0]))) { + if (1 != words[i].size() || (words[i].size() == 1 && mpSeg_.IsUserDictSingleChineseWord(words[i][0]))) { res.push_back(words[i]); continue; } // if mp get a single one and it is not in userdict, collect it in sequence j = i; - while (j < words.size() && 1 == words[j].size() && !mpSeg_.isUserDictSingleChineseWord(words[j][0])) { + while (j < words.size() && 1 == words[j].size() && !mpSeg_.IsUserDictSingleChineseWord(words[j][0])) { piece.push_back(words[j][0]); j++; } diff --git a/src/PosTagger.hpp b/src/PosTagger.hpp index eb33410..076d9f3 100644 --- a/src/PosTagger.hpp +++ b/src/PosTagger.hpp @@ -38,7 +38,7 @@ class PosTagger { LogError("decode failed."); return false; } - tmp = dict->find(unico.begin(), unico.end()); + tmp = dict->Find(unico.begin(), unico.end()); if (tmp == NULL || tmp->tag.empty()) { res.push_back(make_pair(*itr, SpecialRule(unico))); } else { diff --git a/src/SegmentBase.hpp b/src/SegmentBase.hpp index 4fcbcad..f1ece55 100644 --- a/src/SegmentBase.hpp +++ b/src/SegmentBase.hpp @@ -20,22 +20,6 @@ class SegmentBase { } ~SegmentBase() { } - /* - public: - void cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const = 0; - bool cut(const string& sentence, vector& words) const { - PreFilter pre_filter(symbols_, sentence); - PreFilter::Range range; - vector uwords; - uwords.reserve(sentence.size()); - while (pre_filter.HasNext()) { - range = pre_filter.Next(); - cut(range.begin, range.end, uwords); - } - TransCode::encode(uwords, words); - return true; - } - */ protected: void LoadSpecialSymbols() { diff --git a/src/Trie.hpp b/src/Trie.hpp index 6052232..8ab502b 100644 --- a/src/Trie.hpp +++ b/src/Trie.hpp @@ -65,7 +65,7 @@ class Trie { } } - const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const { + const DictUnit* Find(Unicode::const_iterator begin, Unicode::const_iterator end) const { if (begin == end) { return NULL; } @@ -85,7 +85,7 @@ class Trie { return ptNode->ptValue; } - void find(Unicode::const_iterator begin, + void Find(Unicode::const_iterator begin, Unicode::const_iterator end, vector&res, size_t max_word_len = MAX_WORD_LENGTH) const { diff --git a/test/demo.cpp b/test/demo.cpp index db0dcbc..7464bae 100644 --- a/test/demo.cpp +++ b/test/demo.cpp @@ -51,7 +51,7 @@ int main(int argc, char** argv) { cout << "[demo] Insert User Word" << endl; app.cut("男默女泪", words); cout << join(words.begin(), words.end(), "/") << endl; - app.insertUserWord("男默女泪"); + app.InsertUserWord("男默女泪"); app.cut("男默女泪", words); cout << join(words.begin(), words.end(), "/") << endl; diff --git a/test/unittest/application_test.cpp b/test/unittest/application_test.cpp index f688037..bd34d8c 100644 --- a/test/unittest/application_test.cpp +++ b/test/unittest/application_test.cpp @@ -76,7 +76,7 @@ TEST(ApplicationTest, InsertUserWord) { result << words; ASSERT_EQ("[\"男默\", \"女泪\"]", result); - ASSERT_TRUE(app.insertUserWord("男默女泪")); + ASSERT_TRUE(app.InsertUserWord("男默女泪")); app.cut("男默女泪", words); result << words; @@ -85,7 +85,7 @@ TEST(ApplicationTest, InsertUserWord) { for (size_t i = 0; i < 100; i++) { string newWord; newWord << rand(); - ASSERT_TRUE(app.insertUserWord(newWord)); + ASSERT_TRUE(app.InsertUserWord(newWord)); app.cut(newWord, words); result << words; ASSERT_EQ(result, string_format("[\"%s\"]", newWord.c_str())); diff --git a/test/unittest/trie_test.cpp b/test/unittest/trie_test.cpp index 17de031..035be26 100644 --- a/test/unittest/trie_test.cpp +++ b/test/unittest/trie_test.cpp @@ -24,16 +24,12 @@ TEST(DictTrieTest, NewAndDelete) { DictTrie * trie; trie = new DictTrie(DICT_FILE); delete trie; - trie = new DictTrie(); - delete trie; } - TEST(DictTrieTest, Test1) { string s1, s2; - DictTrie trie; - trie.init(DICT_FILE); - ASSERT_LT(trie.getMinWeight() + 15.6479, 0.001); + DictTrie trie(DICT_FILE); + ASSERT_LT(trie.GetMinWeight() + 15.6479, 0.001); string word("来到"); Unicode uni; ASSERT_TRUE(TransCode::decode(word, uni)); @@ -42,7 +38,7 @@ TEST(DictTrieTest, Test1) { nodeInfo.tag = "v"; nodeInfo.weight = -8.87033; s1 << nodeInfo; - s2 << (*trie.find(uni.begin(), uni.end())); + s2 << (*trie.Find(uni.begin(), uni.end())); EXPECT_EQ("[\"26469\", \"21040\"] v -8.870", s2); word = "清华大学"; @@ -50,13 +46,13 @@ TEST(DictTrieTest, Test1) { const char * words[] = {"清", "清华", "清华大学"}; for (size_t i = 0; i < sizeof(words)/sizeof(words[0]); i++) { ASSERT_TRUE(TransCode::decode(words[i], uni)); - res.push_back(make_pair(uni.size() - 1, trie.find(uni.begin(), uni.end()))); - //resMap[uni.size() - 1] = trie.find(uni.begin(), uni.end()); + res.push_back(make_pair(uni.size() - 1, trie.Find(uni.begin(), uni.end()))); + //resMap[uni.size() - 1] = trie.Find(uni.begin(), uni.end()); } vector > vec; vector dags; ASSERT_TRUE(TransCode::decode(word, uni)); - trie.find(uni.begin(), uni.end(), dags); + trie.Find(uni.begin(), uni.end(), dags); ASSERT_EQ(dags.size(), uni.size()); ASSERT_NE(dags.size(), 0u); s1 << res; @@ -70,7 +66,7 @@ TEST(DictTrieTest, UserDict) { string word = "云计算"; Unicode unicode; ASSERT_TRUE(TransCode::decode(word, unicode)); - const DictUnit * unit = trie.find(unicode.begin(), unicode.end()); + const DictUnit * unit = trie.Find(unicode.begin(), unicode.end()); ASSERT_TRUE(unit); string res ; res << *unit; @@ -85,7 +81,7 @@ TEST(DictTrieTest, Dag) { Unicode unicode; ASSERT_TRUE(TransCode::decode(word, unicode)); vector res; - trie.find(unicode.begin(), unicode.end(), res); + trie.Find(unicode.begin(), unicode.end(), res); size_t nexts_sizes[] = {3, 2, 2, 1}; ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0])); @@ -99,7 +95,7 @@ TEST(DictTrieTest, Dag) { Unicode unicode; ASSERT_TRUE(TransCode::decode(word, unicode)); vector res; - trie.find(unicode.begin(), unicode.end(), res); + trie.Find(unicode.begin(), unicode.end(), res); size_t nexts_sizes[] = {3, 1, 2, 2, 2, 1}; ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0])); @@ -113,7 +109,7 @@ TEST(DictTrieTest, Dag) { Unicode unicode; ASSERT_TRUE(TransCode::decode(word, unicode)); vector res; - trie.find(unicode.begin(), unicode.end(), res); + trie.Find(unicode.begin(), unicode.end(), res); size_t nexts_sizes[] = {3, 1, 2, 1}; ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0])); @@ -127,7 +123,7 @@ TEST(DictTrieTest, Dag) { Unicode unicode; ASSERT_TRUE(TransCode::decode(word, unicode)); vector res; - trie.find(unicode.begin(), unicode.end(), res, 3); + trie.Find(unicode.begin(), unicode.end(), res, 3); size_t nexts_sizes[] = {2, 1, 2, 1}; ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0])); @@ -141,7 +137,7 @@ TEST(DictTrieTest, Dag) { Unicode unicode; ASSERT_TRUE(TransCode::decode(word, unicode)); vector res; - trie.find(unicode.begin(), unicode.end(), res, 4); + trie.Find(unicode.begin(), unicode.end(), res, 4); size_t nexts_sizes[] = {3, 1, 2, 1}; ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0]));