From 0db2dfa6b83b000efe1fd68c07f60ea8c82c0f3b Mon Sep 17 00:00:00 2001 From: wyy Date: Tue, 24 Dec 2013 01:22:02 -0800 Subject: [PATCH] finished KeywordExtractor and its ut --- src/KeywordExtractor.hpp | 76 +++++++++++++++++++++++++---- test/unittest/CMakeLists.txt | 1 + test/unittest/TKeywordExtractor.cpp | 29 +++++++++-- 3 files changed, 92 insertions(+), 14 deletions(-) diff --git a/src/KeywordExtractor.hpp b/src/KeywordExtractor.hpp index f677a22..c4917bd 100644 --- a/src/KeywordExtractor.hpp +++ b/src/KeywordExtractor.hpp @@ -2,6 +2,7 @@ #define CPPJIEBA_KEYWORD_EXTRACTOR_H #include "MPSegment.hpp" +#include namespace CppJieba { @@ -9,18 +10,24 @@ namespace CppJieba struct KeyWordInfo { - + string word; uint freq; - double weight; + double idf; }; - class KeywordExtractor//: public MPSegment + inline ostream& operator << (ostream& os, const KeyWordInfo & keyword) + { + return os << keyword.word << "," << keyword.freq << "," << keyword.idf; + } + + class KeywordExtractor { private: MPSegment _segment; private: - unordered_map _wordIndex; - vector _words; + unordered_map _wordIndex; + vector _wordinfos; + size_t _totalFreq; protected: bool _isInited; bool _getInitFlag()const{return _isInited;}; @@ -40,16 +47,51 @@ namespace CppJieba LogError("open %s failed.", dictPath.c_str()); return false; } + _totalFreq = 0; + int tfreq; string line ; vector buf; + KeyWordInfo keywordInfo; for(uint lineno = 0; getline(ifs, line); lineno++) { buf.clear(); + if(line.empty()) + { + LogError("line[%d] empty. skipped.", lineno); + continue; + } + if(!split(line, buf, " ") || buf.size() != 3) + { + LogError("line %d [%s] illegal. skipped.", lineno, line.c_str()); + continue; + } + keywordInfo.word = buf[0]; + tfreq= atoi(buf[1].c_str()); + if(tfreq <= 0) + { + LogError("line %d [%s] illegal. skipped.", lineno, line.c_str()); + continue; + } + keywordInfo.freq = tfreq; + _totalFreq += tfreq; + _wordinfos.push_back(keywordInfo); + } + + // calculate idf & make index. + for(uint i = 0; i < _wordinfos.size(); i++) + { + if(_wordinfos[i].freq <= 0) + { + LogFatal("freq value is not positive."); + return false; + } + _wordinfos[i].idf = -log(_wordinfos[i].freq); + _wordIndex[_wordinfos[i].word] = &(_wordinfos[i]); } return _setInitFlag(_segment.init(dictPath)); }; public: - bool extract(const string& str, vector& keywords, uint topN) + bool extract(const string& str, vector& keywords, uint topN) const { assert(_getInitFlag()); @@ -60,14 +102,28 @@ namespace CppJieba return false; } - unordered_map wordcnt; + unordered_map wordmap; for(uint i = 0; i < words.size(); i ++) { - wordcnt[ words[i] ] ++; + wordmap[ words[i] ] += 1.0; } - vector > topWords(topN); - partial_sort_copy(wordcnt.begin(), wordcnt.end(), topWords.begin(), topWords.end(), _cmp); + for(unordered_map::iterator itr = wordmap.begin(); itr != wordmap.end();) + { + unordered_map::const_iterator cit = _wordIndex.find(itr->first); + if(cit != _wordIndex.end()) + { + itr->second *= cit->second->idf; + itr ++; + } + else + { + itr = wordmap.erase(itr); + } + } + + vector > topWords(min(topN, wordmap.size())); + partial_sort_copy(wordmap.begin(), wordmap.end(), topWords.begin(), topWords.end(), _cmp); keywords.clear(); for(uint i = 0; i < topWords.size(); i++) diff --git a/test/unittest/CMakeLists.txt b/test/unittest/CMakeLists.txt index f10a406..56bf087 100644 --- a/test/unittest/CMakeLists.txt +++ b/test/unittest/CMakeLists.txt @@ -3,6 +3,7 @@ SET(LIBRARY_OUTPUT_PATH ${PROJECT_BINARY_DIR}/test/lib) SET(GTEST_ROOT_DIR gtest-1.6.0) +ADD_DEFINITIONS(-DLOGGER_LEVEL=LL_WARN) INCLUDE_DIRECTORIES(${GTEST_ROOT_DIR} ${GTEST_ROOT_DIR}/include ${PROJECT_SOURCE_DIR}) ADD_LIBRARY(gtest STATIC ${GTEST_ROOT_DIR}/src/gtest-all.cc) FILE(GLOB SRCFILES *.cpp) diff --git a/test/unittest/TKeywordExtractor.cpp b/test/unittest/TKeywordExtractor.cpp index fd331a6..d6a3469 100644 --- a/test/unittest/TKeywordExtractor.cpp +++ b/test/unittest/TKeywordExtractor.cpp @@ -7,13 +7,34 @@ TEST(KeywordExtractorTest, Test1) { KeywordExtractor extractor("../dicts/jieba.dict.utf8"); const char* str = "我来自北京邮电大学。。。 学号 123456"; - const char* res[] = {"我", "来自", "北京邮电大学", "。","。","。"," ","学","号", " 123456"}; + const char* res[] = {"北京邮电大学", "来自"}; vector words; ASSERT_TRUE(extractor); ASSERT_TRUE(extractor.extract(str, words, 2)); - //print(words); - //exit(0); - //print(words); ASSERT_EQ(words, vector(res, res + sizeof(res)/sizeof(res[0]))); } +TEST(KeywordExtractorTest, Test2) +{ + KeywordExtractor extractor("../dicts/jieba.dict.utf8"); + const char* str = "我来自北京邮电大学。。。 学号 123456"; + const char* res[] = {"北京邮电大学", "来自", "学", "号", "我"}; + vector words; + ASSERT_TRUE(extractor); + ASSERT_TRUE(extractor.extract(str, words, 9)); + ASSERT_EQ(words, vector(res, res + sizeof(res)/sizeof(res[0]))); +} + + +TEST(KeywordExtractorTest, Test3) +{ + ifstream ifs("../test/testdata/weicheng.utf8"); + ASSERT_TRUE(ifs); + string str((istreambuf_iterator(ifs)), (istreambuf_iterator())); + KeywordExtractor extractor("../dicts/jieba.dict.utf8"); + vector keywords; + string res; + extractor.extract(str, keywords, 5); + res << keywords; + ASSERT_EQ("[\"第三性\", \"多愁多病\", \"记挂着\", \"揭去\", \"贫血症\"]", res); +}