Skip to content
This repository was archived by the owner on Sep 3, 2024. It is now read-only.

Comments

Call optimizer.zero_grad() when profiling#12

Open
mostafaelhoushi wants to merge 2 commits intomainfrom
call-zero-grad-before-fwd
Open

Call optimizer.zero_grad() when profiling#12
mostafaelhoushi wants to merge 2 commits intomainfrom
call-zero-grad-before-fwd

Conversation

@mostafaelhoushi
Copy link
Contributor

Didn't test locally... let's see if CI merges and I can check locally as well tomorrow

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 6, 2023
@mostafaelhoushi mostafaelhoushi marked this pull request as ready for review January 6, 2023 05:50
@mostafaelhoushi
Copy link
Contributor Author

Error in CI... need to debug that

@jacobkahn
Copy link
Member

Discussed offline -- from a correctness perspective, we need to clear accumulated gradients on each iteration, so let's get this passing in CI 👍

@mostafaelhoushi
Copy link
Contributor Author

So far in this PR, I was calling optim.zero_grad() inside the model wrapper before creating fx graph:
image

This caused things to mess up and it requires more debugging. Also, I was surprised that the fx trace already had an instruction to zero out tensors (assuming these are gradient tensors). In the screenshot below, the LHS is the trace of testSimpleTrainSchedule (the test case that is failing in CI) before and after this PR:
image

Another option is to call optim.zero_grad() outside the wrapper and during profiling. This will require some refactoring. I will probably go through this route.

Having said that, the instructions that seem to be introduced by optim.zero_grad():

    zeros = torch.ops.aten.zeros.default([1], device = device(type='cpu'), pin_memory = False)
    empty = torch.ops.aten.empty.memory_format([280], dtype = torch.uint8, device = device(type='cpu'))

are problematic when creating or modifying graphs. This is because the instructions do not take any tensor as input, and the output of the instructions are not used by any other op. Hence, the ILP solver might randomly move them to any other location in the graph, while we want to make sure to call them before the forward pass of the model.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants