From 32a0e92a09614cf5c72f87b1a59a5c4369200516 Mon Sep 17 00:00:00 2001 From: Dingyuan Wang Date: Tue, 10 Feb 2015 21:22:34 +0800 Subject: [PATCH] don't compile re every time; autopep8 --- jieba/__init__.py | 95 +++++++++++++++++++++++++------------- jieba/finalseg/__init__.py | 45 ++++++++++-------- jieba/posseg/__init__.py | 83 ++++++++++++++++++++------------- jieba/posseg/viterbi.py | 22 +++++---- 4 files changed, 155 insertions(+), 90 deletions(-) diff --git a/jieba/__init__.py b/jieba/__init__.py index 6803d65..4561e39 100644 --- a/jieba/__init__.py +++ b/jieba/__init__.py @@ -18,24 +18,27 @@ from . import finalseg DICTIONARY = "dict.txt" DICT_LOCK = threading.RLock() -pfdict = None # to be initialized +pfdict = None # to be initialized FREQ = {} total = 0 user_word_tag_tab = {} initialized = False pool = None -_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) +_curpath = os.path.normpath( + os.path.join(os.getcwd(), os.path.dirname(__file__))) log_console = logging.StreamHandler(sys.stderr) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(log_console) + def setLogLevel(log_level): global logger logger.setLevel(log_level) + def gen_pfdict(f_name): lfreq = {} pfdict = set() @@ -50,12 +53,13 @@ def gen_pfdict(f_name): lfreq[word] = freq ltotal += freq for ch in xrange(len(word)): - pfdict.add(word[:ch+1]) + pfdict.add(word[:ch + 1]) except ValueError as e: logger.debug('%s at line %s %s' % (f_name, lineno, line)) raise e return pfdict, lfreq, ltotal + def initialize(dictionary=None): global pfdict, FREQ, total, initialized, DICTIONARY, DICT_LOCK if not dictionary: @@ -67,10 +71,12 @@ def initialize(dictionary=None): abs_path = os.path.join(_curpath, dictionary) logger.debug("Building prefix dict from %s ..." % abs_path) t1 = time.time() - if abs_path == os.path.join(_curpath, "dict.txt"): #default dictionary + # default dictionary + if abs_path == os.path.join(_curpath, "dict.txt"): cache_file = os.path.join(tempfile.gettempdir(), "jieba.cache") - else: #custom dictionary - cache_file = os.path.join(tempfile.gettempdir(), "jieba.u%s.cache" % md5(abs_path.encode('utf-8', 'replace')).hexdigest()) + else: # custom dictionary + cache_file = os.path.join(tempfile.gettempdir(), "jieba.u%s.cache" % md5( + abs_path.encode('utf-8', 'replace')).hexdigest()) load_from_cache_fail = True if os.path.isfile(cache_file) and os.path.getmtime(cache_file) > os.path.getmtime(abs_path): @@ -121,22 +127,24 @@ def require_initialized(fn): def __cut_all(sentence): dag = get_DAG(sentence) old_j = -1 - for k,L in iteritems(dag): + for k, L in iteritems(dag): if len(L) == 1 and k > old_j: - yield sentence[k:L[0]+1] + yield sentence[k:L[0] + 1] old_j = L[0] else: for j in L: if j > k: - yield sentence[k:j+1] + yield sentence[k:j + 1] old_j = j def calc(sentence, DAG, route): N = len(sentence) route[N] = (0.0, '') - for idx in xrange(N-1, -1, -1): - route[idx] = max((log(FREQ.get(sentence[idx:x+1], 1)) - log(total) + route[x+1][0], x) for x in DAG[idx]) + for idx in xrange(N - 1, -1, -1): + route[idx] = max((log(FREQ.get(sentence[idx:x + 1], 1)) - + log(total) + route[x + 1][0], x) for x in DAG[idx]) + @require_initialized def get_DAG(sentence): @@ -151,14 +159,16 @@ def get_DAG(sentence): if frag in FREQ: tmplist.append(i) i += 1 - frag = sentence[k:i+1] + frag = sentence[k:i + 1] if not tmplist: tmplist.append(k) DAG[k] = tmplist return DAG +re_eng = re.compile(r'[a-zA-Z0-9]', re.U) + + def __cut_DAG_NO_HMM(sentence): - re_eng = re.compile(r'[a-zA-Z0-9]',re.U) DAG = get_DAG(sentence) route = {} calc(sentence, DAG, route) @@ -181,6 +191,7 @@ def __cut_DAG_NO_HMM(sentence): yield buf buf = '' + def __cut_DAG(sentence): DAG = get_DAG(sentence) route = {} @@ -189,9 +200,9 @@ def __cut_DAG(sentence): buf = '' N = len(sentence) while x < N: - y = route[x][1]+1 + y = route[x][1] + 1 l_word = sentence[x:y] - if y-x == 1: + if y - x == 1: buf += l_word else: if buf: @@ -221,6 +232,12 @@ def __cut_DAG(sentence): for elem in buf: yield elem +re_han_default = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U) +re_skip_default = re.compile("(\r\n|\s)", re.U) +re_han_cut_all = re.compile("([\u4E00-\u9FA5]+)", re.U) +re_skip_cut_all = re.compile("[^a-zA-Z0-9+#\n]", re.U) + + def cut(sentence, cut_all=False, HMM=True): '''The main function that segments an entire sentence that contains Chinese characters into seperated words. @@ -235,9 +252,11 @@ def cut(sentence, cut_all=False, HMM=True): # \r\n|\s : whitespace characters. Will not be handled. if cut_all: - re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)", re.U), re.compile("[^a-zA-Z0-9+#\n]", re.U) + re_han = re_han_cut_all + re_skip = re_skip_cut_all else: - re_han, re_skip = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U), re.compile("(\r\n|\s)", re.U) + re_han = re_han_default + re_skip = re_skip_default blocks = re_han.split(sentence) if cut_all: cut_block = __cut_all @@ -262,21 +281,23 @@ def cut(sentence, cut_all=False, HMM=True): else: yield x + def cut_for_search(sentence, HMM=True): words = cut(sentence, HMM=HMM) for w in words: if len(w) > 2: - for i in xrange(len(w)-1): - gram2 = w[i:i+2] + for i in xrange(len(w) - 1): + gram2 = w[i:i + 2] if gram2 in FREQ: yield gram2 if len(w) > 3: - for i in xrange(len(w)-2): - gram3 = w[i:i+3] + for i in xrange(len(w) - 2): + gram3 = w[i:i + 3] if gram3 in FREQ: yield gram3 yield w + @require_initialized def load_userdict(f): ''' Load personalized dict to improve detect rate. @@ -300,6 +321,7 @@ def load_userdict(f): if tup[1].isdigit(): add_word(*tup) + @require_initialized def add_word(word, freq, tag=None): global FREQ, pfdict, total, user_word_tag_tab @@ -309,17 +331,24 @@ def add_word(word, freq, tag=None): if tag is not None: user_word_tag_tab[word] = tag for ch in xrange(len(word)): - pfdict.add(word[:ch+1]) + pfdict.add(word[:ch + 1]) __ref_cut = cut __ref_cut_for_search = cut_for_search + def __lcut(sentence): return list(__ref_cut(sentence, False)) + + def __lcut_no_hmm(sentence): return list(__ref_cut(sentence, False, False)) + + def __lcut_all(sentence): return list(__ref_cut(sentence, True)) + + def __lcut_for_search(sentence): return list(__ref_cut_for_search(sentence)) @@ -334,7 +363,7 @@ def enable_parallel(processnum=None): processnum = cpu_count() pool = Pool(processnum) - def pcut(sentence,cut_all=False,HMM=True): + def pcut(sentence, cut_all=False, HMM=True): parts = strdecode(sentence).split('\n') if cut_all: result = pool.map(__lcut_all, parts) @@ -356,6 +385,7 @@ def enable_parallel(processnum=None): cut = pcut cut_for_search = pcut_for_search + def disable_parallel(): global pool, cut, cut_for_search if pool: @@ -364,6 +394,7 @@ def disable_parallel(): cut = __ref_cut cut_for_search = __ref_cut_for_search + def set_dictionary(dictionary_path): global initialized, DICTIONARY with DICT_LOCK: @@ -373,9 +404,11 @@ def set_dictionary(dictionary_path): DICTIONARY = abs_path initialized = False + def get_abs_path_dict(): return os.path.join(_curpath, DICTIONARY) + def tokenize(unicode_sentence, mode="default", HMM=True): """Tokenize a sentence and yields tuples of (word, start, end) Parameter: @@ -389,20 +422,20 @@ def tokenize(unicode_sentence, mode="default", HMM=True): if mode == 'default': for w in cut(unicode_sentence, HMM=HMM): width = len(w) - yield (w, start, start+width) + yield (w, start, start + width) start += width else: for w in cut(unicode_sentence, HMM=HMM): width = len(w) if len(w) > 2: - for i in xrange(len(w)-1): - gram2 = w[i:i+2] + for i in xrange(len(w) - 1): + gram2 = w[i:i + 2] if gram2 in FREQ: - yield (gram2, start+i, start+i+2) + yield (gram2, start + i, start + i + 2) if len(w) > 3: - for i in xrange(len(w)-2): - gram3 = w[i:i+3] + for i in xrange(len(w) - 2): + gram3 = w[i:i + 3] if gram3 in FREQ: - yield (gram3, start+i, start+i+3) - yield (w, start, start+width) + yield (gram3, start + i, start + i + 3) + yield (w, start, start + width) start += width diff --git a/jieba/finalseg/__init__.py b/jieba/finalseg/__init__.py index 50bebb3..a780cff 100644 --- a/jieba/finalseg/__init__.py +++ b/jieba/finalseg/__init__.py @@ -13,14 +13,16 @@ PROB_EMIT_P = "prob_emit.p" PrevStatus = { - 'B':('E','S'), - 'M':('M','B'), - 'S':('S','E'), - 'E':('B','M') + 'B': 'ES', + 'M': 'MB', + 'S': 'SE', + 'E': 'BM' } + def load_model(): - _curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + _curpath = os.path.normpath( + os.path.join(os.getcwd(), os.path.dirname(__file__))) start_p = {} abs_path = os.path.join(_curpath, PROB_START_P) @@ -44,50 +46,55 @@ if sys.platform.startswith("java"): else: from .prob_start import P as start_P from .prob_trans import P as trans_P - from .prob_emit import P as emit_P + from .prob_emit import P as emit_P + def viterbi(obs, states, start_p, trans_p, emit_p): - V = [{}] #tabular + V = [{}] # tabular path = {} - for y in states: #init - V[0][y] = start_p[y] + emit_p[y].get(obs[0],MIN_FLOAT) + for y in states: # init + V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT) path[y] = [y] for t in xrange(1, len(obs)): V.append({}) newpath = {} for y in states: - em_p = emit_p[y].get(obs[t],MIN_FLOAT) - (prob,state ) = max([(V[t-1][y0] + trans_p[y0].get(y, MIN_FLOAT) + em_p, y0) for y0 in PrevStatus[y]]) + em_p = emit_p[y].get(obs[t], MIN_FLOAT) + (prob, state) = max( + [(V[t - 1][y0] + trans_p[y0].get(y, MIN_FLOAT) + em_p, y0) for y0 in PrevStatus[y]]) V[t][y] = prob newpath[y] = path[state] + [y] path = newpath - (prob, state) = max([(V[len(obs)-1][y], y) for y in ('E','S')]) + (prob, state) = max((V[len(obs) - 1][y], y) for y in 'ES') return (prob, path[state]) def __cut(sentence): global emit_P - prob, pos_list = viterbi(sentence, ('B','M','E','S'), start_P, trans_P, emit_P) + prob, pos_list = viterbi(sentence, 'BMES', start_P, trans_P, emit_P) begin, nexti = 0, 0 - #print pos_list, sentence - for i,char in enumerate(sentence): + # print pos_list, sentence + for i, char in enumerate(sentence): pos = pos_list[i] if pos == 'B': begin = i elif pos == 'E': - yield sentence[begin:i+1] - nexti = i+1 + yield sentence[begin:i + 1] + nexti = i + 1 elif pos == 'S': yield char - nexti = i+1 + nexti = i + 1 if nexti < len(sentence): yield sentence[nexti:] +re_han = re.compile("([\u4E00-\u9FA5]+)") +re_skip = re.compile("(\d+\.\d+|[a-zA-Z0-9]+)") + + def cut(sentence): sentence = strdecode(sentence) - re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)"), re.compile("(\d+\.\d+|[a-zA-Z0-9]+)") blocks = re_han.split(sentence) for blk in blocks: if re_han.match(blk): diff --git a/jieba/posseg/__init__.py b/jieba/posseg/__init__.py index 7e6b45a..d648f28 100644 --- a/jieba/posseg/__init__.py +++ b/jieba/posseg/__init__.py @@ -13,8 +13,20 @@ PROB_TRANS_P = "prob_trans.p" PROB_EMIT_P = "prob_emit.p" CHAR_STATE_TAB_P = "char_state_tab.p" +re_han_detail = re.compile("([\u4E00-\u9FA5]+)") +re_skip_detail = re.compile("([\.0-9]+|[a-zA-Z0-9]+)") +re_han_internal = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)") +re_skip_internal = re.compile("(\r\n|\s)") + +re_eng = re.compile("[a-zA-Z0-9]+") +re_num = re.compile("[\.0-9]+") + +re_eng1 = re.compile('^[a-zA-Z0-9]$', re.U) + + def load_model(f_name, isJython=True): - _curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + _curpath = os.path.normpath( + os.path.join(os.getcwd(), os.path.dirname(__file__))) result = {} with open(f_name, "rb") as f: @@ -53,28 +65,32 @@ def load_model(f_name, isJython=True): return state, start_p, trans_p, emit_p, result if sys.platform.startswith("java"): - char_state_tab_P, start_P, trans_P, emit_P, word_tag_tab = load_model(jieba.get_abs_path_dict()) + char_state_tab_P, start_P, trans_P, emit_P, word_tag_tab = load_model( + jieba.get_abs_path_dict()) else: from .char_state_tab import P as char_state_tab_P from .prob_start import P as start_P from .prob_trans import P as trans_P - from .prob_emit import P as emit_P + from .prob_emit import P as emit_P word_tag_tab = load_model(jieba.get_abs_path_dict(), isJython=False) + def makesure_userdict_loaded(fn): @wraps(fn) - def wrapped(*args,**kwargs): + def wrapped(*args, **kwargs): if jieba.user_word_tag_tab: word_tag_tab.update(jieba.user_word_tag_tab) jieba.user_word_tag_tab = {} - return fn(*args,**kwargs) + return fn(*args, **kwargs) return wrapped + class pair(object): - def __init__(self,word,flag): + + def __init__(self, word, flag): self.word = word self.flag = flag @@ -90,36 +106,37 @@ class pair(object): else: return self.__unicode__() - def encode(self,arg): + def encode(self, arg): return self.__unicode__().encode(arg) + def __cut(sentence): - prob, pos_list = viterbi(sentence, char_state_tab_P, start_P, trans_P, emit_P) + prob, pos_list = viterbi( + sentence, char_state_tab_P, start_P, trans_P, emit_P) begin, nexti = 0, 0 - for i,char in enumerate(sentence): + for i, char in enumerate(sentence): pos = pos_list[i][0] if pos == 'B': begin = i elif pos == 'E': - yield pair(sentence[begin:i+1], pos_list[i][1]) - nexti = i+1 + yield pair(sentence[begin:i + 1], pos_list[i][1]) + nexti = i + 1 elif pos == 'S': yield pair(char, pos_list[i][1]) - nexti = i+1 + nexti = i + 1 if nexti < len(sentence): yield pair(sentence[nexti:], pos_list[nexti][1]) + def __cut_detail(sentence): - re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)"), re.compile("([\.0-9]+|[a-zA-Z0-9]+)") - re_eng, re_num = re.compile("[a-zA-Z0-9]+"), re.compile("[\.0-9]+") - blocks = re_han.split(sentence) + blocks = re_han_detail.split(sentence) for blk in blocks: - if re_han.match(blk): + if re_han_detail.match(blk): for word in __cut(blk): yield word else: - tmp = re_skip.split(blk) + tmp = re_skip_detail.split(blk) for x in tmp: if x: if re_num.match(x): @@ -129,6 +146,7 @@ def __cut_detail(sentence): else: yield pair(x, 'x') + def __cut_DAG_NO_HMM(sentence): DAG = jieba.get_DAG(sentence) route = {} @@ -136,23 +154,23 @@ def __cut_DAG_NO_HMM(sentence): x = 0 N = len(sentence) buf = '' - re_eng = re.compile('[a-zA-Z0-9]',re.U) while x < N: - y = route[x][1]+1 + y = route[x][1] + 1 l_word = sentence[x:y] - if re_eng.match(l_word) and len(l_word) == 1: + if re_eng1.match(l_word): buf += l_word x = y else: if buf: - yield pair(buf,'eng') + yield pair(buf, 'eng') buf = '' yield pair(l_word, word_tag_tab.get(l_word, 'x')) x = y if buf: - yield pair(buf,'eng') + yield pair(buf, 'eng') buf = '' + def __cut_DAG(sentence): DAG = jieba.get_DAG(sentence) route = {} @@ -163,9 +181,9 @@ def __cut_DAG(sentence): buf = '' N = len(sentence) while x < N: - y = route[x][1]+1 + y = route[x][1] + 1 l_word = sentence[x:y] - if y-x == 1: + if y - x == 1: buf += l_word else: if buf: @@ -193,24 +211,23 @@ def __cut_DAG(sentence): for elem in buf: yield pair(elem, word_tag_tab.get(elem, 'x')) + def __cut_internal(sentence, HMM=True): sentence = strdecode(sentence) - re_han, re_skip = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)"), re.compile("(\r\n|\s)") - re_eng, re_num = re.compile("[a-zA-Z0-9]+"), re.compile("[\.0-9]+") - blocks = re_han.split(sentence) + blocks = re_han_internal.split(sentence) if HMM: __cut_blk = __cut_DAG else: __cut_blk = __cut_DAG_NO_HMM for blk in blocks: - if re_han.match(blk): + if re_han_internal.match(blk): for word in __cut_blk(blk): yield word else: - tmp = re_skip.split(blk) + tmp = re_skip_internal.split(blk) for x in tmp: - if re_skip.match(x): + if re_skip_internal.match(x): yield pair(x, 'x') else: for xx in x: @@ -221,10 +238,13 @@ def __cut_internal(sentence, HMM=True): else: yield pair(xx, 'x') + def __lcut_internal(sentence): return list(__cut_internal(sentence)) + + def __lcut_internal_no_hmm(sentence): - return list(__cut_internal(sentence,False)) + return list(__cut_internal(sentence, False)) @makesure_userdict_loaded @@ -241,4 +261,3 @@ def cut(sentence, HMM=True): for r in result: for w in r: yield w - diff --git a/jieba/posseg/viterbi.py b/jieba/posseg/viterbi.py index ce9e928..5f0682d 100644 --- a/jieba/posseg/viterbi.py +++ b/jieba/posseg/viterbi.py @@ -6,36 +6,42 @@ MIN_INF = float("-inf") if sys.version_info[0] > 2: xrange = range + def get_top_states(t_state_v, K=4): return sorted(t_state_v, key=t_state_v.__getitem__, reverse=True)[:K] + def viterbi(obs, states, start_p, trans_p, emit_p): - V = [{}] #tabular + V = [{}] # tabular mem_path = [{}] all_states = trans_p.keys() - for y in states.get(obs[0], all_states): #init + for y in states.get(obs[0], all_states): # init V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT) mem_path[0][y] = '' for t in xrange(1, len(obs)): V.append({}) mem_path.append({}) #prev_states = get_top_states(V[t-1]) - prev_states = [x for x in mem_path[t-1].keys() if len(trans_p[x]) > 0] + prev_states = [ + x for x in mem_path[t - 1].keys() if len(trans_p[x]) > 0] - prev_states_expect_next = set((y for x in prev_states for y in trans_p[x].keys())) - obs_states = set(states.get(obs[t], all_states)) & prev_states_expect_next + prev_states_expect_next = set( + (y for x in prev_states for y in trans_p[x].keys())) + obs_states = set( + states.get(obs[t], all_states)) & prev_states_expect_next if not obs_states: obs_states = prev_states_expect_next if prev_states_expect_next else all_states for y in obs_states: - prob, state = max((V[t-1][y0] + trans_p[y0].get(y,MIN_INF) + emit_p[y].get(obs[t],MIN_FLOAT), y0) for y0 in prev_states) + prob, state = max((V[t - 1][y0] + trans_p[y0].get(y, MIN_INF) + + emit_p[y].get(obs[t], MIN_FLOAT), y0) for y0 in prev_states) V[t][y] = prob mem_path[t][y] = state last = [(V[-1][y], y) for y in mem_path[-1].keys()] - #if len(last)==0: - #print obs + # if len(last)==0: + # print obs prob, state = max(last) route = [None] * len(obs)