cppjieba/src/DictTrie.hpp
2014-11-13 01:16:38 +08:00

212 lines
6.8 KiB
C++

#ifndef CPPJIEBA_DICT_TRIE_HPP
#define CPPJIEBA_DICT_TRIE_HPP
#include <iostream>
#include <fstream>
#include <map>
#include <cstring>
#include <stdint.h>
#include <cmath>
#include <limits>
#include "Limonp/StringUtil.hpp"
#include "Limonp/Logger.hpp"
#include "TransCode.hpp"
#include "Trie.hpp"
namespace CppJieba
{
using namespace Limonp;
const double MIN_DOUBLE = -3.14e+100;
const double MAX_DOUBLE = 3.14e+100;
const size_t DICT_COLUMN_NUM = 3;
const char* const UNKNOWN_TAG = "x";
class DictTrie
{
private:
vector<DictUnit> _nodeInfos;
Trie * _trie;
double _minWeight;
private:
unordered_set<Unicode::value_type> _userDictSingleChineseWord;
public:
bool isUserDictSingleChineseWord(const Unicode::value_type& word) const
{
return isIn(_userDictSingleChineseWord, word);
}
public:
double getMinWeight() const {return _minWeight;};
public:
DictTrie()
{
_trie = NULL;
_minWeight = MAX_DOUBLE;
}
DictTrie(const string& dictPath, const string& userDictPath = "")
{
new (this) DictTrie();
init(dictPath, userDictPath);
}
~DictTrie()
{
if(_trie)
{
delete _trie;
}
}
public:
bool init(const string& dictPath, const string& userDictPath = "")
{
assert(!_trie);
_loadDict(dictPath);
_calculateWeight(_nodeInfos);
_minWeight = _findMinWeight(_nodeInfos);
if(userDictPath.size())
{
double maxWeight = _findMaxWeight(_nodeInfos);
_loadUserDict(userDictPath, maxWeight, UNKNOWN_TAG);
}
_shrink(_nodeInfos);
_trie = _createTrie(_nodeInfos);
assert(_trie);
return true;
}
public:
const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const
{
return _trie->find(begin, end);
}
bool find(Unicode::const_iterator begin, Unicode::const_iterator end, DagType& dag, size_t offset = 0) const
{
return _trie->find(begin, end, dag, offset);
}
void find(
Unicode::const_iterator begin,
Unicode::const_iterator end,
vector<SegmentChar>& res
) const
{
_trie->find(begin, end, res);
}
private:
Trie * _createTrie(const vector<DictUnit>& dictUnits)
{
assert(dictUnits.size());
vector<Unicode> words;
vector<const DictUnit*> valuePointers;
for(size_t i = 0 ; i < dictUnits.size(); i ++)
{
words.push_back(dictUnits[i].word);
valuePointers.push_back(&dictUnits[i]);
}
Trie * trie = new Trie(words, valuePointers);
return trie;
}
void _loadUserDict(const string& filePath, double defaultWeight, const string& defaultTag)
{
ifstream ifs(filePath.c_str());
assert(ifs);
string line;
DictUnit nodeInfo;
vector<string> buf;
size_t lineno;
for(lineno = 0; getline(ifs, line); lineno++)
{
buf.clear();
split(line, buf, " ");
assert(buf.size() >= 1);
if(!TransCode::decode(buf[0], nodeInfo.word))
{
LogError("line[%u:%s] illegal.", lineno, line.c_str());
continue;
}
if(nodeInfo.word.size() == 1)
{
_userDictSingleChineseWord.insert(nodeInfo.word[0]);
}
nodeInfo.weight = defaultWeight;
nodeInfo.tag = (buf.size() == 2 ? buf[1] : defaultTag);
_nodeInfos.push_back(nodeInfo);
}
LogInfo("load userdict[%s] ok. lines[%u]", filePath.c_str(), lineno);
}
void _loadDict(const string& filePath)
{
ifstream ifs(filePath.c_str());
assert(ifs);
string line;
vector<string> buf;
DictUnit nodeInfo;
for(size_t lineno = 0 ; getline(ifs, line); lineno++)
{
split(line, buf, " ");
assert(buf.size() == DICT_COLUMN_NUM);
if(!TransCode::decode(buf[0], nodeInfo.word))
{
LogError("line[%u:%s] illegal.", lineno, line.c_str());
continue;
}
nodeInfo.weight = atof(buf[1].c_str());
nodeInfo.tag = buf[2];
_nodeInfos.push_back(nodeInfo);
}
}
double _findMinWeight(const vector<DictUnit>& nodeInfos) const
{
double ret = MAX_DOUBLE;
for(size_t i = 0; i < nodeInfos.size(); i++)
{
ret = min(nodeInfos[i].weight, ret);
}
return ret;
}
double _findMaxWeight(const vector<DictUnit>& nodeInfos) const
{
double ret = MIN_DOUBLE;
for(size_t i = 0; i < nodeInfos.size(); i++)
{
ret = max(nodeInfos[i].weight, ret);
}
return ret;
}
void _calculateWeight(vector<DictUnit>& nodeInfos) const
{
double sum = 0.0;
for(size_t i = 0; i < nodeInfos.size(); i++)
{
sum += nodeInfos[i].weight;
}
assert(sum);
for(size_t i = 0; i < nodeInfos.size(); i++)
{
DictUnit& nodeInfo = nodeInfos[i];
assert(nodeInfo.weight);
nodeInfo.weight = log(double(nodeInfo.weight)/double(sum));
}
}
void _shrink(vector<DictUnit>& units) const
{
vector<DictUnit>(units.begin(), units.end()).swap(units);
}
};
}
#endif