add userdict loader

This commit is contained in:
wyy 2014-04-25 17:29:42 +08:00
parent 2f314ffdb1
commit dc96bb3795
3 changed files with 57 additions and 48 deletions

View File

@ -22,21 +22,21 @@ namespace CppJieba
const double MIN_DOUBLE = -3.14e+100;
const double MAX_DOUBLE = 3.14e+100;
const size_t DICT_COLUMN_NUM = 3;
const char* const UNKNOWN_TAG = "x";
struct DictUnit
{
Unicode word;
size_t freq;
double weight;
string tag;
double logFreq; //logFreq = log(freq/sum(freq));
};
inline ostream & operator << (ostream& os, const DictUnit& unit)
{
string s;
s << unit.word;
return os << string_format("%s %u %s %.3lf", s.c_str(), unit.freq, unit.tag.c_str(), unit.logFreq);
return os << string_format("%s %s %.3lf", s.c_str(), unit.tag.c_str(), unit.weight);
}
typedef map<size_t, const DictUnit*> DagType;
@ -49,15 +49,15 @@ namespace CppJieba
vector<DictUnit> _nodeInfos;
TrieType * _trie;
size_t _freqSum;
double _minLogFreq;
double _minWeight;
public:
double getMinWeight() const {return _minWeight;};
public:
DictTrie()
{
_trie = NULL;
_freqSum = 0;
_minLogFreq = MAX_DOUBLE;
_minWeight = MAX_DOUBLE;
_setInitFlag(false);
}
DictTrie(const string& filePath)
@ -72,20 +72,22 @@ namespace CppJieba
delete _trie;
}
}
private:
public:
bool init(const string& filePath)
bool init(const string& dictPath, const string& userDictPath = "")
{
assert(!_getInitFlag());
_loadDict(filePath, _nodeInfos);
_loadDict(dictPath, _nodeInfos);
_calculateWeight(_nodeInfos);
_minWeight = _findMinWeight(_nodeInfos);
if(userDictPath.size())
{
_loadUserDict(dictPath, _minWeight, UNKNOWN_TAG, _nodeInfos);
}
_shrink(_nodeInfos);
_freqSum = _calculateFreqSum(_nodeInfos);
assert(_freqSum);
_minLogFreq = _calculateLogFreqAndGetMinValue(_nodeInfos, _freqSum);
_trie = _creatTrie(_nodeInfos);
return _setInitFlag(_trie);
assert(_trie);
return _setInitFlag(true);
}
public:
@ -98,16 +100,11 @@ namespace CppJieba
return _trie->find(begin, end, dag, offset);
}
public:
double getMinLogFreq() const {return _minLogFreq;};
private:
TrieType * _creatTrie(const vector<DictUnit>& dictUnits)
{
if(dictUnits.empty())
{
return NULL;
}
assert(dictUnits.size());
vector<Unicode> words;
vector<const DictUnit*> valuePointers;
for(size_t i = 0 ; i < dictUnits.size(); i ++)
@ -119,18 +116,31 @@ namespace CppJieba
TrieType * trie = new TrieType(words, valuePointers);
return trie;
}
void _loadUserDict(const string& filePath, double defaultWeight, const string& defaultTag, vector<DictUnit>& nodeInfos) const
{
ifstream ifs(filePath.c_str());
assert(ifs);
string line;
DictUnit nodeInfo;
for(size_t lineno = 0; getline(ifs, line); lineno++)
{
if(!TransCode::decode(line, nodeInfo.word))
{
LogError("line[%u:%s] illegal.", lineno, line.c_str());
continue;
}
nodeInfo.weight = defaultWeight;
nodeInfo.tag = defaultTag;
nodeInfos.push_back(nodeInfo);
}
}
void _loadDict(const string& filePath, vector<DictUnit>& nodeInfos) const
{
ifstream ifs(filePath.c_str());
if(!ifs)
{
LogFatal("open %s failed.", filePath.c_str());
exit(1);
}
assert(ifs);
string line;
vector<string> buf;
nodeInfos.clear();
DictUnit nodeInfo;
for(size_t lineno = 0 ; getline(ifs, line); lineno++)
{
@ -142,36 +152,36 @@ namespace CppJieba
LogError("line[%u:%s] illegal.", lineno, line.c_str());
continue;
}
nodeInfo.freq = atoi(buf[1].c_str());
nodeInfo.weight = atof(buf[1].c_str());
nodeInfo.tag = buf[2];
nodeInfos.push_back(nodeInfo);
}
}
size_t _calculateFreqSum(const vector<DictUnit>& nodeInfos) const
double _findMinWeight(const vector<DictUnit>& nodeInfos) const
{
size_t freqSum = 0;
double ret = MAX_DOUBLE;
for(size_t i = 0; i < nodeInfos.size(); i++)
{
freqSum += nodeInfos[i].freq;
ret = min(nodeInfos[i].weight, ret);
}
return freqSum;
return ret;
}
double _calculateLogFreqAndGetMinValue(vector<DictUnit>& nodeInfos, size_t freqSum) const
void _calculateWeight(vector<DictUnit>& nodeInfos) const
{
assert(freqSum);
double minLogFreq = MAX_DOUBLE;
double sum = 0.0;
for(size_t i = 0; i < nodeInfos.size(); i++)
{
sum += nodeInfos[i].weight;
}
assert(sum);
for(size_t i = 0; i < nodeInfos.size(); i++)
{
DictUnit& nodeInfo = nodeInfos[i];
assert(nodeInfo.freq);
nodeInfo.logFreq = log(double(nodeInfo.freq)/double(freqSum));
if(minLogFreq > nodeInfo.logFreq)
{
minLogFreq = nodeInfo.logFreq;
}
assert(nodeInfo.weight);
nodeInfo.weight = log(double(nodeInfo.weight)/double(sum));
}
return minLogFreq;
}
void _shrink(vector<DictUnit>& units) const

View File

@ -160,11 +160,11 @@ namespace CppJieba
if(p)
{
val += p->logFreq;
val += p->weight;
}
else
{
val += _dictTrie.getMinLogFreq();
val += _dictTrie.getMinWeight();
}
if(val > SegmentChars[i].weight)
{

View File

@ -20,19 +20,18 @@ TEST(DictTrieTest, Test1)
string s1, s2;
DictTrie trie;
ASSERT_TRUE(trie.init(DICT_FILE));
ASSERT_LT(trie.getMinLogFreq() + 15.6479, 0.001);
ASSERT_LT(trie.getMinWeight() + 15.6479, 0.001);
string word("来到");
Unicode uni;
ASSERT_TRUE(TransCode::decode(word, uni));
DictUnit nodeInfo;
nodeInfo.word = uni;
nodeInfo.freq = 8779;
nodeInfo.tag = "v";
nodeInfo.logFreq = -8.87033;
nodeInfo.weight = -8.87033;
s1 << nodeInfo;
s2 << (*trie.find(uni.begin(), uni.end()));
EXPECT_EQ("[\"26469\", \"21040\"] 8779 v -8.870", s2);
EXPECT_EQ("[\"26469\", \"21040\"] v -8.870", s2);
word = "清华大学";
vector<pair<size_t, const DictUnit*> > res;
map<size_t, const DictUnit* > resMap;