diff --git a/src/HMMModel.hpp b/src/HMMModel.hpp new file mode 100644 index 0000000..a79e720 --- /dev/null +++ b/src/HMMModel.hpp @@ -0,0 +1,142 @@ +#ifndef CPPJIEBA_HMMMODEL_H +#define CPPJIEBA_HMMMODEL_H + +#include "Limonp/StringUtil.hpp" + +namespace CppJieba { + +using namespace Limonp; +typedef unordered_map EmitProbMap; + +struct HMMModel { + /* + * STATUS: + * 0: HMMModel::B, 1: HMMModel::E, 2: HMMModel::M, 3:HMMModel::S + * */ + enum {B = 0, E = 1, M = 2, S = 3, STATUS_SUM = 4}; + + HMMModel(const string& modelPath) { + memset(startProb, 0, sizeof(startProb)); + memset(transProb, 0, sizeof(transProb)); + statMap[0] = 'B'; + statMap[1] = 'E'; + statMap[2] = 'M'; + statMap[3] = 'S'; + emitProbVec.push_back(&emitProbB); + emitProbVec.push_back(&emitProbE); + emitProbVec.push_back(&emitProbM); + emitProbVec.push_back(&emitProbS); + loadModel(modelPath); + } + ~HMMModel() { + } + void loadModel(const string& filePath) { + ifstream ifile(filePath.c_str()); + if(!ifile.is_open()) { + LogFatal("open %s failed.", filePath.c_str()); + } + string line; + vector tmp; + vector tmp2; + //load startProb + if(!getLine(ifile, line)) { + LogFatal("load startProb"); + } + split(line, tmp, " "); + if(tmp.size() != STATUS_SUM) { + LogFatal("start_p illegal"); + } + for(size_t j = 0; j< tmp.size(); j++) { + startProb[j] = atof(tmp[j].c_str()); + } + + //load transProb + for(size_t i = 0; i < STATUS_SUM; i++) { + if(!getLine(ifile, line)) { + LogFatal("load transProb failed."); + } + split(line, tmp, " "); + if(tmp.size() != STATUS_SUM) { + LogFatal("trans_p illegal"); + } + for(size_t j =0; j < STATUS_SUM; j++) { + transProb[i][j] = atof(tmp[j].c_str()); + } + } + + //load emitProbB + if(!getLine(ifile, line) || !loadEmitProb(line, emitProbB)) { + LogFatal("load emitProbB failed."); + } + + //load emitProbE + if(!getLine(ifile, line) || !loadEmitProb(line, emitProbE)) { + LogFatal("load emitProbE failed."); + } + + //load emitProbM + if(!getLine(ifile, line) || !loadEmitProb(line, emitProbM)) { + LogFatal("load emitProbM failed."); + } + + //load emitProbS + if(!getLine(ifile, line) || !loadEmitProb(line, emitProbS)) { + LogFatal("load emitProbS failed."); + } + } + double getEmitProb(const EmitProbMap* ptMp, uint16_t key, + double defVal)const { + EmitProbMap::const_iterator cit = ptMp->find(key); + if(cit == ptMp->end()) { + return defVal; + } + return cit->second; + } + bool getLine(ifstream& ifile, string& line) { + while(getline(ifile, line)) { + trim(line); + if(line.empty()) { + continue; + } + if(startsWith(line, "#")) { + continue; + } + return true; + } + return false; + } + bool loadEmitProb(const string& line, EmitProbMap& mp) { + if(line.empty()) { + return false; + } + vector tmp, tmp2; + Unicode unicode; + split(line, tmp, ","); + for(size_t i = 0; i < tmp.size(); i++) { + split(tmp[i], tmp2, ":"); + if(2 != tmp2.size()) { + LogError("emitProb illegal."); + return false; + } + if(!TransCode::decode(tmp2[0], unicode) || unicode.size() != 1) { + LogError("TransCode failed."); + return false; + } + mp[unicode[0]] = atof(tmp2[1].c_str()); + } + return true; + } + + char statMap[STATUS_SUM]; + double startProb[STATUS_SUM]; + double transProb[STATUS_SUM][STATUS_SUM]; + EmitProbMap emitProbB; + EmitProbMap emitProbE; + EmitProbMap emitProbM; + EmitProbMap emitProbS; + vector emitProbVec; +}; // struct HMMModel + +} // namespace CppJieba + +#endif diff --git a/src/HMMSegment.hpp b/src/HMMSegment.hpp index 5cf0bf6..78c4dfd 100644 --- a/src/HMMSegment.hpp +++ b/src/HMMSegment.hpp @@ -5,47 +5,18 @@ #include #include #include -#include "Limonp/StringUtil.hpp" -#include "TransCode.hpp" -#include "ISegment.hpp" +#include "HMMModel.hpp" #include "SegmentBase.hpp" -#include "DictTrie.hpp" namespace CppJieba { -using namespace Limonp; -typedef unordered_map EmitProbMap; + class HMMSegment: public SegmentBase { public: - /* - * STATUS: - * 0:B, 1:E, 2:M, 3:S - * */ - enum {B = 0, E = 1, M = 2, S = 3, STATUS_SUM = 4}; - - public: - HMMSegment() {} - explicit HMMSegment(const string& filePath) { - init(filePath); + explicit HMMSegment(const string& filePath): model_(filePath) { } virtual ~HMMSegment() {} - public: - void init(const string& filePath) { - memset(startProb_, 0, sizeof(startProb_)); - memset(transProb_, 0, sizeof(transProb_)); - statMap_[0] = 'B'; - statMap_[1] = 'E'; - statMap_[2] = 'M'; - statMap_[3] = 'S'; - emitProbVec_.push_back(&emitProbB_); - emitProbVec_.push_back(&emitProbE_); - emitProbVec_.push_back(&emitProbM_); - emitProbVec_.push_back(&emitProbS_); - loadModel_(filePath.c_str()); - LogInfo("HMMSegment init(%s) ok.", filePath.c_str()); - } - public: + using SegmentBase::cut; - public: bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const { Unicode::const_iterator left = begin; Unicode::const_iterator right = begin; @@ -77,7 +48,6 @@ class HMMSegment: public SegmentBase { } return true; } - public: virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const { if(begin == end) { return false; @@ -141,7 +111,7 @@ class HMMSegment: public SegmentBase { Unicode::const_iterator left = begin; Unicode::const_iterator right; for(size_t i = 0; i < status.size(); i++) { - if(status[i] % 2) { //if(E == status[i] || S == status[i]) + if(status[i] % 2) { //if(HMMModel::E == status[i] || HMMModel::S == status[i]) right = begin + i + 1; res.push_back(Unicode(left, right)); left = right; @@ -150,12 +120,13 @@ class HMMSegment: public SegmentBase { return true; } - bool viterbi_(Unicode::const_iterator begin, Unicode::const_iterator end, vector& status)const { + bool viterbi_(Unicode::const_iterator begin, Unicode::const_iterator end, + vector& status) const { if(begin == end) { return false; } - size_t Y = STATUS_SUM; + size_t Y = HMMModel::STATUS_SUM; size_t X = end - begin; size_t XYSize = X * Y; @@ -167,22 +138,21 @@ class HMMSegment: public SegmentBase { //start for(size_t y = 0; y < Y; y++) { - weight[0 + y * X] = startProb_[y] + getEmitProb_(emitProbVec_[y], *begin, MIN_DOUBLE); + weight[0 + y * X] = model_.startProb[y] + model_.getEmitProb(model_.emitProbVec[y], *begin, MIN_DOUBLE); path[0 + y * X] = -1; } - double emitProb; for(size_t x = 1; x < X; x++) { for(size_t y = 0; y < Y; y++) { now = x + y*X; weight[now] = MIN_DOUBLE; - path[now] = E; // warning - emitProb = getEmitProb_(emitProbVec_[y], *(begin+x), MIN_DOUBLE); + path[now] = HMMModel::E; // warning + emitProb = model_.getEmitProb(model_.emitProbVec[y], *(begin+x), MIN_DOUBLE); for(size_t preY = 0; preY < Y; preY++) { old = x - 1 + preY * X; - tmp = weight[old] + transProb_[preY][y] + emitProb; + tmp = weight[old] + model_.transProb[preY][y] + emitProb; if(tmp > weight[now]) { weight[now] = tmp; path[now] = preY; @@ -191,13 +161,13 @@ class HMMSegment: public SegmentBase { } } - endE = weight[X-1+E*X]; - endS = weight[X-1+S*X]; + endE = weight[X-1+HMMModel::E*X]; + endS = weight[X-1+HMMModel::S*X]; stat = 0; if(endE >= endS) { - stat = E; + stat = HMMModel::E; } else { - stat = S; + stat = HMMModel::S; } status.resize(X); @@ -208,114 +178,10 @@ class HMMSegment: public SegmentBase { return true; } - void loadModel_(const char* const filePath) { - ifstream ifile(filePath); - if(!ifile.is_open()) { - LogFatal("open %s failed.", filePath); - } - string line; - vector tmp; - vector tmp2; - //load startProb_ - if(!getLine_(ifile, line)) { - LogFatal("load startProb_"); - } - split(line, tmp, " "); - if(tmp.size() != STATUS_SUM) { - LogFatal("start_p illegal"); - } - for(size_t j = 0; j< tmp.size(); j++) { - startProb_[j] = atof(tmp[j].c_str()); - } - //load transProb_ - for(size_t i = 0; i < STATUS_SUM; i++) { - if(!getLine_(ifile, line)) { - LogFatal("load transProb_ failed."); - } - split(line, tmp, " "); - if(tmp.size() != STATUS_SUM) { - LogFatal("trans_p illegal"); - } - for(size_t j =0; j < STATUS_SUM; j++) { - transProb_[i][j] = atof(tmp[j].c_str()); - } - } + HMMModel model_; +}; // class HMMSegment - //load emitProbB_ - if(!getLine_(ifile, line) || !loadEmitProb_(line, emitProbB_)) { - LogFatal("load emitProbB_ failed."); - } - - //load emitProbE_ - if(!getLine_(ifile, line) || !loadEmitProb_(line, emitProbE_)) { - LogFatal("load emitProbE_ failed."); - } - - //load emitProbM_ - if(!getLine_(ifile, line) || !loadEmitProb_(line, emitProbM_)) { - LogFatal("load emitProbM_ failed."); - } - - //load emitProbS_ - if(!getLine_(ifile, line) || !loadEmitProb_(line, emitProbS_)) { - LogFatal("load emitProbS_ failed."); - } - } - bool getLine_(ifstream& ifile, string& line) { - while(getline(ifile, line)) { - trim(line); - if(line.empty()) { - continue; - } - if(startsWith(line, "#")) { - continue; - } - return true; - } - return false; - } - bool loadEmitProb_(const string& line, EmitProbMap& mp) { - if(line.empty()) { - return false; - } - vector tmp, tmp2; - Unicode unicode; - split(line, tmp, ","); - for(size_t i = 0; i < tmp.size(); i++) { - split(tmp[i], tmp2, ":"); - if(2 != tmp2.size()) { - LogError("emitProb_ illegal."); - return false; - } - if(!TransCode::decode(tmp2[0], unicode) || unicode.size() != 1) { - LogError("TransCode failed."); - return false; - } - mp[unicode[0]] = atof(tmp2[1].c_str()); - } - return true; - } - double getEmitProb_(const EmitProbMap* ptMp, uint16_t key, double defVal)const { - EmitProbMap::const_iterator cit = ptMp->find(key); - if(cit == ptMp->end()) { - return defVal; - } - return cit->second; - - } - - private: - char statMap_[STATUS_SUM]; - double startProb_[STATUS_SUM]; - double transProb_[STATUS_SUM][STATUS_SUM]; - EmitProbMap emitProbB_; - EmitProbMap emitProbE_; - EmitProbMap emitProbM_; - EmitProbMap emitProbS_; - vector emitProbVec_; - -}; -} +} // namespace CppJieba #endif diff --git a/src/KeywordExtractor.hpp b/src/KeywordExtractor.hpp index 05280bd..dee73c5 100644 --- a/src/KeywordExtractor.hpp +++ b/src/KeywordExtractor.hpp @@ -11,17 +11,17 @@ using namespace Limonp; /*utf8*/ class KeywordExtractor { public: - KeywordExtractor() {}; - KeywordExtractor(const string& dictPath, const string& hmmFilePath, const string& idfPath, const string& stopWordPath, const string& userDict = "") { - init(dictPath, hmmFilePath, idfPath, stopWordPath, userDict); - }; - ~KeywordExtractor() {}; - - void init(const string& dictPath, const string& hmmFilePath, const string& idfPath, const string& stopWordPath, const string& userDict = "") { + KeywordExtractor(const string& dictPath, + const string& hmmFilePath, + const string& idfPath, + const string& stopWordPath, + const string& userDict = "") + : segment_(dictPath, hmmFilePath, userDict) { loadIdfDict_(idfPath); loadStopWordDict_(stopWordPath); - segment_.init(dictPath, hmmFilePath, userDict); - }; + } + ~KeywordExtractor() { + } bool extract(const string& str, vector& keywords, size_t topN) const { vector > topWords; diff --git a/src/MixSegment.hpp b/src/MixSegment.hpp index aace901..589d1ff 100644 --- a/src/MixSegment.hpp +++ b/src/MixSegment.hpp @@ -9,18 +9,14 @@ namespace CppJieba { class MixSegment: public SegmentBase { public: - MixSegment() { - } - MixSegment(const string& mpSegDict, const string& hmmSegDict, const string& userDict = "") { - init(mpSegDict, hmmSegDict, userDict); + MixSegment(const string& mpSegDict, const string& hmmSegDict, + const string& userDict = "") + : mpSeg_(mpSegDict, userDict), + hmmSeg_(hmmSegDict) { + LogInfo("MixSegment init %s, %s", mpSegDict.c_str(), hmmSegDict.c_str()); } virtual ~MixSegment() { } - void init(const string& mpSegDict, const string& hmmSegDict, const string& userDict = "") { - mpSeg_.init(mpSegDict, userDict); - hmmSeg_.init(hmmSegDict); - LogInfo("MixSegment init(%s, %s)", mpSegDict.c_str(), hmmSegDict.c_str()); - } using SegmentBase::cut; virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const { vector words; diff --git a/src/PosTagger.hpp b/src/PosTagger.hpp index dfeebf1..b64e437 100644 --- a/src/PosTagger.hpp +++ b/src/PosTagger.hpp @@ -14,27 +14,15 @@ static const char* const POS_X = "x"; class PosTagger { public: - PosTagger() { - } - PosTagger( - const string& dictPath, + PosTagger(const string& dictPath, const string& hmmFilePath, - const string& userDictPath = "" - ) { - init(dictPath, hmmFilePath, userDictPath); + const string& userDictPath = "") + : segment_(dictPath, hmmFilePath, userDictPath) { + dictTrie_ = segment_.getDictTrie(); + LIMONP_CHECK(dictTrie_); } ~PosTagger() { } - void init( - const string& dictPath, - const string& hmmFilePath, - const string& userDictPath = "" - ) { - segment_.init(dictPath, hmmFilePath, userDictPath); - dictTrie_ = segment_.getDictTrie(); - LIMONP_CHECK(dictTrie_); - }; - bool tag(const string& src, vector >& res) const { vector cutRes; diff --git a/src/QuerySegment.hpp b/src/QuerySegment.hpp index 06cc905..0c66481 100644 --- a/src/QuerySegment.hpp +++ b/src/QuerySegment.hpp @@ -16,17 +16,14 @@ namespace CppJieba { class QuerySegment: public SegmentBase { public: - QuerySegment() {}; - QuerySegment(const string& dict, const string& model, size_t maxWordLen = 4, const string& userDict = "") { - init(dict, model, maxWordLen, userDict); - }; - virtual ~QuerySegment() {}; - void init(const string& dict, const string& model, size_t maxWordLen, const string& userDict = "") { - mixSeg_.init(dict, model, userDict); - fullSeg_.init(mixSeg_.getDictTrie()); + QuerySegment(const string& dict, const string& model, size_t maxWordLen = 4, + const string& userDict = "") + : mixSeg_(dict, model, userDict), + fullSeg_(mixSeg_.getDictTrie()) { assert(maxWordLen); maxWordLen_ = maxWordLen; - } + }; + virtual ~QuerySegment() {}; using SegmentBase::cut; bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const { //use mix cut first