package com.alibaba.alink.operator.batch.dataproc;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithFirstInputSpec;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalModelException;
import com.alibaba.alink.common.utils.OutputColsHelper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.params.dataproc.HasHandleInvalid;
import com.alibaba.alink.params.dataproc.HugeMultiStringIndexerPredictParams;
import com.alibaba.alink.params.dataproc.MultiStringIndexerPredictParams;
import com.alibaba.alink.params.dataproc.StringIndexerPredictParams;
import com.alibaba.alink.params.shared.colname.HasSelectedCols;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(value = PortType.MODEL, desc = PortDesc.PREDICT_INPUT_MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.PREDICT_INPUT_DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = {TypeCollections.LONG_TYPES})
@NameCn("多列并行反ID化预测")
@ReservedColsWithFirstInputSpec
@NameEn("Huge Multi Indexer String Prediction")
/* loaded from: input_file:com/alibaba/alink/operator/batch/dataproc/HugeMultiIndexerStringPredictBatchOp.class */
public final class HugeMultiIndexerStringPredictBatchOp extends BatchOperator<HugeMultiIndexerStringPredictBatchOp> implements HugeMultiStringIndexerPredictParams<HugeMultiIndexerStringPredictBatchOp> {
    private static final long serialVersionUID = -1392825675494011436L;

    public HugeMultiIndexerStringPredictBatchOp() {
        this(new Params());
    }

    public HugeMultiIndexerStringPredictBatchOp(Params params) {
        super(params);
    }

    private DataSet<String> getModelMeta(BatchOperator batchOperator) {
        return batchOperator.getDataSet().flatMap(new RichFlatMapFunction<Row, String>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiIndexerStringPredictBatchOp.1
            private static final long serialVersionUID = -4936189616000293070L;

            public void flatMap(Row row, Collector<String> collector) throws Exception {
                if (((Long) row.getField(0)).longValue() < 0) {
                    collector.collect((String) row.getField(1));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<String>) collector);
            }
        }).name("get_model_meta").returns(Types.STRING);
    }

    private DataSet<Tuple3<Integer, String, Long>> getModelData(BatchOperator batchOperator, DataSet<String> dataSet, final String[] strArr) {
        return batchOperator.getDataSet().flatMap(new RichFlatMapFunction<Row, Tuple3<Integer, String, Long>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiIndexerStringPredictBatchOp.2
            private static final long serialVersionUID = 3103936416747450973L;
            transient int[] selectedColIdxInModel;

            public void open(Configuration configuration) throws Exception {
                List broadcastVariable = getRuntimeContext().getBroadcastVariable("modelMeta");
                if (broadcastVariable.size() != 1) {
                    throw new AkIllegalModelException("Invalid model.");
                }
                String[] strArr2 = (String[]) Params.fromJson((String) broadcastVariable.get(0)).get(HasSelectedCols.SELECTED_COLS);
                this.selectedColIdxInModel = new int[strArr.length];
                for (int i = 0; i < strArr.length; i++) {
                    String str = strArr[i];
                    int findColIndex = TableUtil.findColIndex(strArr2, str);
                    if (findColIndex < 0) {
                        throw new AkIllegalModelException("Can't find col in model: " + str);
                    }
                    this.selectedColIdxInModel[i] = findColIndex;
                }
            }

            public void flatMap(Row row, Collector<Tuple3<Integer, String, Long>> collector) throws Exception {
                if (((Long) row.getField(0)).longValue() >= 0) {
                    int intValue = ((Long) row.getField(0)).intValue();
                    for (int i = 0; i < this.selectedColIdxInModel.length; i++) {
                        if (this.selectedColIdxInModel[i] == intValue) {
                            collector.collect(Tuple3.of(Integer.valueOf(i), (String) row.getField(1), (Long) row.getField(2)));
                            return;
                        }
                    }
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Tuple3<Integer, String, Long>>) collector);
            }
        }).withBroadcastSet(dataSet, "modelMeta").name("get_model_data").returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.STRING, Types.LONG}));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public HugeMultiIndexerStringPredictBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        Params params = super.getParams();
        BatchOperator<?> batchOperator = batchOperatorArr[0];
        BatchOperator<?> batchOperator2 = batchOperatorArr[1];
        String[] strArr = (String[]) params.get(MultiStringIndexerPredictParams.SELECTED_COLS);
        String[] strArr2 = (String[]) params.get(MultiStringIndexerPredictParams.OUTPUT_COLS);
        if (strArr2 == null) {
            strArr2 = strArr;
        }
        String[] strArr3 = (String[]) params.get(StringIndexerPredictParams.RESERVED_COLS);
        TypeInformation[] typeInformationArr = new TypeInformation[strArr2.length];
        Arrays.fill(typeInformationArr, Types.STRING);
        final OutputColsHelper outputColsHelper = new OutputColsHelper(batchOperator2.getSchema(), strArr2, (TypeInformation<?>[]) typeInformationArr, strArr3);
        final int[] findColIndicesWithAssertAndHint = TableUtil.findColIndicesWithAssertAndHint(batchOperator2.getSchema(), strArr);
        HasHandleInvalid.HandleInvalid.valueOf(((HasHandleInvalid.HandleInvalid) params.get(StringIndexerPredictParams.HANDLE_INVALID)).toString());
        DataSet zipWithUniqueId = DataSetUtils.zipWithUniqueId(batchOperator2.getDataSet());
        setOutput((DataSet<Row>) zipWithUniqueId.join(zipWithUniqueId.flatMap(new RichFlatMapFunction<Tuple2<Long, Row>, Tuple3<Long, Integer, Long>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiIndexerStringPredictBatchOp.3
            private static final long serialVersionUID = 7795878509849151894L;

            public void flatMap(Tuple2<Long, Row> tuple2, Collector<Tuple3<Long, Integer, Long>> collector) throws Exception {
                for (int i = 0; i < findColIndicesWithAssertAndHint.length; i++) {
                    Object field = ((Row) tuple2.f1).getField(findColIndicesWithAssertAndHint[i]);
                    if (field != null) {
                        collector.collect(Tuple3.of(tuple2.f0, Integer.valueOf(i), (Long) field));
                    } else {
                        collector.collect(Tuple3.of(tuple2.f0, Integer.valueOf(i), -1L));
                    }
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple2<Long, Row>) obj, (Collector<Tuple3<Long, Integer, Long>>) collector);
            }
        }).name("flatten_pred_data").returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, Types.INT, Types.LONG})).leftOuterJoin(getModelData(batchOperator, getModelMeta(batchOperator), strArr)).where(new int[]{1, 2}).equalTo(new int[]{0, 2}).with(new JoinFunction<Tuple3<Long, Integer, Long>, Tuple3<Integer, String, Long>, Tuple3<Long, Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiIndexerStringPredictBatchOp.4
            private static final long serialVersionUID = -3177975102816197011L;

            public Tuple3<Long, Integer, String> join(Tuple3<Long, Integer, Long> tuple3, Tuple3<Integer, String, Long> tuple32) throws Exception {
                return tuple32 == null ? Tuple3.of(tuple3.f0, tuple3.f1, "null") : Tuple3.of(tuple3.f0, tuple3.f1, tuple32.f1);
            }
        }).name("map_index_to_token").returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, Types.INT, Types.STRING})).groupBy(new int[]{0}).reduceGroup(new RichGroupReduceFunction<Tuple3<Long, Integer, String>, Tuple2<Long, Row>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiIndexerStringPredictBatchOp.5
            private static final long serialVersionUID = 2318140138585310686L;

            public void reduce(Iterable<Tuple3<Long, Integer, String>> iterable, Collector<Tuple2<Long, Row>> collector) throws Exception {
                Long l = null;
                Row row = new Row(findColIndicesWithAssertAndHint.length);
                for (Tuple3<Long, Integer, String> tuple3 : iterable) {
                    row.setField(((Integer) tuple3.f1).intValue(), tuple3.f2);
                    l = (Long) tuple3.f0;
                }
                collector.collect(Tuple2.of(l, row));
            }
        }).name("aggregate_result").returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, new RowTypeInfo(typeInformationArr)}))).where(new int[]{0}).equalTo(new int[]{0}).with(new JoinFunction<Tuple2<Long, Row>, Tuple2<Long, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiIndexerStringPredictBatchOp.6
            private static final long serialVersionUID = 3724539437313089427L;

            public Row join(Tuple2<Long, Row> tuple2, Tuple2<Long, Row> tuple22) throws Exception {
                return outputColsHelper.getResultRow((Row) tuple2.f1, (Row) tuple22.f1);
            }
        }).name("merge_result").returns(new RowTypeInfo(outputColsHelper.getResultSchema().getFieldTypes())), outputColsHelper.getResultSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ HugeMultiIndexerStringPredictBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
