refactor: improvement check_paddle_installed (#806)

This commit is contained in:
vissssa 2020-01-09 19:23:11 +08:00 committed by Sun Junyi
parent 0868c323d9
commit dc2b788eb3
3 changed files with 49 additions and 63 deletions

View File

@ -1,26 +1,24 @@
from __future__ import absolute_import, unicode_literals from __future__ import absolute_import, unicode_literals
__version__ = '0.41' __version__ = '0.41'
__license__ = 'MIT' __license__ = 'MIT'
import re
import os
import sys
import time
import logging
import marshal import marshal
import re
import tempfile import tempfile
import threading import threading
from math import log import time
from hashlib import md5 from hashlib import md5
from ._compat import * from math import log
from . import finalseg from . import finalseg
from ._compat import *
if os.name == 'nt': if os.name == 'nt':
from shutil import move as _replace_file from shutil import move as _replace_file
else: else:
_replace_file = os.rename _replace_file = os.rename
_get_abs_path = lambda path: os.path.normpath(os.path.join(os.getcwd(), path)) _get_abs_path = lambda path: os.path.normpath(os.path.join(os.getcwd(), path))
DEFAULT_DICT = None DEFAULT_DICT = None
@ -47,10 +45,11 @@ re_han_default = re.compile("([\u4E00-\u9FD5a-zA-Z0-9+#&\._%\-]+)", re.U)
re_skip_default = re.compile("(\r\n|\s)", re.U) re_skip_default = re.compile("(\r\n|\s)", re.U)
def setLogLevel(log_level): def setLogLevel(log_level):
global logger
default_logger.setLevel(log_level) default_logger.setLevel(log_level)
class Tokenizer(object): class Tokenizer(object):
def __init__(self, dictionary=DEFAULT_DICT): def __init__(self, dictionary=DEFAULT_DICT):
@ -69,7 +68,8 @@ class Tokenizer(object):
def __repr__(self): def __repr__(self):
return '<Tokenizer dictionary=%r>' % self.dictionary return '<Tokenizer dictionary=%r>' % self.dictionary
def gen_pfdict(self, f): @staticmethod
def gen_pfdict(f):
lfreq = {} lfreq = {}
ltotal = 0 ltotal = 0
f_name = resolve_filename(f) f_name = resolve_filename(f)
@ -128,7 +128,7 @@ class Tokenizer(object):
load_from_cache_fail = True load_from_cache_fail = True
if os.path.isfile(cache_file) and (abs_path == DEFAULT_DICT or if os.path.isfile(cache_file) and (abs_path == DEFAULT_DICT or
os.path.getmtime(cache_file) > os.path.getmtime(abs_path)): os.path.getmtime(cache_file) > os.path.getmtime(abs_path)):
default_logger.debug( default_logger.debug(
"Loading model from cache %s" % cache_file) "Loading model from cache %s" % cache_file)
try: try:
@ -201,7 +201,7 @@ class Tokenizer(object):
eng_scan = 0 eng_scan = 0
eng_buf = u'' eng_buf = u''
for k, L in iteritems(dag): for k, L in iteritems(dag):
if eng_scan==1 and not re_eng.match(sentence[k]): if eng_scan == 1 and not re_eng.match(sentence[k]):
eng_scan = 0 eng_scan = 0
yield eng_buf yield eng_buf
if len(L) == 1 and k > old_j: if len(L) == 1 and k > old_j:
@ -219,7 +219,7 @@ class Tokenizer(object):
if j > k: if j > k:
yield sentence[k:j + 1] yield sentence[k:j + 1]
old_j = j old_j = j
if eng_scan==1: if eng_scan == 1:
yield eng_buf yield eng_buf
def __cut_DAG_NO_HMM(self, sentence): def __cut_DAG_NO_HMM(self, sentence):
@ -285,8 +285,8 @@ class Tokenizer(object):
for elem in buf: for elem in buf:
yield elem yield elem
def cut(self, sentence, cut_all = False, HMM = True,use_paddle = False): def cut(self, sentence, cut_all=False, HMM=True, use_paddle=False):
''' """
The main function that segments an entire sentence that contains The main function that segments an entire sentence that contains
Chinese characters into separated words. Chinese characters into separated words.
@ -294,15 +294,12 @@ class Tokenizer(object):
- sentence: The str(unicode) to be segmented. - sentence: The str(unicode) to be segmented.
- cut_all: Model type. True for full pattern, False for accurate pattern. - cut_all: Model type. True for full pattern, False for accurate pattern.
- HMM: Whether to use the Hidden Markov Model. - HMM: Whether to use the Hidden Markov Model.
''' """
is_paddle_installed = False is_paddle_installed = check_paddle_install['is_paddle_installed']
if use_paddle == True:
is_paddle_installed = check_paddle_install()
sentence = strdecode(sentence) sentence = strdecode(sentence)
if use_paddle == True and is_paddle_installed == True: if use_paddle and is_paddle_installed:
if sentence is None or sentence == "" or sentence == u"": if sentence is None or sentence == "" or sentence == u"":
yield sentence yield sentence
return
import jieba.lac_small.predict as predict import jieba.lac_small.predict as predict
results = predict.get_sent(sentence) results = predict.get_sent(sentence)
for sent in results: for sent in results:

View File

@ -1,49 +1,55 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging
import os import os
import sys import sys
import logging
log_console = logging.StreamHandler(sys.stderr) log_console = logging.StreamHandler(sys.stderr)
default_logger = logging.getLogger(__name__) default_logger = logging.getLogger(__name__)
default_logger.setLevel(logging.DEBUG) default_logger.setLevel(logging.DEBUG)
def setLogLevel(log_level): def setLogLevel(log_level):
global logger
default_logger.setLevel(log_level) default_logger.setLevel(log_level)
check_paddle_install = {'is_paddle_installed': False}
try: try:
import pkg_resources import pkg_resources
get_module_res = lambda *res: pkg_resources.resource_stream(__name__, get_module_res = lambda *res: pkg_resources.resource_stream(__name__,
os.path.join(*res)) os.path.join(*res))
except ImportError: except ImportError:
get_module_res = lambda *res: open(os.path.normpath(os.path.join( get_module_res = lambda *res: open(os.path.normpath(os.path.join(
os.getcwd(), os.path.dirname(__file__), *res)), 'rb') os.getcwd(), os.path.dirname(__file__), *res)), 'rb')
def enable_paddle(): def enable_paddle():
import_paddle_check = False
try: try:
import paddle import paddle
except ImportError: except ImportError:
default_logger.debug("Installing paddle-tiny, please wait a minute......") default_logger.debug("Installing paddle-tiny, please wait a minute......")
os.system("pip install paddlepaddle-tiny") os.system("pip install paddlepaddle-tiny")
try: try:
import paddle import paddle
except ImportError: except ImportError:
default_logger.debug("Import paddle error, please use command to install: pip install paddlepaddle-tiny==1.6.1." default_logger.debug(
"Now, back to jieba basic cut......") "Import paddle error, please use command to install: pip install paddlepaddle-tiny==1.6.1."
"Now, back to jieba basic cut......")
if paddle.__version__ < '1.6.1': if paddle.__version__ < '1.6.1':
default_logger.debug("Find your own paddle version doesn't satisfy the minimum requirement (1.6.1), " default_logger.debug("Find your own paddle version doesn't satisfy the minimum requirement (1.6.1), "
"please install paddle tiny by 'pip install --upgrade paddlepaddle-tiny', " "please install paddle tiny by 'pip install --upgrade paddlepaddle-tiny', "
"or upgrade paddle full version by 'pip install --upgrade paddlepaddle (-gpu for GPU version)' ") "or upgrade paddle full version by "
"'pip install --upgrade paddlepaddle (-gpu for GPU version)' ")
else: else:
try: try:
import jieba.lac_small.predict as predict import jieba.lac_small.predict as predict
import_paddle_check = True
default_logger.debug("Paddle enabled successfully......") default_logger.debug("Paddle enabled successfully......")
check_paddle_install['is_paddle_installed'] = True
except ImportError: except ImportError:
default_logger.debug("Import error, cannot find paddle.fluid and jieba.lac_small.predict module. " default_logger.debug("Import error, cannot find paddle.fluid and jieba.lac_small.predict module. "
"Now, back to jieba basic cut......") "Now, back to jieba basic cut......")
PY2 = sys.version_info[0] == 2 PY2 = sys.version_info[0] == 2
@ -66,6 +72,7 @@ else:
itervalues = lambda d: iter(d.values()) itervalues = lambda d: iter(d.values())
iteritems = lambda d: iter(d.items()) iteritems = lambda d: iter(d.items())
def strdecode(sentence): def strdecode(sentence):
if not isinstance(sentence, text_type): if not isinstance(sentence, text_type):
try: try:
@ -74,25 +81,9 @@ def strdecode(sentence):
sentence = sentence.decode('gbk', 'ignore') sentence = sentence.decode('gbk', 'ignore')
return sentence return sentence
def resolve_filename(f): def resolve_filename(f):
try: try:
return f.name return f.name
except AttributeError: except AttributeError:
return repr(f) return repr(f)
def check_paddle_install():
is_paddle_installed = False
try:
import paddle
if paddle.__version__ >= '1.6.1':
is_paddle_installed = True
else:
is_paddle_installed = False
default_logger.debug("Check the paddle version is not correct, the current version is "+ paddle.__version__+","
"please use command to install paddle: pip uninstall paddlepaddle(-gpu), "
"pip install paddlepaddle-tiny==1.6.1. Now, back to jieba basic cut......")
except ImportError:
default_logger.debug("Import paddle error, back to jieba basic cut......")
is_paddle_installed = False
return is_paddle_installed

View File

@ -1,11 +1,11 @@
from __future__ import absolute_import, unicode_literals from __future__ import absolute_import, unicode_literals
import os
import re
import sys
import jieba
import pickle import pickle
from .._compat import * import re
import jieba
from .viterbi import viterbi from .viterbi import viterbi
from .._compat import *
PROB_START_P = "prob_start.p" PROB_START_P = "prob_start.p"
PROB_TRANS_P = "prob_trans.p" PROB_TRANS_P = "prob_trans.p"
@ -252,6 +252,7 @@ class POSTokenizer(object):
def lcut(self, *args, **kwargs): def lcut(self, *args, **kwargs):
return list(self.cut(*args, **kwargs)) return list(self.cut(*args, **kwargs))
# default Tokenizer instance # default Tokenizer instance
dt = POSTokenizer(jieba.dt) dt = POSTokenizer(jieba.dt)
@ -276,19 +277,16 @@ def cut(sentence, HMM=True, use_paddle=False):
Note that this only works using dt, custom POSTokenizer Note that this only works using dt, custom POSTokenizer
instances are not supported. instances are not supported.
""" """
is_paddle_installed = False is_paddle_installed = check_paddle_install['is_paddle_installed']
if use_paddle == True: if use_paddle and is_paddle_installed:
is_paddle_installed = check_paddle_install()
if use_paddle==True and is_paddle_installed == True:
if sentence is None or sentence == "" or sentence == u"": if sentence is None or sentence == "" or sentence == u"":
yield pair(None, None) yield pair(None, None)
return
import jieba.lac_small.predict as predict import jieba.lac_small.predict as predict
sents,tags = predict.get_result(strdecode(sentence)) sents, tags = predict.get_result(strdecode(sentence))
for i,sent in enumerate(sents): for i, sent in enumerate(sents):
if sent is None or tags[i] is None: if sent is None or tags[i] is None:
continue continue
yield pair(sent,tags[i]) yield pair(sent, tags[i])
return return
global dt global dt
if jieba.pool is None: if jieba.pool is None: