package com.alibaba.alink.operator.batch.dataproc;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithFirstInputSpec;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.utils.OutputColsHelper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.params.dataproc.HasHandleInvalid;
import com.alibaba.alink.params.dataproc.HugeMultiStringIndexerPredictParams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import org.apache.commons.lang.StringUtils;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(value = PortType.MODEL, desc = PortDesc.PREDICT_INPUT_MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.PREDICT_INPUT_DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = {TypeCollections.LONG_TYPES})
@NameCn("超大ID化预测")
@ReservedColsWithFirstInputSpec
@NameEn("Huge Indexer String Prediction")
/* loaded from: input_file:com/alibaba/alink/operator/batch/dataproc/HugeIndexerStringPredictBatchOp.class */
public final class HugeIndexerStringPredictBatchOp extends BatchOperator<HugeIndexerStringPredictBatchOp> implements HugeMultiStringIndexerPredictParams<HugeIndexerStringPredictBatchOp> {
    private static final long serialVersionUID = -794572755838107745L;

    public HugeIndexerStringPredictBatchOp() {
        this(new Params());
    }

    public HugeIndexerStringPredictBatchOp(Params params) {
        super(params);
    }

    private DataSet<Tuple2<String, Long>> getModelData(BatchOperator batchOperator) {
        return batchOperator.getDataSet().flatMap(new FlatMapFunction<Row, Tuple2<String, Long>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeIndexerStringPredictBatchOp.1
            private static final long serialVersionUID = 7697943140162154366L;

            public void flatMap(Row row, Collector<Tuple2<String, Long>> collector) throws Exception {
                collector.collect(Tuple2.of((String) row.getField(0), (Long) row.getField(1)));
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Row) obj, (Collector<Tuple2<String, Long>>) collector);
            }
        }).name("get_model_data").returns(new TupleTypeInfo(new TypeInformation[]{Types.STRING, Types.LONG}));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public HugeIndexerStringPredictBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        Params params = super.getParams();
        BatchOperator<?> batchOperator = batchOperatorArr[0];
        BatchOperator<?> batchOperator2 = batchOperatorArr[1];
        String[] strArr = (String[]) params.get(HugeMultiStringIndexerPredictParams.SELECTED_COLS);
        String[] strArr2 = (String[]) params.get(HugeMultiStringIndexerPredictParams.OUTPUT_COLS);
        if (strArr2 == null) {
            strArr2 = strArr;
        }
        String[] strArr3 = (String[]) params.get(HugeMultiStringIndexerPredictParams.RESERVED_COLS);
        TypeInformation[] typeInformationArr = new TypeInformation[strArr2.length];
        final TypeInformation<?>[] findColTypesWithAssert = TableUtil.findColTypesWithAssert(batchOperator2.getSchema(), strArr);
        Arrays.fill(typeInformationArr, Types.STRING);
        final OutputColsHelper outputColsHelper = new OutputColsHelper(batchOperator2.getSchema(), strArr2, (TypeInformation<?>[]) typeInformationArr, strArr3);
        final int[] findColIndicesWithAssertAndHint = TableUtil.findColIndicesWithAssertAndHint(batchOperator2.getSchema(), strArr);
        HasHandleInvalid.HandleInvalid.valueOf(((HasHandleInvalid.HandleInvalid) params.get(HugeMultiStringIndexerPredictParams.HANDLE_INVALID)).toString());
        DataSet zipWithUniqueId = DataSetUtils.zipWithUniqueId(batchOperator2.getDataSet());
        setOutput((DataSet<Row>) zipWithUniqueId.join(zipWithUniqueId.flatMap(new RichFlatMapFunction<Tuple2<Long, Row>, Tuple4<Long, Integer, Integer, Long>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeIndexerStringPredictBatchOp.2
            private static final long serialVersionUID = -8382461068855755626L;

            public void flatMap(Tuple2<Long, Row> tuple2, Collector<Tuple4<Long, Integer, Integer, Long>> collector) throws Exception {
                for (int i = 0; i < findColIndicesWithAssertAndHint.length; i++) {
                    Object field = ((Row) tuple2.f1).getField(findColIndicesWithAssertAndHint[i]);
                    if (null == field) {
                        collector.collect(Tuple4.of(tuple2.f0, Integer.valueOf(i), 0, -1L));
                    } else if (findColTypesWithAssert[i].isBasicType()) {
                        collector.collect(Tuple4.of(tuple2.f0, Integer.valueOf(i), 0, (Long) field));
                    } else {
                        Long[] lArr = (Long[]) field;
                        for (int i2 = 0; i2 < lArr.length; i2++) {
                            collector.collect(Tuple4.of(tuple2.f0, Integer.valueOf(i), Integer.valueOf(i2), lArr[i2]));
                        }
                    }
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple2<Long, Row>) obj, (Collector<Tuple4<Long, Integer, Integer, Long>>) collector);
            }
        }).name("flatten_pred_data").returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, Types.INT, Types.INT, Types.LONG})).leftOuterJoin(getModelData(batchOperator)).where(new int[]{3}).equalTo(new int[]{1}).with(new JoinFunction<Tuple4<Long, Integer, Integer, Long>, Tuple2<String, Long>, Tuple4<Long, Integer, Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeIndexerStringPredictBatchOp.3
            private static final long serialVersionUID = 2270459281179536013L;

            public Tuple4<Long, Integer, Integer, String> join(Tuple4<Long, Integer, Integer, Long> tuple4, Tuple2<String, Long> tuple2) throws Exception {
                return tuple2 == null ? Tuple4.of(tuple4.f0, tuple4.f1, tuple4.f2, "notFound") : Tuple4.of(tuple4.f0, tuple4.f1, tuple4.f2, tuple2.f0);
            }
        }).name("map_index_to_token").returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, Types.INT, Types.INT, Types.STRING})).groupBy(new int[]{0}).reduceGroup(new GroupReduceFunction<Tuple4<Long, Integer, Integer, String>, Tuple2<Long, Row>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeIndexerStringPredictBatchOp.4
            private static final long serialVersionUID = -1581264399340055162L;

            public void reduce(Iterable<Tuple4<Long, Integer, Integer, String>> iterable, Collector<Tuple2<Long, Row>> collector) throws Exception {
                Long l = null;
                Row row = new Row(findColIndicesWithAssertAndHint.length);
                ArrayList arrayList = new ArrayList();
                for (Tuple4<Long, Integer, Integer, String> tuple4 : iterable) {
                    arrayList.add(Tuple3.of(tuple4.f1, tuple4.f2, tuple4.f3));
                    l = (Long) tuple4.f0;
                }
                arrayList.sort(new Comparator<Tuple3<Integer, Integer, String>>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeIndexerStringPredictBatchOp.4.1
                    @Override // java.util.Comparator
                    public int compare(Tuple3<Integer, Integer, String> tuple3, Tuple3<Integer, Integer, String> tuple32) {
                        return ((Integer) tuple3.f0).equals(tuple32.f0) ? ((Integer) tuple3.f1).compareTo((Integer) tuple32.f1) : ((Integer) tuple3.f0).compareTo((Integer) tuple32.f0);
                    }
                });
                ArrayList arrayList2 = new ArrayList(arrayList.size());
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    arrayList2.add(((Tuple3) it.next()).f2);
                }
                String[] strArr4 = new String[findColIndicesWithAssertAndHint.length];
                int i = 0;
                int i2 = 0;
                int i3 = 0;
                for (int i4 = 0; i4 < arrayList.size(); i4++) {
                    Tuple3 tuple3 = (Tuple3) arrayList.get(i4);
                    if (i3 != ((Integer) tuple3.f0).intValue()) {
                        strArr4[i3] = StringUtils.join(arrayList2.subList(i, i2), ",");
                        i3 = ((Integer) tuple3.f0).intValue();
                        i = i4;
                        i2 = i4;
                    }
                    i2++;
                    if (i4 == arrayList.size() - 1) {
                        strArr4[i3] = StringUtils.join(arrayList2.subList(i, i2), ",");
                    }
                }
                for (int i5 = 0; i5 < strArr4.length; i5++) {
                    row.setField(i5, strArr4[i5]);
                }
                collector.collect(Tuple2.of(l, row));
            }
        }).name("aggregate_result").returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, new RowTypeInfo(typeInformationArr)}))).where(new int[]{0}).equalTo(new int[]{0}).with(new JoinFunction<Tuple2<Long, Row>, Tuple2<Long, Row>, Row>() { // from class: com.alibaba.alink.operator.batch.dataproc.HugeIndexerStringPredictBatchOp.5
            private static final long serialVersionUID = 3724539437313089427L;

            public Row join(Tuple2<Long, Row> tuple2, Tuple2<Long, Row> tuple22) throws Exception {
                return outputColsHelper.getResultRow((Row) tuple2.f1, (Row) tuple22.f1);
            }
        }).name("merge_result").returns(new RowTypeInfo(outputColsHelper.getResultSchema().getFieldTypes())), outputColsHelper.getResultSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ HugeIndexerStringPredictBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
