Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions mmaction/models/tenons/segmental_consensuses/simple_consensus.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,35 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from ...registry import SEGMENTAL_CONSENSUSES


class _SimpleConsensus(torch.autograd.Function):
"""Simplest segmental consensus module"""

def __init__(self,
consensus_type='avg',
dim=1):
super(_SimpleConsensus, self).__init__()

assert consensus_type in ['avg']
self.consensus_type = consensus_type
self.dim = dim
self.shape = None

def forward(self, x):
self.shape = x.size()
if self.consensus_type == 'avg':
output = x.mean(dim=self.dim, keepdim=True)
@staticmethod
def forward(ctx,x,dim,consensus_type):
ctx.dim = dim
ctx.consensus_type=consensus_type
ctx.save_for_backward(x)
if consensus_type == 'avg':
output = x.mean(dim=dim, keepdim=True)
else:
output = None
return output

def backward(self, grad_output):
if self.consensus_type == 'avg':
grad_in = grad_output.expand(self.shape) / float(self.shape[self.dim])

@staticmethod
def backward( ctx,grad_output):
x, = ctx.saved_tensors
dim = ctx.dim
consensus_type=ctx.consensus_type
shape = x.size()
if consensus_type == 'avg':
grad_in = grad_output.expand(shape) / float(shape[dim])
else:
grad_in = None
return grad_in
return grad_in, None , None


@SEGMENTAL_CONSENSUSES.register_module
Expand All @@ -46,4 +45,6 @@ def init_weights(self):
pass

def forward(self, input):
return _SimpleConsensus(self.consensus_type, self.dim)(input)
return _SimpleConsensus.apply(input,
self.dim,
self.consensus_type)