diff --git a/server/server.cpp b/server/server.cpp index 8db24d7..93a1e2a 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -12,80 +12,70 @@ using namespace Husky; using namespace CppJieba; -class ReqHandler: public IRequestHandler -{ - public: - ReqHandler(const string& dictPath, const string& modelPath, const string& userDictPath): _segment(dictPath, modelPath, userDictPath){}; - virtual ~ReqHandler(){}; - public: - virtual bool do_GET(const HttpReqInfo& httpReq, string& strSnd) const - { - string sentence, tmp; - vector words; - httpReq.GET("key", tmp); - URLDecode(tmp, sentence); - _segment.cut(sentence, words); - if(httpReq.GET("format", tmp) && tmp == "simple") - { - join(words.begin(), words.end(), strSnd, " "); - return true; - } - strSnd << words; - return true; - } - virtual bool do_POST(const HttpReqInfo& httpReq, string& strSnd) const - { - vector words; - _segment.cut(httpReq.getBody(), words); - strSnd << words; - return true; - } - private: - MixSegment _segment; +class ReqHandler: public IRequestHandler { + public: + ReqHandler(const string& dictPath, const string& modelPath, const string& userDictPath): _segment(dictPath, modelPath, userDictPath) {}; + virtual ~ReqHandler() {}; + public: + virtual bool do_GET(const HttpReqInfo& httpReq, string& strSnd) const { + string sentence, tmp; + vector words; + httpReq.GET("key", tmp); + URLDecode(tmp, sentence); + _segment.cut(sentence, words); + if(httpReq.GET("format", tmp) && tmp == "simple") { + join(words.begin(), words.end(), strSnd, " "); + return true; + } + strSnd << words; + return true; + } + virtual bool do_POST(const HttpReqInfo& httpReq, string& strSnd) const { + vector words; + _segment.cut(httpReq.getBody(), words); + strSnd << words; + return true; + } + private: + MixSegment _segment; }; -bool run(int argc, char** argv) -{ - if(argc < 2) - { - return false; - } - Config conf(argv[1]); - if(!conf) - { - return false; - } - int port = 0; - int threadNumber = 0; - int queueMaxSize = 0; - string dictPath; - string modelPath; - string userDictPath; - LIMONP_CHECK(conf.get("port", port)); - LIMONP_CHECK(conf.get("thread_number", threadNumber)); - LIMONP_CHECK(conf.get("queue_max_size", queueMaxSize)); - LIMONP_CHECK(conf.get("dict_path", dictPath)); - LIMONP_CHECK(conf.get("model_path", modelPath)); - if(!conf.get("user_dict_path", userDictPath)) //optional - { - userDictPath = ""; - } +bool run(int argc, char** argv) { + if(argc < 2) { + return false; + } + Config conf(argv[1]); + if(!conf) { + return false; + } + int port = 0; + int threadNumber = 0; + int queueMaxSize = 0; + string dictPath; + string modelPath; + string userDictPath; + LIMONP_CHECK(conf.get("port", port)); + LIMONP_CHECK(conf.get("thread_number", threadNumber)); + LIMONP_CHECK(conf.get("queue_max_size", queueMaxSize)); + LIMONP_CHECK(conf.get("dict_path", dictPath)); + LIMONP_CHECK(conf.get("model_path", modelPath)); + if(!conf.get("user_dict_path", userDictPath)) { //optional + userDictPath = ""; + } - LogInfo("config info: %s", conf.getConfigInfo().c_str()); + LogInfo("config info: %s", conf.getConfigInfo().c_str()); - ReqHandler reqHandler(dictPath, modelPath, userDictPath); - ThreadPoolServer sf(threadNumber, queueMaxSize, port, reqHandler); - return sf.start(); + ReqHandler reqHandler(dictPath, modelPath, userDictPath); + ThreadPoolServer sf(threadNumber, queueMaxSize, port, reqHandler); + return sf.start(); } -int main(int argc, char* argv[]) -{ - if(!run(argc, argv)) - { - printf("usage: %s \n", argv[0]); - return EXIT_FAILURE; - } - return EXIT_SUCCESS; +int main(int argc, char* argv[]) { + if(!run(argc, argv)) { + printf("usage: %s \n", argv[0]); + return EXIT_FAILURE; + } + return EXIT_SUCCESS; } diff --git a/src/DictTrie.hpp b/src/DictTrie.hpp index 2e27646..adbd67e 100644 --- a/src/DictTrie.hpp +++ b/src/DictTrie.hpp @@ -15,206 +15,174 @@ -namespace CppJieba -{ - using namespace Limonp; - const double MIN_DOUBLE = -3.14e+100; - const double MAX_DOUBLE = 3.14e+100; - const size_t DICT_COLUMN_NUM = 3; - const char* const UNKNOWN_TAG = ""; +namespace CppJieba { +using namespace Limonp; +const double MIN_DOUBLE = -3.14e+100; +const double MAX_DOUBLE = 3.14e+100; +const size_t DICT_COLUMN_NUM = 3; +const char* const UNKNOWN_TAG = ""; - class DictTrie - { - public: +class DictTrie { + public: - DictTrie() - { - _trie = NULL; - _minWeight = MAX_DOUBLE; - } - DictTrie(const string& dictPath, const string& userDictPath = "") - { - new (this) DictTrie(); - init(dictPath, userDictPath); - } - ~DictTrie() - { - if(_trie) - { - delete _trie; - } - } - - bool init(const string& dictPath, const string& userDictPath = "") - { - if(_trie != NULL) - { - LogFatal("trie already initted"); - } - _loadDict(dictPath); - _calculateWeight(_nodeInfos); - _minWeight = _findMinWeight(_nodeInfos); - - if(userDictPath.size()) - { - double maxWeight = _findMaxWeight(_nodeInfos); - _loadUserDict(userDictPath, maxWeight, UNKNOWN_TAG); - } - _shrink(_nodeInfos); - _trie = _createTrie(_nodeInfos); - assert(_trie); - return true; - } + DictTrie() { + _trie = NULL; + _minWeight = MAX_DOUBLE; + } + DictTrie(const string& dictPath, const string& userDictPath = "") { + new (this) DictTrie(); + init(dictPath, userDictPath); + } + ~DictTrie() { + if(_trie) { + delete _trie; + } + } - const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const - { - return _trie->find(begin, end); - } - bool find(Unicode::const_iterator begin, Unicode::const_iterator end, DagType& dag, size_t offset = 0) const - { - return _trie->find(begin, end, dag, offset); - } - void find( - Unicode::const_iterator begin, - Unicode::const_iterator end, - vector& res - ) const - { - _trie->find(begin, end, res); - } - bool isUserDictSingleChineseWord(const Unicode::value_type& word) const - { - return isIn(_userDictSingleChineseWord, word); - } - double getMinWeight() const {return _minWeight;}; + bool init(const string& dictPath, const string& userDictPath = "") { + if(_trie != NULL) { + LogFatal("trie already initted"); + } + _loadDict(dictPath); + _calculateWeight(_nodeInfos); + _minWeight = _findMinWeight(_nodeInfos); + + if(userDictPath.size()) { + double maxWeight = _findMaxWeight(_nodeInfos); + _loadUserDict(userDictPath, maxWeight, UNKNOWN_TAG); + } + _shrink(_nodeInfos); + _trie = _createTrie(_nodeInfos); + assert(_trie); + return true; + } + + const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const { + return _trie->find(begin, end); + } + bool find(Unicode::const_iterator begin, Unicode::const_iterator end, DagType& dag, size_t offset = 0) const { + return _trie->find(begin, end, dag, offset); + } + void find( + Unicode::const_iterator begin, + Unicode::const_iterator end, + vector& res + ) const { + _trie->find(begin, end, res); + } + bool isUserDictSingleChineseWord(const Unicode::value_type& word) const { + return isIn(_userDictSingleChineseWord, word); + } + double getMinWeight() const { + return _minWeight; + }; - private: - Trie * _createTrie(const vector& dictUnits) - { - assert(dictUnits.size()); - vector words; - vector valuePointers; - for(size_t i = 0 ; i < dictUnits.size(); i ++) - { - words.push_back(dictUnits[i].word); - valuePointers.push_back(&dictUnits[i]); - } + private: + Trie * _createTrie(const vector& dictUnits) { + assert(dictUnits.size()); + vector words; + vector valuePointers; + for(size_t i = 0 ; i < dictUnits.size(); i ++) { + words.push_back(dictUnits[i].word); + valuePointers.push_back(&dictUnits[i]); + } - Trie * trie = new Trie(words, valuePointers); - return trie; - } - void _loadUserDict(const string& filePath, double defaultWeight, const string& defaultTag) - { - ifstream ifs(filePath.c_str()); - if(!ifs.is_open()) - { - LogFatal("file %s open failed.", filePath.c_str()); - } - string line; - DictUnit nodeInfo; - vector buf; - size_t lineno; - for(lineno = 0; getline(ifs, line); lineno++) - { - buf.clear(); - split(line, buf, " "); - if(buf.size() < 1) - { - LogFatal("split [%s] result illegal", line.c_str()); - } - if(!TransCode::decode(buf[0], nodeInfo.word)) - { - LogError("line[%u:%s] illegal.", lineno, line.c_str()); - continue; - } - if(nodeInfo.word.size() == 1) - { - _userDictSingleChineseWord.insert(nodeInfo.word[0]); - } - nodeInfo.weight = defaultWeight; - nodeInfo.tag = (buf.size() == 2 ? buf[1] : defaultTag); - _nodeInfos.push_back(nodeInfo); - } - LogInfo("load userdict[%s] ok. lines[%u]", filePath.c_str(), lineno); - } - void _loadDict(const string& filePath) - { - ifstream ifs(filePath.c_str()); - if(!ifs.is_open()) - { - LogFatal("file %s open failed.", filePath.c_str()); - } - string line; - vector buf; + Trie * trie = new Trie(words, valuePointers); + return trie; + } + void _loadUserDict(const string& filePath, double defaultWeight, const string& defaultTag) { + ifstream ifs(filePath.c_str()); + if(!ifs.is_open()) { + LogFatal("file %s open failed.", filePath.c_str()); + } + string line; + DictUnit nodeInfo; + vector buf; + size_t lineno; + for(lineno = 0; getline(ifs, line); lineno++) { + buf.clear(); + split(line, buf, " "); + if(buf.size() < 1) { + LogFatal("split [%s] result illegal", line.c_str()); + } + if(!TransCode::decode(buf[0], nodeInfo.word)) { + LogError("line[%u:%s] illegal.", lineno, line.c_str()); + continue; + } + if(nodeInfo.word.size() == 1) { + _userDictSingleChineseWord.insert(nodeInfo.word[0]); + } + nodeInfo.weight = defaultWeight; + nodeInfo.tag = (buf.size() == 2 ? buf[1] : defaultTag); + _nodeInfos.push_back(nodeInfo); + } + LogInfo("load userdict[%s] ok. lines[%u]", filePath.c_str(), lineno); + } + void _loadDict(const string& filePath) { + ifstream ifs(filePath.c_str()); + if(!ifs.is_open()) { + LogFatal("file %s open failed.", filePath.c_str()); + } + string line; + vector buf; - DictUnit nodeInfo; - for(size_t lineno = 0; getline(ifs, line); lineno++) - { - split(line, buf, " "); - if(buf.size() != DICT_COLUMN_NUM) - { - LogFatal("split result illegal, line: %s, result size: %u", line.c_str(), buf.size()); - } - - if(!TransCode::decode(buf[0], nodeInfo.word)) - { - LogError("line[%u:%s] illegal.", lineno, line.c_str()); - continue; - } - nodeInfo.weight = atof(buf[1].c_str()); - nodeInfo.tag = buf[2]; - - _nodeInfos.push_back(nodeInfo); - } - } - double _findMinWeight(const vector& nodeInfos) const - { - double ret = MAX_DOUBLE; - for(size_t i = 0; i < nodeInfos.size(); i++) - { - ret = min(nodeInfos[i].weight, ret); - } - return ret; - } - double _findMaxWeight(const vector& nodeInfos) const - { - double ret = MIN_DOUBLE; - for(size_t i = 0; i < nodeInfos.size(); i++) - { - ret = max(nodeInfos[i].weight, ret); - } - return ret; - } + DictUnit nodeInfo; + for(size_t lineno = 0; getline(ifs, line); lineno++) { + split(line, buf, " "); + if(buf.size() != DICT_COLUMN_NUM) { + LogFatal("split result illegal, line: %s, result size: %u", line.c_str(), buf.size()); + } - void _calculateWeight(vector& nodeInfos) const - { - double sum = 0.0; - for(size_t i = 0; i < nodeInfos.size(); i++) - { - sum += nodeInfos[i].weight; - } - assert(sum); - for(size_t i = 0; i < nodeInfos.size(); i++) - { - DictUnit& nodeInfo = nodeInfos[i]; - assert(nodeInfo.weight); - nodeInfo.weight = log(double(nodeInfo.weight)/double(sum)); - } - } + if(!TransCode::decode(buf[0], nodeInfo.word)) { + LogError("line[%u:%s] illegal.", lineno, line.c_str()); + continue; + } + nodeInfo.weight = atof(buf[1].c_str()); + nodeInfo.tag = buf[2]; - void _shrink(vector& units) const - { - vector(units.begin(), units.end()).swap(units); - } + _nodeInfos.push_back(nodeInfo); + } + } + double _findMinWeight(const vector& nodeInfos) const { + double ret = MAX_DOUBLE; + for(size_t i = 0; i < nodeInfos.size(); i++) { + ret = min(nodeInfos[i].weight, ret); + } + return ret; + } + double _findMaxWeight(const vector& nodeInfos) const { + double ret = MIN_DOUBLE; + for(size_t i = 0; i < nodeInfos.size(); i++) { + ret = max(nodeInfos[i].weight, ret); + } + return ret; + } - private: - vector _nodeInfos; - Trie * _trie; + void _calculateWeight(vector& nodeInfos) const { + double sum = 0.0; + for(size_t i = 0; i < nodeInfos.size(); i++) { + sum += nodeInfos[i].weight; + } + assert(sum); + for(size_t i = 0; i < nodeInfos.size(); i++) { + DictUnit& nodeInfo = nodeInfos[i]; + assert(nodeInfo.weight); + nodeInfo.weight = log(double(nodeInfo.weight)/double(sum)); + } + } - double _minWeight; - unordered_set _userDictSingleChineseWord; - }; + void _shrink(vector& units) const { + vector(units.begin(), units.end()).swap(units); + } + + private: + vector _nodeInfos; + Trie * _trie; + + double _minWeight; + unordered_set _userDictSingleChineseWord; +}; } #endif diff --git a/src/FullSegment.hpp b/src/FullSegment.hpp index 0a3e747..a8b60a1 100644 --- a/src/FullSegment.hpp +++ b/src/FullSegment.hpp @@ -10,140 +10,116 @@ #include "SegmentBase.hpp" #include "TransCode.hpp" -namespace CppJieba -{ - class FullSegment: public SegmentBase - { - public: - FullSegment() - { - _dictTrie = NULL; - _isBorrowed = false; - } - explicit FullSegment(const string& dictPath) - { - _dictTrie = NULL; - init(dictPath); - } - explicit FullSegment(const DictTrie* dictTrie) - { - _dictTrie = NULL; - init(dictTrie); - } - virtual ~FullSegment() - { - if(_dictTrie && ! _isBorrowed) - { - delete _dictTrie; - } +namespace CppJieba { +class FullSegment: public SegmentBase { + public: + FullSegment() { + _dictTrie = NULL; + _isBorrowed = false; + } + explicit FullSegment(const string& dictPath) { + _dictTrie = NULL; + init(dictPath); + } + explicit FullSegment(const DictTrie* dictTrie) { + _dictTrie = NULL; + init(dictTrie); + } + virtual ~FullSegment() { + if(_dictTrie && ! _isBorrowed) { + delete _dictTrie; + } - }; - bool init(const string& dictPath) - { - assert(_dictTrie == NULL); - _dictTrie = new DictTrie(dictPath); - _isBorrowed = false; - return true; - } - bool init(const DictTrie* dictTrie) - { - assert(_dictTrie == NULL); - assert(dictTrie); - _dictTrie = dictTrie; - _isBorrowed = true; - return true; - } + }; + bool init(const string& dictPath) { + assert(_dictTrie == NULL); + _dictTrie = new DictTrie(dictPath); + _isBorrowed = false; + return true; + } + bool init(const DictTrie* dictTrie) { + assert(_dictTrie == NULL); + assert(dictTrie); + _dictTrie = dictTrie; + _isBorrowed = true; + return true; + } - using SegmentBase::cut; - bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const - { - assert(_dictTrie); - if (begin >= end) - { - LogError("begin >= end"); - return false; - } + using SegmentBase::cut; + bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const { + assert(_dictTrie); + if (begin >= end) { + LogError("begin >= end"); + return false; + } - //resut of searching in trie tree - DagType tRes; + //resut of searching in trie tree + DagType tRes; - //max index of res's words - int maxIdx = 0; + //max index of res's words + int maxIdx = 0; - // always equals to (uItr - begin) - int uIdx = 0; + // always equals to (uItr - begin) + int uIdx = 0; - //tmp variables - int wordLen = 0; - for (Unicode::const_iterator uItr = begin; uItr != end; uItr++) - { - //find word start from uItr - if (_dictTrie->find(uItr, end, tRes, 0)) - { - for(DagType::const_iterator itr = tRes.begin(); itr != tRes.end(); itr++) - //for (vector >::const_iterator itr = tRes.begin(); itr != tRes.end(); itr++) - { - wordLen = itr->second->word.size(); - if (wordLen >= 2 || (tRes.size() == 1 && maxIdx <= uIdx)) - { - res.push_back(itr->second->word); - } - maxIdx = uIdx+wordLen > maxIdx ? uIdx+wordLen : maxIdx; - } - tRes.clear(); - } - else // not found word start from uItr - { - if (maxIdx <= uIdx) // never exist in prev results - { - //put itr itself in res - res.push_back(Unicode(1, *uItr)); + //tmp variables + int wordLen = 0; + for (Unicode::const_iterator uItr = begin; uItr != end; uItr++) { + //find word start from uItr + if (_dictTrie->find(uItr, end, tRes, 0)) { + for(DagType::const_iterator itr = tRes.begin(); itr != tRes.end(); itr++) + //for (vector >::const_iterator itr = tRes.begin(); itr != tRes.end(); itr++) + { + wordLen = itr->second->word.size(); + if (wordLen >= 2 || (tRes.size() == 1 && maxIdx <= uIdx)) { + res.push_back(itr->second->word); + } + maxIdx = uIdx+wordLen > maxIdx ? uIdx+wordLen : maxIdx; + } + tRes.clear(); + } else { // not found word start from uItr + if (maxIdx <= uIdx) { // never exist in prev results + //put itr itself in res + res.push_back(Unicode(1, *uItr)); - //mark it exits - ++maxIdx; - } - } - ++uIdx; - } + //mark it exits + ++maxIdx; + } + } + ++uIdx; + } - return true; - } + return true; + } - bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const - { - assert(_dictTrie); - if (begin >= end) - { - LogError("begin >= end"); - return false; - } + bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const { + assert(_dictTrie); + if (begin >= end) { + LogError("begin >= end"); + return false; + } - vector uRes; - if (!cut(begin, end, uRes)) - { - LogError("get unicode cut result error."); - return false; - } + vector uRes; + if (!cut(begin, end, uRes)) { + LogError("get unicode cut result error."); + return false; + } - string tmp; - for (vector::const_iterator uItr = uRes.begin(); uItr != uRes.end(); uItr++) - { - if (TransCode::encode(*uItr, tmp)) - { - res.push_back(tmp); - } - else - { - LogError("encode failed."); - } - } + string tmp; + for (vector::const_iterator uItr = uRes.begin(); uItr != uRes.end(); uItr++) { + if (TransCode::encode(*uItr, tmp)) { + res.push_back(tmp); + } else { + LogError("encode failed."); + } + } - return true; - } - private: - const DictTrie* _dictTrie; - bool _isBorrowed; - }; + return true; + } + private: + const DictTrie* _dictTrie; + bool _isBorrowed; +}; } #endif diff --git a/src/HMMSegment.hpp b/src/HMMSegment.hpp index d7c8c89..d000bce 100644 --- a/src/HMMSegment.hpp +++ b/src/HMMSegment.hpp @@ -12,387 +12,315 @@ #include "SegmentBase.hpp" #include "DictTrie.hpp" -namespace CppJieba -{ - using namespace Limonp; - typedef unordered_map EmitProbMap; - class HMMSegment: public SegmentBase - { - public: - /* - * STATUS: - * 0:B, 1:E, 2:M, 3:S - * */ - enum {B = 0, E = 1, M = 2, S = 3, STATUS_SUM = 4}; +namespace CppJieba { +using namespace Limonp; +typedef unordered_map EmitProbMap; +class HMMSegment: public SegmentBase { + public: + /* + * STATUS: + * 0:B, 1:E, 2:M, 3:S + * */ + enum {B = 0, E = 1, M = 2, S = 3, STATUS_SUM = 4}; - public: - HMMSegment(){} - explicit HMMSegment(const string& filePath) - { - LIMONP_CHECK(init(filePath)); - } - virtual ~HMMSegment(){} - public: - bool init(const string& filePath) - { - memset(_startProb, 0, sizeof(_startProb)); - memset(_transProb, 0, sizeof(_transProb)); - _statMap[0] = 'B'; - _statMap[1] = 'E'; - _statMap[2] = 'M'; - _statMap[3] = 'S'; - _emitProbVec.push_back(&_emitProbB); - _emitProbVec.push_back(&_emitProbE); - _emitProbVec.push_back(&_emitProbM); - _emitProbVec.push_back(&_emitProbS); - LIMONP_CHECK(_loadModel(filePath.c_str())); - LogInfo("HMMSegment init(%s) ok.", filePath.c_str()); - return true; - } - public: - using SegmentBase::cut; - public: - bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const - { - Unicode::const_iterator left = begin; - Unicode::const_iterator right = begin; - while(right != end) - { - if(*right < 0x80) - { - if(left != right && !_cut(left, right, res)) - { - return false; - } - left = right; - do { - right = _sequentialLetterRule(left, end); - if(right != left) - { - break; - } - right = _numbersRule(left, end); - if(right != left) - { - break; - } - right ++; - } while(false); - res.push_back(Unicode(left, right)); - left = right; - } - else - { - right++; - } - } - if(left != right && !_cut(left, right, res)) - { - return false; - } - return true; - } - public: - virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const - { - if(begin == end) - { - return false; - } - vector words; - words.reserve(end - begin); - if(!cut(begin, end, words)) - { - return false; - } - size_t offset = res.size(); - res.resize(res.size() + words.size()); - for(size_t i = 0; i < words.size(); i++) - { - if(!TransCode::encode(words[i], res[offset + i])) - { - LogError("encode failed."); - } - } - return true; - } - private: - // sequential letters rule - Unicode::const_iterator _sequentialLetterRule(Unicode::const_iterator begin, Unicode::const_iterator end) const - { - Unicode::value_type x = *begin; - if (('a' <= x && x <= 'z') || ('A' <= x && x <= 'Z')) - { - begin ++; - } - else - { - return begin; - } - while(begin != end) - { - x = *begin; - if(('a' <= x && x <= 'z') || ('A' <= x && x <= 'Z') || ('0' <= x && x <= '9')) - { - begin ++; - } - else - { - break; - } - } - return begin; - } - // - Unicode::const_iterator _numbersRule(Unicode::const_iterator begin, Unicode::const_iterator end) const - { - Unicode::value_type x = *begin; - if('0' <= x && x <= '9') - { - begin ++; - } - else - { - return begin; - } - while(begin != end) - { - x = *begin; - if( ('0' <= x && x <= '9') || x == '.') - { - begin++; - } - else - { - break; - } - } - return begin; - } - bool _cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const - { - vector status; - if(!_viterbi(begin, end, status)) - { - LogError("_viterbi failed."); - return false; - } + public: + HMMSegment() {} + explicit HMMSegment(const string& filePath) { + LIMONP_CHECK(init(filePath)); + } + virtual ~HMMSegment() {} + public: + bool init(const string& filePath) { + memset(_startProb, 0, sizeof(_startProb)); + memset(_transProb, 0, sizeof(_transProb)); + _statMap[0] = 'B'; + _statMap[1] = 'E'; + _statMap[2] = 'M'; + _statMap[3] = 'S'; + _emitProbVec.push_back(&_emitProbB); + _emitProbVec.push_back(&_emitProbE); + _emitProbVec.push_back(&_emitProbM); + _emitProbVec.push_back(&_emitProbS); + LIMONP_CHECK(_loadModel(filePath.c_str())); + LogInfo("HMMSegment init(%s) ok.", filePath.c_str()); + return true; + } + public: + using SegmentBase::cut; + public: + bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const { + Unicode::const_iterator left = begin; + Unicode::const_iterator right = begin; + while(right != end) { + if(*right < 0x80) { + if(left != right && !_cut(left, right, res)) { + return false; + } + left = right; + do { + right = _sequentialLetterRule(left, end); + if(right != left) { + break; + } + right = _numbersRule(left, end); + if(right != left) { + break; + } + right ++; + } while(false); + res.push_back(Unicode(left, right)); + left = right; + } else { + right++; + } + } + if(left != right && !_cut(left, right, res)) { + return false; + } + return true; + } + public: + virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const { + if(begin == end) { + return false; + } + vector words; + words.reserve(end - begin); + if(!cut(begin, end, words)) { + return false; + } + size_t offset = res.size(); + res.resize(res.size() + words.size()); + for(size_t i = 0; i < words.size(); i++) { + if(!TransCode::encode(words[i], res[offset + i])) { + LogError("encode failed."); + } + } + return true; + } + private: + // sequential letters rule + Unicode::const_iterator _sequentialLetterRule(Unicode::const_iterator begin, Unicode::const_iterator end) const { + Unicode::value_type x = *begin; + if (('a' <= x && x <= 'z') || ('A' <= x && x <= 'Z')) { + begin ++; + } else { + return begin; + } + while(begin != end) { + x = *begin; + if(('a' <= x && x <= 'z') || ('A' <= x && x <= 'Z') || ('0' <= x && x <= '9')) { + begin ++; + } else { + break; + } + } + return begin; + } + // + Unicode::const_iterator _numbersRule(Unicode::const_iterator begin, Unicode::const_iterator end) const { + Unicode::value_type x = *begin; + if('0' <= x && x <= '9') { + begin ++; + } else { + return begin; + } + while(begin != end) { + x = *begin; + if( ('0' <= x && x <= '9') || x == '.') { + begin++; + } else { + break; + } + } + return begin; + } + bool _cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const { + vector status; + if(!_viterbi(begin, end, status)) { + LogError("_viterbi failed."); + return false; + } - Unicode::const_iterator left = begin; - Unicode::const_iterator right; - for(size_t i = 0; i < status.size(); i++) - { - if(status[i] % 2) //if(E == status[i] || S == status[i]) - { - right = begin + i + 1; - res.push_back(Unicode(left, right)); - left = right; - } - } - return true; - } + Unicode::const_iterator left = begin; + Unicode::const_iterator right; + for(size_t i = 0; i < status.size(); i++) { + if(status[i] % 2) { //if(E == status[i] || S == status[i]) + right = begin + i + 1; + res.push_back(Unicode(left, right)); + left = right; + } + } + return true; + } - bool _viterbi(Unicode::const_iterator begin, Unicode::const_iterator end, vector& status)const - { - if(begin == end) - { - return false; - } + bool _viterbi(Unicode::const_iterator begin, Unicode::const_iterator end, vector& status)const { + if(begin == end) { + return false; + } - size_t Y = STATUS_SUM; - size_t X = end - begin; + size_t Y = STATUS_SUM; + size_t X = end - begin; - size_t XYSize = X * Y; - size_t now, old, stat; - double tmp, endE, endS; + size_t XYSize = X * Y; + size_t now, old, stat; + double tmp, endE, endS; - vector path(XYSize); - vector weight(XYSize); + vector path(XYSize); + vector weight(XYSize); - //start - for(size_t y = 0; y < Y; y++) - { - weight[0 + y * X] = _startProb[y] + _getEmitProb(_emitProbVec[y], *begin, MIN_DOUBLE); - path[0 + y * X] = -1; - } + //start + for(size_t y = 0; y < Y; y++) { + weight[0 + y * X] = _startProb[y] + _getEmitProb(_emitProbVec[y], *begin, MIN_DOUBLE); + path[0 + y * X] = -1; + } - double emitProb; + double emitProb; - for(size_t x = 1; x < X; x++) - { - for(size_t y = 0; y < Y; y++) - { - now = x + y*X; - weight[now] = MIN_DOUBLE; - path[now] = E; // warning - emitProb = _getEmitProb(_emitProbVec[y], *(begin+x), MIN_DOUBLE); - for(size_t preY = 0; preY < Y; preY++) - { - old = x - 1 + preY * X; - tmp = weight[old] + _transProb[preY][y] + emitProb; - if(tmp > weight[now]) - { - weight[now] = tmp; - path[now] = preY; - } - } - } - } + for(size_t x = 1; x < X; x++) { + for(size_t y = 0; y < Y; y++) { + now = x + y*X; + weight[now] = MIN_DOUBLE; + path[now] = E; // warning + emitProb = _getEmitProb(_emitProbVec[y], *(begin+x), MIN_DOUBLE); + for(size_t preY = 0; preY < Y; preY++) { + old = x - 1 + preY * X; + tmp = weight[old] + _transProb[preY][y] + emitProb; + if(tmp > weight[now]) { + weight[now] = tmp; + path[now] = preY; + } + } + } + } - endE = weight[X-1+E*X]; - endS = weight[X-1+S*X]; - stat = 0; - if(endE >= endS) - { - stat = E; - } - else - { - stat = S; - } + endE = weight[X-1+E*X]; + endS = weight[X-1+S*X]; + stat = 0; + if(endE >= endS) { + stat = E; + } else { + stat = S; + } - status.resize(X); - for(int x = X -1 ; x >= 0; x--) - { - status[x] = stat; - stat = path[x + stat*X]; - } + status.resize(X); + for(int x = X -1 ; x >= 0; x--) { + status[x] = stat; + stat = path[x + stat*X]; + } - return true; - } - bool _loadModel(const char* const filePath) - { - ifstream ifile(filePath); - string line; - vector tmp; - vector tmp2; - //load _startProb - if(!_getLine(ifile, line)) - { - return false; - } - split(line, tmp, " "); - if(tmp.size() != STATUS_SUM) - { - LogError("start_p illegal"); - return false; - } - for(size_t j = 0; j< tmp.size(); j++) - { - _startProb[j] = atof(tmp[j].c_str()); - } + return true; + } + bool _loadModel(const char* const filePath) { + ifstream ifile(filePath); + string line; + vector tmp; + vector tmp2; + //load _startProb + if(!_getLine(ifile, line)) { + return false; + } + split(line, tmp, " "); + if(tmp.size() != STATUS_SUM) { + LogError("start_p illegal"); + return false; + } + for(size_t j = 0; j< tmp.size(); j++) { + _startProb[j] = atof(tmp[j].c_str()); + } - //load _transProb - for(size_t i = 0; i < STATUS_SUM; i++) - { - if(!_getLine(ifile, line)) - { - return false; - } - split(line, tmp, " "); - if(tmp.size() != STATUS_SUM) - { - LogError("trans_p illegal"); - return false; - } - for(size_t j =0; j < STATUS_SUM; j++) - { - _transProb[i][j] = atof(tmp[j].c_str()); - } - } + //load _transProb + for(size_t i = 0; i < STATUS_SUM; i++) { + if(!_getLine(ifile, line)) { + return false; + } + split(line, tmp, " "); + if(tmp.size() != STATUS_SUM) { + LogError("trans_p illegal"); + return false; + } + for(size_t j =0; j < STATUS_SUM; j++) { + _transProb[i][j] = atof(tmp[j].c_str()); + } + } - //load _emitProbB - if(!_getLine(ifile, line) || !_loadEmitProb(line, _emitProbB)) - { - return false; - } + //load _emitProbB + if(!_getLine(ifile, line) || !_loadEmitProb(line, _emitProbB)) { + return false; + } - //load _emitProbE - if(!_getLine(ifile, line) || !_loadEmitProb(line, _emitProbE)) - { - return false; - } + //load _emitProbE + if(!_getLine(ifile, line) || !_loadEmitProb(line, _emitProbE)) { + return false; + } - //load _emitProbM - if(!_getLine(ifile, line) || !_loadEmitProb(line, _emitProbM)) - { - return false; - } + //load _emitProbM + if(!_getLine(ifile, line) || !_loadEmitProb(line, _emitProbM)) { + return false; + } - //load _emitProbS - if(!_getLine(ifile, line) || !_loadEmitProb(line, _emitProbS)) - { - return false; - } + //load _emitProbS + if(!_getLine(ifile, line) || !_loadEmitProb(line, _emitProbS)) { + return false; + } - return true; - } - bool _getLine(ifstream& ifile, string& line) - { - while(getline(ifile, line)) - { - trim(line); - if(line.empty()) - { - continue; - } - if(startsWith(line, "#")) - { - continue; - } - return true; - } - return false; - } - bool _loadEmitProb(const string& line, EmitProbMap& mp) - { - if(line.empty()) - { - return false; - } - vector tmp, tmp2; - Unicode unicode; - split(line, tmp, ","); - for(size_t i = 0; i < tmp.size(); i++) - { - split(tmp[i], tmp2, ":"); - if(2 != tmp2.size()) - { - LogError("_emitProb illegal."); - return false; - } - if(!TransCode::decode(tmp2[0], unicode) || unicode.size() != 1) - { - LogError("TransCode failed."); - return false; - } - mp[unicode[0]] = atof(tmp2[1].c_str()); - } - return true; - } - double _getEmitProb(const EmitProbMap* ptMp, uint16_t key, double defVal)const - { - EmitProbMap::const_iterator cit = ptMp->find(key); - if(cit == ptMp->end()) - { - return defVal; - } - return cit->second; + return true; + } + bool _getLine(ifstream& ifile, string& line) { + while(getline(ifile, line)) { + trim(line); + if(line.empty()) { + continue; + } + if(startsWith(line, "#")) { + continue; + } + return true; + } + return false; + } + bool _loadEmitProb(const string& line, EmitProbMap& mp) { + if(line.empty()) { + return false; + } + vector tmp, tmp2; + Unicode unicode; + split(line, tmp, ","); + for(size_t i = 0; i < tmp.size(); i++) { + split(tmp[i], tmp2, ":"); + if(2 != tmp2.size()) { + LogError("_emitProb illegal."); + return false; + } + if(!TransCode::decode(tmp2[0], unicode) || unicode.size() != 1) { + LogError("TransCode failed."); + return false; + } + mp[unicode[0]] = atof(tmp2[1].c_str()); + } + return true; + } + double _getEmitProb(const EmitProbMap* ptMp, uint16_t key, double defVal)const { + EmitProbMap::const_iterator cit = ptMp->find(key); + if(cit == ptMp->end()) { + return defVal; + } + return cit->second; - } + } - private: - char _statMap[STATUS_SUM]; - double _startProb[STATUS_SUM]; - double _transProb[STATUS_SUM][STATUS_SUM]; - EmitProbMap _emitProbB; - EmitProbMap _emitProbE; - EmitProbMap _emitProbM; - EmitProbMap _emitProbS; - vector _emitProbVec; + private: + char _statMap[STATUS_SUM]; + double _startProb[STATUS_SUM]; + double _transProb[STATUS_SUM][STATUS_SUM]; + EmitProbMap _emitProbB; + EmitProbMap _emitProbE; + EmitProbMap _emitProbM; + EmitProbMap _emitProbS; + vector _emitProbVec; - }; +}; } #endif diff --git a/src/ISegment.hpp b/src/ISegment.hpp index 167e2f9..4faded5 100644 --- a/src/ISegment.hpp +++ b/src/ISegment.hpp @@ -2,15 +2,13 @@ #define CPPJIEBA_SEGMENTINTERFACE_H -namespace CppJieba -{ - class ISegment - { - public: - virtual ~ISegment(){}; - virtual bool cut(Unicode::const_iterator begin , Unicode::const_iterator end, vector& res) const = 0; - virtual bool cut(const string& str, vector& res) const = 0; - }; +namespace CppJieba { +class ISegment { + public: + virtual ~ISegment() {}; + virtual bool cut(Unicode::const_iterator begin , Unicode::const_iterator end, vector& res) const = 0; + virtual bool cut(const string& str, vector& res) const = 0; +}; } #endif diff --git a/src/KeywordExtractor.hpp b/src/KeywordExtractor.hpp index a6ed647..3a4f1c5 100644 --- a/src/KeywordExtractor.hpp +++ b/src/KeywordExtractor.hpp @@ -5,160 +5,134 @@ #include #include -namespace CppJieba -{ - using namespace Limonp; +namespace CppJieba { +using namespace Limonp; - /*utf8*/ - class KeywordExtractor - { - public: - KeywordExtractor(){}; - KeywordExtractor(const string& dictPath, const string& hmmFilePath, const string& idfPath, const string& stopWordPath, const string& userDict = "") - { - init(dictPath, hmmFilePath, idfPath, stopWordPath, userDict); - }; - ~KeywordExtractor(){}; +/*utf8*/ +class KeywordExtractor { + public: + KeywordExtractor() {}; + KeywordExtractor(const string& dictPath, const string& hmmFilePath, const string& idfPath, const string& stopWordPath, const string& userDict = "") { + init(dictPath, hmmFilePath, idfPath, stopWordPath, userDict); + }; + ~KeywordExtractor() {}; - void init(const string& dictPath, const string& hmmFilePath, const string& idfPath, const string& stopWordPath, const string& userDict = "") - { - _loadIdfDict(idfPath); - _loadStopWordDict(stopWordPath); - LIMONP_CHECK(_segment.init(dictPath, hmmFilePath, userDict)); - }; + void init(const string& dictPath, const string& hmmFilePath, const string& idfPath, const string& stopWordPath, const string& userDict = "") { + _loadIdfDict(idfPath); + _loadStopWordDict(stopWordPath); + LIMONP_CHECK(_segment.init(dictPath, hmmFilePath, userDict)); + }; - bool extract(const string& str, vector& keywords, size_t topN) const - { - vector > topWords; - if(!extract(str, topWords, topN)) - { - return false; - } - for(size_t i = 0; i < topWords.size(); i++) - { - keywords.push_back(topWords[i].first); - } - return true; - } + bool extract(const string& str, vector& keywords, size_t topN) const { + vector > topWords; + if(!extract(str, topWords, topN)) { + return false; + } + for(size_t i = 0; i < topWords.size(); i++) { + keywords.push_back(topWords[i].first); + } + return true; + } - bool extract(const string& str, vector >& keywords, size_t topN) const - { - vector words; - if(!_segment.cut(str, words)) - { - LogError("segment cut(%s) failed.", str.c_str()); - return false; - } + bool extract(const string& str, vector >& keywords, size_t topN) const { + vector words; + if(!_segment.cut(str, words)) { + LogError("segment cut(%s) failed.", str.c_str()); + return false; + } - map wordmap; - for(vector::iterator iter = words.begin(); iter != words.end(); iter++) - { - if(_isSingleWord(*iter)) - { - continue; - } - wordmap[*iter] += 1.0; - } + map wordmap; + for(vector::iterator iter = words.begin(); iter != words.end(); iter++) { + if(_isSingleWord(*iter)) { + continue; + } + wordmap[*iter] += 1.0; + } - for(map::iterator itr = wordmap.begin(); itr != wordmap.end(); ) - { - if(_stopWords.end() != _stopWords.find(itr->first)) - { - wordmap.erase(itr++); - continue; - } + for(map::iterator itr = wordmap.begin(); itr != wordmap.end(); ) { + if(_stopWords.end() != _stopWords.find(itr->first)) { + wordmap.erase(itr++); + continue; + } - unordered_map::const_iterator cit = _idfMap.find(itr->first); - if(cit != _idfMap.end()) - { - itr->second *= cit->second; - } - else - { - itr->second *= _idfAverage; - } - itr ++; - } + unordered_map::const_iterator cit = _idfMap.find(itr->first); + if(cit != _idfMap.end()) { + itr->second *= cit->second; + } else { + itr->second *= _idfAverage; + } + itr ++; + } - keywords.clear(); - std::copy(wordmap.begin(), wordmap.end(), std::inserter(keywords, keywords.begin())); - topN = min(topN, keywords.size()); - partial_sort(keywords.begin(), keywords.begin() + topN, keywords.end(), _cmp); - keywords.resize(topN); - return true; - } - private: - void _loadIdfDict(const string& idfPath) - { - ifstream ifs(idfPath.c_str()); - if(!ifs.is_open()) - { - LogFatal("open %s failed.", idfPath.c_str()); - } - string line ; - vector buf; - double idf = 0.0; - double idfSum = 0.0; - size_t lineno = 0; - for(;getline(ifs, line); lineno++) - { - buf.clear(); - if(line.empty()) - { - LogError("line[%d] empty. skipped.", lineno); - continue; - } - if(!split(line, buf, " ") || buf.size() != 2) - { - LogError("line %d [%s] illegal. skipped.", lineno, line.c_str()); - continue; - } - idf = atof(buf[1].c_str()); - _idfMap[buf[0]] = idf; - idfSum += idf; + keywords.clear(); + std::copy(wordmap.begin(), wordmap.end(), std::inserter(keywords, keywords.begin())); + topN = min(topN, keywords.size()); + partial_sort(keywords.begin(), keywords.begin() + topN, keywords.end(), _cmp); + keywords.resize(topN); + return true; + } + private: + void _loadIdfDict(const string& idfPath) { + ifstream ifs(idfPath.c_str()); + if(!ifs.is_open()) { + LogFatal("open %s failed.", idfPath.c_str()); + } + string line ; + vector buf; + double idf = 0.0; + double idfSum = 0.0; + size_t lineno = 0; + for(; getline(ifs, line); lineno++) { + buf.clear(); + if(line.empty()) { + LogError("line[%d] empty. skipped.", lineno); + continue; + } + if(!split(line, buf, " ") || buf.size() != 2) { + LogError("line %d [%s] illegal. skipped.", lineno, line.c_str()); + continue; + } + idf = atof(buf[1].c_str()); + _idfMap[buf[0]] = idf; + idfSum += idf; - } + } - assert(lineno); - _idfAverage = idfSum / lineno; - assert(_idfAverage > 0.0); - } - void _loadStopWordDict(const string& filePath) - { - ifstream ifs(filePath.c_str()); - if(!ifs.is_open()) - { - LogFatal("open %s failed.", filePath.c_str()); - } - string line ; - while(getline(ifs, line)) - { - _stopWords.insert(line); - } - assert(_stopWords.size()); - } + assert(lineno); + _idfAverage = idfSum / lineno; + assert(_idfAverage > 0.0); + } + void _loadStopWordDict(const string& filePath) { + ifstream ifs(filePath.c_str()); + if(!ifs.is_open()) { + LogFatal("open %s failed.", filePath.c_str()); + } + string line ; + while(getline(ifs, line)) { + _stopWords.insert(line); + } + assert(_stopWords.size()); + } - bool _isSingleWord(const string& str) const - { - Unicode unicode; - TransCode::decode(str, unicode); - if(unicode.size() == 1) - return true; - return false; - } + bool _isSingleWord(const string& str) const { + Unicode unicode; + TransCode::decode(str, unicode); + if(unicode.size() == 1) + return true; + return false; + } - static bool _cmp(const pair& lhs, const pair& rhs) - { - return lhs.second > rhs.second; - } - - private: - MixSegment _segment; - unordered_map _idfMap; - double _idfAverage; + static bool _cmp(const pair& lhs, const pair& rhs) { + return lhs.second > rhs.second; + } - unordered_set _stopWords; - }; + private: + MixSegment _segment; + unordered_map _idfMap; + double _idfAverage; + + unordered_set _stopWords; +}; } #endif diff --git a/src/MPSegment.hpp b/src/MPSegment.hpp index 36b756a..971da1a 100644 --- a/src/MPSegment.hpp +++ b/src/MPSegment.hpp @@ -9,140 +9,114 @@ #include "ISegment.hpp" #include "SegmentBase.hpp" -namespace CppJieba -{ +namespace CppJieba { - class MPSegment: public SegmentBase - { +class MPSegment: public SegmentBase { - public: - MPSegment(){}; - MPSegment(const string& dictPath, const string& userDictPath = "") - { - LIMONP_CHECK(init(dictPath, userDictPath)); - }; - virtual ~MPSegment(){}; + public: + MPSegment() {}; + MPSegment(const string& dictPath, const string& userDictPath = "") { + LIMONP_CHECK(init(dictPath, userDictPath)); + }; + virtual ~MPSegment() {}; - bool init(const string& dictPath, const string& userDictPath = "") - { - LIMONP_CHECK(_dictTrie.init(dictPath, userDictPath)); - LogInfo("MPSegment init(%s) ok", dictPath.c_str()); - return true; - } - bool isUserDictSingleChineseWord(const Unicode::value_type & value) const - { - return _dictTrie.isUserDictSingleChineseWord(value); - } + bool init(const string& dictPath, const string& userDictPath = "") { + LIMONP_CHECK(_dictTrie.init(dictPath, userDictPath)); + LogInfo("MPSegment init(%s) ok", dictPath.c_str()); + return true; + } + bool isUserDictSingleChineseWord(const Unicode::value_type & value) const { + return _dictTrie.isUserDictSingleChineseWord(value); + } - using SegmentBase::cut; - virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const - { - if(begin == end) - { - return false; - } + using SegmentBase::cut; + virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const { + if(begin == end) { + return false; + } - vector words; - words.reserve(end - begin); - if(!cut(begin, end, words)) - { - return false; - } - size_t offset = res.size(); - res.resize(res.size() + words.size()); - for(size_t i = 0; i < words.size(); i++) - { - if(!TransCode::encode(words[i], res[i + offset])) - { - LogError("encode failed."); - res[i + offset].clear(); - } - } - return true; - } + vector words; + words.reserve(end - begin); + if(!cut(begin, end, words)) { + return false; + } + size_t offset = res.size(); + res.resize(res.size() + words.size()); + for(size_t i = 0; i < words.size(); i++) { + if(!TransCode::encode(words[i], res[i + offset])) { + LogError("encode failed."); + res[i + offset].clear(); + } + } + return true; + } - bool cut(Unicode::const_iterator begin , Unicode::const_iterator end, vector& res) const - { - if(end == begin) - { - return false; - } - vector segmentChars; + bool cut(Unicode::const_iterator begin , Unicode::const_iterator end, vector& res) const { + if(end == begin) { + return false; + } + vector segmentChars; - _dictTrie.find(begin, end, segmentChars); + _dictTrie.find(begin, end, segmentChars); - _calcDP(segmentChars); + _calcDP(segmentChars); - _cut(segmentChars, res); + _cut(segmentChars, res); - return true; - } - const DictTrie* getDictTrie() const - { - return &_dictTrie; - } + return true; + } + const DictTrie* getDictTrie() const { + return &_dictTrie; + } - private: - void _calcDP(vector& segmentChars) const - { - size_t nextPos; - const DictUnit* p; - double val; + private: + void _calcDP(vector& segmentChars) const { + size_t nextPos; + const DictUnit* p; + double val; - for(ssize_t i = segmentChars.size() - 1; i >= 0; i--) - { - segmentChars[i].pInfo = NULL; - segmentChars[i].weight = MIN_DOUBLE; - assert(!segmentChars[i].dag.empty()); - for(DagType::const_iterator it = segmentChars[i].dag.begin(); it != segmentChars[i].dag.end(); it++) - { - nextPos = it->first; - p = it->second; - val = 0.0; - if(nextPos + 1 < segmentChars.size()) - { - val += segmentChars[nextPos + 1].weight; - } + for(ssize_t i = segmentChars.size() - 1; i >= 0; i--) { + segmentChars[i].pInfo = NULL; + segmentChars[i].weight = MIN_DOUBLE; + assert(!segmentChars[i].dag.empty()); + for(DagType::const_iterator it = segmentChars[i].dag.begin(); it != segmentChars[i].dag.end(); it++) { + nextPos = it->first; + p = it->second; + val = 0.0; + if(nextPos + 1 < segmentChars.size()) { + val += segmentChars[nextPos + 1].weight; + } - if(p) - { - val += p->weight; - } - else - { - val += _dictTrie.getMinWeight(); - } - if(val > segmentChars[i].weight) - { - segmentChars[i].pInfo = p; - segmentChars[i].weight = val; - } - } - } - } - void _cut(const vector& segmentChars, vector& res) const - { - size_t i = 0; - while(i < segmentChars.size()) - { - const DictUnit* p = segmentChars[i].pInfo; - if(p) - { - res.push_back(p->word); - i += p->word.size(); - } - else//single chinese word - { - res.push_back(Unicode(1, segmentChars[i].uniCh)); - i++; - } - } - } + if(p) { + val += p->weight; + } else { + val += _dictTrie.getMinWeight(); + } + if(val > segmentChars[i].weight) { + segmentChars[i].pInfo = p; + segmentChars[i].weight = val; + } + } + } + } + void _cut(const vector& segmentChars, vector& res) const { + size_t i = 0; + while(i < segmentChars.size()) { + const DictUnit* p = segmentChars[i].pInfo; + if(p) { + res.push_back(p->word); + i += p->word.size(); + } else { //single chinese word + res.push_back(Unicode(1, segmentChars[i].uniCh)); + i++; + } + } + } - private: - DictTrie _dictTrie; + private: + DictTrie _dictTrie; - }; +}; } #endif diff --git a/src/MixSegment.hpp b/src/MixSegment.hpp index 80e6615..2cc5a53 100644 --- a/src/MixSegment.hpp +++ b/src/MixSegment.hpp @@ -6,117 +6,98 @@ #include "HMMSegment.hpp" #include "Limonp/StringUtil.hpp" -namespace CppJieba -{ - class MixSegment: public SegmentBase - { - public: - MixSegment() - { - } - MixSegment(const string& mpSegDict, const string& hmmSegDict, const string& userDict = "") - { - LIMONP_CHECK(init(mpSegDict, hmmSegDict, userDict)); - } - virtual ~MixSegment() - { - } - bool init(const string& mpSegDict, const string& hmmSegDict, const string& userDict = "") - { - LIMONP_CHECK(_mpSeg.init(mpSegDict, userDict)); - LIMONP_CHECK(_hmmSeg.init(hmmSegDict)); - LogInfo("MixSegment init(%s, %s)", mpSegDict.c_str(), hmmSegDict.c_str()); - return true; - } - using SegmentBase::cut; - virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const - { - vector words; - words.reserve(end - begin); - if(!_mpSeg.cut(begin, end, words)) - { - LogError("mpSeg cutDAG failed."); - return false; - } +namespace CppJieba { +class MixSegment: public SegmentBase { + public: + MixSegment() { + } + MixSegment(const string& mpSegDict, const string& hmmSegDict, const string& userDict = "") { + LIMONP_CHECK(init(mpSegDict, hmmSegDict, userDict)); + } + virtual ~MixSegment() { + } + bool init(const string& mpSegDict, const string& hmmSegDict, const string& userDict = "") { + LIMONP_CHECK(_mpSeg.init(mpSegDict, userDict)); + LIMONP_CHECK(_hmmSeg.init(hmmSegDict)); + LogInfo("MixSegment init(%s, %s)", mpSegDict.c_str(), hmmSegDict.c_str()); + return true; + } + using SegmentBase::cut; + virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const { + vector words; + words.reserve(end - begin); + if(!_mpSeg.cut(begin, end, words)) { + LogError("mpSeg cutDAG failed."); + return false; + } - vector hmmRes; - hmmRes.reserve(end - begin); - Unicode piece; - piece.reserve(end - begin); - for (size_t i = 0, j = 0; i < words.size(); i++) - { - //if mp get a word, it's ok, put it into result - if (1 != words[i].size() || (words[i].size() == 1 && _mpSeg.isUserDictSingleChineseWord(words[i][0]))) - { - res.push_back(words[i]); - continue; - } + vector hmmRes; + hmmRes.reserve(end - begin); + Unicode piece; + piece.reserve(end - begin); + for (size_t i = 0, j = 0; i < words.size(); i++) { + //if mp get a word, it's ok, put it into result + if (1 != words[i].size() || (words[i].size() == 1 && _mpSeg.isUserDictSingleChineseWord(words[i][0]))) { + res.push_back(words[i]); + continue; + } - // if mp get a single one and it is not in userdict, collect it in sequence - j = i; - while (j < words.size() && 1 == words[j].size() && !_mpSeg.isUserDictSingleChineseWord(words[j][0])) - { - piece.push_back(words[j][0]); - j++; - } + // if mp get a single one and it is not in userdict, collect it in sequence + j = i; + while (j < words.size() && 1 == words[j].size() && !_mpSeg.isUserDictSingleChineseWord(words[j][0])) { + piece.push_back(words[j][0]); + j++; + } - // cut the sequence with hmm - if (!_hmmSeg.cut(piece.begin(), piece.end(), hmmRes)) - { - LogError("_hmmSeg cut failed."); - return false; - } + // cut the sequence with hmm + if (!_hmmSeg.cut(piece.begin(), piece.end(), hmmRes)) { + LogError("_hmmSeg cut failed."); + return false; + } - //put hmm result to result - for (size_t k = 0; k < hmmRes.size(); k++) - { - res.push_back(hmmRes[k]); - } + //put hmm result to result + for (size_t k = 0; k < hmmRes.size(); k++) { + res.push_back(hmmRes[k]); + } - //clear tmp vars - piece.clear(); - hmmRes.clear(); + //clear tmp vars + piece.clear(); + hmmRes.clear(); - //let i jump over this piece - i = j - 1; - } - return true; - } + //let i jump over this piece + i = j - 1; + } + return true; + } - virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const - { - if(begin == end) - { - return false; - } + virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const { + if(begin == end) { + return false; + } - vector uRes; - uRes.reserve(end - begin); - if (!cut(begin, end, uRes)) - { - return false; - } + vector uRes; + uRes.reserve(end - begin); + if (!cut(begin, end, uRes)) { + return false; + } - size_t offset = res.size(); - res.resize(res.size() + uRes.size()); - for(size_t i = 0; i < uRes.size(); i ++, offset++) - { - if(!TransCode::encode(uRes[i], res[offset])) - { - LogError("encode failed."); - } - } - return true; - } + size_t offset = res.size(); + res.resize(res.size() + uRes.size()); + for(size_t i = 0; i < uRes.size(); i ++, offset++) { + if(!TransCode::encode(uRes[i], res[offset])) { + LogError("encode failed."); + } + } + return true; + } - const DictTrie* getDictTrie() const - { - return _mpSeg.getDictTrie(); - } - private: - MPSegment _mpSeg; - HMMSegment _hmmSeg; - }; + const DictTrie* getDictTrie() const { + return _mpSeg.getDictTrie(); + } + private: + MPSegment _mpSeg; + HMMSegment _hmmSeg; +}; } #endif diff --git a/src/PosTagger.hpp b/src/PosTagger.hpp index 6d16695..e11f1df 100644 --- a/src/PosTagger.hpp +++ b/src/PosTagger.hpp @@ -5,106 +5,87 @@ #include "Limonp/StringUtil.hpp" #include "DictTrie.hpp" -namespace CppJieba -{ - using namespace Limonp; +namespace CppJieba { +using namespace Limonp; - static const char* const POS_M = "m"; - static const char* const POS_ENG = "eng"; - static const char* const POS_X = "x"; +static const char* const POS_M = "m"; +static const char* const POS_ENG = "eng"; +static const char* const POS_X = "x"; - class PosTagger - { - public: - PosTagger() - { - } - PosTagger( - const string& dictPath, - const string& hmmFilePath, - const string& userDictPath = "" - ) - { - init(dictPath, hmmFilePath, userDictPath); - } - ~PosTagger() - { - } - void init( - const string& dictPath, - const string& hmmFilePath, - const string& userDictPath = "" - ) - { - LIMONP_CHECK(_segment.init(dictPath, hmmFilePath, userDictPath)); - _dictTrie = _segment.getDictTrie(); - LIMONP_CHECK(_dictTrie); - }; - +class PosTagger { + public: + PosTagger() { + } + PosTagger( + const string& dictPath, + const string& hmmFilePath, + const string& userDictPath = "" + ) { + init(dictPath, hmmFilePath, userDictPath); + } + ~PosTagger() { + } + void init( + const string& dictPath, + const string& hmmFilePath, + const string& userDictPath = "" + ) { + LIMONP_CHECK(_segment.init(dictPath, hmmFilePath, userDictPath)); + _dictTrie = _segment.getDictTrie(); + LIMONP_CHECK(_dictTrie); + }; - bool tag(const string& src, vector >& res) const - { - vector cutRes; - if (!_segment.cut(src, cutRes)) - { - LogError("_mixSegment cut failed"); - return false; - } - const DictUnit *tmp = NULL; - Unicode unico; - for (vector::iterator itr = cutRes.begin(); itr != cutRes.end(); ++itr) - { - if (!TransCode::decode(*itr, unico)) - { - LogError("decode failed."); - return false; - } - tmp = _dictTrie->find(unico.begin(), unico.end()); - if(tmp == NULL || tmp->tag.empty()) - { - res.push_back(make_pair(*itr, _specialRule(unico))); - } - else - { - res.push_back(make_pair(*itr, tmp->tag)); - } - } - return !res.empty(); - } - private: - const char* _specialRule(const Unicode& unicode) const - { - size_t m = 0; - size_t eng = 0; - for(size_t i = 0; i < unicode.size() && eng < unicode.size() / 2; i++) - { - if(unicode[i] < 0x80) - { - eng ++; - if('0' <= unicode[i] && unicode[i] <= '9') - { - m++; - } - } - } - // ascii char is not found - if(eng == 0) - { - return POS_X; - } - // all the ascii is number char - if(m == eng) - { - return POS_M; - } - // the ascii chars contain english letter - return POS_ENG; - } - private: - MixSegment _segment; - const DictTrie * _dictTrie; - }; + bool tag(const string& src, vector >& res) const { + vector cutRes; + if (!_segment.cut(src, cutRes)) { + LogError("_mixSegment cut failed"); + return false; + } + + const DictUnit *tmp = NULL; + Unicode unico; + for (vector::iterator itr = cutRes.begin(); itr != cutRes.end(); ++itr) { + if (!TransCode::decode(*itr, unico)) { + LogError("decode failed."); + return false; + } + tmp = _dictTrie->find(unico.begin(), unico.end()); + if(tmp == NULL || tmp->tag.empty()) { + res.push_back(make_pair(*itr, _specialRule(unico))); + } else { + res.push_back(make_pair(*itr, tmp->tag)); + } + } + return !res.empty(); + } + private: + const char* _specialRule(const Unicode& unicode) const { + size_t m = 0; + size_t eng = 0; + for(size_t i = 0; i < unicode.size() && eng < unicode.size() / 2; i++) { + if(unicode[i] < 0x80) { + eng ++; + if('0' <= unicode[i] && unicode[i] <= '9') { + m++; + } + } + } + // ascii char is not found + if(eng == 0) { + return POS_X; + } + // all the ascii is number char + if(m == eng) { + return POS_M; + } + // the ascii chars contain english letter + return POS_ENG; + } + private: + MixSegment _segment; + const DictTrie * _dictTrie; +}; } #endif diff --git a/src/QuerySegment.hpp b/src/QuerySegment.hpp index 76a6c0e..c787d24 100644 --- a/src/QuerySegment.hpp +++ b/src/QuerySegment.hpp @@ -13,106 +13,86 @@ #include "TransCode.hpp" #include "DictTrie.hpp" -namespace CppJieba -{ - class QuerySegment: public SegmentBase - { - public: - QuerySegment(){}; - QuerySegment(const string& dict, const string& model, size_t maxWordLen, const string& userDict = "") - { - init(dict, model, maxWordLen, userDict); - }; - virtual ~QuerySegment(){}; - bool init(const string& dict, const string& model, size_t maxWordLen, const string& userDict = "") - { - LIMONP_CHECK(_mixSeg.init(dict, model, userDict)); - LIMONP_CHECK(_fullSeg.init(_mixSeg.getDictTrie())); - assert(maxWordLen); - _maxWordLen = maxWordLen; - return true; +namespace CppJieba { +class QuerySegment: public SegmentBase { + public: + QuerySegment() {}; + QuerySegment(const string& dict, const string& model, size_t maxWordLen, const string& userDict = "") { + init(dict, model, maxWordLen, userDict); + }; + virtual ~QuerySegment() {}; + bool init(const string& dict, const string& model, size_t maxWordLen, const string& userDict = "") { + LIMONP_CHECK(_mixSeg.init(dict, model, userDict)); + LIMONP_CHECK(_fullSeg.init(_mixSeg.getDictTrie())); + assert(maxWordLen); + _maxWordLen = maxWordLen; + return true; + } + using SegmentBase::cut; + bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const { + if (begin >= end) { + LogError("begin >= end"); + return false; + } + + //use mix cut first + vector mixRes; + if (!_mixSeg.cut(begin, end, mixRes)) { + LogError("_mixSeg cut failed."); + return false; + } + + vector fullRes; + for (vector::const_iterator mixResItr = mixRes.begin(); mixResItr != mixRes.end(); mixResItr++) { + + // if it's too long, cut with _fullSeg, put fullRes in res + if (mixResItr->size() > _maxWordLen) { + if (_fullSeg.cut(mixResItr->begin(), mixResItr->end(), fullRes)) { + for (vector::const_iterator fullResItr = fullRes.begin(); fullResItr != fullRes.end(); fullResItr++) { + res.push_back(*fullResItr); + } + + //clear tmp res + fullRes.clear(); } - using SegmentBase::cut; - bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const - { - if (begin >= end) - { - LogError("begin >= end"); - return false; - } + } else { // just use the mix result + res.push_back(*mixResItr); + } + } - //use mix cut first - vector mixRes; - if (!_mixSeg.cut(begin, end, mixRes)) - { - LogError("_mixSeg cut failed."); - return false; - } - - vector fullRes; - for (vector::const_iterator mixResItr = mixRes.begin(); mixResItr != mixRes.end(); mixResItr++) - { - - // if it's too long, cut with _fullSeg, put fullRes in res - if (mixResItr->size() > _maxWordLen) - { - if (_fullSeg.cut(mixResItr->begin(), mixResItr->end(), fullRes)) - { - for (vector::const_iterator fullResItr = fullRes.begin(); fullResItr != fullRes.end(); fullResItr++) - { - res.push_back(*fullResItr); - } - - //clear tmp res - fullRes.clear(); - } - } - else // just use the mix result - { - res.push_back(*mixResItr); - } - } - - return true; - } + return true; + } - bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const - { - if (begin >= end) - { - LogError("begin >= end"); - return false; - } + bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const { + if (begin >= end) { + LogError("begin >= end"); + return false; + } - vector uRes; - if (!cut(begin, end, uRes)) - { - LogError("get unicode cut result error."); - return false; - } + vector uRes; + if (!cut(begin, end, uRes)) { + LogError("get unicode cut result error."); + return false; + } - string tmp; - for (vector::const_iterator uItr = uRes.begin(); uItr != uRes.end(); uItr++) - { - if (TransCode::encode(*uItr, tmp)) - { - res.push_back(tmp); - } - else - { - LogError("encode failed."); - } - } + string tmp; + for (vector::const_iterator uItr = uRes.begin(); uItr != uRes.end(); uItr++) { + if (TransCode::encode(*uItr, tmp)) { + res.push_back(tmp); + } else { + LogError("encode failed."); + } + } - return true; - } - private: - MixSegment _mixSeg; - FullSegment _fullSeg; - size_t _maxWordLen; + return true; + } + private: + MixSegment _mixSeg; + FullSegment _fullSeg; + size_t _maxWordLen; - }; +}; } #endif diff --git a/src/SegmentBase.hpp b/src/SegmentBase.hpp index 55c881d..4288a31 100644 --- a/src/SegmentBase.hpp +++ b/src/SegmentBase.hpp @@ -9,70 +9,63 @@ #include -namespace CppJieba -{ - using namespace Limonp; +namespace CppJieba { +using namespace Limonp; - //const char* const SPECIAL_CHARS = " \t\n"; +//const char* const SPECIAL_CHARS = " \t\n"; #ifndef CPPJIEBA_GBK - const UnicodeValueType SPECIAL_SYMBOL[] = {32u, 9u, 10u, 12290u, 65292u}; +const UnicodeValueType SPECIAL_SYMBOL[] = {32u, 9u, 10u, 12290u, 65292u}; #else - const UnicodeValueType SPECIAL_SYMBOL[] = {32u, 9u, 10u}; +const UnicodeValueType SPECIAL_SYMBOL[] = {32u, 9u, 10u}; #endif - class SegmentBase: public ISegment, public NonCopyable - { - public: - SegmentBase(){_loadSpecialSymbols();}; - virtual ~SegmentBase(){}; - public: - virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const = 0; - virtual bool cut(const string& str, vector& res) const - { - res.clear(); +class SegmentBase: public ISegment, public NonCopyable { + public: + SegmentBase() { + _loadSpecialSymbols(); + }; + virtual ~SegmentBase() {}; + public: + virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const = 0; + virtual bool cut(const string& str, vector& res) const { + res.clear(); - Unicode unicode; - unicode.reserve(str.size()); + Unicode unicode; + unicode.reserve(str.size()); - TransCode::decode(str, unicode); - - Unicode::const_iterator left = unicode.begin(); - Unicode::const_iterator right; - - for(right = unicode.begin(); right != unicode.end(); right++) - { - if(isIn(_specialSymbols, *right)) - { - if(left != right) - { - cut(left, right, res); - } - res.resize(res.size() + 1); - TransCode::encode(right, right + 1, res.back()); - left = right + 1; - } - } - if(left != right) - { - cut(left, right, res); - } - - return true; - } - private: - void _loadSpecialSymbols() - { - size_t size = sizeof(SPECIAL_SYMBOL)/sizeof(*SPECIAL_SYMBOL); - for(size_t i = 0; i < size; i ++) - { - _specialSymbols.insert(SPECIAL_SYMBOL[i]); - } - assert(_specialSymbols.size()); - } - private: - unordered_set _specialSymbols; + TransCode::decode(str, unicode); - }; + Unicode::const_iterator left = unicode.begin(); + Unicode::const_iterator right; + + for(right = unicode.begin(); right != unicode.end(); right++) { + if(isIn(_specialSymbols, *right)) { + if(left != right) { + cut(left, right, res); + } + res.resize(res.size() + 1); + TransCode::encode(right, right + 1, res.back()); + left = right + 1; + } + } + if(left != right) { + cut(left, right, res); + } + + return true; + } + private: + void _loadSpecialSymbols() { + size_t size = sizeof(SPECIAL_SYMBOL)/sizeof(*SPECIAL_SYMBOL); + for(size_t i = 0; i < size; i ++) { + _specialSymbols.insert(SPECIAL_SYMBOL[i]); + } + assert(_specialSymbols.size()); + } + private: + unordered_set _specialSymbols; + +}; } #endif diff --git a/src/TransCode.hpp b/src/TransCode.hpp index 6b7c734..d46bfbb 100644 --- a/src/TransCode.hpp +++ b/src/TransCode.hpp @@ -9,55 +9,48 @@ #include "Limonp/StringUtil.hpp" #include "Limonp/LocalVector.hpp" -namespace CppJieba -{ +namespace CppJieba { - using namespace Limonp; - typedef uint16_t UnicodeValueType; - typedef Limonp::LocalVector Unicode; - namespace TransCode - { - inline bool decode(const string& str, Unicode& res) - { +using namespace Limonp; +typedef uint16_t UnicodeValueType; +typedef Limonp::LocalVector Unicode; +namespace TransCode { +inline bool decode(const string& str, Unicode& res) { #ifdef CPPJIEBA_GBK - return gbkTrans(str, res); + return gbkTrans(str, res); #else - return utf8ToUnicode(str, res); + return utf8ToUnicode(str, res); #endif - } +} - inline bool encode(Unicode::const_iterator begin, Unicode::const_iterator end, string& res) - { +inline bool encode(Unicode::const_iterator begin, Unicode::const_iterator end, string& res) { #ifdef CPPJIEBA_GBK - return gbkTrans(begin, end, res); + return gbkTrans(begin, end, res); #else - return unicodeToUtf8(begin, end, res); + return unicodeToUtf8(begin, end, res); #endif - } - - inline bool encode(const Unicode& uni, string& res) - { - return encode(uni.begin(), uni.end(), res); - } +} - // compiler is expected to optimized this function to avoid return value copy - inline string encode(Unicode::const_iterator begin, Unicode::const_iterator end) - { - string res; - res.reserve(end - begin); - encode(begin, end, res); - return res; - } +inline bool encode(const Unicode& uni, string& res) { + return encode(uni.begin(), uni.end(), res); +} - // compiler is expected to optimized this function to avoid return value copy - inline Unicode decode(const string& str) - { - Unicode unicode; - unicode.reserve(str.size()); - decode(str, unicode); - return unicode; - } - } +// compiler is expected to optimized this function to avoid return value copy +inline string encode(Unicode::const_iterator begin, Unicode::const_iterator end) { + string res; + res.reserve(end - begin); + encode(begin, end, res); + return res; +} + +// compiler is expected to optimized this function to avoid return value copy +inline Unicode decode(const string& str) { + Unicode unicode; + unicode.reserve(str.size()); + decode(str, unicode); + return unicode; +} +} } #endif diff --git a/src/Trie.hpp b/src/Trie.hpp index 6297443..1a35973 100644 --- a/src/Trie.hpp +++ b/src/Trie.hpp @@ -5,290 +5,241 @@ #include #include -namespace CppJieba -{ - using namespace std; +namespace CppJieba { +using namespace std; - struct DictUnit - { - Unicode word; - double weight; - string tag; - }; +struct DictUnit { + Unicode word; + double weight; + string tag; +}; - // for debugging - inline ostream & operator << (ostream& os, const DictUnit& unit) - { - string s; - s << unit.word; - return os << string_format("%s %s %.3lf", s.c_str(), unit.tag.c_str(), unit.weight); +// for debugging +inline ostream & operator << (ostream& os, const DictUnit& unit) { + string s; + s << unit.word; + return os << string_format("%s %s %.3lf", s.c_str(), unit.tag.c_str(), unit.weight); +} + +typedef LocalVector > DagType; + +struct SegmentChar { + uint16_t uniCh; + DagType dag; + const DictUnit * pInfo; + double weight; + size_t nextPos; + SegmentChar():uniCh(0), pInfo(NULL), weight(0.0), nextPos(0) { + } + ~SegmentChar() { + } +}; + +typedef Unicode::value_type TrieKey; + +class TrieNode { + public: + TrieNode(): fail(NULL), next(NULL), ptValue(NULL) { + } + const TrieNode * findNext(TrieKey key) const { + if(next == NULL) { + return NULL; } + NextMap::const_iterator iter = next->find(key); + if(iter == next->end()) { + return NULL; + } + return iter->second; + } + public: + typedef unordered_map NextMap; + TrieNode * fail; + NextMap * next; + const DictUnit * ptValue; +}; - typedef LocalVector > DagType; +class Trie { + public: + Trie(const vector& keys, const vector & valuePointers) { + _root = new TrieNode; + _createTrie(keys, valuePointers); + _build();// build automation + } + ~Trie() { + if(_root) { + _deleteNode(_root); + } + } + public: + const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const { + TrieNode::NextMap::const_iterator citer; + const TrieNode* ptNode = _root; + for(Unicode::const_iterator it = begin; it != end; it++) { + // build automation + assert(ptNode); + if(NULL == ptNode->next || ptNode->next->end() == (citer = ptNode->next->find(*it))) { + return NULL; + } + ptNode = citer->second; + } + return ptNode->ptValue; + } + // aho-corasick-automation + void find( + Unicode::const_iterator begin, + Unicode::const_iterator end, + vector& res + ) const { + res.resize(end - begin); + const TrieNode * now = _root; + const TrieNode* node; + // compiler will complain warnings if only "i < end - begin" . + for (size_t i = 0; i < size_t(end - begin); i++) { + Unicode::value_type ch = *(begin + i); + res[i].uniCh = ch; + assert(res[i].dag.empty()); + res[i].dag.push_back(pair::size_type, const DictUnit* >(i, NULL)); + bool flag = false; - struct SegmentChar - { - uint16_t uniCh; - DagType dag; - const DictUnit * pInfo; - double weight; - size_t nextPos; - SegmentChar():uniCh(0), pInfo(NULL), weight(0.0), nextPos(0) - {} - ~SegmentChar() - {} - }; + // rollback + while( now != _root ) { + node = now->findNext(ch); + if (node != NULL) { + flag = true; + break; + } else { + now = now->fail; + } + } - typedef Unicode::value_type TrieKey; - - class TrieNode - { - public: - TrieNode(): fail(NULL), next(NULL), ptValue(NULL) - {} - const TrieNode * findNext(TrieKey key) const - { - if(next == NULL) - { - return NULL; - } - NextMap::const_iterator iter = next->find(key); - if(iter == next->end()) - { - return NULL; - } - return iter->second; + if(!flag) { + node = now->findNext(ch); + } + if(node == NULL) { + now = _root; + } else { + now = node; + const TrieNode * temp = now; + while(temp != _root) { + if (temp->ptValue) { + size_t pos = i - temp->ptValue->word.size() + 1; + res[pos].dag.push_back(pair::size_type, const DictUnit* >(i, temp->ptValue)); + if(pos == i) { + res[pos].dag[0].second = temp->ptValue; } - public: - typedef unordered_map NextMap; - TrieNode * fail; - NextMap * next; - const DictUnit * ptValue; - }; + } + temp = temp->fail; + assert(temp); + } + } + } + } + bool find( + Unicode::const_iterator begin, + Unicode::const_iterator end, + DagType & res, + size_t offset = 0) const { + const TrieNode * ptNode = _root; + TrieNode::NextMap::const_iterator citer; + for(Unicode::const_iterator itr = begin; itr != end ; itr++) { + assert(ptNode); + if(NULL == ptNode->next || ptNode->next->end() == (citer = ptNode->next->find(*itr))) { + break; + } + ptNode = citer->second; + if(ptNode->ptValue) { + if(itr == begin && res.size() == 1) { // first singleword + res[0].second = ptNode->ptValue; + } else { + res.push_back(pair::size_type, const DictUnit* >(itr - begin + offset, ptNode->ptValue)); + } + } + } + return !res.empty(); + } + private: + void _build() { + queue que; + assert(_root->ptValue == NULL); + assert(_root->next); + _root->fail = NULL; + for(TrieNode::NextMap::iterator iter = _root->next->begin(); iter != _root->next->end(); iter++) { + iter->second->fail = _root; + que.push(iter->second); + } + TrieNode* back = NULL; + TrieNode::NextMap::iterator backiter; + while(!que.empty()) { + TrieNode * now = que.front(); + que.pop(); + if(now->next == NULL) { + continue; + } + for(TrieNode::NextMap::iterator iter = now->next->begin(); iter != now->next->end(); iter++) { + back = now->fail; + while(back != NULL) { + if(back->next && (backiter = back->next->find(iter->first)) != back->next->end()) { + iter->second->fail = backiter->second; + break; + } + back = back->fail; + } + if(back == NULL) { + iter->second->fail = _root; + } + que.push(iter->second); + } + } + } + void _createTrie(const vector& keys, const vector & valuePointers) { + if(valuePointers.empty() || keys.empty()) { + return; + } + assert(keys.size() == valuePointers.size()); - class Trie - { - public: - Trie(const vector& keys, const vector & valuePointers) - { - _root = new TrieNode; - _createTrie(keys, valuePointers); - _build();// build automation - } - ~Trie() - { - if(_root) - { - _deleteNode(_root); - } - } - public: - const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const - { - TrieNode::NextMap::const_iterator citer; - const TrieNode* ptNode = _root; - for(Unicode::const_iterator it = begin; it != end; it++) - {// build automation - assert(ptNode); - if(NULL == ptNode->next || ptNode->next->end() == (citer = ptNode->next->find(*it))) - { - return NULL; - } - ptNode = citer->second; - } - return ptNode->ptValue; - } - // aho-corasick-automation - void find( - Unicode::const_iterator begin, - Unicode::const_iterator end, - vector& res - ) const - { - res.resize(end - begin); - const TrieNode * now = _root; - const TrieNode* node; - // compiler will complain warnings if only "i < end - begin" . - for (size_t i = 0; i < size_t(end - begin); i++) - { - Unicode::value_type ch = *(begin + i); - res[i].uniCh = ch; - assert(res[i].dag.empty()); - res[i].dag.push_back(pair::size_type, const DictUnit* >(i, NULL)); - bool flag = false; + for(size_t i = 0; i < keys.size(); i++) { + _insertNode(keys[i], valuePointers[i]); + } + } + void _insertNode(const Unicode& key, const DictUnit* ptValue) { + TrieNode* ptNode = _root; - // rollback - while( now != _root ) - { - node = now->findNext(ch); - if (node != NULL) - { - flag = true; - break; - } - else - { - now = now->fail; - } - } + TrieNode::NextMap::const_iterator kmIter; - if(!flag) - { - node = now->findNext(ch); - } - if(node == NULL) - { - now = _root; - } - else - { - now = node; - const TrieNode * temp = now; - while(temp != _root) - { - if (temp->ptValue) - { - size_t pos = i - temp->ptValue->word.size() + 1; - res[pos].dag.push_back(pair::size_type, const DictUnit* >(i, temp->ptValue)); - if(pos == i) - { - res[pos].dag[0].second = temp->ptValue; - } - } - temp = temp->fail; - assert(temp); - } - } - } - } - bool find( - Unicode::const_iterator begin, - Unicode::const_iterator end, - DagType & res, - size_t offset = 0) const - { - const TrieNode * ptNode = _root; - TrieNode::NextMap::const_iterator citer; - for(Unicode::const_iterator itr = begin; itr != end ; itr++) - { - assert(ptNode); - if(NULL == ptNode->next || ptNode->next->end() == (citer = ptNode->next->find(*itr))) - { - break; - } - ptNode = citer->second; - if(ptNode->ptValue) - { - if(itr == begin && res.size() == 1) // first singleword - { - res[0].second = ptNode->ptValue; - } - else - { - res.push_back(pair::size_type, const DictUnit* >(itr - begin + offset, ptNode->ptValue)); - } - } - } - return !res.empty(); - } - private: - void _build() - { - queue que; - assert(_root->ptValue == NULL); - assert(_root->next); - _root->fail = NULL; - for(TrieNode::NextMap::iterator iter = _root->next->begin(); iter != _root->next->end(); iter++) { - iter->second->fail = _root; - que.push(iter->second); - } - TrieNode* back = NULL; - TrieNode::NextMap::iterator backiter; - while(!que.empty()) { - TrieNode * now = que.front(); - que.pop(); - if(now->next == NULL) { - continue; - } - for(TrieNode::NextMap::iterator iter = now->next->begin(); iter != now->next->end(); iter++) { - back = now->fail; - while(back != NULL) { - if(back->next && (backiter = back->next->find(iter->first)) != back->next->end()) - { - iter->second->fail = backiter->second; - break; - } - back = back->fail; - } - if(back == NULL) { - iter->second->fail = _root; - } - que.push(iter->second); - } - } - } - void _createTrie(const vector& keys, const vector & valuePointers) - { - if(valuePointers.empty() || keys.empty()) - { - return; - } - assert(keys.size() == valuePointers.size()); + for(Unicode::const_iterator citer = key.begin(); citer != key.end(); citer++) { + if(NULL == ptNode->next) { + ptNode->next = new TrieNode::NextMap; + } + kmIter = ptNode->next->find(*citer); + if(ptNode->next->end() == kmIter) { + TrieNode * nextNode = new TrieNode; + nextNode->next = NULL; + nextNode->ptValue = NULL; - for(size_t i = 0; i < keys.size(); i++) - { - _insertNode(keys[i], valuePointers[i]); - } - } - void _insertNode(const Unicode& key, const DictUnit* ptValue) - { - TrieNode* ptNode = _root; - - TrieNode::NextMap::const_iterator kmIter; - - for(Unicode::const_iterator citer = key.begin(); citer != key.end(); citer++) - { - if(NULL == ptNode->next) - { - ptNode->next = new TrieNode::NextMap; - } - kmIter = ptNode->next->find(*citer); - if(ptNode->next->end() == kmIter) - { - TrieNode * nextNode = new TrieNode; - nextNode->next = NULL; - nextNode->ptValue = NULL; - - (*ptNode->next)[*citer] = nextNode; - ptNode = nextNode; - } - else - { - ptNode = kmIter->second; - } - } - ptNode->ptValue = ptValue; - } - void _deleteNode(TrieNode* node) - { - if(!node) - { - return; - } - if(node->next) - { - TrieNode::NextMap::iterator it; - for(it = node->next->begin(); it != node->next->end(); it++) - { - _deleteNode(it->second); - } - delete node->next; - } - delete node; - } - private: - TrieNode* _root; - }; + (*ptNode->next)[*citer] = nextNode; + ptNode = nextNode; + } else { + ptNode = kmIter->second; + } + } + ptNode->ptValue = ptValue; + } + void _deleteNode(TrieNode* node) { + if(!node) { + return; + } + if(node->next) { + TrieNode::NextMap::iterator it; + for(it = node->next->begin(); it != node->next->end(); it++) { + _deleteNode(it->second); + } + delete node->next; + } + delete node; + } + private: + TrieNode* _root; +}; } #endif diff --git a/test/keyword_demo.cpp b/test/keyword_demo.cpp index e48fc4a..4588480 100644 --- a/test/keyword_demo.cpp +++ b/test/keyword_demo.cpp @@ -1,17 +1,16 @@ #include "../src/KeywordExtractor.hpp" using namespace CppJieba; -int main(int argc, char ** argv) -{ - KeywordExtractor extractor("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8", "../dict/idf.utf8", "../dict/stop_words.utf8"); - //KeywordExtractor extractor("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8", "../dict/idf.utf8", "../dict/stop_words.utf8", "../dict/user.dict.utf8"); - string s("我是拖拉机学院手扶拖拉机专业的。不用多久,我就会升职加薪,当上CEO,走上人生巅峰。"); - vector > wordweights; - vector words; - size_t topN = 5; - extractor.extract(s, wordweights, topN); - cout<< s << '\n' << wordweights << endl; - extractor.extract(s, words, topN); - cout<< s << '\n' << words << endl; - return EXIT_SUCCESS; +int main(int argc, char ** argv) { + KeywordExtractor extractor("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8", "../dict/idf.utf8", "../dict/stop_words.utf8"); + //KeywordExtractor extractor("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8", "../dict/idf.utf8", "../dict/stop_words.utf8", "../dict/user.dict.utf8"); + string s("我是拖拉机学院手扶拖拉机专业的。不用多久,我就会升职加薪,当上CEO,走上人生巅峰。"); + vector > wordweights; + vector words; + size_t topN = 5; + extractor.extract(s, wordweights, topN); + cout<< s << '\n' << wordweights << endl; + extractor.extract(s, words, topN); + cout<< s << '\n' << words << endl; + return EXIT_SUCCESS; } diff --git a/test/load_test.cpp b/test/load_test.cpp index 6ac3ef6..f6f9116 100644 --- a/test/load_test.cpp +++ b/test/load_test.cpp @@ -9,51 +9,46 @@ using namespace CppJieba; -void cut(size_t times = 50) -{ - MixSegment seg("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8"); - vector res; - string doc; - ifstream ifs("../test/testdata/weicheng.utf8"); - assert(ifs); - doc << ifs; - long beginTime = clock(); - for(size_t i = 0; i < times; i ++) - { - printf("process [%3.0lf %%]\r", 100.0*(i+1)/times); - fflush(stdout); - res.clear(); - seg.cut(doc, res); - } - printf("\n"); - long endTime = clock(); - ColorPrintln(GREEN, "cut: [%.3lf seconds]time consumed.", double(endTime - beginTime)/CLOCKS_PER_SEC); +void cut(size_t times = 50) { + MixSegment seg("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8"); + vector res; + string doc; + ifstream ifs("../test/testdata/weicheng.utf8"); + assert(ifs); + doc << ifs; + long beginTime = clock(); + for(size_t i = 0; i < times; i ++) { + printf("process [%3.0lf %%]\r", 100.0*(i+1)/times); + fflush(stdout); + res.clear(); + seg.cut(doc, res); + } + printf("\n"); + long endTime = clock(); + ColorPrintln(GREEN, "cut: [%.3lf seconds]time consumed.", double(endTime - beginTime)/CLOCKS_PER_SEC); } -void extract(size_t times = 400) -{ - KeywordExtractor extractor("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8", "../dict/idf.utf8", "../dict/stop_words.utf8"); - vector words; - string doc; - ifstream ifs("../test/testdata/review.100"); - assert(ifs); - doc << ifs; - long beginTime = clock(); - for(size_t i = 0; i < times; i ++) - { - printf("process [%3.0lf %%]\r", 100.0*(i+1)/times); - fflush(stdout); - words.clear(); - extractor.extract(doc, words, 5); - } - printf("\n"); - long endTime = clock(); - ColorPrintln(GREEN, "extract: [%.3lf seconds]time consumed.", double(endTime - beginTime)/CLOCKS_PER_SEC); +void extract(size_t times = 400) { + KeywordExtractor extractor("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8", "../dict/idf.utf8", "../dict/stop_words.utf8"); + vector words; + string doc; + ifstream ifs("../test/testdata/review.100"); + assert(ifs); + doc << ifs; + long beginTime = clock(); + for(size_t i = 0; i < times; i ++) { + printf("process [%3.0lf %%]\r", 100.0*(i+1)/times); + fflush(stdout); + words.clear(); + extractor.extract(doc, words, 5); + } + printf("\n"); + long endTime = clock(); + ColorPrintln(GREEN, "extract: [%.3lf seconds]time consumed.", double(endTime - beginTime)/CLOCKS_PER_SEC); } -int main(int argc, char ** argv) -{ - cut(); - extract(); - return EXIT_SUCCESS; +int main(int argc, char ** argv) { + cut(); + extract(); + return EXIT_SUCCESS; } diff --git a/test/segment_demo.cpp b/test/segment_demo.cpp index 36d8e5d..a803f75 100644 --- a/test/segment_demo.cpp +++ b/test/segment_demo.cpp @@ -14,46 +14,42 @@ const char * const JIEBA_DICT_FILE = "../dict/jieba.dict.utf8"; const char * const HMM_DICT_FILE = "../dict/hmm_model.utf8"; const char * const USER_DICT_FILE = "../dict/user.dict.utf8"; -void cut(const ISegment& seg, const char * const filePath) -{ - ifstream ifile(filePath); - vector words; - string line; - string res; - while(getline(ifile, line)) - { - if(!line.empty()) - { - words.clear(); - seg.cut(line, words); - join(words.begin(), words.end(), res, "/"); - cout<< res < words; + string line; + string res; + while(getline(ifile, line)) { + if(!line.empty()) { + words.clear(); + seg.cut(line, words); + join(words.begin(), words.end(), res, "/"); + cout<< res < > res; - tagger.tag(s, res); - cout << res << endl; - return EXIT_SUCCESS; +int main(int argc, char ** argv) { + PosTagger tagger("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8", "../dict/user.dict.utf8"); + string s("我是蓝翔技工拖拉机学院手扶拖拉机专业的。不用多久,我就会升职加薪,当上总经理,出任CEO,迎娶白富美,走上人生巅峰。"); + vector > res; + tagger.tag(s, res); + cout << res << endl; + return EXIT_SUCCESS; } diff --git a/test/unittest/TKeywordExtractor.cpp b/test/unittest/TKeywordExtractor.cpp index 9ea203c..25dfaa5 100644 --- a/test/unittest/TKeywordExtractor.cpp +++ b/test/unittest/TKeywordExtractor.cpp @@ -5,52 +5,50 @@ using namespace CppJieba; -TEST(KeywordExtractorTest, Test1) -{ - KeywordExtractor extractor("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", "../dict/idf.utf8", "../dict/stop_words.utf8"); +TEST(KeywordExtractorTest, Test1) { + KeywordExtractor extractor("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", "../dict/idf.utf8", "../dict/stop_words.utf8"); - { - string s("我是拖拉机学院手扶拖拉机专业的。不用多久,我就会升职加薪,当上CEO,走上人生巅峰。"); - string res; - vector > wordweights; - size_t topN = 5; - extractor.extract(s, wordweights, topN); - res << wordweights; - ASSERT_EQ(res, "[\"CEO:11.7392\", \"升职:10.8562\", \"加薪:10.6426\", \"手扶拖拉机:10.0089\", \"巅峰:9.49396\"]"); - } + { + string s("我是拖拉机学院手扶拖拉机专业的。不用多久,我就会升职加薪,当上CEO,走上人生巅峰。"); + string res; + vector > wordweights; + size_t topN = 5; + extractor.extract(s, wordweights, topN); + res << wordweights; + ASSERT_EQ(res, "[\"CEO:11.7392\", \"升职:10.8562\", \"加薪:10.6426\", \"手扶拖拉机:10.0089\", \"巅峰:9.49396\"]"); + } - { - string s("一部iPhone6"); - string res; - vector > wordweights; - size_t topN = 5; - extractor.extract(s, wordweights, topN); - res << wordweights; - ASSERT_EQ(res, "[\"iPhone6:11.7392\", \"一部:6.47592\"]"); - } + { + string s("一部iPhone6"); + string res; + vector > wordweights; + size_t topN = 5; + extractor.extract(s, wordweights, topN); + res << wordweights; + ASSERT_EQ(res, "[\"iPhone6:11.7392\", \"一部:6.47592\"]"); + } } -TEST(KeywordExtractorTest, Test2) -{ - KeywordExtractor extractor("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", "../dict/idf.utf8", "../dict/stop_words.utf8", "../test/testdata/userdict.utf8"); +TEST(KeywordExtractorTest, Test2) { + KeywordExtractor extractor("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", "../dict/idf.utf8", "../dict/stop_words.utf8", "../test/testdata/userdict.utf8"); - { - string s("蓝翔优秀毕业生"); - string res; - vector > wordweights; - size_t topN = 5; - extractor.extract(s, wordweights, topN); - res << wordweights; - ASSERT_EQ(res, "[\"蓝翔:11.7392\", \"毕业生:8.13549\", \"优秀:6.78347\"]"); - } + { + string s("蓝翔优秀毕业生"); + string res; + vector > wordweights; + size_t topN = 5; + extractor.extract(s, wordweights, topN); + res << wordweights; + ASSERT_EQ(res, "[\"蓝翔:11.7392\", \"毕业生:8.13549\", \"优秀:6.78347\"]"); + } - { - string s("一部iPhone6"); - string res; - vector > wordweights; - size_t topN = 5; - extractor.extract(s, wordweights, topN); - res << wordweights; - ASSERT_EQ(res, "[\"iPhone6:11.7392\", \"一部:6.47592\"]"); - } + { + string s("一部iPhone6"); + string res; + vector > wordweights; + size_t topN = 5; + extractor.extract(s, wordweights, topN); + res << wordweights; + ASSERT_EQ(res, "[\"iPhone6:11.7392\", \"一部:6.47592\"]"); + } } diff --git a/test/unittest/TPosTagger.cpp b/test/unittest/TPosTagger.cpp index 89e0b37..a157a2b 100644 --- a/test/unittest/TPosTagger.cpp +++ b/test/unittest/TPosTagger.cpp @@ -12,32 +12,30 @@ static const char * const QUERY_TEST3 = "iPhone6手机的最大特点是很容 static const char * const ANS_TEST3 = "[\"iPhone6:eng\", \"手机:n\", \"的:uj\", \"最大:a\", \"特点:n\", \"是:v\", \"很:zg\", \"容易:a\", \"弯曲:v\", \"。:x\"]"; //static const char * const ANS_TEST3 = ""; -TEST(PosTaggerTest, Test) -{ - PosTagger tagger("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8"); - { - vector > res; - tagger.tag(QUERY_TEST1, res); - string s; - s << res; - ASSERT_TRUE(s == ANS_TEST1); - } +TEST(PosTaggerTest, Test) { + PosTagger tagger("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8"); + { + vector > res; + tagger.tag(QUERY_TEST1, res); + string s; + s << res; + ASSERT_TRUE(s == ANS_TEST1); + } } -TEST(PosTagger, TestUserDict) -{ - PosTagger tagger("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8", "../test/testdata/userdict.utf8"); - { - vector > res; - tagger.tag(QUERY_TEST2, res); - string s; - s << res; - ASSERT_EQ(s, ANS_TEST2); - } - { - vector > res; - tagger.tag(QUERY_TEST3, res); - string s; - s << res; - ASSERT_EQ(s, ANS_TEST3); - } +TEST(PosTagger, TestUserDict) { + PosTagger tagger("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8", "../test/testdata/userdict.utf8"); + { + vector > res; + tagger.tag(QUERY_TEST2, res); + string s; + s << res; + ASSERT_EQ(s, ANS_TEST2); + } + { + vector > res; + tagger.tag(QUERY_TEST3, res); + string s; + s << res; + ASSERT_EQ(s, ANS_TEST3); + } } diff --git a/test/unittest/TSegments.cpp b/test/unittest/TSegments.cpp index 54e5f64..f557a0b 100644 --- a/test/unittest/TSegments.cpp +++ b/test/unittest/TSegments.cpp @@ -9,170 +9,176 @@ using namespace CppJieba; -TEST(MixSegmentTest, Test1) -{ - MixSegment segment("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8");; - const char* str = "我来自北京邮电大学。。。学号123456,用AK47"; - const char* res[] = {"我", "来自", "北京邮电大学", "。","。","。", "学号", "123456",",","用","AK47"}; - const char* str2 = "B超 T恤"; - const char* res2[] = {"B超"," ", "T恤"}; - vector words; - ASSERT_TRUE(segment.cut(str, words)); - ASSERT_EQ(words, vector(res, res + sizeof(res)/sizeof(res[0]))); - ASSERT_TRUE(segment.cut(str2, words)); - ASSERT_EQ(words, vector(res2, res2 + sizeof(res2)/sizeof(res2[0]))); +TEST(MixSegmentTest, Test1) { + MixSegment segment("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8");; + const char* str = "我来自北京邮电大学。。。学号123456,用AK47"; + const char* res[] = {"我", "来自", "北京邮电大学", "。","。","。", "学号", "123456",",","用","AK47"}; + const char* str2 = "B超 T恤"; + const char* res2[] = {"B超"," ", "T恤"}; + vector words; + ASSERT_TRUE(segment.cut(str, words)); + ASSERT_EQ(words, vector(res, res + sizeof(res)/sizeof(res[0]))); + ASSERT_TRUE(segment.cut(str2, words)); + ASSERT_EQ(words, vector(res2, res2 + sizeof(res2)/sizeof(res2[0]))); } -TEST(MixSegmentTest, NoUserDict) -{ - MixSegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8"); +TEST(MixSegmentTest, NoUserDict) { + MixSegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8"); + const char* str = "令狐冲是云计算方面的专家"; + vector words; + ASSERT_TRUE(segment.cut(str, words)); + string res; + ASSERT_EQ("[\"令狐冲\", \"是\", \"云\", \"计算\", \"方面\", \"的\", \"专家\"]", res << words); + +} +TEST(MixSegmentTest, UserDict) { + MixSegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", "../dict/user.dict.utf8"); + { const char* str = "令狐冲是云计算方面的专家"; vector words; ASSERT_TRUE(segment.cut(str, words)); string res; - ASSERT_EQ("[\"令狐冲\", \"是\", \"云\", \"计算\", \"方面\", \"的\", \"专家\"]", res << words); - + ASSERT_EQ("[\"令狐冲\", \"是\", \"云计算\", \"方面\", \"的\", \"专家\"]", res << words); + } + { + const char* str = "小明先就职于IBM,后在日本京都大学深造"; + vector words; + ASSERT_TRUE(segment.cut(str, words)); + string res; + res << words; + ASSERT_EQ("[\"小明\", \"先\", \"就职\", \"于\", \"IBM\", \",\", \"后\", \"在\", \"日本\", \"京都大学\", \"深造\"]", res); + } + { + const char* str = "IBM,3.14"; + vector words; + ASSERT_TRUE(segment.cut(str, words)); + string res; + res << words; + ASSERT_EQ("[\"IBM\", \",\", \"3.14\"]", res); + } } -TEST(MixSegmentTest, UserDict) -{ - MixSegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", "../dict/user.dict.utf8"); - { - const char* str = "令狐冲是云计算方面的专家"; - vector words; - ASSERT_TRUE(segment.cut(str, words)); - string res; - ASSERT_EQ("[\"令狐冲\", \"是\", \"云计算\", \"方面\", \"的\", \"专家\"]", res << words); - } - { - const char* str = "小明先就职于IBM,后在日本京都大学深造"; - vector words; - ASSERT_TRUE(segment.cut(str, words)); - string res; - res << words; - ASSERT_EQ("[\"小明\", \"先\", \"就职\", \"于\", \"IBM\", \",\", \"后\", \"在\", \"日本\", \"京都大学\", \"深造\"]", res); - } - { - const char* str = "IBM,3.14"; - vector words; - ASSERT_TRUE(segment.cut(str, words)); - string res; - res << words; - ASSERT_EQ("[\"IBM\", \",\", \"3.14\"]", res); - } -} -TEST(MixSegmentTest, UserDict2) -{ - MixSegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", "../test/testdata/userdict.utf8"); - { - const char* str = "令狐冲是云计算方面的专家"; - vector words; - ASSERT_TRUE(segment.cut(str, words)); - string res; - ASSERT_EQ("[\"令狐冲\", \"是\", \"云计算\", \"方面\", \"的\", \"专家\"]", res << words); - } - { - const char* str = "小明先就职于IBM,后在日本京都大学深造"; - vector words; - ASSERT_TRUE(segment.cut(str, words)); - string res; - res << words; - ASSERT_EQ("[\"小明\", \"先\", \"就职\", \"于\", \"I\", \"B\", \"M\", \",\", \"后\", \"在\", \"日本\", \"京都大学\", \"深造\"]", res); - } - { - const char* str = "IBM,3.14"; - vector words; - ASSERT_TRUE(segment.cut(str, words)); - string res; - res << words; - ASSERT_EQ("[\"I\", \"B\", \"M\", \",\", \"3.14\"]", res); - } +TEST(MixSegmentTest, UserDict2) { + MixSegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", "../test/testdata/userdict.utf8"); + { + const char* str = "令狐冲是云计算方面的专家"; + vector words; + ASSERT_TRUE(segment.cut(str, words)); + string res; + ASSERT_EQ("[\"令狐冲\", \"是\", \"云计算\", \"方面\", \"的\", \"专家\"]", res << words); + } + { + const char* str = "小明先就职于IBM,后在日本京都大学深造"; + vector words; + ASSERT_TRUE(segment.cut(str, words)); + string res; + res << words; + ASSERT_EQ("[\"小明\", \"先\", \"就职\", \"于\", \"I\", \"B\", \"M\", \",\", \"后\", \"在\", \"日本\", \"京都大学\", \"深造\"]", res); + } + { + const char* str = "IBM,3.14"; + vector words; + ASSERT_TRUE(segment.cut(str, words)); + string res; + res << words; + ASSERT_EQ("[\"I\", \"B\", \"M\", \",\", \"3.14\"]", res); + } } -TEST(MPSegmentTest, Test1) -{ - MPSegment segment("../dict/jieba.dict.utf8");; - const char* str = "我来自北京邮电大学。"; - const char* res[] = {"我", "来自", "北京邮电大学", "。"}; +TEST(MPSegmentTest, Test1) { + MPSegment segment("../dict/jieba.dict.utf8");; + const char* str = "我来自北京邮电大学。"; + const char* res[] = {"我", "来自", "北京邮电大学", "。"}; + vector words; + ASSERT_TRUE(segment.cut(str, words)); + ASSERT_EQ(words, vector(res, res + sizeof(res)/sizeof(res[0]))); + + { + const char* str = "B超 T恤"; + const char * res[] = {"B超", " ", "T恤"}; vector words; ASSERT_TRUE(segment.cut(str, words)); ASSERT_EQ(words, vector(res, res + sizeof(res)/sizeof(res[0]))); - - { - const char* str = "B超 T恤"; - const char * res[] = {"B超", " ", "T恤"}; - vector words; - ASSERT_TRUE(segment.cut(str, words)); - ASSERT_EQ(words, vector(res, res + sizeof(res)/sizeof(res[0]))); - } + } } -TEST(MPSegmentTest, Test2) -{ - MPSegment segment("../dict/extra_dict/jieba.dict.small.utf8"); - string line; - ifstream ifs("../test/testdata/review.100"); - vector words; +TEST(MPSegmentTest, Test2) { + MPSegment segment("../dict/extra_dict/jieba.dict.small.utf8"); + string line; + ifstream ifs("../test/testdata/review.100"); + vector words; - string eRes; - { - ifstream ifs("../test/testdata/review.100.res"); - ASSERT_TRUE(!!ifs); - eRes << ifs; - } - string res; - - while(getline(ifs, line)) - { - res += line; - res += '\n'; - - segment.cut(line, words); - string s; - s << words; - res += s; - res += '\n'; - } - ofstream ofs("../test/testdata/review.100.res"); - ASSERT_TRUE(!!ofs); - ofs << res; - -} -TEST(HMMSegmentTest, Test1) -{ - HMMSegment segment("../dict/hmm_model.utf8");; - { - const char* str = "我来自北京邮电大学。。。学号123456"; - const char* res[] = {"我来", "自北京", "邮电大学", "。", "。", "。", "学号", "123456"}; - vector words; - ASSERT_TRUE(segment.cut(str, words)); - ASSERT_EQ(words, vector(res, res + sizeof(res)/sizeof(res[0]))); - } - - { - const char* str = "IBM,1.2,123"; - const char* res[] = {"IBM", ",", "1.2", ",", "123"}; - vector words; - ASSERT_TRUE(segment.cut(str, words)); - ASSERT_EQ(words, vector(res, res + sizeof(res)/sizeof(res[0]))); - } -} + string eRes; + { + ifstream ifs("../test/testdata/review.100.res"); + ASSERT_TRUE(!!ifs); + eRes << ifs; + } + string res; -TEST(FullSegment, Test1) -{ - FullSegment segment("../dict/extra_dict/jieba.dict.small.utf8"); - const char* str = "我来自北京邮电大学"; - vector words; - - ASSERT_EQ(segment.cut(str, words), true); + while(getline(ifs, line)) { + res += line; + res += '\n'; + segment.cut(line, words); string s; s << words; - ASSERT_EQ(s, "[\"我\", \"来自\", \"北京\", \"北京邮电大学\", \"邮电\", \"电大\", \"大学\"]"); + res += s; + res += '\n'; + } + ofstream ofs("../test/testdata/review.100.res"); + ASSERT_TRUE(!!ofs); + ofs << res; + +} +TEST(HMMSegmentTest, Test1) { + HMMSegment segment("../dict/hmm_model.utf8");; + { + const char* str = "我来自北京邮电大学。。。学号123456"; + const char* res[] = {"我来", "自北京", "邮电大学", "。", "。", "。", "学号", "123456"}; + vector words; + ASSERT_TRUE(segment.cut(str, words)); + ASSERT_EQ(words, vector(res, res + sizeof(res)/sizeof(res[0]))); + } + + { + const char* str = "IBM,1.2,123"; + const char* res[] = {"IBM", ",", "1.2", ",", "123"}; + vector words; + ASSERT_TRUE(segment.cut(str, words)); + ASSERT_EQ(words, vector(res, res + sizeof(res)/sizeof(res[0]))); + } } -TEST(QuerySegment, Test1) -{ - QuerySegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", 3); +TEST(FullSegment, Test1) { + FullSegment segment("../dict/extra_dict/jieba.dict.small.utf8"); + const char* str = "我来自北京邮电大学"; + vector words; + + ASSERT_EQ(segment.cut(str, words), true); + + string s; + s << words; + ASSERT_EQ(s, "[\"我\", \"来自\", \"北京\", \"北京邮电大学\", \"邮电\", \"电大\", \"大学\"]"); +} + +TEST(QuerySegment, Test1) { + QuerySegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", 3); + const char* str = "小明硕士毕业于中国科学院计算所,后在日本京都大学深造"; + vector words; + + ASSERT_TRUE(segment.cut(str, words)); + + string s1, s2; + s1 << words; + s2 = "[\"小明\", \"硕士\", \"毕业\", \"于\", \"中国\", \"中国科学院\", \"科学\", \"科学院\", \"学院\", \"计算所\", \",\", \"后\", \"在\", \"日本\", \"京都\", \"京都大学\", \"大学\", \"深造\"]"; + ASSERT_EQ(s1, s2); + +} + +TEST(QuerySegment, Test2) { + QuerySegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", 3, "../test/testdata/userdict.utf8"); + + { const char* str = "小明硕士毕业于中国科学院计算所,后在日本京都大学深造"; vector words; @@ -182,35 +188,18 @@ TEST(QuerySegment, Test1) s1 << words; s2 = "[\"小明\", \"硕士\", \"毕业\", \"于\", \"中国\", \"中国科学院\", \"科学\", \"科学院\", \"学院\", \"计算所\", \",\", \"后\", \"在\", \"日本\", \"京都\", \"京都大学\", \"大学\", \"深造\"]"; ASSERT_EQ(s1, s2); + } + + { + const char* str = "小明硕士毕业于中国科学院计算所iPhone6"; + vector words; + + ASSERT_TRUE(segment.cut(str, words)); + + string s1, s2; + s1 << words; + s2 = "[\"小明\", \"硕士\", \"毕业\", \"于\", \"中国\", \"中国科学院\", \"科学\", \"科学院\", \"学院\", \"计算所\", \"iPhone6\"]"; + ASSERT_EQ(s1, s2); + } } - -TEST(QuerySegment, Test2) -{ - QuerySegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", 3, "../test/testdata/userdict.utf8"); - - { - const char* str = "小明硕士毕业于中国科学院计算所,后在日本京都大学深造"; - vector words; - - ASSERT_TRUE(segment.cut(str, words)); - - string s1, s2; - s1 << words; - s2 = "[\"小明\", \"硕士\", \"毕业\", \"于\", \"中国\", \"中国科学院\", \"科学\", \"科学院\", \"学院\", \"计算所\", \",\", \"后\", \"在\", \"日本\", \"京都\", \"京都大学\", \"大学\", \"深造\"]"; - ASSERT_EQ(s1, s2); - } - - { - const char* str = "小明硕士毕业于中国科学院计算所iPhone6"; - vector words; - - ASSERT_TRUE(segment.cut(str, words)); - - string s1, s2; - s1 << words; - s2 = "[\"小明\", \"硕士\", \"毕业\", \"于\", \"中国\", \"中国科学院\", \"科学\", \"科学院\", \"学院\", \"计算所\", \"iPhone6\"]"; - ASSERT_EQ(s1, s2); - } - -} diff --git a/test/unittest/TTrie.cpp b/test/unittest/TTrie.cpp index 49ce021..7449b8a 100644 --- a/test/unittest/TTrie.cpp +++ b/test/unittest/TTrie.cpp @@ -6,75 +6,70 @@ using namespace CppJieba; static const char* const DICT_FILE = "../dict/extra_dict/jieba.dict.small.utf8"; -TEST(DictTrieTest, NewAndDelete) -{ - DictTrie * trie; - trie = new DictTrie(DICT_FILE); - delete trie; - trie = new DictTrie(); - delete trie; +TEST(DictTrieTest, NewAndDelete) { + DictTrie * trie; + trie = new DictTrie(DICT_FILE); + delete trie; + trie = new DictTrie(); + delete trie; } -TEST(DictTrieTest, Test1) -{ +TEST(DictTrieTest, Test1) { - string s1, s2; - DictTrie trie; - ASSERT_TRUE(trie.init(DICT_FILE)); - ASSERT_LT(trie.getMinWeight() + 15.6479, 0.001); - string word("来到"); - Unicode uni; - ASSERT_TRUE(TransCode::decode(word, uni)); - DictUnit nodeInfo; - nodeInfo.word = uni; - nodeInfo.tag = "v"; - nodeInfo.weight = -8.87033; - s1 << nodeInfo; - s2 << (*trie.find(uni.begin(), uni.end())); - - EXPECT_EQ("[\"26469\", \"21040\"] v -8.870", s2); - word = "清华大学"; - LocalVector > res; - //vector resMap; - LocalVector > res2; - const char * words[] = {"清", "清华", "清华大学"}; - for(size_t i = 0; i < sizeof(words)/sizeof(words[0]); i++) - { - ASSERT_TRUE(TransCode::decode(words[i], uni)); - res.push_back(make_pair(uni.size() - 1, trie.find(uni.begin(), uni.end()))); - //resMap[uni.size() - 1] = trie.find(uni.begin(), uni.end()); - } - //DictUnit - //res.push_back(make_pair(0, )) + string s1, s2; + DictTrie trie; + ASSERT_TRUE(trie.init(DICT_FILE)); + ASSERT_LT(trie.getMinWeight() + 15.6479, 0.001); + string word("来到"); + Unicode uni; + ASSERT_TRUE(TransCode::decode(word, uni)); + DictUnit nodeInfo; + nodeInfo.word = uni; + nodeInfo.tag = "v"; + nodeInfo.weight = -8.87033; + s1 << nodeInfo; + s2 << (*trie.find(uni.begin(), uni.end())); - vector > vec; - ASSERT_TRUE(TransCode::decode(word, uni)); - ASSERT_TRUE(trie.find(uni.begin(), uni.end(), res2, 0)); - s1 << res; - s2 << res; - ASSERT_EQ(s1, s2); + EXPECT_EQ("[\"26469\", \"21040\"] v -8.870", s2); + word = "清华大学"; + LocalVector > res; + //vector resMap; + LocalVector > res2; + const char * words[] = {"清", "清华", "清华大学"}; + for(size_t i = 0; i < sizeof(words)/sizeof(words[0]); i++) { + ASSERT_TRUE(TransCode::decode(words[i], uni)); + res.push_back(make_pair(uni.size() - 1, trie.find(uni.begin(), uni.end()))); + //resMap[uni.size() - 1] = trie.find(uni.begin(), uni.end()); + } + //DictUnit + //res.push_back(make_pair(0, )) + + vector > vec; + ASSERT_TRUE(TransCode::decode(word, uni)); + ASSERT_TRUE(trie.find(uni.begin(), uni.end(), res2, 0)); + s1 << res; + s2 << res; + ASSERT_EQ(s1, s2); } -TEST(DictTrieTest, UserDict) -{ - DictTrie trie(DICT_FILE, "../test/testdata/userdict.utf8"); - string word = "云计算"; - Unicode unicode; - ASSERT_TRUE(TransCode::decode(word, unicode)); - const DictUnit * unit = trie.find(unicode.begin(), unicode.end()); - ASSERT_TRUE(unit); - string res ; - res << *unit; - ASSERT_EQ("[\"20113\", \"35745\", \"31639\"] -2.975", res); +TEST(DictTrieTest, UserDict) { + DictTrie trie(DICT_FILE, "../test/testdata/userdict.utf8"); + string word = "云计算"; + Unicode unicode; + ASSERT_TRUE(TransCode::decode(word, unicode)); + const DictUnit * unit = trie.find(unicode.begin(), unicode.end()); + ASSERT_TRUE(unit); + string res ; + res << *unit; + ASSERT_EQ("[\"20113\", \"35745\", \"31639\"] -2.975", res); } -TEST(DictTrieTest, automation) -{ - DictTrie trie(DICT_FILE, "../test/testdata/userdict.utf8"); - //string word = "yasherhs"; - string word = "abcderf"; - Unicode unicode; - ASSERT_TRUE(TransCode::decode(word, unicode)); - vector res; - trie.find(unicode.begin(), unicode.end(), res); +TEST(DictTrieTest, automation) { + DictTrie trie(DICT_FILE, "../test/testdata/userdict.utf8"); + //string word = "yasherhs"; + string word = "abcderf"; + Unicode unicode; + ASSERT_TRUE(TransCode::decode(word, unicode)); + vector res; + trie.find(unicode.begin(), unicode.end(), res); }