package com.alibaba.alink.pipeline.classification;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp;
import com.alibaba.alink.operator.batch.sql.UnionAllBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.operator.common.io.types.JdbcTypeConverter;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.classification.OneVsRestPredictParams;
import com.alibaba.alink.params.classification.OneVsRestTrainParams;
import com.alibaba.alink.params.shared.colname.HasLabelCol;
import com.alibaba.alink.params.shared.linear.HasPositiveLabelValueString;
import com.alibaba.alink.pipeline.EstimatorBase;
import com.alibaba.alink.pipeline.ModelBase;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.UnionOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@NameCn("OneVsRest")
/* loaded from: input_file:com/alibaba/alink/pipeline/classification/OneVsRest.class */
public class OneVsRest extends EstimatorBase<OneVsRest, OneVsRestModel> implements OneVsRestTrainParams<OneVsRest>, OneVsRestPredictParams<OneVsRest> {
    private static final long serialVersionUID = -5340633471006011434L;
    private EstimatorBase<?, ?> classifier;

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public OneVsRestModel fit(BatchOperator<?> batchOperator) {
        String str = (String) this.classifier.getParams().get(HasLabelCol.LABEL_COL);
        BatchOperator<?> allLabels = getAllLabels(batchOperator, str);
        int intValue = getNumClass().intValue();
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str);
        TypeInformation<?> typeInformation = batchOperator.getColTypes()[findColIndexWithAssertAndHint];
        ModelBase<?>[] modelBaseArr = new ModelBase[intValue];
        for (int i = 0; i < intValue; i++) {
            this.classifier.set(HasPositiveLabelValueString.POS_LABEL_VAL_STR, "1");
            modelBaseArr[i] = this.classifier.fit(generateTrainData(batchOperator, allLabels, i, findColIndexWithAssertAndHint));
        }
        Table unionAllModels = unionAllModels(modelBaseArr);
        Table createModelMeta = createModelMeta(new Params().set((ParamInfo<ParamInfo<Integer>>) ModelParamName.NUM_CLASSES, (ParamInfo<Integer>) Integer.valueOf(intValue)).set((ParamInfo<ParamInfo<String>>) ModelParamName.BIN_CLS_CLASS_NAME, (ParamInfo<String>) this.classifier.getClass().getCanonicalName()).set((ParamInfo<ParamInfo<String>>) ModelParamName.BIN_CLS_PARAMS, (ParamInfo<String>) this.classifier.getParams().toJson()).set((ParamInfo<ParamInfo<String>>) ModelParamName.LABEL_TYPE_NAME, (ParamInfo<String>) FlinkTypeConverter.getTypeString(typeInformation)).set((ParamInfo<ParamInfo<String[]>>) ModelParamName.MODEL_COL_NAMES, (ParamInfo<String[]>) modelBaseArr[0].getModelData().getSchema().getFieldNames()).set((ParamInfo<ParamInfo<Integer[]>>) ModelParamName.MODEL_COL_TYPES, (ParamInfo<Integer[]>) toJdbcColTypes(modelBaseArr[0].getModelData().getSchema().getFieldTypes())), allLabels);
        OneVsRestModel oneVsRestModel = new OneVsRestModel(this.classifier.getParams().m1495clone().merge(getParams()));
        oneVsRestModel.setModelData(BatchOperator.fromTable(concatTables(new Table[]{createModelMeta, unionAllModels, allLabels.getOutputTable()}, getMLEnvironmentId())));
        return oneVsRestModel;
    }

    public OneVsRest setClassifier(EstimatorBase<?, ?> estimatorBase) {
        this.classifier = estimatorBase;
        return this;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static Table concatTables(Table[] tableArr, Long l) {
        final int[] iArr = new int[tableArr.length];
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("table_id");
        arrayList2.add(Types.LONG);
        for (int i = 0; i < tableArr.length; i++) {
            if (tableArr[i] == null) {
                iArr[i] = 0;
            } else {
                iArr[i] = tableArr[i].getSchema().getFieldNames().length;
                String[] strArr = (String[]) tableArr[i].getSchema().getFieldNames().clone();
                for (int i2 = 0; i2 < strArr.length; i2++) {
                    strArr[i2] = String.format("t%d_%s", Integer.valueOf(i), strArr[i2]);
                }
                arrayList.addAll(Arrays.asList(strArr));
                arrayList2.addAll(Arrays.asList(tableArr[i].getSchema().getFieldTypes()));
            }
        }
        if (arrayList.size() == 1) {
            return null;
        }
        UnionOperator unionOperator = null;
        int i3 = 1;
        final int size = arrayList.size();
        for (int i4 = 0; i4 < tableArr.length; i4++) {
            if (tableArr[i4] != null) {
                final int i5 = i3;
                final int i6 = i4;
                UnionOperator map = ((BatchOperator) BatchOperator.fromTable(tableArr[i4]).setMLEnvironmentId(l)).getDataSet().map(new RichMapFunction<Row, Row>() { // from class: com.alibaba.alink.pipeline.classification.OneVsRest.1
                    private static final long serialVersionUID = -8085823678072944808L;
                    transient Row reused;

                    public void open(Configuration configuration) {
                        this.reused = new Row(size);
                    }

                    public Row map(Row row) {
                        for (int i7 = 0; i7 < size; i7++) {
                            this.reused.setField(i7, (Object) null);
                        }
                        this.reused.setField(0, Long.valueOf(i6));
                        for (int i8 = 0; i8 < iArr[i6]; i8++) {
                            this.reused.setField(i5 + i8, row.getField(i8));
                        }
                        return this.reused;
                    }
                });
                unionOperator = unionOperator == null ? map : unionOperator.union(map);
                i3 += iArr[i4];
            }
        }
        return DataSetConversionUtil.toTable(l, (DataSet<Row>) unionOperator, (String[]) arrayList.toArray(new String[0]), (TypeInformation<?>[]) arrayList2.toArray(new TypeInformation[0]));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static BatchOperator<?> generateTrainData(BatchOperator<?> batchOperator, BatchOperator<?> batchOperator2, final int i, final int i2) {
        Operator name = batchOperator.getDataSet().map(new RichMapFunction<Row, Row>() { // from class: com.alibaba.alink.pipeline.classification.OneVsRest.2
            private static final long serialVersionUID = 6739243159349750842L;
            transient Object label;

            public void open(Configuration configuration) throws Exception {
                final List broadcastVariable = getRuntimeContext().getBroadcastVariable("allLabels");
                Integer[] numArr = new Integer[broadcastVariable.size()];
                for (int i3 = 0; i3 < numArr.length; i3++) {
                    numArr[i3] = Integer.valueOf(i3);
                }
                Arrays.sort(numArr, new Comparator<Integer>() { // from class: com.alibaba.alink.pipeline.classification.OneVsRest.2.1
                    @Override // java.util.Comparator
                    public int compare(Integer num, Integer num2) {
                        return ((Comparable) ((Row) broadcastVariable.get(num.intValue())).getField(0)).compareTo((Comparable) ((Row) broadcastVariable.get(num2.intValue())).getField(0));
                    }
                });
                if (i >= broadcastVariable.size()) {
                    throw new RuntimeException("the specified numClasses is larger than the number of distinct labels.: " + String.format("iLabel = %d, num lables = %d", Integer.valueOf(i), Integer.valueOf(broadcastVariable.size())));
                }
                this.label = ((Row) broadcastVariable.get(numArr[i].intValue())).getField(0);
            }

            public Row map(Row row) {
                for (int i3 = 0; i3 < row.getArity(); i3++) {
                    if (i3 == i2) {
                        if (row.getField(i3).equals(this.label)) {
                            row.setField(i3, Double.valueOf(1.0d));
                        } else {
                            row.setField(i3, Double.valueOf(Criteria.INVALID_GAIN));
                        }
                    }
                }
                return row;
            }
        }).withBroadcastSet(batchOperator2.getDataSet(), "allLabels").name("CreateTrainData#" + i);
        TypeInformation[] typeInformationArr = (TypeInformation[]) batchOperator.getColTypes().clone();
        typeInformationArr[i2] = Types.DOUBLE;
        return (BatchOperator) new DataSetWrapperBatchOp(name, batchOperator.getColNames(), typeInformationArr).setMLEnvironmentId(batchOperator.getMLEnvironmentId());
    }

    private Table createModelMeta(Params params, BatchOperator<?> batchOperator) {
        return DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new String[]{params.toJson()}).map(new RichMapFunction<String, String>() { // from class: com.alibaba.alink.pipeline.classification.OneVsRest.4
            private static final long serialVersionUID = 6749489554360703883L;

            public String map(String str) throws Exception {
                Params fromJson = Params.fromJson(str);
                List broadcastVariable = getRuntimeContext().getBroadcastVariable("allLabels");
                Integer[] numArr = new Integer[broadcastVariable.size()];
                for (int i = 0; i < numArr.length; i++) {
                    numArr[i] = Integer.valueOf(i);
                }
                Arrays.sort(numArr, (num, num2) -> {
                    return ((Comparable) ((Row) broadcastVariable.get(num.intValue())).getField(0)).compareTo((Comparable) ((Row) broadcastVariable.get(num2.intValue())).getField(0));
                });
                ArrayList arrayList = new ArrayList(numArr.length);
                for (Integer num3 : numArr) {
                    arrayList.add(((Row) broadcastVariable.get(num3.intValue())).getField(0));
                }
                fromJson.set((ParamInfo<ParamInfo<String>>) ModelParamName.LABELS, (ParamInfo<String>) JsonConverter.gson.toJson(arrayList, ArrayList.class));
                return fromJson.toJson();
            }
        }).withBroadcastSet(batchOperator.getDataSet(), "allLabels").map(new MapFunction<String, Row>() { // from class: com.alibaba.alink.pipeline.classification.OneVsRest.3
            private static final long serialVersionUID = 7457969876114448730L;

            public Row map(String str) throws Exception {
                return Row.of(new Object[]{str});
            }
        }), new String[]{"meta"}, (TypeInformation<?>[]) new TypeInformation[]{Types.STRING});
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Table unionAllModels(ModelBase<?>[] modelBaseArr) {
        BatchOperator<?> batchOperator = null;
        int i = 0;
        while (i < modelBaseArr.length) {
            BatchOperator<?> select = modelBaseArr[i].getModelData().select(String.format("CAST(%d as bigint) AS ovr_id, *", Integer.valueOf(i)));
            batchOperator = i == 0 ? select : ((UnionAllBatchOp) new UnionAllBatchOp().setMLEnvironmentId(getMLEnvironmentId())).linkFrom(batchOperator, select);
            i++;
        }
        return batchOperator.getOutputTable();
    }

    private Integer[] toJdbcColTypes(TypeInformation<?>[] typeInformationArr) {
        Integer[] numArr = new Integer[typeInformationArr.length];
        for (int i = 0; i < typeInformationArr.length; i++) {
            numArr[i] = Integer.valueOf(JdbcTypeConverter.getIntegerSqlType(typeInformationArr[i]));
        }
        return numArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private BatchOperator<?> getAllLabels(BatchOperator<?> batchOperator, String str) {
        final int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str);
        return (BatchOperator) new DataSetWrapperBatchOp(batchOperator.getDataSet().mapPartition(new MapPartitionFunction<Row, Tuple1<Comparable>>() { // from class: com.alibaba.alink.pipeline.classification.OneVsRest.6
            private static final long serialVersionUID = 8885186467183165174L;

            public void mapPartition(Iterable<Row> iterable, Collector<Tuple1<Comparable>> collector) {
                HashSet hashSet = new HashSet();
                int i = findColIndexWithAssertAndHint;
                iterable.forEach(row -> {
                    hashSet.add(row.getField(i));
                });
                hashSet.forEach(obj -> {
                    collector.collect(Tuple1.of((Comparable) obj));
                });
            }
        }).groupBy(new int[]{0}).reduceGroup(new GroupReduceFunction<Tuple1<Comparable>, Row>() { // from class: com.alibaba.alink.pipeline.classification.OneVsRest.5
            private static final long serialVersionUID = -3338967361764454088L;

            public void reduce(Iterable<Tuple1<Comparable>> iterable, Collector<Row> collector) throws Exception {
                collector.collect(Row.of(new Object[]{iterable.iterator().next().f0}));
            }
        }), new String[]{str}, new TypeInformation[]{batchOperator.getColTypes()[findColIndexWithAssertAndHint]}).setMLEnvironmentId(batchOperator.getMLEnvironmentId());
    }

    @Override // com.alibaba.alink.pipeline.EstimatorBase
    public /* bridge */ /* synthetic */ OneVsRestModel fit(BatchOperator batchOperator) {
        return fit((BatchOperator<?>) batchOperator);
    }
}
