diff --git a/src/Application.hpp b/src/Application.hpp index 3e8f104..d7b7cc9 100644 --- a/src/Application.hpp +++ b/src/Application.hpp @@ -57,8 +57,8 @@ class Application { LogError("argument method is illegal."); } } - void insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { - dictTrie_.insertUserWord(word, tag); + bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { + return dictTrie_.insertUserWord(word, tag); } void tag(const string& str, vector >& res) const { tagger_.tag(str, res); diff --git a/src/DictTrie.hpp b/src/DictTrie.hpp index 0f02f4f..4fdd478 100644 --- a/src/DictTrie.hpp +++ b/src/DictTrie.hpp @@ -13,8 +13,6 @@ #include "TransCode.hpp" #include "Trie.hpp" - - namespace CppJieba { using namespace Limonp; const double MIN_DOUBLE = -3.14e+100; @@ -44,16 +42,25 @@ class DictTrie { LogFatal("trie already initted"); } loadDict_(dictPath); - calculateWeight_(nodeInfos_); - minWeight_ = findMinWeight_(nodeInfos_); - maxWeight_ = findMaxWeight_(nodeInfos_); + calculateWeight_(staticNodeInfos_); + minWeight_ = findMinWeight_(staticNodeInfos_); + maxWeight_ = findMaxWeight_(staticNodeInfos_); if(userDictPath.size()) { loadUserDict_(userDictPath); } - shrink_(nodeInfos_); - trie_ = createTrie_(nodeInfos_); - assert(trie_); + shrink_(staticNodeInfos_); + createTrie_(staticNodeInfos_); + } + + bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { + DictUnit nodeInfo; + if(!makeUserNodeInfo_(nodeInfo, word, tag)) { + return false; + } + activeNodeInfos_.push_back(nodeInfo); + trie_->insertNode(nodeInfo.word, &activeNodeInfos_.back()); + return true; } const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const { @@ -67,20 +74,6 @@ class DictTrie { vector& res) const { trie_->find(begin, end, res); } - bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { - DictUnit nodeInfo; - if(!TransCode::decode(word, nodeInfo.word)) { - LogError("decode %s failed.", word.c_str()); - return false; - } - if(nodeInfo.word.size() == 1) { - userDictSingleChineseWord_.insert(nodeInfo.word[0]); - } - nodeInfo.weight = maxWeight_; - nodeInfo.tag = tag; - nodeInfos_.push_back(nodeInfo); - return true; - } bool isUserDictSingleChineseWord(const Unicode::value_type& word) const { return isIn(userDictSingleChineseWord_, word); } @@ -88,9 +81,8 @@ class DictTrie { return minWeight_; }; - private: - Trie * createTrie_(const vector& dictUnits) { + void createTrie_(const vector& dictUnits) { assert(dictUnits.size()); vector words; vector valuePointers; @@ -99,8 +91,7 @@ class DictTrie { valuePointers.push_back(&dictUnits[i]); } - Trie * trie = new Trie(words, valuePointers); - return trie; + trie_ = new Trie(words, valuePointers); } void loadUserDict_(const string& filePath) { ifstream ifs(filePath.c_str()); @@ -117,19 +108,37 @@ class DictTrie { if(buf.size() < 1) { LogFatal("split [%s] result illegal", line.c_str()); } - insertUserWord(buf[0], (buf.size() == 2 ? buf[1] : UNKNOWN_TAG)); + DictUnit nodeInfo; + makeUserNodeInfo_(nodeInfo, buf[0], + (buf.size() == 2 ? buf[1] : UNKNOWN_TAG)); + staticNodeInfos_.push_back(nodeInfo); } LogInfo("load userdict[%s] ok. lines[%u]", filePath.c_str(), lineno); } - bool insertWord_(const string& word, double weight, const string& tag) { - DictUnit nodeInfo; + bool makeNodeInfo(DictUnit& nodeInfo, + const string& word, + double weight, + const string& tag) { if(!TransCode::decode(word, nodeInfo.word)) { LogError("decode %s failed.", word.c_str()); return false; } nodeInfo.weight = weight; nodeInfo.tag = tag; - nodeInfos_.push_back(nodeInfo); + return true; + } + bool makeUserNodeInfo_(DictUnit& nodeInfo, + const string& word, + const string& tag = UNKNOWN_TAG) { + if(!TransCode::decode(word, nodeInfo.word)) { + LogError("decode %s failed.", word.c_str()); + return false; + } + if(nodeInfo.word.size() == 1) { + userDictSingleChineseWord_.insert(nodeInfo.word[0]); + } + nodeInfo.weight = maxWeight_; + nodeInfo.tag = tag; return true; } void loadDict_(const string& filePath) { @@ -146,7 +155,11 @@ class DictTrie { if(buf.size() != DICT_COLUMN_NUM) { LogFatal("split result illegal, line: %s, result size: %u", line.c_str(), buf.size()); } - insertWord_(buf[0], atof(buf[1].c_str()), buf[2]); + makeNodeInfo(nodeInfo, + buf[0], + atof(buf[1].c_str()), + buf[2]); + staticNodeInfos_.push_back(nodeInfo); } } double findMinWeight_(const vector& nodeInfos) const { @@ -182,7 +195,8 @@ class DictTrie { } private: - vector nodeInfos_; + vector staticNodeInfos_; + deque activeNodeInfos_; // must not be vector Trie * trie_; double minWeight_; diff --git a/src/Trie.hpp b/src/Trie.hpp index 7363d10..cfc20e3 100644 --- a/src/Trie.hpp +++ b/src/Trie.hpp @@ -85,13 +85,11 @@ class Trie { return ptNode->ptValue; } // aho-corasick-automation - void find( - Unicode::const_iterator begin, + void find(Unicode::const_iterator begin, Unicode::const_iterator end, - vector& res - ) const { + vector& res) const { res.resize(end - begin); - const TrieNode * now = root_; + const TrieNode* now = root_; const TrieNode* node; // compiler will complain warnings if only "i < end - begin" . for (size_t i = 0; i < size_t(end - begin); i++) { @@ -134,8 +132,7 @@ class Trie { } } } - bool find( - Unicode::const_iterator begin, + bool find(Unicode::const_iterator begin, Unicode::const_iterator end, DagType & res, size_t offset = 0) const { @@ -157,6 +154,9 @@ class Trie { } return !res.empty(); } + void insertNode(const Unicode& key, const DictUnit* ptValue) { + insertNode_(key, ptValue); + } private: void build_() { queue que; @@ -191,7 +191,8 @@ class Trie { } } } - void createTrie_(const vector& keys, const vector & valuePointers) { + void createTrie_(const vector& keys, + const vector & valuePointers) { if(valuePointers.empty() || keys.empty()) { return; } diff --git a/test/unittest/TApplication.cpp b/test/unittest/TApplication.cpp index 3023a51..bfb888a 100644 --- a/test/unittest/TApplication.cpp +++ b/test/unittest/TApplication.cpp @@ -51,10 +51,22 @@ TEST(ApplicationTest, Test1) { ASSERT_EQ(result, "[\"CEO:11.7392\", \"升职:10.8562\", \"加薪:10.6426\", \"手扶拖拉机:10.0089\", \"巅峰:9.49396\"]"); } -//TEST(ApplicationTest, InsertUserWord) { -// CppJieba::Application app("../dict/jieba.dict.utf8", -// "../dict/hmm_model.utf8", -// "../dict/user.dict.utf8", -// "../dict/idf.utf8", -// "../dict/stop_words.utf8"); -//} +TEST(ApplicationTest, InsertUserWord) { + CppJieba::Application app("../dict/jieba.dict.utf8", + "../dict/hmm_model.utf8", + "../dict/user.dict.utf8", + "../dict/idf.utf8", + "../dict/stop_words.utf8"); + vector words; + string result; + + app.cut("男默女泪", words); + result << words; + ASSERT_EQ("[\"男默\", \"女泪\"]", result); + + //ASSERT_TRUE(app.insertUserWord("男默女泪")); + + //app.cut("男默女泪", words); + //result << words; + //ASSERT_EQ("[\"男默女泪\"]", result); +}