package com.alibaba.alink.operator.local.feature;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.ReservedColsWithFirstInputSpec;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.feature.BaseCrossTrainBatchOp;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData;
import com.alibaba.alink.operator.common.feature.OneHotModelDataConverter;
import com.alibaba.alink.operator.common.feature.OneHotModelMapper;
import com.alibaba.alink.operator.common.recommendation.KObjectUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.operator.local.feature.BaseCrossTrainLocalOp;
import com.alibaba.alink.operator.local.source.TableSourceLocalOp;
import com.alibaba.alink.params.classification.RandomForestTrainParams;
import com.alibaba.alink.params.feature.AutoCrossTrainParams;
import com.alibaba.alink.params.feature.HasDropLast;
import com.alibaba.alink.params.shared.colname.HasOutputColsDefaultAsNull;
import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.TransformerBase;
import com.alibaba.alink.pipeline.feature.AutoCrossAlgoModel;
import com.alibaba.alink.pipeline.feature.OneHotEncoderModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

@NameCn("")
@ReservedColsWithFirstInputSpec
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "selectedCols"), @ParamSelectColumnSpec(name = "labelCol")})
/* loaded from: input_file:com/alibaba/alink/operator/local/feature/BaseCrossTrainLocalOp.class */
abstract class BaseCrossTrainLocalOp<T extends BaseCrossTrainLocalOp<T>> extends LocalOperator<T> implements AutoCrossTrainParams<T> {
    static final String oneHotVectorCol = "oneHotVectorCol";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/local/feature/BaseCrossTrainLocalOp$DataColumnsSaver.class */
    public static class DataColumnsSaver {
        String[] categoricalCols;
        String[] numericalCols;
        int[] numericalIndices;

        DataColumnsSaver(String[] strArr, String[] strArr2, int[] iArr) {
            this.categoricalCols = strArr;
            this.numericalCols = strArr2;
            this.numericalIndices = iArr;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public BaseCrossTrainLocalOp(Params params) {
        super(params);
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public T linkFrom(LocalOperator<?>... localOperatorArr) {
        LocalOperator<?> checkAndGetFirst = checkAndGetFirst(localOperatorArr);
        String[] strArr = (String[]) getParams().get(HasReservedColsDefaultAsNull.RESERVED_COLS);
        if (strArr == null) {
            strArr = checkAndGetFirst.getColNames();
        }
        String[] selectedCols = getSelectedCols();
        String labelCol = getLabelCol();
        String[] strArr2 = (String[]) ArrayUtils.add(selectedCols, labelCol);
        LocalOperator<?> select = checkAndGetFirst.select(strArr2);
        TableSchema schema = select.getSchema();
        String[] categoricalCols = TableUtil.getCategoricalCols(select.getSchema(), selectedCols, getParams().contains(RandomForestTrainParams.CATEGORICAL_COLS) ? (String[]) getParams().get(RandomForestTrainParams.CATEGORICAL_COLS) : null);
        if (null == categoricalCols || categoricalCols.length == 0) {
            throw new AkIllegalArgumentException("Please input param CategoricalCols!");
        }
        String[] strArr3 = (String[]) ArrayUtils.removeElements(selectedCols, categoricalCols);
        Params params = new Params().set((ParamInfo<ParamInfo<String[]>>) AutoCrossTrainParams.SELECTED_COLS, (ParamInfo<String[]>) categoricalCols);
        if (getParams().contains(AutoCrossTrainParams.DISCRETE_THRESHOLDS_ARRAY)) {
            params.set((ParamInfo<ParamInfo<Integer[]>>) AutoCrossTrainParams.DISCRETE_THRESHOLDS_ARRAY, (ParamInfo<Integer[]>) getDiscreteThresholdsArray());
        } else if (getParams().contains(AutoCrossTrainParams.DISCRETE_THRESHOLDS)) {
            params.set((ParamInfo<ParamInfo<Integer>>) AutoCrossTrainParams.DISCRETE_THRESHOLDS, (ParamInfo<Integer>) getDiscreteThresholds());
        }
        params.set((ParamInfo<ParamInfo<Boolean>>) HasDropLast.DROP_LAST, (ParamInfo<Boolean>) false).set((ParamInfo<ParamInfo<String[]>>) HasOutputColsDefaultAsNull.OUTPUT_COLS, (ParamInfo<String[]>) new String[]{oneHotVectorCol});
        OneHotTrainLocalOp linkFrom = new OneHotTrainLocalOp(params).linkFrom(select);
        OneHotEncoderModel oneHotEncoderModel = new OneHotEncoderModel(params);
        oneHotEncoderModel.setModelData(linkFrom);
        TransformerBase[] transformerBaseArr = new TransformerBase[2];
        transformerBaseArr[0] = oneHotEncoderModel;
        OneHotPredictLocalOp linkFrom2 = new OneHotPredictLocalOp(params).linkFrom(linkFrom, select);
        int i = OneHotModelMapper.isEnableElse(params) ? 2 : 1;
        MultiStringIndexerModelData multiStringIndexerModelData = new OneHotModelDataConverter().load(linkFrom.getOutput().getRows()).modelData;
        int size = multiStringIndexerModelData.tokenNumber.size();
        int[] iArr = new int[size];
        for (int i2 = 0; i2 < size; i2++) {
            iArr[i2] = (int) (multiStringIndexerModelData.tokenNumber.get(Integer.valueOf(i2)).longValue() + i);
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Row> it = linkFrom2.select(labelCol).distinct().getOutput().getRows().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getField(0));
        }
        int findColIndex = TableUtil.findColIndex(linkFrom2.getColNames(), oneHotVectorCol);
        int findColIndex2 = TableUtil.findColIndex(linkFrom2.getColNames(), labelCol);
        int[] findColIndicesWithAssert = TableUtil.findColIndicesWithAssert(linkFrom2.getSchema(), strArr3);
        DataColumnsSaver dataColumnsSaver = new DataColumnsSaver(categoricalCols, strArr3, findColIndicesWithAssert);
        ArrayList arrayList2 = new ArrayList();
        SparseVector sparseVector = VectorUtil.getSparseVector(linkFrom2.getOutputTable().getEntry(0, findColIndex));
        int length = findColIndicesWithAssert.length + sparseVector.getIndices().length;
        int length2 = findColIndicesWithAssert.length + sparseVector.size();
        for (Row row : linkFrom2.getOutputTable().getRows()) {
            int[] iArr2 = new int[length];
            double[] dArr = new double[length];
            for (int i3 = 0; i3 < findColIndicesWithAssert.length; i3++) {
                iArr2[i3] = i3;
                dArr[i3] = ((Number) row.getField(findColIndicesWithAssert[i3])).doubleValue();
            }
            SparseVector sparseVector2 = VectorUtil.getSparseVector(row.getField(findColIndex));
            int[] iArr3 = new int[sparseVector2.getIndices().length];
            for (int i4 = 0; i4 < iArr3.length; i4++) {
                iArr3[i4] = sparseVector2.getIndices()[i4] + findColIndicesWithAssert.length;
            }
            System.arraycopy(iArr3, 0, iArr2, findColIndicesWithAssert.length, sparseVector2.getIndices().length);
            System.arraycopy(sparseVector2.getValues(), 0, dArr, findColIndicesWithAssert.length, sparseVector2.getValues().length);
            arrayList2.add(Tuple2.of(0, Tuple3.of(Double.valueOf(1.0d), Double.valueOf(arrayList.get(0).equals(row.getField(findColIndex2)) ? 1.0d : Criteria.INVALID_GAIN), new SparseVector(length2, iArr2, dArr))));
        }
        ArrayList arrayList3 = new ArrayList();
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            arrayList3.add(((Tuple2) it2.next()).f1);
        }
        List<Row> buildAcModelData = buildAcModelData(arrayList3, iArr, dataColumnsSaver);
        TableSourceLocalOp tableSourceLocalOp = new TableSourceLocalOp(new MTable(buildAcModelData, new String[]{"feature_id", "cross_feature", KObjectUtil.SCORE_NAME}, (TypeInformation<?>[]) new TypeInformation[]{Types.LONG, Types.STRING, Types.DOUBLE}));
        Params params2 = getParams();
        params2.set((ParamInfo<ParamInfo<String[]>>) HasReservedColsDefaultAsNull.RESERVED_COLS, (ParamInfo<String[]>) strArr);
        transformerBaseArr[1] = (AutoCrossAlgoModel) new AutoCrossAlgoModel(params2).setModelData(tableSourceLocalOp);
        LocalOperator<?> saveLocal = new PipelineModel((TransformerBase<?>[]) transformerBaseArr).saveLocal();
        ArrayList arrayList4 = new ArrayList();
        int length3 = strArr2.length + 1;
        for (Row row2 : saveLocal.getOutputTable().getRows()) {
            Row row3 = new Row(row2.getArity() + length3);
            for (int i5 = 0; i5 < row2.getArity(); i5++) {
                row3.setField(i5, row2.getField(i5));
            }
            arrayList4.add(row3);
        }
        setOutputTable(new MTable(arrayList4, BaseCrossTrainBatchOp.getAutoCrossModelSchema(schema, saveLocal.getSchema(), strArr2)));
        buildSideOutput(linkFrom, buildAcModelData, Arrays.asList(strArr3));
        return this;
    }

    abstract List<Row> buildAcModelData(List<Tuple3<Double, Double, Vector>> list, int[] iArr, DataColumnsSaver dataColumnsSaver);

    abstract void buildSideOutput(OneHotTrainLocalOp oneHotTrainLocalOp, List<Row> list, List<String> list2);

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public /* bridge */ /* synthetic */ LocalOperator linkFrom(LocalOperator[] localOperatorArr) {
        return linkFrom((LocalOperator<?>[]) localOperatorArr);
    }
}
