package com.alibaba.alink.operator.batch.feature;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelDataConverter;
import com.alibaba.alink.operator.common.dataproc.StringIndexerUtil;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.params.dataproc.HasSelectedColTypes;
import com.alibaba.alink.params.feature.CrossFeatureTrainParams;
import com.alibaba.alink.params.shared.colname.HasSelectedCols;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import com.google.common.collect.Lists;
import java.util.List;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpec(name = "selectedCols")
@NameCn("Cross特征训练")
@NameEn("Cross Feature Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.CrossFeature")
/* loaded from: input_file:com/alibaba/alink/operator/batch/feature/CrossFeatureTrainBatchOp.class */
public class CrossFeatureTrainBatchOp extends BatchOperator<CrossFeatureTrainBatchOp> implements CrossFeatureTrainParams<CrossFeatureTrainBatchOp> {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/feature/CrossFeatureTrainBatchOp$BuildSideOutput.class */
    public static class BuildSideOutput extends RichMapPartitionFunction<Row, Row> {
        private BuildSideOutput() {
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
            MultiStringIndexerModelData load = new MultiStringIndexerModelDataConverter().load((List<Row>) Lists.newArrayList(iterable));
            int size = load.tokenNumber.size();
            String[] strArr = new String[size];
            int[] iArr = new int[size];
            for (int i = 0; i < size; i++) {
                iArr[i] = load.tokenNumber.get(Integer.valueOf(i)).intValue();
                strArr[i] = new String[iArr[i]];
            }
            for (Tuple3<Integer, String, Long> tuple3 : load.tokenAndIndex) {
                strArr[((Integer) tuple3.f0).intValue()][((Long) tuple3.f2).intValue()] = (String) tuple3.f1;
            }
            int[] iArr2 = new int[size + 1];
            iArr2[0] = -1;
            int i2 = 1;
            for (int i3 : iArr) {
                i2 *= i3;
            }
            for (int i4 = 0; i4 < i2; i4++) {
                carry(iArr2, iArr);
                StringBuilder sb = new StringBuilder();
                for (int i5 = 0; i5 < size; i5++) {
                    if (i5 != 0) {
                        sb.append(", ");
                    }
                    sb.append(strArr[i5][iArr2[i5]]);
                }
                collector.collect(Row.of(new Object[]{Integer.valueOf(i4), sb.toString()}));
            }
        }

        private static void carry(int[] iArr, int[] iArr2) {
            int i = 0;
            iArr[0] = iArr[0] + 1;
            while (iArr[i] == iArr2[i]) {
                int i2 = i;
                i++;
                iArr[i2] = 0;
                iArr[i] = iArr[i] + 1;
            }
        }
    }

    public CrossFeatureTrainBatchOp() {
        super(new Params());
    }

    public CrossFeatureTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public CrossFeatureTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        long longValue = getMLEnvironmentId().longValue();
        final String[] selectedCols = getSelectedCols();
        final String[] strArr = new String[selectedCols.length];
        for (int i = 0; i < selectedCols.length; i++) {
            strArr[i] = FlinkTypeConverter.getTypeString(TableUtil.findColTypeWithAssertAndHint(checkAndGetFirst.getSchema(), selectedCols[i]));
        }
        Operator name = StringIndexerUtil.indexRandom(checkAndGetFirst.select(selectedCols).getDataSet(), 0L, false).mapPartition(new RichMapPartitionFunction<Tuple3<Integer, String, Long>, Row>() { // from class: com.alibaba.alink.operator.batch.feature.CrossFeatureTrainBatchOp.1
            private static final long serialVersionUID = 2876851020570715540L;

            public void mapPartition(Iterable<Tuple3<Integer, String, Long>> iterable, Collector<Row> collector) throws Exception {
                Params params = null;
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    params = new Params().set((ParamInfo<ParamInfo<String[]>>) HasSelectedCols.SELECTED_COLS, (ParamInfo<String[]>) selectedCols).set((ParamInfo<ParamInfo<String[]>>) HasSelectedColTypes.SELECTED_COL_TYPES, (ParamInfo<String[]>) strArr);
                }
                new MultiStringIndexerModelDataConverter().save2(Tuple2.of(params, iterable), collector);
            }
        }).name("build_model");
        setOutput((DataSet<Row>) name, new MultiStringIndexerModelDataConverter().getModelSchema());
        setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(Long.valueOf(longValue), (DataSet<Row>) name.mapPartition(new BuildSideOutput()).setParallelism(1), new String[]{"index", "value"}, (TypeInformation<?>[]) new TypeInformation[]{Types.INT, Types.STRING})});
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ CrossFeatureTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
