Update _SimpleConsensus to use static autograd methods (for PyTorch >1.3) #30
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Thank you so much for sharing your implementation of TPN!
Problem
I've been trying to get it to work in one of my own projects -- however, I ran into the same issue as mentioned #28, in which the user pastes a stack trace with error message
"Legacy autograd function with non-static forward method is deprecated."This occurs when you try to callforward()with the old code when the averaging consensus (_SimpleConsensus) is used.Environment
Summary of changes
In order to make the
_SimpleConsensusclass (subclassingtorch.Autograd.Function) compatible with PyTorch >1.3:__init__method from_SimpleConsensusapplystatic method instead offorwardfor passing input tensor through the_SimpleConsensusobjectforward()method of_SimpleConsensususesctx.save_for_backward(args)to cache input tensorx,dim, andconsensus_type.self.shapeis no longer a member of_SimpleConsensus; it is reconstructed by retrievingxfromctx.saved_tensorsand callingx.size()in each call tobackward().This is consistent with the template given in the PyTorch docs, which I referenced.
Discussion
The changes in this PR work for me -- I am able to run
forward()without issue now. However, as a disclaimer, due to the nature of my project, I'm using my own testing script instead of the provided testing framework in this repo. For completeness, my model loading code looks like this:Please let me know if there's any additional testing (suites or otherwise) I should run, or if there's a contributing guide that I've overlooked. Furthermore, I'm happy to provide more details as needed. Thanks!