Merge pull request #44 from aholic/master

提升Trie的效率
This commit is contained in:
Yanyi Wu 2015-07-21 11:15:26 +08:00
commit 5296a83823
2 changed files with 102 additions and 152 deletions

View File

@ -9,7 +9,6 @@
#include "SegmentBase.hpp" #include "SegmentBase.hpp"
namespace CppJieba { namespace CppJieba {
class HMMSegment: public SegmentBase { class HMMSegment: public SegmentBase {
public: public:
HMMSegment(const string& filePath) { HMMSegment(const string& filePath) {

View File

@ -29,228 +29,179 @@ struct SegmentChar {
const DictUnit * pInfo; const DictUnit * pInfo;
double weight; double weight;
size_t nextPos; size_t nextPos;
SegmentChar():uniCh(0), pInfo(NULL), weight(0.0), nextPos(0) { SegmentChar() : uniCh(), pInfo(NULL), weight(0.0), nextPos(0) {}
} ~SegmentChar() {}
~SegmentChar() {
}
}; };
typedef Unicode::value_type TrieKey; typedef Unicode::value_type TrieKey;
class TrieNode { class TrieNode {
public :
TrieNode(): next(NULL), ptValue(NULL) {}
public: public:
TrieNode(): fail(NULL), next(NULL), ptValue(NULL) { typedef unordered_map<TrieKey, TrieNode*> NextMap;
} NextMap *next;
const TrieNode * findNext(TrieKey key) const { const DictUnit *ptValue;
if(next == NULL) {
return NULL;
}
NextMap::const_iterator iter = next->find(key);
if(iter == next->end()) {
return NULL;
}
return iter->second;
}
public:
typedef unordered_map<TrieKey, TrieNode*> NextMap;
TrieNode * fail;
NextMap * next;
const DictUnit * ptValue;
}; };
class Trie { class Trie {
public: public:
Trie(const vector<Unicode>& keys, const vector<const DictUnit*> & valuePointers) { static const size_t BASE_SIZE = (1 << (8 * (sizeof(TrieKey))));
root_ = new TrieNode;
createTrie_(keys, valuePointers);
build_();// build automation
}
~Trie() {
if(root_) {
deleteNode_(root_);
}
}
public: public:
Trie(const vector<Unicode>& keys, const vector<const DictUnit*>& valuePointers) {
_createTrie(keys, valuePointers);
}
const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const { const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const {
if (begin == end) {
return NULL;
}
const TrieNode* ptNode = _base + (*(begin++));
TrieNode::NextMap::const_iterator citer; TrieNode::NextMap::const_iterator citer;
const TrieNode* ptNode = root_; for (Unicode::const_iterator it = begin; it != end; it++) {
for(Unicode::const_iterator it = begin; it != end; it++) { if (NULL == ptNode->next) {
// build automation return NULL;
assert(ptNode); }
if(NULL == ptNode->next || ptNode->next->end() == (citer = ptNode->next->find(*it))) { citer = ptNode->next->find(*it);
if (ptNode->next->end() == citer) {
return NULL; return NULL;
} }
ptNode = citer->second; ptNode = citer->second;
} }
return ptNode->ptValue; return ptNode->ptValue;
} }
// aho-corasick-automation
void find(Unicode::const_iterator begin, void find(
Unicode::const_iterator begin,
Unicode::const_iterator end, Unicode::const_iterator end,
vector<struct SegmentChar>& res) const { vector<struct SegmentChar>& res
) const {
res.resize(end - begin); res.resize(end - begin);
const TrieNode* now = root_;
const TrieNode* node; const TrieNode *ptNode = NULL;
// compiler will complain warnings if only "i < end - begin" . TrieNode::NextMap::const_iterator citer;
for (size_t i = 0; i < size_t(end - begin); i++) { for (size_t i = 0; i < size_t(end - begin); i++) {
Unicode::value_type ch = *(begin + i); Unicode::value_type ch = *(begin + i);
ptNode = _base + ch;
res[i].uniCh = ch; res[i].uniCh = ch;
assert(res[i].dag.empty()); assert(res[i].dag.empty());
res[i].dag.push_back(pair<vector<Unicode >::size_type, const DictUnit* >(i, (const DictUnit*)NULL));
bool flag = false;
// rollback res[i].dag.push_back(DagType::value_type(i, ptNode->ptValue));
while( now != root_ ) {
node = now->findNext(ch); for (size_t j = i + 1; j < size_t(end - begin); j++) {
if (node != NULL) { if (ptNode->next == NULL) {
flag = true;
break; break;
} else {
now = now->fail;
} }
} citer = ptNode->next->find(*(begin + j));
if (ptNode->next->end() == citer) {
if(!flag) { break;
node = now->findNext(ch); }
} ptNode = citer->second;
if(node == NULL) { if (NULL != ptNode->ptValue) {
now = root_; res[i].dag.push_back(DagType::value_type(j, ptNode->ptValue));
} else {
now = node;
const TrieNode * temp = now;
while(temp != root_) {
if (temp->ptValue) {
size_t pos = i - temp->ptValue->word.size() + 1;
res[pos].dag.push_back(pair<vector<Unicode >::size_type, const DictUnit* >(i, temp->ptValue));
if(pos == i) {
res[pos].dag[0].second = temp->ptValue;
}
}
temp = temp->fail;
assert(temp);
} }
} }
} }
} }
bool find(Unicode::const_iterator begin, bool find(
Unicode::const_iterator begin,
Unicode::const_iterator end, Unicode::const_iterator end,
DagType & res, DagType & res,
size_t offset = 0) const { size_t offset = 0) const {
const TrieNode * ptNode = root_; if (begin == end) {
return !res.empty();
}
const TrieNode* ptNode = _base + (*(begin++));
if (ptNode->ptValue != NULL && res.size() == 1) {
res[0].second = ptNode->ptValue;
} else if (ptNode->ptValue != NULL) {
res.push_back(DagType::value_type(offset, ptNode->ptValue));
}
TrieNode::NextMap::const_iterator citer; TrieNode::NextMap::const_iterator citer;
for(Unicode::const_iterator itr = begin; itr != end ; itr++) { for (Unicode::const_iterator itr = begin; itr != end; itr++) {
assert(ptNode); if (NULL == ptNode->next) {
if(NULL == ptNode->next || ptNode->next->end() == (citer = ptNode->next->find(*itr))) { break;
}
citer = ptNode->next->find(*itr);
if (citer == ptNode->next->end()) {
break; break;
} }
ptNode = citer->second; ptNode = citer->second;
if(ptNode->ptValue) { if (NULL != ptNode->ptValue) {
if(itr == begin && res.size() == 1) { // first singleword res.push_back(DagType::value_type(itr - begin + offset, ptNode->ptValue));
res[0].second = ptNode->ptValue;
} else {
res.push_back(pair<vector<Unicode >::size_type, const DictUnit* >(itr - begin + offset, ptNode->ptValue));
}
} }
} }
return !res.empty(); return !res.empty();
} }
void insertNode(const Unicode& key, const DictUnit* ptValue) { ~Trie() {
TrieNode* newAddedNode = insertNode_(key, ptValue); for (size_t i = 0; i < BASE_SIZE; i++) {
if (newAddedNode) { if (_base[i].next == NULL) {
build_(newAddedNode);
}
}
private:
void build_() {
assert(root_->ptValue == NULL);
assert(root_->next);
root_->fail = NULL;
for(TrieNode::NextMap::iterator iter = root_->next->begin(); iter != root_->next->end(); iter++) {
build_(iter->second);
}
}
void build_(TrieNode* node) {
node->fail = root_;
queue<TrieNode*> que;
que.push(node);
TrieNode* back = NULL;
TrieNode::NextMap::iterator backiter;
while(!que.empty()) {
TrieNode * now = que.front();
que.pop();
if(now->next == NULL) {
continue; continue;
} }
for(TrieNode::NextMap::iterator iter = now->next->begin(); iter != now->next->end(); iter++) { for (TrieNode::NextMap::iterator it = _base[i].next->begin(); it != _base[i].next->end(); it++) {
back = now->fail; _deleteNode(it->second);
while(back != NULL) { it->second = NULL;
if(back->next && (backiter = back->next->find(iter->first)) != back->next->end()) {
iter->second->fail = backiter->second;
break;
}
back = back->fail;
}
if(back == NULL) {
iter->second->fail = root_;
}
que.push(iter->second);
} }
delete _base[i].next;
_base[i].next = NULL;
} }
} }
void createTrie_(const vector<Unicode>& keys,
const vector<const DictUnit*> & valuePointers) { void insertNode(const Unicode& key, const DictUnit* ptValue) {
if(valuePointers.empty() || keys.empty()) { if (key.begin() == key.end()) {
return; return;
} }
assert(keys.size() == valuePointers.size());
for(size_t i = 0; i < keys.size(); i++) {
insertNode_(keys[i], valuePointers[i]);
}
}
TrieNode* insertNode_(const Unicode& key, const DictUnit* ptValue) {
TrieNode* ptNode = root_;
TrieNode* newAddedNode = NULL;
TrieNode::NextMap::const_iterator kmIter; TrieNode::NextMap::const_iterator kmIter;
Unicode::const_iterator citer= key.begin();
for(Unicode::const_iterator citer = key.begin(); citer != key.end(); citer++) { TrieNode *ptNode = _base + (*(citer++));
if(NULL == ptNode->next) { for (; citer != key.end(); citer++) {
if (NULL == ptNode->next) {
ptNode->next = new TrieNode::NextMap; ptNode->next = new TrieNode::NextMap;
} }
kmIter = ptNode->next->find(*citer); kmIter = ptNode->next->find(*citer);
if(ptNode->next->end() == kmIter) { if (ptNode->next->end() == kmIter) {
TrieNode * nextNode = new TrieNode; TrieNode *nextNode = new TrieNode;
nextNode->next = NULL;
nextNode->ptValue = NULL;
if(newAddedNode == NULL) { (*(ptNode->next))[*citer] = nextNode;
newAddedNode = nextNode;
}
(*ptNode->next)[*citer] = nextNode;
ptNode = nextNode; ptNode = nextNode;
} else { } else {
ptNode = kmIter->second; ptNode = kmIter->second;
} }
} }
ptNode->ptValue = ptValue; ptNode->ptValue = ptValue;
return newAddedNode;
} }
void deleteNode_(TrieNode* node) {
if(!node) { private:
void _createTrie(const vector<Unicode>& keys, const vector<const DictUnit*>& valuePointers) {
if (valuePointers.empty() || keys.empty()) {
return; return;
} }
if(node->next) { assert(keys.size() == valuePointers.size());
for (size_t i = 0; i < keys.size(); i++) {
insertNode(keys[i], valuePointers[i]);
}
}
void _deleteNode(TrieNode* node) {
if (NULL == node) {
return;
}
if (NULL != node->next) {
TrieNode::NextMap::iterator it; TrieNode::NextMap::iterator it;
for(it = node->next->begin(); it != node->next->end(); it++) { for (it = node->next->begin(); it != node->next->end(); it++) {
deleteNode_(it->second); _deleteNode(it->second);
} }
delete node->next; delete node->next;
node->next = NULL;
} }
delete node; delete node;
} }
private:
TrieNode* root_; TrieNode _base[BASE_SIZE];
}; };
} }