From dc96bb3795aa33e2ef583292e5e953d3dc3408aa Mon Sep 17 00:00:00 2001 From: wyy Date: Fri, 25 Apr 2014 17:29:42 +0800 Subject: [PATCH] add userdict loader --- src/DictTrie.hpp | 94 +++++++++++++++++++++++------------------ src/MPSegment.hpp | 4 +- test/unittest/TTrie.cpp | 7 ++- 3 files changed, 57 insertions(+), 48 deletions(-) diff --git a/src/DictTrie.hpp b/src/DictTrie.hpp index 6d9ed04..7bce4ae 100644 --- a/src/DictTrie.hpp +++ b/src/DictTrie.hpp @@ -22,21 +22,21 @@ namespace CppJieba const double MIN_DOUBLE = -3.14e+100; const double MAX_DOUBLE = 3.14e+100; const size_t DICT_COLUMN_NUM = 3; + const char* const UNKNOWN_TAG = "x"; struct DictUnit { Unicode word; - size_t freq; + double weight; string tag; - double logFreq; //logFreq = log(freq/sum(freq)); }; inline ostream & operator << (ostream& os, const DictUnit& unit) { string s; s << unit.word; - return os << string_format("%s %u %s %.3lf", s.c_str(), unit.freq, unit.tag.c_str(), unit.logFreq); + return os << string_format("%s %s %.3lf", s.c_str(), unit.tag.c_str(), unit.weight); } typedef map DagType; @@ -49,15 +49,15 @@ namespace CppJieba vector _nodeInfos; TrieType * _trie; - size_t _freqSum; - double _minLogFreq; + double _minWeight; + public: + double getMinWeight() const {return _minWeight;}; public: DictTrie() { _trie = NULL; - _freqSum = 0; - _minLogFreq = MAX_DOUBLE; + _minWeight = MAX_DOUBLE; _setInitFlag(false); } DictTrie(const string& filePath) @@ -72,20 +72,22 @@ namespace CppJieba delete _trie; } } - private: - public: - bool init(const string& filePath) + bool init(const string& dictPath, const string& userDictPath = "") { assert(!_getInitFlag()); - _loadDict(filePath, _nodeInfos); + _loadDict(dictPath, _nodeInfos); + _calculateWeight(_nodeInfos); + _minWeight = _findMinWeight(_nodeInfos); + if(userDictPath.size()) + { + _loadUserDict(dictPath, _minWeight, UNKNOWN_TAG, _nodeInfos); + } _shrink(_nodeInfos); - _freqSum = _calculateFreqSum(_nodeInfos); - assert(_freqSum); - _minLogFreq = _calculateLogFreqAndGetMinValue(_nodeInfos, _freqSum); _trie = _creatTrie(_nodeInfos); - return _setInitFlag(_trie); + assert(_trie); + return _setInitFlag(true); } public: @@ -98,16 +100,11 @@ namespace CppJieba return _trie->find(begin, end, dag, offset); } - public: - double getMinLogFreq() const {return _minLogFreq;}; private: TrieType * _creatTrie(const vector& dictUnits) { - if(dictUnits.empty()) - { - return NULL; - } + assert(dictUnits.size()); vector words; vector valuePointers; for(size_t i = 0 ; i < dictUnits.size(); i ++) @@ -119,18 +116,31 @@ namespace CppJieba TrieType * trie = new TrieType(words, valuePointers); return trie; } + void _loadUserDict(const string& filePath, double defaultWeight, const string& defaultTag, vector& nodeInfos) const + { + ifstream ifs(filePath.c_str()); + assert(ifs); + string line; + DictUnit nodeInfo; + for(size_t lineno = 0; getline(ifs, line); lineno++) + { + if(!TransCode::decode(line, nodeInfo.word)) + { + LogError("line[%u:%s] illegal.", lineno, line.c_str()); + continue; + } + nodeInfo.weight = defaultWeight; + nodeInfo.tag = defaultTag; + nodeInfos.push_back(nodeInfo); + } + } void _loadDict(const string& filePath, vector& nodeInfos) const { ifstream ifs(filePath.c_str()); - if(!ifs) - { - LogFatal("open %s failed.", filePath.c_str()); - exit(1); - } + assert(ifs); string line; vector buf; - nodeInfos.clear(); DictUnit nodeInfo; for(size_t lineno = 0 ; getline(ifs, line); lineno++) { @@ -142,36 +152,36 @@ namespace CppJieba LogError("line[%u:%s] illegal.", lineno, line.c_str()); continue; } - nodeInfo.freq = atoi(buf[1].c_str()); + nodeInfo.weight = atof(buf[1].c_str()); nodeInfo.tag = buf[2]; nodeInfos.push_back(nodeInfo); } } - size_t _calculateFreqSum(const vector& nodeInfos) const + double _findMinWeight(const vector& nodeInfos) const { - size_t freqSum = 0; + double ret = MAX_DOUBLE; for(size_t i = 0; i < nodeInfos.size(); i++) { - freqSum += nodeInfos[i].freq; + ret = min(nodeInfos[i].weight, ret); } - return freqSum; + return ret; } - double _calculateLogFreqAndGetMinValue(vector& nodeInfos, size_t freqSum) const + + void _calculateWeight(vector& nodeInfos) const { - assert(freqSum); - double minLogFreq = MAX_DOUBLE; + double sum = 0.0; + for(size_t i = 0; i < nodeInfos.size(); i++) + { + sum += nodeInfos[i].weight; + } + assert(sum); for(size_t i = 0; i < nodeInfos.size(); i++) { DictUnit& nodeInfo = nodeInfos[i]; - assert(nodeInfo.freq); - nodeInfo.logFreq = log(double(nodeInfo.freq)/double(freqSum)); - if(minLogFreq > nodeInfo.logFreq) - { - minLogFreq = nodeInfo.logFreq; - } + assert(nodeInfo.weight); + nodeInfo.weight = log(double(nodeInfo.weight)/double(sum)); } - return minLogFreq; } void _shrink(vector& units) const diff --git a/src/MPSegment.hpp b/src/MPSegment.hpp index e266f02..5c43ffc 100644 --- a/src/MPSegment.hpp +++ b/src/MPSegment.hpp @@ -160,11 +160,11 @@ namespace CppJieba if(p) { - val += p->logFreq; + val += p->weight; } else { - val += _dictTrie.getMinLogFreq(); + val += _dictTrie.getMinWeight(); } if(val > SegmentChars[i].weight) { diff --git a/test/unittest/TTrie.cpp b/test/unittest/TTrie.cpp index 5ae91b2..d2c68ae 100644 --- a/test/unittest/TTrie.cpp +++ b/test/unittest/TTrie.cpp @@ -20,19 +20,18 @@ TEST(DictTrieTest, Test1) string s1, s2; DictTrie trie; ASSERT_TRUE(trie.init(DICT_FILE)); - ASSERT_LT(trie.getMinLogFreq() + 15.6479, 0.001); + ASSERT_LT(trie.getMinWeight() + 15.6479, 0.001); string word("来到"); Unicode uni; ASSERT_TRUE(TransCode::decode(word, uni)); DictUnit nodeInfo; nodeInfo.word = uni; - nodeInfo.freq = 8779; nodeInfo.tag = "v"; - nodeInfo.logFreq = -8.87033; + nodeInfo.weight = -8.87033; s1 << nodeInfo; s2 << (*trie.find(uni.begin(), uni.end())); - EXPECT_EQ("[\"26469\", \"21040\"] 8779 v -8.870", s2); + EXPECT_EQ("[\"26469\", \"21040\"] v -8.870", s2); word = "清华大学"; vector > res; map resMap;