package com.alibaba.alink.operator.batch.feature;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.ReservedColsWithFirstInputSpec;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.PipelineModelMapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.feature.BaseCrossTrainBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.dataproc.FirstReducer;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData;
import com.alibaba.alink.operator.common.feature.OneHotModelDataConverter;
import com.alibaba.alink.operator.common.feature.OneHotModelMapper;
import com.alibaba.alink.operator.common.recommendation.KObjectUtil;
import com.alibaba.alink.operator.common.slidingwindow.SessionSharedData;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.classification.RandomForestTrainParams;
import com.alibaba.alink.params.feature.AutoCrossTrainParams;
import com.alibaba.alink.params.feature.HasDropLast;
import com.alibaba.alink.params.shared.colname.HasOutputColsDefaultAsNull;
import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.TransformerBase;
import com.alibaba.alink.pipeline.feature.AutoCrossAlgoModel;
import com.alibaba.alink.pipeline.feature.OneHotEncoderModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@NameCn("")
@ReservedColsWithFirstInputSpec
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "selectedCols"), @ParamSelectColumnSpec(name = "labelCol")})
/* loaded from: input_file:com/alibaba/alink/operator/batch/feature/BaseCrossTrainBatchOp.class */
public abstract class BaseCrossTrainBatchOp<T extends BaseCrossTrainBatchOp<T>> extends BatchOperator<T> implements AutoCrossTrainParams<T> {
    static final String oneHotVectorCol = "oneHotVectorCol";
    boolean hasDiscrete;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/BaseCrossTrainBatchOp$BuildFeatureSize.class */
    public static class BuildFeatureSize extends RichMapFunction<int[], int[]> {
        private static final long serialVersionUID = 873642749154257046L;
        private final int additionalSize;
        private int[] featureSize;

        BuildFeatureSize(boolean z) {
            this.additionalSize = z ? 2 : 1;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            MultiStringIndexerModelData multiStringIndexerModelData = new OneHotModelDataConverter().load(getRuntimeContext().getBroadcastVariable("oneHotModel")).modelData;
            int size = multiStringIndexerModelData.tokenNumber.size();
            this.featureSize = new int[size];
            for (int i = 0; i < size; i++) {
                this.featureSize[i] = (int) (multiStringIndexerModelData.tokenNumber.get(Integer.valueOf(i)).longValue() + this.additionalSize);
            }
        }

        public int[] map(int[] iArr) throws Exception {
            return this.featureSize;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/BaseCrossTrainBatchOp$DataColumnsSaver.class */
    public static class DataColumnsSaver {
        String[] categoricalCols;
        String[] numericalCols;
        int[] numericalIndices;

        DataColumnsSaver(String[] strArr, String[] strArr2, int[] iArr) {
            this.categoricalCols = strArr;
            this.numericalCols = strArr2;
            this.numericalIndices = iArr;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/BaseCrossTrainBatchOp$GetTrainData.class */
    public static class GetTrainData extends RichMapPartitionFunction<Row, Tuple2<Integer, Tuple3<Double, Double, Vector>>> {
        private static final long serialVersionUID = -4406174781328407356L;
        private int[] numericalIndices;
        private int svIndex;
        private int labelIndex;
        private Object positiveLabel;

        GetTrainData(int[] iArr, int i, int i2) {
            this.svIndex = i;
            this.labelIndex = i2;
            this.numericalIndices = iArr;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Integer, Tuple3<Double, Double, Vector>>> collector) throws Exception {
            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            int i = -1;
            int i2 = -1;
            int[] iArr = null;
            double[] dArr = null;
            for (Row row : iterable) {
                if (i == -1) {
                    SparseVector sparseVector = VectorUtil.getSparseVector(row.getField(this.svIndex));
                    i = this.numericalIndices.length + sparseVector.getIndices().length;
                    i2 = this.numericalIndices.length + sparseVector.size();
                    iArr = new int[i];
                    dArr = new double[i];
                }
                for (int i3 = 0; i3 < this.numericalIndices.length; i3++) {
                    iArr[i3] = i3;
                    dArr[i3] = ((Number) row.getField(this.numericalIndices[i3])).doubleValue();
                }
                SparseVector sparseVector2 = VectorUtil.getSparseVector(row.getField(this.svIndex));
                int[] iArr2 = new int[sparseVector2.getIndices().length];
                for (int i4 = 0; i4 < iArr2.length; i4++) {
                    iArr2[i4] = sparseVector2.getIndices()[i4] + this.numericalIndices.length;
                }
                System.arraycopy(iArr2, 0, iArr, this.numericalIndices.length, sparseVector2.getIndices().length);
                System.arraycopy(sparseVector2.getValues(), 0, dArr, this.numericalIndices.length, sparseVector2.getValues().length);
                for (int i5 = 0; i5 < numberOfParallelSubtasks; i5++) {
                    collector.collect(Tuple2.of(Integer.valueOf(i5), Tuple3.of(Double.valueOf(1.0d), Double.valueOf(this.positiveLabel.equals(row.getField(this.labelIndex)) ? 1.0d : Criteria.INVALID_GAIN), new SparseVector(i2, iArr, dArr))));
                }
            }
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.positiveLabel = getRuntimeContext().getBroadcastVariable("positiveLabel").get(0);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public BaseCrossTrainBatchOp(Params params) {
        super(params);
        this.hasDiscrete = true;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v32, types: [int[], java.lang.Object[]] */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public T linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String[] strArr = (String[]) getParams().get(HasReservedColsDefaultAsNull.RESERVED_COLS);
        if (strArr == null) {
            strArr = checkAndGetFirst.getColNames();
        }
        long longValue = getMLEnvironmentId().longValue();
        String[] selectedCols = getSelectedCols();
        String labelCol = getLabelCol();
        String[] strArr2 = (String[]) ArrayUtils.add(selectedCols, labelCol);
        BatchOperator<?> select = checkAndGetFirst.select(strArr2);
        TableSchema schema = select.getSchema();
        ExecutionEnvironment executionEnvironment = MLEnvironmentFactory.get(Long.valueOf(longValue)).getExecutionEnvironment();
        String[] categoricalCols = TableUtil.getCategoricalCols(select.getSchema(), selectedCols, getParams().contains(RandomForestTrainParams.CATEGORICAL_COLS) ? (String[]) getParams().get(RandomForestTrainParams.CATEGORICAL_COLS) : null);
        if (null == categoricalCols || categoricalCols.length == 0) {
            throw new AkIllegalArgumentException("Please input param CategoricalCols!");
        }
        String[] strArr3 = (String[]) ArrayUtils.removeElements(selectedCols, categoricalCols);
        Params params = new Params().set((ParamInfo<ParamInfo<String[]>>) AutoCrossTrainParams.SELECTED_COLS, (ParamInfo<String[]>) categoricalCols);
        if (getParams().contains(AutoCrossTrainParams.DISCRETE_THRESHOLDS_ARRAY)) {
            params.set((ParamInfo<ParamInfo<Integer[]>>) AutoCrossTrainParams.DISCRETE_THRESHOLDS_ARRAY, (ParamInfo<Integer[]>) getDiscreteThresholdsArray());
        } else if (getParams().contains(AutoCrossTrainParams.DISCRETE_THRESHOLDS)) {
            params.set((ParamInfo<ParamInfo<Integer>>) AutoCrossTrainParams.DISCRETE_THRESHOLDS, (ParamInfo<Integer>) getDiscreteThresholds());
        }
        params.set((ParamInfo<ParamInfo<Boolean>>) HasDropLast.DROP_LAST, (ParamInfo<Boolean>) false).set((ParamInfo<ParamInfo<String[]>>) HasOutputColsDefaultAsNull.OUTPUT_COLS, (ParamInfo<String[]>) new String[]{oneHotVectorCol});
        OneHotTrainBatchOp linkFrom = ((OneHotTrainBatchOp) new OneHotTrainBatchOp(params).setMLEnvironmentId(Long.valueOf(longValue))).linkFrom(select);
        OneHotEncoderModel oneHotEncoderModel = (OneHotEncoderModel) new OneHotEncoderModel(params).setMLEnvironmentId(Long.valueOf(longValue));
        oneHotEncoderModel.setModelData(linkFrom);
        OneHotPredictBatchOp linkFrom2 = ((OneHotPredictBatchOp) new OneHotPredictBatchOp(params).setMLEnvironmentId(getMLEnvironmentId())).linkFrom(linkFrom, select);
        this.hasDiscrete = OneHotModelMapper.isEnableElse(params);
        SingleInputUdfOperator withBroadcastSet = executionEnvironment.fromElements((Object[]) new int[]{new int[0]}).map(new BuildFeatureSize(this.hasDiscrete)).withBroadcastSet(linkFrom.getDataSet(), "oneHotModel");
        MapOperator map = linkFrom2.select(labelCol).getDataSet().reduceGroup(new FirstReducer(1)).map(new MapFunction<Row, Object>() { // from class: com.alibaba.alink.operator.batch.feature.BaseCrossTrainBatchOp.1
            private static final long serialVersionUID = 110081999458221448L;

            public Object map(Row row) throws Exception {
                return row.getField(0);
            }
        });
        int findColIndex = TableUtil.findColIndex(linkFrom2.getColNames(), oneHotVectorCol);
        int findColIndex2 = TableUtil.findColIndex(linkFrom2.getColNames(), labelCol);
        int[] findColIndicesWithAssert = TableUtil.findColIndicesWithAssert(linkFrom2.getSchema(), strArr3);
        DataSet<Row> buildAcModelData = buildAcModelData(linkFrom2.getDataSet().rebalance().mapPartition(new GetTrainData(findColIndicesWithAssert, findColIndex, findColIndex2)).withBroadcastSet(map, "positiveLabel").partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.batch.feature.BaseCrossTrainBatchOp.2
            private static final long serialVersionUID = 5552966434608252752L;

            public int partition(Integer num, int i) {
                return num.intValue();
            }
        }, 0).mapPartition(new RichMapPartitionFunction<Tuple2<Integer, Tuple3<Double, Double, Vector>>, Tuple3<Double, Double, Vector>>() { // from class: com.alibaba.alink.operator.batch.feature.BaseCrossTrainBatchOp.3
            public void mapPartition(Iterable<Tuple2<Integer, Tuple3<Double, Double, Vector>>> iterable, Collector<Tuple3<Double, Double, Vector>> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                Iterator<Tuple2<Integer, Tuple3<Double, Double, Vector>>> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next().f1);
                }
                SessionSharedData.put("AC_TRAIN_DATA", AutoCrossTrainBatchOp.SESSION_ID, getRuntimeContext().getIndexOfThisSubtask(), arrayList);
            }
        }), withBroadcastSet, new DataColumnsSaver(categoricalCols, strArr3, findColIndicesWithAssert));
        BatchOperator<?> fromTable = BatchOperator.fromTable(DataSetConversionUtil.toTable(Long.valueOf(longValue), buildAcModelData, new String[]{"feature_id", "cross_feature", KObjectUtil.SCORE_NAME}, (TypeInformation<?>[]) new TypeInformation[]{Types.LONG, Types.STRING, Types.DOUBLE}));
        Params params2 = getParams();
        params2.set((ParamInfo<ParamInfo<String[]>>) HasReservedColsDefaultAsNull.RESERVED_COLS, (ParamInfo<String[]>) strArr);
        BatchOperator<?> save = new PipelineModel((TransformerBase<?>[]) new TransformerBase[]{oneHotEncoderModel, (AutoCrossAlgoModel) ((AutoCrossAlgoModel) new AutoCrossAlgoModel(params2).setModelData(fromTable)).setMLEnvironmentId(Long.valueOf(longValue))}).save();
        setOutput((DataSet<Row>) save.getDataSet().map(new PipelineModelMapper.ExtendPipelineModelRow(strArr2.length + 1)), getAutoCrossModelSchema(schema, save.getSchema(), strArr2));
        buildSideOutput(linkFrom, buildAcModelData, Arrays.asList(strArr3), longValue);
        return this;
    }

    abstract DataSet<Row> buildAcModelData(DataSet<Tuple3<Double, Double, Vector>> dataSet, DataSet<int[]> dataSet2, DataColumnsSaver dataColumnsSaver);

    abstract void buildSideOutput(OneHotTrainBatchOp oneHotTrainBatchOp, DataSet<Row> dataSet, List<String> list, long j);

    public static TableSchema getAutoCrossModelSchema(TableSchema tableSchema, TableSchema tableSchema2, String[] strArr) {
        int length = tableSchema2.getFieldNames().length;
        String[] strArr2 = new String[length + 1 + strArr.length];
        TypeInformation[] typeInformationArr = new TypeInformation[length + 1 + strArr.length];
        System.arraycopy(tableSchema2.getFieldNames(), 0, strArr2, 0, length);
        System.arraycopy(tableSchema2.getFieldTypes(), 0, typeInformationArr, 0, length);
        strArr2[length] = PipelineModelMapper.SPLITER_COL_NAME;
        typeInformationArr[length] = PipelineModelMapper.SPLITER_COL_TYPE;
        System.arraycopy(strArr, 0, strArr2, length + 1, strArr.length);
        for (int i = 0; i < strArr.length; i++) {
            typeInformationArr[i + length + 1] = tableSchema.getFieldTypes()[TableUtil.findColIndex(tableSchema, strArr[i])];
        }
        return new TableSchema(strArr2, typeInformationArr);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ BatchOperator linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
