don't compile re every time; autopep8

This commit is contained in:
Dingyuan Wang 2015-02-10 21:22:34 +08:00
parent 22bcf8be7a
commit 32a0e92a09
4 changed files with 155 additions and 90 deletions

View File

@ -18,24 +18,27 @@ from . import finalseg
DICTIONARY = "dict.txt"
DICT_LOCK = threading.RLock()
pfdict = None # to be initialized
pfdict = None # to be initialized
FREQ = {}
total = 0
user_word_tag_tab = {}
initialized = False
pool = None
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
_curpath = os.path.normpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
log_console = logging.StreamHandler(sys.stderr)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(log_console)
def setLogLevel(log_level):
global logger
logger.setLevel(log_level)
def gen_pfdict(f_name):
lfreq = {}
pfdict = set()
@ -50,12 +53,13 @@ def gen_pfdict(f_name):
lfreq[word] = freq
ltotal += freq
for ch in xrange(len(word)):
pfdict.add(word[:ch+1])
pfdict.add(word[:ch + 1])
except ValueError as e:
logger.debug('%s at line %s %s' % (f_name, lineno, line))
raise e
return pfdict, lfreq, ltotal
def initialize(dictionary=None):
global pfdict, FREQ, total, initialized, DICTIONARY, DICT_LOCK
if not dictionary:
@ -67,10 +71,12 @@ def initialize(dictionary=None):
abs_path = os.path.join(_curpath, dictionary)
logger.debug("Building prefix dict from %s ..." % abs_path)
t1 = time.time()
if abs_path == os.path.join(_curpath, "dict.txt"): #default dictionary
# default dictionary
if abs_path == os.path.join(_curpath, "dict.txt"):
cache_file = os.path.join(tempfile.gettempdir(), "jieba.cache")
else: #custom dictionary
cache_file = os.path.join(tempfile.gettempdir(), "jieba.u%s.cache" % md5(abs_path.encode('utf-8', 'replace')).hexdigest())
else: # custom dictionary
cache_file = os.path.join(tempfile.gettempdir(), "jieba.u%s.cache" % md5(
abs_path.encode('utf-8', 'replace')).hexdigest())
load_from_cache_fail = True
if os.path.isfile(cache_file) and os.path.getmtime(cache_file) > os.path.getmtime(abs_path):
@ -121,22 +127,24 @@ def require_initialized(fn):
def __cut_all(sentence):
dag = get_DAG(sentence)
old_j = -1
for k,L in iteritems(dag):
for k, L in iteritems(dag):
if len(L) == 1 and k > old_j:
yield sentence[k:L[0]+1]
yield sentence[k:L[0] + 1]
old_j = L[0]
else:
for j in L:
if j > k:
yield sentence[k:j+1]
yield sentence[k:j + 1]
old_j = j
def calc(sentence, DAG, route):
N = len(sentence)
route[N] = (0.0, '')
for idx in xrange(N-1, -1, -1):
route[idx] = max((log(FREQ.get(sentence[idx:x+1], 1)) - log(total) + route[x+1][0], x) for x in DAG[idx])
for idx in xrange(N - 1, -1, -1):
route[idx] = max((log(FREQ.get(sentence[idx:x + 1], 1)) -
log(total) + route[x + 1][0], x) for x in DAG[idx])
@require_initialized
def get_DAG(sentence):
@ -151,14 +159,16 @@ def get_DAG(sentence):
if frag in FREQ:
tmplist.append(i)
i += 1
frag = sentence[k:i+1]
frag = sentence[k:i + 1]
if not tmplist:
tmplist.append(k)
DAG[k] = tmplist
return DAG
re_eng = re.compile(r'[a-zA-Z0-9]', re.U)
def __cut_DAG_NO_HMM(sentence):
re_eng = re.compile(r'[a-zA-Z0-9]',re.U)
DAG = get_DAG(sentence)
route = {}
calc(sentence, DAG, route)
@ -181,6 +191,7 @@ def __cut_DAG_NO_HMM(sentence):
yield buf
buf = ''
def __cut_DAG(sentence):
DAG = get_DAG(sentence)
route = {}
@ -189,9 +200,9 @@ def __cut_DAG(sentence):
buf = ''
N = len(sentence)
while x < N:
y = route[x][1]+1
y = route[x][1] + 1
l_word = sentence[x:y]
if y-x == 1:
if y - x == 1:
buf += l_word
else:
if buf:
@ -221,6 +232,12 @@ def __cut_DAG(sentence):
for elem in buf:
yield elem
re_han_default = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U)
re_skip_default = re.compile("(\r\n|\s)", re.U)
re_han_cut_all = re.compile("([\u4E00-\u9FA5]+)", re.U)
re_skip_cut_all = re.compile("[^a-zA-Z0-9+#\n]", re.U)
def cut(sentence, cut_all=False, HMM=True):
'''The main function that segments an entire sentence that contains
Chinese characters into seperated words.
@ -235,9 +252,11 @@ def cut(sentence, cut_all=False, HMM=True):
# \r\n|\s : whitespace characters. Will not be handled.
if cut_all:
re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)", re.U), re.compile("[^a-zA-Z0-9+#\n]", re.U)
re_han = re_han_cut_all
re_skip = re_skip_cut_all
else:
re_han, re_skip = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U), re.compile("(\r\n|\s)", re.U)
re_han = re_han_default
re_skip = re_skip_default
blocks = re_han.split(sentence)
if cut_all:
cut_block = __cut_all
@ -262,21 +281,23 @@ def cut(sentence, cut_all=False, HMM=True):
else:
yield x
def cut_for_search(sentence, HMM=True):
words = cut(sentence, HMM=HMM)
for w in words:
if len(w) > 2:
for i in xrange(len(w)-1):
gram2 = w[i:i+2]
for i in xrange(len(w) - 1):
gram2 = w[i:i + 2]
if gram2 in FREQ:
yield gram2
if len(w) > 3:
for i in xrange(len(w)-2):
gram3 = w[i:i+3]
for i in xrange(len(w) - 2):
gram3 = w[i:i + 3]
if gram3 in FREQ:
yield gram3
yield w
@require_initialized
def load_userdict(f):
''' Load personalized dict to improve detect rate.
@ -300,6 +321,7 @@ def load_userdict(f):
if tup[1].isdigit():
add_word(*tup)
@require_initialized
def add_word(word, freq, tag=None):
global FREQ, pfdict, total, user_word_tag_tab
@ -309,17 +331,24 @@ def add_word(word, freq, tag=None):
if tag is not None:
user_word_tag_tab[word] = tag
for ch in xrange(len(word)):
pfdict.add(word[:ch+1])
pfdict.add(word[:ch + 1])
__ref_cut = cut
__ref_cut_for_search = cut_for_search
def __lcut(sentence):
return list(__ref_cut(sentence, False))
def __lcut_no_hmm(sentence):
return list(__ref_cut(sentence, False, False))
def __lcut_all(sentence):
return list(__ref_cut(sentence, True))
def __lcut_for_search(sentence):
return list(__ref_cut_for_search(sentence))
@ -334,7 +363,7 @@ def enable_parallel(processnum=None):
processnum = cpu_count()
pool = Pool(processnum)
def pcut(sentence,cut_all=False,HMM=True):
def pcut(sentence, cut_all=False, HMM=True):
parts = strdecode(sentence).split('\n')
if cut_all:
result = pool.map(__lcut_all, parts)
@ -356,6 +385,7 @@ def enable_parallel(processnum=None):
cut = pcut
cut_for_search = pcut_for_search
def disable_parallel():
global pool, cut, cut_for_search
if pool:
@ -364,6 +394,7 @@ def disable_parallel():
cut = __ref_cut
cut_for_search = __ref_cut_for_search
def set_dictionary(dictionary_path):
global initialized, DICTIONARY
with DICT_LOCK:
@ -373,9 +404,11 @@ def set_dictionary(dictionary_path):
DICTIONARY = abs_path
initialized = False
def get_abs_path_dict():
return os.path.join(_curpath, DICTIONARY)
def tokenize(unicode_sentence, mode="default", HMM=True):
"""Tokenize a sentence and yields tuples of (word, start, end)
Parameter:
@ -389,20 +422,20 @@ def tokenize(unicode_sentence, mode="default", HMM=True):
if mode == 'default':
for w in cut(unicode_sentence, HMM=HMM):
width = len(w)
yield (w, start, start+width)
yield (w, start, start + width)
start += width
else:
for w in cut(unicode_sentence, HMM=HMM):
width = len(w)
if len(w) > 2:
for i in xrange(len(w)-1):
gram2 = w[i:i+2]
for i in xrange(len(w) - 1):
gram2 = w[i:i + 2]
if gram2 in FREQ:
yield (gram2, start+i, start+i+2)
yield (gram2, start + i, start + i + 2)
if len(w) > 3:
for i in xrange(len(w)-2):
gram3 = w[i:i+3]
for i in xrange(len(w) - 2):
gram3 = w[i:i + 3]
if gram3 in FREQ:
yield (gram3, start+i, start+i+3)
yield (w, start, start+width)
yield (gram3, start + i, start + i + 3)
yield (w, start, start + width)
start += width

View File

@ -13,14 +13,16 @@ PROB_EMIT_P = "prob_emit.p"
PrevStatus = {
'B':('E','S'),
'M':('M','B'),
'S':('S','E'),
'E':('B','M')
'B': 'ES',
'M': 'MB',
'S': 'SE',
'E': 'BM'
}
def load_model():
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
_curpath = os.path.normpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
start_p = {}
abs_path = os.path.join(_curpath, PROB_START_P)
@ -44,50 +46,55 @@ if sys.platform.startswith("java"):
else:
from .prob_start import P as start_P
from .prob_trans import P as trans_P
from .prob_emit import P as emit_P
from .prob_emit import P as emit_P
def viterbi(obs, states, start_p, trans_p, emit_p):
V = [{}] #tabular
V = [{}] # tabular
path = {}
for y in states: #init
V[0][y] = start_p[y] + emit_p[y].get(obs[0],MIN_FLOAT)
for y in states: # init
V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT)
path[y] = [y]
for t in xrange(1, len(obs)):
V.append({})
newpath = {}
for y in states:
em_p = emit_p[y].get(obs[t],MIN_FLOAT)
(prob,state ) = max([(V[t-1][y0] + trans_p[y0].get(y, MIN_FLOAT) + em_p, y0) for y0 in PrevStatus[y]])
em_p = emit_p[y].get(obs[t], MIN_FLOAT)
(prob, state) = max(
[(V[t - 1][y0] + trans_p[y0].get(y, MIN_FLOAT) + em_p, y0) for y0 in PrevStatus[y]])
V[t][y] = prob
newpath[y] = path[state] + [y]
path = newpath
(prob, state) = max([(V[len(obs)-1][y], y) for y in ('E','S')])
(prob, state) = max((V[len(obs) - 1][y], y) for y in 'ES')
return (prob, path[state])
def __cut(sentence):
global emit_P
prob, pos_list = viterbi(sentence, ('B','M','E','S'), start_P, trans_P, emit_P)
prob, pos_list = viterbi(sentence, 'BMES', start_P, trans_P, emit_P)
begin, nexti = 0, 0
#print pos_list, sentence
for i,char in enumerate(sentence):
# print pos_list, sentence
for i, char in enumerate(sentence):
pos = pos_list[i]
if pos == 'B':
begin = i
elif pos == 'E':
yield sentence[begin:i+1]
nexti = i+1
yield sentence[begin:i + 1]
nexti = i + 1
elif pos == 'S':
yield char
nexti = i+1
nexti = i + 1
if nexti < len(sentence):
yield sentence[nexti:]
re_han = re.compile("([\u4E00-\u9FA5]+)")
re_skip = re.compile("(\d+\.\d+|[a-zA-Z0-9]+)")
def cut(sentence):
sentence = strdecode(sentence)
re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)"), re.compile("(\d+\.\d+|[a-zA-Z0-9]+)")
blocks = re_han.split(sentence)
for blk in blocks:
if re_han.match(blk):

View File

@ -13,8 +13,20 @@ PROB_TRANS_P = "prob_trans.p"
PROB_EMIT_P = "prob_emit.p"
CHAR_STATE_TAB_P = "char_state_tab.p"
re_han_detail = re.compile("([\u4E00-\u9FA5]+)")
re_skip_detail = re.compile("([\.0-9]+|[a-zA-Z0-9]+)")
re_han_internal = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)")
re_skip_internal = re.compile("(\r\n|\s)")
re_eng = re.compile("[a-zA-Z0-9]+")
re_num = re.compile("[\.0-9]+")
re_eng1 = re.compile('^[a-zA-Z0-9]$', re.U)
def load_model(f_name, isJython=True):
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
_curpath = os.path.normpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
result = {}
with open(f_name, "rb") as f:
@ -53,28 +65,32 @@ def load_model(f_name, isJython=True):
return state, start_p, trans_p, emit_p, result
if sys.platform.startswith("java"):
char_state_tab_P, start_P, trans_P, emit_P, word_tag_tab = load_model(jieba.get_abs_path_dict())
char_state_tab_P, start_P, trans_P, emit_P, word_tag_tab = load_model(
jieba.get_abs_path_dict())
else:
from .char_state_tab import P as char_state_tab_P
from .prob_start import P as start_P
from .prob_trans import P as trans_P
from .prob_emit import P as emit_P
from .prob_emit import P as emit_P
word_tag_tab = load_model(jieba.get_abs_path_dict(), isJython=False)
def makesure_userdict_loaded(fn):
@wraps(fn)
def wrapped(*args,**kwargs):
def wrapped(*args, **kwargs):
if jieba.user_word_tag_tab:
word_tag_tab.update(jieba.user_word_tag_tab)
jieba.user_word_tag_tab = {}
return fn(*args,**kwargs)
return fn(*args, **kwargs)
return wrapped
class pair(object):
def __init__(self,word,flag):
def __init__(self, word, flag):
self.word = word
self.flag = flag
@ -90,36 +106,37 @@ class pair(object):
else:
return self.__unicode__()
def encode(self,arg):
def encode(self, arg):
return self.__unicode__().encode(arg)
def __cut(sentence):
prob, pos_list = viterbi(sentence, char_state_tab_P, start_P, trans_P, emit_P)
prob, pos_list = viterbi(
sentence, char_state_tab_P, start_P, trans_P, emit_P)
begin, nexti = 0, 0
for i,char in enumerate(sentence):
for i, char in enumerate(sentence):
pos = pos_list[i][0]
if pos == 'B':
begin = i
elif pos == 'E':
yield pair(sentence[begin:i+1], pos_list[i][1])
nexti = i+1
yield pair(sentence[begin:i + 1], pos_list[i][1])
nexti = i + 1
elif pos == 'S':
yield pair(char, pos_list[i][1])
nexti = i+1
nexti = i + 1
if nexti < len(sentence):
yield pair(sentence[nexti:], pos_list[nexti][1])
def __cut_detail(sentence):
re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)"), re.compile("([\.0-9]+|[a-zA-Z0-9]+)")
re_eng, re_num = re.compile("[a-zA-Z0-9]+"), re.compile("[\.0-9]+")
blocks = re_han.split(sentence)
blocks = re_han_detail.split(sentence)
for blk in blocks:
if re_han.match(blk):
if re_han_detail.match(blk):
for word in __cut(blk):
yield word
else:
tmp = re_skip.split(blk)
tmp = re_skip_detail.split(blk)
for x in tmp:
if x:
if re_num.match(x):
@ -129,6 +146,7 @@ def __cut_detail(sentence):
else:
yield pair(x, 'x')
def __cut_DAG_NO_HMM(sentence):
DAG = jieba.get_DAG(sentence)
route = {}
@ -136,23 +154,23 @@ def __cut_DAG_NO_HMM(sentence):
x = 0
N = len(sentence)
buf = ''
re_eng = re.compile('[a-zA-Z0-9]',re.U)
while x < N:
y = route[x][1]+1
y = route[x][1] + 1
l_word = sentence[x:y]
if re_eng.match(l_word) and len(l_word) == 1:
if re_eng1.match(l_word):
buf += l_word
x = y
else:
if buf:
yield pair(buf,'eng')
yield pair(buf, 'eng')
buf = ''
yield pair(l_word, word_tag_tab.get(l_word, 'x'))
x = y
if buf:
yield pair(buf,'eng')
yield pair(buf, 'eng')
buf = ''
def __cut_DAG(sentence):
DAG = jieba.get_DAG(sentence)
route = {}
@ -163,9 +181,9 @@ def __cut_DAG(sentence):
buf = ''
N = len(sentence)
while x < N:
y = route[x][1]+1
y = route[x][1] + 1
l_word = sentence[x:y]
if y-x == 1:
if y - x == 1:
buf += l_word
else:
if buf:
@ -193,24 +211,23 @@ def __cut_DAG(sentence):
for elem in buf:
yield pair(elem, word_tag_tab.get(elem, 'x'))
def __cut_internal(sentence, HMM=True):
sentence = strdecode(sentence)
re_han, re_skip = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)"), re.compile("(\r\n|\s)")
re_eng, re_num = re.compile("[a-zA-Z0-9]+"), re.compile("[\.0-9]+")
blocks = re_han.split(sentence)
blocks = re_han_internal.split(sentence)
if HMM:
__cut_blk = __cut_DAG
else:
__cut_blk = __cut_DAG_NO_HMM
for blk in blocks:
if re_han.match(blk):
if re_han_internal.match(blk):
for word in __cut_blk(blk):
yield word
else:
tmp = re_skip.split(blk)
tmp = re_skip_internal.split(blk)
for x in tmp:
if re_skip.match(x):
if re_skip_internal.match(x):
yield pair(x, 'x')
else:
for xx in x:
@ -221,10 +238,13 @@ def __cut_internal(sentence, HMM=True):
else:
yield pair(xx, 'x')
def __lcut_internal(sentence):
return list(__cut_internal(sentence))
def __lcut_internal_no_hmm(sentence):
return list(__cut_internal(sentence,False))
return list(__cut_internal(sentence, False))
@makesure_userdict_loaded
@ -241,4 +261,3 @@ def cut(sentence, HMM=True):
for r in result:
for w in r:
yield w

View File

@ -6,36 +6,42 @@ MIN_INF = float("-inf")
if sys.version_info[0] > 2:
xrange = range
def get_top_states(t_state_v, K=4):
return sorted(t_state_v, key=t_state_v.__getitem__, reverse=True)[:K]
def viterbi(obs, states, start_p, trans_p, emit_p):
V = [{}] #tabular
V = [{}] # tabular
mem_path = [{}]
all_states = trans_p.keys()
for y in states.get(obs[0], all_states): #init
for y in states.get(obs[0], all_states): # init
V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT)
mem_path[0][y] = ''
for t in xrange(1, len(obs)):
V.append({})
mem_path.append({})
#prev_states = get_top_states(V[t-1])
prev_states = [x for x in mem_path[t-1].keys() if len(trans_p[x]) > 0]
prev_states = [
x for x in mem_path[t - 1].keys() if len(trans_p[x]) > 0]
prev_states_expect_next = set((y for x in prev_states for y in trans_p[x].keys()))
obs_states = set(states.get(obs[t], all_states)) & prev_states_expect_next
prev_states_expect_next = set(
(y for x in prev_states for y in trans_p[x].keys()))
obs_states = set(
states.get(obs[t], all_states)) & prev_states_expect_next
if not obs_states:
obs_states = prev_states_expect_next if prev_states_expect_next else all_states
for y in obs_states:
prob, state = max((V[t-1][y0] + trans_p[y0].get(y,MIN_INF) + emit_p[y].get(obs[t],MIN_FLOAT), y0) for y0 in prev_states)
prob, state = max((V[t - 1][y0] + trans_p[y0].get(y, MIN_INF) +
emit_p[y].get(obs[t], MIN_FLOAT), y0) for y0 in prev_states)
V[t][y] = prob
mem_path[t][y] = state
last = [(V[-1][y], y) for y in mem_path[-1].keys()]
#if len(last)==0:
#print obs
# if len(last)==0:
# print obs
prob, state = max(last)
route = [None] * len(obs)