From bc049090a5eeb7e00d60c3ae014646442921777d Mon Sep 17 00:00:00 2001 From: fxsjy Date: Fri, 26 Apr 2013 12:54:05 +0800 Subject: [PATCH] make lazy load thread safe --- jieba/__init__.py | 75 +++++++++++++++++++++------------------- test/test_file.py | 2 +- test/test_multithread.py | 29 ++++++++++++++++ 3 files changed, 70 insertions(+), 36 deletions(-) create mode 100644 test/test_multithread.py diff --git a/jieba/__init__.py b/jieba/__init__.py index f814b7a..a2c866e 100644 --- a/jieba/__init__.py +++ b/jieba/__init__.py @@ -9,9 +9,10 @@ import tempfile import marshal from math import log import random +import threading DICTIONARY = "dict.txt" - +DICT_LOCK = threading.RLock() trie = None # to be initialized FREQ = {} min_freq = 0.0 @@ -45,38 +46,44 @@ def gen_trie(f_name): def initialize(dictionary=DICTIONARY): global trie, FREQ, total, min_freq, initialized - _curpath=os.path.normpath( os.path.join( os.getcwd(), os.path.dirname(__file__) ) ) + with DICT_LOCK: + if initialized: + return + if trie: + del trie + trie = None + _curpath=os.path.normpath( os.path.join( os.getcwd(), os.path.dirname(__file__) ) ) - print >> sys.stderr, "Building Trie..." - t1 = time.time() - cache_file = os.path.join(tempfile.gettempdir(),"jieba.cache") - load_from_cache_fail = True - if os.path.exists(cache_file) and os.path.getmtime(cache_file)>os.path.getmtime(os.path.join(_curpath,"dict.txt")): - print >> sys.stderr, "loading model from cache" - try: - trie,FREQ,total,min_freq = marshal.load(open(cache_file,'rb')) - load_from_cache_fail = False - except: - load_from_cache_fail = True + print >> sys.stderr, "Building Trie..." + t1 = time.time() + cache_file = os.path.join(tempfile.gettempdir(),"jieba.cache") + load_from_cache_fail = True + if os.path.exists(cache_file) and os.path.getmtime(cache_file)>os.path.getmtime(os.path.join(_curpath,dictionary)): + print >> sys.stderr, "loading model from cache" + try: + trie,FREQ,total,min_freq = marshal.load(open(cache_file,'rb')) + load_from_cache_fail = False + except: + load_from_cache_fail = True - if load_from_cache_fail: - trie,FREQ,total = gen_trie(os.path.join(_curpath, dictionary)) - FREQ = dict([(k,log(float(v)/total)) for k,v in FREQ.iteritems()]) #normalize - min_freq = min(FREQ.itervalues()) - print >> sys.stderr, "dumping model to file cache" - tmp_suffix = "."+str(random.random()) - marshal.dump((trie,FREQ,total,min_freq),open(cache_file+tmp_suffix,'wb')) - if os.name=='nt': - import shutil - replace_file = shutil.move - else: - replace_file = os.rename - replace_file(cache_file+tmp_suffix,cache_file) + if load_from_cache_fail: + trie,FREQ,total = gen_trie(os.path.join(_curpath, dictionary)) + FREQ = dict([(k,log(float(v)/total)) for k,v in FREQ.iteritems()]) #normalize + min_freq = min(FREQ.itervalues()) + print >> sys.stderr, "dumping model to file cache" + tmp_suffix = "."+str(random.random()) + marshal.dump((trie,FREQ,total,min_freq),open(cache_file+tmp_suffix,'wb')) + if os.name=='nt': + import shutil + replace_file = shutil.move + else: + replace_file = os.rename + replace_file(cache_file+tmp_suffix,cache_file) - initialized = True + initialized = True - print >> sys.stderr, "loading model cost ", time.time() - t1, "seconds." - print >> sys.stderr, "Trie has been built succesfully." + print >> sys.stderr, "loading model cost ", time.time() - t1, "seconds." + print >> sys.stderr, "Trie has been built succesfully." def require_initialized(fn): @@ -111,7 +118,6 @@ def calc(sentence,DAG,idx,route): candidates = [ ( FREQ.get(sentence[idx:x+1],min_freq) + route[x+1][0],x ) for x in DAG[idx] ] route[idx] = max(candidates) - @require_initialized def get_DAG(sentence): N = len(sentence) @@ -173,7 +179,6 @@ def __cut_DAG(sentence): regognized = finalseg.cut(buf) for t in regognized: yield t - def cut(sentence,cut_all=False): if not ( type(sentence) is unicode): try: @@ -201,7 +206,6 @@ def cut(sentence,cut_all=False): else: for xx in x: yield xx - def cut_for_search(sentence): words = cut(sentence) for w in words: @@ -252,6 +256,7 @@ def __lcut_all(sentence): def __lcut_for_search(sentence): return list(__ref_cut_for_search(sentence)) +@require_initialized def enable_parallel(processnum): global pool,cut,cut_for_search if os.name=='nt': @@ -290,6 +295,6 @@ def disable_parallel(): def set_dictionary(dictionary_path): global initialized, DICTIONARY - DICTIONARY = dictionary_path - if initialized: - initialize() + with DICT_LOCK: + DICTIONARY = dictionary_path + initialized = False diff --git a/test/test_file.py b/test/test_file.py index 2107c36..fe2d93a 100644 --- a/test/test_file.py +++ b/test/test_file.py @@ -15,6 +15,6 @@ tm_cost = t2-t1 log_f = open("1.log","wb") for w in words: print >> log_f, w.encode("gbk"), "/" , - +print 'cost',tm_cost print 'speed' , len(content)/tm_cost, " bytes/second" diff --git a/test/test_multithread.py b/test/test_multithread.py new file mode 100644 index 0000000..e54310a --- /dev/null +++ b/test/test_multithread.py @@ -0,0 +1,29 @@ +#encoding=utf-8 +import sys +import threading +sys.path.append("../") + +import jieba + +class Worker(threading.Thread): + def run(self): + seg_list = jieba.cut("我来到北京清华大学",cut_all=True) + print "Full Mode:" + "/ ".join(seg_list) #全模式 + + seg_list = jieba.cut("我来到北京清华大学",cut_all=False) + print "Default Mode:" + "/ ".join(seg_list) #默认模式 + + seg_list = jieba.cut("他来到了网易杭研大厦") + print ", ".join(seg_list) + + seg_list = jieba.cut_for_search("小明硕士毕业于中国科学院计算所,后在日本京都大学深造") #搜索引擎模式 + print ", ".join(seg_list) +workers = [] +for i in xrange(10): + worker = Worker() + workers.append(worker) + worker.start() + +for worker in workers: + worker.join() +