add SetStaticWordWeights UserWordWeightOption

This commit is contained in:
yanyiwu 2015-10-08 17:36:52 +08:00
parent b28d6db574
commit 98345d6aed

View File

@ -14,7 +14,9 @@
#include "Trie.hpp"
namespace CppJieba {
using namespace limonp;
const double MIN_DOUBLE = -3.14e+100;
const double MAX_DOUBLE = 3.14e+100;
const size_t DICT_COLUMN_NUM = 3;
@ -22,42 +24,46 @@ const char* const UNKNOWN_TAG = "";
class DictTrie {
public:
enum UserWordWeightOption {
Min,
Median,
Max,
}; // enum UserWordWeightOption
DictTrie() {
trie_ = NULL;
minWeight_ = MAX_DOUBLE;
min_weight_ = MAX_DOUBLE;
}
DictTrie(const string& dictPath, const string& userDictPaths = "") {
DictTrie(const string& dict_path, const string& user_dict_paths = "") {
new (this) DictTrie();
init(dictPath, userDictPaths);
init(dict_path, user_dict_paths);
}
~DictTrie() {
delete trie_;
}
void init(const string& dictPath, const string& userDictPaths = "") {
void init(const string& dict_path, const string& user_dict_paths = "") {
if (trie_ != NULL) {
LogFatal("trie already initted");
}
LoadDict(dictPath);
CalculateWeight(staticNodeInfos_);
minWeight_ = FindMinWeight(staticNodeInfos_);
maxWeight_ = FindMaxWeight(staticNodeInfos_);
LoadDict(dict_path);
CalculateWeight(static_node_infos_);
SetStaticWordWeights();
if (userDictPaths.size()) {
LoadUserDict(userDictPaths);
if (user_dict_paths.size()) {
LoadUserDict(user_dict_paths);
}
Shrink(staticNodeInfos_);
CreateTrie(staticNodeInfos_);
Shrink(static_node_infos_);
CreateTrie(static_node_infos_);
}
bool insertUserWord(const string& word, const string& tag = UNKNOWN_TAG) {
DictUnit nodeInfo;
if (!MakeUserNodeInfo(nodeInfo, word, tag)) {
DictUnit node_info;
if (!MakeUserNodeInfo(node_info, word, tag)) {
return false;
}
activeNodeInfos_.push_back(nodeInfo);
trie_->insertNode(nodeInfo.word, &activeNodeInfos_.back());
active_node_infos_.push_back(node_info);
trie_->insertNode(node_info.word, &active_node_infos_.back());
return true;
}
@ -73,11 +79,11 @@ class DictTrie {
}
bool isUserDictSingleChineseWord(const Rune& word) const {
return isIn(userDictSingleChineseWord_, word);
return isIn(user_dict_single_chinese_word_, word);
}
double getMinWeight() const {
return minWeight_;
return min_weight_;
}
private:
@ -101,7 +107,7 @@ class DictTrie {
LogFatal("file %s open failed.", files[i].c_str());
}
string line;
DictUnit nodeInfo;
DictUnit node_info;
vector<string> buf;
for (; getline(ifs, line); lineno++) {
buf.clear();
@ -109,38 +115,38 @@ class DictTrie {
if (buf.size() < 1) {
LogFatal("split [%s] result illegal", line.c_str());
}
DictUnit nodeInfo;
MakeUserNodeInfo(nodeInfo, buf[0],
DictUnit node_info;
MakeUserNodeInfo(node_info, buf[0],
(buf.size() == 2 ? buf[1] : UNKNOWN_TAG));
staticNodeInfos_.push_back(nodeInfo);
static_node_infos_.push_back(node_info);
}
}
LogInfo("load userdicts[%s] ok. lines[%u]", filePaths.c_str(), lineno);
}
bool MakeNodeInfo(DictUnit& nodeInfo,
bool MakeNodeInfo(DictUnit& node_info,
const string& word,
double weight,
const string& tag) {
if (!TransCode::decode(word, nodeInfo.word)) {
if (!TransCode::decode(word, node_info.word)) {
LogError("decode %s failed.", word.c_str());
return false;
}
nodeInfo.weight = weight;
nodeInfo.tag = tag;
node_info.weight = weight;
node_info.tag = tag;
return true;
}
bool MakeUserNodeInfo(DictUnit& nodeInfo,
bool MakeUserNodeInfo(DictUnit& node_info,
const string& word,
const string& tag = UNKNOWN_TAG) {
if (!TransCode::decode(word, nodeInfo.word)) {
if (!TransCode::decode(word, node_info.word)) {
LogError("decode %s failed.", word.c_str());
return false;
}
if (nodeInfo.word.size() == 1) {
userDictSingleChineseWord_.insert(nodeInfo.word[0]);
if (node_info.word.size() == 1) {
user_dict_single_chinese_word_.insert(node_info.word[0]);
}
nodeInfo.weight = maxWeight_;
nodeInfo.tag = tag;
node_info.weight = max_weight_;
node_info.tag = tag;
return true;
}
void LoadDict(const string& filePath) {
@ -151,44 +157,45 @@ class DictTrie {
string line;
vector<string> buf;
DictUnit nodeInfo;
DictUnit node_info;
for (size_t lineno = 0; getline(ifs, line); lineno++) {
split(line, buf, " ");
if (buf.size() != DICT_COLUMN_NUM) {
LogFatal("split result illegal, line: %s, result size: %u", line.c_str(), buf.size());
}
MakeNodeInfo(nodeInfo,
MakeNodeInfo(node_info,
buf[0],
atof(buf[1].c_str()),
buf[2]);
staticNodeInfos_.push_back(nodeInfo);
static_node_infos_.push_back(node_info);
}
}
double FindMinWeight(const vector<DictUnit>& nodeInfos) const {
double ret = MAX_DOUBLE;
for (size_t i = 0; i < nodeInfos.size(); i++) {
ret = min(nodeInfos[i].weight, ret);
}
return ret;
}
double FindMaxWeight(const vector<DictUnit>& nodeInfos) const {
double ret = MIN_DOUBLE;
for (size_t i = 0; i < nodeInfos.size(); i++) {
ret = max(nodeInfos[i].weight, ret);
}
return ret;
}
void CalculateWeight(vector<DictUnit>& nodeInfos) const {
static bool WeightCompare(const DictUnit& lhs, const DictUnit& rhs) {
return lhs.weight < rhs.weight;
}
void SetStaticWordWeights() {
if (static_node_infos_.empty()) {
LogFatal("something must be wrong");
}
vector<DictUnit> x = static_node_infos_;
sort(x.begin(), x.end(), WeightCompare);
min_weight_ = x[0].weight;
max_weight_ = x[x.size() - 1].weight;
median_weight_ = x[x.size() / 2].weight;
}
void CalculateWeight(vector<DictUnit>& node_infos) const {
double sum = 0.0;
for (size_t i = 0; i < nodeInfos.size(); i++) {
sum += nodeInfos[i].weight;
for (size_t i = 0; i < node_infos.size(); i++) {
sum += node_infos[i].weight;
}
assert(sum);
for (size_t i = 0; i < nodeInfos.size(); i++) {
DictUnit& nodeInfo = nodeInfos[i];
assert(nodeInfo.weight);
nodeInfo.weight = log(double(nodeInfo.weight)/double(sum));
for (size_t i = 0; i < node_infos.size(); i++) {
DictUnit& node_info = node_infos[i];
assert(node_info.weight);
node_info.weight = log(double(node_info.weight)/double(sum));
}
}
@ -196,13 +203,14 @@ class DictTrie {
vector<DictUnit>(units.begin(), units.end()).swap(units);
}
vector<DictUnit> staticNodeInfos_;
deque<DictUnit> activeNodeInfos_; // must not be vector
vector<DictUnit> static_node_infos_;
deque<DictUnit> active_node_infos_; // must not be vector
Trie * trie_;
double minWeight_;
double maxWeight_;
unordered_set<Rune> userDictSingleChineseWord_;
double min_weight_;
double max_weight_;
double median_weight_;
unordered_set<Rune> user_dict_single_chinese_word_;
};
}