Add for option to use tensor hooks for Dynamic Linear#198
Add for option to use tensor hooks for Dynamic Linear#198
Conversation
| # option is useful for safety, but not strictly necessary. | ||
| enable_pre_and_post_forward = True | ||
|
|
||
| # If True, dynamic linear uses hooks for activation casting |
There was a problem hiding this comment.
need to figure out if we want this
| @pytest.mark.parametrize( | ||
| "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] | ||
| ) | ||
| @pytest.mark.parametrize("use_activation_hooks", [True, False]) |
There was a problem hiding this comment.
There was some testing that was globbed together before this split the test into two
wanchaol
left a comment
There was a problem hiding this comment.
Looks great!! I have only one comment as I thought we can avoid the backward prehook by just using a forward hook + the existing autograd function.
| return module.cast_to_float8_e4m3fn(args[0]) | ||
|
|
||
|
|
||
| def cast_dldy_to_float8_e5m2_backward_pre_hook(module, grad_output): |
There was a problem hiding this comment.
testing my understanding:
It looks to me that the torch.compile issue is related to the backward_prehook and tensor subclass interactions.
But could this be solved by using module.register_forward_hook and then inside the forward hook we call y = self.cast_to_float8_e5m2_bw(y) just like the current casting?
This would solve the backward prehook issue I believe, but not sure if it would hit subclass issues (hopefully not)
There was a problem hiding this comment.
ahhh great, idea unfortunately still erroring for compile..
I think that is a little cleaner, is that enough for a good interaction with DTensor?
I have this locally and can push up but don't know which one ultimately gets us closer, forward_hook orfull_backward_pre_hook
There was a problem hiding this comment.
Yeah I think a forward_hook might be a little cleaner and easier for DTensor to interact. We would need to actually try composing those hooks to see if there's any gap. I feel let's try land the forward_hook approach in this PR, and if we found we need full_backward_pre_hook instead of forward_hook later, we can always change it back later if needed?
wanchaol
left a comment
There was a problem hiding this comment.
stamp to unblock, thanks for getting this to work!
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary
This is a duplicate of: #170
With more testing, ideally I think we wouldn't have the choice between hooks and modified forwards and just use hooks. However compile does not appear to support this yet