From a11aa1e01c53fbf3796a7a7eccd1c9b69db8aedf Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Mon, 22 Dec 2025 22:44:42 -0800 Subject: [PATCH] Plan Comprehensions PiperOrigin-RevId: 848026572 --- runtime/BUILD.bazel | 8 + .../src/main/java/dev/cel/runtime/BUILD.bazel | 4 +- .../dev/cel/runtime/ConcatenatedListView.java | 11 +- .../java/dev/cel/runtime/planner/BUILD.bazel | 17 ++ .../dev/cel/runtime/planner/EvalFold.java | 204 ++++++++++++++++++ .../cel/runtime/planner/ProgramPlanner.java | 26 ++- .../runtime/planner/ProgramPlannerTest.java | 43 +++- 7 files changed, 299 insertions(+), 14 deletions(-) create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java diff --git a/runtime/BUILD.bazel b/runtime/BUILD.bazel index 7760d96b8..07bfdebbc 100644 --- a/runtime/BUILD.bazel +++ b/runtime/BUILD.bazel @@ -255,3 +255,11 @@ java_library( visibility = ["//:internal"], exports = ["//runtime/src/main/java/dev/cel/runtime:metadata"], ) + +java_library( + name = "concatenated_list_view", + visibility = ["//:internal"], + exports = [ + "//runtime/src/main/java/dev/cel/runtime:concatenated_list_view", + ], +) diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index 847092cc3..b55aec00f 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -1142,7 +1142,9 @@ java_library( name = "concatenated_list_view", srcs = ["ConcatenatedListView.java"], # used_by_android - visibility = ["//visibility:private"], + tags = [ + ], + deps = ["//common/annotations"], ) java_library( diff --git a/runtime/src/main/java/dev/cel/runtime/ConcatenatedListView.java b/runtime/src/main/java/dev/cel/runtime/ConcatenatedListView.java index ac7696751..c15e76f77 100644 --- a/runtime/src/main/java/dev/cel/runtime/ConcatenatedListView.java +++ b/runtime/src/main/java/dev/cel/runtime/ConcatenatedListView.java @@ -2,7 +2,7 @@ // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. -// You may obtain a copy of the License aj +// You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // @@ -14,6 +14,7 @@ package dev.cel.runtime; +import dev.cel.common.annotations.Internal; import java.util.AbstractList; import java.util.ArrayList; import java.util.Collection; @@ -27,8 +28,12 @@ * comprehensions that dispatch `add_list` to concat N lists together). * *

This does not support any of the standard list operations from {@link java.util.List}. + * + + *

CEL Library Internals. Do Not Use. */ -final class ConcatenatedListView extends AbstractList { +@Internal +public final class ConcatenatedListView extends AbstractList { private final List> sourceLists; private int totalSize = 0; @@ -36,7 +41,7 @@ final class ConcatenatedListView extends AbstractList { this.sourceLists = new ArrayList<>(); } - ConcatenatedListView(Collection collection) { + public ConcatenatedListView(Collection collection) { this(); addAll(collection); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index 13a8d5759..09a082e87 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -22,6 +22,7 @@ java_library( ":eval_create_list", ":eval_create_map", ":eval_create_struct", + ":eval_fold", ":eval_or", ":eval_test_only", ":eval_unary", @@ -309,6 +310,22 @@ java_library( ], ) +java_library( + name = "eval_fold", + srcs = ["EvalFold.java"], + deps = [ + ":planned_interpretable", + "//runtime:concatenated_list_view", + "//runtime:evaluation_exception", + "//runtime:evaluation_listener", + "//runtime:function_resolver", + "//runtime:interpretable", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + "@maven//:org_jspecify_jspecify", + ], +) + java_library( name = "eval_helpers", srcs = ["EvalHelpers.java"], diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java new file mode 100644 index 000000000..2a8ba1603 --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java @@ -0,0 +1,204 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.Immutable; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelEvaluationListener; +import dev.cel.runtime.CelFunctionResolver; +import dev.cel.runtime.ConcatenatedListView; +import dev.cel.runtime.GlobalResolver; +import java.util.Collection; +import java.util.Map; +import org.jspecify.annotations.Nullable; + +@Immutable +final class EvalFold extends PlannedInterpretable { + + private final String accuVar; + private final PlannedInterpretable accuInit; + private final String iterVar; + private final String iterVar2; + private final PlannedInterpretable iterRange; + private final PlannedInterpretable condition; + private final PlannedInterpretable loopStep; + private final PlannedInterpretable result; + + static EvalFold create( + long exprId, + String accuVar, + PlannedInterpretable accuInit, + String iterVar, + String iterVar2, + PlannedInterpretable iterRange, + PlannedInterpretable loopCondition, + PlannedInterpretable loopStep, + PlannedInterpretable result) { + return new EvalFold( + exprId, accuVar, accuInit, iterVar, iterVar2, iterRange, loopCondition, loopStep, result); + } + + private EvalFold( + long exprId, + String accuVar, + PlannedInterpretable accuInit, + String iterVar, + String iterVar2, + PlannedInterpretable iterRange, + PlannedInterpretable condition, + PlannedInterpretable loopStep, + PlannedInterpretable result) { + super(exprId); + this.accuVar = accuVar; + this.accuInit = accuInit; + this.iterVar = iterVar; + this.iterVar2 = iterVar2; + this.iterRange = iterRange; + this.condition = condition; + this.loopStep = loopStep; + this.result = result; + } + + @Override + public Object eval(GlobalResolver resolver) throws CelEvaluationException { + Object iterRangeRaw = iterRange.eval(resolver); + Folder folder = new Folder(resolver, accuVar, iterVar, iterVar2); + folder.accuVal = maybeWrapAccumulator(accuInit.eval(folder)); + + Object result; + if (iterRangeRaw instanceof Map) { + result = evalMap((Map) iterRangeRaw, folder); + } else if (iterRangeRaw instanceof Collection) { + result = evalList((Collection) iterRangeRaw, folder); + } else { + throw new IllegalArgumentException("Unexpected iter_range type: " + iterRangeRaw.getClass()); + } + + return maybeUnwrapAccumulator(result); + } + + @Override + public Object eval(GlobalResolver resolver, CelEvaluationListener listener) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public Object eval(GlobalResolver resolver, CelFunctionResolver lateBoundFunctionResolver) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public Object eval( + GlobalResolver resolver, + CelFunctionResolver lateBoundFunctionResolver, + CelEvaluationListener listener) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + private Object evalMap(Map iterRange, Folder folder) throws CelEvaluationException { + for (Map.Entry entry : iterRange.entrySet()) { + folder.iterVarVal = entry.getKey(); + if (!iterVar2.isEmpty()) { + folder.iterVar2Val = entry.getValue(); + } + + boolean cond = (boolean) condition.eval(folder); + if (!cond) { + return result.eval(folder); + } + + // TODO: Introduce comprehension safety controls, such as iteration limit. + folder.accuVal = loopStep.eval(folder); + } + return result.eval(folder); + } + + private Object evalList(Collection iterRange, Folder folder) throws CelEvaluationException { + int index = 0; + for (Object item : iterRange) { + if (iterVar2.isEmpty()) { + folder.iterVarVal = item; + } else { + folder.iterVarVal = (long) index; + folder.iterVar2Val = item; + } + + boolean cond = (boolean) condition.eval(folder); + if (!cond) { + return result.eval(folder); + } + + folder.accuVal = loopStep.eval(folder); + index++; + } + return result.eval(folder); + } + + private static Object maybeWrapAccumulator(Object val) { + if (val instanceof Collection) { + return new ConcatenatedListView<>((Collection) val); + } + // TODO: Introduce mutable map support (for comp v2) + return val; + } + + private static Object maybeUnwrapAccumulator(Object val) { + if (val instanceof ConcatenatedListView) { + return ImmutableList.copyOf((ConcatenatedListView) val); + } + + // TODO: Introduce mutable map support (for comp v2) + return val; + } + + private static class Folder implements GlobalResolver { + private final GlobalResolver resolver; + private final String accuVar; + private final String iterVar; + private final String iterVar2; + + private Object iterVarVal; + private Object iterVar2Val; + private Object accuVal; + + private Folder(GlobalResolver resolver, String accuVar, String iterVar, String iterVar2) { + this.resolver = resolver; + this.accuVar = accuVar; + this.iterVar = iterVar; + this.iterVar2 = iterVar2; + } + + @Override + public @Nullable Object resolve(String name) { + if (name.equals(accuVar)) { + return accuVal; + } + + if (name.equals(iterVar)) { + return this.iterVarVal; + } + + if (name.equals(iterVar2)) { + return this.iterVar2Val; + } + + return resolver.resolve(name); + } + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index 13d0d10ce..c751ba88c 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -26,6 +26,7 @@ import dev.cel.common.ast.CelConstant; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelCall; +import dev.cel.common.ast.CelExpr.CelComprehension; import dev.cel.common.ast.CelExpr.CelList; import dev.cel.common.ast.CelExpr.CelMap; import dev.cel.common.ast.CelExpr.CelSelect; @@ -94,10 +95,12 @@ private PlannedInterpretable plan(CelExpr celExpr, PlannerContext ctx) { return planCreateStruct(celExpr, ctx); case MAP: return planCreateMap(celExpr, ctx); + case COMPREHENSION: + return planComprehension(celExpr, ctx); case NOT_SET: throw new UnsupportedOperationException("Unsupported kind: " + celExpr.getKind()); default: - throw new IllegalArgumentException("Not yet implemented kind: " + celExpr.getKind()); + throw new UnsupportedOperationException("Unexpected kind: " + celExpr.getKind()); } } @@ -280,6 +283,27 @@ private PlannedInterpretable planCreateMap(CelExpr celExpr, PlannerContext ctx) return EvalCreateMap.create(celExpr.id(), keys, values); } + private PlannedInterpretable planComprehension(CelExpr expr, PlannerContext ctx) { + CelComprehension comprehension = expr.comprehension(); + + PlannedInterpretable accuInit = plan(comprehension.accuInit(), ctx); + PlannedInterpretable iterRange = plan(comprehension.iterRange(), ctx); + PlannedInterpretable loopCondition = plan(comprehension.loopCondition(), ctx); + PlannedInterpretable loopStep = plan(comprehension.loopStep(), ctx); + PlannedInterpretable result = plan(comprehension.result(), ctx); + + return EvalFold.create( + expr.id(), + comprehension.accuVar(), + accuInit, + comprehension.iterVar(), + comprehension.iterVar2(), + iterRange, + loopCondition, + loopStep, + result); + } + /** * resolveFunction determines the call target, function name, and overload name (when unambiguous) * from the given call expr. diff --git a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java index 871b5aa95..642f3e5ca 100644 --- a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java +++ b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java @@ -144,7 +144,7 @@ public final class ProgramPlannerTest { newMemberOverload( "bytes_concat_bytes", SimpleType.BYTES, SimpleType.BYTES, SimpleType.BYTES))) .addMessageTypes(TestAllTypes.getDescriptor()) - .addLibraries(CelExtensions.optional()) + .addLibraries(CelExtensions.optional(), CelExtensions.comprehensions()) .setContainer(CEL_CONTAINER) .build(); @@ -178,10 +178,7 @@ private static DefaultDispatcher newDispatcher() { builder, Operator.NOT_STRICTLY_FALSE.getFunction(), fromStandardFunction(NotStrictlyFalseFunction.create())); - addBindings( - builder, - "dyn", - fromStandardFunction(DynFunction.create())); + addBindings(builder, "dyn", fromStandardFunction(DynFunction.create())); // Custom functions addBindings( @@ -663,7 +660,7 @@ public void plan_select_safeTraversal() throws Exception { CelAbstractSyntaxTree ast = compile("msg.single_nested_message.bb"); Program program = PLANNER.plan(ast); - Object result = program.eval(ImmutableMap.of("msg", TestAllTypes.newBuilder().build())); + Object result = program.eval(ImmutableMap.of("msg", TestAllTypes.getDefaultInstance())); assertThat(result).isEqualTo(0L); } @@ -775,6 +772,35 @@ public void plan_select_badPresenceTest_throws() throws Exception { + " maps."); } + @Test + @TestParameters("{expression: '[1,2,3].exists(x, x > 0) == true'}") + @TestParameters("{expression: '[1,2,3].exists(x, x < 0) == false'}") + @TestParameters("{expression: '[1,2,3].exists(i, v, i >= 0 && v > 0) == true'}") + @TestParameters("{expression: '[1,2,3].exists(i, v, i < 0 || v < 0) == false'}") + @TestParameters("{expression: '[1,2,3].map(x, x + 1) == [2,3,4]'}") + public void plan_comprehension_lists(String expression) throws Exception { + CelAbstractSyntaxTree ast = compile(expression); + Program program = PLANNER.plan(ast); + + boolean result = (boolean) program.eval(); + + assertThat(result).isTrue(); + } + + @Test + @TestParameters("{expression: '{\"a\": 1, \"b\": 2}.exists(k, k == \"a\")'}") + @TestParameters("{expression: '{\"a\": 1, \"b\": 2}.exists(k, k == \"c\") == false'}") + @TestParameters("{expression: '{\"a\": \"b\", \"c\": \"c\"}.exists(k, v, k == v)'}") + @TestParameters("{expression: '{\"a\": 1, \"b\": 2}.exists(k, v, v == 3) == false'}") + public void plan_comprehension_maps(String expression) throws Exception { + CelAbstractSyntaxTree ast = compile(expression); + Program program = PLANNER.plan(ast); + + boolean result = (boolean) program.eval(); + + assertThat(result).isTrue(); + } + private CelAbstractSyntaxTree compile(String expression) throws Exception { CelAbstractSyntaxTree ast = CEL_COMPILER.parse(expression).getAst(); if (isParseOnly) { @@ -848,12 +874,11 @@ private enum TypeLiteralTestCase { } } - @SuppressWarnings("Immutable") // Test only private enum PresenceTestCase { PROTO_FIELD_PRESENT( "has(msg.single_string)", TestAllTypes.newBuilder().setSingleString("foo").build(), true), - PROTO_FIELD_ABSENT("has(msg.single_string)", TestAllTypes.newBuilder().build(), false), + PROTO_FIELD_ABSENT("has(msg.single_string)", TestAllTypes.getDefaultInstance(), false), PROTO_NESTED_FIELD_PRESENT( "has(msg.single_nested_message.bb)", TestAllTypes.newBuilder() @@ -861,7 +886,7 @@ private enum PresenceTestCase { .build(), true), PROTO_NESTED_FIELD_ABSENT( - "has(msg.single_nested_message.bb)", TestAllTypes.newBuilder().build(), false), + "has(msg.single_nested_message.bb)", TestAllTypes.getDefaultInstance(), false), PROTO_MAP_KEY_PRESENT("has(map_var.foo)", ImmutableMap.of("foo", "1"), true), PROTO_MAP_KEY_ABSENT("has(map_var.bar)", ImmutableMap.of(), false);