refactor trie.hpp's loading and building

This commit is contained in:
wyy 2014-04-02 09:08:45 -07:00
parent 86de722888
commit 467fcf8434

View File

@ -26,16 +26,14 @@ namespace CppJieba
const double MAX_DOUBLE = 3.14e+100; const double MAX_DOUBLE = 3.14e+100;
const size_t DICT_COLUMN_NUM = 3; const size_t DICT_COLUMN_NUM = 3;
typedef unordered_map<uint16_t, struct TrieNode*> TrieNodeMap; typedef unordered_map<uint16_t, struct TrieNode*> TrieNodeMap;
struct TrieNodeInfo;
struct TrieNode struct TrieNode
{ {
TrieNodeMap hmap; TrieNodeMap hmap;
bool isLeaf; bool isLeaf;
size_t nodeInfoPos; const TrieNodeInfo * ptTrieNodeInfo;
TrieNode() TrieNode(): isLeaf(false), ptTrieNodeInfo(NULL)
{ {}
isLeaf = false;
nodeInfoPos = 0;
}
}; };
struct TrieNodeInfo struct TrieNodeInfo
@ -72,7 +70,7 @@ namespace CppJieba
public: public:
Trie() Trie()
{ {
_root = NULL; _root = new TrieNode;
_freqSum = 0; _freqSum = 0;
_minLogFreq = MAX_DOUBLE; _minLogFreq = MAX_DOUBLE;
_setInitFlag(false); _setInitFlag(false);
@ -90,15 +88,11 @@ namespace CppJieba
bool init(const string& filePath) bool init(const string& filePath)
{ {
assert(!_getInitFlag()); assert(!_getInitFlag());
_loadDict(filePath, _nodeInfos);
_root = new TrieNode; _createTrie(_nodeInfos, _root);
assert(_root); _freqSum = _calculateFreqSum(_nodeInfos);
if(!_trieInsert(filePath)) assert(_freqSum);
{ _minLogFreq = _calculateLogFreqAndGetMinValue(_nodeInfos, _freqSum);
LogError("_trieInsert failed.");
return false;
}
_countWeight();
return _setInitFlag(true); return _setInitFlag(true);
} }
@ -119,7 +113,7 @@ namespace CppJieba
} }
if(p->isLeaf) if(p->isLeaf)
{ {
return &(_nodeInfos[p->nodeInfoPos]); return p->ptTrieNodeInfo;
} }
return NULL; return NULL;
} }
@ -138,7 +132,7 @@ namespace CppJieba
p = citer->second; p = citer->second;
if(p->isLeaf) if(p->isLeaf)
{ {
res.push_back(make_pair(itr-begin, &_nodeInfos[p->nodeInfoPos])); res.push_back(make_pair(itr-begin, p->ptTrieNodeInfo));
} }
} }
return !res.empty(); return !res.empty();
@ -158,7 +152,7 @@ namespace CppJieba
p = citer->second; p = citer->second;
if(p->isLeaf) if(p->isLeaf)
{ {
res[itr - begin + offset] = &_nodeInfos[p->nodeInfoPos]; res[itr - begin + offset] = p->ptTrieNodeInfo;
} }
} }
return !res.empty(); return !res.empty();
@ -168,43 +162,43 @@ namespace CppJieba
double getMinLogFreq() const {return _minLogFreq;}; double getMinLogFreq() const {return _minLogFreq;};
private: private:
void _insert(const TrieNodeInfo& nodeInfo, size_t nodeInfoPos) void _insertNode(const TrieNodeInfo& nodeInfo, TrieNode* ptNode) const
{ {
const Unicode& unico = nodeInfo.word; const Unicode& unico = nodeInfo.word;
TrieNode* p = _root;
for(size_t i = 0; i < unico.size(); i++) for(size_t i = 0; i < unico.size(); i++)
{ {
uint16_t cu = unico[i]; uint16_t cu = unico[i];
assert(p); assert(ptNode);
if(!isIn(p->hmap, cu)) if(!isIn(ptNode->hmap, cu))
{ {
TrieNode * next = new TrieNode; TrieNode * next = new TrieNode;
assert(next); assert(next);
p->hmap[cu] = next; ptNode->hmap[cu] = next;
p = next; ptNode = next;
} }
else else
{ {
p = p->hmap[cu]; ptNode = ptNode->hmap[cu];
} }
} }
p->isLeaf = true; ptNode->isLeaf = true;
p->nodeInfoPos = nodeInfoPos; ptNode->ptTrieNodeInfo = &nodeInfo;
} }
private: private:
bool _trieInsert(const string& filePath) void _loadDict(const string& filePath, vector<TrieNodeInfo>& nodeInfos) const
{ {
ifstream ifs(filePath.c_str()); ifstream ifs(filePath.c_str());
if(!ifs) if(!ifs)
{ {
LogError("open %s failed.", filePath.c_str()); LogFatal("open %s failed.", filePath.c_str());
return false; exit(1);
} }
string line; string line;
vector<string> buf; vector<string> buf;
nodeInfos.clear();
TrieNodeInfo nodeInfo; TrieNodeInfo nodeInfo;
for(size_t lineno = 0 ; getline(ifs, line); lineno++) for(size_t lineno = 0 ; getline(ifs, line); lineno++)
{ {
@ -213,43 +207,46 @@ namespace CppJieba
if(!TransCode::decode(buf[0], nodeInfo.word)) if(!TransCode::decode(buf[0], nodeInfo.word))
{ {
LogError("line[%u:%s] illegal.", lineno, line.c_str()); LogError("line[%u:%s] illegal.", lineno, line.c_str());
return false; continue;
} }
nodeInfo.freq = atoi(buf[1].c_str()); nodeInfo.freq = atoi(buf[1].c_str());
nodeInfo.tag = buf[2]; nodeInfo.tag = buf[2];
_nodeInfos.push_back(nodeInfo); nodeInfos.push_back(nodeInfo);
} }
}
bool _createTrie(const vector<TrieNodeInfo>& nodeInfos, TrieNode * ptNode)
{
for(size_t i = 0; i < _nodeInfos.size(); i++) for(size_t i = 0; i < _nodeInfos.size(); i++)
{ {
_insert(_nodeInfos[i], i); _insertNode(_nodeInfos[i], ptNode);
} }
return true; return true;
} }
void _countWeight() size_t _calculateFreqSum(const vector<TrieNodeInfo>& nodeInfos) const
{ {
//freq total freq size_t freqSum = 0;
_freqSum = 0; for(size_t i = 0; i < nodeInfos.size(); i++)
for(size_t i = 0; i < _nodeInfos.size(); i++)
{ {
_freqSum += _nodeInfos[i].freq; freqSum += nodeInfos[i].freq;
} }
return freqSum;
assert(_freqSum); }
double _calculateLogFreqAndGetMinValue(vector<TrieNodeInfo>& nodeInfos, size_t freqSum) const
//normalize {
for(size_t i = 0; i < _nodeInfos.size(); i++) assert(freqSum);
double minLogFreq = MAX_DOUBLE;
for(size_t i = 0; i < nodeInfos.size(); i++)
{ {
TrieNodeInfo& nodeInfo = _nodeInfos[i]; TrieNodeInfo& nodeInfo = nodeInfos[i];
assert(nodeInfo.freq); assert(nodeInfo.freq);
nodeInfo.logFreq = log(double(nodeInfo.freq)/double(_freqSum)); nodeInfo.logFreq = log(double(nodeInfo.freq)/double(freqSum));
if(_minLogFreq > nodeInfo.logFreq) if(minLogFreq > nodeInfo.logFreq)
{ {
_minLogFreq = nodeInfo.logFreq; minLogFreq = nodeInfo.logFreq;
} }
} }
return minLogFreq;
} }
void _deleteNode(TrieNode* node) void _deleteNode(TrieNode* node)