finished load hmm_model

This commit is contained in:
gwdwyy 2013-08-24 22:24:54 +08:00
parent 8ab598eaeb
commit d4b69f9e58
8 changed files with 166 additions and 10 deletions

View File

@ -4,10 +4,13 @@ namespace CppJieba
{
HMMSegment::HMMSegment()
{
memset(_startProb, 0, sizeof(_startProb));
memset(_transProb, 0, sizeof(_transProb));
}
HMMSegment::~HMMSegment()
{
}
bool HMMSegment::init()
@ -20,10 +23,133 @@ namespace CppJieba
return true;
}
bool HMMSegment::loadModel()
bool HMMSegment::loadModel(const char* const filePath)
{
ifstream ifile(filePath);
string line;
vector<string> tmp;
vector<string> tmp2;
//load _startProb
if(!_getLine(ifile, line))
{
return false;
}
splitStr(line, tmp, " ");
if(tmp.size() != STATUS_SUM)
{
LogError("start_p illegal");
return false;
}
for(uint j = 0; j< tmp.size(); j++)
{
_startProb[j] = atof(tmp[j].c_str());
//cout<<_startProb[j]<<endl;
}
//load _transProb
for(uint i = 0; i < STATUS_SUM; i++)
{
if(!_getLine(ifile, line))
{
return false;
}
splitStr(line, tmp, " ");
if(tmp.size() != STATUS_SUM)
{
LogError("trans_p illegal");
return false;
}
for(uint j =0; j < STATUS_SUM; j++)
{
_transProb[i][j] = atof(tmp[j].c_str());
//cout<<_transProb[i][j]<<endl;
}
}
//load _emitProbB
if(!_getLine(ifile, line) || !_loadEmitProb(line, _emitProbB))
{
return false;
}
//load _emitProbE
if(!_getLine(ifile, line) || !_loadEmitProb(line, _emitProbE))
{
return false;
}
//load _emitProbM
if(!_getLine(ifile, line) || !_loadEmitProb(line, _emitProbM))
{
return false;
}
//load _emitProbS
if(!_getLine(ifile, line) || !_loadEmitProb(line, _emitProbS))
{
return false;
}
return true;
}
bool HMMSegment::_getLine(ifstream& ifile, string& line)
{
while(getline(ifile, line))
{
trim(line);
if(line.empty())
{
continue;
}
if(strStartsWith(line, "#"))
{
continue;
}
return true;
}
return false;
}
bool HMMSegment::_loadEmitProb(const string& line, EmitProbMap& mp)
{
if(line.empty())
{
return false;
}
vector<string> tmp, tmp2;
uint16_t unico;
splitStr(line, tmp, ",");
for(uint i = 0; i < tmp.size(); i++)
{
splitStr(tmp[i], tmp2, ":");
if(2 != tmp2.size())
{
LogError("_emitProb illegal.");
return false;
}
if(!_decodeOne(tmp2[0], unico))
{
LogError("TransCode failed.");
return false;
}
mp[unico] = atof(tmp2[1].c_str());
}
return true;
}
bool HMMSegment::_decodeOne(const string& str, uint16_t& res)
{
vector<uint16_t> ui16;
if(!TransCode::strToVec(str, ui16) || ui16.size() != 1)
{
return false;
}
res = ui16[0];
return true;
}
}
@ -32,7 +158,10 @@ using namespace CppJieba;
int main()
{
TransCode::setUtf8Enc();
HMMSegment hmm;
hmm.loadModel("../dicts/hmm_model.utf8");
//cout<<MIN_DOUBLE<<endl;
return 0;
}

View File

@ -1,10 +1,31 @@
#ifndef CPPJIBEA_HMMSEGMENT_H
#define CPPJIBEA_HMMSEGMENT_H
#include <iostream>
#include <fstream>
#include <memory.h>
#include "cppcommon/headers.h"
#include "globals.h"
#include "TransCode.h"
namespace CppJieba
{
using namespace CPPCOMMON;
class HMMSegment
{
private:
/*
* STATUS:
* 0:B, 1:E, 2:M, 3:S
* */
enum {STATUS_SUM = 4};
double _startProb[STATUS_SUM];
double _transProb[STATUS_SUM][STATUS_SUM];
EmitProbMap _emitProbB;
EmitProbMap _emitProbE;
EmitProbMap _emitProbM;
EmitProbMap _emitProbS;
public:
HMMSegment();
~HMMSegment();
@ -12,7 +33,11 @@ namespace CppJieba
bool init();
bool dispose();
public:
bool loadModel();
bool loadModel(const char* const filePath);
private:
bool _getLine(ifstream& ifile, string& line);
bool _loadEmitProb(const string& line, EmitProbMap& mp);
bool _decodeOne(const string& str, uint16_t& res);
};
}

View File

@ -61,6 +61,8 @@ KeyWordExt.ut: KeyWordExt.cpp KeyWordExt.h Segment.h Trie.h globals.h TransCode.
TransCode.ut: TransCode.cpp TransCode.h globals.h $(CMLIB)
$(CXX) -o $@ $(CXXFLAGS) TransCode.cpp -DCPPJIEBA_TRANSCODE_UT $(CMLIB)
HMMSegment.ut: HMMSegment.cpp TransCode.cpp TransCode.h HMMSegment.h $(CMLIB)
$(CXX) -o $@ $(CXXFLAGS) TransCode.cpp HMMSegment.cpp -DHMMSEGMENT_UT $(CMLIB)
clean:
rm -f *.o *.d *.ut $(LIBA)

View File

@ -138,7 +138,7 @@ namespace CppJieba
{
// calc max
segContext.dp[i].first = NULL;
segContext.dp[i].second = -(numeric_limits<double>::max());
segContext.dp[i].second = MIN_DOUBLE;
for(uint j = 0; j < segContext.dag[i].size(); j++)
{
const pair<uint , const TrieNodeInfo*>& p = segContext.dag[i][j];

View File

@ -21,7 +21,7 @@ namespace CppJieba
_root = NULL;
_freqSum = 0;
_minLogFreq = numeric_limits<double>::max();
_minLogFreq = MAX_DOUBLE;
_initFlag = false;
}

View File

@ -53,10 +53,11 @@ namespace CPPCOMMON
}
void splitStr(const string& source, vector<string>& out_vec, const string& pattern)
{
if(0 == pattern.size())
if(source.empty())
{
return;
}
out_vec.clear();
string s = source + pattern;
string::size_type pos;
uint length = s.size();

View File

@ -20,16 +20,15 @@ namespace CppJieba
using __gnu_cxx::hash_map;
//using namespace stdext;
//typedefs
const double MIN_DOUBLE = -3.14e+100;
const double MAX_DOUBLE = 3.14e+100;
typedef unsigned int uint;
typedef std::vector<std::string>::iterator VSI;
typedef std::vector<uint16_t> VUINT16;
typedef std::vector<uint16_t>::const_iterator VUINT16_CONST_ITER;
typedef hash_map<uint16_t, struct TrieNode*> TrieNodeMap;
typedef hash_map<uint16_t, double> EmitProbMap;
namespace HMMDict
{
}
}

View File

@ -26,7 +26,7 @@ namespace CppJieba
word = _word;
wLen = TransCode::getWordLength(_word);
freq = 0;
logFreq = -numeric_limits<double>::max();
logFreq = MIN_DOUBLE;
}
};