package com.alibaba.alink.operator.batch.evaluation;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary;
import com.alibaba.alink.operator.common.evaluation.ClassificationEvaluationUtil;
import com.alibaba.alink.operator.common.evaluation.EvaluationMetricsCollector;
import com.alibaba.alink.operator.common.evaluation.EvaluationUtil;
import com.alibaba.alink.operator.common.evaluation.MultiClassMetrics;
import com.alibaba.alink.params.evaluation.EvalBinaryClassParams;
import com.alibaba.alink.params.evaluation.EvalMultiClassParams;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.EVAL_METRICS)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "labelCol"), @ParamSelectColumnSpec(name = "predictionCol"), @ParamSelectColumnSpec(name = "predictionDetailCol", allowedTypeCollections = {TypeCollections.STRING_TYPE})})
@NameCn("多分类评估")
@NameEn("Eval Multi Class")
/* loaded from: input_file:com/alibaba/alink/operator/batch/evaluation/EvalMultiClassBatchOp.class */
public class EvalMultiClassBatchOp extends BatchOperator<EvalMultiClassBatchOp> implements EvalMultiClassParams<EvalMultiClassBatchOp>, EvaluationMetricsCollector<MultiClassMetrics, EvalMultiClassBatchOp> {
    private static final String LABELS = "labels";
    private static final String DATA_OUTPUT = "Data";
    private static final long serialVersionUID = -2027803227905959081L;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/evaluation/EvalMultiClassBatchOp$CalLabelDetailLocal.class */
    public static class CalLabelDetailLocal extends RichMapPartitionFunction<Row, BaseMetricsSummary> {
        private static final long serialVersionUID = 5680342197308160013L;
        private Tuple2<Map<Object, Integer>, Object[]> map;
        private TypeInformation labelType;

        public CalLabelDetailLocal(TypeInformation typeInformation) {
            this.labelType = typeInformation;
        }

        public void open(Configuration configuration) throws Exception {
            List broadcastVariable = getRuntimeContext().getBroadcastVariable(EvalMultiClassBatchOp.LABELS);
            AkPreconditions.checkState(broadcastVariable.size() > 0, (ExceptionWithErrorCode) new AkIllegalDataException("Please check the evaluation input! there is no effective row!"));
            this.map = (Tuple2) broadcastVariable.get(0);
        }

        public void mapPartition(Iterable<Row> iterable, Collector<BaseMetricsSummary> collector) {
            collector.collect(EvaluationUtil.getDetailStatistics(iterable, false, this.map, this.labelType));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/evaluation/EvalMultiClassBatchOp$CalLabelPredictionLocal.class */
    public static class CalLabelPredictionLocal extends RichMapPartitionFunction<Row, BaseMetricsSummary> {
        private static final long serialVersionUID = -2439136352527525005L;
        private Tuple2<Map<Object, Integer>, Object[]> map;

        CalLabelPredictionLocal() {
        }

        public void open(Configuration configuration) throws Exception {
            List broadcastVariable = getRuntimeContext().getBroadcastVariable(EvalMultiClassBatchOp.LABELS);
            AkPreconditions.checkState(broadcastVariable.size() > 0, (ExceptionWithErrorCode) new AkIllegalDataException("Please check the evaluation input! there is no effective row!"));
            this.map = (Tuple2) broadcastVariable.get(0);
        }

        public void mapPartition(Iterable<Row> iterable, Collector<BaseMetricsSummary> collector) {
            collector.collect(EvaluationUtil.getMultiClassMetrics(iterable, this.map));
        }
    }

    public EvalMultiClassBatchOp() {
        this(null);
    }

    public EvalMultiClassBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.common.evaluation.EvaluationMetricsCollector
    public MultiClassMetrics createMetrics(List<Row> list) {
        return new MultiClassMetrics(list.get(0));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public EvalMultiClassBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        DataSet<BaseMetricsSummary> calLabelPredDetailLocal;
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String str = (String) get(EvalMultiClassParams.LABEL_COL);
        TypeInformation<?> findColTypeWithAssertAndHint = TableUtil.findColTypeWithAssertAndHint(checkAndGetFirst.getSchema(), str);
        String str2 = (String) get(EvalBinaryClassParams.POS_LABEL_VAL_STR);
        ClassificationEvaluationUtil.Type judgeEvaluationType = ClassificationEvaluationUtil.judgeEvaluationType(getParams());
        switch (judgeEvaluationType) {
            case PRED_RESULT:
                String str3 = (String) get(EvalMultiClassParams.PREDICTION_COL);
                TableUtil.assertSelectedColExist(checkAndGetFirst.getColNames(), str, str3);
                calLabelPredDetailLocal = calLabelPredictionLocal(checkAndGetFirst.select(new String[]{str, str3}).getDataSet(), str2, findColTypeWithAssertAndHint);
                break;
            case PRED_DETAIL:
                String str4 = (String) get(EvalMultiClassParams.PREDICTION_DETAIL_COL);
                TableUtil.assertSelectedColExist(checkAndGetFirst.getColNames(), str, str4);
                calLabelPredDetailLocal = calLabelPredDetailLocal(checkAndGetFirst.select(new String[]{str, str4}).getDataSet(), str2, findColTypeWithAssertAndHint);
                break;
            default:
                throw new AkUnsupportedOperationException("Unsupported evaluation type: " + judgeEvaluationType);
        }
        setOutput(calLabelPredDetailLocal.reduce(new EvaluationUtil.ReduceBaseMetrics()).flatMap(new EvaluationUtil.SaveDataAsParams()), new String[]{DATA_OUTPUT}, new TypeInformation[]{Types.STRING});
        return this;
    }

    private static DataSet<BaseMetricsSummary> calLabelPredictionLocal(DataSet<Row> dataSet, String str, TypeInformation typeInformation) {
        return dataSet.rebalance().mapPartition(new CalLabelPredictionLocal()).withBroadcastSet(dataSet.flatMap(new FlatMapFunction<Row, Object>() { // from class: com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp.1
            private static final long serialVersionUID = -120689740292597906L;

            public void flatMap(Row row, Collector<Object> collector) {
                if (EvaluationUtil.checkRowFieldNotNull(row)) {
                    collector.collect(row.getField(0));
                    collector.collect(row.getField(1));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Object>) collector);
            }
        }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(false, str, typeInformation, false)), LABELS);
    }

    private static DataSet<BaseMetricsSummary> calLabelPredDetailLocal(DataSet<Row> dataSet, String str, final TypeInformation typeInformation) {
        return dataSet.rebalance().mapPartition(new CalLabelDetailLocal(typeInformation)).withBroadcastSet(dataSet.flatMap(new FlatMapFunction<Row, Object>() { // from class: com.alibaba.alink.operator.batch.evaluation.EvalMultiClassBatchOp.2
            private static final long serialVersionUID = 7858786264569432008L;

            public void flatMap(Row row, Collector<Object> collector) {
                if (EvaluationUtil.checkRowFieldNotNull(row)) {
                    Set<Object> keySet = EvaluationUtil.extractLabelProbMap(row, typeInformation).keySet();
                    collector.getClass();
                    keySet.forEach(collector::collect);
                    collector.collect(row.getField(0));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Object>) collector);
            }
        }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(false, str, typeInformation, false)), LABELS);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ EvalMultiClassBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }

    @Override // com.alibaba.alink.operator.common.evaluation.EvaluationMetricsCollector
    public /* bridge */ /* synthetic */ MultiClassMetrics createMetrics(List list) {
        return createMetrics((List<Row>) list);
    }
}
