add userdict loader

This commit is contained in:
wyy 2014-04-25 17:29:42 +08:00
parent 2f314ffdb1
commit dc96bb3795
3 changed files with 57 additions and 48 deletions

View File

@ -22,21 +22,21 @@ namespace CppJieba
const double MIN_DOUBLE = -3.14e+100; const double MIN_DOUBLE = -3.14e+100;
const double MAX_DOUBLE = 3.14e+100; const double MAX_DOUBLE = 3.14e+100;
const size_t DICT_COLUMN_NUM = 3; const size_t DICT_COLUMN_NUM = 3;
const char* const UNKNOWN_TAG = "x";
struct DictUnit struct DictUnit
{ {
Unicode word; Unicode word;
size_t freq; double weight;
string tag; string tag;
double logFreq; //logFreq = log(freq/sum(freq));
}; };
inline ostream & operator << (ostream& os, const DictUnit& unit) inline ostream & operator << (ostream& os, const DictUnit& unit)
{ {
string s; string s;
s << unit.word; s << unit.word;
return os << string_format("%s %u %s %.3lf", s.c_str(), unit.freq, unit.tag.c_str(), unit.logFreq); return os << string_format("%s %s %.3lf", s.c_str(), unit.tag.c_str(), unit.weight);
} }
typedef map<size_t, const DictUnit*> DagType; typedef map<size_t, const DictUnit*> DagType;
@ -49,15 +49,15 @@ namespace CppJieba
vector<DictUnit> _nodeInfos; vector<DictUnit> _nodeInfos;
TrieType * _trie; TrieType * _trie;
size_t _freqSum; double _minWeight;
double _minLogFreq; public:
double getMinWeight() const {return _minWeight;};
public: public:
DictTrie() DictTrie()
{ {
_trie = NULL; _trie = NULL;
_freqSum = 0; _minWeight = MAX_DOUBLE;
_minLogFreq = MAX_DOUBLE;
_setInitFlag(false); _setInitFlag(false);
} }
DictTrie(const string& filePath) DictTrie(const string& filePath)
@ -72,20 +72,22 @@ namespace CppJieba
delete _trie; delete _trie;
} }
} }
private:
public: public:
bool init(const string& filePath) bool init(const string& dictPath, const string& userDictPath = "")
{ {
assert(!_getInitFlag()); assert(!_getInitFlag());
_loadDict(filePath, _nodeInfos); _loadDict(dictPath, _nodeInfos);
_calculateWeight(_nodeInfos);
_minWeight = _findMinWeight(_nodeInfos);
if(userDictPath.size())
{
_loadUserDict(dictPath, _minWeight, UNKNOWN_TAG, _nodeInfos);
}
_shrink(_nodeInfos); _shrink(_nodeInfos);
_freqSum = _calculateFreqSum(_nodeInfos);
assert(_freqSum);
_minLogFreq = _calculateLogFreqAndGetMinValue(_nodeInfos, _freqSum);
_trie = _creatTrie(_nodeInfos); _trie = _creatTrie(_nodeInfos);
return _setInitFlag(_trie); assert(_trie);
return _setInitFlag(true);
} }
public: public:
@ -98,16 +100,11 @@ namespace CppJieba
return _trie->find(begin, end, dag, offset); return _trie->find(begin, end, dag, offset);
} }
public:
double getMinLogFreq() const {return _minLogFreq;};
private: private:
TrieType * _creatTrie(const vector<DictUnit>& dictUnits) TrieType * _creatTrie(const vector<DictUnit>& dictUnits)
{ {
if(dictUnits.empty()) assert(dictUnits.size());
{
return NULL;
}
vector<Unicode> words; vector<Unicode> words;
vector<const DictUnit*> valuePointers; vector<const DictUnit*> valuePointers;
for(size_t i = 0 ; i < dictUnits.size(); i ++) for(size_t i = 0 ; i < dictUnits.size(); i ++)
@ -119,18 +116,31 @@ namespace CppJieba
TrieType * trie = new TrieType(words, valuePointers); TrieType * trie = new TrieType(words, valuePointers);
return trie; return trie;
} }
void _loadUserDict(const string& filePath, double defaultWeight, const string& defaultTag, vector<DictUnit>& nodeInfos) const
{
ifstream ifs(filePath.c_str());
assert(ifs);
string line;
DictUnit nodeInfo;
for(size_t lineno = 0; getline(ifs, line); lineno++)
{
if(!TransCode::decode(line, nodeInfo.word))
{
LogError("line[%u:%s] illegal.", lineno, line.c_str());
continue;
}
nodeInfo.weight = defaultWeight;
nodeInfo.tag = defaultTag;
nodeInfos.push_back(nodeInfo);
}
}
void _loadDict(const string& filePath, vector<DictUnit>& nodeInfos) const void _loadDict(const string& filePath, vector<DictUnit>& nodeInfos) const
{ {
ifstream ifs(filePath.c_str()); ifstream ifs(filePath.c_str());
if(!ifs) assert(ifs);
{
LogFatal("open %s failed.", filePath.c_str());
exit(1);
}
string line; string line;
vector<string> buf; vector<string> buf;
nodeInfos.clear();
DictUnit nodeInfo; DictUnit nodeInfo;
for(size_t lineno = 0 ; getline(ifs, line); lineno++) for(size_t lineno = 0 ; getline(ifs, line); lineno++)
{ {
@ -142,36 +152,36 @@ namespace CppJieba
LogError("line[%u:%s] illegal.", lineno, line.c_str()); LogError("line[%u:%s] illegal.", lineno, line.c_str());
continue; continue;
} }
nodeInfo.freq = atoi(buf[1].c_str()); nodeInfo.weight = atof(buf[1].c_str());
nodeInfo.tag = buf[2]; nodeInfo.tag = buf[2];
nodeInfos.push_back(nodeInfo); nodeInfos.push_back(nodeInfo);
} }
} }
size_t _calculateFreqSum(const vector<DictUnit>& nodeInfos) const double _findMinWeight(const vector<DictUnit>& nodeInfos) const
{ {
size_t freqSum = 0; double ret = MAX_DOUBLE;
for(size_t i = 0; i < nodeInfos.size(); i++) for(size_t i = 0; i < nodeInfos.size(); i++)
{ {
freqSum += nodeInfos[i].freq; ret = min(nodeInfos[i].weight, ret);
} }
return freqSum; return ret;
} }
double _calculateLogFreqAndGetMinValue(vector<DictUnit>& nodeInfos, size_t freqSum) const
void _calculateWeight(vector<DictUnit>& nodeInfos) const
{ {
assert(freqSum); double sum = 0.0;
double minLogFreq = MAX_DOUBLE; for(size_t i = 0; i < nodeInfos.size(); i++)
{
sum += nodeInfos[i].weight;
}
assert(sum);
for(size_t i = 0; i < nodeInfos.size(); i++) for(size_t i = 0; i < nodeInfos.size(); i++)
{ {
DictUnit& nodeInfo = nodeInfos[i]; DictUnit& nodeInfo = nodeInfos[i];
assert(nodeInfo.freq); assert(nodeInfo.weight);
nodeInfo.logFreq = log(double(nodeInfo.freq)/double(freqSum)); nodeInfo.weight = log(double(nodeInfo.weight)/double(sum));
if(minLogFreq > nodeInfo.logFreq)
{
minLogFreq = nodeInfo.logFreq;
}
} }
return minLogFreq;
} }
void _shrink(vector<DictUnit>& units) const void _shrink(vector<DictUnit>& units) const

View File

@ -160,11 +160,11 @@ namespace CppJieba
if(p) if(p)
{ {
val += p->logFreq; val += p->weight;
} }
else else
{ {
val += _dictTrie.getMinLogFreq(); val += _dictTrie.getMinWeight();
} }
if(val > SegmentChars[i].weight) if(val > SegmentChars[i].weight)
{ {

View File

@ -20,19 +20,18 @@ TEST(DictTrieTest, Test1)
string s1, s2; string s1, s2;
DictTrie trie; DictTrie trie;
ASSERT_TRUE(trie.init(DICT_FILE)); ASSERT_TRUE(trie.init(DICT_FILE));
ASSERT_LT(trie.getMinLogFreq() + 15.6479, 0.001); ASSERT_LT(trie.getMinWeight() + 15.6479, 0.001);
string word("来到"); string word("来到");
Unicode uni; Unicode uni;
ASSERT_TRUE(TransCode::decode(word, uni)); ASSERT_TRUE(TransCode::decode(word, uni));
DictUnit nodeInfo; DictUnit nodeInfo;
nodeInfo.word = uni; nodeInfo.word = uni;
nodeInfo.freq = 8779;
nodeInfo.tag = "v"; nodeInfo.tag = "v";
nodeInfo.logFreq = -8.87033; nodeInfo.weight = -8.87033;
s1 << nodeInfo; s1 << nodeInfo;
s2 << (*trie.find(uni.begin(), uni.end())); s2 << (*trie.find(uni.begin(), uni.end()));
EXPECT_EQ("[\"26469\", \"21040\"] 8779 v -8.870", s2); EXPECT_EQ("[\"26469\", \"21040\"] v -8.870", s2);
word = "清华大学"; word = "清华大学";
vector<pair<size_t, const DictUnit*> > res; vector<pair<size_t, const DictUnit*> > res;
map<size_t, const DictUnit* > resMap; map<size_t, const DictUnit* > resMap;