mirror of
https://github.com/yanyiwu/cppjieba.git
synced 2025-07-18 00:00:12 +08:00
add userdict loader
This commit is contained in:
parent
2f314ffdb1
commit
dc96bb3795
@ -22,21 +22,21 @@ namespace CppJieba
|
||||
const double MIN_DOUBLE = -3.14e+100;
|
||||
const double MAX_DOUBLE = 3.14e+100;
|
||||
const size_t DICT_COLUMN_NUM = 3;
|
||||
const char* const UNKNOWN_TAG = "x";
|
||||
|
||||
|
||||
struct DictUnit
|
||||
{
|
||||
Unicode word;
|
||||
size_t freq;
|
||||
double weight;
|
||||
string tag;
|
||||
double logFreq; //logFreq = log(freq/sum(freq));
|
||||
};
|
||||
|
||||
inline ostream & operator << (ostream& os, const DictUnit& unit)
|
||||
{
|
||||
string s;
|
||||
s << unit.word;
|
||||
return os << string_format("%s %u %s %.3lf", s.c_str(), unit.freq, unit.tag.c_str(), unit.logFreq);
|
||||
return os << string_format("%s %s %.3lf", s.c_str(), unit.tag.c_str(), unit.weight);
|
||||
}
|
||||
|
||||
typedef map<size_t, const DictUnit*> DagType;
|
||||
@ -49,15 +49,15 @@ namespace CppJieba
|
||||
vector<DictUnit> _nodeInfos;
|
||||
TrieType * _trie;
|
||||
|
||||
size_t _freqSum;
|
||||
double _minLogFreq;
|
||||
double _minWeight;
|
||||
public:
|
||||
double getMinWeight() const {return _minWeight;};
|
||||
|
||||
public:
|
||||
DictTrie()
|
||||
{
|
||||
_trie = NULL;
|
||||
_freqSum = 0;
|
||||
_minLogFreq = MAX_DOUBLE;
|
||||
_minWeight = MAX_DOUBLE;
|
||||
_setInitFlag(false);
|
||||
}
|
||||
DictTrie(const string& filePath)
|
||||
@ -72,20 +72,22 @@ namespace CppJieba
|
||||
delete _trie;
|
||||
}
|
||||
}
|
||||
private:
|
||||
|
||||
|
||||
public:
|
||||
bool init(const string& filePath)
|
||||
bool init(const string& dictPath, const string& userDictPath = "")
|
||||
{
|
||||
assert(!_getInitFlag());
|
||||
_loadDict(filePath, _nodeInfos);
|
||||
_loadDict(dictPath, _nodeInfos);
|
||||
_calculateWeight(_nodeInfos);
|
||||
_minWeight = _findMinWeight(_nodeInfos);
|
||||
if(userDictPath.size())
|
||||
{
|
||||
_loadUserDict(dictPath, _minWeight, UNKNOWN_TAG, _nodeInfos);
|
||||
}
|
||||
_shrink(_nodeInfos);
|
||||
_freqSum = _calculateFreqSum(_nodeInfos);
|
||||
assert(_freqSum);
|
||||
_minLogFreq = _calculateLogFreqAndGetMinValue(_nodeInfos, _freqSum);
|
||||
_trie = _creatTrie(_nodeInfos);
|
||||
return _setInitFlag(_trie);
|
||||
assert(_trie);
|
||||
return _setInitFlag(true);
|
||||
}
|
||||
|
||||
public:
|
||||
@ -98,16 +100,11 @@ namespace CppJieba
|
||||
return _trie->find(begin, end, dag, offset);
|
||||
}
|
||||
|
||||
public:
|
||||
double getMinLogFreq() const {return _minLogFreq;};
|
||||
|
||||
private:
|
||||
TrieType * _creatTrie(const vector<DictUnit>& dictUnits)
|
||||
{
|
||||
if(dictUnits.empty())
|
||||
{
|
||||
return NULL;
|
||||
}
|
||||
assert(dictUnits.size());
|
||||
vector<Unicode> words;
|
||||
vector<const DictUnit*> valuePointers;
|
||||
for(size_t i = 0 ; i < dictUnits.size(); i ++)
|
||||
@ -119,18 +116,31 @@ namespace CppJieba
|
||||
TrieType * trie = new TrieType(words, valuePointers);
|
||||
return trie;
|
||||
}
|
||||
void _loadUserDict(const string& filePath, double defaultWeight, const string& defaultTag, vector<DictUnit>& nodeInfos) const
|
||||
{
|
||||
ifstream ifs(filePath.c_str());
|
||||
assert(ifs);
|
||||
string line;
|
||||
DictUnit nodeInfo;
|
||||
for(size_t lineno = 0; getline(ifs, line); lineno++)
|
||||
{
|
||||
if(!TransCode::decode(line, nodeInfo.word))
|
||||
{
|
||||
LogError("line[%u:%s] illegal.", lineno, line.c_str());
|
||||
continue;
|
||||
}
|
||||
nodeInfo.weight = defaultWeight;
|
||||
nodeInfo.tag = defaultTag;
|
||||
nodeInfos.push_back(nodeInfo);
|
||||
}
|
||||
}
|
||||
void _loadDict(const string& filePath, vector<DictUnit>& nodeInfos) const
|
||||
{
|
||||
ifstream ifs(filePath.c_str());
|
||||
if(!ifs)
|
||||
{
|
||||
LogFatal("open %s failed.", filePath.c_str());
|
||||
exit(1);
|
||||
}
|
||||
assert(ifs);
|
||||
string line;
|
||||
vector<string> buf;
|
||||
|
||||
nodeInfos.clear();
|
||||
DictUnit nodeInfo;
|
||||
for(size_t lineno = 0 ; getline(ifs, line); lineno++)
|
||||
{
|
||||
@ -142,36 +152,36 @@ namespace CppJieba
|
||||
LogError("line[%u:%s] illegal.", lineno, line.c_str());
|
||||
continue;
|
||||
}
|
||||
nodeInfo.freq = atoi(buf[1].c_str());
|
||||
nodeInfo.weight = atof(buf[1].c_str());
|
||||
nodeInfo.tag = buf[2];
|
||||
|
||||
nodeInfos.push_back(nodeInfo);
|
||||
}
|
||||
}
|
||||
size_t _calculateFreqSum(const vector<DictUnit>& nodeInfos) const
|
||||
double _findMinWeight(const vector<DictUnit>& nodeInfos) const
|
||||
{
|
||||
size_t freqSum = 0;
|
||||
double ret = MAX_DOUBLE;
|
||||
for(size_t i = 0; i < nodeInfos.size(); i++)
|
||||
{
|
||||
freqSum += nodeInfos[i].freq;
|
||||
ret = min(nodeInfos[i].weight, ret);
|
||||
}
|
||||
return freqSum;
|
||||
return ret;
|
||||
}
|
||||
double _calculateLogFreqAndGetMinValue(vector<DictUnit>& nodeInfos, size_t freqSum) const
|
||||
|
||||
void _calculateWeight(vector<DictUnit>& nodeInfos) const
|
||||
{
|
||||
assert(freqSum);
|
||||
double minLogFreq = MAX_DOUBLE;
|
||||
double sum = 0.0;
|
||||
for(size_t i = 0; i < nodeInfos.size(); i++)
|
||||
{
|
||||
sum += nodeInfos[i].weight;
|
||||
}
|
||||
assert(sum);
|
||||
for(size_t i = 0; i < nodeInfos.size(); i++)
|
||||
{
|
||||
DictUnit& nodeInfo = nodeInfos[i];
|
||||
assert(nodeInfo.freq);
|
||||
nodeInfo.logFreq = log(double(nodeInfo.freq)/double(freqSum));
|
||||
if(minLogFreq > nodeInfo.logFreq)
|
||||
{
|
||||
minLogFreq = nodeInfo.logFreq;
|
||||
}
|
||||
assert(nodeInfo.weight);
|
||||
nodeInfo.weight = log(double(nodeInfo.weight)/double(sum));
|
||||
}
|
||||
return minLogFreq;
|
||||
}
|
||||
|
||||
void _shrink(vector<DictUnit>& units) const
|
||||
|
@ -160,11 +160,11 @@ namespace CppJieba
|
||||
|
||||
if(p)
|
||||
{
|
||||
val += p->logFreq;
|
||||
val += p->weight;
|
||||
}
|
||||
else
|
||||
{
|
||||
val += _dictTrie.getMinLogFreq();
|
||||
val += _dictTrie.getMinWeight();
|
||||
}
|
||||
if(val > SegmentChars[i].weight)
|
||||
{
|
||||
|
@ -20,19 +20,18 @@ TEST(DictTrieTest, Test1)
|
||||
string s1, s2;
|
||||
DictTrie trie;
|
||||
ASSERT_TRUE(trie.init(DICT_FILE));
|
||||
ASSERT_LT(trie.getMinLogFreq() + 15.6479, 0.001);
|
||||
ASSERT_LT(trie.getMinWeight() + 15.6479, 0.001);
|
||||
string word("来到");
|
||||
Unicode uni;
|
||||
ASSERT_TRUE(TransCode::decode(word, uni));
|
||||
DictUnit nodeInfo;
|
||||
nodeInfo.word = uni;
|
||||
nodeInfo.freq = 8779;
|
||||
nodeInfo.tag = "v";
|
||||
nodeInfo.logFreq = -8.87033;
|
||||
nodeInfo.weight = -8.87033;
|
||||
s1 << nodeInfo;
|
||||
s2 << (*trie.find(uni.begin(), uni.end()));
|
||||
|
||||
EXPECT_EQ("[\"26469\", \"21040\"] 8779 v -8.870", s2);
|
||||
EXPECT_EQ("[\"26469\", \"21040\"] v -8.870", s2);
|
||||
word = "清华大学";
|
||||
vector<pair<size_t, const DictUnit*> > res;
|
||||
map<size_t, const DictUnit* > resMap;
|
||||
|
Loading…
x
Reference in New Issue
Block a user