From 8d9a2254dddf61de4baad4b40a4dd96ddba67f9c Mon Sep 17 00:00:00 2001 From: Aleksey Kobylin Date: Sat, 29 Apr 2023 23:11:33 +0300 Subject: [PATCH] BASIL-21-1: attention-vector MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Исправил причину утечки памяти - visualize.py/plt.close() 2. Поправил опечатку в readme 3. Добавил использование attention_context в последующих итерациях декодера 4. Добавил возможность продолжать обучение из чекпоинта 5. Убрал дублирование log("Resource/RAM") 6. Сделал логирование текста более приятным --- README.md | 2 +- models/decoder.py | 6 ++++++ pl/modules.py | 15 +++++++++++---- utils/visualization.py | 5 +++-- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index a9df959..fa24ef4 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ make dataset_advanced \ RAW_DATASET= \ DIALOG_SECONDS_COOLDOWN= \ DIALOG_MEMORY= \ - DAATSET_OUTPUT= + DATASET_OUTPUT= ``` ## Fit tokenizer diff --git a/models/decoder.py b/models/decoder.py index 50c9536..bdfbdcf 100644 --- a/models/decoder.py +++ b/models/decoder.py @@ -50,6 +50,7 @@ def __init__( self.embedding = nn.Embedding( num_embeddings=num_embeddings, embedding_dim=embedding_dim, max_norm=True ) + self.combine_hiddens = nn.Linear(hidden_size * 2, hidden_size) self._dec_lstm = nn.LSTM( input_size=embedding_dim, hidden_size=hidden_size, @@ -74,8 +75,12 @@ def forward(self, tokens, context, mask=None, *args, **kwargs): dec_in_h = context["hidden"] dec_in_c = context["context"] encoder_outputs = context["encoder_output"] + attention_context = context.get("attention_context", None) embeddings = self.embedding(tokens) + if attention_context is not None: + attention_context = attention_context.transpose(0, 1).repeat(dec_in_h.shape[0], 1, 1).contiguous() + dec_in_h = self.combine_hiddens(torch.cat([attention_context, dec_in_h], dim=-1)) dec_out, (dec_h, dec_c) = self._dec_lstm(embeddings, (dec_in_h, dec_in_c)) attention_context, attention_weights = self._attention( @@ -95,5 +100,6 @@ def forward(self, tokens, context, mask=None, *args, **kwargs): "context": dec_c, "encoder_output": encoder_outputs, "attention": attention_weights, + "attention_context": attention_context, }, ) diff --git a/pl/modules.py b/pl/modules.py index c2974d9..503d7ce 100644 --- a/pl/modules.py +++ b/pl/modules.py @@ -1,8 +1,8 @@ from typing import Any, Dict, Optional, Sequence, Union +import os import lightning as pl import torch -import torchtext import torchvision from tokenizers import BaseTokenizer @@ -37,6 +37,14 @@ def __init__(self, config: Dict[str, Any]) -> None: start_token=self.tokenizer.start_token, stop_token=self.tokenizer.stop_token, ) + checkpoint_path = config.get("checkpoint", None) + if checkpoint_path is not None and os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location="cpu") + model_state_dict = { + key.replace("model.", ""): val + for key, val in checkpoint["state_dict"].items() if "model." in key + } + self.model.load_state_dict(model_state_dict) self._log_attention = self.model_config['parameters'].get('use_attention', False) self.train_batch_size: Optional[int] = self.dataset_config.get( "train_batch", None @@ -99,7 +107,6 @@ def __step(self, batch: Dict[str, Any], batch_idx: int, stage: str) -> torch.Ten { f"Loss/{stage}": loss, f"Accuracy/{stage}": accuracy, - f"Resources/RAM": get_ram_consumption_mb(), f"BLEU/{stage}": bleu, f"Resources/RAM": get_ram_consumption_mb(), }, @@ -141,8 +148,8 @@ def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> None: ) val_result = "" - for sample_idx in range(max(len(in_tokens), 16)): - prediction = self.tokenizer.decode(tokens[sample_idx].tolist()).replace("PAD", " ") + for sample_idx in range(min(len(in_tokens), 16)): + prediction = self.tokenizer.decode(tokens[sample_idx].tolist()).split("STOP", 1)[0].replace("PAD", "").replace("STOP", "") val_result += ( f"Sample #{sample_idx}:\n\ninput: {text_input[sample_idx]}\n\npredicted: " f"{prediction}\n\ntarget: {target_output[sample_idx]}\n\n" diff --git a/utils/visualization.py b/utils/visualization.py index eb10431..5b6cdbf 100644 --- a/utils/visualization.py +++ b/utils/visualization.py @@ -26,9 +26,10 @@ def plot_attention_scores(input_words, output_words, scores, figsize=(4, 4), dpi plt.tight_layout(pad=0.5, w_pad=0.5, h_pad=1.0) - image = None with tempfile.NamedTemporaryFile(suffix='.png', delete=True) as temp_file: fig.savefig(temp_file.name) - image = plt.imread(temp_file.name)[:,:,:3] + image = plt.imread(temp_file.name)[:, :, :3] + + plt.close(fig) return image