From 0984c9ed3fe55dc8fa91b17d70fd931ef11d0970 Mon Sep 17 00:00:00 2001 From: yanyiwu Date: Fri, 22 Jul 2016 23:53:49 +0800 Subject: [PATCH] update user dict loading method about word weight, and add unit tests --- dict/user.dict.utf8 | 1 + include/cppjieba/DictTrie.hpp | 27 +++++++++++++++++++-------- test/testdata/userdict.utf8 | 1 + test/unittest/trie_test.cpp | 16 +++++++++++++++- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/dict/user.dict.utf8 b/dict/user.dict.utf8 index a422594..6a42ee4 100644 --- a/dict/user.dict.utf8 +++ b/dict/user.dict.utf8 @@ -1,3 +1,4 @@ 云计算 韩玉鉴赏 蓝翔 nz +区块链 10 nz diff --git a/include/cppjieba/DictTrie.hpp b/include/cppjieba/DictTrie.hpp index d0b2d2b..1e22ecc 100644 --- a/include/cppjieba/DictTrie.hpp +++ b/include/cppjieba/DictTrie.hpp @@ -72,7 +72,8 @@ class DictTrie { private: void Init(const string& dict_path, const string& user_dict_paths, UserWordWeightOption user_word_weight_opt) { LoadDict(dict_path); - CalculateWeight(static_node_infos_); + freq_sum_ = CalcFreqSum(static_node_infos_); + CalculateWeight(static_node_infos_, freq_sum_); SetStaticWordWeights(user_word_weight_opt); if (user_dict_paths.size()) { @@ -115,11 +116,16 @@ class DictTrie { buf[0], user_word_default_weight_, UNKNOWN_TAG); - } else { + } else if (buf.size() == 2) { MakeNodeInfo(node_info, buf[0], - (buf.size() == 2 ? user_word_default_weight_ : atoi(buf[1].c_str())), - (buf.size() == 3 ? buf[2] : buf[1])); + user_word_default_weight_, + buf[1]); + } else if (buf.size() == 3) { + int freq = atoi(buf[1].c_str()); + assert(freq_sum_ > 0.0); + double weight = log(1.0 * freq / freq_sum_); + MakeNodeInfo(node_info, buf[0], weight, buf[2]); } static_node_infos_.push_back(node_info); if (node_info.word.size() == 1) { @@ -184,16 +190,20 @@ class DictTrie { } } - void CalculateWeight(vector& node_infos) const { + double CalcFreqSum(const vector& node_infos) const { double sum = 0.0; for (size_t i = 0; i < node_infos.size(); i++) { sum += node_infos[i].weight; } - assert(sum); + return sum; + } + + void CalculateWeight(vector& node_infos, double sum) const { + assert(sum > 0.0); for (size_t i = 0; i < node_infos.size(); i++) { DictUnit& node_info = node_infos[i]; - assert(node_info.weight); - node_info.weight = log(double(node_info.weight)/double(sum)); + assert(node_info.weight > 0.0); + node_info.weight = log(double(node_info.weight)/sum); } } @@ -205,6 +215,7 @@ class DictTrie { deque active_node_infos_; // must not be vector Trie * trie_; + double freq_sum_; double min_weight_; double max_weight_; double median_weight_; diff --git a/test/testdata/userdict.utf8 b/test/testdata/userdict.utf8 index 6477fef..688ff55 100644 --- a/test/testdata/userdict.utf8 +++ b/test/testdata/userdict.utf8 @@ -5,3 +5,4 @@ B iPhone6 蓝翔 nz 忽如一夜春风来 +区块链 10 nz diff --git a/test/unittest/trie_test.cpp b/test/unittest/trie_test.cpp index 137b213..1f03540 100644 --- a/test/unittest/trie_test.cpp +++ b/test/unittest/trie_test.cpp @@ -74,8 +74,22 @@ TEST(DictTrieTest, UserDict) { cppjieba::RuneStrArray unicode; ASSERT_TRUE(DecodeRunesInString(word, unicode)); const DictUnit * unit = trie.Find(unicode.begin(), unicode.end()); - ASSERT_TRUE(unit); + ASSERT_TRUE(unit != NULL); ASSERT_NEAR(unit->weight, -14.100, 0.001); + + word = "蓝翔"; + ASSERT_TRUE(DecodeRunesInString(word, unicode)); + unit = trie.Find(unicode.begin(), unicode.end()); + ASSERT_TRUE(unit != NULL); + ASSERT_EQ(unit->tag, "nz"); + ASSERT_NEAR(unit->weight, -14.100, 0.001); + + word = "区块链"; + ASSERT_TRUE(DecodeRunesInString(word, unicode)); + unit = trie.Find(unicode.begin(), unicode.end()); + ASSERT_TRUE(unit != NULL); + ASSERT_EQ(unit->tag, "nz"); + ASSERT_NEAR(unit->weight, -15.6478, 0.001); } TEST(DictTrieTest, UserDictWithMaxWeight) {