package com.alibaba.alink.operator.batch.dataproc;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.dataproc.HugeStringIndexerUtil;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelDataConverter;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.params.dataproc.HasSelectedColTypes;
import com.alibaba.alink.params.dataproc.HasStringOrderTypeDefaultAsRandom;
import com.alibaba.alink.params.dataproc.MultiStringIndexerTrainParams;
import com.alibaba.alink.params.shared.colname.HasSelectedCols;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = {TypeCollections.INT_LONG_STRING_TYPES})
@NameCn("多字段字符串编码训练")
@NameEn("Multiple String Indexer Train")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.dataproc.MultiStringIndexer")
/* loaded from: input_file:com/alibaba/alink/operator/batch/dataproc/MultiStringIndexerTrainBatchOp.class */
public final class MultiStringIndexerTrainBatchOp extends BatchOperator<MultiStringIndexerTrainBatchOp> implements MultiStringIndexerTrainParams<MultiStringIndexerTrainBatchOp> {
    private static final long serialVersionUID = 3760905390429627737L;

    public MultiStringIndexerTrainBatchOp() {
        this(new Params());
    }

    public MultiStringIndexerTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public MultiStringIndexerTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        final String[] selectedCols = getSelectedCols();
        HasStringOrderTypeDefaultAsRandom.StringOrderType stringOrderType = getStringOrderType();
        final String[] strArr = new String[selectedCols.length];
        for (int i = 0; i < selectedCols.length; i++) {
            strArr[i] = FlinkTypeConverter.getTypeString(TableUtil.findColTypeWithAssertAndHint(checkAndGetFirst.getSchema(), selectedCols[i]));
        }
        setOutput((DataSet<Row>) HugeStringIndexerUtil.indexTokens(checkAndGetFirst.select(selectedCols).getDataSet().flatMap(new FlatMapFunction<Row, Tuple2<Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.MultiStringIndexerTrainBatchOp.1
            public void flatMap(Row row, Collector<Tuple2<Integer, String>> collector) throws Exception {
                for (int i2 = 0; i2 < selectedCols.length; i2++) {
                    Object field = row.getField(i2);
                    if (null != field) {
                        collector.collect(Tuple2.of(Integer.valueOf(i2), String.valueOf(field)));
                    }
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Tuple2<Integer, String>>) collector);
            }
        }).returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.STRING})), stringOrderType, 0L).mapPartition(new RichMapPartitionFunction<Tuple3<Integer, String, Long>, Row>() { // from class: com.alibaba.alink.operator.batch.dataproc.MultiStringIndexerTrainBatchOp.2
            private static final long serialVersionUID = 2876851020570715540L;

            public void mapPartition(Iterable<Tuple3<Integer, String, Long>> iterable, Collector<Row> collector) throws Exception {
                Params params = null;
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    params = new Params().set((ParamInfo<ParamInfo<String[]>>) HasSelectedCols.SELECTED_COLS, (ParamInfo<String[]>) selectedCols).set((ParamInfo<ParamInfo<String[]>>) HasSelectedColTypes.SELECTED_COL_TYPES, (ParamInfo<String[]>) strArr);
                }
                new MultiStringIndexerModelDataConverter().save2(Tuple2.of(params, iterable), collector);
            }
        }).name("build_model").returns(new RowTypeInfo(new MultiStringIndexerModelDataConverter().getModelSchema().getFieldTypes())), new MultiStringIndexerModelDataConverter().getModelSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ MultiStringIndexerTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
