package com.alibaba.alink.operator.batch.dataproc;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithFirstInputSpec;
import com.alibaba.alink.common.annotation.SelectedColsWithFirstInputSpec;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkIllegalModelException;
import com.alibaba.alink.common.utils.OutputColsHelper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.params.dataproc.HasHandleInvalid;
import com.alibaba.alink.params.dataproc.HugeMultiStringIndexerPredictParams;
import com.alibaba.alink.params.dataproc.MultiStringIndexerPredictParams;
import com.alibaba.alink.params.dataproc.StringIndexerPredictParams;
import com.alibaba.alink.params.shared.colname.HasSelectedCols;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(value = PortType.MODEL, desc = PortDesc.PREDICT_INPUT_MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.PREDICT_INPUT_DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@SelectedColsWithFirstInputSpec
@NameCn("HugeMultiStringIndexer预测")
@ReservedColsWithFirstInputSpec
@NameEn("Huge Multi String Indexer Prediction")
/* loaded from: input_file:com/alibaba/alink/operator/batch/dataproc/HugeMultiStringIndexerPredictBatchOp.class */
public final class HugeMultiStringIndexerPredictBatchOp extends BatchOperator<HugeMultiStringIndexerPredictBatchOp> implements HugeMultiStringIndexerPredictParams<HugeMultiStringIndexerPredictBatchOp> {
    private static final long serialVersionUID = 2965070988618051205L;

    public HugeMultiStringIndexerPredictBatchOp() {
        this(new Params());
    }

    public HugeMultiStringIndexerPredictBatchOp(Params params) {
        super(params);
    }

    private DataSet<String> getModelMeta(BatchOperator batchOperator) {
        return batchOperator.getDataSet().flatMap(new RichFlatMapFunction<Row, String>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiStringIndexerPredictBatchOp.1
            private static final long serialVersionUID = 7056981277013604788L;

            public void flatMap(Row row, Collector<String> collector) throws Exception {
                if (((Long) row.getField(0)).longValue() < 0) {
                    collector.collect((String) row.getField(1));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<String>) collector);
            }
        }).name("get_model_meta").returns(Types.STRING);
    }

    private DataSet<Tuple3<Integer, String, Long>> getModelData(BatchOperator batchOperator, DataSet<String> dataSet, final String[] strArr) {
        return batchOperator.getDataSet().flatMap(new RichFlatMapFunction<Row, Tuple3<Integer, String, Long>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiStringIndexerPredictBatchOp.2
            private static final long serialVersionUID = 6211320484769507949L;
            transient int[] selectedColIdxInModel;

            public void open(Configuration configuration) throws Exception {
                List broadcastVariable = getRuntimeContext().getBroadcastVariable("modelMeta");
                if (broadcastVariable.size() != 1) {
                    throw new AkIllegalModelException("Invalid model.");
                }
                String[] strArr2 = (String[]) Params.fromJson((String) broadcastVariable.get(0)).get(HasSelectedCols.SELECTED_COLS);
                this.selectedColIdxInModel = new int[strArr.length];
                for (int i = 0; i < strArr.length; i++) {
                    String str = strArr[i];
                    int findColIndex = TableUtil.findColIndex(strArr2, str);
                    if (findColIndex < 0) {
                        throw new AkIllegalModelException("Can't find col in model: " + str);
                    }
                    this.selectedColIdxInModel[i] = findColIndex;
                }
            }

            public void flatMap(Row row, Collector<Tuple3<Integer, String, Long>> collector) throws Exception {
                if (((Long) row.getField(0)).longValue() >= 0) {
                    int intValue = ((Long) row.getField(0)).intValue();
                    for (int i = 0; i < this.selectedColIdxInModel.length; i++) {
                        if (this.selectedColIdxInModel[i] == intValue) {
                            collector.collect(Tuple3.of(Integer.valueOf(i), (String) row.getField(1), (Long) row.getField(2)));
                            return;
                        }
                    }
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Tuple3<Integer, String, Long>>) collector);
            }
        }).withBroadcastSet(dataSet, "modelMeta").name("get_model_data").returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.STRING, Types.LONG}));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public HugeMultiStringIndexerPredictBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        Params params = super.getParams();
        BatchOperator<?> batchOperator = batchOperatorArr[0];
        BatchOperator<?> batchOperator2 = batchOperatorArr[1];
        String[] strArr = (String[]) params.get(MultiStringIndexerPredictParams.SELECTED_COLS);
        String[] strArr2 = (String[]) params.get(MultiStringIndexerPredictParams.OUTPUT_COLS);
        if (strArr2 == null) {
            strArr2 = strArr;
        }
        String[] strArr3 = (String[]) params.get(StringIndexerPredictParams.RESERVED_COLS);
        TypeInformation[] typeInformationArr = new TypeInformation[strArr2.length];
        Arrays.fill(typeInformationArr, Types.LONG);
        final OutputColsHelper outputColsHelper = new OutputColsHelper(batchOperator2.getSchema(), strArr2, (TypeInformation<?>[]) typeInformationArr, strArr3);
        final int[] findColIndicesWithAssertAndHint = TableUtil.findColIndicesWithAssertAndHint(batchOperator2.getSchema(), strArr);
        final HasHandleInvalid.HandleInvalid valueOf = HasHandleInvalid.HandleInvalid.valueOf(((HasHandleInvalid.HandleInvalid) params.get(StringIndexerPredictParams.HANDLE_INVALID)).toString());
        DataSet zipWithUniqueId = DataSetUtils.zipWithUniqueId(batchOperator2.getDataSet());
        DataSet<Tuple3<Integer, String, Long>> modelData = getModelData(batchOperator, getModelMeta(batchOperator), strArr);
        setOutput((DataSet<Row>) zipWithUniqueId.join(zipWithUniqueId.flatMap(new RichFlatMapFunction<Tuple2<Long, Row>, Tuple3<Long, Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiStringIndexerPredictBatchOp.6
            private static final long serialVersionUID = -8382461068855755626L;

            public void flatMap(Tuple2<Long, Row> tuple2, Collector<Tuple3<Long, Integer, String>> collector) throws Exception {
                for (int i = 0; i < findColIndicesWithAssertAndHint.length; i++) {
                    Object field = ((Row) tuple2.f1).getField(findColIndicesWithAssertAndHint[i]);
                    if (field != null) {
                        collector.collect(Tuple3.of(tuple2.f0, Integer.valueOf(i), String.valueOf(field)));
                    }
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple2<Long, Row>) obj, (Collector<Tuple3<Long, Integer, String>>) collector);
            }
        }).name("flatten_pred_data").returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, Types.INT, Types.STRING})).leftOuterJoin(modelData).where(new int[]{1, 2}).equalTo(new int[]{0, 1}).with(new JoinFunction<Tuple3<Long, Integer, String>, Tuple3<Integer, String, Long>, Tuple3<Long, Integer, Long>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiStringIndexerPredictBatchOp.8
            private static final long serialVersionUID = -2049684621727655644L;

            public Tuple3<Long, Integer, Long> join(Tuple3<Long, Integer, String> tuple3, Tuple3<Integer, String, Long> tuple32) throws Exception {
                return tuple32 == null ? Tuple3.of(tuple3.f0, tuple3.f1, -1L) : Tuple3.of(tuple3.f0, tuple3.f1, tuple32.f2);
            }
        }).name("map_token_to_index").returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, Types.INT, Types.LONG})).union(zipWithUniqueId.flatMap(new FlatMapFunction<Tuple2<Long, Row>, Tuple3<Long, Integer, Long>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiStringIndexerPredictBatchOp.7
            private static final long serialVersionUID = 4078100010408649546L;

            public void flatMap(Tuple2<Long, Row> tuple2, Collector<Tuple3<Long, Integer, Long>> collector) throws Exception {
                for (int i = 0; i < findColIndicesWithAssertAndHint.length; i++) {
                    if (((Row) tuple2.f1).getField(findColIndicesWithAssertAndHint[i]) == null) {
                        collector.collect(Tuple3.of(tuple2.f0, Integer.valueOf(i), -1L));
                    }
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple2<Long, Row>) obj, (Collector<Tuple3<Long, Integer, Long>>) collector);
            }
        }).name("map_null_token_to_index").returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, Types.INT, Types.LONG}))).groupBy(new int[]{0}).reduceGroup(new RichGroupReduceFunction<Tuple3<Long, Integer, Long>, Tuple2<Long, Row>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiStringIndexerPredictBatchOp.9
            private static final long serialVersionUID = -1581264399340055162L;
            transient Map<Integer, Long> defaultIndex;

            /* JADX WARN: Multi-variable type inference failed */
            public void open(Configuration configuration) throws Exception {
                if (valueOf.equals(HasHandleInvalid.HandleInvalid.SKIP) || valueOf.equals(HasHandleInvalid.HandleInvalid.ERROR)) {
                    return;
                }
                List broadcastVariable = getRuntimeContext().getBroadcastVariable("defaultIndex");
                this.defaultIndex = new HashMap();
                for (int i = 0; i < broadcastVariable.size(); i++) {
                    this.defaultIndex.put(((Tuple2) broadcastVariable.get(i)).f0, ((Tuple2) broadcastVariable.get(i)).f1);
                }
            }

            public void reduce(Iterable<Tuple3<Long, Integer, Long>> iterable, Collector<Tuple2<Long, Row>> collector) throws Exception {
                Long l = null;
                Row row = new Row(findColIndicesWithAssertAndHint.length);
                for (Tuple3<Long, Integer, Long> tuple3 : iterable) {
                    Long l2 = (Long) tuple3.f2;
                    if (l2.longValue() == -1) {
                        switch (valueOf) {
                            case KEEP:
                                Long l3 = this.defaultIndex.get(tuple3.f1);
                                l2 = Long.valueOf(l3 == null ? 0L : l3.longValue());
                                break;
                            case SKIP:
                                l2 = null;
                                break;
                            case ERROR:
                                throw new AkIllegalDataException("Unknown token.");
                        }
                    }
                    row.setField(((Integer) tuple3.f1).intValue(), l2);
                    l = (Long) tuple3.f0;
                }
                collector.collect(Tuple2.of(l, row));
            }
        }).withBroadcastSet(modelData.project(new int[]{0, 2}).mapPartition(new MapPartitionFunction<Tuple2<Integer, Long>, Tuple2<Integer, Long>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiStringIndexerPredictBatchOp.5
            public void mapPartition(Iterable<Tuple2<Integer, Long>> iterable, Collector<Tuple2<Integer, Long>> collector) throws Exception {
                HashMap hashMap = new HashMap();
                for (Tuple2<Integer, Long> tuple2 : iterable) {
                    hashMap.put(tuple2.f0, Long.valueOf(Math.max(((Long) hashMap.getOrDefault(tuple2.f0, 0L)).longValue(), ((Long) tuple2.f1).longValue())));
                }
                hashMap.forEach((num, l) -> {
                    collector.collect(Tuple2.of(num, l));
                });
            }
        }).groupBy(new int[]{0}).reduce(new ReduceFunction<Tuple2<Integer, Long>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiStringIndexerPredictBatchOp.4
            private static final long serialVersionUID = 5053931294560858595L;

            public Tuple2<Integer, Long> reduce(Tuple2<Integer, Long> tuple2, Tuple2<Integer, Long> tuple22) throws Exception {
                return Tuple2.of(tuple2.f0, Long.valueOf(Math.max(((Long) tuple2.f1).longValue(), ((Long) tuple22.f1).longValue())));
            }
        }).map(new MapFunction<Tuple2<Integer, Long>, Tuple2<Integer, Long>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiStringIndexerPredictBatchOp.3
            private static final long serialVersionUID = 2371384596429653822L;

            public Tuple2<Integer, Long> map(Tuple2<Integer, Long> tuple2) throws Exception {
                return Tuple2.of(tuple2.f0, Long.valueOf(((Long) tuple2.f1).longValue() + 1));
            }
        }).name("get_default_index").returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.LONG})), "defaultIndex").name("aggregate_result").returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, new RowTypeInfo(typeInformationArr)}))).where(new int[]{0}).equalTo(new int[]{0}).with(new JoinFunction<Tuple2<Long, Row>, Tuple2<Long, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeMultiStringIndexerPredictBatchOp.10
            private static final long serialVersionUID = 3724539437313089427L;

            public Row join(Tuple2<Long, Row> tuple2, Tuple2<Long, Row> tuple22) throws Exception {
                return outputColsHelper.getResultRow((Row) tuple2.f1, (Row) tuple22.f1);
            }
        }).name("merge_result").returns(new RowTypeInfo(outputColsHelper.getResultSchema().getFieldTypes())), outputColsHelper.getResultSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ HugeMultiStringIndexerPredictBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
