diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java index 2225a9f8312..b1a43f2f194 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java @@ -338,7 +338,7 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) { List arguments = node.getFuncArgs().stream().map(arg -> analyze(arg, context)).toList(); RexNode resolvedNode = - PPLFuncImpTable.INSTANCE.resolveSafe( + PPLFuncImpTable.INSTANCE.resolve( context.rexBuilder, node.getFuncName(), arguments.toArray(new RexNode[0])); if (resolvedNode != null) { return resolvedNode; diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 40a42b48b0d..f5b42e9afdc 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -11,6 +11,7 @@ import com.google.common.collect.ImmutableMap; import java.math.BigDecimal; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -19,13 +20,12 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.runtime.PairList; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.fun.SqlTrimFunction.Flag; import org.apache.calcite.sql.type.SqlTypeName; -import org.checkerframework.checker.nullness.qual.Nullable; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.executor.QueryType; public class PPLFuncImpTable { @@ -91,21 +91,44 @@ default List getParams() { INSTANCE = new PPLFuncImpTable(builder); } - private final ImmutableMap> map; + /** + * The registry for built-in functions. Functions defined by the PPL specification, whose + * implementations are independent of any specific data storage, should be registered here + * internally. + */ + private final ImmutableMap>> + functionRegistry; + + /** + * The external function registry. Functions whose implementations depend on a specific data + * engine should be registered here. This reduces coupling between the core module and particular + * storage backends. + */ + private final Map>> + externalFunctionRegistry; private PPLFuncImpTable(Builder builder) { - final ImmutableMap.Builder> + final ImmutableMap.Builder>> mapBuilder = ImmutableMap.builder(); - builder.map.forEach((k, v) -> mapBuilder.put(k, v.immutable())); - this.map = ImmutableMap.copyOf(mapBuilder.build()); + builder.map.forEach((k, v) -> mapBuilder.put(k, List.copyOf(v))); + this.functionRegistry = ImmutableMap.copyOf(mapBuilder.build()); + this.externalFunctionRegistry = new HashMap<>(); } - public @Nullable RexNode resolveSafe( - final RexBuilder builder, final String functionName, RexNode... args) { - try { - return resolve(builder, functionName, args); - } catch (Exception e) { - return null; + /** + * Register a function implementation from external services dynamically. + * + * @param functionName the name of the function, has to be defined in BuiltinFunctionName + * @param functionImp the implementation of the function + */ + public void registerExternalFunction(BuiltinFunctionName functionName, FunctionImp functionImp) { + CalciteFuncSignature signature = + new CalciteFuncSignature(functionName.getName(), functionImp.getParams()); + if (externalFunctionRegistry.containsKey(functionName)) { + externalFunctionRegistry.get(functionName).add(Pair.of(signature, functionImp)); + } else { + externalFunctionRegistry.put( + functionName, new ArrayList<>(List.of(Pair.of(signature, functionImp)))); } } @@ -119,7 +142,14 @@ public RexNode resolve(final RexBuilder builder, final String functionName, RexN public RexNode resolve( final RexBuilder builder, final BuiltinFunctionName functionName, RexNode... args) { - final PairList implementList = map.get(functionName); + // Check the external function registry first. This allows the data-storage-dependent + // function implementations to override the internal ones with the same name. + List> implementList = + externalFunctionRegistry.get(functionName); + // If the function is not part of the external registry, check the internal registry. + if (implementList == null) { + implementList = functionRegistry.get(functionName); + } if (implementList == null || implementList.isEmpty()) { throw new IllegalStateException(String.format("Cannot resolve function: %s", functionName)); } @@ -401,7 +431,7 @@ void populate() { } private static class Builder extends AbstractBuilder { - private final Map> map = + private final Map>> map = new HashMap<>(); @Override @@ -409,9 +439,9 @@ void register(BuiltinFunctionName functionName, FunctionImp implement) { CalciteFuncSignature signature = new CalciteFuncSignature(functionName.getName(), implement.getParams()); if (map.containsKey(functionName)) { - map.get(functionName).add(signature, implement); + map.get(functionName).add(Pair.of(signature, implement)); } else { - map.put(functionName, PairList.of(signature, implement)); + map.put(functionName, new ArrayList<>(List.of(Pair.of(signature, implement)))); } } } diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteGeoIpFunctionsIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteGeoIpFunctionsIT.java index 247a8cc045a..20de0289d17 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteGeoIpFunctionsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteGeoIpFunctionsIT.java @@ -12,7 +12,6 @@ public class CalciteGeoIpFunctionsIT extends GeoIpFunctionsIT { public void init() throws Exception { super.init(); enableCalcite(); - // TODO: "https://github.com/opensearch-project/sql/issues/3506" - // disallowCalciteFallback(); + disallowCalciteFallback(); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java index 177eaf88ee9..a0105e7ee32 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java @@ -18,7 +18,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; -import lombok.RequiredArgsConstructor; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; @@ -38,14 +37,16 @@ import org.opensearch.sql.executor.ExecutionEngine.Schema.Column; import org.opensearch.sql.executor.Explain; import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.PPLFuncImpTable; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.executor.protector.ExecutionProtector; +import org.opensearch.sql.opensearch.functions.GeoIpFunction; import org.opensearch.sql.opensearch.util.JdbcOpenSearchDataTypeConvertor; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.storage.TableScanOperator; /** OpenSearch execution engine implementation. */ -@RequiredArgsConstructor public class OpenSearchExecutionEngine implements ExecutionEngine { private final OpenSearchClient client; @@ -53,6 +54,16 @@ public class OpenSearchExecutionEngine implements ExecutionEngine { private final ExecutionProtector executionProtector; private final PlanSerializer planSerializer; + public OpenSearchExecutionEngine( + OpenSearchClient client, + ExecutionProtector executionProtector, + PlanSerializer planSerializer) { + this.client = client; + this.executionProtector = executionProtector; + this.planSerializer = planSerializer; + registerOpenSearchFunctions(); + } + @Override public void execute(PhysicalPlan physicalPlan, ResponseListener listener) { execute(physicalPlan, ExecutionContext.emptyExecutionContext(), listener); @@ -224,4 +235,12 @@ private void buildResultSet( QueryResponse response = new QueryResponse(schema, values, null); listener.onResponse(response); } + + /** Registers opensearch-dependent functions */ + private void registerOpenSearchFunctions() { + PPLFuncImpTable.FunctionImp geoIpImpl = + (builder, args) -> + builder.makeCall(new GeoIpFunction(client.getNodeClient()).toUDF("GEOIP"), args); + PPLFuncImpTable.INSTANCE.registerExternalFunction(BuiltinFunctionName.GEOIP, geoIpImpl); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/GeoIpFunction.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/GeoIpFunction.java new file mode 100644 index 00000000000..9d39c46c1d3 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/GeoIpFunction.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.functions; + +import java.util.*; +import java.util.stream.Collectors; +import lombok.Getter; +import org.apache.calcite.adapter.enumerable.NotNullImplementor; +import org.apache.calcite.adapter.enumerable.NullPolicy; +import org.apache.calcite.adapter.enumerable.RexToLixTranslator; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.opensearch.geospatial.action.IpEnrichmentActionClient; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.expression.function.ImplementorUDF; +import org.opensearch.transport.client.node.NodeClient; + +/** + * {@code GEOIP(dataSourceName, ipAddress[, options])} looks up location information from given IP + * addresses via OpenSearch GeoSpatial plugin API. The options is a comma-separated list of fields + * to be returned. If not specified, all fields are returned. + * + *

Signatures: + * + *

    + *
  • (STRING, STRING) -> MAP + *
  • (STRING, STRING, STRING) -> MAP + *
+ */ +public class GeoIpFunction extends ImplementorUDF { + public GeoIpFunction(NodeClient nodeClient) { + super(new GeoIPImplementor(nodeClient), NullPolicy.ANY); + } + + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return op -> { + RelDataTypeFactory typeFactory = op.getTypeFactory(); + RelDataType varcharType = typeFactory.createSqlType(SqlTypeName.VARCHAR); + RelDataType anyType = typeFactory.createSqlType(SqlTypeName.ANY); + return typeFactory.createMapType(varcharType, anyType); + }; + } + + public static class GeoIPImplementor implements NotNullImplementor { + @Getter private static NodeClient nodeClient; + + public GeoIPImplementor(NodeClient nodeClient) { + GeoIPImplementor.nodeClient = nodeClient; + } + + @Override + public Expression implement( + RexToLixTranslator translator, RexCall call, List translatedOperands) { + if (getNodeClient() == null) { + throw new IllegalStateException("nodeClient is null."); + } + List operandsWithClient = new ArrayList<>(translatedOperands); + // Since a NodeClient cannot be passed as a parameter using Expressions.constant, + // it is instead provided through a function call. + operandsWithClient.add(Expressions.call(GeoIPImplementor.class, "getNodeClient")); + return Expressions.call(GeoIPImplementor.class, "fetchIpEnrichment", operandsWithClient); + } + + public static Map fetchIpEnrichment( + String dataSource, String ipAddress, NodeClient nodeClient) { + return fetchIpEnrichment(dataSource, ipAddress, Collections.emptySet(), nodeClient); + } + + public static Map fetchIpEnrichment( + String dataSource, String ipAddress, String commaSeparatedOptions, NodeClient nodeClient) { + String unquotedOptions = StringUtils.unquoteText(commaSeparatedOptions); + final Set options = + Arrays.stream(unquotedOptions.split(",")).map(String::trim).collect(Collectors.toSet()); + return fetchIpEnrichment(dataSource, ipAddress, options, nodeClient); + } + + private static Map fetchIpEnrichment( + String dataSource, String ipAddress, Set options, NodeClient nodeClient) { + IpEnrichmentActionClient ipClient = new IpEnrichmentActionClient(nodeClient); + dataSource = StringUtils.unquoteText(dataSource); + try { + Map geoLocationData = ipClient.getGeoLocationData(ipAddress, dataSource); + Map enrichmentResult = + geoLocationData.entrySet().stream() + .filter(entry -> options.isEmpty() || options.contains(entry.getKey())) + .collect( + Collectors.toMap( + Map.Entry::getKey, v -> new ExprStringValue(v.getValue().toString()))); + @SuppressWarnings("unchecked") + Map result = + (Map) ExprTupleValue.fromExprValueMap(enrichmentResult).valueForCalcite(); + return result; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 679565e22f6..af821b6915c 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -233,7 +233,7 @@ public UnresolvedExpression visitTakeAggFunctionCall( /** Eval function. */ @Override public UnresolvedExpression visitBooleanFunctionCall(BooleanFunctionCallContext ctx) { - final String functionName = ctx.conditionFunctionName().getText().toLowerCase(); + final String functionName = ctx.conditionFunctionName().getText().toLowerCase(Locale.ROOT); return buildFunction( FUNCTION_NAME_MAPPING.getOrDefault(functionName, functionName), ctx.functionArgs().functionArg()); @@ -287,7 +287,7 @@ private Function buildFunction( public UnresolvedExpression visitSingleFieldRelevanceFunction( SingleFieldRelevanceFunctionContext ctx) { return new Function( - ctx.singleFieldRelevanceFunctionName().getText().toLowerCase(), + ctx.singleFieldRelevanceFunctionName().getText().toLowerCase(Locale.ROOT), singleFieldRelevanceArguments(ctx)); } @@ -295,7 +295,7 @@ public UnresolvedExpression visitSingleFieldRelevanceFunction( public UnresolvedExpression visitMultiFieldRelevanceFunction( MultiFieldRelevanceFunctionContext ctx) { return new Function( - ctx.multiFieldRelevanceFunctionName().getText().toLowerCase(), + ctx.multiFieldRelevanceFunctionName().getText().toLowerCase(Locale.ROOT), multiFieldRelevanceArguments(ctx)); } @@ -506,7 +506,7 @@ private List singleFieldRelevanceArguments( v -> builder.add( new UnresolvedArgument( - v.relevanceArgName().getText().toLowerCase(), + v.relevanceArgName().getText().toLowerCase(Locale.ROOT), new Literal( StringUtils.unquoteText(v.relevanceArgValue().getText()), DataType.STRING)))); @@ -534,7 +534,7 @@ private List multiFieldRelevanceArguments( v -> builder.add( new UnresolvedArgument( - v.relevanceArgName().getText().toLowerCase(), + v.relevanceArgName().getText().toLowerCase(Locale.ROOT), new Literal( StringUtils.unquoteText(v.relevanceArgValue().getText()), DataType.STRING)))); diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index 89edb0cfa27..346ef6660d7 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -72,6 +72,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; @@ -152,7 +153,7 @@ public UnresolvedExpression visitHighlightFunctionCall(HighlightFunctionCallCont .forEach( v -> builder.put( - v.highlightArgName().getText().toLowerCase(), + v.highlightArgName().getText().toLowerCase(Locale.ROOT), new Literal( StringUtils.unquoteText(v.highlightArgValue().getText()), DataType.STRING))); @@ -416,14 +417,15 @@ public UnresolvedExpression visitPercentileApproxFunctionCall( @Override public UnresolvedExpression visitNoFieldRelevanceFunction(NoFieldRelevanceFunctionContext ctx) { return new Function( - ctx.noFieldRelevanceFunctionName().getText().toLowerCase(), noFieldRelevanceArguments(ctx)); + ctx.noFieldRelevanceFunctionName().getText().toLowerCase(Locale.ROOT), + noFieldRelevanceArguments(ctx)); } @Override public UnresolvedExpression visitSingleFieldRelevanceFunction( SingleFieldRelevanceFunctionContext ctx) { return new Function( - ctx.singleFieldRelevanceFunctionName().getText().toLowerCase(), + ctx.singleFieldRelevanceFunctionName().getText().toLowerCase(Locale.ROOT), singleFieldRelevanceArguments(ctx)); } @@ -431,7 +433,7 @@ public UnresolvedExpression visitSingleFieldRelevanceFunction( public UnresolvedExpression visitAltSingleFieldRelevanceFunction( AltSingleFieldRelevanceFunctionContext ctx) { return new Function( - ctx.altSyntaxFunctionName.getText().toLowerCase(), + ctx.altSyntaxFunctionName.getText().toLowerCase(Locale.ROOT), altSingleFieldRelevanceFunctionArguments(ctx)); } @@ -446,11 +448,11 @@ public UnresolvedExpression visitMultiFieldRelevanceFunction( || funcName.equalsIgnoreCase(BuiltinFunctionName.MULTIMATCHQUERY.toString())) && !ctx.getRuleContexts(AlternateMultiMatchQueryContext.class).isEmpty()) { return new Function( - ctx.multiFieldRelevanceFunctionName().getText().toLowerCase(), + ctx.multiFieldRelevanceFunctionName().getText().toLowerCase(Locale.ROOT), alternateMultiMatchArguments(ctx)); } else { return new Function( - ctx.multiFieldRelevanceFunctionName().getText().toLowerCase(), + ctx.multiFieldRelevanceFunctionName().getText().toLowerCase(Locale.ROOT), multiFieldRelevanceArguments(ctx)); } } @@ -459,7 +461,7 @@ public UnresolvedExpression visitMultiFieldRelevanceFunction( public UnresolvedExpression visitAltMultiFieldRelevanceFunction( AltMultiFieldRelevanceFunctionContext ctx) { return new Function( - ctx.altSyntaxFunctionName.getText().toLowerCase(), + ctx.altSyntaxFunctionName.getText().toLowerCase(Locale.ROOT), altMultiFieldRelevanceFunctionArguments(ctx)); } @@ -504,12 +506,12 @@ private void fillRelevanceArgs( builder.add( v.argName == null ? new UnresolvedArgument( - v.relevanceArgName().getText().toLowerCase(), + v.relevanceArgName().getText().toLowerCase(Locale.ROOT), new Literal( StringUtils.unquoteText(v.relevanceArgValue().getText()), DataType.STRING)) : new UnresolvedArgument( - StringUtils.unquoteText(v.argName.getText()).toLowerCase(), + StringUtils.unquoteText(v.argName.getText()).toLowerCase(Locale.ROOT), new Literal( StringUtils.unquoteText(v.argVal.getText()), DataType.STRING)))); }