diff --git a/src/FullSegment.hpp b/src/FullSegment.hpp index 624383e..55594be 100644 --- a/src/FullSegment.hpp +++ b/src/FullSegment.hpp @@ -13,40 +13,23 @@ namespace CppJieba { class FullSegment: public SegmentBase { public: - FullSegment() { - dictTrie_ = NULL; - isBorrowed_ = false; + FullSegment(const string& dictPath) { + dictTrie_ = new DictTrie(dictPath); + isNeedDestroy_ = true; + LogInfo("FullSegment init %s ok", dictPath.c_str()); } - explicit FullSegment(const string& dictPath) { - dictTrie_ = NULL; - init(dictPath); - } - explicit FullSegment(const DictTrie* dictTrie) { - dictTrie_ = NULL; - init(dictTrie); + FullSegment(const DictTrie* dictTrie) + : dictTrie_(dictTrie), isNeedDestroy_(false) { + assert(dictTrie_); } virtual ~FullSegment() { - if(dictTrie_ && ! isBorrowed_) { + if(isNeedDestroy_) { delete dictTrie_; } - - }; - bool init(const string& dictPath) { - assert(dictTrie_ == NULL); - dictTrie_ = new DictTrie(dictPath); - isBorrowed_ = false; - return true; } - bool init(const DictTrie* dictTrie) { - assert(dictTrie_ == NULL); - assert(dictTrie); - dictTrie_ = dictTrie; - isBorrowed_ = true; - return true; - } - using SegmentBase::cut; - bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const { + bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, + vector& res) const { //resut of searching in trie tree DagType tRes; @@ -87,7 +70,8 @@ class FullSegment: public SegmentBase { return true; } - bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const { + bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, + vector& res) const { vector uRes; if (!cut(begin, end, uRes)) { LogError("get unicode cut result error."); @@ -95,7 +79,8 @@ class FullSegment: public SegmentBase { } string tmp; - for (vector::const_iterator uItr = uRes.begin(); uItr != uRes.end(); uItr++) { + for (vector::const_iterator uItr = uRes.begin(); + uItr != uRes.end(); uItr++) { TransCode::encode(*uItr, tmp); res.push_back(tmp); } @@ -104,7 +89,7 @@ class FullSegment: public SegmentBase { } private: const DictTrie* dictTrie_; - bool isBorrowed_; + bool isNeedDestroy_; }; } diff --git a/src/HMMSegment.hpp b/src/HMMSegment.hpp index 78c4dfd..f903b02 100644 --- a/src/HMMSegment.hpp +++ b/src/HMMSegment.hpp @@ -12,9 +12,17 @@ namespace CppJieba { class HMMSegment: public SegmentBase { public: - explicit HMMSegment(const string& filePath): model_(filePath) { + HMMSegment(const string& filePath) { + model_ = new HMMModel(filePath); + } + HMMSegment(const HMMModel* model) + : model_(model), isNeedDestroy_(false) { + } + virtual ~HMMSegment() { + if(isNeedDestroy_) { + delete model_; + } } - virtual ~HMMSegment() {} using SegmentBase::cut; bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const { @@ -138,7 +146,7 @@ class HMMSegment: public SegmentBase { //start for(size_t y = 0; y < Y; y++) { - weight[0 + y * X] = model_.startProb[y] + model_.getEmitProb(model_.emitProbVec[y], *begin, MIN_DOUBLE); + weight[0 + y * X] = model_->startProb[y] + model_->getEmitProb(model_->emitProbVec[y], *begin, MIN_DOUBLE); path[0 + y * X] = -1; } @@ -149,10 +157,10 @@ class HMMSegment: public SegmentBase { now = x + y*X; weight[now] = MIN_DOUBLE; path[now] = HMMModel::E; // warning - emitProb = model_.getEmitProb(model_.emitProbVec[y], *(begin+x), MIN_DOUBLE); + emitProb = model_->getEmitProb(model_->emitProbVec[y], *(begin+x), MIN_DOUBLE); for(size_t preY = 0; preY < Y; preY++) { old = x - 1 + preY * X; - tmp = weight[old] + model_.transProb[preY][y] + emitProb; + tmp = weight[old] + model_->transProb[preY][y] + emitProb; if(tmp > weight[now]) { weight[now] = tmp; path[now] = preY; @@ -179,7 +187,9 @@ class HMMSegment: public SegmentBase { return true; } - HMMModel model_; + private: + const HMMModel* model_; + bool isNeedDestroy_; }; // class HMMSegment } // namespace CppJieba diff --git a/src/KeywordExtractor.hpp b/src/KeywordExtractor.hpp index dee73c5..11fcea8 100644 --- a/src/KeywordExtractor.hpp +++ b/src/KeywordExtractor.hpp @@ -20,6 +20,14 @@ class KeywordExtractor { loadIdfDict_(idfPath); loadStopWordDict_(stopWordPath); } + KeywordExtractor(const DictTrie* dictTrie, + const HMMModel* model, + const string& idfPath, + const string& stopWordPath) + : segment_(dictTrie, model){ + loadIdfDict_(idfPath); + loadStopWordDict_(stopWordPath); + } ~KeywordExtractor() { } diff --git a/src/MPSegment.hpp b/src/MPSegment.hpp index a0d3a57..0e30baf 100644 --- a/src/MPSegment.hpp +++ b/src/MPSegment.hpp @@ -12,28 +12,28 @@ namespace CppJieba { class MPSegment: public SegmentBase { - public: - MPSegment() {}; MPSegment(const string& dictPath, const string& userDictPath = "") { - init(dictPath, userDictPath); - }; - virtual ~MPSegment() {}; - - void init(const string& dictPath, const string& userDictPath = "") { - dictTrie_.init(dictPath, userDictPath); + dictTrie_ = new DictTrie(dictPath, userDictPath); + isNeedDestroy_ = true; LogInfo("MPSegment init(%s) ok", dictPath.c_str()); } + MPSegment(const DictTrie* dictTrie) + : dictTrie_(dictTrie), isNeedDestroy_(false) { + assert(dictTrie_); + } + virtual ~MPSegment() { + if(isNeedDestroy_) { + delete dictTrie_; + } + } + bool isUserDictSingleChineseWord(const Unicode::value_type & value) const { - return dictTrie_.isUserDictSingleChineseWord(value); + return dictTrie_->isUserDictSingleChineseWord(value); } using SegmentBase::cut; virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const { - if(begin == end) { - return false; - } - vector words; words.reserve(end - begin); if(!cut(begin, end, words)) { @@ -48,12 +48,9 @@ class MPSegment: public SegmentBase { } bool cut(Unicode::const_iterator begin , Unicode::const_iterator end, vector& res) const { - if(end == begin) { - return false; - } vector segmentChars; - dictTrie_.find(begin, end, segmentChars); + dictTrie_->find(begin, end, segmentChars); calcDP_(segmentChars); @@ -62,7 +59,7 @@ class MPSegment: public SegmentBase { return true; } const DictTrie* getDictTrie() const { - return &dictTrie_; + return dictTrie_; } private: @@ -86,7 +83,7 @@ class MPSegment: public SegmentBase { if(p) { val += p->weight; } else { - val += dictTrie_.getMinWeight(); + val += dictTrie_->getMinWeight(); } if(val > rit->weight) { rit->pInfo = p; @@ -95,7 +92,8 @@ class MPSegment: public SegmentBase { } } } - void cut_(const vector& segmentChars, vector& res) const { + void cut_(const vector& segmentChars, + vector& res) const { size_t i = 0; while(i < segmentChars.size()) { const DictUnit* p = segmentChars[i].pInfo; @@ -110,9 +108,10 @@ class MPSegment: public SegmentBase { } private: - DictTrie dictTrie_; + const DictTrie* dictTrie_; + bool isNeedDestroy_; +}; // class MPSegment -}; -} +} // namespace CppJieba #endif diff --git a/src/MixSegment.hpp b/src/MixSegment.hpp index 589d1ff..afe536d 100644 --- a/src/MixSegment.hpp +++ b/src/MixSegment.hpp @@ -15,6 +15,9 @@ class MixSegment: public SegmentBase { hmmSeg_(hmmSegDict) { LogInfo("MixSegment init %s, %s", mpSegDict.c_str(), hmmSegDict.c_str()); } + MixSegment(const DictTrie* dictTrie, const HMMModel* model) + : mpSeg_(dictTrie), hmmSeg_(model) { + } virtual ~MixSegment() { } using SegmentBase::cut; @@ -90,7 +93,9 @@ class MixSegment: public SegmentBase { private: MPSegment mpSeg_; HMMSegment hmmSeg_; -}; -} + +}; // class MixSegment + +} // namespace CppJieba #endif diff --git a/src/PosTagger.hpp b/src/PosTagger.hpp index b64e437..2bb9e76 100644 --- a/src/PosTagger.hpp +++ b/src/PosTagger.hpp @@ -18,8 +18,9 @@ class PosTagger { const string& hmmFilePath, const string& userDictPath = "") : segment_(dictPath, hmmFilePath, userDictPath) { - dictTrie_ = segment_.getDictTrie(); - LIMONP_CHECK(dictTrie_); + } + PosTagger(const DictTrie* dictTrie, const HMMModel* model) + : segment_(dictTrie, model) { } ~PosTagger() { } @@ -33,12 +34,14 @@ class PosTagger { const DictUnit *tmp = NULL; Unicode unico; + const DictTrie * dict = segment_.getDictTrie(); + assert(dict != NULL); for (vector::iterator itr = cutRes.begin(); itr != cutRes.end(); ++itr) { if (!TransCode::decode(*itr, unico)) { LogError("decode failed."); return false; } - tmp = dictTrie_->find(unico.begin(), unico.end()); + tmp = dict->find(unico.begin(), unico.end()); if(tmp == NULL || tmp->tag.empty()) { res.push_back(make_pair(*itr, specialRule_(unico))); } else { @@ -72,8 +75,8 @@ class PosTagger { } private: MixSegment segment_; - const DictTrie * dictTrie_; -}; -} +}; // class PosTagger + +} // namespace CppJieba #endif diff --git a/src/QuerySegment.hpp b/src/QuerySegment.hpp index 0c66481..be2f6d6 100644 --- a/src/QuerySegment.hpp +++ b/src/QuerySegment.hpp @@ -19,11 +19,15 @@ class QuerySegment: public SegmentBase { QuerySegment(const string& dict, const string& model, size_t maxWordLen = 4, const string& userDict = "") : mixSeg_(dict, model, userDict), - fullSeg_(mixSeg_.getDictTrie()) { - assert(maxWordLen); - maxWordLen_ = maxWordLen; - }; - virtual ~QuerySegment() {}; + fullSeg_(mixSeg_.getDictTrie()), + maxWordLen_(maxWordLen) { + assert(maxWordLen_); + } + QuerySegment(const DictTrie* dictTrie, const HMMModel* model) + : mixSeg_(dictTrie, model), fullSeg_(dictTrie) { + } + virtual ~QuerySegment() { + } using SegmentBase::cut; bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res) const { //use mix cut first