diff --git a/tests/test_classify.py b/tests/test_classify.py index 0a1874a..9c92108 100644 --- a/tests/test_classify.py +++ b/tests/test_classify.py @@ -13,15 +13,18 @@ def setUp(self): pass # __init__() + @mock.patch('cherry.classifyer.Classify._classify') @mock.patch('cherry.classifyer.Classify._load_cache') def test_init(self, mock_load, mock_classify): mock_load.return_value = ('foo', 'bar') - cherry.classifyer.Classify(model='random', text=['random text']) - mock_load.assert_called_once_with('random') + res = cherry.classifyer.Classify(model='random', text=['random text']) + if res.get_CACHE() == False: + mock_load.assert_called_once_with('random') mock_classify.assert_called_once_with(['random text']) # _load_cache() + @mock.patch('cherry.classifyer.Classify._classify') @mock.patch('cherry.classifyer.load_cache') def test_load_cache(self, mock_load, mock_classify):