use prefix dict instead of trie, add a command line interface, and a few small improvements

This commit is contained in:
Dingyuan Wang 2014-10-18 22:22:14 +08:00
parent eb98eb9248
commit 51df77831b
8 changed files with 331 additions and 317 deletions

View File

@ -17,14 +17,13 @@ import logging
DICTIONARY = "dict.txt" DICTIONARY = "dict.txt"
DICT_LOCK = threading.RLock() DICT_LOCK = threading.RLock()
trie = None # to be initialized pfdict = None # to be initialized
FREQ = {} FREQ = {}
min_freq = 0.0 min_freq = 0.0
total = 0.0 total = 0.0
user_word_tag_tab = {} user_word_tag_tab = {}
initialized = False initialized = False
log_console = logging.StreamHandler(sys.stderr) log_console = logging.StreamHandler(sys.stderr)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -34,84 +33,79 @@ def setLogLevel(log_level):
global logger global logger
logger.setLevel(log_level) logger.setLevel(log_level)
def gen_trie(f_name): def gen_pfdict(f_name):
lfreq = {} lfreq = {}
trie = {} pfdict = set()
ltotal = 0.0 ltotal = 0.0
with open(f_name, 'rb') as f: with open(f_name, 'rb') as f:
lineno = 0 lineno = 0
for line in f.read().rstrip().decode('utf-8').split('\n'): for line in f.read().rstrip().decode('utf-8').split('\n'):
lineno += 1 lineno += 1
try: try:
word,freq,_ = line.split(' ') word,freq = line.split(' ')[:2]
freq = float(freq) freq = float(freq)
lfreq[word] = freq lfreq[word] = freq
ltotal += freq ltotal += freq
p = trie for ch in xrange(len(word)):
for c in word: pfdict.add(word[:ch+1])
if c not in p:
p[c] ={}
p = p[c]
p['']='' #ending flag
except ValueError, e: except ValueError, e:
logger.debug('%s at line %s %s' % (f_name, lineno, line)) logger.debug('%s at line %s %s' % (f_name, lineno, line))
raise ValueError, e raise ValueError, e
return trie, lfreq,ltotal return pfdict, lfreq, ltotal
def initialize(*args): def initialize(*args):
global trie, FREQ, total, min_freq, initialized global pfdict, FREQ, total, min_freq, initialized
if len(args)==0: if not args:
dictionary = DICTIONARY dictionary = DICTIONARY
else: else:
dictionary = args[0] dictionary = args[0]
with DICT_LOCK: with DICT_LOCK:
if initialized: if initialized:
return return
if trie: if pfdict:
del trie del pfdict
trie = None pfdict = None
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) _curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
abs_path = os.path.join(_curpath,dictionary) abs_path = os.path.join(_curpath,dictionary)
logger.debug("Building Trie..., from %s" % abs_path) logger.debug("Building prefix dict from %s ..." % abs_path)
t1 = time.time() t1 = time.time()
if abs_path == os.path.join(_curpath,"dict.txt"): #defautl dictionary if abs_path == os.path.join(_curpath, "dict.txt"): #default dictionary
cache_file = os.path.join(tempfile.gettempdir(), "jieba.cache") cache_file = os.path.join(tempfile.gettempdir(), "jieba.cache")
else: #customer dictionary else: #custom dictionary
cache_file = os.path.join(tempfile.gettempdir(),"jieba.user."+str(hash(abs_path))+".cache") cache_file = os.path.join(tempfile.gettempdir(), "jieba.user.%s.cache" % hash(abs_path))
load_from_cache_fail = True load_from_cache_fail = True
if os.path.exists(cache_file) and os.path.getmtime(cache_file) > os.path.getmtime(abs_path): if os.path.exists(cache_file) and os.path.getmtime(cache_file) > os.path.getmtime(abs_path):
logger.debug("loading model from cache %s" % cache_file) logger.debug("Loading model from cache %s" % cache_file)
try: try:
trie,FREQ,total,min_freq = marshal.load(open(cache_file,'rb')) pfdict,FREQ,total,min_freq = marshal.load(open(cache_file,'rb'))
load_from_cache_fail = False # prevent conflict with old version
load_from_cache_fail = not isinstance(pfdict, set)
except: except:
load_from_cache_fail = True load_from_cache_fail = True
if load_from_cache_fail: if load_from_cache_fail:
trie,FREQ,total = gen_trie(abs_path) pfdict,FREQ,total = gen_pfdict(abs_path)
FREQ = dict([(k,log(float(v)/total)) for k,v in FREQ.iteritems()]) #normalize FREQ = dict([(k,log(float(v)/total)) for k,v in FREQ.iteritems()]) #normalize
min_freq = min(FREQ.itervalues()) min_freq = min(FREQ.itervalues())
logger.debug("dumping model to file cache %s" % cache_file) logger.debug("Dumping model to file cache %s" % cache_file)
try: try:
tmp_suffix = "."+str(random.random()) tmp_suffix = "."+str(random.random())
with open(cache_file+tmp_suffix,'wb') as temp_cache_file: with open(cache_file+tmp_suffix,'wb') as temp_cache_file:
marshal.dump((trie,FREQ,total,min_freq),temp_cache_file) marshal.dump((pfdict,FREQ,total,min_freq), temp_cache_file)
if os.name == 'nt': if os.name == 'nt':
import shutil from shutil import move as replace_file
replace_file = shutil.move
else: else:
replace_file = os.rename replace_file = os.rename
replace_file(cache_file + tmp_suffix, cache_file) replace_file(cache_file + tmp_suffix, cache_file)
except: except:
logger.error("dump cache file failed.") logger.exception("Dump cache file failed.")
logger.exception("")
initialized = True initialized = True
logger.debug("loading model cost %s seconds." % (time.time() - t1)) logger.debug("Loading model cost %s seconds." % (time.time() - t1))
logger.debug("Trie has been built succesfully.") logger.debug("Prefix dict has been built succesfully.")
def require_initialized(fn): def require_initialized(fn):
@ -151,30 +145,21 @@ def calc(sentence,DAG,idx,route):
@require_initialized @require_initialized
def get_DAG(sentence): def get_DAG(sentence):
N = len(sentence) global pfdict, FREQ
i,j=0,0
p = trie
DAG = {} DAG = {}
while i<N: N = len(sentence)
c = sentence[j] for k in xrange(N):
if c in p: tmplist = []
p = p[c] i = k
if '' in p: frag = sentence[k]
if i not in DAG: while i < N and frag in pfdict:
DAG[i]=[] if frag in FREQ:
DAG[i].append(j) tmplist.append(i)
j+=1
if j>=N:
i += 1 i += 1
j=i frag = sentence[k:i+1]
p=trie if not tmplist:
else: tmplist.append(k)
p = trie DAG[k] = tmplist
i+=1
j=i
for i in xrange(len(sentence)):
if i not in DAG:
DAG[i] =[i]
return DAG return DAG
def __cut_DAG_NO_HMM(sentence): def __cut_DAG_NO_HMM(sentence):
@ -192,12 +177,12 @@ def __cut_DAG_NO_HMM(sentence):
buf += l_word buf += l_word
x = y x = y
else: else:
if len(buf)>0: if buf:
yield buf yield buf
buf = u'' buf = u''
yield l_word yield l_word
x = y x = y
if len(buf)>0: if buf:
yield buf yield buf
buf = u'' buf = u''
@ -214,14 +199,14 @@ def __cut_DAG(sentence):
if y-x == 1: if y-x == 1:
buf += l_word buf += l_word
else: else:
if len(buf)>0: if buf:
if len(buf) == 1: if len(buf) == 1:
yield buf yield buf
buf = u'' buf = u''
else: else:
if (buf not in FREQ): if (buf not in FREQ):
regognized = finalseg.cut(buf) recognized = finalseg.cut(buf)
for t in regognized: for t in recognized:
yield t yield t
else: else:
for elem in buf: for elem in buf:
@ -230,13 +215,12 @@ def __cut_DAG(sentence):
yield l_word yield l_word
x = y x = y
if len(buf)>0: if buf:
if len(buf) == 1: if len(buf) == 1:
yield buf yield buf
else: elif (buf not in FREQ):
if (buf not in FREQ): recognized = finalseg.cut(buf)
regognized = finalseg.cut(buf) for t in recognized:
for t in regognized:
yield t yield t
else: else:
for elem in buf: for elem in buf:
@ -246,31 +230,32 @@ def cut(sentence,cut_all=False,HMM=True):
'''The main function that segments an entire sentence that contains '''The main function that segments an entire sentence that contains
Chinese characters into seperated words. Chinese characters into seperated words.
Parameter: Parameter:
- sentence: The String to be segmented - sentence: The str/unicode to be segmented.
- cut_all: Model. True means full pattern, false means accurate pattern. - cut_all: Model type. True for full pattern, False for accurate pattern.
- HMM: Whether use Hidden Markov Model. - HMM: Whether to use the Hidden Markov Model.
''' '''
if not isinstance(sentence, unicode): if not isinstance(sentence, unicode):
try: try:
sentence = sentence.decode('utf-8') sentence = sentence.decode('utf-8')
except UnicodeDecodeError: except UnicodeDecodeError:
sentence = sentence.decode('gbk', 'ignore') sentence = sentence.decode('gbk', 'ignore')
'''
\u4E00-\u9FA5a-zA-Z0-9+#&\._ : All non-space characters. Will be handled with re_han # \u4E00-\u9FA5a-zA-Z0-9+#&\._ : All non-space characters. Will be handled with re_han
\r\n|\s : whitespace characters. Will not be Handled. # \r\n|\s : whitespace characters. Will not be handled.
'''
re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U), re.compile(ur"(\r\n|\s)", re.U)
if cut_all: if cut_all:
re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5]+)", re.U), re.compile(ur"[^a-zA-Z0-9+#\n]", re.U) re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5]+)", re.U), re.compile(ur"[^a-zA-Z0-9+#\n]", re.U)
else:
re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U), re.compile(ur"(\r\n|\s)", re.U)
blocks = re_han.split(sentence) blocks = re_han.split(sentence)
if HMM: if cut_all:
cut_block = __cut_all
elif HMM:
cut_block = __cut_DAG cut_block = __cut_DAG
else: else:
cut_block = __cut_DAG_NO_HMM cut_block = __cut_DAG_NO_HMM
if cut_all:
cut_block = __cut_all
for blk in blocks: for blk in blocks:
if len(blk)==0: if not blk:
continue continue
if re_han.match(blk): if re_han.match(blk):
for word in cut_block(blk): for word in cut_block(blk):
@ -312,37 +297,30 @@ def load_userdict(f):
... ...
Word type may be ignored Word type may be ignored
''' '''
global trie,total,FREQ
if isinstance(f, (str, unicode)): if isinstance(f, (str, unicode)):
f = open(f, 'rb') f = open(f, 'rb')
content = f.read().decode('utf-8') content = f.read().decode('utf-8')
line_no = 0 line_no = 0
for line in content.split("\n"): for line in content.split("\n"):
line_no += 1 line_no += 1
if line.rstrip()=='': continue if not line.rstrip():
continue
tup = line.split(" ") tup = line.split(" ")
word, freq = tup[0], tup[1] word, freq = tup[0], tup[1]
if freq.isdigit() is False: continue if freq.isdigit() is False:
continue
if line_no == 1: if line_no == 1:
word = word.replace(u'\ufeff',u"") #remove bom flag if it exists word = word.replace(u'\ufeff',u"") #remove bom flag if it exists
if len(tup)==3: add_word(*tup)
add_word(word, freq, tup[2])
else:
add_word(word, freq)
@require_initialized @require_initialized
def add_word(word, freq, tag=None): def add_word(word, freq, tag=None):
global FREQ, trie, total, user_word_tag_tab global FREQ, pfdict, total, user_word_tag_tab
freq = float(freq) FREQ[word] = log(float(freq) / total)
FREQ[word] = log(freq / total)
if tag is not None: if tag is not None:
user_word_tag_tab[word] = tag.strip() user_word_tag_tab[word] = tag.strip()
p = trie for ch in xrange(len(word)):
for c in word: pfdict.add(word[:ch+1])
if c not in p:
p[c] = {}
p = p[c]
p[''] = '' # ending flag
__ref_cut = cut __ref_cut = cut
__ref_cut_for_search = cut_for_search __ref_cut_for_search = cut_for_search
@ -365,7 +343,7 @@ def enable_parallel(processnum=None):
if sys.version_info[0]==2 and sys.version_info[1]<6: if sys.version_info[0]==2 and sys.version_info[1]<6:
raise Exception("jieba: the parallel feature needs Python version>2.5") raise Exception("jieba: the parallel feature needs Python version>2.5")
from multiprocessing import Pool, cpu_count from multiprocessing import Pool, cpu_count
if processnum==None: if processnum is None:
processnum = cpu_count() processnum = cpu_count()
pool = Pool(processnum) pool = Pool(processnum)
@ -373,8 +351,7 @@ def enable_parallel(processnum=None):
parts = re.compile('([\r\n]+)').split(sentence) parts = re.compile('([\r\n]+)').split(sentence)
if cut_all: if cut_all:
result = pool.map(__lcut_all, parts) result = pool.map(__lcut_all, parts)
else: elif HMM:
if HMM:
result = pool.map(__lcut, parts) result = pool.map(__lcut, parts)
else: else:
result = pool.map(__lcut_no_hmm, parts) result = pool.map(__lcut_no_hmm, parts)
@ -415,9 +392,14 @@ def get_abs_path_dict():
return abs_path return abs_path
def tokenize(unicode_sentence, mode="default", HMM=True): def tokenize(unicode_sentence, mode="default", HMM=True):
#mode ("default" or "search") """Tokenize a sentence and yields tuples of (word, start, end)
Parameter:
- sentence: the unicode to be segmented.
- mode: "default" or "search", "search" is for finer segmentation.
- HMM: whether to use the Hidden Markov Model.
"""
if not isinstance(unicode_sentence, unicode): if not isinstance(unicode_sentence, unicode):
raise Exception("jieba: the input parameter should unicode.") raise Exception("jieba: the input parameter should be unicode.")
start = 0 start = 0
if mode == 'default': if mode == 'default':
for w in cut(unicode_sentence, HMM=HMM): for w in cut(unicode_sentence, HMM=HMM):
@ -439,4 +421,3 @@ def tokenize(unicode_sentence,mode="default",HMM=True):
yield (gram3, start+i, start+i+3) yield (gram3, start+i, start+i+3)
yield (w, start, start+width) yield (w, start, start+width)
start += width start += width

35
jieba/__main__.py Normal file
View File

@ -0,0 +1,35 @@
"""Jieba command line interface."""
import sys
import jieba
from argparse import ArgumentParser
parser = ArgumentParser(usage="%s -m jieba [options] filename" % sys.executable, description="Jieba command line interface.", version="Jieba " + jieba.__version__, epilog="If no filename specified, use STDIN instead.")
parser.add_argument("-d", "--delimiter", metavar="DELIM", default=' / ',
nargs='?', const=' ',
help="use DELIM instead of ' / ' for word delimiter; use a space if it is without DELIM")
parser.add_argument("-a", "--cut-all",
action="store_true", dest="cutall", default=False,
help="full pattern cutting")
parser.add_argument("-n", "--no-hmm", dest="hmm", action="store_false",
default=True, help="don't use the Hidden Markov Model")
parser.add_argument("-q", "--quiet", action="store_true", default=False,
help="don't print loading messages to stderr")
parser.add_argument("filename", nargs='?', help="input file")
args = parser.parse_args()
if args.quiet:
jieba.setLogLevel(60)
delim = unicode(args.delimiter)
cutall = args.cutall
hmm = args.hmm
fp = open(args.filename, 'r') if args.filename else sys.stdin
jieba.initialize()
ln = fp.readline()
while ln:
l = ln.rstrip('\r\n')
print(delim.join(jieba.cut(ln.rstrip('\r\n'), cutall, hmm)).encode('utf-8'))
ln = fp.readline()
fp.close()

View File

@ -9,9 +9,11 @@ except ImportError:
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) _curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
abs_path = os.path.join(_curpath, "idf.txt") abs_path = os.path.join(_curpath, "idf.txt")
STOP_WORDS = set([ STOP_WORDS = set((
"the","of","is","and","to","in","that","we","for","an","are","by","be","as","on","with","can","if","from","which","you","it","this","then","at","have","all","not","one","has","or","that" "the","of","is","and","to","in","that","we","for","an","are",
]) "by","be","as","on","with","can","if","from","which","you","it",
"this","then","at","have","all","not","one","has","or","that"
))
class IDFLoader: class IDFLoader:
def __init__(self): def __init__(self):
@ -45,7 +47,6 @@ def set_idf_path(idf_path):
if not os.path.exists(new_abs_path): if not os.path.exists(new_abs_path):
raise Exception("jieba: path does not exist: " + new_abs_path) raise Exception("jieba: path does not exist: " + new_abs_path)
idf_loader.set_new_path(new_abs_path) idf_loader.set_new_path(new_abs_path)
return
def set_stop_words(stop_words_path): def set_stop_words(stop_words_path):
global STOP_WORDS global STOP_WORDS
@ -56,7 +57,6 @@ def set_stop_words(stop_words_path):
lines = content.split('\n') lines = content.split('\n')
for line in lines: for line in lines:
STOP_WORDS.add(line) STOP_WORDS.add(line)
return
def extract_tags(sentence, topK=20): def extract_tags(sentence, topK=20):
global STOP_WORDS global STOP_WORDS
@ -66,8 +66,10 @@ def extract_tags(sentence,topK=20):
words = jieba.cut(sentence) words = jieba.cut(sentence)
freq = {} freq = {}
for w in words: for w in words:
if len(w.strip())<2: continue if len(w.strip()) < 2:
if w.lower() in STOP_WORDS: continue continue
if w.lower() in STOP_WORDS:
continue
freq[w] = freq.get(w, 0.0) + 1.0 freq[w] = freq.get(w, 0.0) + 1.0
total = sum(freq.values()) total = sum(freq.values())
freq = [(k,v/total) for k,v in freq.iteritems()] freq = [(k,v/total) for k,v in freq.iteritems()]

View File

@ -1,4 +1,4 @@
#encoding=utf-8 ##encoding=utf-8
from whoosh.analysis import RegexAnalyzer,LowercaseFilter,StopFilter,StemFilter from whoosh.analysis import RegexAnalyzer,LowercaseFilter,StopFilter,StemFilter
from whoosh.analysis import Tokenizer,Token from whoosh.analysis import Tokenizer,Token
from whoosh.lang.porter import stem from whoosh.lang.porter import stem
@ -19,10 +19,7 @@ class ChineseTokenizer(Tokenizer):
words = jieba.tokenize(text, mode="search") words = jieba.tokenize(text, mode="search")
token = Token() token = Token()
for (w,start_pos,stop_pos) in words: for (w,start_pos,stop_pos) in words:
if not accepted_chars.match(w): if not accepted_chars.match(w) and len(w)<=1:
if len(w)>1:
pass
else:
continue continue
token.original = token.text = w token.original = token.text = w
token.pos = start_pos token.pos = start_pos
@ -31,5 +28,6 @@ class ChineseTokenizer(Tokenizer):
yield token yield token
def ChineseAnalyzer(stoplist=STOP_WORDS, minsize=1, stemfn=stem, cachesize=50000): def ChineseAnalyzer(stoplist=STOP_WORDS, minsize=1, stemfn=stem, cachesize=50000):
return ChineseTokenizer() | LowercaseFilter() | StopFilter(stoplist=stoplist,minsize=minsize)\ return (ChineseTokenizer() | LowercaseFilter() |
|StemFilter(stemfn=stemfn, ignore=None,cachesize=cachesize) StopFilter(stoplist=stoplist,minsize=minsize) |
StemFilter(stemfn=stemfn, ignore=None,cachesize=cachesize))

View File

@ -53,7 +53,7 @@ def viterbi(obs, states, start_p, trans_p, emit_p):
for y in states: #init for y in states: #init
V[0][y] = start_p[y] + emit_p[y].get(obs[0],MIN_FLOAT) V[0][y] = start_p[y] + emit_p[y].get(obs[0],MIN_FLOAT)
path[y] = [y] path[y] = [y]
for t in range(1,len(obs)): for t in xrange(1,len(obs)):
V.append({}) V.append({})
newpath = {} newpath = {}
for y in states: for y in states:
@ -87,10 +87,10 @@ def __cut(sentence):
yield sentence[next:] yield sentence[next:]
def cut(sentence): def cut(sentence):
if not ( type(sentence) is unicode): if not isinstance(sentence, unicode):
try: try:
sentence = sentence.decode('utf-8') sentence = sentence.decode('utf-8')
except: except UnicodeDecodeError:
sentence = sentence.decode('gbk', 'ignore') sentence = sentence.decode('gbk', 'ignore')
re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5]+)"), re.compile(ur"(\d+\.\d+|[a-zA-Z0-9]+)") re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5]+)"), re.compile(ur"(\d+\.\d+|[a-zA-Z0-9]+)")
blocks = re_han.split(sentence) blocks = re_han.split(sentence)
@ -101,5 +101,5 @@ def cut(sentence):
else: else:
tmp = re_skip.split(blk) tmp = re_skip.split(blk)
for x in tmp: for x in tmp:
if x!="": if x:
yield x yield x

View File

@ -78,7 +78,7 @@ class pair(object):
self.flag = flag self.flag = flag
def __unicode__(self): def __unicode__(self):
return self.word+u"/"+self.flag return u'%s/%s' % (self.word, self.flag)
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
@ -117,7 +117,7 @@ def __cut_detail(sentence):
else: else:
tmp = re_skip.split(blk) tmp = re_skip.split(blk)
for x in tmp: for x in tmp:
if x!="": if x:
if re_num.match(x): if re_num.match(x):
yield pair(x, 'm') yield pair(x, 'm')
elif re_eng.match(x): elif re_eng.match(x):
@ -140,12 +140,12 @@ def __cut_DAG_NO_HMM(sentence):
buf += l_word buf += l_word
x = y x = y
else: else:
if len(buf)>0: if buf:
yield pair(buf,'eng') yield pair(buf,'eng')
buf = u'' buf = u''
yield pair(l_word, word_tag_tab.get(l_word, 'x')) yield pair(l_word, word_tag_tab.get(l_word, 'x'))
x = y x = y
if len(buf)>0: if buf:
yield pair(buf,'eng') yield pair(buf,'eng')
buf = u'' buf = u''
@ -164,14 +164,14 @@ def __cut_DAG(sentence):
if y-x == 1: if y-x == 1:
buf += l_word buf += l_word
else: else:
if len(buf)>0: if buf:
if len(buf) == 1: if len(buf) == 1:
yield pair(buf, word_tag_tab.get(buf, 'x')) yield pair(buf, word_tag_tab.get(buf, 'x'))
buf = u'' buf = u''
else: else:
if (buf not in jieba.FREQ): if (buf not in jieba.FREQ):
regognized = __cut_detail(buf) recognized = __cut_detail(buf)
for t in regognized: for t in recognized:
yield t yield t
else: else:
for elem in buf: for elem in buf:
@ -180,23 +180,22 @@ def __cut_DAG(sentence):
yield pair(l_word, word_tag_tab.get(l_word, 'x')) yield pair(l_word, word_tag_tab.get(l_word, 'x'))
x = y x = y
if len(buf)>0: if buf:
if len(buf) == 1: if len(buf) == 1:
yield pair(buf, word_tag_tab.get(buf, 'x')) yield pair(buf, word_tag_tab.get(buf, 'x'))
else: elif (buf not in jieba.FREQ):
if (buf not in jieba.FREQ): recognized = __cut_detail(buf)
regognized = __cut_detail(buf) for t in recognized:
for t in regognized:
yield t yield t
else: else:
for elem in buf: for elem in buf:
yield pair(elem, word_tag_tab.get(elem, 'x')) yield pair(elem, word_tag_tab.get(elem, 'x'))
def __cut_internal(sentence, HMM=True): def __cut_internal(sentence, HMM=True):
if not ( type(sentence) is unicode): if not isinstance(sentence, unicode):
try: try:
sentence = sentence.decode('utf-8') sentence = sentence.decode('utf-8')
except: except UnicodeDecodeError:
sentence = sentence.decode('gbk', 'ignore') sentence = sentence.decode('gbk', 'ignore')
re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)"), re.compile(ur"(\r\n|\s)") re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)"), re.compile(ur"(\r\n|\s)")
re_eng, re_num = re.compile(ur"[a-zA-Z0-9]+"), re.compile(ur"[\.0-9]+") re_eng, re_num = re.compile(ur"[a-zA-Z0-9]+"), re.compile(ur"[\.0-9]+")
@ -232,7 +231,7 @@ def __lcut_internal_no_hmm(sentence):
@makesure_userdict_loaded @makesure_userdict_loaded
def cut(sentence, HMM=True): def cut(sentence, HMM=True):
if (not hasattr(jieba,'pool')) or (jieba.pool==None): if (not hasattr(jieba, 'pool')) or (jieba.pool is None):
for w in __cut_internal(sentence, HMM=HMM): for w in __cut_internal(sentence, HMM=HMM):
yield w yield w
else: else:

View File

@ -14,28 +14,27 @@ def viterbi(obs, states, start_p, trans_p, emit_p):
for y in states.get(obs[0], all_states): #init for y in states.get(obs[0], all_states): #init
V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT) V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT)
mem_path[0][y] = '' mem_path[0][y] = ''
for t in range(1,len(obs)): for t in xrange(1, len(obs)):
V.append({}) V.append({})
mem_path.append({}) mem_path.append({})
#prev_states = get_top_states(V[t-1]) #prev_states = get_top_states(V[t-1])
prev_states = [x for x in mem_path[t-1].keys() if len(trans_p[x]) > 0] prev_states = [x for x in mem_path[t-1].keys() if len(trans_p[x]) > 0]
prev_states_expect_next = set((y for x in prev_states for y in trans_p[x].keys())) prev_states_expect_next = set((y for x in prev_states for y in trans_p[x].keys()))
obs_states = states.get(obs[t],all_states) obs_states = set(states.get(obs[t], all_states)) & prev_states_expect_next
obs_states = set(obs_states) & set(prev_states_expect_next)
if len(obs_states)==0: obs_states = prev_states_expect_next if not obs_states:
if len(obs_states)==0: obs_states = all_states obs_states = prev_states_expect_next if prev_states_expect_next else all_states
for y in obs_states: for y in obs_states:
(prob,state ) = max([(V[t-1][y0] + trans_p[y0].get(y,MIN_INF) + emit_p[y].get(obs[t],MIN_FLOAT) ,y0) for y0 in prev_states]) prob, state = max([(V[t-1][y0] + trans_p[y0].get(y,MIN_INF) + emit_p[y].get(obs[t],MIN_FLOAT), y0) for y0 in prev_states])
V[t][y] = prob V[t][y] = prob
mem_path[t][y] = state mem_path[t][y] = state
last = [(V[-1][y], y) for y in mem_path[-1].keys()] last = [(V[-1][y], y) for y in mem_path[-1].keys()]
#if len(last)==0: #if len(last)==0:
#print obs #print obs
(prob, state) = max(last) prob, state = max(last)
route = [None] * len(obs) route = [None] * len(obs)
i = len(obs) - 1 i = len(obs) - 1

View File

@ -4,14 +4,14 @@ sys.path.append("../")
import jieba import jieba
seg_list = jieba.cut("我来到北京清华大学", cut_all=True) seg_list = jieba.cut(u"我来到北京清华大学", cut_all=True)
print "Full Mode:", "/ ".join(seg_list) # 全模式 print u"Full Mode:", u"/ ".join(seg_list) # 全模式
seg_list = jieba.cut("我来到北京清华大学", cut_all=False) seg_list = jieba.cut(u"我来到北京清华大学", cut_all=False)
print "Default Mode:", "/ ".join(seg_list) # 默认模式 print u"Default Mode:", u"/ ".join(seg_list) # 默认模式
seg_list = jieba.cut("他来到了网易杭研大厦") seg_list = jieba.cut(u"他来到了网易杭研大厦")
print ", ".join(seg_list) print u", ".join(seg_list)
seg_list = jieba.cut_for_search("小明硕士毕业于中国科学院计算所,后在日本京都大学深造") # 搜索引擎模式 seg_list = jieba.cut_for_search(u"小明硕士毕业于中国科学院计算所,后在日本京都大学深造") # 搜索引擎模式
print ", ".join(seg_list) print u", ".join(seg_list)