重构trie前先ci一下

This commit is contained in:
yanyiwu 2015-06-26 14:29:44 +08:00
parent e0db070529
commit c5f7d4d670
4 changed files with 76 additions and 49 deletions

View File

@ -57,8 +57,8 @@ class Application {
LogError("argument method is illegal."); LogError("argument method is illegal.");
} }
} }
void insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) {
dictTrie_.insertUserWord(word, tag); return dictTrie_.insertUserWord(word, tag);
} }
void tag(const string& str, vector<pair<string, string> >& res) const { void tag(const string& str, vector<pair<string, string> >& res) const {
tagger_.tag(str, res); tagger_.tag(str, res);

View File

@ -13,8 +13,6 @@
#include "TransCode.hpp" #include "TransCode.hpp"
#include "Trie.hpp" #include "Trie.hpp"
namespace CppJieba { namespace CppJieba {
using namespace Limonp; using namespace Limonp;
const double MIN_DOUBLE = -3.14e+100; const double MIN_DOUBLE = -3.14e+100;
@ -44,16 +42,25 @@ class DictTrie {
LogFatal("trie already initted"); LogFatal("trie already initted");
} }
loadDict_(dictPath); loadDict_(dictPath);
calculateWeight_(nodeInfos_); calculateWeight_(staticNodeInfos_);
minWeight_ = findMinWeight_(nodeInfos_); minWeight_ = findMinWeight_(staticNodeInfos_);
maxWeight_ = findMaxWeight_(nodeInfos_); maxWeight_ = findMaxWeight_(staticNodeInfos_);
if(userDictPath.size()) { if(userDictPath.size()) {
loadUserDict_(userDictPath); loadUserDict_(userDictPath);
} }
shrink_(nodeInfos_); shrink_(staticNodeInfos_);
trie_ = createTrie_(nodeInfos_); createTrie_(staticNodeInfos_);
assert(trie_); }
bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) {
DictUnit nodeInfo;
if(!makeUserNodeInfo_(nodeInfo, word, tag)) {
return false;
}
activeNodeInfos_.push_back(nodeInfo);
trie_->insertNode(nodeInfo.word, &activeNodeInfos_.back());
return true;
} }
const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const { const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const {
@ -67,20 +74,6 @@ class DictTrie {
vector<SegmentChar>& res) const { vector<SegmentChar>& res) const {
trie_->find(begin, end, res); trie_->find(begin, end, res);
} }
bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) {
DictUnit nodeInfo;
if(!TransCode::decode(word, nodeInfo.word)) {
LogError("decode %s failed.", word.c_str());
return false;
}
if(nodeInfo.word.size() == 1) {
userDictSingleChineseWord_.insert(nodeInfo.word[0]);
}
nodeInfo.weight = maxWeight_;
nodeInfo.tag = tag;
nodeInfos_.push_back(nodeInfo);
return true;
}
bool isUserDictSingleChineseWord(const Unicode::value_type& word) const { bool isUserDictSingleChineseWord(const Unicode::value_type& word) const {
return isIn(userDictSingleChineseWord_, word); return isIn(userDictSingleChineseWord_, word);
} }
@ -88,9 +81,8 @@ class DictTrie {
return minWeight_; return minWeight_;
}; };
private: private:
Trie * createTrie_(const vector<DictUnit>& dictUnits) { void createTrie_(const vector<DictUnit>& dictUnits) {
assert(dictUnits.size()); assert(dictUnits.size());
vector<Unicode> words; vector<Unicode> words;
vector<const DictUnit*> valuePointers; vector<const DictUnit*> valuePointers;
@ -99,8 +91,7 @@ class DictTrie {
valuePointers.push_back(&dictUnits[i]); valuePointers.push_back(&dictUnits[i]);
} }
Trie * trie = new Trie(words, valuePointers); trie_ = new Trie(words, valuePointers);
return trie;
} }
void loadUserDict_(const string& filePath) { void loadUserDict_(const string& filePath) {
ifstream ifs(filePath.c_str()); ifstream ifs(filePath.c_str());
@ -117,19 +108,37 @@ class DictTrie {
if(buf.size() < 1) { if(buf.size() < 1) {
LogFatal("split [%s] result illegal", line.c_str()); LogFatal("split [%s] result illegal", line.c_str());
} }
insertUserWord(buf[0], (buf.size() == 2 ? buf[1] : UNKNOWN_TAG)); DictUnit nodeInfo;
makeUserNodeInfo_(nodeInfo, buf[0],
(buf.size() == 2 ? buf[1] : UNKNOWN_TAG));
staticNodeInfos_.push_back(nodeInfo);
} }
LogInfo("load userdict[%s] ok. lines[%u]", filePath.c_str(), lineno); LogInfo("load userdict[%s] ok. lines[%u]", filePath.c_str(), lineno);
} }
bool insertWord_(const string& word, double weight, const string& tag) { bool makeNodeInfo(DictUnit& nodeInfo,
DictUnit nodeInfo; const string& word,
double weight,
const string& tag) {
if(!TransCode::decode(word, nodeInfo.word)) { if(!TransCode::decode(word, nodeInfo.word)) {
LogError("decode %s failed.", word.c_str()); LogError("decode %s failed.", word.c_str());
return false; return false;
} }
nodeInfo.weight = weight; nodeInfo.weight = weight;
nodeInfo.tag = tag; nodeInfo.tag = tag;
nodeInfos_.push_back(nodeInfo); return true;
}
bool makeUserNodeInfo_(DictUnit& nodeInfo,
const string& word,
const string& tag = UNKNOWN_TAG) {
if(!TransCode::decode(word, nodeInfo.word)) {
LogError("decode %s failed.", word.c_str());
return false;
}
if(nodeInfo.word.size() == 1) {
userDictSingleChineseWord_.insert(nodeInfo.word[0]);
}
nodeInfo.weight = maxWeight_;
nodeInfo.tag = tag;
return true; return true;
} }
void loadDict_(const string& filePath) { void loadDict_(const string& filePath) {
@ -146,7 +155,11 @@ class DictTrie {
if(buf.size() != DICT_COLUMN_NUM) { if(buf.size() != DICT_COLUMN_NUM) {
LogFatal("split result illegal, line: %s, result size: %u", line.c_str(), buf.size()); LogFatal("split result illegal, line: %s, result size: %u", line.c_str(), buf.size());
} }
insertWord_(buf[0], atof(buf[1].c_str()), buf[2]); makeNodeInfo(nodeInfo,
buf[0],
atof(buf[1].c_str()),
buf[2]);
staticNodeInfos_.push_back(nodeInfo);
} }
} }
double findMinWeight_(const vector<DictUnit>& nodeInfos) const { double findMinWeight_(const vector<DictUnit>& nodeInfos) const {
@ -182,7 +195,8 @@ class DictTrie {
} }
private: private:
vector<DictUnit> nodeInfos_; vector<DictUnit> staticNodeInfos_;
deque<DictUnit> activeNodeInfos_; // must not be vector
Trie * trie_; Trie * trie_;
double minWeight_; double minWeight_;

View File

@ -85,13 +85,11 @@ class Trie {
return ptNode->ptValue; return ptNode->ptValue;
} }
// aho-corasick-automation // aho-corasick-automation
void find( void find(Unicode::const_iterator begin,
Unicode::const_iterator begin,
Unicode::const_iterator end, Unicode::const_iterator end,
vector<struct SegmentChar>& res vector<struct SegmentChar>& res) const {
) const {
res.resize(end - begin); res.resize(end - begin);
const TrieNode * now = root_; const TrieNode* now = root_;
const TrieNode* node; const TrieNode* node;
// compiler will complain warnings if only "i < end - begin" . // compiler will complain warnings if only "i < end - begin" .
for (size_t i = 0; i < size_t(end - begin); i++) { for (size_t i = 0; i < size_t(end - begin); i++) {
@ -134,8 +132,7 @@ class Trie {
} }
} }
} }
bool find( bool find(Unicode::const_iterator begin,
Unicode::const_iterator begin,
Unicode::const_iterator end, Unicode::const_iterator end,
DagType & res, DagType & res,
size_t offset = 0) const { size_t offset = 0) const {
@ -157,6 +154,9 @@ class Trie {
} }
return !res.empty(); return !res.empty();
} }
void insertNode(const Unicode& key, const DictUnit* ptValue) {
insertNode_(key, ptValue);
}
private: private:
void build_() { void build_() {
queue<TrieNode*> que; queue<TrieNode*> que;
@ -191,7 +191,8 @@ class Trie {
} }
} }
} }
void createTrie_(const vector<Unicode>& keys, const vector<const DictUnit*> & valuePointers) { void createTrie_(const vector<Unicode>& keys,
const vector<const DictUnit*> & valuePointers) {
if(valuePointers.empty() || keys.empty()) { if(valuePointers.empty() || keys.empty()) {
return; return;
} }

View File

@ -51,10 +51,22 @@ TEST(ApplicationTest, Test1) {
ASSERT_EQ(result, "[\"CEO:11.7392\", \"升职:10.8562\", \"加薪:10.6426\", \"手扶拖拉机:10.0089\", \"巅峰:9.49396\"]"); ASSERT_EQ(result, "[\"CEO:11.7392\", \"升职:10.8562\", \"加薪:10.6426\", \"手扶拖拉机:10.0089\", \"巅峰:9.49396\"]");
} }
//TEST(ApplicationTest, InsertUserWord) { TEST(ApplicationTest, InsertUserWord) {
// CppJieba::Application app("../dict/jieba.dict.utf8", CppJieba::Application app("../dict/jieba.dict.utf8",
// "../dict/hmm_model.utf8", "../dict/hmm_model.utf8",
// "../dict/user.dict.utf8", "../dict/user.dict.utf8",
// "../dict/idf.utf8", "../dict/idf.utf8",
// "../dict/stop_words.utf8"); "../dict/stop_words.utf8");
//} vector<string> words;
string result;
app.cut("男默女泪", words);
result << words;
ASSERT_EQ("[\"男默\", \"女泪\"]", result);
//ASSERT_TRUE(app.insertUserWord("男默女泪"));
//app.cut("男默女泪", words);
//result << words;
//ASSERT_EQ("[\"男默女泪\"]", result);
}