package com.alibaba.alink.operator.stream.evaluation;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary;
import com.alibaba.alink.operator.common.evaluation.ClassificationEvaluationUtil;
import com.alibaba.alink.operator.common.evaluation.EvaluationUtil;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.evaluation.BaseEvalClassStreamOp;
import com.alibaba.alink.operator.stream.utils.TimeUtil;
import com.alibaba.alink.params.evaluation.EvalBinaryClassParams;
import com.alibaba.alink.params.evaluation.EvalBinaryClassStreamParams;
import com.alibaba.alink.params.evaluation.EvalMultiClassStreamParams;
import java.util.HashSet;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
import org.apache.flink.streaming.api.windowing.windows.Window;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.EVAL_METRICS)})
@Internal
/* loaded from: input_file:com/alibaba/alink/operator/stream/evaluation/BaseEvalClassStreamOp.class */
public class BaseEvalClassStreamOp<T extends BaseEvalClassStreamOp<T>> extends StreamOperator<T> {
    private static final String DATA_OUTPUT = "Data";
    private static final long serialVersionUID = -6277527784116345678L;
    private final boolean binary;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/stream/evaluation/BaseEvalClassStreamOp$LabelPredictionWindow.class */
    public static class LabelPredictionWindow implements AllWindowFunction<Row, BaseMetricsSummary, TimeWindow> {
        private static final long serialVersionUID = -4426213828656690161L;
        private final boolean binary;
        private final String positiveValue;
        private final TypeInformation labelType;

        LabelPredictionWindow(boolean z, String str, TypeInformation typeInformation) {
            this.binary = z;
            this.positiveValue = str;
            this.labelType = typeInformation;
        }

        public void apply(TimeWindow timeWindow, Iterable<Row> iterable, Collector<BaseMetricsSummary> collector) throws Exception {
            HashSet hashSet = new HashSet();
            for (Row row : iterable) {
                if (EvaluationUtil.checkRowFieldNotNull(row)) {
                    hashSet.add(row.getField(0));
                    hashSet.add(row.getField(1));
                }
            }
            if (hashSet.size() > 0) {
                collector.collect(EvaluationUtil.getMultiClassMetrics(iterable, ClassificationEvaluationUtil.buildLabelIndexLabelArray(hashSet, this.binary, this.positiveValue, this.labelType, true)));
            }
        }

        public /* bridge */ /* synthetic */ void apply(Window window, Iterable iterable, Collector collector) throws Exception {
            apply((TimeWindow) window, (Iterable<Row>) iterable, (Collector<BaseMetricsSummary>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/stream/evaluation/BaseEvalClassStreamOp$PredDetailLabel.class */
    public static class PredDetailLabel implements AllWindowFunction<Row, BaseMetricsSummary, TimeWindow> {
        private static final long serialVersionUID = 8057305974098408321L;
        private final String positiveValue;
        private final Boolean binary;
        private final TypeInformation labelType;

        PredDetailLabel(String str, boolean z, TypeInformation typeInformation) {
            this.positiveValue = str;
            this.binary = Boolean.valueOf(z);
            this.labelType = typeInformation;
        }

        public void apply(TimeWindow timeWindow, Iterable<Row> iterable, Collector<BaseMetricsSummary> collector) throws Exception {
            HashSet hashSet = new HashSet();
            for (Row row : iterable) {
                if (EvaluationUtil.checkRowFieldNotNull(row)) {
                    hashSet.addAll(EvaluationUtil.extractLabelProbMap(row, this.labelType).keySet());
                    hashSet.add(row.getField(0));
                }
            }
            if (hashSet.size() > 0) {
                collector.collect(EvaluationUtil.getDetailStatistics(iterable, this.binary.booleanValue(), ClassificationEvaluationUtil.buildLabelIndexLabelArray(hashSet, this.binary.booleanValue(), this.positiveValue, this.labelType, true), this.labelType));
            }
        }

        public /* bridge */ /* synthetic */ void apply(Window window, Iterable iterable, Collector collector) throws Exception {
            apply((TimeWindow) window, (Iterable<Row>) iterable, (Collector<BaseMetricsSummary>) collector);
        }
    }

    public BaseEvalClassStreamOp(Params params, boolean z) {
        super(params);
        this.binary = z;
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public T linkFrom(StreamOperator<?>... streamOperatorArr) {
        SingleOutputStreamOperator apply;
        StreamOperator<?> checkAndGetFirst = checkAndGetFirst(streamOperatorArr);
        String str = (String) get(EvalMultiClassStreamParams.LABEL_COL);
        TypeInformation<?> findColTypeWithAssertAndHint = TableUtil.findColTypeWithAssertAndHint(checkAndGetFirst.getSchema(), str);
        String str2 = (String) get(EvalBinaryClassStreamParams.POS_LABEL_VAL_STR);
        double doubleValue = ((Double) get(EvalMultiClassStreamParams.TIME_INTERVAL)).doubleValue();
        if (this.binary && !getParams().contains(EvalBinaryClassParams.PREDICTION_DETAIL_COL)) {
            throw new AkIllegalOperatorParameterException("Binary Evaluation must give predictionDetailCol!");
        }
        ClassificationEvaluationUtil.Type judgeEvaluationType = ClassificationEvaluationUtil.judgeEvaluationType(getParams());
        switch (judgeEvaluationType) {
            case PRED_RESULT:
                String str3 = (String) get(EvalMultiClassStreamParams.PREDICTION_COL);
                TableUtil.assertSelectedColExist(checkAndGetFirst.getColNames(), str, str3);
                apply = checkAndGetFirst.select(new String[]{str, str3}).getDataStream().timeWindowAll(TimeUtil.convertTime(doubleValue)).apply(new LabelPredictionWindow(this.binary, str2, findColTypeWithAssertAndHint));
                break;
            case PRED_DETAIL:
                String str4 = (String) get(EvalMultiClassStreamParams.PREDICTION_DETAIL_COL);
                TableUtil.assertSelectedColExist(checkAndGetFirst.getColNames(), str, str4);
                apply = checkAndGetFirst.select(new String[]{str, str4}).getDataStream().timeWindowAll(TimeUtil.convertTime(doubleValue)).apply(new PredDetailLabel(str2, this.binary, findColTypeWithAssertAndHint));
                break;
            default:
                throw new AkUnsupportedOperationException("Unsupported evaluation type: " + judgeEvaluationType);
        }
        setOutput(apply.map(new EvaluationUtil.prependTagMapFunction((String) ClassificationEvaluationUtil.WINDOW.f0)).union(new DataStream[]{apply.map(new EvaluationUtil.AllDataMerge()).setParallelism(1).map(new EvaluationUtil.prependTagMapFunction((String) ClassificationEvaluationUtil.ALL.f0))}), new String[]{ClassificationEvaluationUtil.STATISTICS_OUTPUT, DATA_OUTPUT}, new TypeInformation[]{Types.STRING, Types.STRING});
        return this;
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public /* bridge */ /* synthetic */ StreamOperator linkFrom(StreamOperator[] streamOperatorArr) {
        return linkFrom((StreamOperator<?>[]) streamOperatorArr);
    }
}
