LevelSegment

This commit is contained in:
yanyiwu 2015-08-11 00:53:06 +08:00
parent efd029c20b
commit 41e4300c9a
7 changed files with 155 additions and 82 deletions

View File

@ -66,21 +66,18 @@ class DictTrie {
const DictUnit* find(Unicode::const_iterator begin, Unicode::const_iterator end) const {
return trie_->find(begin, end);
}
void find(Unicode::const_iterator begin,
Unicode::const_iterator end,
vector<Dag>& res) const {
trie_->find(begin, end, res);
}
void findByLimit(Unicode::const_iterator begin,
Unicode::const_iterator end,
size_t min_word_len,
size_t max_word_len,
vector<struct Dag>&res) const {
trie_->findByLimit(begin, end, min_word_len, max_word_len, res);
vector<struct Dag>&res,
size_t max_word_len = MAX_WORD_LENGTH) const {
trie_->find(begin, end, res, max_word_len);
}
bool isUserDictSingleChineseWord(const Rune& word) const {
return isIn(userDictSingleChineseWord_, word);
}
double getMinWeight() const {
return minWeight_;
}

View File

@ -6,7 +6,7 @@ namespace CppJieba {
class ISegment {
public:
virtual ~ISegment() {
};
}
virtual bool cut(const string& str, vector<string>& res) const = 0;
};

79
src/LevelSegment.hpp Normal file
View File

@ -0,0 +1,79 @@
#ifndef CPPJIEBA_LEVELSEGMENT_H
#define CPPJIEBA_LEVELSEGMENT_H
#include "MPSegment.hpp"
namespace CppJieba {
class LevelSegment: public ISegment {
public:
LevelSegment(const string& dictPath,
const string& userDictPath = "")
: mpSeg_(dictPath, userDictPath) {
LogInfo("LevelSegment init");
}
LevelSegment(const DictTrie* dictTrie)
: mpSeg_(dictTrie) {
}
virtual ~LevelSegment() {
}
void cut(Unicode::const_iterator begin,
Unicode::const_iterator end,
vector<pair<Unicode, size_t> >& res) const {
vector<Unicode> words;
vector<Unicode> smallerWords;
words.reserve(end - begin);
mpSeg_.cut(begin, end, words);
smallerWords.reserve(words.size());
res.reserve(words.size());
size_t level = 0;
while (!words.empty()) {
smallerWords.clear();
for (size_t i = 0; i < words.size(); i++) {
if (words[i].size() >= 3) {
size_t len = words[i].size() - 1;
mpSeg_.cut(words[i].begin(), words[i].end(), smallerWords, len); // buffer.push_back without clear
}
if (words[i].size() > 1) {
res.push_back(pair<Unicode, size_t>(words[i], level));
}
}
words.swap(smallerWords);
level++;
}
}
void cut(const string& sentence,
vector<pair<string, size_t> >& words) const {
Unicode unicode;
TransCode::decode(sentence, unicode);
vector<pair<Unicode, size_t> > unicodeWords;
cut(unicode.begin(), unicode.end(), unicodeWords);
words.resize(unicodeWords.size());
for (size_t i = 0; i < words.size(); i++) {
TransCode::encode(unicodeWords[i].first, words[i].first);
words[i].second = unicodeWords[i].second;
}
}
bool cut(const string& sentence,
vector<string>& res) const {
vector<pair<string, size_t> > words;
cut(sentence, words);
res.reserve(words.size());
for (size_t i = 0; i < words.size(); i++) {
res.push_back(words[i].first);
}
return true;
}
private:
MPSegment mpSeg_;
}; // class LevelSegment
} // namespace CppJieba
#endif // CPPJIEBA_LEVELSEGMENT_H

View File

@ -44,14 +44,13 @@ class MPSegment: public SegmentBase {
}
void cut(Unicode::const_iterator begin,
Unicode::const_iterator end,
size_t min_word_len,
size_t max_word_len,
vector<Unicode>&res) const {
vector<Unicode>& res,
size_t max_word_len) const {
vector<Dag> dags;
dictTrie_->findByLimit(begin, end,
min_word_len,
max_word_len,
dags);
dictTrie_->find(begin,
end,
dags,
max_word_len);
calcDP_(dags);
cut_(dags, res);
}

View File

@ -8,7 +8,6 @@
namespace CppJieba {
using namespace std;
const size_t MIN_WORD_LENGTH = 1;
const size_t MAX_WORD_LENGTH = 512;
struct DictUnit {
@ -86,18 +85,12 @@ class Trie {
return ptNode->ptValue;
}
void findByLimit(Unicode::const_iterator begin,
void find(Unicode::const_iterator begin,
Unicode::const_iterator end,
size_t min_word_len,
size_t max_word_len,
vector<struct Dag>&res) const {
vector<struct Dag>&res,
size_t max_word_len = MAX_WORD_LENGTH) const {
res.resize(end - begin);
// min_word_len start from 1;
if (min_word_len < 1) {
min_word_len = 1;
}
const TrieNode *ptNode = NULL;
TrieNode::NextMap::const_iterator citer;
for (size_t i = 0; i < size_t(end - begin); i++) {
@ -106,11 +99,8 @@ class Trie {
res[i].rune = rune;
assert(res[i].nexts.empty());
if (min_word_len <= 1) {
res[i].nexts.push_back(pair<size_t, const DictUnit*>(i, ptNode->ptValue));
}
res[i].nexts.push_back(pair<size_t, const DictUnit*>(i, ptNode->ptValue));
// min_word_len start from 1;
for (size_t j = i + 1; j < size_t(end - begin) && (j - i + 1) <= max_word_len ; j++) {
if (ptNode->next == NULL) {
break;
@ -120,18 +110,13 @@ class Trie {
break;
}
ptNode = citer->second;
if (NULL != ptNode->ptValue && (j - i + 1) >= min_word_len) {
if (NULL != ptNode->ptValue) {
res[i].nexts.push_back(pair<size_t, const DictUnit*>(j, ptNode->ptValue));
}
}
}
}
void find(Unicode::const_iterator begin,
Unicode::const_iterator end,
vector<struct Dag>& res) const {
findByLimit(begin, end, MIN_WORD_LENGTH, MAX_WORD_LENGTH, res);
}
void insertNode(const Unicode& key, const DictUnit* ptValue) {
if (key.begin() == key.end()) {
return;

View File

@ -4,6 +4,7 @@
#include "src/HMMSegment.hpp"
#include "src/FullSegment.hpp"
#include "src/QuerySegment.hpp"
#include "src/LevelSegment.hpp"
#include "gtest/gtest.h"
using namespace CppJieba;
@ -86,50 +87,52 @@ TEST(MixSegmentTest, UserDict2) {
TEST(MPSegmentTest, Test1) {
MPSegment segment("../dict/jieba.dict.utf8");;
const char* str = "我来自北京邮电大学。";
const char* res[] = {"", "来自", "北京邮电大学", ""};
string s;
vector<string> words;
ASSERT_TRUE(segment.cut(str, words));
ASSERT_EQ(words, vector<string>(res, res + sizeof(res)/sizeof(res[0])));
ASSERT_TRUE(segment.cut("我来自北京邮电大学。", words));
ASSERT_EQ("[\"\", \"来自\", \"北京邮电大学\", \"\"]", s << words);
{
const char* str = "B超 T恤";
const char * res[] = {"B超", " ", "T恤"};
vector<string> words;
ASSERT_TRUE(segment.cut(str, words));
ASSERT_EQ(words, vector<string>(res, res + sizeof(res)/sizeof(res[0])));
}
ASSERT_TRUE(segment.cut("B超 T恤", words));
ASSERT_EQ(s << words, "[\"B超\", \" \", \"T恤\"]");
ASSERT_TRUE(segment.cut("南京市长江大桥", words));
ASSERT_EQ("[\"南京市\", \"长江大桥\"]", s << words);
// MaxWordLen
//ASSERT_TRUE(segment.cut("南京市长江大桥", words, 3));
//ASSERT_EQ("[\"南京市\", \"长江\", \"大桥\"]", s << words);
}
TEST(MPSegmentTest, Test2) {
MPSegment segment("../test/testdata/extra_dict/jieba.dict.small.utf8");
string line;
ifstream ifs("../test/testdata/review.100");
vector<string> words;
//TEST(MPSegmentTest, Test2) {
// MPSegment segment("../test/testdata/extra_dict/jieba.dict.small.utf8");
// string line;
// ifstream ifs("../test/testdata/review.100");
// vector<string> words;
//
// string eRes;
// {
// ifstream ifs("../test/testdata/review.100.res");
// ASSERT_TRUE(!!ifs);
// eRes << ifs;
// }
// string res;
//
// while(getline(ifs, line)) {
// res += line;
// res += '\n';
//
// segment.cut(line, words);
// string s;
// s << words;
// res += s;
// res += '\n';
// }
// ofstream ofs("../test/testdata/review.100.res");
// ASSERT_TRUE(!!ofs);
// ofs << res;
//
//}
string eRes;
{
ifstream ifs("../test/testdata/review.100.res");
ASSERT_TRUE(!!ifs);
eRes << ifs;
}
string res;
while(getline(ifs, line)) {
res += line;
res += '\n';
segment.cut(line, words);
string s;
s << words;
res += s;
res += '\n';
}
ofstream ofs("../test/testdata/review.100.res");
ASSERT_TRUE(!!ofs);
ofs << res;
}
TEST(HMMSegmentTest, Test1) {
HMMSegment segment("../dict/hmm_model.utf8");;
{
@ -203,3 +206,15 @@ TEST(QuerySegment, Test2) {
}
}
TEST(LevelSegmentTest, Test0) {
string s;
LevelSegment segment("../test/testdata/extra_dict/jieba.dict.small.utf8");
vector<pair<string, size_t> > words;
segment.cut("南京市长江大桥", words);
ASSERT_EQ("[\"南京市:0\", \"长江大桥:0\", \"南京:1\", \"长江:1\", \"大桥:1\"]", s << words);
vector<string> res;
segment.cut("南京市长江大桥", res);
ASSERT_EQ("[\"南京市\", \"长江大桥\", \"南京\", \"长江\", \"大桥\"]", s << res);
}

View File

@ -122,28 +122,26 @@ TEST(DictTrieTest, Dag) {
}
}
//findByLimit [2, 3]
{
string word = "长江大桥";
Unicode unicode;
ASSERT_TRUE(TransCode::decode(word, unicode));
vector<struct Dag> res;
trie.findByLimit(unicode.begin(), unicode.end(), 2, 3, res);
trie.find(unicode.begin(), unicode.end(), res, 3);
size_t nexts_sizes[] = {1, 0, 1, 0};
size_t nexts_sizes[] = {2, 1, 2, 1};
ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0]));
for (size_t i = 0; i < res.size(); i++) {
ASSERT_EQ(res[i].nexts.size(), nexts_sizes[i]);
}
}
//findByLimit [0, 4]
{
string word = "长江大桥";
Unicode unicode;
ASSERT_TRUE(TransCode::decode(word, unicode));
vector<struct Dag> res;
trie.findByLimit(unicode.begin(), unicode.end(), 0, 4, res);
trie.find(unicode.begin(), unicode.end(), res, 4);
size_t nexts_sizes[] = {3, 1, 2, 1};
ASSERT_EQ(res.size(), sizeof(nexts_sizes)/sizeof(nexts_sizes[0]));