Skip to content

Conversation

@inducer
Copy link
Owner

@inducer inducer commented Oct 20, 2025

Very WIP right now.

cc @majosm

@majosm majosm force-pushed the flop-counting branch 6 times, most recently from 31da603 to 83629d9 Compare October 27, 2025 16:31
@majosm
Copy link
Collaborator

majosm commented Oct 27, 2025

ConditionalDependencyMapper addition is unused (I used it in an earlier draft but ended up not needing it). I left it in because it doesn't harm anything and seems potentially useful.

@inducer I think this is ready for a first look, apart from some basedpyright errors that I don't understand. Probably best to go commit by commit.

@majosm
Copy link
Collaborator

majosm commented Nov 7, 2025

@inducer I ran into a complication with the flop counting for conditionals. When mirgecom projects something from interior/boundary faces to all faces, meshmode's direct connection code creates an AdvancedIndexInContiguousAxes (code) shaped like the all-faces discretization, and then it follows up with an IndexLambda that zeros out the entries corresponding to elements that aren't part of the input discretization. You can see this for wave below:
all_faces_projection

Since the flop counting currently evaluates flops for Conditional as max(nflops_then, nflops_else), this is causing the number of flops to be overestimated for this part of the DAG. It's counting the flops (in the case of wave) over 1536 faces instead of the 64 that end up being nonzero. I'm not sure how to deal with this, since it's data-dependent. I could maybe hack in a check for condition being something easily evaluated and then just count the number of true/false cases, but beyond that I don't know how to resolve this. Any thoughts?

@majosm majosm force-pushed the flop-counting branch 4 times, most recently from 5236448 to 72b69ac Compare November 14, 2025 17:05
@majosm
Copy link
Collaborator

majosm commented Nov 14, 2025

@inducer Do you know what's causing these basedpyright errors for rec? Is this another instance of the thing that should be added to the baseline?

Also, about the Conditional flop count handling: if we're going to model it after GPU execution, should it be counted as max(nflops_then, nflops_else) or sum(nflops_then, nflops_else)?

@majosm majosm force-pushed the flop-counting branch 2 times, most recently from 789e7ac to f077dd1 Compare November 21, 2025 23:11
@majosm majosm marked this pull request as ready for review November 21, 2025 23:27
pytato/utils.py Outdated
return (
isinstance(expr, Array)
and not isinstance(expr, (
# FIXME: Is there a nice way to generalize this?
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand the intent of this.

For example, DistributedRecv is always materialized (it ends up in a buffer after all).

What is this function trying to accomplish?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using this in the flop counting code to tell me a couple of things:

  1. Am I allowed to materialize/unmaterialize this expression? (Can't for things like NamedArray and DistributedSendRefHolder.)
  2. Can this expression be lowered to an index lambda? (Can't for things like InputArgumentBase and DistributedRecv.)

Copy link
Collaborator

@majosm majosm Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 5e30e95 any better? It still has some isinstance() checks that seem kind of fragile, but I'm not sure how to avoid them.

@majosm majosm force-pushed the flop-counting branch 2 times, most recently from 8164074 to 75e5dbe Compare December 15, 2025 16:55
@majosm majosm mentioned this pull request Dec 15, 2025
1 task
@majosm
Copy link
Collaborator

majosm commented Jan 9, 2026

@inducer This is ready for a look again when you have a chance.

Copy link
Owner Author

@inducer inducer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the enum thing and then this is good to go from my perspective. 🎉

return {expr.name}


class FlopCounter(FlopCounterBase):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do a separate node type for nflops.

var_names: set[str] = set(ScalarInputGatherer()(nflops))
var_names.discard("nflops")
if var_names:
raise UndefinedOpFlopCountError(next(iter(var_names))) from None
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise UndefinedOpFlopCountError(next(iter(var_names))) from None
raise UndefinedOpFlopCountError(next(iter(var_names)))

if var_names:
raise UndefinedOpFlopCountError(next(iter(var_names))) from None
else:
raise AssertionError from None
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise AssertionError from None
raise AssertionError

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants