聊聊flink Table的groupBy操作

本文主要研究一下flink Table的groupBy操作

Table.groupBy

flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/api/table.scala

class Table(
    private[flink] val tableEnv: TableEnvironment,
    private[flink] val logicalPlan: LogicalNode) {

  //......

  def groupBy(fields: String): GroupedTable = {
    val fieldsExpr = ExpressionParser.parseExpressionList(fields)
    groupBy(fieldsExpr: _*)
  }

  def groupBy(fields: Expression*): GroupedTable = {
    new GroupedTable(this, fields)
  }

  //......
}
  • Table的groupBy操作支持两种参数,一种是String类型,一种是Expression类型;String参数的方法是将String转换为Expression,最后调用的Expression参数的groupBy方法,该方法创建了GroupedTable

GroupedTable

flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/api/table.scala

class GroupedTable(
  private[flink] val table: Table,
  private[flink] val groupKey: Seq[Expression]) {

  def select(fields: Expression*): Table = {
    val expandedFields = expandProjectList(fields, table.logicalPlan, table.tableEnv)
    val (aggNames, propNames) = extractAggregationsAndProperties(expandedFields, table.tableEnv)
    if (propNames.nonEmpty) {
      throw new ValidationException("Window properties can only be used on windowed tables.")
    }

    val projectsOnAgg = replaceAggregationsAndProperties(
      expandedFields, table.tableEnv, aggNames, propNames)
    val projectFields = extractFieldReferences(expandedFields ++ groupKey)

    new Table(table.tableEnv,
      Project(projectsOnAgg,
        Aggregate(groupKey, aggNames.map(a => Alias(a._1, a._2)).toSeq,
          Project(projectFields, table.logicalPlan).validate(table.tableEnv)
        ).validate(table.tableEnv)
      ).validate(table.tableEnv))
  }

  def select(fields: String): Table = {
    val fieldExprs = ExpressionParser.parseExpressionList(fields)
    //get the correct expression for AggFunctionCall
    val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv))
    select(withResolvedAggFunctionCall: _*)
  }
}
  • GroupedTable有两个属性,一个是原始的Table,一个是Seq[Expression]类型的groupKey
  • GroupedTable提供两个select方法,参数类型分别为String、Expression,String类型的参数最后也是转为Expression类型
  • select方法使用Project创建新的Table,而Project则是通过Aggregate来创建

Aggregate

flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/plan/logical/operators.scala

case class Aggregate(
    groupingExpressions: Seq[Expression],
    aggregateExpressions: Seq[NamedExpression],
    child: LogicalNode) extends UnaryNode {

  override def output: Seq[Attribute] = {
    (groupingExpressions ++ aggregateExpressions) map {
      case ne: NamedExpression => ne.toAttribute
      case e => Alias(e, e.toString).toAttribute
    }
  }

  override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
    child.construct(relBuilder)
    relBuilder.aggregate(
      relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
      aggregateExpressions.map {
        case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
        case _ => throw new RuntimeException("This should never happen.")
      }.asJava)
  }

  override def validate(tableEnv: TableEnvironment): LogicalNode = {
    implicit val relBuilder: RelBuilder = tableEnv.getRelBuilder
    val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate]
    val groupingExprs = resolvedAggregate.groupingExpressions
    val aggregateExprs = resolvedAggregate.aggregateExpressions
    aggregateExprs.foreach(validateAggregateExpression)
    groupingExprs.foreach(validateGroupingExpression)

    def validateAggregateExpression(expr: Expression): Unit = expr match {
      case distinctExpr: DistinctAgg =>
        distinctExpr.child match {
          case _: DistinctAgg => failValidation(
            "Chained distinct operators are not supported!")
          case aggExpr: Aggregation => validateAggregateExpression(aggExpr)
          case _ => failValidation(
            "Distinct operator can only be applied to aggregation expressions!")
        }
      // check aggregate function
      case aggExpr: Aggregation
        if aggExpr.getSqlAggFunction.requiresOver =>
        failValidation(s"OVER clause is necessary for window functions: [${aggExpr.getClass}].")
      // check no nested aggregation exists.
      case aggExpr: Aggregation =>
        aggExpr.children.foreach { child =>
          child.preOrderVisit {
            case agg: Aggregation =>
              failValidation(
                "It's not allowed to use an aggregate function as " +
                  "input of another aggregate function")
            case _ => // OK
          }
        }
      case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) =>
        failValidation(
          s"expression '$a' is invalid because it is neither" +
            " present in group by nor an aggregate function")
      case e if groupingExprs.exists(_.checkEquals(e)) => // OK
      case e => e.children.foreach(validateAggregateExpression)
    }

    def validateGroupingExpression(expr: Expression): Unit = {
      if (!expr.resultType.isKeyType) {
        failValidation(
          s"expression $expr cannot be used as a grouping expression " +
            "because it's not a valid key type which must be hashable and comparable")
      }
    }
    resolvedAggregate
  }
}
  • Aggregate继承了UnaryNode,它接收三个参数,一个是Seq[Expression]类型的groupingExpressions,一个是Seq[NamedExpression]类型的aggregateExpressions,一个是LogicalNode类型的child;construct方法调用了relBuilder.aggregate,传入的RelBuilder.GroupKey参数是通过relBuilder.groupKey构建,而传入的RelBuilder.AggCall参数则是通过aggregateExpressions.map构造而来

RelBuilder.groupKey

calcite-core-1.18.0-sources.jar!/org/apache/calcite/tools/RelBuilder.java

public class RelBuilder {
  protected final RelOptCluster cluster;
  protected final RelOptSchema relOptSchema;
  private final RelFactories.FilterFactory filterFactory;
  private final RelFactories.ProjectFactory projectFactory;
  private final RelFactories.AggregateFactory aggregateFactory;
  private final RelFactories.SortFactory sortFactory;
  private final RelFactories.ExchangeFactory exchangeFactory;
  private final RelFactories.SortExchangeFactory sortExchangeFactory;
  private final RelFactories.SetOpFactory setOpFactory;
  private final RelFactories.JoinFactory joinFactory;
  private final RelFactories.SemiJoinFactory semiJoinFactory;
  private final RelFactories.CorrelateFactory correlateFactory;
  private final RelFactories.ValuesFactory valuesFactory;
  private final RelFactories.TableScanFactory scanFactory;
  private final RelFactories.MatchFactory matchFactory;
  private final Deque<Frame> stack = new ArrayDeque<>();
  private final boolean simplify;
  private final RexSimplify simplifier;

  //......

  /** Creates an empty group key. */
  public GroupKey groupKey() {
    return groupKey(ImmutableList.of());
  }

  /** Creates a group key. */
  public GroupKey groupKey(RexNode... nodes) {
    return groupKey(ImmutableList.copyOf(nodes));
  }

  /** Creates a group key. */
  public GroupKey groupKey(Iterable<? extends RexNode> nodes) {
    return new GroupKeyImpl(ImmutableList.copyOf(nodes), false, null, null);
  }

  /** Creates a group key with grouping sets. */
  public GroupKey groupKey(Iterable<? extends RexNode> nodes,
      Iterable<? extends Iterable<? extends RexNode>> nodeLists) {
    return groupKey_(nodes, false, nodeLists);
  }

  /** Creates a group key of fields identified by ordinal. */
  public GroupKey groupKey(int... fieldOrdinals) {
    return groupKey(fields(ImmutableIntList.of(fieldOrdinals)));
  }

  /** Creates a group key of fields identified by name. */
  public GroupKey groupKey(String... fieldNames) {
    return groupKey(fields(ImmutableList.copyOf(fieldNames)));
  }

  public GroupKey groupKey(@Nonnull ImmutableBitSet groupSet) {
    return groupKey(groupSet, ImmutableList.of(groupSet));
  }

  public GroupKey groupKey(ImmutableBitSet groupSet,
      @Nonnull Iterable<? extends ImmutableBitSet> groupSets) {
    return groupKey_(groupSet, false, ImmutableList.copyOf(groupSets));
  }

  private GroupKey groupKey_(ImmutableBitSet groupSet, boolean indicator,
      @Nonnull ImmutableList<ImmutableBitSet> groupSets) {
    if (groupSet.length() > peek().getRowType().getFieldCount()) {
      throw new IllegalArgumentException("out of bounds: " + groupSet);
    }
    Objects.requireNonNull(groupSets);
    final ImmutableList<RexNode> nodes =
        fields(ImmutableIntList.of(groupSet.toArray()));
    final List<ImmutableList<RexNode>> nodeLists =
        Util.transform(groupSets,
            bitSet -> fields(ImmutableIntList.of(bitSet.toArray())));
    return groupKey_(nodes, indicator, nodeLists);
  }

  private GroupKey groupKey_(Iterable<? extends RexNode> nodes,
      boolean indicator,
      Iterable<? extends Iterable<? extends RexNode>> nodeLists) {
    final ImmutableList.Builder<ImmutableList<RexNode>> builder =
        ImmutableList.builder();
    for (Iterable<? extends RexNode> nodeList : nodeLists) {
      builder.add(ImmutableList.copyOf(nodeList));
    }
    return new GroupKeyImpl(ImmutableList.copyOf(nodes), indicator, builder.build(), null);
  }

  //......
}
  • RelBuilder提供了诸多groupKey方法用于创建GroupKey,其最后调用的是私有方法groupKey_,该方法创建了GroupKeyImpl

GroupKey

calcite-core-1.18.0-sources.jar!/org/apache/calcite/tools/RelBuilder.java

public interface GroupKey {
    /** Assigns an alias to this group key.
     *
     * <p>Used to assign field names in the {@code group} operation. */
    GroupKey alias(String alias);
  }

  /** Implementation of {@link GroupKey}. */
  protected static class GroupKeyImpl implements GroupKey {
    final ImmutableList<RexNode> nodes;
    final boolean indicator;
    final ImmutableList<ImmutableList<RexNode>> nodeLists;
    final String alias;

    GroupKeyImpl(ImmutableList<RexNode> nodes, boolean indicator,
        ImmutableList<ImmutableList<RexNode>> nodeLists, String alias) {
      this.nodes = Objects.requireNonNull(nodes);
      assert !indicator;
      this.indicator = indicator;
      this.nodeLists = nodeLists;
      this.alias = alias;
    }

    @Override public String toString() {
      return alias == null ? nodes.toString() : nodes + " as " + alias;
    }

    public GroupKey alias(String alias) {
      return Objects.equals(this.alias, alias)
          ? this
          : new GroupKeyImpl(nodes, indicator, nodeLists, alias);
    }
  }
  • GroupKey接口定义了alias方法,用于给group操作的字段别名;GroupKeyImpl是GroupKey接口的实现类,其alias返回的是GroupKeyImpl

RelBuilder.aggregate

calcite-core-1.18.0-sources.jar!/org/apache/calcite/tools/RelBuilder.java

public class RelBuilder {
  protected final RelOptCluster cluster;
  protected final RelOptSchema relOptSchema;
  private final RelFactories.FilterFactory filterFactory;
  private final RelFactories.ProjectFactory projectFactory;
  private final RelFactories.AggregateFactory aggregateFactory;
  private final RelFactories.SortFactory sortFactory;
  private final RelFactories.ExchangeFactory exchangeFactory;
  private final RelFactories.SortExchangeFactory sortExchangeFactory;
  private final RelFactories.SetOpFactory setOpFactory;
  private final RelFactories.JoinFactory joinFactory;
  private final RelFactories.SemiJoinFactory semiJoinFactory;
  private final RelFactories.CorrelateFactory correlateFactory;
  private final RelFactories.ValuesFactory valuesFactory;
  private final RelFactories.TableScanFactory scanFactory;
  private final RelFactories.MatchFactory matchFactory;
  private final Deque<Frame> stack = new ArrayDeque<>();
  private final boolean simplify;
  private final RexSimplify simplifier;

  //......

  /** Creates an {@link Aggregate} with an array of
   * calls. */
  public RelBuilder aggregate(GroupKey groupKey, AggCall... aggCalls) {
    return aggregate(groupKey, ImmutableList.copyOf(aggCalls));
  }

  public RelBuilder aggregate(GroupKey groupKey,
      List<AggregateCall> aggregateCalls) {
    return aggregate(groupKey,
        Lists.transform(aggregateCalls, AggCallImpl2::new));
  }

  /** Creates an {@link Aggregate} with a list of
   * calls. */
  public RelBuilder aggregate(GroupKey groupKey, Iterable<AggCall> aggCalls) {
    final Registrar registrar = new Registrar();
    registrar.extraNodes.addAll(fields());
    registrar.names.addAll(peek().getRowType().getFieldNames());
    final GroupKeyImpl groupKey_ = (GroupKeyImpl) groupKey;
    final ImmutableBitSet groupSet =
        ImmutableBitSet.of(registrar.registerExpressions(groupKey_.nodes));
  label:
    if (Iterables.isEmpty(aggCalls) && !groupKey_.indicator) {
      final RelMetadataQuery mq = peek().getCluster().getMetadataQuery();
      if (groupSet.isEmpty()) {
        final Double minRowCount = mq.getMinRowCount(peek());
        if (minRowCount == null || minRowCount < 1D) {
          // We can't remove "GROUP BY ()" if there's a chance the rel could be
          // empty.
          break label;
        }
      }
      if (registrar.extraNodes.size() == fields().size()) {
        final Boolean unique = mq.areColumnsUnique(peek(), groupSet);
        if (unique != null && unique) {
          // Rel is already unique.
          return project(fields(groupSet.asList()));
        }
      }
      final Double maxRowCount = mq.getMaxRowCount(peek());
      if (maxRowCount != null && maxRowCount <= 1D) {
        // If there is at most one row, rel is already unique.
        return this;
      }
    }
    final ImmutableList<ImmutableBitSet> groupSets;
    if (groupKey_.nodeLists != null) {
      final int sizeBefore = registrar.extraNodes.size();
      final SortedSet<ImmutableBitSet> groupSetSet =
          new TreeSet<>(ImmutableBitSet.ORDERING);
      for (ImmutableList<RexNode> nodeList : groupKey_.nodeLists) {
        final ImmutableBitSet groupSet2 =
            ImmutableBitSet.of(registrar.registerExpressions(nodeList));
        if (!groupSet.contains(groupSet2)) {
          throw new IllegalArgumentException("group set element " + nodeList
              + " must be a subset of group key");
        }
        groupSetSet.add(groupSet2);
      }
      groupSets = ImmutableList.copyOf(groupSetSet);
      if (registrar.extraNodes.size() > sizeBefore) {
        throw new IllegalArgumentException(
            "group sets contained expressions not in group key: "
                + registrar.extraNodes.subList(sizeBefore,
                registrar.extraNodes.size()));
      }
    } else {
      groupSets = ImmutableList.of(groupSet);
    }
    for (AggCall aggCall : aggCalls) {
      if (aggCall instanceof AggCallImpl) {
        final AggCallImpl aggCall1 = (AggCallImpl) aggCall;
        registrar.registerExpressions(aggCall1.operands);
        if (aggCall1.filter != null) {
          registrar.registerExpression(aggCall1.filter);
        }
      }
    }
    project(registrar.extraNodes);
    rename(registrar.names);
    final Frame frame = stack.pop();
    final RelNode r = frame.rel;
    final List<AggregateCall> aggregateCalls = new ArrayList<>();
    for (AggCall aggCall : aggCalls) {
      final AggregateCall aggregateCall;
      if (aggCall instanceof AggCallImpl) {
        final AggCallImpl aggCall1 = (AggCallImpl) aggCall;
        final List<Integer> args =
            registrar.registerExpressions(aggCall1.operands);
        final int filterArg = aggCall1.filter == null ? -1
            : registrar.registerExpression(aggCall1.filter);
        if (aggCall1.distinct && !aggCall1.aggFunction.isQuantifierAllowed()) {
          throw new IllegalArgumentException("DISTINCT not allowed");
        }
        if (aggCall1.filter != null && !aggCall1.aggFunction.allowsFilter()) {
          throw new IllegalArgumentException("FILTER not allowed");
        }
        RelCollation collation =
            RelCollations.of(aggCall1.orderKeys
                .stream()
                .map(orderKey ->
                    collation(orderKey, RelFieldCollation.Direction.ASCENDING,
                        null, Collections.emptyList()))
                .collect(Collectors.toList()));
        aggregateCall =
            AggregateCall.create(aggCall1.aggFunction, aggCall1.distinct,
                aggCall1.approximate, args, filterArg, collation,
                groupSet.cardinality(), r, null, aggCall1.alias);
      } else {
        aggregateCall = ((AggCallImpl2) aggCall).aggregateCall;
      }
      aggregateCalls.add(aggregateCall);
    }

    assert ImmutableBitSet.ORDERING.isStrictlyOrdered(groupSets) : groupSets;
    for (ImmutableBitSet set : groupSets) {
      assert groupSet.contains(set);
    }
    RelNode aggregate = aggregateFactory.createAggregate(r,
        groupKey_.indicator, groupSet, groupSets, aggregateCalls);

    // build field list
    final ImmutableList.Builder<Field> fields = ImmutableList.builder();
    final List<RelDataTypeField> aggregateFields =
        aggregate.getRowType().getFieldList();
    int i = 0;
    // first, group fields
    for (Integer groupField : groupSet.asList()) {
      RexNode node = registrar.extraNodes.get(groupField);
      final SqlKind kind = node.getKind();
      switch (kind) {
      case INPUT_REF:
        fields.add(frame.fields.get(((RexInputRef) node).getIndex()));
        break;
      default:
        String name = aggregateFields.get(i).getName();
        RelDataTypeField fieldType =
            new RelDataTypeFieldImpl(name, i, node.getType());
        fields.add(new Field(ImmutableSet.of(), fieldType));
        break;
      }
      i++;
    }
    // second, indicator fields (copy from aggregate rel type)
    if (groupKey_.indicator) {
      for (int j = 0; j < groupSet.cardinality(); ++j) {
        final RelDataTypeField field = aggregateFields.get(i);
        final RelDataTypeField fieldType =
            new RelDataTypeFieldImpl(field.getName(), i, field.getType());
        fields.add(new Field(ImmutableSet.of(), fieldType));
        i++;
      }
    }
    // third, aggregate fields. retain `i' as field index
    for (int j = 0; j < aggregateCalls.size(); ++j) {
      final AggregateCall call = aggregateCalls.get(j);
      final RelDataTypeField fieldType =
          new RelDataTypeFieldImpl(aggregateFields.get(i + j).getName(), i + j,
              call.getType());
      fields.add(new Field(ImmutableSet.of(), fieldType));
    }
    stack.push(new Frame(aggregate, fields.build()));
    return this;
  }

  //......
}
  • RelBuilder的aggregate操作接收两个参数,一个是GroupKey,一个是集合类型的AggCall;其中AggCall最后是转换为AggregateCall,然后通过aggregateFactory.createAggregate方法取出stack队首的Frame,创建新的RelNode,构造新的Frame,然后重新放入stack的队首

RelFactories.AggregateFactory.createAggregate

calcite-core-1.18.0-sources.jar!/org/apache/calcite/rel/core/RelFactories.java

public class RelFactories {
  //......
  public static final AggregateFactory DEFAULT_AGGREGATE_FACTORY =
      new AggregateFactoryImpl();

  public interface AggregateFactory {
    /** Creates an aggregate. */
    RelNode createAggregate(RelNode input, boolean indicator,
        ImmutableBitSet groupSet, ImmutableList<ImmutableBitSet> groupSets,
        List<AggregateCall> aggCalls);
  }

  private static class AggregateFactoryImpl implements AggregateFactory {
    @SuppressWarnings("deprecation")
    public RelNode createAggregate(RelNode input, boolean indicator,
        ImmutableBitSet groupSet, ImmutableList<ImmutableBitSet> groupSets,
        List<AggregateCall> aggCalls) {
      return LogicalAggregate.create(input, indicator,
          groupSet, groupSets, aggCalls);
    }
  }

  //......
}
  • RelFactories定义了AggregateFactory接口,该接口定义了createAggregate方法,用于将一系列的AggregateCall操作转为新的RelNode;AggregateFactoryImpl是AggregateFactory接口的实现类,它的createAggregate方法调用的是LogicalAggregate.create方法

LogicalAggregate.create

calcite-core-1.18.0-sources.jar!/org/apache/calcite/rel/logical/LogicalAggregate.java

public final class LogicalAggregate extends Aggregate {
  //......

  public static LogicalAggregate create(final RelNode input,
      ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets,
      List<AggregateCall> aggCalls) {
    return create_(input, false, groupSet, groupSets, aggCalls);
  }

  @Deprecated // to be removed before 2.0
  public static LogicalAggregate create(final RelNode input,
      boolean indicator,
      ImmutableBitSet groupSet,
      List<ImmutableBitSet> groupSets,
      List<AggregateCall> aggCalls) {
    return create_(input, indicator, groupSet, groupSets, aggCalls);
  }

  private static LogicalAggregate create_(final RelNode input,
      boolean indicator,
      ImmutableBitSet groupSet,
      List<ImmutableBitSet> groupSets,
      List<AggregateCall> aggCalls) {
    final RelOptCluster cluster = input.getCluster();
    final RelTraitSet traitSet = cluster.traitSetOf(Convention.NONE);
    return new LogicalAggregate(cluster, traitSet, input, indicator, groupSet,
        groupSets, aggCalls);
  }

  //......
}
  • LogicalAggregate的create方法创建的是LogicalAggregate

小结

  • Table的groupBy操作支持两种参数,一种是String类型,一种是Expression类型;String参数的方法是将String转换为Expression,最后调用的Expression参数的groupBy方法,该方法创建了GroupedTable
  • GroupedTable有两个属性,一个是原始的Table,一个是Seq[Expression]类型的groupKey;它提供两个select方法,参数类型分别为String、Expression,String类型的参数最后也是转为Expression类型;select方法使用Project创建新的Table,而Project则是通过Aggregate来创建
  • Aggregate继承了UnaryNode,它接收三个参数,一个是Seq[Expression]类型的groupingExpressions,一个是Seq[NamedExpression]类型的aggregateExpressions,一个是LogicalNode类型的child;construct方法调用了relBuilder.aggregate,传入的RelBuilder.GroupKey参数是通过relBuilder.groupKey构建,而传入的RelBuilder.AggCall参数则是通过aggregateExpressions.map构造而来
  • RelBuilder的aggregate操作接收两个参数,一个是GroupKey(GroupKey接口定义了alias方法,用于给group操作的字段别名;GroupKeyImpl是GroupKey接口的实现类,其alias返回的是GroupKeyImpl),一个是集合类型的AggCall;其中AggCall最后是转换为AggregateCall,然后通过aggregateFactory.createAggregate方法取出stack队首的Frame,创建新的RelNode,构造新的Frame,然后重新放入stack的队首
  • RelFactories定义了AggregateFactory接口,该接口定义了createAggregate方法,用于将一系列的AggregateCall操作转为新的RelNode;AggregateFactoryImpl是AggregateFactory接口的实现类,它的createAggregate方法调用的是LogicalAggregate.create方法,创建的是LogicalAggregate

doc

相关推荐