add segmentContext in segment && run ok

This commit is contained in:
gwdwyy 2013-08-19 01:29:46 +08:00
parent 8f06d1340a
commit 346bc54c35
10 changed files with 156 additions and 157 deletions

View File

@ -29,7 +29,7 @@ void testKeyWordExt(const char * dictPath, const char * filePath)
if(!line.empty()) if(!line.empty())
{ {
ext.extract(line, res, 20); ext.extract(line, res, 20);
cout<<line<<"\n"<<joinStr(res," ")<<endl; cout<<line<<"\n"<<joinStr(res,",")<<endl;
} }
} }

View File

@ -82,17 +82,17 @@ namespace CppJieba
return true; return true;
} }
bool KeyWordExt::_wordInfoCompare(const WordInfo& a, const WordInfo& b) bool KeyWordExt::_wordInfoCompare(const KeyWordInfo& a, const KeyWordInfo& b)
{ {
return a.weight > b.weight; return a.weight > b.weight;
} }
bool KeyWordExt::_sortWLIDF(vector<WordInfo>& wordInfos) bool KeyWordExt::_sortWLIDF(vector<KeyWordInfo>& wordInfos)
{ {
for(uint i = 0; i < wordInfos.size(); i++) for(uint i = 0; i < wordInfos.size(); i++)
{ {
WordInfo& wInfo = wordInfos[i]; KeyWordInfo& wInfo = wordInfos[i];
double logWordFreq = _segment.getWordWeight(wInfo.word); double logWordFreq = 1.0;//_segment.getWordWeight(wInfo.word);
wInfo.idf = -logWordFreq; wInfo.idf = -logWordFreq;
size_t wLen = TransCode::getWordLength(wInfo.word); size_t wLen = TransCode::getWordLength(wInfo.word);
if(0 == wLen) if(0 == wLen)
@ -108,10 +108,10 @@ namespace CppJieba
bool KeyWordExt::_extractTopN(const vector<string>& words, vector<string>& keywords, uint topN) bool KeyWordExt::_extractTopN(const vector<string>& words, vector<string>& keywords, uint topN)
{ {
keywords.clear(); keywords.clear();
vector<WordInfo> wordInfos; vector<KeyWordInfo> wordInfos;
for(uint i = 0; i < words.size(); i++) for(uint i = 0; i < words.size(); i++)
{ {
WordInfo wInfo; KeyWordInfo wInfo;
wInfo.word = words[i]; wInfo.word = words[i];
wordInfos.push_back(wInfo); wordInfos.push_back(wInfo);
} }
@ -358,16 +358,16 @@ namespace CppJieba
return false; return false;
} }
bool KeyWordExt::_prioritizeSubWords(vector<WordInfo>& wordInfos) bool KeyWordExt::_prioritizeSubWords(vector<KeyWordInfo>& wordInfos)
{ {
if(2 > wordInfos.size()) if(2 > wordInfos.size())
{ {
return true; return true;
} }
WordInfo prior; KeyWordInfo prior;
bool flag = false; bool flag = false;
for(vector<WordInfo>::iterator it = wordInfos.begin(); it != wordInfos.end(); ) for(vector<KeyWordInfo>::iterator it = wordInfos.begin(); it != wordInfos.end(); )
{ {
if(_isContainSubWords(it->word)) if(_isContainSubWords(it->word))
{ {

View File

@ -36,12 +36,12 @@ namespace CppJieba
bool extract(const string& title, vector<string>& keywords, uint topN); bool extract(const string& title, vector<string>& keywords, uint topN);
bool extract(const vector<string>& words, vector<string>& keywords, uint topN); bool extract(const vector<string>& words, vector<string>& keywords, uint topN);
private: private:
static bool _wordInfoCompare(const WordInfo& a, const WordInfo& b); static bool _wordInfoCompare(const KeyWordInfo& a, const KeyWordInfo& b);
private: private:
bool _extractTopN(const vector<string>& words, vector<string>& keywords, uint topN); bool _extractTopN(const vector<string>& words, vector<string>& keywords, uint topN);
private: private:
//sort by word len - idf //sort by word len - idf
bool _sortWLIDF(vector<WordInfo>& wordInfos); bool _sortWLIDF(vector<KeyWordInfo>& wordInfos);
private: private:
bool _filter(vector<string>& strs); bool _filter(vector<string>& strs);
bool _filterDuplicate(vector<string>& strs); bool _filterDuplicate(vector<string>& strs);
@ -49,7 +49,7 @@ namespace CppJieba
bool _filterSubstr(vector<string>& strs); bool _filterSubstr(vector<string>& strs);
bool _filterStopWords(vector<string>& strs); bool _filterStopWords(vector<string>& strs);
private: private:
bool _prioritizeSubWords(vector<WordInfo>& wordInfos); bool _prioritizeSubWords(vector<KeyWordInfo>& wordInfos);
bool _isContainSubWords(const string& word); bool _isContainSubWords(const string& word);
}; };

View File

@ -16,8 +16,7 @@ namespace CppJieba
bool Segment::init() bool Segment::init()
{ {
bool retFlag = _trie.init(); if(!_trie.init())
if(!retFlag)
{ {
LogError("_trie.init failed."); LogError("_trie.init failed.");
return false; return false;
@ -39,179 +38,166 @@ namespace CppJieba
return _trie.dispose(); return _trie.dispose();
} }
double Segment::getWordWeight(const string& str) bool Segment::cutDAG(const string& str, vector<string>& res)
{ {
return _trie.getWeight(str); vector<TrieNodeInfo> segWordInfos;
if(!cutDAG(str, segWordInfos))
{
return false;
}
res.clear();
for(uint i = 0; i < segWordInfos.size(); i++)
{
res.push_back(segWordInfos[i].word);
}
return true;
} }
bool Segment::cutDAG(const string& str, vector<string>& res) bool Segment::cutDAG(const string& str, vector<TrieNodeInfo>& segWordInfos)
{ {
if(str.empty()) if(str.empty())
{ {
return false; return false;
} }
res.clear(); segWordInfos.clear();
SegmentContext segContext;
bool retFlag;
VUINT16 unicode; if(!TransCode::strToVec(str, segContext.uintVec))
retFlag = TransCode::strToVec(str, unicode);
if(!retFlag)
{ {
LogError("TransCode::strToVec failed."); LogError("TransCode::strToVec failed.");
return false; return false;
} }
//calc DAG //calc DAG
vector<vector<uint> > dag; if(!_calcDAG(segContext))
retFlag = _calcDAG(unicode, dag);
if(!retFlag)
{ {
LogError("_calcDAG failed."); LogError("_calcDAG failed.");
return false; return false;
} }
#ifdef DEBUG if(!_calcDP(segContext))
{
string tmp("{");
FOR_VECTOR(dag, i)
{
tmp += "[";
FOR_VECTOR(dag[i], j)
{
tmp += string_format("%d,", dag[i][j]);
}
tmp += "],";
}
tmp += "}";
LogDebug(tmp);
}
#endif
vector<pair<int, double> > dp;
retFlag = _calcDP(unicode, dag, dp);
if(!retFlag)
{ {
LogError("_calcDP failed."); LogError("_calcDP failed.");
return false; return false;
} }
if(!_cutDAG(segContext, segWordInfos))
retFlag = _cutDAG(unicode, dp, res);
if(!retFlag)
{ {
LogError("_cutDAG failed."); LogError("_cutDAG failed.");
return false; return false;
} }
return true; return true;
} }
bool Segment::_calcDAG(const VUINT16& unicode, vector<vector<uint> >& dag) bool Segment::_calcDAG(SegmentContext& segContext)
{ {
if(unicode.empty()) if(segContext.uintVec.empty())
{ {
return false; return false;
} }
VUINT16_CONST_ITER beginIter = unicode.begin(); vector<pair<uint, const TrieNodeInfo*> > vec;
for(VUINT16_CONST_ITER iterI = unicode.begin(); iterI != unicode.end(); iterI++) VUINT16_CONST_ITER beginIter = segContext.uintVec.begin();
for(VUINT16_CONST_ITER iterI = segContext.uintVec.begin(); iterI != segContext.uintVec.end(); iterI++)
{ {
vector<uint> vec; vec.clear();
vec.push_back(iterI - beginIter); vec.push_back(pair<uint, const TrieNodeInfo*>(iterI - beginIter, NULL));
for(VUINT16_CONST_ITER iterJ = iterI + 1; iterJ != unicode.end(); iterJ++) for(VUINT16_CONST_ITER iterJ = iterI + 1; iterJ != segContext.uintVec.end(); iterJ++)
{ {
//care: the iterJ exceed iterEnd //care: the iterJ exceed iterEnd
if(NULL != _trie.find(iterI, iterJ + 1)) const TrieNodeInfo* ptNodeInfo = _trie.find(iterI, iterJ + 1);
if(NULL != ptNodeInfo)
{ {
vec.push_back(iterJ - beginIter); vec.push_back(pair<uint, const TrieNodeInfo*>(iterJ - beginIter, ptNodeInfo));
} }
} }
dag.push_back(vec); segContext.dag.push_back(vec);
} }
return true; return true;
} }
bool Segment::_calcDP(const VUINT16& unicode, const vector<vector<uint> >& dag, vector<pair<int, double> >& res) bool Segment::_calcDP(SegmentContext& segContext)
{ {
if(unicode.empty()) if(segContext.uintVec.empty())
{ {
LogError("unicode illegal"); LogError("uintVec illegal");
return false; return false;
} }
if(unicode.size() != dag.size()) if(segContext.uintVec.size() != segContext.dag.size())
{ {
LogError("dag is illegal!"); LogError("dag is illegal!");
return false; return false;
} }
res.clear(); segContext.dp.assign(segContext.uintVec.size() + 1, pair<const TrieNodeInfo*, double>(NULL, 0.0));
res.assign(unicode.size() + 1, pair<int, double>(-1, 0.0)); segContext.dp[segContext.uintVec.size()].first = NULL;
res[unicode.size()].first = -1; segContext.dp[segContext.uintVec.size()].second = 0.0;
res[unicode.size()].second = 0.0;
VUINT16_CONST_ITER iterBegin = unicode.begin(); for(int i = segContext.uintVec.size() - 1; i >= 0; i--)
for(int i = unicode.size() - 1; i >= 0; i--)
{ {
// calc max // calc max
res[i].first = -1; segContext.dp[i].first = NULL;
res[i].second = -(numeric_limits<double>::max()); segContext.dp[i].second = -(numeric_limits<double>::max());
for(uint j = 0; j < dag[i].size(); j++) for(uint j = 0; j < segContext.dag[i].size(); j++)
{ {
//cout<<(i/2)<<","<<dag[i/2].size()<<","<<j<<endl; const pair<uint , const TrieNodeInfo*>& p = segContext.dag[i][j];
int pos = dag[i][j]; int pos = p.first;
double val = _trie.getWeight(iterBegin + i, iterBegin + pos + 1) + res[pos + 1].second; double val = segContext.dp[pos+1].second;
//cout<<i<<","<<pos<<","<<val<<endl; if(NULL != p.second)
//double val = _trie.getWeight(uniStr.substr(i, pos * 2 - i + 2)) + res[pos + 1].second;
//cout<<pos<<","<<pos * 2 - i + 2<<","<<val<<endl;
if(val > res[i].second)
{ {
res[i].first = pos; val += (p.second)->logFreq;
res[i].second = val; }
else
{
val += _trie.getMinLogFreq();
}
if(val > segContext.dp[i].second)
{
segContext.dp[i].first = p.second;
segContext.dp[i].second = val;
} }
} }
} }
//FOR_VECTOR(res, i) segContext.dp.pop_back();
//{
// cout<<i<<","<<res[i].first<<","<<res[i].second<<endl;
//}
res.pop_back();
return true; return true;
} }
bool Segment::_cutDAG(const VUINT16& unicode, const vector<pair<int, double> >& dp, vector<string>& res)
bool Segment::_cutDAG(SegmentContext& segContext, vector<TrieNodeInfo>& res)
{ {
if(dp.size() != unicode.size()) if(segContext.dp.empty() || segContext.uintVec.empty() || segContext.dp.size() != segContext.uintVec.size())
{ {
LogError("dp or unicode illegal!"); LogError("dp or uintVec illegal!");
return false; return false;
} }
res.clear(); res.clear();
uint begin = 0, end = 0; VUINT16_CONST_ITER iterBegin = segContext.uintVec.begin();
VUINT16_CONST_ITER iterBegin = unicode.begin(); uint i = 0;
//for(uint i = 0; i < dp.size(); i++) while(i < segContext.dp.size())
while(begin < dp.size() && end <= dp.size())
{ {
//cout<<begin<<"," const TrieNodeInfo* p = segContext.dp[i].first;
// <<dp[i].first<<"," if(NULL == p)
// <<dp[i].second<<endl;
end = dp[begin].first + 1;
//cout<<begin<<","<<end<<endl;
//if(end <= begin)
//{
// continue;
// }
//cout<<begin<<","<<end<<endl;
//string tmp = TransCode::vecToStr(uniStr.substr(begin, end - begin));
string tmp = TransCode::vecToStr(iterBegin + begin, iterBegin + end);
if(tmp.empty())
{ {
LogError("TransCode::vecToStr failed."); TrieNodeInfo nodeInfo;
return false; nodeInfo.word = TransCode::vecToStr(iterBegin + i, iterBegin + i +1);
nodeInfo.wLen = 1;
nodeInfo.freq = 0;
nodeInfo.logFreq = _trie.getMinLogFreq();
res.push_back(nodeInfo);
i ++;
}
else
{
res.push_back(*p);
if(0 == p->wLen)
{
LogFatal("TrieNodeInfo's wLen is 0!");
return false;
}
i += p->wLen;
} }
res.push_back(tmp);
begin = end;
} }
return true; return true;
} }

View File

@ -23,14 +23,16 @@ namespace CppJieba
bool init(); bool init();
bool loadSegDict(const char * const filePath); bool loadSegDict(const char * const filePath);
bool dispose(); bool dispose();
double getWordWeight(const string& str);
public: public:
bool cutDAG(const string& chStr, vector<string>& res); bool cutDAG(const string& str, vector<TrieNodeInfo>& segWordInfos);
bool cutDAG(const string& str, vector<string>& res);
private: private:
bool _calcDAG(const VUINT16& unicode, vector<vector<uint> >& dag); bool _calcDAG(SegmentContext& segContext);
bool _calcDP(const VUINT16& unicode, const vector<vector<uint> >& dag, vector<pair<int, double> >& res); bool _calcDP(SegmentContext& segContext);
bool _cutDAG(const VUINT16& unicode, const vector<pair<int, double> >& dp, vector<string>& res); bool _cutDAG(SegmentContext& segContext, vector<TrieNodeInfo>& res);
//bool _fill(const string& )
}; };
} }

View File

@ -44,11 +44,6 @@ namespace CppJieba
return true; return true;
} }
bool TransCode::a(const string& str, vector<uint16_t>& vec)
{
return true;
}
bool TransCode::strToVec(const string& str, vector<uint16_t>& vec) bool TransCode::strToVec(const string& str, vector<uint16_t>& vec)
{ {
if(NULL == _pf_strToVec) if(NULL == _pf_strToVec)

View File

@ -36,7 +36,6 @@ namespace CppJieba
public: public:
static bool init(); static bool init();
public: public:
static bool a(const string& str, vector<uint16_t>& vec);
static bool strToVec(const string& str, vector<uint16_t>& vec); static bool strToVec(const string& str, vector<uint16_t>& vec);
static string vecToStr(VUINT16_CONST_ITER begin, VUINT16_CONST_ITER end); static string vecToStr(VUINT16_CONST_ITER begin, VUINT16_CONST_ITER end);
static size_t getWordLength(const string& str); static size_t getWordLength(const string& str);

View File

@ -152,9 +152,9 @@ namespace CppJieba
LogFatal("trie not initted!"); LogFatal("trie not initted!");
return NULL; return NULL;
} }
VUINT16 unicode; VUINT16 uintVec;
bool retFlag = TransCode::strToVec(str, unicode); bool retFlag = TransCode::strToVec(str, uintVec);
if(retFlag) if(retFlag)
{ {
LogError("TransCode::strToVec failed."); LogError("TransCode::strToVec failed.");
@ -164,9 +164,9 @@ namespace CppJieba
//find //find
TrieNode* p = _root; TrieNode* p = _root;
TrieNodeInfo * res = NULL; TrieNodeInfo * res = NULL;
for(uint i = 0; i < unicode.size(); i++) for(uint i = 0; i < uintVec.size(); i++)
{ {
uint16_t chUni = unicode[i]; uint16_t chUni = uintVec[i];
if(p->isLeaf) if(p->isLeaf)
{ {
uint pos = p->nodeInfoVecPos; uint pos = p->nodeInfoVecPos;
@ -195,22 +195,22 @@ namespace CppJieba
const TrieNodeInfo* Trie::find(const string& str) const TrieNodeInfo* Trie::find(const string& str)
{ {
VUINT16 unicode; VUINT16 uintVec;
bool retFlag = TransCode::strToVec(str, unicode); bool retFlag = TransCode::strToVec(str, uintVec);
if(!retFlag) if(!retFlag)
{ {
return NULL; return NULL;
} }
return find(unicode); return find(uintVec);
} }
const TrieNodeInfo* Trie::find(const VUINT16& unicode) const TrieNodeInfo* Trie::find(const VUINT16& uintVec)
{ {
if(unicode.empty()) if(uintVec.empty())
{ {
return NULL; return NULL;
} }
return find(unicode.begin(), unicode.end()); return find(uintVec.begin(), uintVec.end());
} }
const TrieNodeInfo* Trie::find(VUINT16_CONST_ITER begin, VUINT16_CONST_ITER end) const TrieNodeInfo* Trie::find(VUINT16_CONST_ITER begin, VUINT16_CONST_ITER end)
@ -257,25 +257,25 @@ namespace CppJieba
double Trie::getWeight(const string& str) double Trie::getWeight(const string& str)
{ {
VUINT16 unicode; VUINT16 uintVec;
TransCode::strToVec(str, unicode); TransCode::strToVec(str, uintVec);
return getWeight(unicode); return getWeight(uintVec);
} }
double Trie::getWeight(const VUINT16& unicode) double Trie::getWeight(const VUINT16& uintVec)
{ {
if(unicode.empty()) if(uintVec.empty())
{ {
return getMinWeight(); return getMinLogFreq();
} }
const TrieNodeInfo * p = find(unicode); const TrieNodeInfo * p = find(uintVec);
if(NULL != p) if(NULL != p)
{ {
return p->logFreq; return p->logFreq;
} }
else else
{ {
return getMinWeight(); return getMinLogFreq();
} }
} }
@ -289,11 +289,11 @@ namespace CppJieba
} }
else else
{ {
return getMinWeight(); return getMinLogFreq();
} }
} }
double Trie::getMinWeight() double Trie::getMinLogFreq()
{ {
return _minLogFreq; return _minLogFreq;
} }
@ -326,8 +326,8 @@ namespace CppJieba
const string& word = nodeInfo.word; const string& word = nodeInfo.word;
VUINT16 unicode; VUINT16 uintVec;
bool retFlag = TransCode::strToVec(word, unicode); bool retFlag = TransCode::strToVec(word, uintVec);
if(!retFlag) if(!retFlag)
{ {
LogError("TransCode::strToVec error."); LogError("TransCode::strToVec error.");
@ -335,9 +335,9 @@ namespace CppJieba
} }
TrieNode* p = _root; TrieNode* p = _root;
for(uint i = 0; i < unicode.size(); i++) for(uint i = 0; i < uintVec.size(); i++)
{ {
uint16_t cu = unicode[i]; uint16_t cu = uintVec[i];
if(NULL == p) if(NULL == p)
{ {
return false; return false;
@ -426,7 +426,7 @@ int main()
trie.init(); trie.init();
trie.loadDict("../dicts/segdict.gbk.v2.1"); trie.loadDict("../dicts/segdict.gbk.v2.1");
//trie.loadDict("tmp"); //trie.loadDict("tmp");
cout<<trie.getMinWeight()<<endl; cout<<trie.getMinLogFreq()<<endl;
cout<<trie.getTotalCount()<<endl; cout<<trie.getTotalCount()<<endl;
trie.dispose(); trie.dispose();
return 0; return 0;

View File

@ -67,15 +67,15 @@ namespace CppJieba
public: public:
const TrieNodeInfo* find(const string& str); const TrieNodeInfo* find(const string& str);
const TrieNodeInfo* find(const VUINT16& unicode); const TrieNodeInfo* find(const VUINT16& uintVec);
const TrieNodeInfo* find(VUINT16_CONST_ITER begin, VUINT16_CONST_ITER end); const TrieNodeInfo* find(VUINT16_CONST_ITER begin, VUINT16_CONST_ITER end);
const TrieNodeInfo* findPrefix(const string& str); const TrieNodeInfo* findPrefix(const string& str);
public: public:
double getWeight(const string& str); double getWeight(const string& str);
double getWeight(const VUINT16& unicode); double getWeight(const VUINT16& uintVec);
double getWeight(VUINT16_CONST_ITER begin, VUINT16_CONST_ITER end); double getWeight(VUINT16_CONST_ITER begin, VUINT16_CONST_ITER end);
double getMinWeight(); double getMinLogFreq();
int64_t getTotalCount(); int64_t getTotalCount();

View File

@ -12,7 +12,7 @@ namespace CppJieba
size_t wLen;// the word's len , not string.length(), size_t wLen;// the word's len , not string.length(),
size_t freq; size_t freq;
string tag; string tag;
double logFreq;//log(freq/sum(freq)); double logFreq; //logFreq = log(freq/sum(freq));
TrieNodeInfo() TrieNodeInfo()
{ {
wLen = 0; wLen = 0;
@ -20,12 +20,29 @@ namespace CppJieba
logFreq = 0.0; logFreq = 0.0;
} }
}; };
struct SegmentContext//: public TrieNodeInfo
{
vector<uint16_t> uintVec;
vector< vector<pair<uint, const TrieNodeInfo*> > > dag;
vector< pair<const TrieNodeInfo*, double> > dp;
//vector<string> words;
};
/*
struct SegmentWordInfo: public TrieNodeInfo
{
};
*/
struct WordInfo: public TrieNodeInfo struct KeyWordInfo: public TrieNodeInfo
{ {
double idf; double idf;
double weight;// log(wLen+1)*logFreq; double weight;// log(wLen+1)*logFreq;
WordInfo() KeyWordInfo()
{ {
idf = 0.0; idf = 0.0;
weight = 0.0; weight = 0.0;
@ -36,7 +53,7 @@ namespace CppJieba
} }
}; };
inline string joinWordInfos(const vector<WordInfo>& vec) inline string joinWordInfos(const vector<KeyWordInfo>& vec)
{ {
vector<string> tmp; vector<string> tmp;
for(uint i = 0; i < vec.size(); i++) for(uint i = 0; i < vec.size(); i++)