Skip to content

Example code not working #1

@guytenn

Description

@guytenn
import medkit as mk
synthetic_dataset = mk.batch_generate(
                                   domain = "Ward",
                                   environment = "CRN",
                                   policy = "LSTM",
                                   size = 1000,
                                   test_size = 200,
                                   max_length = 10,
                                   scale = True)

Gives an error:

Traceback (most recent call last):
  File "<stdin>", line 8, in <module>
  File "/home/gtennenholtz/medkit-learn/medkit/api.py", line 58, in batch_generate
    env = env_dict[environment](dom)
  File "/home/gtennenholtz/medkit-learn/medkit/environments/CounterfactualRNN.py", line 118, in __init__
    self.model = CRN_env(domain)
  File "/home/gtennenholtz/medkit-learn/medkit/environments/CounterfactualRNN.py", line 15, in __init__
    self.lstm_layers = self.hyper["lstm_layers"]
KeyError: 'lstm_layers'

Also:
env = mk.live_simulate(domain="ICU", environment="SVAE")

Gives an error:

Traceback (most recent call last):
  File "<stdin>", line 3, in <module>
  File "/home/gtennenholtz/medkit-learn/medkit/api.py", line 214, in live_simulate
    env = env_dict[environment](dom)
  File "/home/gtennenholtz/medkit-learn/medkit/environments/SequentialVAE.py", line 188, in __init__
    self.load_pretrained()
  File "/home/gtennenholtz/medkit-learn/medkit/bases/base_env.py", line 30, in load_pretrained
    self.model.load_state_dict(torch.load(path))
  File "/home/gtennenholtz/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1224, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SVAE_env:
        Missing key(s) in state_dict: "lstm.weight_ih_l0", "lstm.bias_ih_l0", "lstm.weight_hh_l0", "lstm.bias_hh_l0", "lstm.weight_ih_l1", "lstm.bias_ih_l1", "lstm.weight_hh_l1", "lstm.bias_hh_l1", "lstm.layers.0.cell.ih.weight", "lstm.layers.0.cell.ih.bias", "lstm.layers.0.cell.hh.weight", "lstm.layers.0.cell.hh.bias", "lstm.layers.1.cell.ih.weight", "lstm.layers.1.cell.ih.bias", "lstm.layers.1.cell.hh.weight", "lstm.layers.1.cell.hh.bias".
        Unexpected key(s) in state_dict: "lstm.ih.weight", "lstm.ih.bias", "lstm.hh.weight", "lstm.hh.bias".
        size mismatch for encoder.linear1.weight: copying a param with shape torch.Size([128, 24]) from checkpoint, the shape in current model is torch.Size([128, 37]).
        size mismatch for decoder.series_cont_mean.weight: copying a param with shape torch.Size([23, 128]) from checkpoint, the shape in current model is torch.Size([37, 128]).
        size mismatch for decoder.series_cont_mean.bias: copying a param with shape torch.Size([23]) from checkpoint, the shape in current model is torch.Size([37]).
        size mismatch for decoder.series_cont_lstd.weight: copying a param with shape torch.Size([23, 128]) from checkpoint, the shape in current model is torch.Size([37, 128]).
        size mismatch for decoder.series_cont_lstd.bias: copying a param with shape torch.Size([23]) from checkpoint, the shape in current model is torch.Size([37]).
        size mismatch for decoder.series_bin.weight: copying a param with shape torch.Size([1, 128]) from checkpoint, the shape in current model is torch.Size([0, 128]).
        size mismatch for decoder.series_bin.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([0]).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions