jieba/jieba/finalseg/__init__.py
2014-11-15 13:44:30 +08:00

102 lines
2.7 KiB
Python

import re
import os
import marshal
import sys
MIN_FLOAT = -3.14e100
PROB_START_P = "prob_start.p"
PROB_TRANS_P = "prob_trans.p"
PROB_EMIT_P = "prob_emit.p"
PrevStatus = {
'B':('E','S'),
'M':('M','B'),
'S':('S','E'),
'E':('B','M')
}
def load_model():
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
start_p = {}
abs_path = os.path.join(_curpath, PROB_START_P)
with open(abs_path, 'rb') as f:
start_p = marshal.load(f)
trans_p = {}
abs_path = os.path.join(_curpath, PROB_TRANS_P)
with open(abs_path, 'rb') as f:
trans_p = marshal.load(f)
emit_p = {}
abs_path = os.path.join(_curpath, PROB_EMIT_P)
with open(abs_path, 'rb') as f:
emit_p = marshal.load(f)
return start_p, trans_p, emit_p
if sys.platform.startswith("java"):
start_P, trans_P, emit_P = load_model()
else:
from . import prob_start,prob_trans,prob_emit
start_P, trans_P, emit_P = prob_start.P, prob_trans.P, prob_emit.P
def viterbi(obs, states, start_p, trans_p, emit_p):
V = [{}] #tabular
path = {}
for y in states: #init
V[0][y] = start_p[y] + emit_p[y].get(obs[0],MIN_FLOAT)
path[y] = [y]
for t in range(1,len(obs)):
V.append({})
newpath = {}
for y in states:
em_p = emit_p[y].get(obs[t],MIN_FLOAT)
(prob,state ) = max([(V[t-1][y0] + trans_p[y0].get(y, MIN_FLOAT) + em_p, y0) for y0 in PrevStatus[y]])
V[t][y] = prob
newpath[y] = path[state] + [y]
path = newpath
(prob, state) = max([(V[len(obs)-1][y], y) for y in ('E','S')])
return (prob, path[state])
def __cut(sentence):
global emit_P
prob, pos_list = viterbi(sentence, ('B','M','E','S'), start_P, trans_P, emit_P)
begin, next = 0,0
#print pos_list, sentence
for i,char in enumerate(sentence):
pos = pos_list[i]
if pos == 'B':
begin = i
elif pos == 'E':
yield sentence[begin:i+1]
next = i+1
elif pos == 'S':
yield char
next = i+1
if next < len(sentence):
yield sentence[next:]
def cut(sentence):
if not isinstance(sentence, str):
try:
sentence = sentence.decode('utf-8')
except UnicodeDecodeError:
sentence = sentence.decode('gbk', 'ignore')
re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)"), re.compile("(\d+\.\d+|[a-zA-Z0-9]+)")
blocks = re_han.split(sentence)
for blk in blocks:
if re_han.match(blk):
for word in __cut(blk):
yield word
else:
tmp = re_skip.split(blk)
for x in tmp:
if x:
yield x