remove InitOnOff to make code lighter

This commit is contained in:
wyy 2014-08-12 00:34:37 +08:00
parent 5bfd3d0c49
commit 9571a4d0d5
12 changed files with 37 additions and 96 deletions

View File

@ -10,7 +10,6 @@
#include <limits> #include <limits>
#include "Limonp/StringUtil.hpp" #include "Limonp/StringUtil.hpp"
#include "Limonp/Logger.hpp" #include "Limonp/Logger.hpp"
#include "Limonp/InitOnOff.hpp"
#include "TransCode.hpp" #include "TransCode.hpp"
#include "Trie.hpp" #include "Trie.hpp"
@ -41,7 +40,7 @@ namespace CppJieba
typedef map<size_t, const DictUnit*> DagType; typedef map<size_t, const DictUnit*> DagType;
class DictTrie: public InitOnOff class DictTrie
{ {
public: public:
typedef Trie<Unicode::value_type, DictUnit, Unicode, vector<Unicode>, vector<const DictUnit*> > TrieType; typedef Trie<Unicode::value_type, DictUnit, Unicode, vector<Unicode>, vector<const DictUnit*> > TrieType;
@ -65,12 +64,11 @@ namespace CppJieba
{ {
_trie = NULL; _trie = NULL;
_minWeight = MAX_DOUBLE; _minWeight = MAX_DOUBLE;
_setInitFlag(false);
} }
DictTrie(const string& dictPath, const string& userDictPath = "") DictTrie(const string& dictPath, const string& userDictPath = "")
{ {
new (this) DictTrie(); new (this) DictTrie();
_setInitFlag(init(dictPath, userDictPath)); init(dictPath, userDictPath);
} }
~DictTrie() ~DictTrie()
{ {
@ -83,7 +81,7 @@ namespace CppJieba
public: public:
bool init(const string& dictPath, const string& userDictPath = "") bool init(const string& dictPath, const string& userDictPath = "")
{ {
assert(!_getInitFlag()); assert(!_trie);
_loadDict(dictPath, _nodeInfos); _loadDict(dictPath, _nodeInfos);
_calculateWeight(_nodeInfos); _calculateWeight(_nodeInfos);
_minWeight = _findMinWeight(_nodeInfos); _minWeight = _findMinWeight(_nodeInfos);
@ -96,7 +94,7 @@ namespace CppJieba
_shrink(_nodeInfos); _shrink(_nodeInfos);
_trie = _creatTrie(_nodeInfos); _trie = _creatTrie(_nodeInfos);
assert(_trie); assert(_trie);
return _setInitFlag(true); return true;
} }
public: public:

View File

@ -35,20 +35,15 @@ namespace CppJieba
vector<EmitProbMap* > _emitProbVec; vector<EmitProbMap* > _emitProbVec;
public: public:
HMMSegment(){_setInitFlag(false);} HMMSegment(){}
explicit HMMSegment(const string& filePath) explicit HMMSegment(const string& filePath)
{ {
_setInitFlag(init(filePath)); LIMONP_CHECK(init(filePath));
} }
virtual ~HMMSegment(){} virtual ~HMMSegment(){}
public: public:
bool init(const string& filePath) bool init(const string& filePath)
{ {
if(_getInitFlag())
{
LogError("inited already.");
return false;
}
memset(_startProb, 0, sizeof(_startProb)); memset(_startProb, 0, sizeof(_startProb));
memset(_transProb, 0, sizeof(_transProb)); memset(_transProb, 0, sizeof(_transProb));
_statMap[0] = 'B'; _statMap[0] = 'B';
@ -59,11 +54,7 @@ namespace CppJieba
_emitProbVec.push_back(&_emitProbE); _emitProbVec.push_back(&_emitProbE);
_emitProbVec.push_back(&_emitProbM); _emitProbVec.push_back(&_emitProbM);
_emitProbVec.push_back(&_emitProbS); _emitProbVec.push_back(&_emitProbS);
if(!_setInitFlag(_loadModel(filePath.c_str()))) LIMONP_CHECK(_loadModel(filePath.c_str()));
{
LogError("_loadModel(%s) failed.", filePath.c_str());
return false;
}
LogInfo("HMMSegment init(%s) ok.", filePath.c_str()); LogInfo("HMMSegment init(%s) ok.", filePath.c_str());
return true; return true;
} }
@ -104,7 +95,6 @@ namespace CppJieba
private: private:
bool _cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<Unicode>& res) const bool _cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<Unicode>& res) const
{ {
assert(_getInitFlag());
vector<size_t> status; vector<size_t> status;
if(!_viterbi(begin, end, status)) if(!_viterbi(begin, end, status))
{ {
@ -128,7 +118,6 @@ namespace CppJieba
public: public:
virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<string>& res)const virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<string>& res)const
{ {
assert(_getInitFlag());
if(begin == end) if(begin == end)
{ {
return false; return false;

View File

@ -10,7 +10,7 @@ namespace CppJieba
using namespace Limonp; using namespace Limonp;
/*utf8*/ /*utf8*/
class KeywordExtractor: public InitOnOff class KeywordExtractor
{ {
private: private:
MixSegment _segment; MixSegment _segment;
@ -20,10 +20,10 @@ namespace CppJieba
unordered_set<string> _stopWords; unordered_set<string> _stopWords;
public: public:
KeywordExtractor(){_setInitFlag(false);}; KeywordExtractor(){};
KeywordExtractor(const string& dictPath, const string& hmmFilePath, const string& idfPath, const string& stopWordPath) KeywordExtractor(const string& dictPath, const string& hmmFilePath, const string& idfPath, const string& stopWordPath)
{ {
_setInitFlag(init(dictPath, hmmFilePath, idfPath, stopWordPath)); LIMONP_CHECK(init(dictPath, hmmFilePath, idfPath, stopWordPath));
}; };
~KeywordExtractor(){}; ~KeywordExtractor(){};
@ -32,13 +32,13 @@ namespace CppJieba
{ {
_loadIdfDict(idfPath); _loadIdfDict(idfPath);
_loadStopWordDict(stopWordPath); _loadStopWordDict(stopWordPath);
return _setInitFlag(_segment.init(dictPath, hmmFilePath)); LIMONP_CHECK(_segment.init(dictPath, hmmFilePath));
return true;
}; };
public: public:
bool extract(const string& str, vector<string>& keywords, size_t topN) const bool extract(const string& str, vector<string>& keywords, size_t topN) const
{ {
assert(_getInitFlag());
vector<pair<string, double> > topWords; vector<pair<string, double> > topWords;
if(!extract(str, topWords, topN)) if(!extract(str, topWords, topN))
{ {

View File

@ -4,7 +4,7 @@
#include <stdio.h> #include <stdio.h>
#define LIMONP_CHECK(exp) \ #define LIMONP_CHECK(exp) \
if(exp){fprintf(stderr, "File:%s, Line:%d Exp:[" #exp "] is true, abort.\n", __FILE__, __LINE__); abort();} if(!(exp)){fprintf(stderr, "File:%s, Line:%d Exp:[" #exp "] is true, abort.\n", __FILE__, __LINE__); abort();}
#define print(x) cout<< #x": " << x <<endl #define print(x) cout<< #x": " << x <<endl
/* /*

View File

@ -34,24 +34,18 @@ namespace CppJieba
DictTrie _dictTrie; DictTrie _dictTrie;
public: public:
MPSegment(){_setInitFlag(false);}; MPSegment(){};
explicit MPSegment(const string& dictPath, const string& userDictPath = "") MPSegment(const string& dictPath, const string& userDictPath = "")
{ {
_setInitFlag(init(dictPath, userDictPath)); LIMONP_CHECK(init(dictPath, userDictPath));
}; };
virtual ~MPSegment(){}; virtual ~MPSegment(){};
public: public:
bool init(const string& dictPath, const string& userDictPath = "") bool init(const string& dictPath, const string& userDictPath = "")
{ {
if(_getInitFlag()) LIMONP_CHECK(_dictTrie.init(dictPath, userDictPath));
{
LogError("already inited before now.");
return false;
}
_dictTrie.init(dictPath, userDictPath);
assert(_dictTrie);
LogInfo("MPSegment init(%s) ok", dictPath.c_str()); LogInfo("MPSegment init(%s) ok", dictPath.c_str());
return _setInitFlag(true); return true;
} }
bool isUserDictSingleChineseWord(const Unicode::value_type & value) const bool isUserDictSingleChineseWord(const Unicode::value_type & value) const
{ {
@ -61,7 +55,6 @@ namespace CppJieba
using SegmentBase::cut; using SegmentBase::cut;
virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<string>& res)const virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<string>& res)const
{ {
assert(_getInitFlag());
if(begin == end) if(begin == end)
{ {
return false; return false;
@ -92,7 +85,6 @@ namespace CppJieba
{ {
return false; return false;
} }
assert(_getInitFlag());
vector<SegmentChar> segmentChars(end - begin); vector<SegmentChar> segmentChars(end - begin);
//calc DAG //calc DAG

View File

@ -14,36 +14,25 @@ namespace CppJieba
MPSegment _mpSeg; MPSegment _mpSeg;
HMMSegment _hmmSeg; HMMSegment _hmmSeg;
public: public:
MixSegment(){_setInitFlag(false);}; MixSegment(){};
MixSegment(const string& mpSegDict, const string& hmmSegDict, const string& userDict = "") MixSegment(const string& mpSegDict, const string& hmmSegDict, const string& userDict = "")
{ {
_setInitFlag(init(mpSegDict, hmmSegDict, userDict)); LIMONP_CHECK(init(mpSegDict, hmmSegDict, userDict));
assert(_getInitFlag());
} }
virtual ~MixSegment(){} virtual ~MixSegment(){}
public: public:
bool init(const string& mpSegDict, const string& hmmSegDict, const string& userDict = "") bool init(const string& mpSegDict, const string& hmmSegDict, const string& userDict = "")
{ {
assert(!_getInitFlag()); LIMONP_CHECK(_mpSeg.init(mpSegDict, userDict));
if(!_mpSeg.init(mpSegDict, userDict)) LIMONP_CHECK(_hmmSeg.init(hmmSegDict));
{
LogError("_mpSeg init");
return false;
}
if(!_hmmSeg.init(hmmSegDict))
{
LogError("_hmmSeg init");
return false;
}
LogInfo("MixSegment init(%s, %s)", mpSegDict.c_str(), hmmSegDict.c_str()); LogInfo("MixSegment init(%s, %s)", mpSegDict.c_str(), hmmSegDict.c_str());
return _setInitFlag(true); return true;
} }
public: public:
using SegmentBase::cut; using SegmentBase::cut;
public: public:
virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<Unicode>& res) const virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<Unicode>& res) const
{ {
assert(_getInitFlag());
vector<Unicode> words; vector<Unicode> words;
words.reserve(end - begin); words.reserve(end - begin);
if(!_mpSeg.cut(begin, end, words)) if(!_mpSeg.cut(begin, end, words))
@ -98,7 +87,6 @@ namespace CppJieba
virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<string>& res)const virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<string>& res)const
{ {
assert(_getInitFlag());
if(begin == end) if(begin == end)
{ {
return false; return false;

View File

@ -9,32 +9,29 @@ namespace CppJieba
{ {
using namespace Limonp; using namespace Limonp;
class PosTagger: public InitOnOff class PosTagger
{ {
private: private:
MixSegment _segment; MixSegment _segment;
DictTrie _dictTrie; DictTrie _dictTrie;
public: public:
PosTagger(){_setInitFlag(false);}; PosTagger(){};
PosTagger(const string& dictPath, const string& hmmFilePath, const string& charStatus, const string& startProb, const string& emitProb, const string& endProb, const string& transProb) PosTagger(const string& dictPath, const string& hmmFilePath, const string& charStatus, const string& startProb, const string& emitProb, const string& endProb, const string& transProb)
{ {
_setInitFlag(init(dictPath, hmmFilePath, charStatus, startProb, emitProb, endProb, transProb)); LIMONP_CHECK(init(dictPath, hmmFilePath, charStatus, startProb, emitProb, endProb, transProb));
}; };
~PosTagger(){}; ~PosTagger(){};
public: public:
bool init(const string& dictPath, const string& hmmFilePath, const string& charStatus, const string& startProb, const string& emitProb, const string& endProb, const string& transProb) bool init(const string& dictPath, const string& hmmFilePath, const string& charStatus, const string& startProb, const string& emitProb, const string& endProb, const string& transProb)
{ {
LIMONP_CHECK(_dictTrie.init(dictPath));
assert(!_getInitFlag()); LIMONP_CHECK(_segment.init(dictPath, hmmFilePath));
_dictTrie.init(dictPath); return true;
assert(_dictTrie);
return _setInitFlag(_segment.init(dictPath, hmmFilePath));
}; };
bool tag(const string& src, vector<pair<string, string> >& res) bool tag(const string& src, vector<pair<string, string> >& res)
{ {
assert(_getInitFlag());
vector<string> cutRes; vector<string> cutRes;
if (!_segment.cut(src, cutRes)) if (!_segment.cut(src, cutRes))
{ {

View File

@ -23,32 +23,20 @@ namespace CppJieba
size_t _maxWordLen; size_t _maxWordLen;
public: public:
QuerySegment(){_setInitFlag(false);}; QuerySegment(){};
QuerySegment(const string& dict, const string& model, size_t maxWordLen) QuerySegment(const string& dict, const string& model, size_t maxWordLen)
{ {
_setInitFlag(init(dict, model, maxWordLen)); init(dict, model, maxWordLen);
}; };
virtual ~QuerySegment(){}; virtual ~QuerySegment(){};
public: public:
bool init(const string& dict, const string& model, size_t maxWordLen) bool init(const string& dict, const string& model, size_t maxWordLen)
{ {
if (_getInitFlag()) LIMONP_CHECK(_mixSeg.init(dict, model));
{ LIMONP_CHECK(_fullSeg.init(_mixSeg.getDictTrie()));
LogError("inited already."); assert(maxWordLen);
return false;
}
if (!_mixSeg.init(dict, model))
{
LogError("_mixSeg init");
return false;
}
if (!_fullSeg.init(_mixSeg.getDictTrie()))
{
LogError("_fullSeg init");
return false;
}
_maxWordLen = maxWordLen; _maxWordLen = maxWordLen;
return _setInitFlag(true); return true;
} }
public: public:
@ -57,7 +45,6 @@ namespace CppJieba
public: public:
bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<Unicode>& res) const bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<Unicode>& res) const
{ {
assert(_getInitFlag());
if (begin >= end) if (begin >= end)
{ {
LogError("begin >= end"); LogError("begin >= end");
@ -102,7 +89,6 @@ namespace CppJieba
bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<string>& res) const bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<string>& res) const
{ {
assert(_getInitFlag());
if (begin >= end) if (begin >= end)
{ {
LogError("begin >= end"); LogError("begin >= end");

View File

@ -3,8 +3,8 @@
#include "TransCode.hpp" #include "TransCode.hpp"
#include "Limonp/Logger.hpp" #include "Limonp/Logger.hpp"
#include "Limonp/InitOnOff.hpp"
#include "Limonp/NonCopyable.hpp" #include "Limonp/NonCopyable.hpp"
#include "Limonp/HandyMacro.hpp"
#include "ISegment.hpp" #include "ISegment.hpp"
#include <cassert> #include <cassert>
@ -20,7 +20,7 @@ namespace CppJieba
const UnicodeValueType SPECIAL_SYMBOL[] = {32u, 9u, 10u}; const UnicodeValueType SPECIAL_SYMBOL[] = {32u, 9u, 10u};
#endif #endif
class SegmentBase: public ISegment, public InitOnOff, public NonCopyable class SegmentBase: public ISegment, public NonCopyable
{ {
public: public:
SegmentBase(){_loadSpecialSymbols();}; SegmentBase(){_loadSpecialSymbols();};

View File

@ -11,7 +11,6 @@ using namespace CppJieba;
void cut(size_t times = 20) void cut(size_t times = 20)
{ {
MixSegment seg("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8"); MixSegment seg("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8");
assert(seg);
vector<string> res; vector<string> res;
string doc; string doc;
ifstream ifs("../test/testdata/weicheng.utf8"); ifstream ifs("../test/testdata/weicheng.utf8");
@ -32,7 +31,6 @@ void cut(size_t times = 20)
void extract(size_t times = 400) void extract(size_t times = 400)
{ {
KeywordExtractor extractor("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8", "../dict/idf.utf8", "../dict/stop_words.utf8"); KeywordExtractor extractor("../dict/jieba.dict.utf8", "../dict/hmm_model.utf8", "../dict/idf.utf8", "../dict/stop_words.utf8");
assert(extractor);
vector<string> words; vector<string> words;
string doc; string doc;
ifstream ifs("../test/testdata/review.100"); ifstream ifs("../test/testdata/review.100");

View File

@ -17,8 +17,6 @@ TEST(MixSegmentTest, Test1)
const char* str2 = "B超 T恤"; const char* str2 = "B超 T恤";
const char* res2[] = {"B超"," ", "T恤"}; const char* res2[] = {"B超"," ", "T恤"};
vector<string> words; vector<string> words;
ASSERT_TRUE(segment);
ASSERT_TRUE(segment.cut(str, words)); ASSERT_TRUE(segment.cut(str, words));
ASSERT_EQ(words, vector<string>(res, res + sizeof(res)/sizeof(res[0]))); ASSERT_EQ(words, vector<string>(res, res + sizeof(res)/sizeof(res[0])));
ASSERT_TRUE(segment.cut(str2, words)); ASSERT_TRUE(segment.cut(str2, words));
@ -29,7 +27,6 @@ TEST(MixSegmentTest, Test1)
TEST(MixSegmentTest, NoUserDict) TEST(MixSegmentTest, NoUserDict)
{ {
MixSegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8"); MixSegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8");
ASSERT_TRUE(segment);
const char* str = "令狐冲是云计算方面的专家"; const char* str = "令狐冲是云计算方面的专家";
vector<string> words; vector<string> words;
ASSERT_TRUE(segment.cut(str, words)); ASSERT_TRUE(segment.cut(str, words));
@ -40,7 +37,6 @@ TEST(MixSegmentTest, NoUserDict)
TEST(MixSegmentTest, UserDict) TEST(MixSegmentTest, UserDict)
{ {
MixSegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", "../test/testdata/userdict.utf8"); MixSegment segment("../dict/extra_dict/jieba.dict.small.utf8", "../dict/hmm_model.utf8", "../test/testdata/userdict.utf8");
ASSERT_TRUE(segment);
const char* str = "令狐冲是云计算方面的专家"; const char* str = "令狐冲是云计算方面的专家";
vector<string> words; vector<string> words;
ASSERT_TRUE(segment.cut(str, words)); ASSERT_TRUE(segment.cut(str, words));
@ -55,7 +51,6 @@ TEST(MPSegmentTest, Test1)
const char* str = "我来自北京邮电大学。"; const char* str = "我来自北京邮电大学。";
const char* res[] = {"", "来自", "北京邮电大学", ""}; const char* res[] = {"", "来自", "北京邮电大学", ""};
vector<string> words; vector<string> words;
ASSERT_TRUE(segment);
ASSERT_TRUE(segment.cut(str, words)); ASSERT_TRUE(segment.cut(str, words));
ASSERT_EQ(words, vector<string>(res, res + sizeof(res)/sizeof(res[0]))); ASSERT_EQ(words, vector<string>(res, res + sizeof(res)/sizeof(res[0])));
@ -105,7 +100,6 @@ TEST(HMMSegmentTest, Test1)
const char* str = "我来自北京邮电大学。。。学号123456"; const char* str = "我来自北京邮电大学。。。学号123456";
const char* res[] = {"我来", "自北京", "邮电大学", "", "", "", "学号", "123456"}; const char* res[] = {"我来", "自北京", "邮电大学", "", "", "", "学号", "123456"};
vector<string> words; vector<string> words;
ASSERT_TRUE(segment);
ASSERT_TRUE(segment.cut(str, words)); ASSERT_TRUE(segment.cut(str, words));
ASSERT_EQ(words, vector<string>(res, res + sizeof(res)/sizeof(res[0]))); ASSERT_EQ(words, vector<string>(res, res + sizeof(res)/sizeof(res[0])));
} }

View File

@ -55,7 +55,6 @@ TEST(DictTrieTest, Test1)
TEST(DictTrieTest, UserDict) TEST(DictTrieTest, UserDict)
{ {
DictTrie trie(DICT_FILE, "../test/testdata/userdict.utf8"); DictTrie trie(DICT_FILE, "../test/testdata/userdict.utf8");
ASSERT_TRUE(trie);
string word = "云计算"; string word = "云计算";
Unicode unicode; Unicode unicode;
ASSERT_TRUE(TransCode::decode(word, unicode)); ASSERT_TRUE(TransCode::decode(word, unicode));