Conversation
|
|
||
| def _get_module_stack(node): | ||
| if "nn_module_stack" not in node.meta: | ||
| # if 'fwd_nn_module_stack' in node.meta: |
There was a problem hiding this comment.
We need to enable bwd nodes as well
|
Nice! "module hierarchy to organise the nodes / subcontainers in a tree" is similar to exported_program.unflattener idea. cc @angelayi I feel there could be some useful interaction between GraphViewer and User Annotation introduced in pytorch/pytorch#163673 |
|
Yeah! This seems very similar to the unflattener which given an exported program which contains the flattened graph, will create an "UnflattenedModule" which organizes the nodes into submodules with the same module hierarchy as the original eager program. This module can then be run eagerly. Here's an example of how the module looks like (sorry I only have a large example on hand): P1909223007 Unsure how it could fit into autoparallel, but let me know if you would like to learn more!
I'm not familiar with the full context of the PR, but it seems like it's introducing an easier way of querying nodes within specific layers. I wonder if it's enough to just have a function that just partitions nodes containing specific metadata. That could compose with the user annotations in Sherlock's PR. |
eellison
left a comment
There was a problem hiding this comment.
Cool! I like the idea.
I'm not sure how to fit in things like re-computation from the backwards fits in exactly, because that disrupts the nn.Module hierarchy a bit. Similarly, any sort of reordering or pattern matching. Those are details but I think this is valuable anyway.
yes! in fact I want to this in pytorch/pytorch#169426 |
This PR proposes a thin wrapper around a
fx.Graphthat allows for a simpler experience for querying / manipulating graphs.It uses the module hierarchy to organise the nodes / subcontainers in a tree, similar to what
nn.Moduledo.Thus, for a model like
nn.Sequential(nn.Linear(2,2), nn.Sequential(nn.Linear(2,2), nn.Linear(2,2)), one can get quick information about submodules viagraph_view['1']['0'], or for more complex modules, viamodel.block['0'].attention.Things that this might make it easier:
For now, only forward nodes are handled, but we should handle backward nodes as well.