diff --git a/DMF_setup.py b/DMF_setup.py new file mode 100644 index 00000000..33510031 --- /dev/null +++ b/DMF_setup.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# !/usr/bin/env python + +import glob +import os + +import torch +from setuptools import find_packages +from setuptools import setup +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +requirements = ["torch", "torchvision"] + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "dconv", "csrc") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + + extra_compile_args = {"cxx": []} + define_macros = [] + + if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + + sources = [os.path.join(extensions_dir, s) for s in sources] + + include_dirs = [extensions_dir] + + ext_modules = [ + extension( + "dconv._C", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + + return ext_modules + + +setup( + name="dconv", + version="0.1", + author="fmassa", + url="https://github.com/facebookresearch/maskrcnn-benchmark", + description="object detection in pytorch", + packages=find_packages( + exclude=( + "configs", + "tests", + ) + ), + # install_requires=requirements, + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/README.md b/README.md index 5f9ed997..5d1a1597 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ Wenbin Li, Ziyi Wang, Xuesong Yang, Chuanqi Dong, Pinzhuo Tian, Tiexin Qin, Jing + [FEAT (CVPR 2020)](http://arxiv.org/abs/1812.03664) + [RENet (ICCV 2021)](https://arxiv.org/abs/2108.09666) + [FRN (CVPR 2021)](https://arxiv.org/abs/2012.01506) ++ [DMF (CVPR 2021)](https://arxiv.org/pdf/2103.13582) + [DeepBDC (CVPR 2022)](https://arxiv.org/abs/2204.04567) + [CPEA (ICCV 2023)](https://openaccess.thecvf.com/content/ICCV2023/papers/Hao_Class-Aware_Patch_Embedding_Adaptation_for_Few-Shot_Image_Classification_ICCV_2023_paper.pdf) diff --git a/config/DMF.yaml b/config/DMF.yaml new file mode 100644 index 00000000..dc2055f8 --- /dev/null +++ b/config/DMF.yaml @@ -0,0 +1,39 @@ +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml + +way_num: 5 +shot_num: 1 +query_num: 6 +episode_size: 8 +train_episode: 2000 +test_episode: 1200 + +device_ids: 0,1,2,3 +n_gpu: 4 +epoch: 120 + +optimizer: + name: SGD + kwargs: + lr: 0.05 + momentum: 0.9 + nesterov: true + weight_decay: 0.0005 + other: null + +backbone: + name: resnet12_drop + kwargs: + drop_block: true + +classifier: + name: DMF + kwargs: + num_class: 64 + nFeat: 640 + kernel: 1 + groups: 64 diff --git a/config/backbones/resnet12_drop.yaml b/config/backbones/resnet12_drop.yaml new file mode 100644 index 00000000..9dcff29b --- /dev/null +++ b/config/backbones/resnet12_drop.yaml @@ -0,0 +1,4 @@ +backbone: + name: resnet12_drop + kwargs: + drop_block: true diff --git a/config/classifiers/DMF.yaml b/config/classifiers/DMF.yaml new file mode 100644 index 00000000..6a1c5c09 --- /dev/null +++ b/config/classifiers/DMF.yaml @@ -0,0 +1,7 @@ +classifier: + name: DMF + kwargs: + num_class: 64 + nFeat: 64 + kernel: 3 + groups: 1 diff --git a/core/model/backbone/__init__.py b/core/model/backbone/__init__.py index 7f9e431b..9c693a2d 100644 --- a/core/model/backbone/__init__.py +++ b/core/model/backbone/__init__.py @@ -8,6 +8,7 @@ from .swin_transformer import swin_s, swin_l, swin_b, swin_t, swin_mini from .resnet_bdc import resnet12Bdc, resnet18Bdc from core.model.backbone.utils.maml_module import convert_maml_module +from .resnet12_drop import resnet12_drop def get_backbone(config): diff --git a/core/model/backbone/resnet12_drop.py b/core/model/backbone/resnet12_drop.py new file mode 100644 index 00000000..d349a552 --- /dev/null +++ b/core/model/backbone/resnet12_drop.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Bernoulli + + +class DropBlock(nn.Module): + def __init__(self, block_size): + super(DropBlock, self).__init__() + + self.block_size = block_size + # self.gamma = gamma + # self.bernouli = Bernoulli(gamma) + + def forward(self, x, gamma): + # shape: (bsize, channels, height, width) + + if self.training: + batch_size, channels, height, width = x.shape + + bernoulli = Bernoulli(gamma) + mask = bernoulli.sample( + ( + batch_size, + channels, + height - (self.block_size - 1), + width - (self.block_size - 1), + ) + ).cuda() + # print((x.sample[-2], x.sample[-1])) + block_mask = self._compute_block_mask(mask) + # print (block_mask.size()) + # print (x.size()) + countM = ( + block_mask.size()[0] + * block_mask.size()[1] + * block_mask.size()[2] + * block_mask.size()[3] + ) + count_ones = block_mask.sum() + + return block_mask * x * (countM / count_ones) + else: + return x + + def _compute_block_mask(self, mask): + left_padding = int((self.block_size - 1) / 2) + right_padding = int(self.block_size / 2) + + batch_size, channels, height, width = mask.shape + # print ("mask", mask[0][0]) + non_zero_idxs = mask.nonzero() + nr_blocks = non_zero_idxs.shape[0] + + offsets = ( + torch.stack( + [ + torch.arange(self.block_size) + .view(-1, 1) + .expand(self.block_size, self.block_size) + .reshape(-1), # - left_padding, + torch.arange(self.block_size).repeat(self.block_size), # - left_padding + ] + ) + .t() + .cuda() + ) + offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1) + + if nr_blocks > 0: + non_zero_idxs = non_zero_idxs.repeat(self.block_size**2, 1) + offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) + offsets = offsets.long() + + block_idxs = non_zero_idxs + offsets + # block_idxs += left_padding + padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) + padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = ( + 1.0 + ) + else: + padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) + + block_mask = 1 - padded_mask # [:height, :width] + return block_mask + + +# This ResNet network was designed following the practice of the following papers: +# TADAM: Task dependent adaptive metric for improved few-shot learning (Oreshkin et al., in NIPS 2018) and +# A Simple Neural Attentive Meta-Learner (Mishra et al., in ICLR 2018). + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + drop_rate=0.0, + drop_block=False, + block_size=1, + pool=True, + ): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.LeakyReLU(0.1) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = conv3x3(planes, planes) + self.bn3 = nn.BatchNorm2d(planes) + self.maxpool = nn.MaxPool2d(stride, ceil_mode=True) + self.downsample = downsample + self.stride = stride + self.drop_rate = drop_rate + self.register_buffer("num_batches_tracked", torch.tensor(0)) + self.drop_block = drop_block + self.block_size = block_size + self.DropBlock = DropBlock(block_size=self.block_size) + self.pool = pool + + def forward(self, x): + self.num_batches_tracked += 1 + + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + if self.pool: + out = self.maxpool(out) + + if self.drop_rate > 0: + if self.drop_block is True: + feat_size = out.size()[2] + keep_rate = max( + 1.0 - self.drop_rate / (20 * 2000) * (self.num_batches_tracked), + 1.0 - self.drop_rate, + ) + gamma = ( + (1 - keep_rate) + / self.block_size**2 + * feat_size**2 + / (feat_size - self.block_size + 1) ** 2 + ) + out = self.DropBlock(out, gamma=gamma) + else: + out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, drop_block=False, drop_rate=0.1, dropblock_size=5): + self.inplanes = 3 + self.nFeat = 640 + super(ResNet, self).__init__() + + self.layer1 = self._make_layer(block, 64, stride=2, drop_rate=drop_rate) + self.layer2 = self._make_layer(block, 160, stride=2, drop_rate=drop_rate) + self.layer3 = self._make_layer( + block, + 320, + stride=2, + drop_rate=drop_rate, + drop_block=drop_block, + block_size=dropblock_size, + ) + self.layer4 = self._make_layer( + block, + 640, + stride=2, + drop_rate=drop_rate, + drop_block=drop_block, + block_size=dropblock_size, + pool=False, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="leaky_relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer( + self, block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1, pool=True + ): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size, pool + ) + ) + self.inplanes = planes * block.expansion + + return nn.Sequential(*layers) + + def forward(self, x, return_feat=False, return_both=False, return_map=False): + + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + # return [x1, x2, x3, x4] + # return [x4] + return x4 + # return [x3, x4] + # return [x2, x3, x4] + + def forward_as_dict(self, x): + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + # x = self.avgpool(x) + # x = x.view(x.size(0), -1) + # result = self.fc(x) + return x + + +def resnet12_drop(drop_block=True, **kwargs): + """Constructs a ResNet-12 model.""" + model = ResNet(BasicBlock, drop_block=drop_block, **kwargs) + return model diff --git a/core/model/meta/__init__.py b/core/model/meta/__init__.py index 29db67dc..559cdcee 100644 --- a/core/model/meta/__init__.py +++ b/core/model/meta/__init__.py @@ -7,4 +7,4 @@ from .leo import LEO from .mtl import MTL from .boil import BOIL -from .matchingnet import DMatchingNet +from .matchingnet_ifsl import DMatchingNet diff --git a/core/model/metric/DMF.py b/core/model/metric/DMF.py new file mode 100644 index 00000000..3bd2b6e8 --- /dev/null +++ b/core/model/metric/DMF.py @@ -0,0 +1,352 @@ +# -*- coding: utf-8 -*- +""" +@inproceedings{xu2021dmf, + title={Learning Dynamic Alignment via Meta-filter for Few-shot Learning}, + author={Chengming Xu and Chen Liu and Li Zhang and Chengjie Wang and Jilin Li and Feiyue Huang and Xiangyang Xue and Yanwei Fu}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + year={2021} +} +https://arxiv.org/pdf/2103.13582 + +Adapted from https://github.com/chmxu/Dynamic-Meta-filter. +""" +from __future__ import absolute_import +from __future__ import division + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from dconv.layers import DeformConv +from torchdiffeq import odeint as odeint + +from core.utils import accuracy +from core.model.metric.metric_model import MetricModel + + +class CrossEntropyLoss(nn.Module): + def __init__(self): + super(CrossEntropyLoss, self).__init__() + self.logsoftmax = nn.LogSoftmax(dim=1) + + def forward(self, inputs, targets): + input_ = inputs + input_ = input_.contiguous().view(input_.size(0), input_.size(1), -1) + + log_probs = self.logsoftmax(input_) + targets_ = torch.zeros(input_.size(0), input_.size(1)).scatter_( + 1, targets.unsqueeze(1).data.cpu(), 1 + ) + targets_ = targets_.unsqueeze(-1) + targets_ = targets_.cuda() + loss = (-targets_ * log_probs).mean(0).sum() + return loss / input_.size(2) + + +def one_hot(labels_train): + """ + Turn the labels_train to one-hot encoding. + Args: + labels_train: [batch_size, num_train_examples] + Return: + labels_train_1hot: [batch_size, num_train_examples, K] + """ + labels_train = labels_train.cpu() + nKnovel = 1 + labels_train.max() + labels_train_1hot_size = list(labels_train.size()) + [ + nKnovel, + ] + labels_train_unsqueeze = labels_train.unsqueeze(dim=labels_train.dim()) + labels_train_1hot = torch.zeros(labels_train_1hot_size).scatter_( + len(labels_train_1hot_size) - 1, labels_train_unsqueeze, 1 + ) + return labels_train_1hot + + +def shuffle(images, targets, global_targets): + """ + A trick for META_FILTER training + """ + sample_num = images.shape[1] + for i in range(4): + indices = torch.randperm(sample_num).to(images.device) + images = images.index_select(1, indices) + targets = targets.index_select(1, indices) + global_targets = global_targets.index_select(1, indices) + return images, targets, global_targets + + +# Dynamic Sampling +class DynamicWeights_(nn.Module): + def __init__(self, channels, dilation=1, kernel=3, groups=1): + super(DynamicWeights_, self).__init__() + self.softmax = nn.Softmax(dim=-1) + + # padding = 1 if kernel == 3 else 0 + # offset_groups = 1 + self.off_conv = nn.Conv2d( + channels * 2, 3 * 3 * 2, 5, padding=2, dilation=dilation, bias=False + ) + self.kernel_conv = DeformConv( + channels, + groups * kernel * kernel, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=False, + ) + + self.K = kernel * kernel + self.group = groups + + def forward(self, support, query): + N, C, H, W = support.size() + # R = C // self.group + offset = self.off_conv(torch.cat([query, support], 1)) # 学习可变形卷积的偏移量矩阵 + dynamic_filter = self.kernel_conv(support, offset) # 进行可变形卷积 + dynamic_filter = F.sigmoid(dynamic_filter) + return dynamic_filter + + +class DynamicWeights(nn.Module): + def __init__(self, channels, dilation=1, kernel=3, groups=1, nFeat=640): + super(DynamicWeights, self).__init__() + self.softmax = nn.Softmax(dim=-1) + padding = 1 if kernel == 3 else 0 + # offset_groups = 1 + self.unfold = nn.Unfold( + kernel_size=(kernel, kernel), padding=padding, dilation=1 + ) # 展平操作,将输入特征图中的局部区域展开为列 + + self.K = kernel * kernel # 卷积核的总大小 + self.group = groups # 组卷积组数 + self.nFeat = nFeat # 特征数量 + + def forward(self, t=None, x=None): + query, dynamic_filter = x + N, C, H, W = query.size() + N_, C_, H_, W_ = dynamic_filter.size() + R = C // self.group + # 将动态滤波器重新调整为形状为(-1,self.K)的张量 + dynamic_filter = dynamic_filter.reshape(-1, self.K) + + xd_unfold = self.unfold(query) + + xd_unfold = xd_unfold.contiguous().view(N, C, self.K, H * W) + xd_unfold = ( + xd_unfold.permute(0, 1, 3, 2) + .contiguous() + .view(N, self.group, R, H * W, self.K) + .permute(0, 1, 3, 2, 4) + .contiguous() + .view(N * self.group * H * W, R, self.K) + ) # 调整大小 + # 批量矩阵乘法,对一批矩阵中的每一对进行矩阵乘法。unsqueeze()函数为在指定维度上插入一维 + out1 = torch.bmm(xd_unfold, dynamic_filter.unsqueeze(2)) + out1 = ( + out1.contiguous() + .view(N, self.group, H * W, R) + .permute(0, 1, 3, 2) + .contiguous() + .view(N, self.group * R, H * W) + .view(N, self.group * R, H, W) + ) + + out1 = F.relu(out1) + return (out1, torch.zeros([N_, C_, H_, W_]).cuda()) + + +class ODEBlock(nn.Module): + def __init__(self, odefunc): + super(ODEBlock, self).__init__() + self.odefunc = odefunc # 定义ODE函数 + self.integration_time = torch.tensor([0, 1]).float() # 定义积分区间 + + def forward(self, x): + self.integration_time = self.integration_time.type_as( + x[0] + ) # 将其数据类型转换为与输入x的类型相同 + out = odeint( + self.odefunc, x, self.integration_time, rtol=1e-2, atol=1e-2, method="rk4" + ) # 求解器,rtol和atol为相对和绝对误差限,rk4为四阶龙格库塔方法 + return out[0][1] # 返回积分结果的第二个时间点的状态 + + @property + def nfe(self): + return self.odefunc.nfe + + @nfe.setter + def nfe(self, value): + self.odefunc.nfe = value + + +class Model(nn.Module): + def __init__(self, num_classes=64, nFeat=640, kernel=3, groups=1): + super(Model, self).__init__() + self.nFeat = nFeat + self.global_clasifier = nn.Conv2d(self.nFeat, num_classes, kernel_size=1) # 图中的f_gc + + self.dw_gen = DynamicWeights_(self.nFeat, 1, kernel, groups) + self.dw = self.dw = ODEBlock(DynamicWeights(self.nFeat, 1, kernel, groups, self.nFeat)) + + # 增加ftrain和ftest的维度,其实是做到了统一维度 + + def reshape(self, ftrain, ftest): + b, n1, c, h, w = ftrain.shape + n2 = ftest.shape[1] + ftrain = ftrain.unsqueeze(2).repeat(1, 1, n2, 1, 1, 1) + ftest = ftest.unsqueeze(1).repeat(1, n1, 1, 1, 1, 1) + return ftrain, ftest + + def get_score(self, ftrain, ftest, num_train, num_test, batch_size): + b, n2, n1, c, h, w = ftrain.shape + + ftrain_ = ftrain.clone() + ftest_ = ftest.clone() + ftrain_ = ftrain_.contiguous().view(-1, *ftrain.size()[3:]) + # 将ftrain_和ftest_变为4维(b*n1*n2,c,h,w) + ftest_ = ftest_.contiguous().view(-1, *ftest.size()[3:]) + + ftrain_norm = F.normalize(ftrain, p=2, dim=3, eps=1e-12) # ftrain归一化处理 + # 将ftrain变为4维(b*n1*n2,c,h,w) + ftrain_norm = ftrain_norm.reshape(-1, *ftrain_norm.size()[3:]) + # 使用全局平均池化学得一个元分类器 + # 求两次平均,并保持原始维度,得到的是meta_classifier + conv_weight = ftrain_norm.mean(-1, keepdim=True).mean(-2, keepdim=True) + # 第一次对每个矩阵的列求平均,第二次对每个矩阵的行求平均,最终化为(b*n1*n2,c,1,1) + + # 用ftrain和ftest进行动态采样学得一个dynamic_filter + filter_weight = self.dw_gen(ftrain_, ftest_) + cls_scores = self.dw(x=(ftest_, filter_weight)) # 对动态卷积过程进行神经ODE + cls_scores = cls_scores.contiguous().view(b * n2, n1, *cls_scores.size()[1:]) + cls_scores = cls_scores.contiguous().view(1, -1, *cls_scores.size()[3:]) + cls_scores = F.conv2d( + cls_scores, conv_weight, groups=b * n1 * n2, padding=1 + ) # 将计算出的得分与卷积权重卷积 + cls_scores = cls_scores.contiguous().view(b * n2, n1, *cls_scores.size()[2:]) + return cls_scores + + def get_global_pred(self, ftest, ytest, num_test, batch_size, K): + h = ftest.shape[-1] # h是ftest的最后一个维度 + # 改变ftest的维度为(batch_size,num_test,K,-1),-1为自适应维度 + ftest_ = ftest.contiguous().view(batch_size, num_test, K, -1) + ftest_ = ftest_.transpose(2, 3) # 对ftest_中的每个矩阵进行转置 + ytest_ = ytest.unsqueeze(3) # 在ytest中加上一维, + ftest_ = torch.matmul(ftest_, ytest_) # ftest_和ytest_进行广义张量乘法 + # 改变维度为(batch_size * num_test,-1,h,h),即每个矩阵变成方阵 + ftest_ = ftest_.contiguous().view(batch_size * num_test, -1, h, h) + global_pred = self.global_clasifier(ftest_) # 进行图中的f_gc操作 + return global_pred + + def get_test_score(self, score_list): + return score_list.mean(-1).mean(-1) + + def forward(self, support_feat, query_feat, support_targets, query_targets, global_labels=None): + original_feat_shape = support_feat.size() + batch_size = support_feat.size(0) + n_support = support_feat.size(1) + n_query = query_feat.size(1) + # way_num = support_targets.size(-1) + K = support_targets.size(2) + + labels_train_transposed = support_targets.transpose(1, 2) + + prototypes = support_feat.contiguous().view(batch_size, n_support, -1) + prototypes = torch.bmm(labels_train_transposed, prototypes) + prototypes = prototypes.div( + labels_train_transposed.sum(dim=2, keepdim=True).expand_as(prototypes) + ) + prototypes = prototypes.contiguous().view(batch_size, -1, *original_feat_shape[2:]) + query_feat = query_feat.contiguous().view(batch_size, n_query, *original_feat_shape[2:]) + prototypes, query_feat = self.reshape(prototypes, query_feat) + prototypes = prototypes.transpose(1, 2) + query_feat = query_feat.transpose(1, 2) + + cls_scores = self.get_score(prototypes, query_feat, n_support, n_query, batch_size) + + if not self.training: + return self.get_test_score(cls_scores) + + global_pred = self.get_global_pred(query_feat, query_targets, n_query, batch_size, K) + return global_pred, cls_scores + + +class META_FILTER(MetricModel): + def __init__(self, num_classes=64, nFeat=640, kernel=3, groups=1, **kwargs): + super(META_FILTER, self).__init__(**kwargs) + self.model = Model(num_classes, nFeat, kernel, groups) + self.criterion = CrossEntropyLoss() + + def set_forward(self, batch): + images, global_targets = batch + images = images.to(self.device) + global_targets = global_targets.to(self.device) + episode_size = images.size(0) // (self.way_num * (self.shot_num + self.query_num)) + emb = self.emb_func(images) + ( + support_feat, + query_feat, + support_targets, + query_targets, + ) = self.split_by_episode(emb, mode=2) + + # convert to one-hot + labels_train_1hot = one_hot(support_targets).to(self.device) + labels_test_1hot = one_hot(query_targets).to(self.device) + + cls_scores = self.model(support_feat, query_feat, labels_train_1hot, labels_test_1hot) + + cls_scores = cls_scores.reshape(episode_size * self.way_num * self.query_num, -1) + acc = accuracy(cls_scores, query_targets.reshape(-1), topk=1) + return cls_scores, acc + + def set_forward_loss(self, batch): + images, global_targets = batch + images = images.to(self.device) + global_targets = global_targets.to(self.device) + episode_size = images.size(0) // (self.way_num * (self.shot_num + self.query_num)) + emb = self.emb_func(images) + ( + support_feat, + query_feat, + support_targets, + query_targets, + ) = self.split_by_episode(emb, mode=2) + + support_targets = support_targets.reshape( + episode_size, self.way_num * self.shot_num + ).contiguous() + support_global_targets, query_global_targets = ( + global_targets[:, :, : self.shot_num], + global_targets[:, :, self.shot_num :], + ) + + support_feat, support_targets, support_global_targets = shuffle( + support_feat, + support_targets, + support_global_targets.reshape(*support_targets.size()), + ) + query_feat, query_targets, query_global_targets = shuffle( + query_feat, + query_targets.reshape(*query_feat.size()[:2]), + query_global_targets.reshape(*query_feat.size()[:2]), + ) + + # convert to one-hot + labels_train_1hot = one_hot(support_targets).to(self.device) + labels_test_1hot = one_hot(query_targets).to(self.device) + + ytest, cls_scores = self.model( + support_feat, query_feat, labels_train_1hot, labels_test_1hot + ) + # print(ytest.size()) + # print(query_global_targets.size()) + + loss1 = self.criterion(ytest, query_global_targets.contiguous().reshape(-1)) + loss2 = self.criterion(cls_scores, query_targets.view(-1)) + loss = loss1 + 0.5 * loss2 + + cls_scores = torch.sum(cls_scores.reshape(*cls_scores.size()[:2], -1), dim=-1) + acc = accuracy(cls_scores, query_targets.reshape(-1), topk=1) + return cls_scores, acc, loss diff --git a/core/model/metric/__init__.py b/core/model/metric/__init__.py index f41b9861..e791a1c3 100644 --- a/core/model/metric/__init__.py +++ b/core/model/metric/__init__.py @@ -12,4 +12,5 @@ from .dsn import DSN from .deepbdc import DeepBDC from .frn import FRN -from .meta_baseline import MetaBaseline \ No newline at end of file +from .meta_baseline import MetaBaseline +from .DMF import META_FILTER diff --git a/dconv/_C.cpython-38-x86_64-linux-gnu.so b/dconv/_C.cpython-38-x86_64-linux-gnu.so new file mode 100644 index 00000000..ef3e0dc1 Binary files /dev/null and b/dconv/_C.cpython-38-x86_64-linux-gnu.so differ diff --git a/dconv/__init__.py b/dconv/__init__.py new file mode 100644 index 00000000..8eb75b90 --- /dev/null +++ b/dconv/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. diff --git a/dconv/csrc/cpu/vision.h b/dconv/csrc/cpu/vision.h new file mode 100644 index 00000000..e4cb39d0 --- /dev/null +++ b/dconv/csrc/cpu/vision.h @@ -0,0 +1,3 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include diff --git a/dconv/csrc/cuda/deform_conv_cuda.cu b/dconv/csrc/cuda/deform_conv_cuda.cu new file mode 100644 index 00000000..d858c34a --- /dev/null +++ b/dconv/csrc/cuda/deform_conv_cuda.cu @@ -0,0 +1,946 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +// #include +// #include + +#include +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) +{ + TORCH_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) + { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) + { + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) +{ + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) + { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) + { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) + { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) + { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) + { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) +{ + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) + { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) + { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) + { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) + { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) +{ + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) + { + // Force batch + batch = 0; + input = input.contiguous().view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.contiguous().view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.contiguous().view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.contiguous().view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.contiguous().view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.contiguous().view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.contiguous().view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.contiguous().view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) + { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.contiguous().view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.contiguous().view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.contiguous().view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) + { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.contiguous().view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.contiguous().view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.contiguous().view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.contiguous().view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.contiguous().view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) + { + gradOutput = gradOutput.contiguous().view({nOutputPlane, outputHeight, outputWidth}); + input = input.contiguous().view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) +{ + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) + { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) + { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) + { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) + { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) +{ + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) + { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) + { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) + { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) + { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) + { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} + +void unfold_shape_check(at::Tensor input, at::Tensor offset, + int kH, int kW, + int dH, int dW, int padH, int padW, int dilationH, + int dilationW, int deformable_group) +{ + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", + kH, kW); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK(dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) + { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, + "3D or 4D input tensor expected but got: %s", ndim); + + long nInputPlane = input.size(dimf); + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = nInputPlane; + long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK( + (offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); +} + +int deform_unfold_forward_cuda(at::Tensor input, + at::Tensor offset, at::Tensor output, + at::Tensor columns, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, + int deformable_group, int im2col_step) +{ + + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly transpose output) + // todo: possibly change data indexing because of parallel_imgs + unfold_shape_check(input, offset, kH, kW, dH, dW, padH, padW, dilationH, + dilationW, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) + { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = nInputPlane; + + long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nInputPlane * kW * kH, outputHeight * outputWidth}); + columns = at::zeros({nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, input.type()); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth}); + offset = offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = at::zeros({batchSize / im2col_step, nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, output.type()); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) + { + deformable_im2col( + input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, + im2col_step, deformable_group, columns); + output_buffer[elt] = columns; + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view({batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + output_buffer = output_buffer.view( + {batchSize / im2col_step, nInputPlane * kW * kH, im2col_step, outputHeight * outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nInputPlane * kW * kH, outputHeight * outputWidth}); + + if (batch == 0) + { + output = output.view({nOutputPlane * kH * kW, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_unfold_backward_input_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradInput, at::Tensor gradOffset, + at::Tensor columns, int kW, int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int deformable_group, int im2col_step) +{ + unfold_shape_check(input, offset, kH, kW, dH, dW, padH, padW, dilationH, + dilationW, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) + { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view({1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = nInputPlane; + + long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros({nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, input.type()); + + // change order of grad output + gradOutput = gradOutput.view( + {batchSize / im2col_step, im2col_step, nInputPlane * kW * kH, outputHeight * outputWidth}); + gradOutput.transpose_(1, 2); + gradOutput = gradOutput.contiguous(); + gradOutput = gradOutput.view({batchSize / im2col_step, nInputPlane * kH * kW, im2col_step * outputHeight * outputWidth}); + + gradInput = gradInput.view( + {batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) + { + columns.copy_(gradOutput[elt]); + + deformable_col2im_coord( + columns, input[elt], offset[elt], + nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, gradOffset[elt]); + + deformable_col2im( + columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, im2col_step, + deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = gradOutput.contiguous(); + gradOutput = gradOutput.view({batchSize, nInputPlane * kH * kW, outputHeight * outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view({batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) + { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} diff --git a/dconv/csrc/cuda/deform_conv_kernel_cuda.cu b/dconv/csrc/cuda/deform_conv_kernel_cuda.cu new file mode 100644 index 00000000..870fab11 --- /dev/null +++ b/dconv/csrc/cuda/deform_conv_kernel_cuda.cu @@ -0,0 +1,873 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +/* +const int CUDA_NUM_THREADS = 1024; + +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +}*/ + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + // const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + // const scalar_t map_h = i * dilation_h + offset_h; + // const scalar_t map_w = j * dilation_w + offset_w; + // const int cur_height = height - h_in; + // const int cur_width = width - w_in; + // val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.type(), "deformable_im2col_gpu", ([&] + { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.type(), "deformable_col2im_gpu", ([&] + { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.type(), "deformable_col2im_coord_gpu", ([&] + { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + // const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + // if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + // const float map_h = i * dilation_h + offset_h; + // const float map_w = j * dilation_w + offset_w; + // const int cur_height = height - h_in; + // const int cur_width = width - w_in; + // val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + // data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.type(), "modulated_deformable_im2col_gpu", ([&] + { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.type(), "modulated_deformable_col2im_gpu", ([&] + { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.type(), "modulated_deformable_col2im_coord_gpu", ([&] + { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/dconv/csrc/cuda/vision.h b/dconv/csrc/cuda/vision.h new file mode 100644 index 00000000..6e6509da --- /dev/null +++ b/dconv/csrc/cuda/vision.h @@ -0,0 +1,44 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include + + + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); diff --git a/dconv/csrc/deform_conv.h b/dconv/csrc/deform_conv.h new file mode 100644 index 00000000..aeec3c23 --- /dev/null +++ b/dconv/csrc/deform_conv.h @@ -0,0 +1,186 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + +// Interface for Python +int deform_conv_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor output, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) +{ + if (input.is_cuda()) + { +#ifdef WITH_CUDA + return deform_conv_forward_cuda( + input, weight, offset, output, columns, ones, + kW, kH, dW, dH, padW, padH, dilationW, dilationH, + group, deformable_group, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +int deform_conv_backward_input( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradInput, + at::Tensor gradOffset, + at::Tensor weight, + at::Tensor columns, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) +{ + if (input.is_cuda()) + { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda( + input, offset, gradOutput, gradInput, gradOffset, weight, columns, + kW, kH, dW, dH, padW, padH, dilationW, dilationH, + group, deformable_group, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +int deform_conv_backward_parameters( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + float scale, + int im2col_step) +{ + if (input.is_cuda()) + { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda( + input, offset, gradOutput, gradWeight, columns, ones, + kW, kH, dW, dH, padW, padH, dilationW, dilationH, + group, deformable_group, scale, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +void modulated_deform_conv_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor output, + at::Tensor columns, + int kernel_h, + int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int group, + const int deformable_group, + const bool with_bias) +{ + if (input.is_cuda()) + { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward( + input, weight, bias, ones, offset, mask, output, columns, + kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, + group, deformable_group, with_bias); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +void modulated_deform_conv_backward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor columns, + at::Tensor grad_input, + at::Tensor grad_weight, + at::Tensor grad_bias, + at::Tensor grad_offset, + at::Tensor grad_mask, + at::Tensor grad_output, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int group, + int deformable_group, + const bool with_bias) +{ + if (input.is_cuda()) + { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward( + input, weight, bias, ones, offset, mask, columns, + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, + group, deformable_group, with_bias); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} diff --git a/dconv/csrc/vision.cpp b/dconv/csrc/vision.cpp new file mode 100644 index 00000000..8615bc71 --- /dev/null +++ b/dconv/csrc/vision.cpp @@ -0,0 +1,11 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include "deform_conv.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // dcn-v2 + m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", &deform_conv_backward_parameters, "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", &modulated_deform_conv_forward, "modulated_deform_conv_forward"); + m.def("modulated_deform_conv_backward", &modulated_deform_conv_backward, "modulated_deform_conv_backward"); +} diff --git a/dconv/layers/__init__.py b/dconv/layers/__init__.py new file mode 100644 index 00000000..9b640438 --- /dev/null +++ b/dconv/layers/__init__.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +from .dcn.deform_conv_func import deform_conv, modulated_deform_conv +from .dcn.deform_conv_module import DeformConv, ModulatedDeformConv, ModulatedDeformConvPack + +# from .dualgraph import GloReLocalModule + + +__all__ = [ + "deform_conv", + "modulated_deform_conv", + "DeformConv", + "ModulatedDeformConv", + "ModulatedDeformConvPack", + # 'GloReLocalModule', +] diff --git a/dconv/layers/_utils.py b/dconv/layers/_utils.py new file mode 100644 index 00000000..747864d4 --- /dev/null +++ b/dconv/layers/_utils.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import glob +import os.path + +import torch + +try: + from torch.utils.cpp_extension import load as load_ext + from torch.utils.cpp_extension import CUDA_HOME +except ImportError: + raise ImportError("The cpp layer extensions requires PyTorch 0.4 or higher") + + +def _load_C_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + this_dir = os.path.dirname(this_dir) + this_dir = os.path.join(this_dir, "csrc") + + main_file = glob.glob(os.path.join(this_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu")) + + source = main_file + source_cpu + + extra_cflags = [] + if torch.cuda.is_available() and CUDA_HOME is not None: + source.extend(source_cuda) + extra_cflags = ["-DWITH_CUDA"] + source = [os.path.join(this_dir, s) for s in source] + extra_include_paths = [this_dir] + return load_ext( + "torchvision", + source, + extra_cflags=extra_cflags, + extra_include_paths=extra_include_paths, + ) + + +_C = _load_C_extensions() diff --git a/dconv/layers/dcn/__init__.py b/dconv/layers/dcn/__init__.py new file mode 100644 index 00000000..9a63fdfb --- /dev/null +++ b/dconv/layers/dcn/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# +# Copied From [mmdetection](https://github.com/open-mmlab/mmdetection/tree/master/mmdet/ops/dcn) +# diff --git a/dconv/layers/dcn/deform_conv_func.py b/dconv/layers/dcn/deform_conv_func.py new file mode 100644 index 00000000..59f546ba --- /dev/null +++ b/dconv/layers/dcn/deform_conv_func.py @@ -0,0 +1,278 @@ +# -*- coding: utf-8 -*- +import torch +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from dconv import _C + + +class DeformConvFunction(Function): + + @staticmethod + def forward( + ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=2000, + ): + if input is not None and input.dim() != 4: + raise ValueError( + "Expected 4D tensor as input, got {}D tensor instead.".format(input.dim()) + ) + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty( + DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride) + ) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize" + _C.deform_conv_forward( + input, + weight, + offset, + output, + ctx.bufs_[0], + ctx.bufs_[1], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + cur_im2col_step, + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize" + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + _C.deform_conv_backward_input( + input, + offset, + grad_output, + grad_input, + grad_offset, + weight, + ctx.bufs_[0], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + cur_im2col_step, + ) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + _C.deform_conv_backward_parameters( + input, + offset, + grad_output, + grad_weight, + ctx.bufs_[0], + ctx.bufs_[1], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + 1, + cur_im2col_step, + ) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1,) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError( + "convolution input is too small (output would be {})".format( + "x".join(map(str, output_size)) + ) + ) + return output_size + + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward( + ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + ): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if ( + weight.requires_grad + or mask.requires_grad + or offset.requires_grad + or input.requires_grad + ): + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + _C.modulated_deform_conv_forward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + output, + ctx._bufs[1], + weight.shape[2], + weight.shape[3], + ctx.stride, + ctx.stride, + ctx.padding, + ctx.padding, + ctx.dilation, + ctx.dilation, + ctx.groups, + ctx.deformable_groups, + ctx.with_bias, + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + _C.modulated_deform_conv_backward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + ctx._bufs[1], + grad_input, + grad_weight, + grad_bias, + grad_offset, + grad_mask, + grad_output, + weight.shape[2], + weight.shape[3], + ctx.stride, + ctx.stride, + ctx.padding, + ctx.padding, + ctx.dilation, + ctx.dilation, + ctx.groups, + ctx.deformable_groups, + ctx.with_bias, + ) + if not ctx.with_bias: + grad_bias = None + + return ( + grad_input, + grad_offset, + grad_mask, + grad_weight, + grad_bias, + None, + None, + None, + None, + None, + ) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = ( + height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1) + ) // ctx.stride + 1 + width_out = ( + width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1) + ) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply diff --git a/dconv/layers/dcn/deform_conv_module.py b/dconv/layers/dcn/deform_conv_module.py new file mode 100644 index 00000000..bfea38f0 --- /dev/null +++ b/dconv/layers/dcn/deform_conv_module.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +import math + +import torch +import torch.nn as nn +from torch.nn.modules.utils import _pair + +from .deform_conv_func import deform_conv, modulated_deform_conv + + +class DeformConv(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False, + ): + assert not bias + super(DeformConv, self).__init__() + self.with_bias = bias + + assert in_channels % groups == 0, "in_channels {} cannot be divisible by groups {}".format( + in_channels, groups + ) + assert ( + out_channels % groups == 0 + ), "out_channels {} cannot be divisible by groups {}".format(out_channels, groups) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size) + ) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1.0 / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, input, offset): + return deform_conv( + input, + offset, + self.weight, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deformable_groups, + ) + + def __repr__(self): + return "".join( + [ + "{}(".format(self.__class__.__name__), + "in_channels={}, ".format(self.in_channels), + "out_channels={}, ".format(self.out_channels), + "kernel_size={}, ".format(self.kernel_size), + "stride={}, ".format(self.stride), + "dilation={}, ".format(self.dilation), + "padding={}, ".format(self.padding), + "groups={}, ".format(self.groups), + "deformable_groups={}, ".format(self.deformable_groups), + "bias={})".format(self.with_bias), + ] + ) + + +class ModulatedDeformConv(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True, + ): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1.0 / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, input, offset, mask): + return modulated_deform_conv( + input, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deformable_groups, + ) + + def __repr__(self): + return "".join( + [ + "{}(".format(self.__class__.__name__), + "in_channels={}, ".format(self.in_channels), + "out_channels={}, ".format(self.out_channels), + "kernel_size={}, ".format(self.kernel_size), + "stride={}, ".format(self.stride), + "dilation={}, ".format(self.dilation), + "padding={}, ".format(self.padding), + "groups={}, ".format(self.groups), + "deformable_groups={}, ".format(self.deformable_groups), + "bias={})".format(self.with_bias), + ] + ) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True, + ): + super(ModulatedDeformConvPack, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + deformable_groups, + bias, + ) + + self.conv_offset_mask = nn.Conv2d( + self.in_channels // self.groups, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + bias=True, + ) + self.init_offset() + + def init_offset(self): + self.conv_offset_mask.weight.data.zero_() + self.conv_offset_mask.bias.data.zero_() + + def forward(self, input): + out = self.conv_offset_mask(input) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv( + input, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deformable_groups, + ) diff --git a/reproduce/DMF/META_FILTER-miniImagenet-resnet12_drop-5-1.yaml b/reproduce/DMF/META_FILTER-miniImagenet-resnet12_drop-5-1.yaml new file mode 100644 index 00000000..770daf58 --- /dev/null +++ b/reproduce/DMF/META_FILTER-miniImagenet-resnet12_drop-5-1.yaml @@ -0,0 +1,70 @@ +augment: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: + drop_block: true + name: resnet12_drop +batch_size: 128 +classifier: + kwargs: + num_class: 64 + nFeat: 640 + kernel: 1 + groups: 64 + name: DMF +data_root: /home/qinhaonan/.vscode-server/data/dataset/miniImageNet--ravi +deterministic: true +device_ids: 0,1,2,3 +episode_size: 8 +epoch: 120 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +- classifiers/DMF.yaml +- backbones/resnet12_drop.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + T_max: 120 + eta_min: 0 + name: CosineAnnealingLR +warmup: 10 +n_gpu: 4 +optimizer: + kwargs: + lr: 0.35 + momentum: 0.9 + nesterov: true + weight_decay: 0.0005 + name: SGD + other: null +parallel_part: +- emb_func +pretrain_path: null +query_num: 6 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 2147483647 +shot_num: 1 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 1200 +test_epoch: 5 +test_query: 6 +test_shot: 1 +test_way: 5 +train_episode: 2000 +use_memory: false +way_num: 5 diff --git a/reproduce/DMF/META_FILTER-miniImagenet-resnet12_drop-5-5.yaml b/reproduce/DMF/META_FILTER-miniImagenet-resnet12_drop-5-5.yaml new file mode 100644 index 00000000..9d3b91c3 --- /dev/null +++ b/reproduce/DMF/META_FILTER-miniImagenet-resnet12_drop-5-5.yaml @@ -0,0 +1,70 @@ +augment: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: + drop_block: true + name: resnet12_drop +batch_size: 128 +classifier: + kwargs: + num_class: 64 + nFeat: 640 + kernel: 1 + groups: 160 + name: DMF +data_root: /home/qinhaonan/.vscode-server/data/dataset/miniImageNet--ravi +deterministic: true +device_ids: 0,1,2,3 +episode_size: 8 +epoch: 120 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +- classifiers/DMF.yaml +- backbones/resnet12_drop.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + T_max: 120 + eta_min: 0 + name: CosineAnnealingLR +warmup: 10 +n_gpu: 4 +optimizer: + kwargs: + lr: 0.35 + momentum: 0.9 + nesterov: true + weight_decay: 0.0005 + name: SGD + other: null +parallel_part: +- emb_func +pretrain_path: null +query_num: 6 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 2147483647 +shot_num: 5 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 1200 +test_epoch: 5 +test_query: 6 +test_shot: 5 +test_way: 5 +train_episode: 2000 +use_memory: false +way_num: 5 diff --git a/reproduce/DMF/META_FILTER-tiered_imagenet-resnet12_drop-5-1.yaml b/reproduce/DMF/META_FILTER-tiered_imagenet-resnet12_drop-5-1.yaml new file mode 100644 index 00000000..744785c9 --- /dev/null +++ b/reproduce/DMF/META_FILTER-tiered_imagenet-resnet12_drop-5-1.yaml @@ -0,0 +1,70 @@ +augment: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: + drop_block: true + name: resnet12_drop +batch_size: 128 +classifier: + kwargs: + num_class: 351 + nFeat: 640 + kernel: 1 + groups: 64 + name: DMF +data_root: /home/qinhaonan/.vscode-server/data/dataset/tiered_imagenet +deterministic: true +device_ids: 0,1,2,3 +episode_size: 8 +epoch: 120 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +- classifiers/DMF.yaml +- backbones/resnet12_drop.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + T_max: 120 + eta_min: 0 + name: CosineAnnealingLR +warmup: 10 +n_gpu: 4 +optimizer: + kwargs: + lr: 0.05 + momentum: 0.9 + nesterov: true + weight_decay: 0.0005 + name: SGD + other: null +parallel_part: +- emb_func +pretrain_path: null +query_num: 6 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 2147483647 +shot_num: 1 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 1200 +test_epoch: 5 +test_query: 6 +test_shot: 1 +test_way: 5 +train_episode: 2000 +use_memory: false +way_num: 5 diff --git a/reproduce/DMF/META_FILTER-tiered_imagenet-resnet12_drop-5-5.yaml b/reproduce/DMF/META_FILTER-tiered_imagenet-resnet12_drop-5-5.yaml new file mode 100644 index 00000000..aebb853a --- /dev/null +++ b/reproduce/DMF/META_FILTER-tiered_imagenet-resnet12_drop-5-5.yaml @@ -0,0 +1,70 @@ +augment: true +augment_times: 1 +augment_times_query: 1 +backbone: + kwargs: + drop_block: true + name: resnet12_drop +batch_size: 128 +classifier: + kwargs: + num_class: 351 + nFeat: 640 + kernel: 1 + groups: 320 + name: DMF +data_root: /home/qinhaonan/.vscode-server/data/dataset/tiered_imagenet +deterministic: true +device_ids: 0,1,2,3 +episode_size: 8 +epoch: 120 +image_size: 84 +includes: +- headers/data.yaml +- headers/device.yaml +- headers/misc.yaml +- headers/model.yaml +- headers/optimizer.yaml +- classifiers/DMF.yaml +- backbones/resnet12_drop.yaml +log_interval: 100 +log_level: info +log_name: null +log_paramerter: false +lr_scheduler: + kwargs: + T_max: 120 + eta_min: 0 + name: CosineAnnealingLR +warmup: 10 +n_gpu: 4 +optimizer: + kwargs: + lr: 0.05 + momentum: 0.9 + nesterov: true + weight_decay: 0.0005 + name: SGD + other: null +parallel_part: +- emb_func +pretrain_path: null +query_num: 6 +rank: 0 +result_root: ./results +resume: false +save_interval: 10 +save_part: +- emb_func +seed: 2147483647 +shot_num: 5 +tag: null +tb_scale: 3.3333333333333335 +test_episode: 1200 +test_epoch: 5 +test_query: 6 +test_shot: 5 +test_way: 5 +train_episode: 2000 +use_memory: false +way_num: 5 diff --git a/reproduce/DMF/README.md b/reproduce/DMF/README.md new file mode 100644 index 00000000..dcd3aa88 --- /dev/null +++ b/reproduce/DMF/README.md @@ -0,0 +1,36 @@ +# Learning Dynamic Alignment via Meta-filter for Few-shot Learning +## Introduction +| Name: | [DMF](https://arxiv.org/pdf/2103.13582) | +|----------|-------------------------------| +| Embed.: | Conv64F/ResNet12/ | +| Type: | Metric | +| Venue: | CVPR'21 | +| Codes: | [**DMF**](https://github.com/chmxu/Dynamic-Meta-filter) | + + +Cite this work with: +```bibtex +@inproceedings{xu2021dmf, + title={Learning Dynamic Alignment via Meta-filter for Few-shot Learning}, + author={Chengming Xu and Chen Liu and Li Zhang and Chengjie Wang and Jilin Li and Feiyue Huang and Xiangyang Xue and Yanwei Fu}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + year={2021} +} +``` +--- +## Setup +You need to setup first. Run the code below: +``` +python DMF_setup.py develop build +``` +## Results and Models + +**Classification** + +| | Embedding | :book: *mini*ImageNet (5,1) | :computer: *mini*ImageNet (5,1) | :book:*mini*ImageNet (5,5) | :computer: *mini*ImageNet (5,5) | :memo: Comments | +|---|-----------|--------------------|--------------------|--------------------|--------------------|---| +| 1 | ResNet12 | 67.76 ± 0.46% | 67.185% [:arrow_down:]( https://pan.baidu.com/s/1H_Y7G4BH-OnbU5hl54ljww?pwd=s3ut) [:clipboard:](./META_FILTER-miniImagenet-resnet12_drop-5-1.yaml) | 82.71 ± 0.31% | 81.997% [:arrow_down:]( https://pan.baidu.com/s/1X99febUlbV7WE6IYkNeGfQ?pwd=8rv5) [:clipboard:](./META_FILTER-miniImagenet-resnet12_drop-5-5.yaml) | Comments | + +| | Embedding | :book: *tiered*ImageNet (5,1) | :computer: *tiered*ImageNet (5,1) | :book:*tiered*ImageNet (5,5) | :computer: *tiered*ImageNet (5,5) | :memo: Comments | +|---|-----------|--------------------|--------------------|--------------------|--------------------|---| +| 1 | ResNet12 | 71.89 ± 0.52% | 71.369% [:arrow_down:]( https://pan.baidu.com/s/1pYD9H7SOuw0BYIQerYjOhA?pwd=546y) [:clipboard:](./META_FILTER-tiered_imagenet-resnet12_drop-5-1.yaml) | 85.96 ± 0.35% | 85.350% [:arrow_down:](https://pan.baidu.com/s/1Jf1XKcXziEdcXGZjFwQdeg?pwd=hy7j) [:clipboard:](./META_FILTER-tiered_imagenet-resnet12_drop-5-5.yaml) | Comments | \ No newline at end of file