mirror of
https://github.com/fxsjy/jieba.git
synced 2025-07-10 00:01:33 +08:00
make lazy load thread safe
This commit is contained in:
parent
d2460029d5
commit
bc049090a5
@ -9,9 +9,10 @@ import tempfile
|
||||
import marshal
|
||||
from math import log
|
||||
import random
|
||||
import threading
|
||||
|
||||
DICTIONARY = "dict.txt"
|
||||
|
||||
DICT_LOCK = threading.RLock()
|
||||
trie = None # to be initialized
|
||||
FREQ = {}
|
||||
min_freq = 0.0
|
||||
@ -45,38 +46,44 @@ def gen_trie(f_name):
|
||||
|
||||
def initialize(dictionary=DICTIONARY):
|
||||
global trie, FREQ, total, min_freq, initialized
|
||||
_curpath=os.path.normpath( os.path.join( os.getcwd(), os.path.dirname(__file__) ) )
|
||||
with DICT_LOCK:
|
||||
if initialized:
|
||||
return
|
||||
if trie:
|
||||
del trie
|
||||
trie = None
|
||||
_curpath=os.path.normpath( os.path.join( os.getcwd(), os.path.dirname(__file__) ) )
|
||||
|
||||
print >> sys.stderr, "Building Trie..."
|
||||
t1 = time.time()
|
||||
cache_file = os.path.join(tempfile.gettempdir(),"jieba.cache")
|
||||
load_from_cache_fail = True
|
||||
if os.path.exists(cache_file) and os.path.getmtime(cache_file)>os.path.getmtime(os.path.join(_curpath,"dict.txt")):
|
||||
print >> sys.stderr, "loading model from cache"
|
||||
try:
|
||||
trie,FREQ,total,min_freq = marshal.load(open(cache_file,'rb'))
|
||||
load_from_cache_fail = False
|
||||
except:
|
||||
load_from_cache_fail = True
|
||||
print >> sys.stderr, "Building Trie..."
|
||||
t1 = time.time()
|
||||
cache_file = os.path.join(tempfile.gettempdir(),"jieba.cache")
|
||||
load_from_cache_fail = True
|
||||
if os.path.exists(cache_file) and os.path.getmtime(cache_file)>os.path.getmtime(os.path.join(_curpath,dictionary)):
|
||||
print >> sys.stderr, "loading model from cache"
|
||||
try:
|
||||
trie,FREQ,total,min_freq = marshal.load(open(cache_file,'rb'))
|
||||
load_from_cache_fail = False
|
||||
except:
|
||||
load_from_cache_fail = True
|
||||
|
||||
if load_from_cache_fail:
|
||||
trie,FREQ,total = gen_trie(os.path.join(_curpath, dictionary))
|
||||
FREQ = dict([(k,log(float(v)/total)) for k,v in FREQ.iteritems()]) #normalize
|
||||
min_freq = min(FREQ.itervalues())
|
||||
print >> sys.stderr, "dumping model to file cache"
|
||||
tmp_suffix = "."+str(random.random())
|
||||
marshal.dump((trie,FREQ,total,min_freq),open(cache_file+tmp_suffix,'wb'))
|
||||
if os.name=='nt':
|
||||
import shutil
|
||||
replace_file = shutil.move
|
||||
else:
|
||||
replace_file = os.rename
|
||||
replace_file(cache_file+tmp_suffix,cache_file)
|
||||
if load_from_cache_fail:
|
||||
trie,FREQ,total = gen_trie(os.path.join(_curpath, dictionary))
|
||||
FREQ = dict([(k,log(float(v)/total)) for k,v in FREQ.iteritems()]) #normalize
|
||||
min_freq = min(FREQ.itervalues())
|
||||
print >> sys.stderr, "dumping model to file cache"
|
||||
tmp_suffix = "."+str(random.random())
|
||||
marshal.dump((trie,FREQ,total,min_freq),open(cache_file+tmp_suffix,'wb'))
|
||||
if os.name=='nt':
|
||||
import shutil
|
||||
replace_file = shutil.move
|
||||
else:
|
||||
replace_file = os.rename
|
||||
replace_file(cache_file+tmp_suffix,cache_file)
|
||||
|
||||
initialized = True
|
||||
initialized = True
|
||||
|
||||
print >> sys.stderr, "loading model cost ", time.time() - t1, "seconds."
|
||||
print >> sys.stderr, "Trie has been built succesfully."
|
||||
print >> sys.stderr, "loading model cost ", time.time() - t1, "seconds."
|
||||
print >> sys.stderr, "Trie has been built succesfully."
|
||||
|
||||
|
||||
def require_initialized(fn):
|
||||
@ -111,7 +118,6 @@ def calc(sentence,DAG,idx,route):
|
||||
candidates = [ ( FREQ.get(sentence[idx:x+1],min_freq) + route[x+1][0],x ) for x in DAG[idx] ]
|
||||
route[idx] = max(candidates)
|
||||
|
||||
|
||||
@require_initialized
|
||||
def get_DAG(sentence):
|
||||
N = len(sentence)
|
||||
@ -173,7 +179,6 @@ def __cut_DAG(sentence):
|
||||
regognized = finalseg.cut(buf)
|
||||
for t in regognized:
|
||||
yield t
|
||||
|
||||
def cut(sentence,cut_all=False):
|
||||
if not ( type(sentence) is unicode):
|
||||
try:
|
||||
@ -201,7 +206,6 @@ def cut(sentence,cut_all=False):
|
||||
else:
|
||||
for xx in x:
|
||||
yield xx
|
||||
|
||||
def cut_for_search(sentence):
|
||||
words = cut(sentence)
|
||||
for w in words:
|
||||
@ -252,6 +256,7 @@ def __lcut_all(sentence):
|
||||
def __lcut_for_search(sentence):
|
||||
return list(__ref_cut_for_search(sentence))
|
||||
|
||||
@require_initialized
|
||||
def enable_parallel(processnum):
|
||||
global pool,cut,cut_for_search
|
||||
if os.name=='nt':
|
||||
@ -290,6 +295,6 @@ def disable_parallel():
|
||||
|
||||
def set_dictionary(dictionary_path):
|
||||
global initialized, DICTIONARY
|
||||
DICTIONARY = dictionary_path
|
||||
if initialized:
|
||||
initialize()
|
||||
with DICT_LOCK:
|
||||
DICTIONARY = dictionary_path
|
||||
initialized = False
|
||||
|
@ -15,6 +15,6 @@ tm_cost = t2-t1
|
||||
log_f = open("1.log","wb")
|
||||
for w in words:
|
||||
print >> log_f, w.encode("gbk"), "/" ,
|
||||
|
||||
print 'cost',tm_cost
|
||||
print 'speed' , len(content)/tm_cost, " bytes/second"
|
||||
|
||||
|
29
test/test_multithread.py
Normal file
29
test/test_multithread.py
Normal file
@ -0,0 +1,29 @@
|
||||
#encoding=utf-8
|
||||
import sys
|
||||
import threading
|
||||
sys.path.append("../")
|
||||
|
||||
import jieba
|
||||
|
||||
class Worker(threading.Thread):
|
||||
def run(self):
|
||||
seg_list = jieba.cut("我来到北京清华大学",cut_all=True)
|
||||
print "Full Mode:" + "/ ".join(seg_list) #全模式
|
||||
|
||||
seg_list = jieba.cut("我来到北京清华大学",cut_all=False)
|
||||
print "Default Mode:" + "/ ".join(seg_list) #默认模式
|
||||
|
||||
seg_list = jieba.cut("他来到了网易杭研大厦")
|
||||
print ", ".join(seg_list)
|
||||
|
||||
seg_list = jieba.cut_for_search("小明硕士毕业于中国科学院计算所,后在日本京都大学深造") #搜索引擎模式
|
||||
print ", ".join(seg_list)
|
||||
workers = []
|
||||
for i in xrange(10):
|
||||
worker = Worker()
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
|
Loading…
x
Reference in New Issue
Block a user