package com.alibaba.alink.operator.local.evaluation;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.evaluation.EvaluationUtil;
import com.alibaba.alink.operator.common.evaluation.MultiLabelMetrics;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.params.evaluation.EvalMultiLabelParams;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.EVAL_METRICS)})
@NameCn("多标签分类评估")
/* loaded from: input_file:com/alibaba/alink/operator/local/evaluation/EvalMultiLabelLocalOp.class */
public class EvalMultiLabelLocalOp extends LocalOperator<EvalMultiLabelLocalOp> implements EvalMultiLabelParams<EvalMultiLabelLocalOp>, EvaluationMetricsCollector<MultiLabelMetrics, EvalMultiLabelLocalOp> {
    public static String LABELS = "labels";
    private MultiLabelMetrics metrics;

    public EvalMultiLabelLocalOp() {
        super(null);
    }

    public EvalMultiLabelLocalOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.local.LocalOperator
    public EvalMultiLabelLocalOp linkFrom(LocalOperator<?>... localOperatorArr) {
        LocalOperator<?> checkAndGetFirst = checkAndGetFirst(localOperatorArr);
        AkPreconditions.checkArgument(TableUtil.findColIndex(checkAndGetFirst.getColNames(), getLabelCol()) >= 0 && TableUtil.findColIndex(checkAndGetFirst.getColNames(), getPredictionCol()) >= 0, (ExceptionWithErrorCode) new AkIllegalOperatorParameterException("Can not find label column or prediction column!"));
        this.metrics = EvaluationUtil.getMultiLabelMetrics(checkAndGetFirst.getOutputTable().select(getLabelCol(), getPredictionCol()).getRows(), getLabelNumberAndMaxK(checkAndGetFirst.select(new String[]{getLabelCol(), getPredictionCol()}).getOutputTable().getRows(), getPredictionRankingInfo(), getPredictionRankingInfo()), getLabelRankingInfo(), getPredictionRankingInfo()).toMetrics();
        setOutputTable(new MTable(new Row[]{this.metrics.serialize()}, new TableSchema(new String[]{"multilabel_eval_result"}, new TypeInformation[]{Types.STRING})));
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.local.evaluation.EvaluationMetricsCollector
    public MultiLabelMetrics createMetrics(List<Row> list) {
        return new MultiLabelMetrics(list.get(0));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.local.evaluation.EvaluationMetricsCollector
    public MultiLabelMetrics collectMetrics() {
        return this.metrics;
    }

    public static Tuple3<HashSet<Object>, Class, Integer> subGetLabelNumberAndMaxK(Row row, String str, String str2) {
        Class<?> cls;
        HashSet hashSet = new HashSet();
        if (!EvaluationUtil.checkRowFieldNotNull(row)) {
            return Tuple3.of(hashSet, (Object) null, 0);
        }
        List<Object> extractDistinctLabel = EvaluationUtil.extractDistinctLabel((String) row.getField(0), str);
        List<Object> extractDistinctLabel2 = EvaluationUtil.extractDistinctLabel((String) row.getField(1), str2);
        Class<?> cls2 = null;
        Class<?> cls3 = null;
        if (extractDistinctLabel.size() > 0) {
            cls2 = extractDistinctLabel.get(0).getClass();
        }
        if (extractDistinctLabel2.size() > 0) {
            cls3 = extractDistinctLabel2.get(0).getClass();
        }
        if (cls2 == null) {
            cls = cls3;
            hashSet.addAll(extractDistinctLabel2);
        } else if (cls3 == null) {
            cls = cls2;
            hashSet.addAll(extractDistinctLabel);
        } else if (cls2.equals(cls3)) {
            cls = cls2;
            hashSet.addAll(extractDistinctLabel);
            hashSet.addAll(extractDistinctLabel2);
        } else {
            cls = String.class;
            Iterator<Object> it = extractDistinctLabel.iterator();
            while (it.hasNext()) {
                hashSet.add(it.next().toString());
            }
            Iterator<Object> it2 = extractDistinctLabel2.iterator();
            while (it2.hasNext()) {
                hashSet.add(it2.next().toString());
            }
        }
        return Tuple3.of(hashSet, cls, Integer.valueOf(Math.max(extractDistinctLabel.size(), extractDistinctLabel2.size())));
    }

    public static Tuple3<Integer, Class, Integer> getLabelNumberAndMaxK(List<Row> list, String str, String str2) {
        Tuple3<HashSet<Object>, Class, Integer> subGetLabelNumberAndMaxK = subGetLabelNumberAndMaxK(list.get(0), str, str2);
        for (int i = 1; i < list.size(); i++) {
            Tuple3<HashSet<Object>, Class, Integer> subGetLabelNumberAndMaxK2 = subGetLabelNumberAndMaxK(list.get(i), str, str2);
            if (subGetLabelNumberAndMaxK.f1 == null) {
                AkPreconditions.checkArgument(((HashSet) subGetLabelNumberAndMaxK.f0).size() == 0 && ((Integer) subGetLabelNumberAndMaxK.f2).intValue() == 0, "LabelClass is null but label size is not 0!");
                subGetLabelNumberAndMaxK = subGetLabelNumberAndMaxK2;
            } else if (subGetLabelNumberAndMaxK2.f1 == null) {
                AkPreconditions.checkArgument(((HashSet) subGetLabelNumberAndMaxK2.f0).size() == 0 && ((Integer) subGetLabelNumberAndMaxK2.f2).intValue() == 0, "LabelClass is null but label size is not 0!");
            } else if (((Class) subGetLabelNumberAndMaxK.f1).equals(subGetLabelNumberAndMaxK2.f1)) {
                ((HashSet) subGetLabelNumberAndMaxK.f0).addAll((Collection) subGetLabelNumberAndMaxK2.f0);
                subGetLabelNumberAndMaxK.f2 = Integer.valueOf(Math.max(((Integer) subGetLabelNumberAndMaxK.f2).intValue(), ((Integer) subGetLabelNumberAndMaxK2.f2).intValue()));
            } else {
                HashSet hashSet = new HashSet();
                Iterator it = ((HashSet) subGetLabelNumberAndMaxK.f0).iterator();
                while (it.hasNext()) {
                    hashSet.add(it.next().toString());
                }
                Iterator it2 = ((HashSet) subGetLabelNumberAndMaxK2.f0).iterator();
                while (it2.hasNext()) {
                    hashSet.add(it2.next().toString());
                }
                subGetLabelNumberAndMaxK = Tuple3.of(hashSet, String.class, Integer.valueOf(Math.max(((Integer) subGetLabelNumberAndMaxK.f2).intValue(), ((Integer) subGetLabelNumberAndMaxK2.f2).intValue())));
            }
        }
        AkPreconditions.checkState(((HashSet) subGetLabelNumberAndMaxK.f0).size() > 0, "There is no valid data in the whole dataSet, please check the input for evaluation!");
        return Tuple3.of(Integer.valueOf(((HashSet) subGetLabelNumberAndMaxK.f0).size()), subGetLabelNumberAndMaxK.f1, subGetLabelNumberAndMaxK.f2);
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public /* bridge */ /* synthetic */ EvalMultiLabelLocalOp linkFrom(LocalOperator[] localOperatorArr) {
        return linkFrom((LocalOperator<?>[]) localOperatorArr);
    }

    @Override // com.alibaba.alink.operator.local.evaluation.EvaluationMetricsCollector
    public /* bridge */ /* synthetic */ MultiLabelMetrics createMetrics(List list) {
        return createMetrics((List<Row>) list);
    }
}
