support optional user word freq weight

This commit is contained in:
yanyiwu 2015-10-08 20:05:27 +08:00
parent 98345d6aed
commit 4d56be920b
13 changed files with 85 additions and 105 deletions

View File

@ -2,7 +2,8 @@
## next version ## next version
1. 支持多个userdict载入多词典路径用英文冒号(:)作为分隔符就当坐是向环境变量PATH致敬哈哈。 1. 支持多个userdict载入多词典路径用英文冒号(:)作为分隔符就当是向环境变量PATH致敬哈哈。
2. userdict是不带权重的之前对于新的userword默认设置词频权重为最大值现已支持可配置默认使用中位值。
## v3.2.1 ## v3.2.1

View File

@ -63,7 +63,7 @@ class Application {
vector<string>& words, size_t max_word_len) const { vector<string>& words, size_t max_word_len) const {
jieba_.CutSmall(sentence, words, max_word_len); jieba_.CutSmall(sentence, words, max_word_len);
} }
bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { bool InsertUserWord(const string& word, const string& tag = UNKNOWN_TAG) {
return jieba_.InsertUserWord(word, tag); return jieba_.InsertUserWord(word, tag);
} }
void tag(const string& str, vector<pair<string, string> >& res) const { void tag(const string& str, vector<pair<string, string> >& res) const {

View File

@ -30,25 +30,48 @@ class DictTrie {
Max, Max,
}; // enum UserWordWeightOption }; // enum UserWordWeightOption
DictTrie() { DictTrie(const string& dict_path, const string& user_dict_paths = "", UserWordWeightOption user_word_weight_opt = Median) {
trie_ = NULL; Init(dict_path, user_dict_paths, user_word_weight_opt);
min_weight_ = MAX_DOUBLE;
}
DictTrie(const string& dict_path, const string& user_dict_paths = "") {
new (this) DictTrie();
init(dict_path, user_dict_paths);
} }
~DictTrie() { ~DictTrie() {
delete trie_; delete trie_;
} }
void init(const string& dict_path, const string& user_dict_paths = "") { bool InsertUserWord(const string& word, const string& tag = UNKNOWN_TAG) {
if (trie_ != NULL) { DictUnit node_info;
LogFatal("trie already initted"); if (!MakeNodeInfo(node_info, word, max_weight_, tag)) {
return false;
} }
active_node_infos_.push_back(node_info);
trie_->insertNode(node_info.word, &active_node_infos_.back());
return true;
}
const DictUnit* Find(Unicode::const_iterator begin, Unicode::const_iterator end) const {
return trie_->Find(begin, end);
}
void Find(Unicode::const_iterator begin,
Unicode::const_iterator end,
vector<struct Dag>&res,
size_t max_word_len = MAX_WORD_LENGTH) const {
trie_->Find(begin, end, res, max_word_len);
}
bool IsUserDictSingleChineseWord(const Rune& word) const {
return isIn(user_dict_single_chinese_word_, word);
}
double GetMinWeight() const {
return min_weight_;
}
private:
void Init(const string& dict_path, const string& user_dict_paths, UserWordWeightOption user_word_weight_opt) {
LoadDict(dict_path); LoadDict(dict_path);
CalculateWeight(static_node_infos_); CalculateWeight(static_node_infos_);
SetStaticWordWeights(); SetStaticWordWeights(user_word_weight_opt);
if (user_dict_paths.size()) { if (user_dict_paths.size()) {
LoadUserDict(user_dict_paths); LoadUserDict(user_dict_paths);
@ -57,36 +80,6 @@ class DictTrie {
CreateTrie(static_node_infos_); CreateTrie(static_node_infos_);
} }
bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) {
DictUnit node_info;
if (!MakeUserNodeInfo(node_info, word, tag)) {
return false;
}
active_node_infos_.push_back(node_info);
trie_->insertNode(node_info.word, &active_node_infos_.back());
return true;
}
const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const {
return trie_->find(begin, end);
}
void find(Unicode::const_iterator begin,
Unicode::const_iterator end,
vector<struct Dag>&res,
size_t max_word_len = MAX_WORD_LENGTH) const {
trie_->find(begin, end, res, max_word_len);
}
bool isUserDictSingleChineseWord(const Rune& word) const {
return isIn(user_dict_single_chinese_word_, word);
}
double getMinWeight() const {
return min_weight_;
}
private:
void CreateTrie(const vector<DictUnit>& dictUnits) { void CreateTrie(const vector<DictUnit>& dictUnits) {
assert(dictUnits.size()); assert(dictUnits.size());
vector<Unicode> words; vector<Unicode> words;
@ -98,6 +91,7 @@ class DictTrie {
trie_ = new Trie(words, valuePointers); trie_ = new Trie(words, valuePointers);
} }
void LoadUserDict(const string& filePaths) { void LoadUserDict(const string& filePaths) {
vector<string> files = limonp::split(filePaths, ":"); vector<string> files = limonp::split(filePaths, ":");
size_t lineno = 0; size_t lineno = 0;
@ -116,13 +110,19 @@ class DictTrie {
LogFatal("split [%s] result illegal", line.c_str()); LogFatal("split [%s] result illegal", line.c_str());
} }
DictUnit node_info; DictUnit node_info;
MakeUserNodeInfo(node_info, buf[0], MakeNodeInfo(node_info,
buf[0],
max_weight_,
(buf.size() == 2 ? buf[1] : UNKNOWN_TAG)); (buf.size() == 2 ? buf[1] : UNKNOWN_TAG));
static_node_infos_.push_back(node_info); static_node_infos_.push_back(node_info);
if (node_info.word.size() == 1) {
user_dict_single_chinese_word_.insert(node_info.word[0]);
}
} }
} }
LogInfo("load userdicts[%s] ok. lines[%u]", filePaths.c_str(), lineno); LogInfo("load userdicts[%s] ok. lines[%u]", filePaths.c_str(), lineno);
} }
bool MakeNodeInfo(DictUnit& node_info, bool MakeNodeInfo(DictUnit& node_info,
const string& word, const string& word,
double weight, double weight,
@ -135,20 +135,7 @@ class DictTrie {
node_info.tag = tag; node_info.tag = tag;
return true; return true;
} }
bool MakeUserNodeInfo(DictUnit& node_info,
const string& word,
const string& tag = UNKNOWN_TAG) {
if (!TransCode::decode(word, node_info.word)) {
LogError("decode %s failed.", word.c_str());
return false;
}
if (node_info.word.size() == 1) {
user_dict_single_chinese_word_.insert(node_info.word[0]);
}
node_info.weight = max_weight_;
node_info.tag = tag;
return true;
}
void LoadDict(const string& filePath) { void LoadDict(const string& filePath) {
ifstream ifs(filePath.c_str()); ifstream ifs(filePath.c_str());
if (!ifs.is_open()) { if (!ifs.is_open()) {
@ -175,7 +162,7 @@ class DictTrie {
return lhs.weight < rhs.weight; return lhs.weight < rhs.weight;
} }
void SetStaticWordWeights() { void SetStaticWordWeights(UserWordWeightOption option) {
if (static_node_infos_.empty()) { if (static_node_infos_.empty()) {
LogFatal("something must be wrong"); LogFatal("something must be wrong");
} }
@ -184,6 +171,17 @@ class DictTrie {
min_weight_ = x[0].weight; min_weight_ = x[0].weight;
max_weight_ = x[x.size() - 1].weight; max_weight_ = x[x.size() - 1].weight;
median_weight_ = x[x.size() / 2].weight; median_weight_ = x[x.size() / 2].weight;
switch (option) {
case Min:
user_word_default_weight_ = min_weight_;
break;
case Median:
user_word_default_weight_ = median_weight_;
break;
default:
user_word_default_weight_ = max_weight_;
break;
}
} }
void CalculateWeight(vector<DictUnit>& node_infos) const { void CalculateWeight(vector<DictUnit>& node_infos) const {
@ -210,6 +208,7 @@ class DictTrie {
double min_weight_; double min_weight_;
double max_weight_; double max_weight_;
double median_weight_; double median_weight_;
double user_word_default_weight_;
unordered_set<Rune> user_dict_single_chinese_word_; unordered_set<Rune> user_dict_single_chinese_word_;
}; };
} }

View File

@ -54,7 +54,7 @@ class FullSegment: public SegmentBase {
int wordLen = 0; int wordLen = 0;
assert(dictTrie_); assert(dictTrie_);
vector<struct Dag> dags; vector<struct Dag> dags;
dictTrie_->find(begin, end, dags); dictTrie_->Find(begin, end, dags);
for (size_t i = 0; i < dags.size(); i++) { for (size_t i = 0; i < dags.size(); i++) {
for (size_t j = 0; j < dags[i].nexts.size(); j++) { for (size_t j = 0; j < dags[i].nexts.size(); j++) {
const DictUnit* du = dags[i].nexts[j].second; const DictUnit* du = dags[i].nexts[j].second;

View File

@ -44,7 +44,7 @@ class Jieba {
mp_seg_.cut(sentence, words, max_word_len); mp_seg_.cut(sentence, words, max_word_len);
} }
bool InsertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { bool InsertUserWord(const string& word, const string& tag = UNKNOWN_TAG) {
return dict_trie_.insertUserWord(word, tag); return dict_trie_.InsertUserWord(word, tag);
} }
const DictTrie* GetDictTrie() const { const DictTrie* GetDictTrie() const {

View File

@ -45,7 +45,7 @@ class MPSegment: public SegmentBase {
vector<Unicode>& words, vector<Unicode>& words,
size_t max_word_len = MAX_WORD_LENGTH) const { size_t max_word_len = MAX_WORD_LENGTH) const {
vector<Dag> dags; vector<Dag> dags;
dictTrie_->find(begin, dictTrie_->Find(begin,
end, end,
dags, dags,
max_word_len); max_word_len);
@ -57,8 +57,8 @@ class MPSegment: public SegmentBase {
return dictTrie_; return dictTrie_;
} }
bool isUserDictSingleChineseWord(const Rune & value) const { bool IsUserDictSingleChineseWord(const Rune& value) const {
return dictTrie_->isUserDictSingleChineseWord(value); return dictTrie_->IsUserDictSingleChineseWord(value);
} }
private: private:
void CalcDP(vector<Dag>& dags) const { void CalcDP(vector<Dag>& dags) const {
@ -81,7 +81,7 @@ class MPSegment: public SegmentBase {
if (p) { if (p) {
val += p->weight; val += p->weight;
} else { } else {
val += dictTrie_->getMinWeight(); val += dictTrie_->GetMinWeight();
} }
if (val > rit->weight) { if (val > rit->weight) {
rit->pInfo = p; rit->pInfo = p;

View File

@ -48,14 +48,14 @@ class MixSegment: public SegmentBase {
piece.reserve(end - begin); piece.reserve(end - begin);
for (size_t i = 0, j = 0; i < words.size(); i++) { for (size_t i = 0, j = 0; i < words.size(); i++) {
//if mp get a word, it's ok, put it into result //if mp get a word, it's ok, put it into result
if (1 != words[i].size() || (words[i].size() == 1 && mpSeg_.isUserDictSingleChineseWord(words[i][0]))) { if (1 != words[i].size() || (words[i].size() == 1 && mpSeg_.IsUserDictSingleChineseWord(words[i][0]))) {
res.push_back(words[i]); res.push_back(words[i]);
continue; continue;
} }
// if mp get a single one and it is not in userdict, collect it in sequence // if mp get a single one and it is not in userdict, collect it in sequence
j = i; j = i;
while (j < words.size() && 1 == words[j].size() && !mpSeg_.isUserDictSingleChineseWord(words[j][0])) { while (j < words.size() && 1 == words[j].size() && !mpSeg_.IsUserDictSingleChineseWord(words[j][0])) {
piece.push_back(words[j][0]); piece.push_back(words[j][0]);
j++; j++;
} }

View File

@ -38,7 +38,7 @@ class PosTagger {
LogError("decode failed."); LogError("decode failed.");
return false; return false;
} }
tmp = dict->find(unico.begin(), unico.end()); tmp = dict->Find(unico.begin(), unico.end());
if (tmp == NULL || tmp->tag.empty()) { if (tmp == NULL || tmp->tag.empty()) {
res.push_back(make_pair(*itr, SpecialRule(unico))); res.push_back(make_pair(*itr, SpecialRule(unico)));
} else { } else {

View File

@ -20,22 +20,6 @@ class SegmentBase {
} }
~SegmentBase() { ~SegmentBase() {
} }
/*
public:
void cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector<Unicode>& res) const = 0;
bool cut(const string& sentence, vector<string>& words) const {
PreFilter pre_filter(symbols_, sentence);
PreFilter::Range range;
vector<Unicode> uwords;
uwords.reserve(sentence.size());
while (pre_filter.HasNext()) {
range = pre_filter.Next();
cut(range.begin, range.end, uwords);
}
TransCode::encode(uwords, words);
return true;
}
*/
protected: protected:
void LoadSpecialSymbols() { void LoadSpecialSymbols() {

View File

@ -65,7 +65,7 @@ class Trie {
} }
} }
const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const { const DictUnit* Find(Unicode::const_iterator begin, Unicode::const_iterator end) const {
if (begin == end) { if (begin == end) {
return NULL; return NULL;
} }
@ -85,7 +85,7 @@ class Trie {
return ptNode->ptValue; return ptNode->ptValue;
} }
void find(Unicode::const_iterator begin, void Find(Unicode::const_iterator begin,
Unicode::const_iterator end, Unicode::const_iterator end,
vector<struct Dag>&res, vector<struct Dag>&res,
size_t max_word_len = MAX_WORD_LENGTH) const { size_t max_word_len = MAX_WORD_LENGTH) const {

View File

@ -51,7 +51,7 @@ int main(int argc, char** argv) {
cout << "[demo] Insert User Word" << endl; cout << "[demo] Insert User Word" << endl;
app.cut("男默女泪", words); app.cut("男默女泪", words);
cout << join(words.begin(), words.end(), "/") << endl; cout << join(words.begin(), words.end(), "/") << endl;
app.insertUserWord("男默女泪"); app.InsertUserWord("男默女泪");
app.cut("男默女泪", words); app.cut("男默女泪", words);
cout << join(words.begin(), words.end(), "/") << endl; cout << join(words.begin(), words.end(), "/") << endl;

View File

@ -76,7 +76,7 @@ TEST(ApplicationTest, InsertUserWord) {
result << words; result << words;
ASSERT_EQ("[\"男默\", \"女泪\"]", result); ASSERT_EQ("[\"男默\", \"女泪\"]", result);
ASSERT_TRUE(app.insertUserWord("男默女泪")); ASSERT_TRUE(app.InsertUserWord("男默女泪"));
app.cut("男默女泪", words); app.cut("男默女泪", words);
result << words; result << words;
@ -85,7 +85,7 @@ TEST(ApplicationTest, InsertUserWord) {
for (size_t i = 0; i < 100; i++) { for (size_t i = 0; i < 100; i++) {
string newWord; string newWord;
newWord << rand(); newWord << rand();
ASSERT_TRUE(app.insertUserWord(newWord)); ASSERT_TRUE(app.InsertUserWord(newWord));
app.cut(newWord, words); app.cut(newWord, words);
result << words; result << words;
ASSERT_EQ(result, string_format("[\"%s\"]", newWord.c_str())); ASSERT_EQ(result, string_format("[\"%s\"]", newWord.c_str()));

View File

@ -24,16 +24,12 @@ TEST(DictTrieTest, NewAndDelete) {
DictTrie * trie; DictTrie * trie;
trie = new DictTrie(DICT_FILE); trie = new DictTrie(DICT_FILE);
delete trie; delete trie;
trie = new DictTrie();
delete trie;
} }
TEST(DictTrieTest, Test1) { TEST(DictTrieTest, Test1) {
string s1, s2; string s1, s2;
DictTrie trie; DictTrie trie(DICT_FILE);
trie.init(DICT_FILE); ASSERT_LT(trie.GetMinWeight() + 15.6479, 0.001);
ASSERT_LT(trie.getMinWeight() + 15.6479, 0.001);
string word("来到"); string word("来到");
Unicode uni; Unicode uni;
ASSERT_TRUE(TransCode::decode(word, uni)); ASSERT_TRUE(TransCode::decode(word, uni));
@ -42,7 +38,7 @@ TEST(DictTrieTest, Test1) {
nodeInfo.tag = "v"; nodeInfo.tag = "v";
nodeInfo.weight = -8.87033; nodeInfo.weight = -8.87033;
s1 << nodeInfo; s1 << nodeInfo;
s2 << (*trie.find(uni.begin(), uni.end())); s2 << (*trie.Find(uni.begin(), uni.end()));
EXPECT_EQ("[\"26469\", \"21040\"] v -8.870", s2); EXPECT_EQ("[\"26469\", \"21040\"] v -8.870", s2);
word = "清华大学"; word = "清华大学";
@ -50,13 +46,13 @@ TEST(DictTrieTest, Test1) {
const char * words[] = {"", "清华", "清华大学"}; const char * words[] = {"", "清华", "清华大学"};
for (size_t i = 0; i < sizeof(words)/sizeof(words[0]); i++) { for (size_t i = 0; i < sizeof(words)/sizeof(words[0]); i++) {
ASSERT_TRUE(TransCode::decode(words[i], uni)); ASSERT_TRUE(TransCode::decode(words[i], uni));
res.push_back(make_pair(uni.size() - 1, trie.find(uni.begin(), uni.end()))); res.push_back(make_pair(uni.size() - 1, trie.Find(uni.begin(), uni.end())));
//resMap[uni.size() - 1] = trie.find(uni.begin(), uni.end()); //resMap[uni.size() - 1] = trie.Find(uni.begin(), uni.end());
} }
vector<pair<size_t, const DictUnit*> > vec; vector<pair<size_t, const DictUnit*> > vec;
vector<struct Dag> dags; vector<struct Dag> dags;
ASSERT_TRUE(TransCode::decode(word, uni)); ASSERT_TRUE(TransCode::decode(word, uni));
trie.find(uni.begin(), uni.end(), dags); trie.Find(uni.begin(), uni.end(), dags);
ASSERT_EQ(dags.size(), uni.size()); ASSERT_EQ(dags.size(), uni.size());
ASSERT_NE(dags.size(), 0u); ASSERT_NE(dags.size(), 0u);
s1 << res; s1 << res;
@ -70,7 +66,7 @@ TEST(DictTrieTest, UserDict) {
string word = "云计算"; string word = "云计算";
Unicode unicode; Unicode unicode;
ASSERT_TRUE(TransCode::decode(word, unicode)); ASSERT_TRUE(TransCode::decode(word, unicode));
const DictUnit * unit = trie.find(unicode.begin(), unicode.end()); const DictUnit * unit = trie.Find(unicode.begin(), unicode.end());
ASSERT_TRUE(unit); ASSERT_TRUE(unit);
string res ; string res ;
res << *unit; res << *unit;
@ -85,7 +81,7 @@ TEST(DictTrieTest, Dag) {
Unicode unicode; Unicode unicode;
ASSERT_TRUE(TransCode::decode(word, unicode)); ASSERT_TRUE(TransCode::decode(word, unicode));
vector<struct Dag> res; vector<struct Dag> res;
trie.find(unicode.begin(), unicode.end(), res); trie.Find(unicode.begin(), unicode.end(), res);
size_t nexts_sizes[] = {3, 2, 2, 1}; size_t nexts_sizes[] = {3, 2, 2, 1};
ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0])); ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0]));
@ -99,7 +95,7 @@ TEST(DictTrieTest, Dag) {
Unicode unicode; Unicode unicode;
ASSERT_TRUE(TransCode::decode(word, unicode)); ASSERT_TRUE(TransCode::decode(word, unicode));
vector<struct Dag> res; vector<struct Dag> res;
trie.find(unicode.begin(), unicode.end(), res); trie.Find(unicode.begin(), unicode.end(), res);
size_t nexts_sizes[] = {3, 1, 2, 2, 2, 1}; size_t nexts_sizes[] = {3, 1, 2, 2, 2, 1};
ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0])); ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0]));
@ -113,7 +109,7 @@ TEST(DictTrieTest, Dag) {
Unicode unicode; Unicode unicode;
ASSERT_TRUE(TransCode::decode(word, unicode)); ASSERT_TRUE(TransCode::decode(word, unicode));
vector<struct Dag> res; vector<struct Dag> res;
trie.find(unicode.begin(), unicode.end(), res); trie.Find(unicode.begin(), unicode.end(), res);
size_t nexts_sizes[] = {3, 1, 2, 1}; size_t nexts_sizes[] = {3, 1, 2, 1};
ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0])); ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0]));
@ -127,7 +123,7 @@ TEST(DictTrieTest, Dag) {
Unicode unicode; Unicode unicode;
ASSERT_TRUE(TransCode::decode(word, unicode)); ASSERT_TRUE(TransCode::decode(word, unicode));
vector<struct Dag> res; vector<struct Dag> res;
trie.find(unicode.begin(), unicode.end(), res, 3); trie.Find(unicode.begin(), unicode.end(), res, 3);
size_t nexts_sizes[] = {2, 1, 2, 1}; size_t nexts_sizes[] = {2, 1, 2, 1};
ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0])); ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0]));
@ -141,7 +137,7 @@ TEST(DictTrieTest, Dag) {
Unicode unicode; Unicode unicode;
ASSERT_TRUE(TransCode::decode(word, unicode)); ASSERT_TRUE(TransCode::decode(word, unicode));
vector<struct Dag> res; vector<struct Dag> res;
trie.find(unicode.begin(), unicode.end(), res, 4); trie.Find(unicode.begin(), unicode.end(), res, 4);
size_t nexts_sizes[] = {3, 1, 2, 1}; size_t nexts_sizes[] = {3, 1, 2, 1};
ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0])); ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0]));