diff --git a/.gitignore b/.gitignore index 2db7330a..d73bc7a1 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,6 @@ dmypy.json .idea/ .vscode/ results/ + +# Datasets +datasets/ \ No newline at end of file diff --git a/README.md b/README.md index 1cb91aff..536018e7 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,8 @@ Wenbin Li, Ziyi Wang, Xuesong Yang, Chuanqi Dong, Pinzhuo Tian, Tiexin Qin, Jing + [MTL (CVPR 2019)](https://arxiv.org/abs/1812.02391) + [ANIL (ICLR 2020)](https://arxiv.org/abs/1909.09157) + [BOIL (ICLR 2021)](https://arxiv.org/abs/2008.08882) ++ [MeTal (arXiv 2021)](https://arxiv.org/abs/2110.03909) + ### Metric-learning based methods + [ProtoNet (NeurIPS 2017)](https://arxiv.org/abs/1703.05175) + [RelationNet (CVPR 2018)](https://arxiv.org/abs/1711.06025) diff --git a/config/classifiers/METAL.yaml b/config/classifiers/METAL.yaml new file mode 100644 index 00000000..bbcf7f5c --- /dev/null +++ b/config/classifiers/METAL.yaml @@ -0,0 +1,7 @@ +classifier: + name: MAML + kwargs: + inner_param: + lr: 1e-2 + iter: 5 + feat_dim: 1600 diff --git a/config/getting_started.yaml b/config/getting_started.yaml new file mode 100644 index 00000000..dd5a6323 --- /dev/null +++ b/config/getting_started.yaml @@ -0,0 +1,9 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/losses.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - classifiers/Proto.yaml + - backbones/Conv64FLeakyReLU.yaml \ No newline at end of file diff --git a/config/headers/data.yaml b/config/headers/data.yaml index b59b6b05..f6316c43 100644 --- a/config/headers/data.yaml +++ b/config/headers/data.yaml @@ -1,8 +1,8 @@ -data_root: /data/fewshot/miniImageNet--ravi +data_root: datasets/miniImageNet--ravi image_size: 84 use_memory: False augment: True augment_times: 1 augment_times_query: 1 -workers: 8 # number of workers for dataloader in all threads +workers: 1 # number of workers for dataloader in all threads dataloader_num: 1 diff --git a/config/maml.yaml b/config/maml.yaml index 3b8ca0a6..09329419 100644 --- a/config/maml.yaml +++ b/config/maml.yaml @@ -12,7 +12,7 @@ episode_size: 2 train_episode: 2000 test_episode: 600 -device_ids: 5 +device_ids: 0 n_gpu: 1 epoch: 100 diff --git a/config/metal.yaml b/config/metal.yaml new file mode 100644 index 00000000..f5e7a0bf --- /dev/null +++ b/config/metal.yaml @@ -0,0 +1,86 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + +way_num: 5 +shot_num: 1 +query_num: 15 + +episode_size: 2 +train_episode: 2000 +test_episode: 600 + +device_ids: 0 +n_gpu: 1 +epoch: 100 + +optimizer: + name: Adam + kwargs: + lr: 0.001 + other: ~ + +backbone: + name: Conv64F + kwargs: + is_flatten: True + is_feature: False + leaky_relu: False + negative_slope: 0.2 + last_pool: True + +classifier: + name: METAL + kwargs: + inner_param: + lr: 1e-2 + train_iter: 5 + test_iter: 10 + feat_dim: 1600 + + +#backbone: +# name: resnet12 +# kwargs: ~ +# +#classifier: +# name: METAL +# kwargs: +# inner_param: +# lr: 1e-2 +# train_iter: 5 +# test_iter: 10 #must same as train_iter +# feat_dim: 640 + + +# backbone: +# name: resnet18 +# kwargs: ~ + +# classifier: +# name: METAL +# kwargs: +# inner_param: +# lr: 1e-2 +# train_iter: 5 +# test_iter: 10 #must same as train_iter +# feat_dim: 512 + + +# backbone: +# name: WRN +# kwargs: +# depth: 28 +# widen_factor: 10 + +# classifier: +# name: METAL +# kwargs: +# inner_param: +# lr: 1e-2 +# train_iter: 5 +# test_iter: 10 +# feat_dim: 640 diff --git a/config/proto.yaml b/config/proto.yaml index 6bbb205d..b5e1af42 100644 --- a/config/proto.yaml +++ b/config/proto.yaml @@ -9,7 +9,7 @@ includes: device_ids: 0,1 -n_gpu: 2 +n_gpu: 1 way_num: 5 shot_num: 1 query_num: 15 diff --git a/config/renet.yaml b/config/renet.yaml index 74bebbb2..29152c37 100644 --- a/config/renet.yaml +++ b/config/renet.yaml @@ -20,9 +20,8 @@ classifier: temperature: 0.2 temperature_attn: 5.0 name: RENet -data_root: /data/fewshot/miniImageNet--ravi deterministic: true -device_ids: 3 +device_ids: 0 episode_size: 1 epoch: 100 image_size: 84 diff --git a/config/skd.yaml b/config/skd.yaml deleted file mode 100644 index 864be758..00000000 --- a/config/skd.yaml +++ /dev/null @@ -1,63 +0,0 @@ -includes: - - headers/data.yaml - - headers/device.yaml - - headers/misc.yaml - - headers/model.yaml - - headers/optimizer.yaml - - classifiers/SKD.yaml - - backbones/resnet12.yaml - - -device_ids: 0 -way_num: 5 -shot_num: 1 -query_num: 15 -episode_size: 1 -train_episode: 100 -test_episode: 100 - -batch_size: 128 - -save_part: - - emb_func - - cls_classifier - -classifier: - name: SKDModel - kwargs: - feat_dim: 1600 - num_class: 64 - gamma: 1.0 - alpha: 0.1 - is_distill: False - emb_func_path: ./results/SKDModel-miniImageNet--ravi-Conv64F-5-1-Sep-23-2021-15-16-27/checkpoints/emb_func_best.pth - cls_classifier_path: ./results/SKDModel-miniImageNet--ravi-Conv64F-5-1-Sep-23-2021-15-16-27/checkpoints/cls_classifier_best.pth - - -backbone: - name: Conv64F - kwargs: - is_flatten: True - is_feature: False - leaky_relu: False - negative_slope: 0.2 - last_pool: True - maxpool_last2: True - -# backbone: -# name: resnet12 -# kwargs: -# keep_prob: 0.0 - -# backbone: -# name: resnet18 -# kwargs: - -# backbone: -# name: WRN -# kwargs: -# depth: 10 -# widen_factor: 10 -# dropRate: 0.0 -# avg_pool: True -# is_flatten: True diff --git a/config/versa.yaml b/config/versa.yaml index a60e71d0..66c3c3df 100644 --- a/config/versa.yaml +++ b/config/versa.yaml @@ -8,10 +8,10 @@ includes: deterministic: False way_num: 5 -shot_num: 5 +shot_num: 1 query_num: 15 test_way: 5 # use ~ -> test_* = *_num -test_shot: 5 +test_shot: 1 test_query: 15 episode_size: 1 diff --git a/core/model/meta/__init__.py b/core/model/meta/__init__.py index c13ffc26..dbdaea36 100644 --- a/core/model/meta/__init__.py +++ b/core/model/meta/__init__.py @@ -7,3 +7,4 @@ from .leo import LEO from .mtl import MTL from .boil import BOIL +from .metal import METAL diff --git a/core/model/meta/metal.py b/core/model/meta/metal.py new file mode 100644 index 00000000..e66c6a12 --- /dev/null +++ b/core/model/meta/metal.py @@ -0,0 +1,417 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .meta_model import MetaModel +from ..backbone.utils import convert_maml_module +from .maml import MAMLLayer +from core.utils import accuracy + + +class METAL(MetaModel): + + def __init__(self, inner_param, feat_dim, **kwargs): + super(METAL, self).__init__(**kwargs) + self.classifier = MAMLLayer(feat_dim, way_num=self.way_num) + base_learner_num_layers = len(list(self.classifier.named_parameters())) + support_meta_loss_num_dim = base_learner_num_layers + 2 * self.way_num + 1 + support_adapter_num_dim = base_learner_num_layers + 1 + query_num_dim = base_learner_num_layers + 1 + self.way_num + + self.loss_func = MetaLossNetwork(support_meta_loss_num_dim, inner_param) + + self.query_loss_func = MetaLossNetwork(query_num_dim, inner_param) + + self.loss_adapter = LossAdapter(support_adapter_num_dim, args=inner_param, num_loss_net_layers=2) + + self.query_loss_adapter = LossAdapter(query_num_dim, args=inner_param, num_loss_net_layers=2) + + self.feat_dim = feat_dim + self.inner_param = inner_param + convert_maml_module(self) + + def forward_output(self, x): + out1 = self.emb_func(x) + out2 = self.classifier(out1) + return out2 + + def set_forward(self, batch): + image, global_target = batch # unused global_target + image = image.to(self.device) + ( + support_image, + query_image, + support_target, + query_target, + ) = self.split_by_episode(image, mode=2) + episode_size, _, c, h, w = support_image.size() + + output_list = [] + for i in range(episode_size): + """ + 源代码: + x_support_set_task = x_support_set_task.view(-1, c, h, w) + x_target_set_task = x_target_set_task.view(-1, c, h, w) + + y_support_set_task = y_support_set_task.view(-1) + y_target_set_task = y_target_set_task.view(-1) + """ + # 都是x + episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w) + episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w) + # 都是y + episode_support_targets = support_target[i].reshape(-1) + episode_query_targets = query_target[i].reshape(-1) + + self.set_forward_adaptation(episode_support_image, episode_query_image, episode_support_targets, + episode_query_targets) + + output = self.forward_output(episode_query_image) + + output_list.append(output) + + output = torch.cat(output_list, dim=0) + acc = accuracy(output, query_target.contiguous().view(-1)) + return output, acc + + def set_forward_loss(self, batch): + image, global_target = batch # unused global_target + image = image.to(self.device) + ( + support_image, + query_image, + support_target, + query_target, + ) = self.split_by_episode(image, mode=2) + episode_size, _, c, h, w = support_image.size() + output_list = [] + for i in range(episode_size): + episode_support_image = support_image[i].contiguous().reshape(-1, c, h, w) + episode_query_image = query_image[i].contiguous().reshape(-1, c, h, w) + episode_support_targets = support_target[i].reshape(-1) + episode_query_targets = query_target[i].reshape(-1) + self.set_forward_adaptation(episode_support_image, episode_query_image, episode_support_targets, + episode_query_targets) + + output = self.forward_output(episode_query_image) + + output_list.append(output) + + output = torch.cat(output_list, dim=0) + loss = F.cross_entropy(output, query_target.contiguous().view(-1)) + acc = accuracy(output, query_target.contiguous().view(-1)) + return output, acc, loss + + def set_forward_adaptation(self, support_set, query_set, support_target, query_target): + lr = self.inner_param["lr"] + fast_parameters = list(self.classifier.parameters()) + for parameter in self.classifier.parameters(): + parameter.fast = None + + self.emb_func.train() + self.classifier.train() + for i in range( + self.inner_param["train_iter"] + if self.training + else self.inner_param["test_iter"] + ): # num_step = i + # adapt loss weights + # support_set--x, query_set--x_t, support_target--y, query_target--y_t + tmp_preds = self.forward_output(x=torch.cat((support_set, query_set), 0)) + support_preds = tmp_preds[:-query_set.size(0)] + query_preds = tmp_preds[-query_set.size(0):] + weights = dict(self.classifier.named_parameters()) # name_param of classifier + meta_loss_weights = dict(self.loss_func.named_parameters()) # name_param of loss_func + meta_query_loss_weights = dict(self.query_loss_func.named_parameters()) # name_param of loss_query_func + + support_task_state = [] + + support_loss = F.cross_entropy(input=support_preds, target=support_target) + support_task_state.append(support_loss) + + for v in weights.values(): + support_task_state.append(v.mean()) + + support_task_state = torch.stack(support_task_state) + adapt_support_task_state = (support_task_state - support_task_state.mean()) / ( + support_task_state.std() + 1e-12) + + updated_meta_loss_weights = self.loss_adapter(adapt_support_task_state, i, meta_loss_weights) + + support_y = torch.zeros(support_preds.shape).to(support_preds.device) + support_y[torch.arange(support_y.size(0)), support_target] = 1 + support_task_state = torch.cat(( + support_task_state.view(1, -1).expand(support_preds.size(0), -1), + support_preds, + support_y + ), -1) + + support_task_state = (support_task_state - support_task_state.mean()) / (support_task_state.std() + 1e-12) + meta_support_loss = self.loss_func(support_task_state, i, + params=updated_meta_loss_weights).mean().squeeze() + + query_task_state = [] + for v in weights.values(): + query_task_state.append(v.mean()) + out_prob = F.log_softmax(query_preds) + instance_entropy = torch.sum(torch.exp(out_prob) * out_prob, dim=-1) + query_task_state = torch.stack(query_task_state) + query_task_state = torch.cat(( + query_task_state.view(1, -1).expand(instance_entropy.size(0), -1), + query_preds, + instance_entropy.view(-1, 1) + ), -1) + + query_task_state = (query_task_state - query_task_state.mean()) / (query_task_state.std() + 1e-12) + updated_meta_query_loss_weights = self.query_loss_adapter(query_task_state.mean(0), i, + meta_query_loss_weights) + + meta_query_loss = self.query_loss_func(query_task_state, i, + params=updated_meta_query_loss_weights).mean().squeeze() + + loss = support_loss + meta_query_loss + meta_support_loss + + # 下面应该是 使用 loss, + grad = torch.autograd.grad(loss, fast_parameters, create_graph=True, allow_unused=True) + fast_parameters = [] + + for k, weight in enumerate(list(self.classifier.parameters())): + if grad[k] is not None: + if weight.fast is None: + weight.fast = weight - lr * grad[k] + else: + weight.fast = weight.fast - lr * grad[k] + fast_parameters.append(weight.fast) + + +def extract_top_level_dict(current_dict): + output_dict = dict() + for key in current_dict.keys(): + name = key.replace("layer_dict.", "") + name = name.replace("layer_dict.", "") + name = name.replace("block_dict.", "") + name = name.replace("module-", "") + top_level = name.split(".")[0] + sub_level = ".".join(name.split(".")[1:]) + + if top_level not in output_dict: + if sub_level == "": + output_dict[top_level] = current_dict[key] + else: + output_dict[top_level] = {sub_level: current_dict[key]} + else: + new_item = {key: value for key, value in output_dict[top_level].items()} + new_item[sub_level] = current_dict[key] + output_dict[top_level] = new_item + + # print(current_dict.keys(), output_dict.keys()) + return output_dict + + +class MetaLinearLayer(nn.Module): + def __init__(self, input_shape, num_filters, use_bias): + """ + A MetaLinear layer. Applies the same functionality of a standard linearlayer with the added functionality of + being able to receive a parameter dictionary at the forward pass which allows the convolution to use external + weights instead of the internal ones stored in the linear layer. Useful for inner loop optimization in the meta + learning setting. + :param input_shape: The shape of the input data, in the form (b, f) + :param num_filters: Number of output filters + :param use_bias: Whether to use biases or not. + """ + super(MetaLinearLayer, self).__init__() + b, c = input_shape + + self.use_bias = use_bias + self.weights = nn.Parameter(torch.ones(num_filters, c)) + nn.init.xavier_uniform_(self.weights) + if self.use_bias: + self.bias = nn.Parameter(torch.zeros(num_filters)) + + def forward(self, x, params=None): + + if params is not None: + params = extract_top_level_dict(current_dict=params) + if self.use_bias: + (weight, bias) = params["weights"], params["bias"] + else: + (weight) = params["weights"] + bias = None + else: + pass + if self.use_bias: + weight, bias = self.weights, self.bias + else: + weight = self.weights + bias = None + out = F.linear(input=x, weight=weight, bias=bias) + return out + + +class MetaStepLossNetwork(nn.Module): + def __init__(self, input_dim, args): + super(MetaStepLossNetwork, self).__init__() + + self.args = args + self.input_dim = input_dim + self.input_shape = (1, input_dim) + + self.build_network() + print("meta network params") + for name, param in self.named_parameters(): + print(name, param.shape) + + def build_network(self): + """ + Builds the network before inference is required by creating some dummy inputs with the same input as the + self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and + sets output shapes for each layer. + """ + x = torch.zeros(self.input_shape) + out = x + + self.linear1 = MetaLinearLayer(input_shape=self.input_shape, + num_filters=self.input_dim, use_bias=True) + + self.linear2 = MetaLinearLayer(input_shape=(1, self.input_dim), + num_filters=1, use_bias=True) + + out = self.linear1(out) + out = F.relu_(out) + out = self.linear2(out) + + def forward(self, x, params=None): + + linear1_params = None + linear2_params = None + + if params is not None: + params = extract_top_level_dict(current_dict=params) + + linear1_params = params['linear1'] + linear2_params = params['linear2'] + + out = x + + out = self.linear1(out, linear1_params) + out = F.relu_(out) + out = self.linear2(out, linear2_params) + + return out + + def restore_backup_stats(self): + """ + Reset stored batch statistics from the stored backup. + """ + for i in range(self.num_stages): + self.layer_dict['conv{}'.format(i)].restore_backup_stats() + + +class MetaLossNetwork(nn.Module): + def __init__(self, input_dim, args): + """ + Builds a multilayer convolutional network. It also provides functionality for passing external parameters to be + used at inference time. Enables inner loop optimization readily. + :param input_dim: The input image batch shape. + :param args: A named tuple containing the system's hyperparameters. + """ + super(MetaLossNetwork, self).__init__() + + self.args = args + self.input_dim = input_dim + self.input_shape = (1, input_dim) + + self.num_steps = args['train_iter'] # number of inner-loop steps + + self.build_network() + print("meta network params") + for name, param in self.named_parameters(): + print(name, param.shape) + + def build_network(self): + """ + Builds the network before inference is required by creating some dummy inputs with the same input as the + self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and + sets output shapes for each layer. + """ + x = torch.zeros(self.input_shape) + self.layer_dict = nn.ModuleDict() + + for i in range(self.num_steps): + self.layer_dict['step{}'.format(i)] = MetaStepLossNetwork(self.input_dim, args=self.args) + + out = self.layer_dict['step{}'.format(i)](x) + + def forward(self, x, num_step, params=None): + param_dict = dict() + + if params is not None: + params = {key: value for key, value in params.items()} + param_dict = extract_top_level_dict(current_dict=params) + + for name, param in self.layer_dict.named_parameters(): + path_bits = name.split(".") + layer_name = path_bits[0] + if layer_name not in param_dict: + param_dict[layer_name] = None + + out = x + + out = self.layer_dict['step{}'.format(num_step)](out, param_dict['step{}'.format(num_step)]) + + return out + + def restore_backup_stats(self): + """ + Reset stored batch statistics from the stored backup. + """ + for i in range(self.num_stages): + self.layer_dict['conv{}'.format(i)].restore_backup_stats() + + +class StepLossAdapter(nn.Module): + def __init__(self, input_dim, num_loss_net_layers, args): + super(StepLossAdapter, self).__init__() + + self.args = args + output_dim = num_loss_net_layers * 2 * 2 # 2 for weight and bias, another 2 for multiplier and offset + + self.linear1 = nn.Linear(input_dim, input_dim) + self.activation = nn.ReLU(inplace=True) + self.linear2 = nn.Linear(input_dim, output_dim) + + self.multiplier_bias = nn.Parameter(torch.zeros(output_dim // 2)) + self.offset_bias = nn.Parameter(torch.zeros(output_dim // 2)) + + def forward(self, task_state, num_step, loss_params): + + out = self.linear1(task_state) + out = F.relu_(out) + out = self.linear2(out) + + generated_multiplier, generated_offset = torch.chunk(out, chunks=2, dim=-1) + + i = 0 + updated_loss_weights = dict() + for key, val in loss_params.items(): + if 'step{}'.format(num_step) in key: + updated_loss_weights[key] = (1 + self.multiplier_bias[i] * generated_multiplier[i]) * val + \ + self.offset_bias[i] * generated_offset[i] + i += 1 + + return updated_loss_weights + + +class LossAdapter(nn.Module): + def __init__(self, input_dim, num_loss_net_layers, args): + super(LossAdapter, self).__init__() + + self.args = args + + self.num_steps = args['train_iter'] # number of inner-loop steps + self.loss_adapter = nn.ModuleList() + for i in range(self.num_steps): + self.loss_adapter.append(StepLossAdapter(input_dim, num_loss_net_layers, args)) + + def forward(self, task_state, num_step, loss_params): + return self.loss_adapter[num_step](task_state, num_step, loss_params) diff --git a/core/model/meta/metal_util.py b/core/model/meta/metal_util.py new file mode 100644 index 00000000..c3991aca --- /dev/null +++ b/core/model/meta/metal_util.py @@ -0,0 +1,282 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch +import numpy as np + + +class MetaLinearLayer(nn.Module): + def __init__(self, input_shape, num_filters, use_bias): + """ + A MetaLinear layer. Applies the same functionality of a standard linearlayer with the added functionality of + being able to receive a parameter dictionary at the forward pass which allows the convolution to use external + weights instead of the internal ones stored in the linear layer. Useful for inner loop optimization in the meta + learning setting. + :param input_shape: The shape of the input data, in the form (b, f) + :param num_filters: Number of output filters + :param use_bias: Whether to use biases or not. + """ + super(MetaLinearLayer, self).__init__() + b, c = input_shape + + self.use_bias = use_bias + self.weights = nn.Parameter(torch.ones(num_filters, c)) + nn.init.xavier_uniform_(self.weights) + if self.use_bias: + self.bias = nn.Parameter(torch.zeros(num_filters)) + + def forward(self, x, params=None): + """ + Forward propagates by applying a linear function (Wx + b). If params are none then internal params are used. + Otherwise passed params will be used to execute the function. + :param x: Input data batch, in the form (b, f) + :param params: A dictionary containing 'weights' and 'bias'. If params are none then internal params are used. + Otherwise the external are used. + :return: The result of the linear function. + """ + if params is not None: + params = extract_top_level_dict(current_dict=params) + if self.use_bias: + (weight, bias) = params["weights"], params["bias"] + else: + (weight) = params["weights"] + bias = None + else: + pass + # print('no inner loop params', self) + + if self.use_bias: + weight, bias = self.weights, self.bias + else: + weight = self.weights + bias = None + # print(x.shape) + out = F.linear(input=x, weight=weight, bias=bias) + return out + + +def extract_top_level_dict(current_dict): + output_dict = dict() + for key in current_dict.keys(): + name = key.replace("layer_dict.", "") + name = name.replace("layer_dict.", "") + name = name.replace("block_dict.", "") + name = name.replace("module-", "") + top_level = name.split(".")[0] + sub_level = ".".join(name.split(".")[1:]) + + if top_level not in output_dict: + if sub_level == "": + output_dict[top_level] = current_dict[key] + else: + output_dict[top_level] = {sub_level: current_dict[key]} + else: + new_item = {key: value for key, value in output_dict[top_level].items()} + new_item[sub_level] = current_dict[key] + output_dict[top_level] = new_item + + # print(current_dict.keys(), output_dict.keys()) + return output_dict + + +class MetaStepLossNetwork(nn.Module): + def __init__(self, input_dim, device): + super(MetaStepLossNetwork, self).__init__() + + self.linear2 = None + self.linear1 = None + self.device = device + self.input_dim = input_dim + self.input_shape = (1, input_dim) + + self.build_network() + print("meta network params") + for name, param in self.named_parameters(): + print(name, param.shape) + + def build_network(self): + """ + Builds the network before inference is required by creating some dummy inputs with the same input as the + self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and + sets output shapes for each layer. + """ + x = torch.zeros(self.input_shape) + out = x + + self.linear1 = MetaLinearLayer(input_shape=self.input_shape, + num_filters=self.input_dim, use_bias=True) + + self.linear2 = MetaLinearLayer(input_shape=(1, self.input_dim), + num_filters=1, use_bias=True) + + out = self.linear1(out) + out = F.relu_(out) + out = self.linear2(out) + + def forward(self, x, params=None): + + linear1_params = None + linear2_params = None + + if params is not None: + params = extract_top_level_dict(current_dict=params) + + linear1_params = params['linear1'] + linear2_params = params['linear2'] + + out = x + + out = self.linear1(out, linear1_params) + out = F.relu_(out) + out = self.linear2(out, linear2_params) + + return out + + def zero_grad(self, params=None): + if params is None: + for param in self.parameters(): + if param.requires_grad: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + else: + for name, param in params.items(): + if param.requires_grad: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + params[name].grad = None + + def restore_backup_stats(self): + """ + Reset stored batch statistics from the stored backup. + """ + for i in range(self.num_stages): + self.layer_dict['conv{}'.format(i)].restore_backup_stats() + + +class MetaLossNetwork(nn.Module): + def __init__(self, input_dim, device): + + super(MetaLossNetwork, self).__init__() + + self.layer_dict = None + self.device = device + self.input_dim = input_dim + self.input_shape = (1, input_dim) + # TODO 修改成配置文件 num_steps + self.num_steps = 5 + + self.build_network() + print("meta network params") + for name, param in self.named_parameters(): + print(name, param.shape) + + def build_network(self): + """ + Builds the network before inference is required by creating some dummy inputs with the same input as the + self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and + sets output shapes for each layer. + """ + x = torch.zeros(self.input_shape) + self.layer_dict = nn.ModuleDict() + + for i in range(self.num_steps): + self.layer_dict['step{}'.format(i)] = MetaStepLossNetwork(self.input_dim, + device=self.device) + + out = self.layer_dict['step{}'.format(i)](x) + + def forward(self, x, num_step, params=None): + param_dict = dict() + + if params is not None: + params = {key: value[0] for key, value in params.items()} + param_dict = extract_top_level_dict(current_dict=params) + + for name, param in self.layer_dict.named_parameters(): + path_bits = name.split(".") + layer_name = path_bits[0] + if layer_name not in param_dict: + param_dict[layer_name] = None + + out = x + + out = self.layer_dict['step{}'.format(num_step)](out, param_dict['step{}'.format(num_step)]) + + return out + + def zero_grad(self, params=None): + if params is None: + for param in self.parameters(): + if param.requires_grad: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + else: + for name, param in params.items(): + if param.requires_grad: + if param.grad is not None: + if torch.sum(param.grad) > 0: + print(param.grad) + param.grad.zero_() + params[name].grad = None + + def restore_backup_stats(self): + """ + Reset stored batch statistics from the stored backup. + """ + for i in range(self.num_stages): + self.layer_dict['conv{}'.format(i)].restore_backup_stats() + + +class StepLossAdapter(nn.Module): + def __init__(self, input_dim, num_loss_net_layers, device): + super(StepLossAdapter, self).__init__() + + self.device = device + output_dim = num_loss_net_layers * 2 * 2 # 2 for weight and bias, another 2 for multiplier and offset + + self.linear1 = nn.Linear(input_dim, input_dim) + self.activation = nn.ReLU(inplace=True) + self.linear2 = nn.Linear(input_dim, output_dim) + + self.multiplier_bias = nn.Parameter(torch.zeros(output_dim // 2)) + self.offset_bias = nn.Parameter(torch.zeros(output_dim // 2)) + + def forward(self, task_state, num_step, loss_params): + + out = self.linear1(task_state) + out = F.relu_(out) + out = self.linear2(out) + + generated_multiplier, generated_offset = torch.chunk(out, chunks=2, dim=-1) + + i = 0 + updated_loss_weights = dict() + for key, val in loss_params.items(): + if 'step{}'.format(num_step) in key: + updated_loss_weights[key] = (1 + self.multiplier_bias[i] * generated_multiplier[i]) * val + \ + self.offset_bias[i] * generated_offset[i] + i += 1 + + return updated_loss_weights + + +class LossAdapter(nn.Module): + def __init__(self, input_dim, num_loss_net_layers, device): + super(LossAdapter, self).__init__() + + self.device = device + # TODO 修改成配置文件 num_steps + + self.num_steps = 5 # number of inn r-loop steps + + self.loss_adapter = nn.ModuleList() + for i in range(self.num_steps): + self.loss_adapter.append(StepLossAdapter(input_dim, num_loss_net_layers, device=device)) + + def forward(self, task_state, num_step, loss_params): + return self.loss_adapter[num_step](task_state, num_step, loss_params) diff --git a/core/model/meta/r2d2.py b/core/model/meta/r2d2.py index a1f2735e..ac409088 100644 --- a/core/model/meta/r2d2.py +++ b/core/model/meta/r2d2.py @@ -54,7 +54,7 @@ def binv(b_mat): """ id_matrix = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat).to(b_mat.device) - b_inv, _ = torch.solve(id_matrix, b_mat) + b_inv = torch.linalg.solve(id_matrix, b_mat) return b_inv diff --git a/reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-1-Table2.yaml b/reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-1-Table2.yaml new file mode 100644 index 00000000..de095164 --- /dev/null +++ b/reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-1-Table2.yaml @@ -0,0 +1,74 @@ +augment: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: + is_feature: false + is_flatten: true + last_pool: true + leaky_relu: false + negative_slope: 0.2 + name: Conv64F +batch_size: 128 +classifier: + kwargs: + feat_dim: 1600 + inner_param: + lr: 0.01 + test_iter: 5 + train_iter: 5 + name: METAL +data_root: datasets/miniImageNet--ravi +dataloader_num: 1 +deterministic: true +device_ids: 0 +episode_size: 2 +epoch: 100 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + gamma: 1.0 + step_size: 20 + name: StepLR +n_gpu: 1 +optimizer: + kwargs: + lr: 0.001 + name: Adam + other: null +parallel_part: +- emb_func +port: 31594 +pretrain_path: null +query_num: 15 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 2147483647 +shot_num: 1 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 600 +test_epoch: 5 +test_query: 15 +test_shot: 1 +test_way: 5 +train_episode: 2000 +use_memory: false +val_per_epoch: 1 +warmup: 0 +way_num: 5 +workers: 1 diff --git a/reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-5-Table2.yaml b/reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-5-Table2.yaml new file mode 100644 index 00000000..6d68290a --- /dev/null +++ b/reproduce/MeTaL/METAL-miniImageNet--ravi-Conv64F-5-5-Table2.yaml @@ -0,0 +1,74 @@ +augment: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: + is_feature: false + is_flatten: true + last_pool: true + leaky_relu: false + negative_slope: 0.2 + name: Conv64F +batch_size: 128 +classifier: + kwargs: + feat_dim: 1600 + inner_param: + lr: 0.01 + test_iter: 5 + train_iter: 5 + name: METAL +data_root: datasets/miniImageNet--ravi +dataloader_num: 1 +deterministic: true +device_ids: 0 +episode_size: 2 +epoch: 100 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + gamma: 1.0 + step_size: 20 + name: StepLR +n_gpu: 1 +optimizer: + kwargs: + lr: 0.001 + name: Adam + other: null +parallel_part: +- emb_func +port: 25269 +pretrain_path: null +query_num: 15 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 2147483647 +shot_num: 5 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 600 +test_epoch: 5 +test_query: 15 +test_shot: 5 +test_way: 5 +train_episode: 2000 +use_memory: false +val_per_epoch: 1 +warmup: 0 +way_num: 5 +workers: 1 diff --git a/reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-1-Table2.yaml b/reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-1-Table2.yaml new file mode 100644 index 00000000..36f989ab --- /dev/null +++ b/reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-1-Table2.yaml @@ -0,0 +1,69 @@ +METAL-miniImageNet--ravi-resnet12-5-1.yaml: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: null + name: resnet12 +batch_size: 128 +classifier: + kwargs: + feat_dim: 640 + inner_param: + lr: 0.01 + test_iter: 5 + train_iter: 5 + name: METAL +data_root: datasets/miniImageNet--ravi +dataloader_num: 1 +deterministic: true +device_ids: 0 +episode_size: 2 +epoch: 100 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + gamma: 1.0 + step_size: 20 + name: StepLR +n_gpu: 1 +optimizer: + kwargs: + lr: 0.001 + name: Adam + other: null +parallel_part: +- emb_func +port: 46621 +pretrain_path: null +query_num: 3 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 2147483647 +shot_num: 1 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 600 +test_epoch: 5 +test_query: 3 +test_shot: 1 +test_way: 5 +train_episode: 2000 +use_memory: false +val_per_epoch: 1 +warmup: 0 +way_num: 5 +workers: 1 diff --git a/reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-5-Table2.yaml b/reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-5-Table2.yaml new file mode 100644 index 00000000..6b7f9fb5 --- /dev/null +++ b/reproduce/MeTaL/METAL-miniImageNet--ravi-resnet12-5-5-Table2.yaml @@ -0,0 +1,69 @@ +augment: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: null + name: resnet12 +batch_size: 64 +classifier: + kwargs: + feat_dim: 640 + inner_param: + lr: 0.01 + test_iter: 5 + train_iter: 5 + name: METAL +data_root: datasets/miniImageNet--ravi +dataloader_num: 1 +deterministic: true +device_ids: 0 +episode_size: 1 +epoch: 100 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + gamma: 1.0 + step_size: 20 + name: StepLR +n_gpu: 1 +optimizer: + kwargs: + lr: 0.001 + name: Adam + other: null +parallel_part: +- emb_func +port: 44341 +pretrain_path: null +query_num: 3 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 2147483647 +shot_num: 5 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 600 +test_epoch: 5 +test_query: 3 +test_shot: 5 +test_way: 5 +train_episode: 2000 +use_memory: false +val_per_epoch: 1 +warmup: 0 +way_num: 5 +workers: 1 diff --git a/reproduce/MeTaL/README.md b/reproduce/MeTaL/README.md new file mode 100644 index 00000000..0e33b53f --- /dev/null +++ b/reproduce/MeTaL/README.md @@ -0,0 +1,27 @@ +# Template for Reproduce configs +## Introduction +| Name: | [MeTal](https://arxiv.org/abs/2110.03909) | +|----------|---------------------------------------------------------------------------------------------------------| +| Embed.: | Conv64F,ResNet12 | +| Type: | Meta | +| Venue: | arXiv'21 | +| Codes: | [**MeTal**](https://github.com/baiksung/MeTAL)| +Cite this work with: +```bibtex +@InProceedings{baik2021meta, + title={Meta-Learning with Task-Adaptive Loss Function for Few-Shot Learning}, + author={Sungyong Baik, Janghoon Choi, Heewon Kim, Dohee Cho, Jaesik Min, Kyoung Mu Lee} + booktitle = {International Conference on Computer Vision (ICCV)}, + year={2021} +} +``` +--- +## Results and Models + +**Classification** + +| | Embedding | :book: *mini*ImageNet (5,1) | :computer: *mini*ImageNet (5,1) | :book:*mini*ImageNet (5,5) | :computer: *mini*ImageNet (5,5) | :memo: Comments | +|---|----------|-----------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------|----------------------------------------------------------------------------------------------------------------------|-----------------| +| 1 | Conv64F | - | 52.364 [:arrow_down:](https://drive.google.com/file/d/1ljtq5PH7VywDh2ZInqWzCOn5Lowu0zyC/view?usp=drive_link) [:clipboard:](./METAL-miniImageNet--ravi-Conv64F-5-1-Table2.yaml) | - | 70.421 [:arrow_down:](https://drive.google.com/file/d/1lzgeg4ckxSP1Zu-E_f4gfMenkK49_2tV/view?usp=drive_link) [:clipboard:](./METAL-miniImageNet--ravi-Conv64F-5-5-Table2.yaml) | Table.2 | +| 2 | ResNet12 | - | 60.542 [:arrow_down:](https://drive.google.com/file/d/1qLrWig2eq85wxXkZrP6XGzKqnL6RO3IS/view?usp=drive_link) [:clipboard:](./METAL-miniImageNet--ravi-resnet12-5-1-Table2.yaml) | - | 76.880 [:arrow_down:](https://drive.google.com/file/d/1fNUAd9gpKHUeoOSkkVzQj9BPmITnFQEx/view?usp=drive_link) [:clipboard:](./METAL-miniImageNet--ravi-resnet12-5-5-Table2.yaml) | Table.2 | + diff --git a/reproduce/README.md b/reproduce/README.md index a014fa7e..406659de 100644 --- a/reproduce/README.md +++ b/reproduce/README.md @@ -175,6 +175,7 @@ This folder contains: 67.65 68.17 + DN4 Conv64F @@ -206,6 +207,21 @@ This folder contains: 82.58 82.13 + + MeTal + Conv64F + 52.63 + 52.36 + 70.52 + 70.42 + + + ResNet12 + 59.64 + 60.54 + 76.20 + 76.88 + diff --git a/results/README.md b/results/README.md deleted file mode 100644 index 092f9b26..00000000 --- a/results/README.md +++ /dev/null @@ -1 +0,0 @@ -This folder contains all training and testing results. diff --git a/run_test.py b/run_test.py index 958c87f2..54896126 100644 --- a/run_test.py +++ b/run_test.py @@ -9,11 +9,11 @@ from core import Test -PATH = "./results/DN4-miniImageNet--ravi-Conv64F-5-1-Dec-01-2021-06-05-20" +PATH = "./results/DN4-WebCaricature-Conv64F-5-5-Nov-17-2023-19-42-01" VAR_DICT = { "test_epoch": 5, - "device_ids": "4,5", - "n_gpu": 2, + "device_ids": "0", + "n_gpu": 1, "test_episode": 600, "episode_size": 2, } diff --git a/run_trainer.py b/run_trainer.py index be786ee7..5b1cc0e9 100644 --- a/run_trainer.py +++ b/run_trainer.py @@ -15,10 +15,9 @@ def main(rank, config): if __name__ == "__main__": - config = Config("./config/proto.yaml").get_config_dict() - + config = Config("./config/metal.yaml").get_config_dict() if config["n_gpu"] > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config["device_ids"] torch.multiprocessing.spawn(main, nprocs=config["n_gpu"], args=(config,)) else: - main(0, config) \ No newline at end of file + main(0, config)