mirror of
https://github.com/fxsjy/jieba.git
synced 2025-07-10 00:01:33 +08:00
don't compile re every time; autopep8
This commit is contained in:
parent
22bcf8be7a
commit
32a0e92a09
@ -18,24 +18,27 @@ from . import finalseg
|
|||||||
|
|
||||||
DICTIONARY = "dict.txt"
|
DICTIONARY = "dict.txt"
|
||||||
DICT_LOCK = threading.RLock()
|
DICT_LOCK = threading.RLock()
|
||||||
pfdict = None # to be initialized
|
pfdict = None # to be initialized
|
||||||
FREQ = {}
|
FREQ = {}
|
||||||
total = 0
|
total = 0
|
||||||
user_word_tag_tab = {}
|
user_word_tag_tab = {}
|
||||||
initialized = False
|
initialized = False
|
||||||
pool = None
|
pool = None
|
||||||
|
|
||||||
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
_curpath = os.path.normpath(
|
||||||
|
os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
|
|
||||||
log_console = logging.StreamHandler(sys.stderr)
|
log_console = logging.StreamHandler(sys.stderr)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
logger.addHandler(log_console)
|
logger.addHandler(log_console)
|
||||||
|
|
||||||
|
|
||||||
def setLogLevel(log_level):
|
def setLogLevel(log_level):
|
||||||
global logger
|
global logger
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
|
||||||
def gen_pfdict(f_name):
|
def gen_pfdict(f_name):
|
||||||
lfreq = {}
|
lfreq = {}
|
||||||
pfdict = set()
|
pfdict = set()
|
||||||
@ -50,12 +53,13 @@ def gen_pfdict(f_name):
|
|||||||
lfreq[word] = freq
|
lfreq[word] = freq
|
||||||
ltotal += freq
|
ltotal += freq
|
||||||
for ch in xrange(len(word)):
|
for ch in xrange(len(word)):
|
||||||
pfdict.add(word[:ch+1])
|
pfdict.add(word[:ch + 1])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.debug('%s at line %s %s' % (f_name, lineno, line))
|
logger.debug('%s at line %s %s' % (f_name, lineno, line))
|
||||||
raise e
|
raise e
|
||||||
return pfdict, lfreq, ltotal
|
return pfdict, lfreq, ltotal
|
||||||
|
|
||||||
|
|
||||||
def initialize(dictionary=None):
|
def initialize(dictionary=None):
|
||||||
global pfdict, FREQ, total, initialized, DICTIONARY, DICT_LOCK
|
global pfdict, FREQ, total, initialized, DICTIONARY, DICT_LOCK
|
||||||
if not dictionary:
|
if not dictionary:
|
||||||
@ -67,10 +71,12 @@ def initialize(dictionary=None):
|
|||||||
abs_path = os.path.join(_curpath, dictionary)
|
abs_path = os.path.join(_curpath, dictionary)
|
||||||
logger.debug("Building prefix dict from %s ..." % abs_path)
|
logger.debug("Building prefix dict from %s ..." % abs_path)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
if abs_path == os.path.join(_curpath, "dict.txt"): #default dictionary
|
# default dictionary
|
||||||
|
if abs_path == os.path.join(_curpath, "dict.txt"):
|
||||||
cache_file = os.path.join(tempfile.gettempdir(), "jieba.cache")
|
cache_file = os.path.join(tempfile.gettempdir(), "jieba.cache")
|
||||||
else: #custom dictionary
|
else: # custom dictionary
|
||||||
cache_file = os.path.join(tempfile.gettempdir(), "jieba.u%s.cache" % md5(abs_path.encode('utf-8', 'replace')).hexdigest())
|
cache_file = os.path.join(tempfile.gettempdir(), "jieba.u%s.cache" % md5(
|
||||||
|
abs_path.encode('utf-8', 'replace')).hexdigest())
|
||||||
|
|
||||||
load_from_cache_fail = True
|
load_from_cache_fail = True
|
||||||
if os.path.isfile(cache_file) and os.path.getmtime(cache_file) > os.path.getmtime(abs_path):
|
if os.path.isfile(cache_file) and os.path.getmtime(cache_file) > os.path.getmtime(abs_path):
|
||||||
@ -121,22 +127,24 @@ def require_initialized(fn):
|
|||||||
def __cut_all(sentence):
|
def __cut_all(sentence):
|
||||||
dag = get_DAG(sentence)
|
dag = get_DAG(sentence)
|
||||||
old_j = -1
|
old_j = -1
|
||||||
for k,L in iteritems(dag):
|
for k, L in iteritems(dag):
|
||||||
if len(L) == 1 and k > old_j:
|
if len(L) == 1 and k > old_j:
|
||||||
yield sentence[k:L[0]+1]
|
yield sentence[k:L[0] + 1]
|
||||||
old_j = L[0]
|
old_j = L[0]
|
||||||
else:
|
else:
|
||||||
for j in L:
|
for j in L:
|
||||||
if j > k:
|
if j > k:
|
||||||
yield sentence[k:j+1]
|
yield sentence[k:j + 1]
|
||||||
old_j = j
|
old_j = j
|
||||||
|
|
||||||
|
|
||||||
def calc(sentence, DAG, route):
|
def calc(sentence, DAG, route):
|
||||||
N = len(sentence)
|
N = len(sentence)
|
||||||
route[N] = (0.0, '')
|
route[N] = (0.0, '')
|
||||||
for idx in xrange(N-1, -1, -1):
|
for idx in xrange(N - 1, -1, -1):
|
||||||
route[idx] = max((log(FREQ.get(sentence[idx:x+1], 1)) - log(total) + route[x+1][0], x) for x in DAG[idx])
|
route[idx] = max((log(FREQ.get(sentence[idx:x + 1], 1)) -
|
||||||
|
log(total) + route[x + 1][0], x) for x in DAG[idx])
|
||||||
|
|
||||||
|
|
||||||
@require_initialized
|
@require_initialized
|
||||||
def get_DAG(sentence):
|
def get_DAG(sentence):
|
||||||
@ -151,14 +159,16 @@ def get_DAG(sentence):
|
|||||||
if frag in FREQ:
|
if frag in FREQ:
|
||||||
tmplist.append(i)
|
tmplist.append(i)
|
||||||
i += 1
|
i += 1
|
||||||
frag = sentence[k:i+1]
|
frag = sentence[k:i + 1]
|
||||||
if not tmplist:
|
if not tmplist:
|
||||||
tmplist.append(k)
|
tmplist.append(k)
|
||||||
DAG[k] = tmplist
|
DAG[k] = tmplist
|
||||||
return DAG
|
return DAG
|
||||||
|
|
||||||
|
re_eng = re.compile(r'[a-zA-Z0-9]', re.U)
|
||||||
|
|
||||||
|
|
||||||
def __cut_DAG_NO_HMM(sentence):
|
def __cut_DAG_NO_HMM(sentence):
|
||||||
re_eng = re.compile(r'[a-zA-Z0-9]',re.U)
|
|
||||||
DAG = get_DAG(sentence)
|
DAG = get_DAG(sentence)
|
||||||
route = {}
|
route = {}
|
||||||
calc(sentence, DAG, route)
|
calc(sentence, DAG, route)
|
||||||
@ -181,6 +191,7 @@ def __cut_DAG_NO_HMM(sentence):
|
|||||||
yield buf
|
yield buf
|
||||||
buf = ''
|
buf = ''
|
||||||
|
|
||||||
|
|
||||||
def __cut_DAG(sentence):
|
def __cut_DAG(sentence):
|
||||||
DAG = get_DAG(sentence)
|
DAG = get_DAG(sentence)
|
||||||
route = {}
|
route = {}
|
||||||
@ -189,9 +200,9 @@ def __cut_DAG(sentence):
|
|||||||
buf = ''
|
buf = ''
|
||||||
N = len(sentence)
|
N = len(sentence)
|
||||||
while x < N:
|
while x < N:
|
||||||
y = route[x][1]+1
|
y = route[x][1] + 1
|
||||||
l_word = sentence[x:y]
|
l_word = sentence[x:y]
|
||||||
if y-x == 1:
|
if y - x == 1:
|
||||||
buf += l_word
|
buf += l_word
|
||||||
else:
|
else:
|
||||||
if buf:
|
if buf:
|
||||||
@ -221,6 +232,12 @@ def __cut_DAG(sentence):
|
|||||||
for elem in buf:
|
for elem in buf:
|
||||||
yield elem
|
yield elem
|
||||||
|
|
||||||
|
re_han_default = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U)
|
||||||
|
re_skip_default = re.compile("(\r\n|\s)", re.U)
|
||||||
|
re_han_cut_all = re.compile("([\u4E00-\u9FA5]+)", re.U)
|
||||||
|
re_skip_cut_all = re.compile("[^a-zA-Z0-9+#\n]", re.U)
|
||||||
|
|
||||||
|
|
||||||
def cut(sentence, cut_all=False, HMM=True):
|
def cut(sentence, cut_all=False, HMM=True):
|
||||||
'''The main function that segments an entire sentence that contains
|
'''The main function that segments an entire sentence that contains
|
||||||
Chinese characters into seperated words.
|
Chinese characters into seperated words.
|
||||||
@ -235,9 +252,11 @@ def cut(sentence, cut_all=False, HMM=True):
|
|||||||
# \r\n|\s : whitespace characters. Will not be handled.
|
# \r\n|\s : whitespace characters. Will not be handled.
|
||||||
|
|
||||||
if cut_all:
|
if cut_all:
|
||||||
re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)", re.U), re.compile("[^a-zA-Z0-9+#\n]", re.U)
|
re_han = re_han_cut_all
|
||||||
|
re_skip = re_skip_cut_all
|
||||||
else:
|
else:
|
||||||
re_han, re_skip = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U), re.compile("(\r\n|\s)", re.U)
|
re_han = re_han_default
|
||||||
|
re_skip = re_skip_default
|
||||||
blocks = re_han.split(sentence)
|
blocks = re_han.split(sentence)
|
||||||
if cut_all:
|
if cut_all:
|
||||||
cut_block = __cut_all
|
cut_block = __cut_all
|
||||||
@ -262,21 +281,23 @@ def cut(sentence, cut_all=False, HMM=True):
|
|||||||
else:
|
else:
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
|
|
||||||
def cut_for_search(sentence, HMM=True):
|
def cut_for_search(sentence, HMM=True):
|
||||||
words = cut(sentence, HMM=HMM)
|
words = cut(sentence, HMM=HMM)
|
||||||
for w in words:
|
for w in words:
|
||||||
if len(w) > 2:
|
if len(w) > 2:
|
||||||
for i in xrange(len(w)-1):
|
for i in xrange(len(w) - 1):
|
||||||
gram2 = w[i:i+2]
|
gram2 = w[i:i + 2]
|
||||||
if gram2 in FREQ:
|
if gram2 in FREQ:
|
||||||
yield gram2
|
yield gram2
|
||||||
if len(w) > 3:
|
if len(w) > 3:
|
||||||
for i in xrange(len(w)-2):
|
for i in xrange(len(w) - 2):
|
||||||
gram3 = w[i:i+3]
|
gram3 = w[i:i + 3]
|
||||||
if gram3 in FREQ:
|
if gram3 in FREQ:
|
||||||
yield gram3
|
yield gram3
|
||||||
yield w
|
yield w
|
||||||
|
|
||||||
|
|
||||||
@require_initialized
|
@require_initialized
|
||||||
def load_userdict(f):
|
def load_userdict(f):
|
||||||
''' Load personalized dict to improve detect rate.
|
''' Load personalized dict to improve detect rate.
|
||||||
@ -300,6 +321,7 @@ def load_userdict(f):
|
|||||||
if tup[1].isdigit():
|
if tup[1].isdigit():
|
||||||
add_word(*tup)
|
add_word(*tup)
|
||||||
|
|
||||||
|
|
||||||
@require_initialized
|
@require_initialized
|
||||||
def add_word(word, freq, tag=None):
|
def add_word(word, freq, tag=None):
|
||||||
global FREQ, pfdict, total, user_word_tag_tab
|
global FREQ, pfdict, total, user_word_tag_tab
|
||||||
@ -309,17 +331,24 @@ def add_word(word, freq, tag=None):
|
|||||||
if tag is not None:
|
if tag is not None:
|
||||||
user_word_tag_tab[word] = tag
|
user_word_tag_tab[word] = tag
|
||||||
for ch in xrange(len(word)):
|
for ch in xrange(len(word)):
|
||||||
pfdict.add(word[:ch+1])
|
pfdict.add(word[:ch + 1])
|
||||||
|
|
||||||
__ref_cut = cut
|
__ref_cut = cut
|
||||||
__ref_cut_for_search = cut_for_search
|
__ref_cut_for_search = cut_for_search
|
||||||
|
|
||||||
|
|
||||||
def __lcut(sentence):
|
def __lcut(sentence):
|
||||||
return list(__ref_cut(sentence, False))
|
return list(__ref_cut(sentence, False))
|
||||||
|
|
||||||
|
|
||||||
def __lcut_no_hmm(sentence):
|
def __lcut_no_hmm(sentence):
|
||||||
return list(__ref_cut(sentence, False, False))
|
return list(__ref_cut(sentence, False, False))
|
||||||
|
|
||||||
|
|
||||||
def __lcut_all(sentence):
|
def __lcut_all(sentence):
|
||||||
return list(__ref_cut(sentence, True))
|
return list(__ref_cut(sentence, True))
|
||||||
|
|
||||||
|
|
||||||
def __lcut_for_search(sentence):
|
def __lcut_for_search(sentence):
|
||||||
return list(__ref_cut_for_search(sentence))
|
return list(__ref_cut_for_search(sentence))
|
||||||
|
|
||||||
@ -334,7 +363,7 @@ def enable_parallel(processnum=None):
|
|||||||
processnum = cpu_count()
|
processnum = cpu_count()
|
||||||
pool = Pool(processnum)
|
pool = Pool(processnum)
|
||||||
|
|
||||||
def pcut(sentence,cut_all=False,HMM=True):
|
def pcut(sentence, cut_all=False, HMM=True):
|
||||||
parts = strdecode(sentence).split('\n')
|
parts = strdecode(sentence).split('\n')
|
||||||
if cut_all:
|
if cut_all:
|
||||||
result = pool.map(__lcut_all, parts)
|
result = pool.map(__lcut_all, parts)
|
||||||
@ -356,6 +385,7 @@ def enable_parallel(processnum=None):
|
|||||||
cut = pcut
|
cut = pcut
|
||||||
cut_for_search = pcut_for_search
|
cut_for_search = pcut_for_search
|
||||||
|
|
||||||
|
|
||||||
def disable_parallel():
|
def disable_parallel():
|
||||||
global pool, cut, cut_for_search
|
global pool, cut, cut_for_search
|
||||||
if pool:
|
if pool:
|
||||||
@ -364,6 +394,7 @@ def disable_parallel():
|
|||||||
cut = __ref_cut
|
cut = __ref_cut
|
||||||
cut_for_search = __ref_cut_for_search
|
cut_for_search = __ref_cut_for_search
|
||||||
|
|
||||||
|
|
||||||
def set_dictionary(dictionary_path):
|
def set_dictionary(dictionary_path):
|
||||||
global initialized, DICTIONARY
|
global initialized, DICTIONARY
|
||||||
with DICT_LOCK:
|
with DICT_LOCK:
|
||||||
@ -373,9 +404,11 @@ def set_dictionary(dictionary_path):
|
|||||||
DICTIONARY = abs_path
|
DICTIONARY = abs_path
|
||||||
initialized = False
|
initialized = False
|
||||||
|
|
||||||
|
|
||||||
def get_abs_path_dict():
|
def get_abs_path_dict():
|
||||||
return os.path.join(_curpath, DICTIONARY)
|
return os.path.join(_curpath, DICTIONARY)
|
||||||
|
|
||||||
|
|
||||||
def tokenize(unicode_sentence, mode="default", HMM=True):
|
def tokenize(unicode_sentence, mode="default", HMM=True):
|
||||||
"""Tokenize a sentence and yields tuples of (word, start, end)
|
"""Tokenize a sentence and yields tuples of (word, start, end)
|
||||||
Parameter:
|
Parameter:
|
||||||
@ -389,20 +422,20 @@ def tokenize(unicode_sentence, mode="default", HMM=True):
|
|||||||
if mode == 'default':
|
if mode == 'default':
|
||||||
for w in cut(unicode_sentence, HMM=HMM):
|
for w in cut(unicode_sentence, HMM=HMM):
|
||||||
width = len(w)
|
width = len(w)
|
||||||
yield (w, start, start+width)
|
yield (w, start, start + width)
|
||||||
start += width
|
start += width
|
||||||
else:
|
else:
|
||||||
for w in cut(unicode_sentence, HMM=HMM):
|
for w in cut(unicode_sentence, HMM=HMM):
|
||||||
width = len(w)
|
width = len(w)
|
||||||
if len(w) > 2:
|
if len(w) > 2:
|
||||||
for i in xrange(len(w)-1):
|
for i in xrange(len(w) - 1):
|
||||||
gram2 = w[i:i+2]
|
gram2 = w[i:i + 2]
|
||||||
if gram2 in FREQ:
|
if gram2 in FREQ:
|
||||||
yield (gram2, start+i, start+i+2)
|
yield (gram2, start + i, start + i + 2)
|
||||||
if len(w) > 3:
|
if len(w) > 3:
|
||||||
for i in xrange(len(w)-2):
|
for i in xrange(len(w) - 2):
|
||||||
gram3 = w[i:i+3]
|
gram3 = w[i:i + 3]
|
||||||
if gram3 in FREQ:
|
if gram3 in FREQ:
|
||||||
yield (gram3, start+i, start+i+3)
|
yield (gram3, start + i, start + i + 3)
|
||||||
yield (w, start, start+width)
|
yield (w, start, start + width)
|
||||||
start += width
|
start += width
|
||||||
|
@ -13,14 +13,16 @@ PROB_EMIT_P = "prob_emit.p"
|
|||||||
|
|
||||||
|
|
||||||
PrevStatus = {
|
PrevStatus = {
|
||||||
'B':('E','S'),
|
'B': 'ES',
|
||||||
'M':('M','B'),
|
'M': 'MB',
|
||||||
'S':('S','E'),
|
'S': 'SE',
|
||||||
'E':('B','M')
|
'E': 'BM'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
_curpath = os.path.normpath(
|
||||||
|
os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
|
|
||||||
start_p = {}
|
start_p = {}
|
||||||
abs_path = os.path.join(_curpath, PROB_START_P)
|
abs_path = os.path.join(_curpath, PROB_START_P)
|
||||||
@ -44,50 +46,55 @@ if sys.platform.startswith("java"):
|
|||||||
else:
|
else:
|
||||||
from .prob_start import P as start_P
|
from .prob_start import P as start_P
|
||||||
from .prob_trans import P as trans_P
|
from .prob_trans import P as trans_P
|
||||||
from .prob_emit import P as emit_P
|
from .prob_emit import P as emit_P
|
||||||
|
|
||||||
|
|
||||||
def viterbi(obs, states, start_p, trans_p, emit_p):
|
def viterbi(obs, states, start_p, trans_p, emit_p):
|
||||||
V = [{}] #tabular
|
V = [{}] # tabular
|
||||||
path = {}
|
path = {}
|
||||||
for y in states: #init
|
for y in states: # init
|
||||||
V[0][y] = start_p[y] + emit_p[y].get(obs[0],MIN_FLOAT)
|
V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT)
|
||||||
path[y] = [y]
|
path[y] = [y]
|
||||||
for t in xrange(1, len(obs)):
|
for t in xrange(1, len(obs)):
|
||||||
V.append({})
|
V.append({})
|
||||||
newpath = {}
|
newpath = {}
|
||||||
for y in states:
|
for y in states:
|
||||||
em_p = emit_p[y].get(obs[t],MIN_FLOAT)
|
em_p = emit_p[y].get(obs[t], MIN_FLOAT)
|
||||||
(prob,state ) = max([(V[t-1][y0] + trans_p[y0].get(y, MIN_FLOAT) + em_p, y0) for y0 in PrevStatus[y]])
|
(prob, state) = max(
|
||||||
|
[(V[t - 1][y0] + trans_p[y0].get(y, MIN_FLOAT) + em_p, y0) for y0 in PrevStatus[y]])
|
||||||
V[t][y] = prob
|
V[t][y] = prob
|
||||||
newpath[y] = path[state] + [y]
|
newpath[y] = path[state] + [y]
|
||||||
path = newpath
|
path = newpath
|
||||||
|
|
||||||
(prob, state) = max([(V[len(obs)-1][y], y) for y in ('E','S')])
|
(prob, state) = max((V[len(obs) - 1][y], y) for y in 'ES')
|
||||||
|
|
||||||
return (prob, path[state])
|
return (prob, path[state])
|
||||||
|
|
||||||
|
|
||||||
def __cut(sentence):
|
def __cut(sentence):
|
||||||
global emit_P
|
global emit_P
|
||||||
prob, pos_list = viterbi(sentence, ('B','M','E','S'), start_P, trans_P, emit_P)
|
prob, pos_list = viterbi(sentence, 'BMES', start_P, trans_P, emit_P)
|
||||||
begin, nexti = 0, 0
|
begin, nexti = 0, 0
|
||||||
#print pos_list, sentence
|
# print pos_list, sentence
|
||||||
for i,char in enumerate(sentence):
|
for i, char in enumerate(sentence):
|
||||||
pos = pos_list[i]
|
pos = pos_list[i]
|
||||||
if pos == 'B':
|
if pos == 'B':
|
||||||
begin = i
|
begin = i
|
||||||
elif pos == 'E':
|
elif pos == 'E':
|
||||||
yield sentence[begin:i+1]
|
yield sentence[begin:i + 1]
|
||||||
nexti = i+1
|
nexti = i + 1
|
||||||
elif pos == 'S':
|
elif pos == 'S':
|
||||||
yield char
|
yield char
|
||||||
nexti = i+1
|
nexti = i + 1
|
||||||
if nexti < len(sentence):
|
if nexti < len(sentence):
|
||||||
yield sentence[nexti:]
|
yield sentence[nexti:]
|
||||||
|
|
||||||
|
re_han = re.compile("([\u4E00-\u9FA5]+)")
|
||||||
|
re_skip = re.compile("(\d+\.\d+|[a-zA-Z0-9]+)")
|
||||||
|
|
||||||
|
|
||||||
def cut(sentence):
|
def cut(sentence):
|
||||||
sentence = strdecode(sentence)
|
sentence = strdecode(sentence)
|
||||||
re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)"), re.compile("(\d+\.\d+|[a-zA-Z0-9]+)")
|
|
||||||
blocks = re_han.split(sentence)
|
blocks = re_han.split(sentence)
|
||||||
for blk in blocks:
|
for blk in blocks:
|
||||||
if re_han.match(blk):
|
if re_han.match(blk):
|
||||||
|
@ -13,8 +13,20 @@ PROB_TRANS_P = "prob_trans.p"
|
|||||||
PROB_EMIT_P = "prob_emit.p"
|
PROB_EMIT_P = "prob_emit.p"
|
||||||
CHAR_STATE_TAB_P = "char_state_tab.p"
|
CHAR_STATE_TAB_P = "char_state_tab.p"
|
||||||
|
|
||||||
|
re_han_detail = re.compile("([\u4E00-\u9FA5]+)")
|
||||||
|
re_skip_detail = re.compile("([\.0-9]+|[a-zA-Z0-9]+)")
|
||||||
|
re_han_internal = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)")
|
||||||
|
re_skip_internal = re.compile("(\r\n|\s)")
|
||||||
|
|
||||||
|
re_eng = re.compile("[a-zA-Z0-9]+")
|
||||||
|
re_num = re.compile("[\.0-9]+")
|
||||||
|
|
||||||
|
re_eng1 = re.compile('^[a-zA-Z0-9]$', re.U)
|
||||||
|
|
||||||
|
|
||||||
def load_model(f_name, isJython=True):
|
def load_model(f_name, isJython=True):
|
||||||
_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
_curpath = os.path.normpath(
|
||||||
|
os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
with open(f_name, "rb") as f:
|
with open(f_name, "rb") as f:
|
||||||
@ -53,28 +65,32 @@ def load_model(f_name, isJython=True):
|
|||||||
return state, start_p, trans_p, emit_p, result
|
return state, start_p, trans_p, emit_p, result
|
||||||
|
|
||||||
if sys.platform.startswith("java"):
|
if sys.platform.startswith("java"):
|
||||||
char_state_tab_P, start_P, trans_P, emit_P, word_tag_tab = load_model(jieba.get_abs_path_dict())
|
char_state_tab_P, start_P, trans_P, emit_P, word_tag_tab = load_model(
|
||||||
|
jieba.get_abs_path_dict())
|
||||||
else:
|
else:
|
||||||
from .char_state_tab import P as char_state_tab_P
|
from .char_state_tab import P as char_state_tab_P
|
||||||
from .prob_start import P as start_P
|
from .prob_start import P as start_P
|
||||||
from .prob_trans import P as trans_P
|
from .prob_trans import P as trans_P
|
||||||
from .prob_emit import P as emit_P
|
from .prob_emit import P as emit_P
|
||||||
|
|
||||||
word_tag_tab = load_model(jieba.get_abs_path_dict(), isJython=False)
|
word_tag_tab = load_model(jieba.get_abs_path_dict(), isJython=False)
|
||||||
|
|
||||||
|
|
||||||
def makesure_userdict_loaded(fn):
|
def makesure_userdict_loaded(fn):
|
||||||
|
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def wrapped(*args,**kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
if jieba.user_word_tag_tab:
|
if jieba.user_word_tag_tab:
|
||||||
word_tag_tab.update(jieba.user_word_tag_tab)
|
word_tag_tab.update(jieba.user_word_tag_tab)
|
||||||
jieba.user_word_tag_tab = {}
|
jieba.user_word_tag_tab = {}
|
||||||
return fn(*args,**kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
class pair(object):
|
class pair(object):
|
||||||
def __init__(self,word,flag):
|
|
||||||
|
def __init__(self, word, flag):
|
||||||
self.word = word
|
self.word = word
|
||||||
self.flag = flag
|
self.flag = flag
|
||||||
|
|
||||||
@ -90,36 +106,37 @@ class pair(object):
|
|||||||
else:
|
else:
|
||||||
return self.__unicode__()
|
return self.__unicode__()
|
||||||
|
|
||||||
def encode(self,arg):
|
def encode(self, arg):
|
||||||
return self.__unicode__().encode(arg)
|
return self.__unicode__().encode(arg)
|
||||||
|
|
||||||
|
|
||||||
def __cut(sentence):
|
def __cut(sentence):
|
||||||
prob, pos_list = viterbi(sentence, char_state_tab_P, start_P, trans_P, emit_P)
|
prob, pos_list = viterbi(
|
||||||
|
sentence, char_state_tab_P, start_P, trans_P, emit_P)
|
||||||
begin, nexti = 0, 0
|
begin, nexti = 0, 0
|
||||||
|
|
||||||
for i,char in enumerate(sentence):
|
for i, char in enumerate(sentence):
|
||||||
pos = pos_list[i][0]
|
pos = pos_list[i][0]
|
||||||
if pos == 'B':
|
if pos == 'B':
|
||||||
begin = i
|
begin = i
|
||||||
elif pos == 'E':
|
elif pos == 'E':
|
||||||
yield pair(sentence[begin:i+1], pos_list[i][1])
|
yield pair(sentence[begin:i + 1], pos_list[i][1])
|
||||||
nexti = i+1
|
nexti = i + 1
|
||||||
elif pos == 'S':
|
elif pos == 'S':
|
||||||
yield pair(char, pos_list[i][1])
|
yield pair(char, pos_list[i][1])
|
||||||
nexti = i+1
|
nexti = i + 1
|
||||||
if nexti < len(sentence):
|
if nexti < len(sentence):
|
||||||
yield pair(sentence[nexti:], pos_list[nexti][1])
|
yield pair(sentence[nexti:], pos_list[nexti][1])
|
||||||
|
|
||||||
|
|
||||||
def __cut_detail(sentence):
|
def __cut_detail(sentence):
|
||||||
re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)"), re.compile("([\.0-9]+|[a-zA-Z0-9]+)")
|
blocks = re_han_detail.split(sentence)
|
||||||
re_eng, re_num = re.compile("[a-zA-Z0-9]+"), re.compile("[\.0-9]+")
|
|
||||||
blocks = re_han.split(sentence)
|
|
||||||
for blk in blocks:
|
for blk in blocks:
|
||||||
if re_han.match(blk):
|
if re_han_detail.match(blk):
|
||||||
for word in __cut(blk):
|
for word in __cut(blk):
|
||||||
yield word
|
yield word
|
||||||
else:
|
else:
|
||||||
tmp = re_skip.split(blk)
|
tmp = re_skip_detail.split(blk)
|
||||||
for x in tmp:
|
for x in tmp:
|
||||||
if x:
|
if x:
|
||||||
if re_num.match(x):
|
if re_num.match(x):
|
||||||
@ -129,6 +146,7 @@ def __cut_detail(sentence):
|
|||||||
else:
|
else:
|
||||||
yield pair(x, 'x')
|
yield pair(x, 'x')
|
||||||
|
|
||||||
|
|
||||||
def __cut_DAG_NO_HMM(sentence):
|
def __cut_DAG_NO_HMM(sentence):
|
||||||
DAG = jieba.get_DAG(sentence)
|
DAG = jieba.get_DAG(sentence)
|
||||||
route = {}
|
route = {}
|
||||||
@ -136,23 +154,23 @@ def __cut_DAG_NO_HMM(sentence):
|
|||||||
x = 0
|
x = 0
|
||||||
N = len(sentence)
|
N = len(sentence)
|
||||||
buf = ''
|
buf = ''
|
||||||
re_eng = re.compile('[a-zA-Z0-9]',re.U)
|
|
||||||
while x < N:
|
while x < N:
|
||||||
y = route[x][1]+1
|
y = route[x][1] + 1
|
||||||
l_word = sentence[x:y]
|
l_word = sentence[x:y]
|
||||||
if re_eng.match(l_word) and len(l_word) == 1:
|
if re_eng1.match(l_word):
|
||||||
buf += l_word
|
buf += l_word
|
||||||
x = y
|
x = y
|
||||||
else:
|
else:
|
||||||
if buf:
|
if buf:
|
||||||
yield pair(buf,'eng')
|
yield pair(buf, 'eng')
|
||||||
buf = ''
|
buf = ''
|
||||||
yield pair(l_word, word_tag_tab.get(l_word, 'x'))
|
yield pair(l_word, word_tag_tab.get(l_word, 'x'))
|
||||||
x = y
|
x = y
|
||||||
if buf:
|
if buf:
|
||||||
yield pair(buf,'eng')
|
yield pair(buf, 'eng')
|
||||||
buf = ''
|
buf = ''
|
||||||
|
|
||||||
|
|
||||||
def __cut_DAG(sentence):
|
def __cut_DAG(sentence):
|
||||||
DAG = jieba.get_DAG(sentence)
|
DAG = jieba.get_DAG(sentence)
|
||||||
route = {}
|
route = {}
|
||||||
@ -163,9 +181,9 @@ def __cut_DAG(sentence):
|
|||||||
buf = ''
|
buf = ''
|
||||||
N = len(sentence)
|
N = len(sentence)
|
||||||
while x < N:
|
while x < N:
|
||||||
y = route[x][1]+1
|
y = route[x][1] + 1
|
||||||
l_word = sentence[x:y]
|
l_word = sentence[x:y]
|
||||||
if y-x == 1:
|
if y - x == 1:
|
||||||
buf += l_word
|
buf += l_word
|
||||||
else:
|
else:
|
||||||
if buf:
|
if buf:
|
||||||
@ -193,24 +211,23 @@ def __cut_DAG(sentence):
|
|||||||
for elem in buf:
|
for elem in buf:
|
||||||
yield pair(elem, word_tag_tab.get(elem, 'x'))
|
yield pair(elem, word_tag_tab.get(elem, 'x'))
|
||||||
|
|
||||||
|
|
||||||
def __cut_internal(sentence, HMM=True):
|
def __cut_internal(sentence, HMM=True):
|
||||||
sentence = strdecode(sentence)
|
sentence = strdecode(sentence)
|
||||||
re_han, re_skip = re.compile("([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)"), re.compile("(\r\n|\s)")
|
blocks = re_han_internal.split(sentence)
|
||||||
re_eng, re_num = re.compile("[a-zA-Z0-9]+"), re.compile("[\.0-9]+")
|
|
||||||
blocks = re_han.split(sentence)
|
|
||||||
if HMM:
|
if HMM:
|
||||||
__cut_blk = __cut_DAG
|
__cut_blk = __cut_DAG
|
||||||
else:
|
else:
|
||||||
__cut_blk = __cut_DAG_NO_HMM
|
__cut_blk = __cut_DAG_NO_HMM
|
||||||
|
|
||||||
for blk in blocks:
|
for blk in blocks:
|
||||||
if re_han.match(blk):
|
if re_han_internal.match(blk):
|
||||||
for word in __cut_blk(blk):
|
for word in __cut_blk(blk):
|
||||||
yield word
|
yield word
|
||||||
else:
|
else:
|
||||||
tmp = re_skip.split(blk)
|
tmp = re_skip_internal.split(blk)
|
||||||
for x in tmp:
|
for x in tmp:
|
||||||
if re_skip.match(x):
|
if re_skip_internal.match(x):
|
||||||
yield pair(x, 'x')
|
yield pair(x, 'x')
|
||||||
else:
|
else:
|
||||||
for xx in x:
|
for xx in x:
|
||||||
@ -221,10 +238,13 @@ def __cut_internal(sentence, HMM=True):
|
|||||||
else:
|
else:
|
||||||
yield pair(xx, 'x')
|
yield pair(xx, 'x')
|
||||||
|
|
||||||
|
|
||||||
def __lcut_internal(sentence):
|
def __lcut_internal(sentence):
|
||||||
return list(__cut_internal(sentence))
|
return list(__cut_internal(sentence))
|
||||||
|
|
||||||
|
|
||||||
def __lcut_internal_no_hmm(sentence):
|
def __lcut_internal_no_hmm(sentence):
|
||||||
return list(__cut_internal(sentence,False))
|
return list(__cut_internal(sentence, False))
|
||||||
|
|
||||||
|
|
||||||
@makesure_userdict_loaded
|
@makesure_userdict_loaded
|
||||||
@ -241,4 +261,3 @@ def cut(sentence, HMM=True):
|
|||||||
for r in result:
|
for r in result:
|
||||||
for w in r:
|
for w in r:
|
||||||
yield w
|
yield w
|
||||||
|
|
||||||
|
@ -6,36 +6,42 @@ MIN_INF = float("-inf")
|
|||||||
if sys.version_info[0] > 2:
|
if sys.version_info[0] > 2:
|
||||||
xrange = range
|
xrange = range
|
||||||
|
|
||||||
|
|
||||||
def get_top_states(t_state_v, K=4):
|
def get_top_states(t_state_v, K=4):
|
||||||
return sorted(t_state_v, key=t_state_v.__getitem__, reverse=True)[:K]
|
return sorted(t_state_v, key=t_state_v.__getitem__, reverse=True)[:K]
|
||||||
|
|
||||||
|
|
||||||
def viterbi(obs, states, start_p, trans_p, emit_p):
|
def viterbi(obs, states, start_p, trans_p, emit_p):
|
||||||
V = [{}] #tabular
|
V = [{}] # tabular
|
||||||
mem_path = [{}]
|
mem_path = [{}]
|
||||||
all_states = trans_p.keys()
|
all_states = trans_p.keys()
|
||||||
for y in states.get(obs[0], all_states): #init
|
for y in states.get(obs[0], all_states): # init
|
||||||
V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT)
|
V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT)
|
||||||
mem_path[0][y] = ''
|
mem_path[0][y] = ''
|
||||||
for t in xrange(1, len(obs)):
|
for t in xrange(1, len(obs)):
|
||||||
V.append({})
|
V.append({})
|
||||||
mem_path.append({})
|
mem_path.append({})
|
||||||
#prev_states = get_top_states(V[t-1])
|
#prev_states = get_top_states(V[t-1])
|
||||||
prev_states = [x for x in mem_path[t-1].keys() if len(trans_p[x]) > 0]
|
prev_states = [
|
||||||
|
x for x in mem_path[t - 1].keys() if len(trans_p[x]) > 0]
|
||||||
|
|
||||||
prev_states_expect_next = set((y for x in prev_states for y in trans_p[x].keys()))
|
prev_states_expect_next = set(
|
||||||
obs_states = set(states.get(obs[t], all_states)) & prev_states_expect_next
|
(y for x in prev_states for y in trans_p[x].keys()))
|
||||||
|
obs_states = set(
|
||||||
|
states.get(obs[t], all_states)) & prev_states_expect_next
|
||||||
|
|
||||||
if not obs_states:
|
if not obs_states:
|
||||||
obs_states = prev_states_expect_next if prev_states_expect_next else all_states
|
obs_states = prev_states_expect_next if prev_states_expect_next else all_states
|
||||||
|
|
||||||
for y in obs_states:
|
for y in obs_states:
|
||||||
prob, state = max((V[t-1][y0] + trans_p[y0].get(y,MIN_INF) + emit_p[y].get(obs[t],MIN_FLOAT), y0) for y0 in prev_states)
|
prob, state = max((V[t - 1][y0] + trans_p[y0].get(y, MIN_INF) +
|
||||||
|
emit_p[y].get(obs[t], MIN_FLOAT), y0) for y0 in prev_states)
|
||||||
V[t][y] = prob
|
V[t][y] = prob
|
||||||
mem_path[t][y] = state
|
mem_path[t][y] = state
|
||||||
|
|
||||||
last = [(V[-1][y], y) for y in mem_path[-1].keys()]
|
last = [(V[-1][y], y) for y in mem_path[-1].keys()]
|
||||||
#if len(last)==0:
|
# if len(last)==0:
|
||||||
#print obs
|
# print obs
|
||||||
prob, state = max(last)
|
prob, state = max(last)
|
||||||
|
|
||||||
route = [None] * len(obs)
|
route = [None] * len(obs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user