重构:增加让各个分词类的构造函数,为后面的憋大招做准备。

This commit is contained in:
yanyiwu 2015-06-04 22:38:55 +08:00
parent b99d0698f0
commit d56bf2cc68
7 changed files with 86 additions and 72 deletions

View File

@ -13,40 +13,23 @@
namespace CppJieba { namespace CppJieba {
class FullSegment: public SegmentBase { class FullSegment: public SegmentBase {
public: public:
FullSegment() { FullSegment(const string& dictPath) {
dictTrie_ = NULL; dictTrie_ = new DictTrie(dictPath);
isBorrowed_ = false; isNeedDestroy_ = true;
LogInfo("FullSegment init %s ok", dictPath.c_str());
} }
explicit FullSegment(const string& dictPath) { FullSegment(const DictTrie* dictTrie)
dictTrie_ = NULL; : dictTrie_(dictTrie), isNeedDestroy_(false) {
init(dictPath); assert(dictTrie_);
}
explicit FullSegment(const DictTrie* dictTrie) {
dictTrie_ = NULL;
init(dictTrie);
} }
virtual ~FullSegment() { virtual ~FullSegment() {
if(dictTrie_ && ! isBorrowed_) { if(isNeedDestroy_) {
delete dictTrie_; delete dictTrie_;
} }
};
bool init(const string& dictPath) {
assert(dictTrie_ == NULL);
dictTrie_ = new DictTrie(dictPath);
isBorrowed_ = false;
return true;
} }
bool init(const DictTrie* dictTrie) {
assert(dictTrie_ == NULL);
assert(dictTrie);
dictTrie_ = dictTrie;
isBorrowed_ = true;
return true;
}
using SegmentBase::cut; using SegmentBase::cut;
bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<Unicode>& res) const { bool cut(Unicode::const_iterator begin, Unicode::const_iterator end,
vector<Unicode>& res) const {
//resut of searching in trie tree //resut of searching in trie tree
DagType tRes; DagType tRes;
@ -87,7 +70,8 @@ class FullSegment: public SegmentBase {
return true; return true;
} }
bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<string>& res) const { bool cut(Unicode::const_iterator begin, Unicode::const_iterator end,
vector<string>& res) const {
vector<Unicode> uRes; vector<Unicode> uRes;
if (!cut(begin, end, uRes)) { if (!cut(begin, end, uRes)) {
LogError("get unicode cut result error."); LogError("get unicode cut result error.");
@ -95,7 +79,8 @@ class FullSegment: public SegmentBase {
} }
string tmp; string tmp;
for (vector<Unicode>::const_iterator uItr = uRes.begin(); uItr != uRes.end(); uItr++) { for (vector<Unicode>::const_iterator uItr = uRes.begin();
uItr != uRes.end(); uItr++) {
TransCode::encode(*uItr, tmp); TransCode::encode(*uItr, tmp);
res.push_back(tmp); res.push_back(tmp);
} }
@ -104,7 +89,7 @@ class FullSegment: public SegmentBase {
} }
private: private:
const DictTrie* dictTrie_; const DictTrie* dictTrie_;
bool isBorrowed_; bool isNeedDestroy_;
}; };
} }

View File

@ -12,9 +12,17 @@ namespace CppJieba {
class HMMSegment: public SegmentBase { class HMMSegment: public SegmentBase {
public: public:
explicit HMMSegment(const string& filePath): model_(filePath) { HMMSegment(const string& filePath) {
model_ = new HMMModel(filePath);
}
HMMSegment(const HMMModel* model)
: model_(model), isNeedDestroy_(false) {
}
virtual ~HMMSegment() {
if(isNeedDestroy_) {
delete model_;
}
} }
virtual ~HMMSegment() {}
using SegmentBase::cut; using SegmentBase::cut;
bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<Unicode>& res)const { bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<Unicode>& res)const {
@ -138,7 +146,7 @@ class HMMSegment: public SegmentBase {
//start //start
for(size_t y = 0; y < Y; y++) { for(size_t y = 0; y < Y; y++) {
weight[0 + y * X] = model_.startProb[y] + model_.getEmitProb(model_.emitProbVec[y], *begin, MIN_DOUBLE); weight[0 + y * X] = model_->startProb[y] + model_->getEmitProb(model_->emitProbVec[y], *begin, MIN_DOUBLE);
path[0 + y * X] = -1; path[0 + y * X] = -1;
} }
@ -149,10 +157,10 @@ class HMMSegment: public SegmentBase {
now = x + y*X; now = x + y*X;
weight[now] = MIN_DOUBLE; weight[now] = MIN_DOUBLE;
path[now] = HMMModel::E; // warning path[now] = HMMModel::E; // warning
emitProb = model_.getEmitProb(model_.emitProbVec[y], *(begin+x), MIN_DOUBLE); emitProb = model_->getEmitProb(model_->emitProbVec[y], *(begin+x), MIN_DOUBLE);
for(size_t preY = 0; preY < Y; preY++) { for(size_t preY = 0; preY < Y; preY++) {
old = x - 1 + preY * X; old = x - 1 + preY * X;
tmp = weight[old] + model_.transProb[preY][y] + emitProb; tmp = weight[old] + model_->transProb[preY][y] + emitProb;
if(tmp > weight[now]) { if(tmp > weight[now]) {
weight[now] = tmp; weight[now] = tmp;
path[now] = preY; path[now] = preY;
@ -179,7 +187,9 @@ class HMMSegment: public SegmentBase {
return true; return true;
} }
HMMModel model_; private:
const HMMModel* model_;
bool isNeedDestroy_;
}; // class HMMSegment }; // class HMMSegment
} // namespace CppJieba } // namespace CppJieba

View File

@ -20,6 +20,14 @@ class KeywordExtractor {
loadIdfDict_(idfPath); loadIdfDict_(idfPath);
loadStopWordDict_(stopWordPath); loadStopWordDict_(stopWordPath);
} }
KeywordExtractor(const DictTrie* dictTrie,
const HMMModel* model,
const string& idfPath,
const string& stopWordPath)
: segment_(dictTrie, model){
loadIdfDict_(idfPath);
loadStopWordDict_(stopWordPath);
}
~KeywordExtractor() { ~KeywordExtractor() {
} }

View File

@ -12,28 +12,28 @@
namespace CppJieba { namespace CppJieba {
class MPSegment: public SegmentBase { class MPSegment: public SegmentBase {
public: public:
MPSegment() {};
MPSegment(const string& dictPath, const string& userDictPath = "") { MPSegment(const string& dictPath, const string& userDictPath = "") {
init(dictPath, userDictPath); dictTrie_ = new DictTrie(dictPath, userDictPath);
}; isNeedDestroy_ = true;
virtual ~MPSegment() {};
void init(const string& dictPath, const string& userDictPath = "") {
dictTrie_.init(dictPath, userDictPath);
LogInfo("MPSegment init(%s) ok", dictPath.c_str()); LogInfo("MPSegment init(%s) ok", dictPath.c_str());
} }
MPSegment(const DictTrie* dictTrie)
: dictTrie_(dictTrie), isNeedDestroy_(false) {
assert(dictTrie_);
}
virtual ~MPSegment() {
if(isNeedDestroy_) {
delete dictTrie_;
}
}
bool isUserDictSingleChineseWord(const Unicode::value_type & value) const { bool isUserDictSingleChineseWord(const Unicode::value_type & value) const {
return dictTrie_.isUserDictSingleChineseWord(value); return dictTrie_->isUserDictSingleChineseWord(value);
} }
using SegmentBase::cut; using SegmentBase::cut;
virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<string>& res)const { virtual bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<string>& res)const {
if(begin == end) {
return false;
}
vector<Unicode> words; vector<Unicode> words;
words.reserve(end - begin); words.reserve(end - begin);
if(!cut(begin, end, words)) { if(!cut(begin, end, words)) {
@ -48,12 +48,9 @@ class MPSegment: public SegmentBase {
} }
bool cut(Unicode::const_iterator begin , Unicode::const_iterator end, vector<Unicode>& res) const { bool cut(Unicode::const_iterator begin , Unicode::const_iterator end, vector<Unicode>& res) const {
if(end == begin) {
return false;
}
vector<SegmentChar> segmentChars; vector<SegmentChar> segmentChars;
dictTrie_.find(begin, end, segmentChars); dictTrie_->find(begin, end, segmentChars);
calcDP_(segmentChars); calcDP_(segmentChars);
@ -62,7 +59,7 @@ class MPSegment: public SegmentBase {
return true; return true;
} }
const DictTrie* getDictTrie() const { const DictTrie* getDictTrie() const {
return &dictTrie_; return dictTrie_;
} }
private: private:
@ -86,7 +83,7 @@ class MPSegment: public SegmentBase {
if(p) { if(p) {
val += p->weight; val += p->weight;
} else { } else {
val += dictTrie_.getMinWeight(); val += dictTrie_->getMinWeight();
} }
if(val > rit->weight) { if(val > rit->weight) {
rit->pInfo = p; rit->pInfo = p;
@ -95,7 +92,8 @@ class MPSegment: public SegmentBase {
} }
} }
} }
void cut_(const vector<SegmentChar>& segmentChars, vector<Unicode>& res) const { void cut_(const vector<SegmentChar>& segmentChars,
vector<Unicode>& res) const {
size_t i = 0; size_t i = 0;
while(i < segmentChars.size()) { while(i < segmentChars.size()) {
const DictUnit* p = segmentChars[i].pInfo; const DictUnit* p = segmentChars[i].pInfo;
@ -110,9 +108,10 @@ class MPSegment: public SegmentBase {
} }
private: private:
DictTrie dictTrie_; const DictTrie* dictTrie_;
bool isNeedDestroy_;
}; // class MPSegment
}; } // namespace CppJieba
}
#endif #endif

View File

@ -15,6 +15,9 @@ class MixSegment: public SegmentBase {
hmmSeg_(hmmSegDict) { hmmSeg_(hmmSegDict) {
LogInfo("MixSegment init %s, %s", mpSegDict.c_str(), hmmSegDict.c_str()); LogInfo("MixSegment init %s, %s", mpSegDict.c_str(), hmmSegDict.c_str());
} }
MixSegment(const DictTrie* dictTrie, const HMMModel* model)
: mpSeg_(dictTrie), hmmSeg_(model) {
}
virtual ~MixSegment() { virtual ~MixSegment() {
} }
using SegmentBase::cut; using SegmentBase::cut;
@ -90,7 +93,9 @@ class MixSegment: public SegmentBase {
private: private:
MPSegment mpSeg_; MPSegment mpSeg_;
HMMSegment hmmSeg_; HMMSegment hmmSeg_;
};
} }; // class MixSegment
} // namespace CppJieba
#endif #endif

View File

@ -18,8 +18,9 @@ class PosTagger {
const string& hmmFilePath, const string& hmmFilePath,
const string& userDictPath = "") const string& userDictPath = "")
: segment_(dictPath, hmmFilePath, userDictPath) { : segment_(dictPath, hmmFilePath, userDictPath) {
dictTrie_ = segment_.getDictTrie(); }
LIMONP_CHECK(dictTrie_); PosTagger(const DictTrie* dictTrie, const HMMModel* model)
: segment_(dictTrie, model) {
} }
~PosTagger() { ~PosTagger() {
} }
@ -33,12 +34,14 @@ class PosTagger {
const DictUnit *tmp = NULL; const DictUnit *tmp = NULL;
Unicode unico; Unicode unico;
const DictTrie * dict = segment_.getDictTrie();
assert(dict != NULL);
for (vector<string>::iterator itr = cutRes.begin(); itr != cutRes.end(); ++itr) { for (vector<string>::iterator itr = cutRes.begin(); itr != cutRes.end(); ++itr) {
if (!TransCode::decode(*itr, unico)) { if (!TransCode::decode(*itr, unico)) {
LogError("decode failed."); LogError("decode failed.");
return false; return false;
} }
tmp = dictTrie_->find(unico.begin(), unico.end()); tmp = dict->find(unico.begin(), unico.end());
if(tmp == NULL || tmp->tag.empty()) { if(tmp == NULL || tmp->tag.empty()) {
res.push_back(make_pair(*itr, specialRule_(unico))); res.push_back(make_pair(*itr, specialRule_(unico)));
} else { } else {
@ -72,8 +75,8 @@ class PosTagger {
} }
private: private:
MixSegment segment_; MixSegment segment_;
const DictTrie * dictTrie_; }; // class PosTagger
};
} } // namespace CppJieba
#endif #endif

View File

@ -19,11 +19,15 @@ class QuerySegment: public SegmentBase {
QuerySegment(const string& dict, const string& model, size_t maxWordLen = 4, QuerySegment(const string& dict, const string& model, size_t maxWordLen = 4,
const string& userDict = "") const string& userDict = "")
: mixSeg_(dict, model, userDict), : mixSeg_(dict, model, userDict),
fullSeg_(mixSeg_.getDictTrie()) { fullSeg_(mixSeg_.getDictTrie()),
assert(maxWordLen); maxWordLen_(maxWordLen) {
maxWordLen_ = maxWordLen; assert(maxWordLen_);
}; }
virtual ~QuerySegment() {}; QuerySegment(const DictTrie* dictTrie, const HMMModel* model)
: mixSeg_(dictTrie, model), fullSeg_(dictTrie) {
}
virtual ~QuerySegment() {
}
using SegmentBase::cut; using SegmentBase::cut;
bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<Unicode>& res) const { bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<Unicode>& res) const {
//use mix cut first //use mix cut first