diff --git a/lightfm/lightfm.py b/lightfm/lightfm.py index 3034b49c..4e4b5f5e 100644 --- a/lightfm/lightfm.py +++ b/lightfm/lightfm.py @@ -453,8 +453,10 @@ def fit(self, interactions, num_threads: int, optional Number of parallel computation threads to use. Should not be higher than the number of physical cores. - verbose: bool, optional - whether to print progress messages. + verbose: bool or function, optional + if it's a bool, whether to print progress messages; + if it's a function, the print or logging function for printing + progress messages. Returns ------- @@ -513,8 +515,10 @@ def fit_partial(self, interactions, num_threads: int, optional Number of parallel computation threads to use. Should not be higher than the number of physical cores. - verbose: bool, optional - whether to print progress messages. + verbose: bool or function, optional + if it's a bool, whether to print progress messages; + if it's a function, the print or logging function for printing + progress messages. Returns ------- @@ -567,7 +571,9 @@ def fit_partial(self, interactions, for epoch in range(epochs): - if verbose: + if hasattr(verbose, '__call__'): + verbose('Epoch %s' % epoch) + elif verbose: print('Epoch %s' % epoch) self._run_epoch(item_features,