Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Look at notebook files with full working [examples](https://github.com/stared/li

- [keras.ipynb](https://github.com/stared/livelossplot/blob/master/examples/keras.ipynb) - a Keras callback
- [minimal.ipynb](https://github.com/stared/livelossplot/blob/master/examples/minimal.ipynb) - a bare API, to use anywhere
- [script.py](https://github.com/stared/livelossplot/blob/master/examples/script.py) - to be run as a script, `python script.py`
- [bokeh.ipynb](https://github.com/stared/livelossplot/blob/master/examples/bokeh.ipynb) - a bare API, plots with Bokeh ([open it in Colab to see the plots](https://colab.research.google.com/github/stared/livelossplot/blob/master/examples/bokeh.ipynb))
- [pytorch.ipynb](https://github.com/stared/livelossplot/blob/master/examples/pytorch.ipynb) - a bare API, as applied to PyTorch
- [2d_prediction_maps.ipynb](https://github.com/stared/livelossplot/blob/master/examples/2d_prediction_maps.ipynb) - example of custom plots - 2d prediction maps (0.4.1+)
Expand Down
40 changes: 20 additions & 20 deletions examples/minimal.ipynb

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions examples/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# from livelossplot 0.5.6

from time import sleep
import numpy as np
from livelossplot import PlotLosses

plotlosses = PlotLosses(mode="script")

for i in range(10):
plotlosses.update(
{
"acc": 1 - np.random.rand() / (i + 2.0),
"val_acc": 1 - np.random.rand() / (i + 0.5),
"loss": 1.0 / (i + 2.0),
"val_loss": 1.0 / (i + 0.5),
}
)
plotlosses.send()
sleep(0.5)
12 changes: 10 additions & 2 deletions livelossplot/outputs/matplotlib_plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Tuple, List, Dict, Optional, Callable
from typing import Tuple, List, Dict, Optional, Callable, Literal

import warnings

Expand Down Expand Up @@ -50,6 +50,7 @@ def __init__(
self._before_plots = before_plots if before_plots else self._default_before_plots
self._after_plots = after_plots if after_plots else self._default_after_plots
self.figsize = figsize
self.output_mode: Literal['notebook', 'script'] = "notebook"

def send(self, logger: MainLogger):
"""Draw figures with metrics and show"""
Expand Down Expand Up @@ -110,7 +111,11 @@ def _default_after_plots(self, fig: plt.Figure):
if self.figpath is not None:
fig.savefig(self.figpath.format(i=self.file_idx))
self.file_idx += 1
plt.show()
if self.output_mode == "script":
plt.draw()
plt.pause(0.1)
else:
plt.show()

def _draw_metric_subplot(self, ax: plt.Axes, group_logs: Dict[str, List[LogItem]], group_name: str, x_label: str):
"""
Expand Down Expand Up @@ -139,3 +144,6 @@ def _not_inline_warning(self):
"livelossplot requires inline plots.\nYour current backend is: {}"
"\nRun in a Jupyter environment and execute '%matplotlib inline'.".format(backend)
)

def _set_output_mode(self, mode: Literal['notebook', 'script']):
self.output_mode = mode
24 changes: 22 additions & 2 deletions livelossplot/plot_losses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Type, TypeVar, List, Union, Optional, Tuple
from typing import Type, TypeVar, List, Union, Optional, Tuple, Literal

import livelossplot
from livelossplot.main_logger import MainLogger
Expand All @@ -9,14 +9,32 @@
BO = TypeVar('BO', bound=outputs.BaseOutput)


def get_mode() -> Literal['notebook', 'script']:
try:
from IPython import get_ipython
ipython = get_ipython()
if ipython is None:
return 'script'
name = ipython.__class__.__name__
if name == "ZMQInteractiveShell" or name == "Shell":
# Shell is in Colab
return "notebook"
elif name == "TerminalInteractiveShell":
return "script"
print(f"Unknown IPython mode: {name}. Assuming notebook mode.")
return "notebook"
except ImportError:
return "script"


class PlotLosses:
"""
Class collect metrics from the training engine and send it to plugins, when send is called
"""
def __init__(
self,
outputs: List[Union[Type[BO], str]] = ['MatplotlibPlot', 'ExtremaPrinter'],
mode: str = 'notebook',
mode: Optional[Literal['notebook', 'script']] = None,
figsize: Optional[Tuple[int, int]] = None,
**kwargs
):
Expand All @@ -31,6 +49,8 @@ def __init__(
"""
self.logger = MainLogger(**kwargs)
self.outputs = [getattr(livelossplot.outputs, out)() if isinstance(out, str) else out for out in outputs]
if mode is None:
mode = get_mode()
for out in self.outputs:
out.set_output_mode(mode)
if figsize is not None and isinstance(out, MatplotlibPlot):
Expand Down
Loading