From bda660dc66c78eb64afa3ca693e3b2f4a02aa034 Mon Sep 17 00:00:00 2001 From: gwdwyy Date: Sun, 25 Aug 2013 21:31:18 +0800 Subject: [PATCH] finished viterbi --- src/HMMSegment.cpp | 175 ++++++++++++++++++++++++++++++++++++++++++++- src/HMMSegment.h | 10 ++- src/globals.h | 14 ++-- 3 files changed, 190 insertions(+), 9 deletions(-) diff --git a/src/HMMSegment.cpp b/src/HMMSegment.cpp index cddb320..afa26d4 100644 --- a/src/HMMSegment.cpp +++ b/src/HMMSegment.cpp @@ -6,6 +6,14 @@ namespace CppJieba { memset(_startProb, 0, sizeof(_startProb)); memset(_transProb, 0, sizeof(_transProb)); + _statMap[0] = 'B'; + _statMap[1] = 'E'; + _statMap[2] = 'M'; + _statMap[3] = 'S'; + _emitProbVec.push_back(&_emitProbB); + _emitProbVec.push_back(&_emitProbE); + _emitProbVec.push_back(&_emitProbM); + _emitProbVec.push_back(&_emitProbS); } HMMSegment::~HMMSegment() @@ -29,7 +37,6 @@ namespace CppJieba string line; vector tmp; vector tmp2; - //load _startProb if(!_getLine(ifile, line)) { @@ -94,6 +101,142 @@ namespace CppJieba return true; } + bool HMMSegment::cut(const string& str, vector& res) + { + if(str.empty()) + { + return false; + } + vector unico; + vector status; + vector::iterator begin, left, right; + if(!TransCode::strToVec(str, unico)) + + { + LogError("TransCode failed."); + return false; + } + + if(!viterbi(unico, status)) + { + LogError("viterbi failed."); + return false; + } + begin = unico.begin(); + res.clear(); + for(uint i =0; i< status.size(); i++) + { + switch(status[i]) + { + case B: + left = begin + i; + break; + case E: + right = begin + i + 1; + res.push_back(TransCode::vecToStr(left, right)); + break; + case S: + res.push_back(TransCode::vecToStr(begin + i, begin + i +1)); + + break; + + } + } + + return true; + } + + bool HMMSegment::viterbi(const vector& unico, vector& status) + { + if(unico.empty()) + { + return false; + } + + size_t Y = STATUS_SUM; + size_t X = unico.size(); + size_t XYSize = X * Y; + int * path; + double * weight; + uint now, old, stat; + double tmp, endE, endS; + + try + { + path = new int [XYSize]; + weight = new double [XYSize]; + } + catch(const std::bad_alloc&) + { + LogError("bad_alloc"); + return false; + } + if(NULL == path || NULL == weight) + { + LogError("bad_alloc"); + return false; + } + + //start + for(uint y = 0; y < Y; y++) + { + weight[0 + y * X] = _startProb[y] + _getEmitProb(_emitProbVec[y], unico[0], MIN_DOUBLE); + path[0 + y * X] = -1; + } + + //process + for(uint x = 1; x < X; x++) + { + for(uint y = 0; y < Y; y++) + { + now = x + y*X; + weight[now] = MIN_DOUBLE; + for(uint preY = 0; preY < Y; preY++) + { + old = x - 1 + preY * X; + tmp = weight[old] + _transProb[preY][y] + _getEmitProb(_emitProbVec[y], unico[x], MIN_DOUBLE); + //cout<<__FILE__<<__LINE__< weight[now]) + { + weight[now] = tmp; + path[now] = preY; + } + } + //cout<<__FILE__<<__LINE__< endS) + { + stat = E; + } + else + { + stat = S; + } + + status.assign(X, 0); + for(int x = X -1 ; x >= 0; x--) + { + status[x] = stat; + stat = path[x + stat*X]; + } + + delete [] path; + delete [] weight; + return true; + } + bool HMMSegment::_getLine(ifstream& ifile, string& line) { while(getline(ifile, line)) @@ -150,18 +293,46 @@ namespace CppJieba return true; } + double HMMSegment::_getEmitProb(const EmitProbMap& mp, uint16_t key, double defVal) + { + EmitProbMap::const_iterator cit = mp.find(key); + if(cit == mp.end()) + { + return defVal; + } + return cit->second; + } + + double HMMSegment::_getEmitProb(const EmitProbMap* ptMp, uint16_t key, double defVal) + { + EmitProbMap::const_iterator cit = ptMp->find(key); + if(cit == ptMp->end()) + { + return defVal; + } + return cit->second; + + } } #ifdef HMMSEGMENT_UT using namespace CppJieba; + +size_t add(size_t a, size_t b) +{ + return a*b; +} int main() { TransCode::setUtf8Enc(); HMMSegment hmm; hmm.loadModel("../dicts/hmm_model.utf8"); - //cout< res; + hmm.cut("小明硕士毕业于北邮网络研究院", res); + cout< _emitProbVec; public: HMMSegment(); @@ -34,10 +36,16 @@ namespace CppJieba bool dispose(); public: bool loadModel(const char* const filePath); + bool cut(const string& str, vector& res); + bool viterbi(const vector& unico, vector& status); private: bool _getLine(ifstream& ifile, string& line); bool _loadEmitProb(const string& line, EmitProbMap& mp); bool _decodeOne(const string& str, uint16_t& res); + double _getEmitProb(const EmitProbMap& mp, uint16_t key, double defVal); + double _getEmitProb(const EmitProbMap* ptMp, uint16_t key, double defVal); + + }; } diff --git a/src/globals.h b/src/globals.h index 894c1b6..c61440f 100644 --- a/src/globals.h +++ b/src/globals.h @@ -10,26 +10,28 @@ #include #include #include -#include +//#include +#include //#include namespace CppJieba { using namespace std; - using __gnu_cxx::hash_map; + using std::tr1::unordered_map; + //using __gnu_cxx::hash_map; //using namespace stdext; //typedefs - const double MIN_DOUBLE = -3.14e+100; - const double MAX_DOUBLE = 3.14e+100; typedef unsigned int uint; typedef std::vector::iterator VSI; typedef std::vector VUINT16; typedef std::vector::const_iterator VUINT16_CONST_ITER; - typedef hash_map TrieNodeMap; - typedef hash_map EmitProbMap; + typedef unordered_map TrieNodeMap; + typedef unordered_map EmitProbMap; + const double MIN_DOUBLE = -3.14e+100; + const double MAX_DOUBLE = 3.14e+100; } #endif