Merge pull request #187 from gumblex/master

不用Trie,减少内存加快速度;优化代码细节
This commit is contained in:
Sun Junyi 2014-10-19 19:43:30 +08:00
commit 4a93f21918
14 changed files with 382 additions and 365 deletions

View File

@ -17,14 +17,13 @@ import logging
DICTIONARY = "dict.txt"
DICT_LOCK = threading.RLock()
trie = None # to be initialized
pfdict = None # to be initialized
FREQ = {}
min_freq = 0.0
total = 0.0
user_word_tag_tab = {}
initialized = False
log_console = logging.StreamHandler(sys.stderr)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@ -34,84 +33,79 @@ def setLogLevel(log_level):
global logger
logger.setLevel(log_level)
def gen_trie(f_name):
def gen_pfdict(f_name):
lfreq = {}
trie = {}
pfdict = set()
ltotal = 0.0
with open(f_name, 'rb') as f:
lineno = 0
for line in f.read().rstrip().decode('utf-8').split('\n'):
lineno += 1
try:
word,freq,_ = line.split(' ')
word,freq = line.split(' ')[:2]
freq = float(freq)
lfreq[word] = freq
ltotal += freq
p = trie
for c in word:
if c not in p:
p[c] ={}
p = p[c]
p['']='' #ending flag
for ch in xrange(len(word)):
pfdict.add(word[:ch+1])
except ValueError, e:
logger.debug('%s at line %s %s' % (f_name, lineno, line))
raise ValueError, e
return trie, lfreq,ltotal
return pfdict, lfreq, ltotal
def initialize(*args):
global trie, FREQ, total, min_freq, initialized
if len(args)==0:
global pfdict, FREQ, total, min_freq, initialized
if not args:
dictionary = DICTIONARY
else:
dictionary = args[0]
with DICT_LOCK:
if initialized:
return
if trie:
del trie
trie = None
if pfdict:
del pfdict
pfdict = None
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
abs_path = os.path.join(_curpath,dictionary)
logger.debug("Building Trie..., from %s" % abs_path)
logger.debug("Building prefix dict from %s ..." % abs_path)
t1 = time.time()
if abs_path == os.path.join(_curpath,"dict.txt"): #defautl dictionary
if abs_path == os.path.join(_curpath, "dict.txt"): #default dictionary
cache_file = os.path.join(tempfile.gettempdir(), "jieba.cache")
else: #customer dictionary
cache_file = os.path.join(tempfile.gettempdir(),"jieba.user."+str(hash(abs_path))+".cache")
else: #custom dictionary
cache_file = os.path.join(tempfile.gettempdir(), "jieba.user.%s.cache" % hash(abs_path))
load_from_cache_fail = True
if os.path.exists(cache_file) and os.path.getmtime(cache_file) > os.path.getmtime(abs_path):
logger.debug("loading model from cache %s" % cache_file)
logger.debug("Loading model from cache %s" % cache_file)
try:
trie,FREQ,total,min_freq = marshal.load(open(cache_file,'rb'))
load_from_cache_fail = False
pfdict,FREQ,total,min_freq = marshal.load(open(cache_file,'rb'))
# prevent conflict with old version
load_from_cache_fail = not isinstance(pfdict, set)
except:
load_from_cache_fail = True
if load_from_cache_fail:
trie,FREQ,total = gen_trie(abs_path)
pfdict,FREQ,total = gen_pfdict(abs_path)
FREQ = dict([(k,log(float(v)/total)) for k,v in FREQ.iteritems()]) #normalize
min_freq = min(FREQ.itervalues())
logger.debug("dumping model to file cache %s" % cache_file)
logger.debug("Dumping model to file cache %s" % cache_file)
try:
tmp_suffix = "."+str(random.random())
with open(cache_file+tmp_suffix,'wb') as temp_cache_file:
marshal.dump((trie,FREQ,total,min_freq),temp_cache_file)
marshal.dump((pfdict,FREQ,total,min_freq), temp_cache_file)
if os.name == 'nt':
import shutil
replace_file = shutil.move
from shutil import move as replace_file
else:
replace_file = os.rename
replace_file(cache_file + tmp_suffix, cache_file)
except:
logger.error("dump cache file failed.")
logger.exception("")
logger.exception("Dump cache file failed.")
initialized = True
logger.debug("loading model cost %s seconds." % (time.time() - t1))
logger.debug("Trie has been built succesfully.")
logger.debug("Loading model cost %s seconds." % (time.time() - t1))
logger.debug("Prefix dict has been built succesfully.")
def require_initialized(fn):
@ -151,30 +145,21 @@ def calc(sentence,DAG,idx,route):
@require_initialized
def get_DAG(sentence):
N = len(sentence)
i,j=0,0
p = trie
global pfdict, FREQ
DAG = {}
while i<N:
c = sentence[j]
if c in p:
p = p[c]
if '' in p:
if i not in DAG:
DAG[i]=[]
DAG[i].append(j)
j+=1
if j>=N:
N = len(sentence)
for k in xrange(N):
tmplist = []
i = k
frag = sentence[k]
while i < N and frag in pfdict:
if frag in FREQ:
tmplist.append(i)
i += 1
j=i
p=trie
else:
p = trie
i+=1
j=i
for i in xrange(len(sentence)):
if i not in DAG:
DAG[i] =[i]
frag = sentence[k:i+1]
if not tmplist:
tmplist.append(k)
DAG[k] = tmplist
return DAG
def __cut_DAG_NO_HMM(sentence):
@ -192,12 +177,12 @@ def __cut_DAG_NO_HMM(sentence):
buf += l_word
x = y
else:
if len(buf)>0:
if buf:
yield buf
buf = u''
yield l_word
x = y
if len(buf)>0:
if buf:
yield buf
buf = u''
@ -214,14 +199,14 @@ def __cut_DAG(sentence):
if y-x == 1:
buf += l_word
else:
if len(buf)>0:
if buf:
if len(buf) == 1:
yield buf
buf = u''
else:
if (buf not in FREQ):
regognized = finalseg.cut(buf)
for t in regognized:
recognized = finalseg.cut(buf)
for t in recognized:
yield t
else:
for elem in buf:
@ -230,13 +215,12 @@ def __cut_DAG(sentence):
yield l_word
x = y
if len(buf)>0:
if buf:
if len(buf) == 1:
yield buf
else:
if (buf not in FREQ):
regognized = finalseg.cut(buf)
for t in regognized:
elif (buf not in FREQ):
recognized = finalseg.cut(buf)
for t in recognized:
yield t
else:
for elem in buf:
@ -246,31 +230,32 @@ def cut(sentence,cut_all=False,HMM=True):
'''The main function that segments an entire sentence that contains
Chinese characters into seperated words.
Parameter:
- sentence: The String to be segmented
- cut_all: Model. True means full pattern, false means accurate pattern.
- HMM: Whether use Hidden Markov Model.
- sentence: The str/unicode to be segmented.
- cut_all: Model type. True for full pattern, False for accurate pattern.
- HMM: Whether to use the Hidden Markov Model.
'''
if not isinstance(sentence, unicode):
try:
sentence = sentence.decode('utf-8')
except UnicodeDecodeError:
sentence = sentence.decode('gbk', 'ignore')
'''
\u4E00-\u9FA5a-zA-Z0-9+#&\._ : All non-space characters. Will be handled with re_han
\r\n|\s : whitespace characters. Will not be Handled.
'''
re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U), re.compile(ur"(\r\n|\s)", re.U)
# \u4E00-\u9FA5a-zA-Z0-9+#&\._ : All non-space characters. Will be handled with re_han
# \r\n|\s : whitespace characters. Will not be handled.
if cut_all:
re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5]+)", re.U), re.compile(ur"[^a-zA-Z0-9+#\n]", re.U)
else:
re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U), re.compile(ur"(\r\n|\s)", re.U)
blocks = re_han.split(sentence)
if HMM:
if cut_all:
cut_block = __cut_all
elif HMM:
cut_block = __cut_DAG
else:
cut_block = __cut_DAG_NO_HMM
if cut_all:
cut_block = __cut_all
for blk in blocks:
if len(blk)==0:
if not blk:
continue
if re_han.match(blk):
for word in cut_block(blk):
@ -312,37 +297,30 @@ def load_userdict(f):
...
Word type may be ignored
'''
global trie,total,FREQ
if isinstance(f, (str, unicode)):
f = open(f, 'rb')
content = f.read().decode('utf-8')
line_no = 0
for line in content.split("\n"):
line_no += 1
if line.rstrip()=='': continue
if not line.rstrip():
continue
tup = line.split(" ")
word, freq = tup[0], tup[1]
if freq.isdigit() is False: continue
if freq.isdigit() is False:
continue
if line_no == 1:
word = word.replace(u'\ufeff',u"") #remove bom flag if it exists
if len(tup)==3:
add_word(word, freq, tup[2])
else:
add_word(word, freq)
add_word(*tup)
@require_initialized
def add_word(word, freq, tag=None):
global FREQ, trie, total, user_word_tag_tab
freq = float(freq)
FREQ[word] = log(freq / total)
global FREQ, pfdict, total, user_word_tag_tab
FREQ[word] = log(float(freq) / total)
if tag is not None:
user_word_tag_tab[word] = tag.strip()
p = trie
for c in word:
if c not in p:
p[c] = {}
p = p[c]
p[''] = '' # ending flag
for ch in xrange(len(word)):
pfdict.add(word[:ch+1])
__ref_cut = cut
__ref_cut_for_search = cut_for_search
@ -365,7 +343,7 @@ def enable_parallel(processnum=None):
if sys.version_info[0]==2 and sys.version_info[1]<6:
raise Exception("jieba: the parallel feature needs Python version>2.5")
from multiprocessing import Pool, cpu_count
if processnum==None:
if processnum is None:
processnum = cpu_count()
pool = Pool(processnum)
@ -373,8 +351,7 @@ def enable_parallel(processnum=None):
parts = re.compile('([\r\n]+)').split(sentence)
if cut_all:
result = pool.map(__lcut_all, parts)
else:
if HMM:
elif HMM:
result = pool.map(__lcut, parts)
else:
result = pool.map(__lcut_no_hmm, parts)
@ -415,9 +392,14 @@ def get_abs_path_dict():
return abs_path
def tokenize(unicode_sentence, mode="default", HMM=True):
#mode ("default" or "search")
"""Tokenize a sentence and yields tuples of (word, start, end)
Parameter:
- sentence: the unicode to be segmented.
- mode: "default" or "search", "search" is for finer segmentation.
- HMM: whether to use the Hidden Markov Model.
"""
if not isinstance(unicode_sentence, unicode):
raise Exception("jieba: the input parameter should unicode.")
raise Exception("jieba: the input parameter should be unicode.")
start = 0
if mode == 'default':
for w in cut(unicode_sentence, HMM=HMM):
@ -439,4 +421,3 @@ def tokenize(unicode_sentence,mode="default",HMM=True):
yield (gram3, start+i, start+i+3)
yield (w, start, start+width)
start += width

37
jieba/__main__.py Normal file
View File

@ -0,0 +1,37 @@
"""Jieba command line interface."""
import sys
import jieba
from argparse import ArgumentParser
parser = ArgumentParser(usage="%s -m jieba [options] filename" % sys.executable, description="Jieba command line interface.", epilog="If no filename specified, use STDIN instead.")
parser.add_argument("-d", "--delimiter", metavar="DELIM", default=' / ',
nargs='?', const=' ',
help="use DELIM instead of ' / ' for word delimiter; use a space if it is without DELIM")
parser.add_argument("-a", "--cut-all",
action="store_true", dest="cutall", default=False,
help="full pattern cutting")
parser.add_argument("-n", "--no-hmm", dest="hmm", action="store_false",
default=True, help="don't use the Hidden Markov Model")
parser.add_argument("-q", "--quiet", action="store_true", default=False,
help="don't print loading messages to stderr")
parser.add_argument("-V", '--version', action='version',
version="Jieba " + jieba.__version__)
parser.add_argument("filename", nargs='?', help="input file")
args = parser.parse_args()
if args.quiet:
jieba.setLogLevel(60)
delim = unicode(args.delimiter)
cutall = args.cutall
hmm = args.hmm
fp = open(args.filename, 'r') if args.filename else sys.stdin
jieba.initialize()
ln = fp.readline()
while ln:
l = ln.rstrip('\r\n')
print(delim.join(jieba.cut(ln.rstrip('\r\n'), cutall, hmm)).encode('utf-8'))
ln = fp.readline()
fp.close()

View File

@ -9,9 +9,11 @@ except ImportError:
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
abs_path = os.path.join(_curpath, "idf.txt")
STOP_WORDS = set([
"the","of","is","and","to","in","that","we","for","an","are","by","be","as","on","with","can","if","from","which","you","it","this","then","at","have","all","not","one","has","or","that"
])
STOP_WORDS = set((
"the","of","is","and","to","in","that","we","for","an","are",
"by","be","as","on","with","can","if","from","which","you","it",
"this","then","at","have","all","not","one","has","or","that"
))
class IDFLoader:
def __init__(self):
@ -45,7 +47,6 @@ def set_idf_path(idf_path):
if not os.path.exists(new_abs_path):
raise Exception("jieba: path does not exist: " + new_abs_path)
idf_loader.set_new_path(new_abs_path)
return
def set_stop_words(stop_words_path):
global STOP_WORDS
@ -56,7 +57,6 @@ def set_stop_words(stop_words_path):
lines = content.split('\n')
for line in lines:
STOP_WORDS.add(line)
return
def extract_tags(sentence, topK=20):
global STOP_WORDS
@ -66,8 +66,10 @@ def extract_tags(sentence,topK=20):
words = jieba.cut(sentence)
freq = {}
for w in words:
if len(w.strip())<2: continue
if w.lower() in STOP_WORDS: continue
if len(w.strip()) < 2:
continue
if w.lower() in STOP_WORDS:
continue
freq[w] = freq.get(w, 0.0) + 1.0
total = sum(freq.values())
freq = [(k,v/total) for k,v in freq.iteritems()]

View File

@ -1,4 +1,4 @@
#encoding=utf-8
##encoding=utf-8
from whoosh.analysis import RegexAnalyzer,LowercaseFilter,StopFilter,StemFilter
from whoosh.analysis import Tokenizer,Token
from whoosh.lang.porter import stem
@ -19,10 +19,7 @@ class ChineseTokenizer(Tokenizer):
words = jieba.tokenize(text, mode="search")
token = Token()
for (w,start_pos,stop_pos) in words:
if not accepted_chars.match(w):
if len(w)>1:
pass
else:
if not accepted_chars.match(w) and len(w)<=1:
continue
token.original = token.text = w
token.pos = start_pos
@ -31,5 +28,6 @@ class ChineseTokenizer(Tokenizer):
yield token
def ChineseAnalyzer(stoplist=STOP_WORDS, minsize=1, stemfn=stem, cachesize=50000):
return ChineseTokenizer() | LowercaseFilter() | StopFilter(stoplist=stoplist,minsize=minsize)\
|StemFilter(stemfn=stemfn, ignore=None,cachesize=cachesize)
return (ChineseTokenizer() | LowercaseFilter() |
StopFilter(stoplist=stoplist,minsize=minsize) |
StemFilter(stemfn=stemfn, ignore=None,cachesize=cachesize))

View File

@ -23,19 +23,19 @@ def load_model():
start_p = {}
abs_path = os.path.join(_curpath, PROB_START_P)
with open(abs_path, mode='rb') as f:
with open(abs_path, mode='r') as f:
start_p = marshal.load(f)
f.closed
trans_p = {}
abs_path = os.path.join(_curpath, PROB_TRANS_P)
with open(abs_path, 'rb') as f:
with open(abs_path, 'r') as f:
trans_p = marshal.load(f)
f.closed
emit_p = {}
abs_path = os.path.join(_curpath, PROB_EMIT_P)
with file(abs_path, 'rb') as f:
with open(abs_path, 'r') as f:
emit_p = marshal.load(f)
f.closed
@ -53,7 +53,7 @@ def viterbi(obs, states, start_p, trans_p, emit_p):
for y in states: #init
V[0][y] = start_p[y] + emit_p[y].get(obs[0],MIN_FLOAT)
path[y] = [y]
for t in range(1,len(obs)):
for t in xrange(1,len(obs)):
V.append({})
newpath = {}
for y in states:
@ -87,10 +87,10 @@ def __cut(sentence):
yield sentence[next:]
def cut(sentence):
if not ( type(sentence) is unicode):
if not isinstance(sentence, unicode):
try:
sentence = sentence.decode('utf-8')
except:
except UnicodeDecodeError:
sentence = sentence.decode('gbk', 'ignore')
re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5]+)"), re.compile(ur"(\d+\.\d+|[a-zA-Z0-9]+)")
blocks = re_han.split(sentence)
@ -101,5 +101,5 @@ def cut(sentence):
else:
tmp = re_skip.split(blk)
for x in tmp:
if x!="":
if x:
yield x

View File

@ -18,10 +18,11 @@ def load_model(f_name,isJython=True):
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
result = {}
with file(f_name, "rb") as f:
for line in open(f_name,"rb"):
with open(f_name, "r") as f:
for line in f:
line = line.strip()
if line=="":continue
if not line:
continue
word, _, tag = line.split(' ')
result[word.decode('utf-8')] = tag
f.closed
@ -30,25 +31,25 @@ def load_model(f_name,isJython=True):
start_p = {}
abs_path = os.path.join(_curpath, PROB_START_P)
with open(abs_path, mode='rb') as f:
with open(abs_path, mode='r') as f:
start_p = marshal.load(f)
f.closed
trans_p = {}
abs_path = os.path.join(_curpath, PROB_TRANS_P)
with open(abs_path, 'rb') as f:
with open(abs_path, 'r') as f:
trans_p = marshal.load(f)
f.closed
emit_p = {}
abs_path = os.path.join(_curpath, PROB_EMIT_P)
with file(abs_path, 'rb') as f:
with open(abs_path, 'r') as f:
emit_p = marshal.load(f)
f.closed
state = {}
abs_path = os.path.join(_curpath, CHAR_STATE_TAB_P)
with file(abs_path, 'rb') as f:
with open(abs_path, 'r') as f:
state = marshal.load(f)
f.closed
@ -65,7 +66,7 @@ def makesure_userdict_loaded(fn):
@wraps(fn)
def wrapped(*args,**kwargs):
if len(jieba.user_word_tag_tab)>0:
if jieba.user_word_tag_tab:
word_tag_tab.update(jieba.user_word_tag_tab)
jieba.user_word_tag_tab = {}
return fn(*args,**kwargs)
@ -78,7 +79,7 @@ class pair(object):
self.flag = flag
def __unicode__(self):
return self.word+u"/"+self.flag
return u'%s/%s' % (self.word, self.flag)
def __repr__(self):
return self.__str__()
@ -117,7 +118,7 @@ def __cut_detail(sentence):
else:
tmp = re_skip.split(blk)
for x in tmp:
if x!="":
if x:
if re_num.match(x):
yield pair(x, 'm')
elif re_eng.match(x):
@ -140,12 +141,12 @@ def __cut_DAG_NO_HMM(sentence):
buf += l_word
x = y
else:
if len(buf)>0:
if buf:
yield pair(buf,'eng')
buf = u''
yield pair(l_word, word_tag_tab.get(l_word, 'x'))
x = y
if len(buf)>0:
if buf:
yield pair(buf,'eng')
buf = u''
@ -164,14 +165,14 @@ def __cut_DAG(sentence):
if y-x == 1:
buf += l_word
else:
if len(buf)>0:
if buf:
if len(buf) == 1:
yield pair(buf, word_tag_tab.get(buf, 'x'))
buf = u''
else:
if (buf not in jieba.FREQ):
regognized = __cut_detail(buf)
for t in regognized:
recognized = __cut_detail(buf)
for t in recognized:
yield t
else:
for elem in buf:
@ -180,23 +181,22 @@ def __cut_DAG(sentence):
yield pair(l_word, word_tag_tab.get(l_word, 'x'))
x = y
if len(buf)>0:
if buf:
if len(buf) == 1:
yield pair(buf, word_tag_tab.get(buf, 'x'))
else:
if (buf not in jieba.FREQ):
regognized = __cut_detail(buf)
for t in regognized:
elif (buf not in jieba.FREQ):
recognized = __cut_detail(buf)
for t in recognized:
yield t
else:
for elem in buf:
yield pair(elem, word_tag_tab.get(elem, 'x'))
def __cut_internal(sentence, HMM=True):
if not ( type(sentence) is unicode):
if not isinstance(sentence, unicode):
try:
sentence = sentence.decode('utf-8')
except:
except UnicodeDecodeError:
sentence = sentence.decode('gbk', 'ignore')
re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)"), re.compile(ur"(\r\n|\s)")
re_eng, re_num = re.compile(ur"[a-zA-Z0-9]+"), re.compile(ur"[\.0-9]+")
@ -232,7 +232,7 @@ def __lcut_internal_no_hmm(sentence):
@makesure_userdict_loaded
def cut(sentence, HMM=True):
if (not hasattr(jieba,'pool')) or (jieba.pool==None):
if (not hasattr(jieba, 'pool')) or (jieba.pool is None):
for w in __cut_internal(sentence, HMM=HMM):
yield w
else:

View File

@ -14,28 +14,27 @@ def viterbi(obs, states, start_p, trans_p, emit_p):
for y in states.get(obs[0], all_states): #init
V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT)
mem_path[0][y] = ''
for t in range(1,len(obs)):
for t in xrange(1, len(obs)):
V.append({})
mem_path.append({})
#prev_states = get_top_states(V[t-1])
prev_states = [x for x in mem_path[t-1].keys() if len(trans_p[x]) > 0]
prev_states_expect_next = set((y for x in prev_states for y in trans_p[x].keys()))
obs_states = states.get(obs[t],all_states)
obs_states = set(obs_states) & set(prev_states_expect_next)
obs_states = set(states.get(obs[t], all_states)) & prev_states_expect_next
if len(obs_states)==0: obs_states = prev_states_expect_next
if len(obs_states)==0: obs_states = all_states
if not obs_states:
obs_states = prev_states_expect_next if prev_states_expect_next else all_states
for y in obs_states:
(prob,state ) = max([(V[t-1][y0] + trans_p[y0].get(y,MIN_INF) + emit_p[y].get(obs[t],MIN_FLOAT) ,y0) for y0 in prev_states])
prob, state = max([(V[t-1][y0] + trans_p[y0].get(y,MIN_INF) + emit_p[y].get(obs[t],MIN_FLOAT), y0) for y0 in prev_states])
V[t][y] = prob
mem_path[t][y] = state
last = [(V[-1][y], y) for y in mem_path[-1].keys()]
#if len(last)==0:
#print obs
(prob, state) = max(last)
prob, state = max(last)
route = [None] * len(obs)
i = len(obs) - 1

View File

@ -4,14 +4,14 @@ sys.path.append("../")
import jieba
seg_list = jieba.cut("我来到北京清华大学", cut_all=True)
print "Full Mode:", "/ ".join(seg_list) # 全模式
seg_list = jieba.cut(u"我来到北京清华大学", cut_all=True)
print u"Full Mode:", u"/ ".join(seg_list) # 全模式
seg_list = jieba.cut("我来到北京清华大学", cut_all=False)
print "Default Mode:", "/ ".join(seg_list) # 默认模式
seg_list = jieba.cut(u"我来到北京清华大学", cut_all=False)
print u"Default Mode:", u"/ ".join(seg_list) # 默认模式
seg_list = jieba.cut("他来到了网易杭研大厦")
print ", ".join(seg_list)
seg_list = jieba.cut(u"他来到了网易杭研大厦")
print u", ".join(seg_list)
seg_list = jieba.cut_for_search("小明硕士毕业于中国科学院计算所,后在日本京都大学深造") # 搜索引擎模式
print ", ".join(seg_list)
seg_list = jieba.cut_for_search(u"小明硕士毕业于中国科学院计算所,后在日本京都大学深造") # 搜索引擎模式
print u", ".join(seg_list)