prettify Trie.hpp ing

This commit is contained in:
wyy 2014-03-16 20:20:37 +08:00
parent 582d61e3e8
commit fe7e3ff807
9 changed files with 76 additions and 150 deletions

View File

@ -48,7 +48,7 @@ namespace CppJieba
}
//resut of searching in trie tree
vector<pair<uint, const TrieNodeInfo*> > tRes;
vector<pair<size_t, const TrieNodeInfo*> > tRes;
//max index of res's words
int maxIdx = 0;
@ -63,7 +63,7 @@ namespace CppJieba
//find word start from uItr
if (_trie.find(uItr, end, tRes))
{
for (vector<pair<uint, const TrieNodeInfo*> >::const_iterator itr = tRes.begin(); itr != tRes.end(); itr++)
for (vector<pair<size_t, const TrieNodeInfo*> >::const_iterator itr = tRes.begin(); itr != tRes.end(); itr++)
{
wordLen = itr->second->word.size();
if (wordLen >= 2 || (tRes.size() == 1 && maxIdx <= uIdx))

View File

@ -76,7 +76,7 @@ namespace CppJieba
LogError("not inited.");
return false;
}
vector<uint> status;
vector<size_t> status;
if(!_viterbi(begin, end, status))
{
LogError("_viterbi failed.");
@ -85,7 +85,7 @@ namespace CppJieba
Unicode::const_iterator left = begin;
Unicode::const_iterator right;
for(uint i =0; i< status.size(); i++)
for(size_t i =0; i< status.size(); i++)
{
if(status[i] % 2) //if(E == status[i] || S == status[i])
{
@ -110,7 +110,7 @@ namespace CppJieba
return false;
}
string tmp;
for(uint i = 0; i < words.size(); i++)
for(size_t i = 0; i < words.size(); i++)
{
if(TransCode::encode(words[i], tmp))
{
@ -121,7 +121,7 @@ namespace CppJieba
}
private:
bool _viterbi(Unicode::const_iterator begin, Unicode::const_iterator end, vector<uint>& status)const
bool _viterbi(Unicode::const_iterator begin, Unicode::const_iterator end, vector<size_t>& status)const
{
if(begin == end)
{
@ -133,7 +133,7 @@ namespace CppJieba
size_t XYSize = X * Y;
int * path;
double * weight;
uint now, old, stat;
size_t now, old, stat;
double tmp, endE, endS;
try
@ -153,21 +153,21 @@ namespace CppJieba
}
//start
for(uint y = 0; y < Y; y++)
for(size_t y = 0; y < Y; y++)
{
weight[0 + y * X] = _startProb[y] + _getEmitProb(_emitProbVec[y], *begin, MIN_DOUBLE);
path[0 + y * X] = -1;
}
//process
//for(; begin != end; begin++)
for(uint x = 1; x < X; x++)
for(size_t x = 1; x < X; x++)
{
for(uint y = 0; y < Y; y++)
for(size_t y = 0; y < Y; y++)
{
now = x + y*X;
weight[now] = MIN_DOUBLE;
path[now] = E; // warning
for(uint preY = 0; preY < Y; preY++)
for(size_t preY = 0; preY < Y; preY++)
{
old = x - 1 + preY * X;
tmp = weight[old] + _transProb[preY][y] + _getEmitProb(_emitProbVec[y], *(begin+x), MIN_DOUBLE);
@ -221,14 +221,14 @@ namespace CppJieba
LogError("start_p illegal");
return false;
}
for(uint j = 0; j< tmp.size(); j++)
for(size_t j = 0; j< tmp.size(); j++)
{
_startProb[j] = atof(tmp[j].c_str());
//cout<<_startProb[j]<<endl;
}
//load _transProb
for(uint i = 0; i < STATUS_SUM; i++)
for(size_t i = 0; i < STATUS_SUM; i++)
{
if(!_getLine(ifile, line))
{
@ -240,7 +240,7 @@ namespace CppJieba
LogError("trans_p illegal");
return false;
}
for(uint j =0; j < STATUS_SUM; j++)
for(size_t j =0; j < STATUS_SUM; j++)
{
_transProb[i][j] = atof(tmp[j].c_str());
//cout<<_transProb[i][j]<<endl;
@ -301,7 +301,7 @@ namespace CppJieba
vector<string> tmp, tmp2;
uint16_t unico = 0;
split(line, tmp, ",");
for(uint i = 0; i < tmp.size(); i++)
for(size_t i = 0; i < tmp.size(); i++)
{
split(tmp[i], tmp2, ":");
if(2 != tmp2.size())

View File

@ -37,7 +37,7 @@ namespace CppJieba
};
public:
bool extract(const string& str, vector<string>& keywords, uint topN) const
bool extract(const string& str, vector<string>& keywords, size_t topN) const
{
assert(_getInitFlag());
vector<pair<string, double> > topWords;
@ -45,14 +45,14 @@ namespace CppJieba
{
return false;
}
for(uint i = 0; i < topWords.size(); i++)
for(size_t i = 0; i < topWords.size(); i++)
{
keywords.push_back(topWords[i].first);
}
return true;
}
bool extract(const string& str, vector<pair<string, double> >& keywords, uint topN) const
bool extract(const string& str, vector<pair<string, double> >& keywords, size_t topN) const
{
vector<string> words;
if(!_segment.cut(str, words))
@ -75,7 +75,7 @@ namespace CppJieba
}
map<string, double> wordmap;
for(uint i = 0; i < words.size(); i ++)
for(size_t i = 0; i < words.size(); i ++)
{
wordmap[ words[i] ] += 1.0;
}

View File

@ -66,7 +66,7 @@ namespace CppJieba
return false;
}
string tmp;
for(uint i = 0; i < segWordInfos.size(); i++)
for(size_t i = 0; i < segWordInfos.size(); i++)
{
if(TransCode::encode(segWordInfos[i].word, tmp))
{
@ -123,7 +123,7 @@ namespace CppJieba
for(Unicode::const_iterator it = begin; it != end; it++)
{
SegmentChar schar(*it);
uint i = it - begin;
size_t i = it - begin;
_trie.find(it, end, i, schar.dag);
//DagType::iterator dagIter;
if(schar.dag.end() == schar.dag.find(i))
@ -148,7 +148,7 @@ namespace CppJieba
segContext[i].weight = MIN_DOUBLE;
for(DagType::const_iterator it = segContext[i].dag.begin(); it != segContext[i].dag.end(); it++)
{
uint nextPos = it->first;
size_t nextPos = it->first;
const TrieNodeInfo* p = it->second;
double val = 0.0;
if(nextPos + 1 < segContext.size())
@ -176,7 +176,7 @@ namespace CppJieba
}
bool _cut(SegmentContext& segContext, vector<TrieNodeInfo>& res)const
{
uint i = 0;
size_t i = 0;
while(i < segContext.size())
{
const TrieNodeInfo* p = segContext[i].pInfo;

View File

@ -59,7 +59,7 @@ namespace CppJieba
vector<Unicode> hmmRes;
Unicode piece;
for (uint i = 0, j = 0; i < infos.size(); i++)
for (size_t i = 0, j = 0; i < infos.size(); i++)
{
//if mp get a word, it's ok, put it into result
if (1 != infos[i].word.size())
@ -84,7 +84,7 @@ namespace CppJieba
}
//put hmm result to return
for (uint k = 0; k < hmmRes.size(); k++)
for (size_t k = 0; k < hmmRes.size(); k++)
{
res.push_back(hmmRes[k]);
}

View File

@ -37,15 +37,15 @@ namespace CppJieba
return cut(unico.begin(), unico.end(), res);
#else
const char * const cstr = str.c_str();
uint size = str.size();
uint offset = 0;
size_t size = str.size();
size_t offset = 0;
string subs;
int ret;
uint len;
size_t len;
while(offset < size)
{
const char * const nstr = cstr + offset;
uint nsize = size - offset;
size_t nsize = size - offset;
if(-1 == (ret = filterAscii(nstr, nsize, len)) || 0 == len || len > nsize)
{
LogFatal("str[%s] illegal.", cstr);
@ -78,7 +78,7 @@ namespace CppJieba
* else count the nonascii string's length and return 1;
* if errors, return -1;
* */
static int filterAscii(const char* str, uint len, uint& resLen)
static int filterAscii(const char* str, size_t len, size_t& resLen)
{
if(!str || !len)
{

View File

@ -24,12 +24,13 @@ namespace CppJieba
using namespace Limonp;
const double MIN_DOUBLE = -3.14e+100;
const double MAX_DOUBLE = 3.14e+100;
const size_t DICT_COLUMN_NUM = 3;
typedef unordered_map<uint16_t, struct TrieNode*> TrieNodeMap;
struct TrieNode
{
TrieNodeMap hmap;
bool isLeaf;
uint nodeInfoVecPos;
size_t nodeInfoVecPos;
TrieNode()
{
isLeaf = false;
@ -44,18 +45,11 @@ namespace CppJieba
string tag;
double logFreq; //logFreq = log(freq/sum(freq));
TrieNodeInfo():freq(0),logFreq(0.0)
{
}
{}
TrieNodeInfo(const TrieNodeInfo& nodeInfo):word(nodeInfo.word), freq(nodeInfo.freq), tag(nodeInfo.tag), logFreq(nodeInfo.logFreq)
{
}
{}
TrieNodeInfo(const Unicode& _word):word(_word),freq(0),logFreq(MIN_DOUBLE)
{
}
bool operator == (const TrieNodeInfo & rhs) const
{
return word == rhs.word && freq == rhs.freq && tag == rhs.tag && abs(logFreq - rhs.logFreq) < 0.001;
}
{}
};
inline ostream& operator << (ostream& os, const TrieNodeInfo & nodeInfo)
@ -63,7 +57,7 @@ namespace CppJieba
return os << nodeInfo.word << ":" << nodeInfo.freq << ":" << nodeInfo.tag << ":" << nodeInfo.logFreq ;
}
typedef map<uint, const TrieNodeInfo*> DagType;
typedef map<size_t, const TrieNodeInfo*> DagType;
class Trie: public InitOnOff
{
@ -89,10 +83,6 @@ namespace CppJieba
}
~Trie()
{
if(!_getInitFlag())
{
return;
}
_deleteNode(_root);
}
public:
@ -102,7 +92,7 @@ namespace CppJieba
_root = new TrieNode;
assert(_root);
if(!_trieInsert(filePath.c_str()))
if(!_trieInsert(filePath))
{
LogError("_trieInsert failed.");
return false;
@ -118,16 +108,6 @@ namespace CppJieba
public:
const TrieNodeInfo* find(Unicode::const_iterator begin, Unicode::const_iterator end)const
{
if(!_getInitFlag())
{
LogFatal("trie not initted!");
return NULL;
}
if(begin >= end)
{
return NULL;
}
TrieNode* p = _root;
for(Unicode::const_iterator it = begin; it != end; it++)
{
@ -143,7 +123,7 @@ namespace CppJieba
}
if(p->isLeaf)
{
uint pos = p->nodeInfoVecPos;
size_t pos = p->nodeInfoVecPos;
if(pos < _nodeInfoVec.size())
{
return &(_nodeInfoVec[pos]);
@ -157,18 +137,8 @@ namespace CppJieba
return NULL;
}
bool find(Unicode::const_iterator begin, Unicode::const_iterator end, vector<pair<uint, const TrieNodeInfo*> >& res) const
bool find(Unicode::const_iterator begin, Unicode::const_iterator end, vector<pair<size_t, const TrieNodeInfo*> >& res) const
{
if(!_getInitFlag())
{
LogFatal("trie not initted!");
return false;
}
if (begin >= end)
{
LogFatal("begin >= end");
return false;
}
TrieNode* p = _root;
for (Unicode::const_iterator itr = begin; itr != end; itr++)
{
@ -179,7 +149,7 @@ namespace CppJieba
p = p->hmap[*itr];
if(p->isLeaf)
{
uint pos = p->nodeInfoVecPos;
size_t pos = p->nodeInfoVecPos;
if(pos < _nodeInfoVec.size())
{
res.push_back(make_pair(itr-begin, &_nodeInfoVec[pos]));
@ -194,18 +164,8 @@ namespace CppJieba
return !res.empty();
}
bool find(Unicode::const_iterator begin, Unicode::const_iterator end, uint offset, DagType & res) const
bool find(Unicode::const_iterator begin, Unicode::const_iterator end, size_t offset, DagType & res) const
{
if(!_getInitFlag())
{
LogFatal("trie not initted!");
return false;
}
if (begin >= end)
{
LogFatal("begin >= end");
return false;
}
TrieNode* p = _root;
for (Unicode::const_iterator itr = begin; itr != end; itr++)
{
@ -216,10 +176,9 @@ namespace CppJieba
p = p->hmap[*itr];
if(p->isLeaf)
{
uint pos = p->nodeInfoVecPos;
size_t pos = p->nodeInfoVecPos;
if(pos < _nodeInfoVec.size())
{
//res.push_back(make_pair(itr-begin, &_nodeInfoVec[pos]));
res[itr-begin + offset] = &_nodeInfoVec[pos];
}
else
@ -233,32 +192,22 @@ namespace CppJieba
}
public:
double getMinLogFreq()const{return _minLogFreq;};
double getMinLogFreq() const {return _minLogFreq;};
private:
bool _insert(const TrieNodeInfo& nodeInfo)
void _insert(const TrieNodeInfo& nodeInfo)
{
const Unicode& uintVec = nodeInfo.word;
TrieNode* p = _root;
for(uint i = 0; i < uintVec.size(); i++)
for(size_t i = 0; i < uintVec.size(); i++)
{
uint16_t cu = uintVec[i];
if(NULL == p)
{
return false;
}
assert(p);
if(p->hmap.end() == p->hmap.find(cu))
{
TrieNode * next = NULL;
try
{
next = new TrieNode;
}
catch(const bad_alloc& e)
{
return false;
}
TrieNode * next = new TrieNode;
assert(next);
p->hmap[cu] = next;
p = next;
}
@ -267,62 +216,41 @@ namespace CppJieba
p = p->hmap[cu];
}
}
if(NULL == p)
{
return false;
}
if(p->isLeaf)
{
LogError("this node already _inserted");
return false;
}
assert(p);
assert(!p->isLeaf);
p->isLeaf = true;
_nodeInfoVec.push_back(nodeInfo);
p->nodeInfoVecPos = _nodeInfoVec.size() - 1;
return true;
}
private:
bool _trieInsert(const char * const filePath)
bool _trieInsert(const string& filePath)
{
ifstream ifs(filePath);
ifstream ifs(filePath.c_str());
if(!ifs)
{
LogError("open %s failed.", filePath);
LogError("open %s failed.", filePath.c_str());
return false;
}
string line;
vector<string> vecBuf;
TrieNodeInfo nodeInfo;
size_t lineno = 0;
while(getline(ifs, line))
for(size_t lineno = 0 ; getline(ifs, line); lineno++)
{
vecBuf.clear();
lineno ++;
split(line, vecBuf, " ");
if(3 < vecBuf.size())
{
LogError("line[%u:%s] illegal.", lineno, line.c_str());
return false;
}
assert(vecBuf.size() == DICT_COLUMN_NUM);
if(!TransCode::decode(vecBuf[0], nodeInfo.word))
{
LogError("line[%u:%s] illegal.", lineno, line.c_str());
return false;
}
nodeInfo.freq = atoi(vecBuf[1].c_str());
if(3 == vecBuf.size())
{
nodeInfo.tag = vecBuf[2];
}
nodeInfo.tag = vecBuf[2];
if(!_insert(nodeInfo))
{
assert(false);
}
_insert(nodeInfo);
}
return true;
}
@ -340,21 +268,13 @@ namespace CppJieba
_freqSum += _nodeInfoVec[i].freq;
}
if(0 == _freqSum)
{
LogError("_freqSum == 0 .");
return false;
}
assert(_freqSum);
//normalize
for(uint i = 0; i < _nodeInfoVec.size(); i++)
for(size_t i = 0; i < _nodeInfoVec.size(); i++)
{
TrieNodeInfo& nodeInfo = _nodeInfoVec[i];
if(0 == nodeInfo.freq)
{
LogFatal("nodeInfo.freq == 0!");
return false;
}
assert(nodeInfo.freq);
nodeInfo.logFreq = log(double(nodeInfo.freq)/double(_freqSum));
if(_minLogFreq > nodeInfo.logFreq)
{
@ -367,12 +287,15 @@ namespace CppJieba
void _deleteNode(TrieNode* node)
{
if(!node)
{
return;
}
for(TrieNodeMap::iterator it = node->hmap.begin(); it != node->hmap.end(); it++)
{
TrieNode* next = it->second;
_deleteNode(next);
}
delete node;
}

View File

@ -12,11 +12,11 @@ TEST(SegmentBaseTest, Test1)
buf.push_back("你好");
buf.push_back("...hh");
vector<string> res;
uint size = strlen(str);
uint offset = 0;
size_t size = strlen(str);
size_t offset = 0;
while(offset < size)
{
uint len = 0;
size_t len = 0;
const char* t = str + offset;
SegmentBase::filterAscii(t, size - offset, len);
s.assign(t, len);

View File

@ -7,6 +7,7 @@ static const char* const DICT_FILE = "../dict/extra_dict/jieba.dict.small.utf8";
TEST(TrieTest, Test1)
{
string s1, s2;
Trie trie;
ASSERT_TRUE(trie.init(DICT_FILE));
ASSERT_LT(trie.getMinLogFreq() + 15.6479, 0.001);
@ -18,14 +19,16 @@ TEST(TrieTest, Test1)
nodeInfo.freq = 8779;
nodeInfo.tag = "v";
nodeInfo.logFreq = -8.87033;
s1 << nodeInfo;
s2 << (*trie.find(uni.begin(), uni.end()));
EXPECT_EQ(nodeInfo, *trie.find(uni.begin(), uni.end()));
EXPECT_EQ("[\"26469\", \"21040\"]:8779:v:-8.87033", s2);
word = "清华大学";
vector<pair<uint, const TrieNodeInfo*> > res;
map<uint, const TrieNodeInfo* > resMap;
map<uint, const TrieNodeInfo* > map;
vector<pair<size_t, const TrieNodeInfo*> > res;
map<size_t, const TrieNodeInfo* > resMap;
map<size_t, const TrieNodeInfo* > map;
const char * words[] = {"", "清华", "清华大学"};
for(uint i = 0; i < sizeof(words)/sizeof(words[0]); i++)
for(size_t i = 0; i < sizeof(words)/sizeof(words[0]); i++)
{
ASSERT_TRUE(TransCode::decode(words[i], uni));
res.push_back(make_pair(uni.size() - 1, trie.find(uni.begin(), uni.end())));
@ -34,7 +37,7 @@ TEST(TrieTest, Test1)
//TrieNodeInfo
//res.push_back(make_pair(0, ))
vector<pair<uint, const TrieNodeInfo*> > vec;
vector<pair<size_t, const TrieNodeInfo*> > vec;
ASSERT_TRUE(TransCode::decode(word, uni));
//print(uni);
ASSERT_TRUE(trie.find(uni.begin(), uni.end(), vec));