package com.alibaba.alink.operator.batch.evaluation;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary;
import com.alibaba.alink.operator.common.evaluation.EvaluationMetricsCollector;
import com.alibaba.alink.operator.common.evaluation.EvaluationUtil;
import com.alibaba.alink.operator.common.evaluation.RankingMetrics;
import com.alibaba.alink.params.evaluation.EvalRankingParams;
import java.util.List;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.EVAL_METRICS)})
@NameCn("排序评估")
@NameEn("Eval Ranking")
/* loaded from: input_file:com/alibaba/alink/operator/batch/evaluation/EvalRankingBatchOp.class */
public class EvalRankingBatchOp extends BatchOperator<EvalRankingBatchOp> implements EvalRankingParams<EvalRankingBatchOp>, EvaluationMetricsCollector<RankingMetrics, EvalRankingBatchOp> {
    private static final long serialVersionUID = 4418406919511122133L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/evaluation/EvalRankingBatchOp$CalcLocal.class */
    public static class CalcLocal extends RichMapPartitionFunction<Row, BaseMetricsSummary> {
        private static final long serialVersionUID = -2274636215166393789L;
        String labelKObject;
        String predictionKObject;

        public CalcLocal(String str, String str2) {
            this.labelKObject = str;
            this.predictionKObject = str2;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<BaseMetricsSummary> collector) throws Exception {
            collector.collect(EvaluationUtil.getRankingMetrics(iterable, (Tuple3) getRuntimeContext().getBroadcastVariable(EvalMultiLabelBatchOp.LABELS).get(0), this.labelKObject, this.predictionKObject));
        }
    }

    public EvalRankingBatchOp() {
        super(null);
    }

    public EvalRankingBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public EvalRankingBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        TableUtil.assertSelectedColExist(checkAndGetFirst.getColNames(), getLabelCol());
        TableUtil.assertSelectedColExist(checkAndGetFirst.getColNames(), getPredictionCol());
        DataSet<Row> dataSet = checkAndGetFirst.select(new String[]{getLabelCol(), getPredictionCol()}).getDataSet();
        setOutputTable(DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) dataSet.rebalance().mapPartition(new CalcLocal(getLabelRankingInfo(), getPredictionRankingInfo())).withBroadcastSet(EvalMultiLabelBatchOp.getLabelNumberAndMaxK(dataSet, getLabelRankingInfo(), getPredictionRankingInfo()), EvalMultiLabelBatchOp.LABELS).reduce(new EvaluationUtil.ReduceBaseMetrics()).flatMap(new EvaluationUtil.SaveDataAsParams()), new TableSchema(new String[]{"data"}, new TypeInformation[]{Types.STRING})));
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.common.evaluation.EvaluationMetricsCollector
    public RankingMetrics createMetrics(List<Row> list) {
        return new RankingMetrics(list.get(0));
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ EvalRankingBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }

    @Override // com.alibaba.alink.operator.common.evaluation.EvaluationMetricsCollector
    public /* bridge */ /* synthetic */ RankingMetrics createMetrics(List list) {
        return createMetrics((List<Row>) list);
    }
}
