Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 176 additions & 15 deletions comprehensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,18 @@ func (con *converter) identifyComprehension(expr *exprpb.Expr) (*ComprehensionIn
}

// analyzeComprehensionPattern examines the comprehension AST structure to identify
// which CEL macro it represents by pattern matching the characteristic expressions
// which CEL macro it represents by pattern matching the characteristic expressions.
//
// IMPORTANT: This relies on CEL's stable macro expansion patterns:
// - all(x, pred) → accuInit=true, step=accu&&pred, result=accu
// - exists(x, pred) → accuInit=false, step=accu||pred, result=accu
// - exists_one(x, pred) → accuInit=0, step=?:(pred,accu+1,accu), result=accu==1
// - map(x, t) → accuInit=[], step=accu+[t], result=accu
// - map(x, t, filter) → accuInit=[], step=?:(filter,accu+[t],accu), result=accu
// - filter(x, pred) → accuInit=[], step=?:(pred,accu+[x],accu), result=accu
//
// These patterns are defined by the CEL specification and are stable across versions.
// If CEL changes its macro expansion, this code will need updates.
func (con *converter) analyzeComprehensionPattern(comp *exprpb.Expr_Comprehension) (*ComprehensionInfo, error) {
info := &ComprehensionInfo{
IterVar: comp.GetIterVar(),
Expand All @@ -94,6 +105,13 @@ func (con *converter) analyzeComprehensionPattern(comp *exprpb.Expr_Comprehensio
// All: accuInit = true, step = accu && predicate, result = accu
if con.isBoolTrue(accuInit) {
if con.isLogicalAndStep(comp.GetLoopStep(), comp.GetAccuVar()) {
// Validate result expression matches expected pattern
if !con.isIdentityResult(comp.GetResult(), comp.GetAccuVar()) {
con.logger.LogAttrs(context.Background(), slog.LevelWarn,
"comprehension result doesn't match expected all() pattern",
slog.String("accu_var", comp.GetAccuVar()),
)
}
info.Type = ComprehensionAll
info.Predicate = con.extractPredicateFromAndStep(comp.GetLoopStep(), comp.GetAccuVar())
con.logger.LogAttrs(context.Background(), slog.LevelDebug,
Expand All @@ -109,6 +127,13 @@ func (con *converter) analyzeComprehensionPattern(comp *exprpb.Expr_Comprehensio
// Exists: accuInit = false, step = accu || predicate, result = accu
if con.isBoolFalse(accuInit) {
if con.isLogicalOrStep(comp.GetLoopStep(), comp.GetAccuVar()) {
// Validate result expression matches expected pattern
if !con.isIdentityResult(comp.GetResult(), comp.GetAccuVar()) {
con.logger.LogAttrs(context.Background(), slog.LevelWarn,
"comprehension result doesn't match expected exists() pattern",
slog.String("accu_var", comp.GetAccuVar()),
)
}
info.Type = ComprehensionExists
info.Predicate = con.extractPredicateFromOrStep(comp.GetLoopStep(), comp.GetAccuVar())
con.logger.LogAttrs(context.Background(), slog.LevelDebug,
Expand All @@ -123,7 +148,14 @@ func (con *converter) analyzeComprehensionPattern(comp *exprpb.Expr_Comprehensio

// ExistsOne: accuInit = 0, step = conditional(predicate, accu + 1, accu), result = accu == 1
if con.isIntZero(accuInit) {
if con.isConditionalCountStep(comp.GetLoopStep(), comp.GetAccuVar()) && con.isEqualsOneResult(comp.GetResult(), comp.GetAccuVar()) {
if con.isConditionalCountStep(comp.GetLoopStep(), comp.GetAccuVar()) {
// Validate result expression is accu == 1
if !con.isEqualsOneResult(comp.GetResult(), comp.GetAccuVar()) {
con.logger.LogAttrs(context.Background(), slog.LevelWarn,
"comprehension result doesn't match expected exists_one() pattern (should be accu == 1)",
slog.String("accu_var", comp.GetAccuVar()),
)
}
info.Type = ComprehensionExistsOne
info.Predicate = con.extractPredicateFromConditionalStep(comp.GetLoopStep())
con.logger.LogAttrs(context.Background(), slog.LevelDebug,
Expand All @@ -139,6 +171,13 @@ func (con *converter) analyzeComprehensionPattern(comp *exprpb.Expr_Comprehensio
// Map: accuInit = [], step = accu + [transform], result = accu
if con.isEmptyList(accuInit) {
if con.isListAppendStep(comp.GetLoopStep(), comp.GetAccuVar()) {
// Validate result expression matches expected pattern
if !con.isIdentityResult(comp.GetResult(), comp.GetAccuVar()) {
con.logger.LogAttrs(context.Background(), slog.LevelWarn,
"comprehension result doesn't match expected map() pattern",
slog.String("accu_var", comp.GetAccuVar()),
)
}
info.Type = ComprehensionMap
info.Transform = con.extractTransformFromAppendStep(comp.GetLoopStep(), comp.GetAccuVar())
con.logger.LogAttrs(context.Background(), slog.LevelDebug,
Expand All @@ -150,7 +189,14 @@ func (con *converter) analyzeComprehensionPattern(comp *exprpb.Expr_Comprehensio
return info, nil
}
// Map with filter: step = conditional(filter, accu + [transform], accu)
if con.isConditionalAppendStep(comp.GetLoopStep(), comp.GetAccuVar()) {
if con.isConditionalAppendStep(comp.GetLoopStep(), comp.GetAccuVar(), comp.GetIterVar()) {
// Validate result expression matches expected pattern
if !con.isIdentityResult(comp.GetResult(), comp.GetAccuVar()) {
con.logger.LogAttrs(context.Background(), slog.LevelWarn,
"comprehension result doesn't match expected map(filter) pattern",
slog.String("accu_var", comp.GetAccuVar()),
)
}
info.Type = ComprehensionMap
info.HasFilter = true
filter, transform := con.extractFilterAndTransformFromConditionalStep(comp.GetLoopStep(), comp.GetAccuVar())
Expand All @@ -169,6 +215,13 @@ func (con *converter) analyzeComprehensionPattern(comp *exprpb.Expr_Comprehensio
// Filter: accuInit = [], step = conditional(predicate, accu + [iterVar], accu), result = accu
if con.isEmptyList(accuInit) {
if con.isConditionalFilterStep(comp.GetLoopStep(), comp.GetAccuVar(), comp.GetIterVar()) {
// Validate result expression matches expected pattern
if !con.isIdentityResult(comp.GetResult(), comp.GetAccuVar()) {
con.logger.LogAttrs(context.Background(), slog.LevelWarn,
"comprehension result doesn't match expected filter() pattern",
slog.String("accu_var", comp.GetAccuVar()),
)
}
info.Type = ComprehensionFilter
info.Predicate = con.extractPredicateFromConditionalStep(comp.GetLoopStep())
con.logger.LogAttrs(context.Background(), slog.LevelDebug,
Expand Down Expand Up @@ -242,30 +295,138 @@ func (con *converter) isListAppendStep(step *exprpb.Expr, accuVar string) bool {
return false
}

func (con *converter) isConditionalCountStep(step *exprpb.Expr, _ string) bool {
if call := step.GetCallExpr(); call != nil {
return call.Function == operators.Conditional && len(call.Args) == 3
func (con *converter) isConditionalCountStep(step *exprpb.Expr, accuVar string) bool {
call := step.GetCallExpr()
if call == nil || call.Function != operators.Conditional || len(call.Args) != 3 {
return false
}

// Validate structure: conditional(predicate, accu + 1, accu)
// Then-branch should be: accu + 1
thenExpr := call.Args[1]
if addCall := thenExpr.GetCallExpr(); addCall != nil {
if addCall.Function == operators.Add && len(addCall.Args) == 2 {
// Check one arg is accu, other is constant 1
hasAccu := false
hasOne := false
for _, arg := range addCall.Args {
if ident := arg.GetIdentExpr(); ident != nil && ident.Name == accuVar {
hasAccu = true
}
if constant := arg.GetConstExpr(); constant != nil && constant.GetInt64Value() == 1 {
hasOne = true
}
}
if hasAccu && hasOne {
// Validate else-branch is: accu
elseExpr := call.Args[2]
if ident := elseExpr.GetIdentExpr(); ident != nil && ident.Name == accuVar {
return true
}
}
}
}
return false
}

func (con *converter) isConditionalAppendStep(step *exprpb.Expr, _ string) bool {
if call := step.GetCallExpr(); call != nil {
return call.Function == operators.Conditional && len(call.Args) == 3
func (con *converter) isConditionalAppendStep(step *exprpb.Expr, accuVar string, iterVar string) bool {
call := step.GetCallExpr()
if call == nil || call.Function != operators.Conditional || len(call.Args) != 3 {
return false
}

// Validate structure: conditional(filter, accu + [transform], accu)
// Then-branch should contain list append operation with a TRANSFORM (not just iterVar)
thenExpr := call.Args[1]
if addCall := thenExpr.GetCallExpr(); addCall != nil {
if addCall.Function == operators.Add && len(addCall.Args) == 2 {
hasAccu := false
hasTransformList := false
for _, arg := range addCall.Args {
if ident := arg.GetIdentExpr(); ident != nil && ident.Name == accuVar {
hasAccu = true
}
// Check if it's a list, but NOT [iterVar] (that would be a filter)
if listExpr := arg.GetListExpr(); listExpr != nil && len(listExpr.Elements) == 1 {
// If the list contains just the iteration variable, this is a filter, not map
if elemIdent := listExpr.Elements[0].GetIdentExpr(); elemIdent != nil && elemIdent.Name == iterVar {
return false // This is a filter pattern, not map-with-filter
}
hasTransformList = true
}
}
if hasAccu && hasTransformList {
// Validate else-branch is: accu
elseExpr := call.Args[2]
if ident := elseExpr.GetIdentExpr(); ident != nil && ident.Name == accuVar {
return true
}
}
}
}
return false
}

func (con *converter) isConditionalFilterStep(step *exprpb.Expr, _, _ string) bool {
if call := step.GetCallExpr(); call != nil {
return call.Function == operators.Conditional && len(call.Args) == 3
func (con *converter) isConditionalFilterStep(step *exprpb.Expr, accuVar string, iterVar string) bool {
call := step.GetCallExpr()
if call == nil || call.Function != operators.Conditional || len(call.Args) != 3 {
return false
}

// Validate structure: conditional(predicate, accu + [iterVar], accu)
// Then-branch should append the iteration variable
thenExpr := call.Args[1]
if addCall := thenExpr.GetCallExpr(); addCall != nil {
if addCall.Function == operators.Add && len(addCall.Args) == 2 {
hasAccu := false
hasIterVarList := false
for _, arg := range addCall.Args {
if ident := arg.GetIdentExpr(); ident != nil && ident.Name == accuVar {
hasAccu = true
}
// Check if it's a list containing just the iteration variable
if listExpr := arg.GetListExpr(); listExpr != nil && len(listExpr.Elements) == 1 {
if elemIdent := listExpr.Elements[0].GetIdentExpr(); elemIdent != nil && elemIdent.Name == iterVar {
hasIterVarList = true
}
}
}
if hasAccu && hasIterVarList {
// Validate else-branch is: accu
elseExpr := call.Args[2]
if ident := elseExpr.GetIdentExpr(); ident != nil && ident.Name == accuVar {
return true
}
}
}
}
return false
}

func (con *converter) isEqualsOneResult(result *exprpb.Expr, _ string) bool {
if call := result.GetCallExpr(); call != nil {
return call.Function == operators.Equals
func (con *converter) isEqualsOneResult(result *exprpb.Expr, accuVar string) bool {
call := result.GetCallExpr()
if call == nil || call.Function != operators.Equals || len(call.Args) != 2 {
return false
}

// Validate structure: accu == 1
hasAccu := false
hasOne := false
for _, arg := range call.Args {
if ident := arg.GetIdentExpr(); ident != nil && ident.Name == accuVar {
hasAccu = true
}
if constant := arg.GetConstExpr(); constant != nil && constant.GetInt64Value() == 1 {
hasOne = true
}
}
return hasAccu && hasOne
}

// isIdentityResult validates that the result expression is just the accumulator variable
func (con *converter) isIdentityResult(result *exprpb.Expr, accuVar string) bool {
if ident := result.GetIdentExpr(); ident != nil {
return ident.Name == accuVar
}
return false
}
Expand Down
Loading
Loading