add SetStaticWordWeights UserWordWeightOption

This commit is contained in:
yanyiwu 2015-10-08 17:36:52 +08:00
parent b28d6db574
commit 98345d6aed

View File

@ -14,7 +14,9 @@
#include "Trie.hpp" #include "Trie.hpp"
namespace CppJieba { namespace CppJieba {
using namespace limonp; using namespace limonp;
const double MIN_DOUBLE = -3.14e+100; const double MIN_DOUBLE = -3.14e+100;
const double MAX_DOUBLE = 3.14e+100; const double MAX_DOUBLE = 3.14e+100;
const size_t DICT_COLUMN_NUM = 3; const size_t DICT_COLUMN_NUM = 3;
@ -22,42 +24,46 @@ const char* const UNKNOWN_TAG = "";
class DictTrie { class DictTrie {
public: public:
enum UserWordWeightOption {
Min,
Median,
Max,
}; // enum UserWordWeightOption
DictTrie() { DictTrie() {
trie_ = NULL; trie_ = NULL;
minWeight_ = MAX_DOUBLE; min_weight_ = MAX_DOUBLE;
} }
DictTrie(const string& dictPath, const string& userDictPaths = "") { DictTrie(const string& dict_path, const string& user_dict_paths = "") {
new (this) DictTrie(); new (this) DictTrie();
init(dictPath, userDictPaths); init(dict_path, user_dict_paths);
} }
~DictTrie() { ~DictTrie() {
delete trie_; delete trie_;
} }
void init(const string& dictPath, const string& userDictPaths = "") { void init(const string& dict_path, const string& user_dict_paths = "") {
if (trie_ != NULL) { if (trie_ != NULL) {
LogFatal("trie already initted"); LogFatal("trie already initted");
} }
LoadDict(dictPath); LoadDict(dict_path);
CalculateWeight(staticNodeInfos_); CalculateWeight(static_node_infos_);
minWeight_ = FindMinWeight(staticNodeInfos_); SetStaticWordWeights();
maxWeight_ = FindMaxWeight(staticNodeInfos_);
if (userDictPaths.size()) { if (user_dict_paths.size()) {
LoadUserDict(userDictPaths); LoadUserDict(user_dict_paths);
} }
Shrink(staticNodeInfos_); Shrink(static_node_infos_);
CreateTrie(staticNodeInfos_); CreateTrie(static_node_infos_);
} }
bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) { bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) {
DictUnit nodeInfo; DictUnit node_info;
if (!MakeUserNodeInfo(nodeInfo, word, tag)) { if (!MakeUserNodeInfo(node_info, word, tag)) {
return false; return false;
} }
activeNodeInfos_.push_back(nodeInfo); active_node_infos_.push_back(node_info);
trie_->insertNode(nodeInfo.word, &activeNodeInfos_.back()); trie_->insertNode(node_info.word, &active_node_infos_.back());
return true; return true;
} }
@ -73,11 +79,11 @@ class DictTrie {
} }
bool isUserDictSingleChineseWord(const Rune& word) const { bool isUserDictSingleChineseWord(const Rune& word) const {
return isIn(userDictSingleChineseWord_, word); return isIn(user_dict_single_chinese_word_, word);
} }
double getMinWeight() const { double getMinWeight() const {
return minWeight_; return min_weight_;
} }
private: private:
@ -101,7 +107,7 @@ class DictTrie {
LogFatal("file %s open failed.", files[i].c_str()); LogFatal("file %s open failed.", files[i].c_str());
} }
string line; string line;
DictUnit nodeInfo; DictUnit node_info;
vector<string> buf; vector<string> buf;
for (; getline(ifs, line); lineno++) { for (; getline(ifs, line); lineno++) {
buf.clear(); buf.clear();
@ -109,38 +115,38 @@ class DictTrie {
if (buf.size() < 1) { if (buf.size() < 1) {
LogFatal("split [%s] result illegal", line.c_str()); LogFatal("split [%s] result illegal", line.c_str());
} }
DictUnit nodeInfo; DictUnit node_info;
MakeUserNodeInfo(nodeInfo, buf[0], MakeUserNodeInfo(node_info, buf[0],
(buf.size() == 2 ? buf[1] : UNKNOWN_TAG)); (buf.size() == 2 ? buf[1] : UNKNOWN_TAG));
staticNodeInfos_.push_back(nodeInfo); static_node_infos_.push_back(node_info);
} }
} }
LogInfo("load userdicts[%s] ok. lines[%u]", filePaths.c_str(), lineno); LogInfo("load userdicts[%s] ok. lines[%u]", filePaths.c_str(), lineno);
} }
bool MakeNodeInfo(DictUnit& nodeInfo, bool MakeNodeInfo(DictUnit& node_info,
const string& word, const string& word,
double weight, double weight,
const string& tag) { const string& tag) {
if (!TransCode::decode(word, nodeInfo.word)) { if (!TransCode::decode(word, node_info.word)) {
LogError("decode %s failed.", word.c_str()); LogError("decode %s failed.", word.c_str());
return false; return false;
} }
nodeInfo.weight = weight; node_info.weight = weight;
nodeInfo.tag = tag; node_info.tag = tag;
return true; return true;
} }
bool MakeUserNodeInfo(DictUnit& nodeInfo, bool MakeUserNodeInfo(DictUnit& node_info,
const string& word, const string& word,
const string& tag = UNKNOWN_TAG) { const string& tag = UNKNOWN_TAG) {
if (!TransCode::decode(word, nodeInfo.word)) { if (!TransCode::decode(word, node_info.word)) {
LogError("decode %s failed.", word.c_str()); LogError("decode %s failed.", word.c_str());
return false; return false;
} }
if (nodeInfo.word.size() == 1) { if (node_info.word.size() == 1) {
userDictSingleChineseWord_.insert(nodeInfo.word[0]); user_dict_single_chinese_word_.insert(node_info.word[0]);
} }
nodeInfo.weight = maxWeight_; node_info.weight = max_weight_;
nodeInfo.tag = tag; node_info.tag = tag;
return true; return true;
} }
void LoadDict(const string& filePath) { void LoadDict(const string& filePath) {
@ -151,44 +157,45 @@ class DictTrie {
string line; string line;
vector<string> buf; vector<string> buf;
DictUnit nodeInfo; DictUnit node_info;
for (size_t lineno = 0; getline(ifs, line); lineno++) { for (size_t lineno = 0; getline(ifs, line); lineno++) {
split(line, buf, " "); split(line, buf, " ");
if (buf.size() != DICT_COLUMN_NUM) { if (buf.size() != DICT_COLUMN_NUM) {
LogFatal("split result illegal, line: %s, result size: %u", line.c_str(), buf.size()); LogFatal("split result illegal, line: %s, result size: %u", line.c_str(), buf.size());
} }
MakeNodeInfo(nodeInfo, MakeNodeInfo(node_info,
buf[0], buf[0],
atof(buf[1].c_str()), atof(buf[1].c_str()),
buf[2]); buf[2]);
staticNodeInfos_.push_back(nodeInfo); static_node_infos_.push_back(node_info);
} }
} }
double FindMinWeight(const vector<DictUnit>& nodeInfos) const {
double ret = MAX_DOUBLE;
for (size_t i = 0; i < nodeInfos.size(); i++) {
ret = min(nodeInfos[i].weight, ret);
}
return ret;
}
double FindMaxWeight(const vector<DictUnit>& nodeInfos) const {
double ret = MIN_DOUBLE;
for (size_t i = 0; i < nodeInfos.size(); i++) {
ret = max(nodeInfos[i].weight, ret);
}
return ret;
}
void CalculateWeight(vector<DictUnit>& nodeInfos) const { static bool WeightCompare(const DictUnit& lhs, const DictUnit& rhs) {
return lhs.weight < rhs.weight;
}
void SetStaticWordWeights() {
if (static_node_infos_.empty()) {
LogFatal("something must be wrong");
}
vector<DictUnit> x = static_node_infos_;
sort(x.begin(), x.end(), WeightCompare);
min_weight_ = x[0].weight;
max_weight_ = x[x.size() - 1].weight;
median_weight_ = x[x.size() / 2].weight;
}
void CalculateWeight(vector<DictUnit>& node_infos) const {
double sum = 0.0; double sum = 0.0;
for (size_t i = 0; i < nodeInfos.size(); i++) { for (size_t i = 0; i < node_infos.size(); i++) {
sum += nodeInfos[i].weight; sum += node_infos[i].weight;
} }
assert(sum); assert(sum);
for (size_t i = 0; i < nodeInfos.size(); i++) { for (size_t i = 0; i < node_infos.size(); i++) {
DictUnit& nodeInfo = nodeInfos[i]; DictUnit& node_info = node_infos[i];
assert(nodeInfo.weight); assert(node_info.weight);
nodeInfo.weight = log(double(nodeInfo.weight)/double(sum)); node_info.weight = log(double(node_info.weight)/double(sum));
} }
} }
@ -196,13 +203,14 @@ class DictTrie {
vector<DictUnit>(units.begin(), units.end()).swap(units); vector<DictUnit>(units.begin(), units.end()).swap(units);
} }
vector<DictUnit> staticNodeInfos_; vector<DictUnit> static_node_infos_;
deque<DictUnit> activeNodeInfos_; // must not be vector deque<DictUnit> active_node_infos_; // must not be vector
Trie * trie_; Trie * trie_;
double minWeight_; double min_weight_;
double maxWeight_; double max_weight_;
unordered_set<Rune> userDictSingleChineseWord_; double median_weight_;
unordered_set<Rune> user_dict_single_chinese_word_;
}; };
} }