diff --git a/extensions/src/main/java/dev/cel/extensions/SetsExtensionsRuntimeImpl.java b/extensions/src/main/java/dev/cel/extensions/SetsExtensionsRuntimeImpl.java index a42fba189..a02fdba8a 100644 --- a/extensions/src/main/java/dev/cel/extensions/SetsExtensionsRuntimeImpl.java +++ b/extensions/src/main/java/dev/cel/extensions/SetsExtensionsRuntimeImpl.java @@ -45,28 +45,34 @@ ImmutableSet newFunctionBindings() { for (SetsFunction function : functions) { switch (function) { case CONTAINS: - bindingBuilder.add( - CelFunctionBinding.from( - "list_sets_contains_list", - Collection.class, - Collection.class, - this::containsAll)); + bindingBuilder.addAll( + CelFunctionBinding.fromOverloads( + function.getFunction(), + CelFunctionBinding.from( + "list_sets_contains_list", + Collection.class, + Collection.class, + this::containsAll))); break; case EQUIVALENT: - bindingBuilder.add( - CelFunctionBinding.from( - "list_sets_equivalent_list", - Collection.class, - Collection.class, - (listA, listB) -> containsAll(listA, listB) && containsAll(listB, listA))); + bindingBuilder.addAll( + CelFunctionBinding.fromOverloads( + function.getFunction(), + CelFunctionBinding.from( + "list_sets_equivalent_list", + Collection.class, + Collection.class, + (listA, listB) -> containsAll(listA, listB) && containsAll(listB, listA)))); break; case INTERSECTS: - bindingBuilder.add( - CelFunctionBinding.from( - "list_sets_intersects_list", - Collection.class, - Collection.class, - this::setIntersects)); + bindingBuilder.addAll( + CelFunctionBinding.fromOverloads( + function.getFunction(), + CelFunctionBinding.from( + "list_sets_intersects_list", + Collection.class, + Collection.class, + this::setIntersects))); break; } } diff --git a/extensions/src/test/java/dev/cel/extensions/CelSetsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelSetsExtensionsTest.java index 1aac5a023..9007bba2e 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelSetsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelSetsExtensionsTest.java @@ -19,8 +19,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelBuilder; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; @@ -30,47 +33,34 @@ import dev.cel.common.CelValidationResult; import dev.cel.common.types.ListType; import dev.cel.common.types.SimpleType; -import dev.cel.compiler.CelCompiler; -import dev.cel.compiler.CelCompilerFactory; import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelRuntime; -import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.testing.CelRuntimeFlavor; import java.util.List; +import java.util.Map; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public final class CelSetsExtensionsTest { - private static final CelOptions CEL_OPTIONS = CelOptions.current().build(); - private static final CelCompiler COMPILER = - CelCompilerFactory.standardCelCompilerBuilder() - .addMessageTypes(TestAllTypes.getDescriptor()) - .setOptions(CEL_OPTIONS) - .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) - .addLibraries(CelExtensions.sets(CEL_OPTIONS)) - .addVar("list", ListType.create(SimpleType.INT)) - .addVar("subList", ListType.create(SimpleType.INT)) - .addFunctionDeclarations( - CelFunctionDecl.newFunctionDeclaration( - "new_int", - CelOverloadDecl.newGlobalOverload( - "new_int_int64", SimpleType.INT, SimpleType.INT))) - .build(); - - private static final CelRuntime RUNTIME = - CelRuntimeFactory.standardCelRuntimeBuilder() - .addMessageTypes(TestAllTypes.getDescriptor()) - .addLibraries(CelExtensions.sets(CEL_OPTIONS)) - .setOptions(CEL_OPTIONS) - .addFunctionBindings( - CelFunctionBinding.from( - "new_int_int64", - Long.class, - // Intentionally return java.lang.Integer to test primitive type adaptation - Math::toIntExact)) - .build(); + private static final CelOptions CEL_OPTIONS = + CelOptions.current().enableHeterogeneousNumericComparisons(true).build(); + + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; + + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + this.cel = setupEnv(runtimeFlavor.builder()); + } @Test public void library() { @@ -87,22 +77,14 @@ public void library() { public void contains_integerListWithSameValue_succeeds() throws Exception { ImmutableList list = ImmutableList.of(1, 2, 3, 4); ImmutableList subList = ImmutableList.of(1, 2, 3, 4); - CelAbstractSyntaxTree ast = COMPILER.compile("sets.contains(list, subList)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(ImmutableMap.of("list", list, "subList", subList)); - - assertThat(result).isEqualTo(true); + assertThat( + eval("sets.contains(list, subList)", ImmutableMap.of("list", list, "subList", subList))) + .isEqualTo(true); } @Test public void contains_integerListAsExpression_succeeds() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("sets.contains([1, 1], [1])").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(true); + assertThat(eval("sets.contains([1, 1], [1])")).isEqualTo(true); } @Test @@ -119,12 +101,7 @@ public void contains_integerListAsExpression_succeeds() throws Exception { + " [TestAllTypes{single_int64: 2, single_uint64: 3u}])', expected: false}") public void contains_withProtoMessage_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -133,12 +110,7 @@ public void contains_withProtoMessage_succeeds(String expression, boolean expect @TestParameters("{expression: 'sets.contains([new_int(2)], [1])', expected: false}") public void contains_withFunctionReturningInteger_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -157,12 +129,9 @@ public void contains_withFunctionReturningInteger_succeeds(String expression, bo @TestParameters("{list: [1], subList: [1, 2], expected: false}") public void contains_withIntTypes_succeeds( List list, List subList, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("sets.contains(list, subList)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(ImmutableMap.of("list", list, "subList", subList)); - - assertThat(result).isEqualTo(expected); + assertThat( + eval("sets.contains(list, subList)", ImmutableMap.of("list", list, "subList", subList))) + .isEqualTo(expected); } @Test @@ -177,12 +146,9 @@ public void contains_withIntTypes_succeeds( @TestParameters("{list: [2, 3.0], subList: [2, 3], expected: true}") public void contains_withDoubleTypes_succeeds( List list, List subList, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("sets.contains(list, subList)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(ImmutableMap.of("list", list, "subList", subList)); - - assertThat(result).isEqualTo(expected); + assertThat( + eval("sets.contains(list, subList)", ImmutableMap.of("list", list, "subList", subList))) + .isEqualTo(expected); } @Test @@ -193,12 +159,7 @@ public void contains_withDoubleTypes_succeeds( @TestParameters("{expression: 'sets.contains([[1], [2, 3.0]], [[2, 3]])', expected: true}") public void contains_withNestedLists_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -206,19 +167,16 @@ public void contains_withNestedLists_succeeds(String expression, boolean expecte @TestParameters("{expression: 'sets.contains([1], [1, \"1\"])', expected: false}") public void contains_withMixingIntAndString_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test - @TestParameters("{expression: 'sets.contains([1], [\"1\"])'}") - @TestParameters("{expression: 'sets.contains([\"1\"], [1])'}") - public void contains_withMixingIntAndString_throwsException(String expression) throws Exception { - CelValidationResult invalidData = COMPILER.compile(expression); + public void contains_withMixingIntAndString_throwsException( + @TestParameter({"sets.contains([1], [\"1\"])", "sets.contains([\"1\"], [1])"}) + String expression) + throws Exception { + Assume.assumeFalse(isParseOnly); + CelValidationResult invalidData = cel.compile(expression); assertThat(invalidData.getErrors()).hasSize(1); assertThat(invalidData.getErrors().get(0).getMessage()) @@ -227,12 +185,7 @@ public void contains_withMixingIntAndString_throwsException(String expression) t @Test public void contains_withMixedValues_succeeds() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("sets.contains([1, 2], [2u, 2.0])").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(true); + assertThat(eval("sets.contains([1, 2], [2u, 2.0])")).isEqualTo(true); } @Test @@ -249,12 +202,7 @@ public void contains_withMixedValues_succeeds() throws Exception { "{expression: 'sets.contains([[[[[[5]]]]]], [[1], [2, 3.0], [[[[[5]]]]]])', expected: false}") public void contains_withMultiLevelNestedList_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -269,12 +217,7 @@ public void contains_withMultiLevelNestedList_succeeds(String expression, boolea + " false}") public void contains_withMapValues_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -289,12 +232,7 @@ public void contains_withMapValues_succeeds(String expression, boolean expected) @TestParameters("{expression: 'sets.equivalent([1, 2], [2, 2, 2])', expected: false}") public void equivalent_withIntTypes_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -308,12 +246,7 @@ public void equivalent_withIntTypes_succeeds(String expression, boolean expected @TestParameters("{expression: 'sets.equivalent([1, 2], [1u, 2, 2.3])', expected: false}") public void equivalent_withMixedTypes_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -338,12 +271,7 @@ public void equivalent_withMixedTypes_succeeds(String expression, boolean expect + " [TestAllTypes{single_int64: 2, single_uint64: 3u}])', expected: false}") public void equivalent_withProtoMessage_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -361,12 +289,7 @@ public void equivalent_withProtoMessage_succeeds(String expression, boolean expe + " expected: false}") public void equivalent_withMapValues_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -391,12 +314,7 @@ public void equivalent_withMapValues_succeeds(String expression, boolean expecte @TestParameters("{expression: 'sets.intersects([1], [1.1, 2u])', expected: false}") public void intersects_withMixedTypes_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -414,12 +332,11 @@ public void intersects_withMixedTypes_succeeds(String expression, boolean expect @TestParameters("{expression: 'sets.intersects([{2: 1}], [{1: 1}])', expected: false}") public void intersects_withMapValues_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); + // The LEGACY runtime is not spec compliant, because decimal keys are not allowed for maps. + Assume.assumeFalse( + runtimeFlavor.equals(CelRuntimeFlavor.PLANNER) && expression.contains("1.0:")); - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -444,25 +361,21 @@ public void intersects_withMapValues_succeeds(String expression, boolean expecte + " [TestAllTypes{single_int64: 2, single_uint64: 3u}])', expected: false}") public void intersects_withProtoMessage_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test public void setsExtension_containsFunctionSubset_succeeds() throws Exception { CelSetsExtensions setsExtensions = CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.CONTAINS); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder().addLibraries(setsExtensions).build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder().addLibraries(setsExtensions).build(); + Cel cel = + runtimeFlavor + .builder() + .addCompilerLibraries(setsExtensions) + .addRuntimeLibraries(setsExtensions) + .build(); - Object evaluatedResult = - celRuntime.createProgram(celCompiler.compile("sets.contains([1, 2], [2])").getAst()).eval(); + Object evaluatedResult = eval(cel, "sets.contains([1, 2], [2])", ImmutableMap.of()); assertThat(evaluatedResult).isEqualTo(true); } @@ -471,15 +384,14 @@ public void setsExtension_containsFunctionSubset_succeeds() throws Exception { public void setsExtension_equivalentFunctionSubset_succeeds() throws Exception { CelSetsExtensions setsExtensions = CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.EQUIVALENT); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder().addLibraries(setsExtensions).build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder().addLibraries(setsExtensions).build(); + Cel cel = + runtimeFlavor + .builder() + .addCompilerLibraries(setsExtensions) + .addRuntimeLibraries(setsExtensions) + .build(); - Object evaluatedResult = - celRuntime - .createProgram(celCompiler.compile("sets.equivalent([1, 1], [1])").getAst()) - .eval(); + Object evaluatedResult = eval(cel, "sets.equivalent([1, 1], [1])", ImmutableMap.of()); assertThat(evaluatedResult).isEqualTo(true); } @@ -488,44 +400,95 @@ public void setsExtension_equivalentFunctionSubset_succeeds() throws Exception { public void setsExtension_intersectsFunctionSubset_succeeds() throws Exception { CelSetsExtensions setsExtensions = CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.INTERSECTS); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder().addLibraries(setsExtensions).build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder().addLibraries(setsExtensions).build(); + Cel cel = + runtimeFlavor + .builder() + .addCompilerLibraries(setsExtensions) + .addRuntimeLibraries(setsExtensions) + .build(); - Object evaluatedResult = - celRuntime - .createProgram(celCompiler.compile("sets.intersects([1, 1], [1])").getAst()) - .eval(); + Object evaluatedResult = eval(cel, "sets.intersects([1, 1], [1])", ImmutableMap.of()); assertThat(evaluatedResult).isEqualTo(true); } @Test public void setsExtension_compileUnallowedFunction_throws() { + Assume.assumeFalse(isParseOnly); CelSetsExtensions setsExtensions = CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.EQUIVALENT); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder().addLibraries(setsExtensions).build(); + Cel cel = runtimeFlavor.builder().addCompilerLibraries(setsExtensions).build(); assertThrows( - CelValidationException.class, - () -> celCompiler.compile("sets.contains([1, 2], [2])").getAst()); + CelValidationException.class, () -> cel.compile("sets.contains([1, 2], [2])").getAst()); } @Test public void setsExtension_evaluateUnallowedFunction_throws() throws Exception { CelSetsExtensions setsExtensions = CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.CONTAINS, SetsFunction.EQUIVALENT); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder().addLibraries(setsExtensions).build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() - .addLibraries(CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.EQUIVALENT)) + CelSetsExtensions runtimeLibrary = + CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.EQUIVALENT); + Cel cel = + runtimeFlavor + .builder() + .addCompilerLibraries(setsExtensions) + .addRuntimeLibraries(runtimeLibrary) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile("sets.contains([1, 2], [2])").getAst(); + CelAbstractSyntaxTree ast = + isParseOnly + ? cel.parse("sets.contains([1, 2], [2])").getAst() + : cel.compile("sets.contains([1, 2], [2])").getAst(); + + if (runtimeFlavor.equals(CelRuntimeFlavor.PLANNER) && !isParseOnly) { + // Fails at plan time + assertThrows(CelEvaluationException.class, () -> cel.createProgram(ast)); + } else { + CelRuntime.Program program = cel.createProgram(ast); + assertThrows(CelEvaluationException.class, () -> program.eval()); + } + } + + private Object eval(Cel cel, String expression, Map variables) throws Exception { + CelAbstractSyntaxTree ast; + if (isParseOnly) { + ast = cel.parse(expression).getAst(); + } else { + ast = cel.compile(expression).getAst(); + } + return cel.createProgram(ast).eval(variables); + } + + private Object eval(String expression) throws Exception { + return eval(this.cel, expression, ImmutableMap.of()); + } + + private Object eval(String expression, Map variables) throws Exception { + return eval(this.cel, expression, variables); + } - assertThrows(CelEvaluationException.class, () -> celRuntime.createProgram(ast).eval()); + private static Cel setupEnv(CelBuilder celBuilder) { + return celBuilder + .addMessageTypes(TestAllTypes.getDescriptor()) + .setOptions(CEL_OPTIONS) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .addCompilerLibraries(CelExtensions.sets(CEL_OPTIONS)) + .addRuntimeLibraries(CelExtensions.sets(CEL_OPTIONS)) + .addVar("list", ListType.create(SimpleType.INT)) + .addVar("subList", ListType.create(SimpleType.INT)) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "new_int", + CelOverloadDecl.newGlobalOverload("new_int_int64", SimpleType.INT, SimpleType.INT))) + .addFunctionBindings( + CelFunctionBinding.fromOverloads( + "new_int", + CelFunctionBinding.from( + "new_int_int64", + Long.class, + // Intentionally return java.lang.Integer to test primitive type adaptation + Math::toIntExact))) + .build(); } } diff --git a/runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeAndroidTest.java b/runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeAndroidTest.java index 54ce24417..73492d126 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeAndroidTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeAndroidTest.java @@ -149,9 +149,10 @@ public void toRuntimeBuilder_propertiesCopied() { assertThat(newRuntimeBuilder.standardFunctionBuilder.build()) .containsExactly(intFunction, equalsOperator) .inOrder(); - assertThat(newRuntimeBuilder.customFunctionBindings).hasSize(2); + assertThat(newRuntimeBuilder.customFunctionBindings).hasSize(3); assertThat(newRuntimeBuilder.customFunctionBindings).containsKey("string_isEmpty"); assertThat(newRuntimeBuilder.customFunctionBindings).containsKey("list_sets_intersects_list"); + assertThat(newRuntimeBuilder.customFunctionBindings).containsKey("sets.intersects"); } @Test