From d4b69f9e5894e81dea673e906d7b7750ae3bd686 Mon Sep 17 00:00:00 2001 From: gwdwyy Date: Sat, 24 Aug 2013 22:24:54 +0800 Subject: [PATCH] finished load hmm_model --- src/HMMSegment.cpp | 131 ++++++++++++++++++++++++++++++++++- src/HMMSegment.h | 27 +++++++- src/Makefile | 2 + src/Segment.cpp | 2 +- src/Trie.cpp | 2 +- src/cppcommon/str_functs.cpp | 3 +- src/globals.h | 7 +- src/structs.h | 2 +- 8 files changed, 166 insertions(+), 10 deletions(-) diff --git a/src/HMMSegment.cpp b/src/HMMSegment.cpp index f6291da..cddb320 100644 --- a/src/HMMSegment.cpp +++ b/src/HMMSegment.cpp @@ -4,10 +4,13 @@ namespace CppJieba { HMMSegment::HMMSegment() { + memset(_startProb, 0, sizeof(_startProb)); + memset(_transProb, 0, sizeof(_transProb)); } HMMSegment::~HMMSegment() { + } bool HMMSegment::init() @@ -20,10 +23,133 @@ namespace CppJieba return true; } - bool HMMSegment::loadModel() + bool HMMSegment::loadModel(const char* const filePath) { + ifstream ifile(filePath); + string line; + vector tmp; + vector tmp2; + + //load _startProb + if(!_getLine(ifile, line)) + { + return false; + } + splitStr(line, tmp, " "); + if(tmp.size() != STATUS_SUM) + { + LogError("start_p illegal"); + return false; + } + for(uint j = 0; j< tmp.size(); j++) + { + _startProb[j] = atof(tmp[j].c_str()); + //cout<<_startProb[j]< tmp, tmp2; + uint16_t unico; + splitStr(line, tmp, ","); + for(uint i = 0; i < tmp.size(); i++) + { + splitStr(tmp[i], tmp2, ":"); + if(2 != tmp2.size()) + { + LogError("_emitProb illegal."); + return false; + } + if(!_decodeOne(tmp2[0], unico)) + { + LogError("TransCode failed."); + return false; + } + mp[unico] = atof(tmp2[1].c_str()); + } + return true; + } + + bool HMMSegment::_decodeOne(const string& str, uint16_t& res) + { + vector ui16; + if(!TransCode::strToVec(str, ui16) || ui16.size() != 1) + { + return false; + } + res = ui16[0]; + return true; + } + } @@ -32,7 +158,10 @@ using namespace CppJieba; int main() { + TransCode::setUtf8Enc(); HMMSegment hmm; + hmm.loadModel("../dicts/hmm_model.utf8"); + //cout< +#include +#include +#include "cppcommon/headers.h" +#include "globals.h" +#include "TransCode.h" + namespace CppJieba { + using namespace CPPCOMMON; class HMMSegment { + private: + /* + * STATUS: + * 0:B, 1:E, 2:M, 3:S + * */ + enum {STATUS_SUM = 4}; + double _startProb[STATUS_SUM]; + double _transProb[STATUS_SUM][STATUS_SUM]; + EmitProbMap _emitProbB; + EmitProbMap _emitProbE; + EmitProbMap _emitProbM; + EmitProbMap _emitProbS; + public: HMMSegment(); ~HMMSegment(); @@ -12,7 +33,11 @@ namespace CppJieba bool init(); bool dispose(); public: - bool loadModel(); + bool loadModel(const char* const filePath); + private: + bool _getLine(ifstream& ifile, string& line); + bool _loadEmitProb(const string& line, EmitProbMap& mp); + bool _decodeOne(const string& str, uint16_t& res); }; } diff --git a/src/Makefile b/src/Makefile index 9de50cf..fe55b4c 100644 --- a/src/Makefile +++ b/src/Makefile @@ -61,6 +61,8 @@ KeyWordExt.ut: KeyWordExt.cpp KeyWordExt.h Segment.h Trie.h globals.h TransCode. TransCode.ut: TransCode.cpp TransCode.h globals.h $(CMLIB) $(CXX) -o $@ $(CXXFLAGS) TransCode.cpp -DCPPJIEBA_TRANSCODE_UT $(CMLIB) +HMMSegment.ut: HMMSegment.cpp TransCode.cpp TransCode.h HMMSegment.h $(CMLIB) + $(CXX) -o $@ $(CXXFLAGS) TransCode.cpp HMMSegment.cpp -DHMMSEGMENT_UT $(CMLIB) clean: rm -f *.o *.d *.ut $(LIBA) diff --git a/src/Segment.cpp b/src/Segment.cpp index 68167fb..3cb7047 100644 --- a/src/Segment.cpp +++ b/src/Segment.cpp @@ -138,7 +138,7 @@ namespace CppJieba { // calc max segContext.dp[i].first = NULL; - segContext.dp[i].second = -(numeric_limits::max()); + segContext.dp[i].second = MIN_DOUBLE; for(uint j = 0; j < segContext.dag[i].size(); j++) { const pair& p = segContext.dag[i][j]; diff --git a/src/Trie.cpp b/src/Trie.cpp index 52182de..5ebec64 100644 --- a/src/Trie.cpp +++ b/src/Trie.cpp @@ -21,7 +21,7 @@ namespace CppJieba _root = NULL; _freqSum = 0; - _minLogFreq = numeric_limits::max(); + _minLogFreq = MAX_DOUBLE; _initFlag = false; } diff --git a/src/cppcommon/str_functs.cpp b/src/cppcommon/str_functs.cpp index 5f70a02..8d41ce9 100644 --- a/src/cppcommon/str_functs.cpp +++ b/src/cppcommon/str_functs.cpp @@ -53,10 +53,11 @@ namespace CPPCOMMON } void splitStr(const string& source, vector& out_vec, const string& pattern) { - if(0 == pattern.size()) + if(source.empty()) { return; } + out_vec.clear(); string s = source + pattern; string::size_type pos; uint length = s.size(); diff --git a/src/globals.h b/src/globals.h index 1bdd50c..894c1b6 100644 --- a/src/globals.h +++ b/src/globals.h @@ -20,16 +20,15 @@ namespace CppJieba using __gnu_cxx::hash_map; //using namespace stdext; //typedefs + const double MIN_DOUBLE = -3.14e+100; + const double MAX_DOUBLE = 3.14e+100; typedef unsigned int uint; typedef std::vector::iterator VSI; typedef std::vector VUINT16; typedef std::vector::const_iterator VUINT16_CONST_ITER; typedef hash_map TrieNodeMap; + typedef hash_map EmitProbMap; - namespace HMMDict - { - - } } diff --git a/src/structs.h b/src/structs.h index 983c75e..4793348 100644 --- a/src/structs.h +++ b/src/structs.h @@ -26,7 +26,7 @@ namespace CppJieba word = _word; wLen = TransCode::getWordLength(_word); freq = 0; - logFreq = -numeric_limits::max(); + logFreq = MIN_DOUBLE; } };