don't compile re every time; autopep8

This commit is contained in:
Dingyuan Wang 2015-02-10 21:22:34 +08:00
parent 22bcf8be7a
commit 32a0e92a09
4 changed files with 155 additions and 90 deletions

View File

@ -25,17 +25,20 @@ user_word_tag_tab = {}
initialized = False
pool = None
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
_curpath = os.path.normpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
log_console = logging.StreamHandler(sys.stderr)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(log_console)
def setLogLevel(log_level):
global logger
logger.setLevel(log_level)
def gen_pfdict(f_name):
lfreq = {}
pfdict = set()
@ -56,6 +59,7 @@ def gen_pfdict(f_name):
raise e
return pfdict, lfreq, ltotal
def initialize(dictionary=None):
global pfdict, FREQ, total, initialized, DICTIONARY, DICT_LOCK
if not dictionary:
@ -67,10 +71,12 @@ def initialize(dictionary=None):
abs_path = os.path.join(_curpath, dictionary)
logger.debug("Building prefix dict from %s ..." % abs_path)
t1 = time.time()
if abs_path == os.path.join(_curpath, "dict.txt"): #default dictionary
# default dictionary
if abs_path == os.path.join(_curpath, "dict.txt"):
cache_file = os.path.join(tempfile.gettempdir(), "jieba.cache")
else: # custom dictionary
cache_file = os.path.join(tempfile.gettempdir(), "jieba.u%s.cache" % md5(abs_path.encode('utf-8', 'replace')).hexdigest())
cache_file = os.path.join(tempfile.gettempdir(), "jieba.u%s.cache" % md5(
abs_path.encode('utf-8', 'replace')).hexdigest())
load_from_cache_fail = True
if os.path.isfile(cache_file) and os.path.getmtime(cache_file) > os.path.getmtime(abs_path):
@ -136,7 +142,9 @@ def calc(sentence, DAG, route):
N = len(sentence)
route[N] = (0.0, '')
for idx in xrange(N - 1, -1, -1):
route[idx] = max((log(FREQ.get(sentence[idx:x+1], 1)) - log(total) + route[x+1][0], x) for x in DAG[idx])
route[idx] = max((log(FREQ.get(sentence[idx:x + 1], 1)) -
log(total) + route[x + 1][0], x) for x in DAG[idx])
@require_initialized
def get_DAG(sentence):
@ -157,8 +165,10 @@ def get_DAG(sentence):
DAG[k] = tmplist
return DAG
def __cut_DAG_NO_HMM(sentence):
re_eng = re.compile(r'[a-zA-Z0-9]', re.U)
def __cut_DAG_NO_HMM(sentence):
DAG = get_DAG(sentence)
route = {}
calc(sentence, DAG, route)
@ -181,6 +191,7 @@ def __cut_DAG_NO_HMM(sentence):
yield buf
buf = ''
def __cut_DAG(sentence):
DAG = get_DAG(sentence)
route = {}
@ -221,6 +232,12 @@ def __cut_DAG(sentence):
for elem in buf:
yield elem
re_han_default = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U)
re_skip_default = re.compile("(\r\n|\s)", re.U)
re_han_cut_all = re.compile("([\u4E00-\u9FA5]+)", re.U)
re_skip_cut_all = re.compile("[^a-zA-Z0-9+#\n]", re.U)
def cut(sentence, cut_all=False, HMM=True):
'''The main function that segments an entire sentence that contains
Chinese characters into seperated words.
@ -235,9 +252,11 @@ def cut(sentence, cut_all=False, HMM=True):
# \r\n|\s : whitespace characters. Will not be handled.
if cut_all:
re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)", re.U), re.compile("[^a-zA-Z0-9+#\n]", re.U)
re_han = re_han_cut_all
re_skip = re_skip_cut_all
else:
re_han, re_skip = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U), re.compile("(\r\n|\s)", re.U)
re_han = re_han_default
re_skip = re_skip_default
blocks = re_han.split(sentence)
if cut_all:
cut_block = __cut_all
@ -262,6 +281,7 @@ def cut(sentence, cut_all=False, HMM=True):
else:
yield x
def cut_for_search(sentence, HMM=True):
words = cut(sentence, HMM=HMM)
for w in words:
@ -277,6 +297,7 @@ def cut_for_search(sentence, HMM=True):
yield gram3
yield w
@require_initialized
def load_userdict(f):
''' Load personalized dict to improve detect rate.
@ -300,6 +321,7 @@ def load_userdict(f):
if tup[1].isdigit():
add_word(*tup)
@require_initialized
def add_word(word, freq, tag=None):
global FREQ, pfdict, total, user_word_tag_tab
@ -314,12 +336,19 @@ def add_word(word, freq, tag=None):
__ref_cut = cut
__ref_cut_for_search = cut_for_search
def __lcut(sentence):
return list(__ref_cut(sentence, False))
def __lcut_no_hmm(sentence):
return list(__ref_cut(sentence, False, False))
def __lcut_all(sentence):
return list(__ref_cut(sentence, True))
def __lcut_for_search(sentence):
return list(__ref_cut_for_search(sentence))
@ -356,6 +385,7 @@ def enable_parallel(processnum=None):
cut = pcut
cut_for_search = pcut_for_search
def disable_parallel():
global pool, cut, cut_for_search
if pool:
@ -364,6 +394,7 @@ def disable_parallel():
cut = __ref_cut
cut_for_search = __ref_cut_for_search
def set_dictionary(dictionary_path):
global initialized, DICTIONARY
with DICT_LOCK:
@ -373,9 +404,11 @@ def set_dictionary(dictionary_path):
DICTIONARY = abs_path
initialized = False
def get_abs_path_dict():
return os.path.join(_curpath, DICTIONARY)
def tokenize(unicode_sentence, mode="default", HMM=True):
"""Tokenize a sentence and yields tuples of (word, start, end)
Parameter:

View File

@ -13,14 +13,16 @@ PROB_EMIT_P = "prob_emit.p"
PrevStatus = {
'B':('E','S'),
'M':('M','B'),
'S':('S','E'),
'E':('B','M')
'B': 'ES',
'M': 'MB',
'S': 'SE',
'E': 'BM'
}
def load_model():
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
_curpath = os.path.normpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
start_p = {}
abs_path = os.path.join(_curpath, PROB_START_P)
@ -46,6 +48,7 @@ else:
from .prob_trans import P as trans_P
from .prob_emit import P as emit_P
def viterbi(obs, states, start_p, trans_p, emit_p):
V = [{}] # tabular
path = {}
@ -57,19 +60,20 @@ def viterbi(obs, states, start_p, trans_p, emit_p):
newpath = {}
for y in states:
em_p = emit_p[y].get(obs[t], MIN_FLOAT)
(prob,state ) = max([(V[t-1][y0] + trans_p[y0].get(y, MIN_FLOAT) + em_p, y0) for y0 in PrevStatus[y]])
(prob, state) = max(
[(V[t - 1][y0] + trans_p[y0].get(y, MIN_FLOAT) + em_p, y0) for y0 in PrevStatus[y]])
V[t][y] = prob
newpath[y] = path[state] + [y]
path = newpath
(prob, state) = max([(V[len(obs)-1][y], y) for y in ('E','S')])
(prob, state) = max((V[len(obs) - 1][y], y) for y in 'ES')
return (prob, path[state])
def __cut(sentence):
global emit_P
prob, pos_list = viterbi(sentence, ('B','M','E','S'), start_P, trans_P, emit_P)
prob, pos_list = viterbi(sentence, 'BMES', start_P, trans_P, emit_P)
begin, nexti = 0, 0
# print pos_list, sentence
for i, char in enumerate(sentence):
@ -85,9 +89,12 @@ def __cut(sentence):
if nexti < len(sentence):
yield sentence[nexti:]
re_han = re.compile("([\u4E00-\u9FA5]+)")
re_skip = re.compile("(\d+\.\d+|[a-zA-Z0-9]+)")
def cut(sentence):
sentence = strdecode(sentence)
re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)"), re.compile("(\d+\.\d+|[a-zA-Z0-9]+)")
blocks = re_han.split(sentence)
for blk in blocks:
if re_han.match(blk):

View File

@ -13,8 +13,20 @@ PROB_TRANS_P = "prob_trans.p"
PROB_EMIT_P = "prob_emit.p"
CHAR_STATE_TAB_P = "char_state_tab.p"
re_han_detail = re.compile("([\u4E00-\u9FA5]+)")
re_skip_detail = re.compile("([\.0-9]+|[a-zA-Z0-9]+)")
re_han_internal = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)")
re_skip_internal = re.compile("(\r\n|\s)")
re_eng = re.compile("[a-zA-Z0-9]+")
re_num = re.compile("[\.0-9]+")
re_eng1 = re.compile('^[a-zA-Z0-9]$', re.U)
def load_model(f_name, isJython=True):
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
_curpath = os.path.normpath(
os.path.join(os.getcwd(), os.path.dirname(__file__)))
result = {}
with open(f_name, "rb") as f:
@ -53,7 +65,8 @@ def load_model(f_name, isJython=True):
return state, start_p, trans_p, emit_p, result
if sys.platform.startswith("java"):
char_state_tab_P, start_P, trans_P, emit_P, word_tag_tab = load_model(jieba.get_abs_path_dict())
char_state_tab_P, start_P, trans_P, emit_P, word_tag_tab = load_model(
jieba.get_abs_path_dict())
else:
from .char_state_tab import P as char_state_tab_P
from .prob_start import P as start_P
@ -62,6 +75,7 @@ else:
word_tag_tab = load_model(jieba.get_abs_path_dict(), isJython=False)
def makesure_userdict_loaded(fn):
@wraps(fn)
@ -73,7 +87,9 @@ def makesure_userdict_loaded(fn):
return wrapped
class pair(object):
def __init__(self, word, flag):
self.word = word
self.flag = flag
@ -93,8 +109,10 @@ class pair(object):
def encode(self, arg):
return self.__unicode__().encode(arg)
def __cut(sentence):
prob, pos_list = viterbi(sentence, char_state_tab_P, start_P, trans_P, emit_P)
prob, pos_list = viterbi(
sentence, char_state_tab_P, start_P, trans_P, emit_P)
begin, nexti = 0, 0
for i, char in enumerate(sentence):
@ -110,16 +128,15 @@ def __cut(sentence):
if nexti < len(sentence):
yield pair(sentence[nexti:], pos_list[nexti][1])
def __cut_detail(sentence):
re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)"), re.compile("([\.0-9]+|[a-zA-Z0-9]+)")
re_eng, re_num = re.compile("[a-zA-Z0-9]+"), re.compile("[\.0-9]+")
blocks = re_han.split(sentence)
blocks = re_han_detail.split(sentence)
for blk in blocks:
if re_han.match(blk):
if re_han_detail.match(blk):
for word in __cut(blk):
yield word
else:
tmp = re_skip.split(blk)
tmp = re_skip_detail.split(blk)
for x in tmp:
if x:
if re_num.match(x):
@ -129,6 +146,7 @@ def __cut_detail(sentence):
else:
yield pair(x, 'x')
def __cut_DAG_NO_HMM(sentence):
DAG = jieba.get_DAG(sentence)
route = {}
@ -136,11 +154,10 @@ def __cut_DAG_NO_HMM(sentence):
x = 0
N = len(sentence)
buf = ''
re_eng = re.compile('[a-zA-Z0-9]',re.U)
while x < N:
y = route[x][1] + 1
l_word = sentence[x:y]
if re_eng.match(l_word) and len(l_word) == 1:
if re_eng1.match(l_word):
buf += l_word
x = y
else:
@ -153,6 +170,7 @@ def __cut_DAG_NO_HMM(sentence):
yield pair(buf, 'eng')
buf = ''
def __cut_DAG(sentence):
DAG = jieba.get_DAG(sentence)
route = {}
@ -193,24 +211,23 @@ def __cut_DAG(sentence):
for elem in buf:
yield pair(elem, word_tag_tab.get(elem, 'x'))
def __cut_internal(sentence, HMM=True):
sentence = strdecode(sentence)
re_han, re_skip = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)"), re.compile("(\r\n|\s)")
re_eng, re_num = re.compile("[a-zA-Z0-9]+"), re.compile("[\.0-9]+")
blocks = re_han.split(sentence)
blocks = re_han_internal.split(sentence)
if HMM:
__cut_blk = __cut_DAG
else:
__cut_blk = __cut_DAG_NO_HMM
for blk in blocks:
if re_han.match(blk):
if re_han_internal.match(blk):
for word in __cut_blk(blk):
yield word
else:
tmp = re_skip.split(blk)
tmp = re_skip_internal.split(blk)
for x in tmp:
if re_skip.match(x):
if re_skip_internal.match(x):
yield pair(x, 'x')
else:
for xx in x:
@ -221,8 +238,11 @@ def __cut_internal(sentence, HMM=True):
else:
yield pair(xx, 'x')
def __lcut_internal(sentence):
return list(__cut_internal(sentence))
def __lcut_internal_no_hmm(sentence):
return list(__cut_internal(sentence, False))
@ -241,4 +261,3 @@ def cut(sentence, HMM=True):
for r in result:
for w in r:
yield w

View File

@ -6,9 +6,11 @@ MIN_INF = float("-inf")
if sys.version_info[0] > 2:
xrange = range
def get_top_states(t_state_v, K=4):
return sorted(t_state_v, key=t_state_v.__getitem__, reverse=True)[:K]
def viterbi(obs, states, start_p, trans_p, emit_p):
V = [{}] # tabular
mem_path = [{}]
@ -20,16 +22,20 @@ def viterbi(obs, states, start_p, trans_p, emit_p):
V.append({})
mem_path.append({})
#prev_states = get_top_states(V[t-1])
prev_states = [x for x in mem_path[t-1].keys() if len(trans_p[x]) > 0]
prev_states = [
x for x in mem_path[t - 1].keys() if len(trans_p[x]) > 0]
prev_states_expect_next = set((y for x in prev_states for y in trans_p[x].keys()))
obs_states = set(states.get(obs[t], all_states)) & prev_states_expect_next
prev_states_expect_next = set(
(y for x in prev_states for y in trans_p[x].keys()))
obs_states = set(
states.get(obs[t], all_states)) & prev_states_expect_next
if not obs_states:
obs_states = prev_states_expect_next if prev_states_expect_next else all_states
for y in obs_states:
prob, state = max((V[t-1][y0] + trans_p[y0].get(y,MIN_INF) + emit_p[y].get(obs[t],MIN_FLOAT), y0) for y0 in prev_states)
prob, state = max((V[t - 1][y0] + trans_p[y0].get(y, MIN_INF) +
emit_p[y].get(obs[t], MIN_FLOAT), y0) for y0 in prev_states)
V[t][y] = prob
mem_path[t][y] = state