mirror of
https://github.com/fxsjy/jieba.git
synced 2025-07-10 00:01:33 +08:00
refactor: improvement check_paddle_installed (#806)
This commit is contained in:
parent
0868c323d9
commit
dc2b788eb3
@ -1,26 +1,24 @@
|
|||||||
from __future__ import absolute_import, unicode_literals
|
from __future__ import absolute_import, unicode_literals
|
||||||
|
|
||||||
__version__ = '0.41'
|
__version__ = '0.41'
|
||||||
__license__ = 'MIT'
|
__license__ = 'MIT'
|
||||||
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
import marshal
|
import marshal
|
||||||
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
from math import log
|
import time
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from ._compat import *
|
from math import log
|
||||||
|
|
||||||
from . import finalseg
|
from . import finalseg
|
||||||
|
from ._compat import *
|
||||||
|
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
from shutil import move as _replace_file
|
from shutil import move as _replace_file
|
||||||
else:
|
else:
|
||||||
_replace_file = os.rename
|
_replace_file = os.rename
|
||||||
|
|
||||||
|
|
||||||
_get_abs_path = lambda path: os.path.normpath(os.path.join(os.getcwd(), path))
|
_get_abs_path = lambda path: os.path.normpath(os.path.join(os.getcwd(), path))
|
||||||
|
|
||||||
DEFAULT_DICT = None
|
DEFAULT_DICT = None
|
||||||
@ -47,10 +45,11 @@ re_han_default = re.compile("([\u4E00-\u9FD5a-zA-Z0-9+#&\._%\-]+)", re.U)
|
|||||||
|
|
||||||
re_skip_default = re.compile("(\r\n|\s)", re.U)
|
re_skip_default = re.compile("(\r\n|\s)", re.U)
|
||||||
|
|
||||||
|
|
||||||
def setLogLevel(log_level):
|
def setLogLevel(log_level):
|
||||||
global logger
|
|
||||||
default_logger.setLevel(log_level)
|
default_logger.setLevel(log_level)
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer(object):
|
class Tokenizer(object):
|
||||||
|
|
||||||
def __init__(self, dictionary=DEFAULT_DICT):
|
def __init__(self, dictionary=DEFAULT_DICT):
|
||||||
@ -69,7 +68,8 @@ class Tokenizer(object):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<Tokenizer dictionary=%r>' % self.dictionary
|
return '<Tokenizer dictionary=%r>' % self.dictionary
|
||||||
|
|
||||||
def gen_pfdict(self, f):
|
@staticmethod
|
||||||
|
def gen_pfdict(f):
|
||||||
lfreq = {}
|
lfreq = {}
|
||||||
ltotal = 0
|
ltotal = 0
|
||||||
f_name = resolve_filename(f)
|
f_name = resolve_filename(f)
|
||||||
@ -286,7 +286,7 @@ class Tokenizer(object):
|
|||||||
yield elem
|
yield elem
|
||||||
|
|
||||||
def cut(self, sentence, cut_all=False, HMM=True, use_paddle=False):
|
def cut(self, sentence, cut_all=False, HMM=True, use_paddle=False):
|
||||||
'''
|
"""
|
||||||
The main function that segments an entire sentence that contains
|
The main function that segments an entire sentence that contains
|
||||||
Chinese characters into separated words.
|
Chinese characters into separated words.
|
||||||
|
|
||||||
@ -294,15 +294,12 @@ class Tokenizer(object):
|
|||||||
- sentence: The str(unicode) to be segmented.
|
- sentence: The str(unicode) to be segmented.
|
||||||
- cut_all: Model type. True for full pattern, False for accurate pattern.
|
- cut_all: Model type. True for full pattern, False for accurate pattern.
|
||||||
- HMM: Whether to use the Hidden Markov Model.
|
- HMM: Whether to use the Hidden Markov Model.
|
||||||
'''
|
"""
|
||||||
is_paddle_installed = False
|
is_paddle_installed = check_paddle_install['is_paddle_installed']
|
||||||
if use_paddle == True:
|
|
||||||
is_paddle_installed = check_paddle_install()
|
|
||||||
sentence = strdecode(sentence)
|
sentence = strdecode(sentence)
|
||||||
if use_paddle == True and is_paddle_installed == True:
|
if use_paddle and is_paddle_installed:
|
||||||
if sentence is None or sentence == "" or sentence == u"":
|
if sentence is None or sentence == "" or sentence == u"":
|
||||||
yield sentence
|
yield sentence
|
||||||
return
|
|
||||||
import jieba.lac_small.predict as predict
|
import jieba.lac_small.predict as predict
|
||||||
results = predict.get_sent(sentence)
|
results = predict.get_sent(sentence)
|
||||||
for sent in results:
|
for sent in results:
|
||||||
|
@ -1,18 +1,22 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import logging
|
|
||||||
|
|
||||||
log_console = logging.StreamHandler(sys.stderr)
|
log_console = logging.StreamHandler(sys.stderr)
|
||||||
default_logger = logging.getLogger(__name__)
|
default_logger = logging.getLogger(__name__)
|
||||||
default_logger.setLevel(logging.DEBUG)
|
default_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
def setLogLevel(log_level):
|
def setLogLevel(log_level):
|
||||||
global logger
|
|
||||||
default_logger.setLevel(log_level)
|
default_logger.setLevel(log_level)
|
||||||
|
|
||||||
|
|
||||||
|
check_paddle_install = {'is_paddle_installed': False}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
get_module_res = lambda *res: pkg_resources.resource_stream(__name__,
|
get_module_res = lambda *res: pkg_resources.resource_stream(__name__,
|
||||||
os.path.join(*res))
|
os.path.join(*res))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -21,7 +25,6 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
def enable_paddle():
|
def enable_paddle():
|
||||||
import_paddle_check = False
|
|
||||||
try:
|
try:
|
||||||
import paddle
|
import paddle
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -30,21 +33,24 @@ def enable_paddle():
|
|||||||
try:
|
try:
|
||||||
import paddle
|
import paddle
|
||||||
except ImportError:
|
except ImportError:
|
||||||
default_logger.debug("Import paddle error, please use command to install: pip install paddlepaddle-tiny==1.6.1."
|
default_logger.debug(
|
||||||
|
"Import paddle error, please use command to install: pip install paddlepaddle-tiny==1.6.1."
|
||||||
"Now, back to jieba basic cut......")
|
"Now, back to jieba basic cut......")
|
||||||
if paddle.__version__ < '1.6.1':
|
if paddle.__version__ < '1.6.1':
|
||||||
default_logger.debug("Find your own paddle version doesn't satisfy the minimum requirement (1.6.1), "
|
default_logger.debug("Find your own paddle version doesn't satisfy the minimum requirement (1.6.1), "
|
||||||
"please install paddle tiny by 'pip install --upgrade paddlepaddle-tiny', "
|
"please install paddle tiny by 'pip install --upgrade paddlepaddle-tiny', "
|
||||||
"or upgrade paddle full version by 'pip install --upgrade paddlepaddle (-gpu for GPU version)' ")
|
"or upgrade paddle full version by "
|
||||||
|
"'pip install --upgrade paddlepaddle (-gpu for GPU version)' ")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
import jieba.lac_small.predict as predict
|
import jieba.lac_small.predict as predict
|
||||||
import_paddle_check = True
|
|
||||||
default_logger.debug("Paddle enabled successfully......")
|
default_logger.debug("Paddle enabled successfully......")
|
||||||
|
check_paddle_install['is_paddle_installed'] = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
default_logger.debug("Import error, cannot find paddle.fluid and jieba.lac_small.predict module. "
|
default_logger.debug("Import error, cannot find paddle.fluid and jieba.lac_small.predict module. "
|
||||||
"Now, back to jieba basic cut......")
|
"Now, back to jieba basic cut......")
|
||||||
|
|
||||||
|
|
||||||
PY2 = sys.version_info[0] == 2
|
PY2 = sys.version_info[0] == 2
|
||||||
|
|
||||||
default_encoding = sys.getfilesystemencoding()
|
default_encoding = sys.getfilesystemencoding()
|
||||||
@ -66,6 +72,7 @@ else:
|
|||||||
itervalues = lambda d: iter(d.values())
|
itervalues = lambda d: iter(d.values())
|
||||||
iteritems = lambda d: iter(d.items())
|
iteritems = lambda d: iter(d.items())
|
||||||
|
|
||||||
|
|
||||||
def strdecode(sentence):
|
def strdecode(sentence):
|
||||||
if not isinstance(sentence, text_type):
|
if not isinstance(sentence, text_type):
|
||||||
try:
|
try:
|
||||||
@ -74,25 +81,9 @@ def strdecode(sentence):
|
|||||||
sentence = sentence.decode('gbk', 'ignore')
|
sentence = sentence.decode('gbk', 'ignore')
|
||||||
return sentence
|
return sentence
|
||||||
|
|
||||||
|
|
||||||
def resolve_filename(f):
|
def resolve_filename(f):
|
||||||
try:
|
try:
|
||||||
return f.name
|
return f.name
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return repr(f)
|
return repr(f)
|
||||||
|
|
||||||
|
|
||||||
def check_paddle_install():
|
|
||||||
is_paddle_installed = False
|
|
||||||
try:
|
|
||||||
import paddle
|
|
||||||
if paddle.__version__ >= '1.6.1':
|
|
||||||
is_paddle_installed = True
|
|
||||||
else:
|
|
||||||
is_paddle_installed = False
|
|
||||||
default_logger.debug("Check the paddle version is not correct, the current version is "+ paddle.__version__+","
|
|
||||||
"please use command to install paddle: pip uninstall paddlepaddle(-gpu), "
|
|
||||||
"pip install paddlepaddle-tiny==1.6.1. Now, back to jieba basic cut......")
|
|
||||||
except ImportError:
|
|
||||||
default_logger.debug("Import paddle error, back to jieba basic cut......")
|
|
||||||
is_paddle_installed = False
|
|
||||||
return is_paddle_installed
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from __future__ import absolute_import, unicode_literals
|
from __future__ import absolute_import, unicode_literals
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import jieba
|
|
||||||
import pickle
|
import pickle
|
||||||
from .._compat import *
|
import re
|
||||||
|
|
||||||
|
import jieba
|
||||||
from .viterbi import viterbi
|
from .viterbi import viterbi
|
||||||
|
from .._compat import *
|
||||||
|
|
||||||
PROB_START_P = "prob_start.p"
|
PROB_START_P = "prob_start.p"
|
||||||
PROB_TRANS_P = "prob_trans.p"
|
PROB_TRANS_P = "prob_trans.p"
|
||||||
@ -252,6 +252,7 @@ class POSTokenizer(object):
|
|||||||
def lcut(self, *args, **kwargs):
|
def lcut(self, *args, **kwargs):
|
||||||
return list(self.cut(*args, **kwargs))
|
return list(self.cut(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
# default Tokenizer instance
|
# default Tokenizer instance
|
||||||
|
|
||||||
dt = POSTokenizer(jieba.dt)
|
dt = POSTokenizer(jieba.dt)
|
||||||
@ -276,13 +277,10 @@ def cut(sentence, HMM=True, use_paddle=False):
|
|||||||
Note that this only works using dt, custom POSTokenizer
|
Note that this only works using dt, custom POSTokenizer
|
||||||
instances are not supported.
|
instances are not supported.
|
||||||
"""
|
"""
|
||||||
is_paddle_installed = False
|
is_paddle_installed = check_paddle_install['is_paddle_installed']
|
||||||
if use_paddle == True:
|
if use_paddle and is_paddle_installed:
|
||||||
is_paddle_installed = check_paddle_install()
|
|
||||||
if use_paddle==True and is_paddle_installed == True:
|
|
||||||
if sentence is None or sentence == "" or sentence == u"":
|
if sentence is None or sentence == "" or sentence == u"":
|
||||||
yield pair(None, None)
|
yield pair(None, None)
|
||||||
return
|
|
||||||
import jieba.lac_small.predict as predict
|
import jieba.lac_small.predict as predict
|
||||||
sents, tags = predict.get_result(strdecode(sentence))
|
sents, tags = predict.get_result(strdecode(sentence))
|
||||||
for i, sent in enumerate(sents):
|
for i, sent in enumerate(sents):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user