package com.alibaba.alink.operator.batch.recommendation;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.dataproc.HugeIndexerStringPredictBatchOp;
import com.alibaba.alink.operator.batch.dataproc.HugeStringIndexerPredictBatchOp;
import com.alibaba.alink.operator.batch.dataproc.StringIndexerTrainBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.operator.common.recommendation.SwingRecommKernel;
import com.alibaba.alink.operator.common.recommendation.SwingRecommModelConverter;
import com.alibaba.alink.operator.common.recommendation.SwingResData;
import com.alibaba.alink.params.recommendation.SwingTrainParams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "userCol"), @ParamSelectColumnSpec(name = "itemCol")})
@NameCn("swing训练")
@NameEn("Swing Recommendation Training")
/* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/SwingTrainBatchOp.class */
public class SwingTrainBatchOp extends BatchOperator<SwingTrainBatchOp> implements SwingTrainParams<SwingTrainBatchOp> {
    private static final long serialVersionUID = 6094224433980263495L;
    private static final String ITEM_ID_COLNAME = "alink_itemID_in_swing";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/SwingTrainBatchOp$BuildModelData.class */
    public static class BuildModelData extends RichMapPartitionFunction<Row, Row> {
        private final String itemCol;
        private final Params meta;

        BuildModelData(String str, Params params) {
            this.itemCol = str;
            this.meta = params;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                collector.collect(Row.of(new Object[]{null, this.meta.toJson()}));
            }
            for (Row row : iterable) {
                collector.collect(Row.of(new Object[]{(Comparable) row.getField(0), JsonConverter.toJson(new SwingResData(row.getField(1) instanceof String ? ((String) row.getField(1)).split(",") : (Long[]) row.getField(1), (Float[]) row.getField(2), this.itemCol))}));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/SwingTrainBatchOp$BuildSwingData.class */
    public static class BuildSwingData implements GroupReduceFunction<Row, Tuple3<Comparable<?>, Long, Long[]>> {
        private static final long serialVersionUID = 6417591701594465880L;
        int maxUserItems;
        int minUserItems;
        int idIndex;

        BuildSwingData(int i, int i2, int i3) {
            this.maxUserItems = i;
            this.minUserItems = i2;
            this.idIndex = i3;
        }

        public void reduce(Iterable<Row> iterable, Collector<Tuple3<Comparable<?>, Long, Long[]>> collector) throws Exception {
            HashMap hashMap = new HashMap();
            for (Row row : iterable) {
                hashMap.put(Long.valueOf(String.valueOf(row.getField(this.idIndex))), (Comparable) row.getField(1));
            }
            if (hashMap.size() < this.minUserItems || hashMap.size() > this.maxUserItems) {
                return;
            }
            Long[] lArr = new Long[hashMap.size()];
            int i = 0;
            Iterator it = hashMap.entrySet().iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                lArr[i2] = (Long) ((Map.Entry) it.next()).getKey();
            }
            for (Long l : lArr) {
                collector.collect(Tuple3.of(hashMap.get(l), l, lArr));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/SwingTrainBatchOp$CalcSimilarity.class */
    public static class CalcSimilarity extends RichGroupReduceFunction<Tuple3<Comparable<?>, Long, Long[]>, Row> {
        private static final long serialVersionUID = -2438120820385058339L;
        private final float alpha;
        int maxItemNumber;
        int maxUserItems;
        float userAlpha;
        float userBeta;
        boolean normalize;

        CalcSimilarity(float f, int i, int i2, float f2, float f3, boolean z) {
            this.alpha = f;
            this.userAlpha = f2;
            this.userBeta = f3;
            this.maxItemNumber = i;
            this.maxUserItems = i2;
            this.normalize = z;
        }

        private float computeUserWeight(int i) {
            return (float) (1.0d / Math.pow(this.userAlpha + i, this.userBeta));
        }

        public void reduce(Iterable<Tuple3<Comparable<?>, Long, Long[]>> iterable, Collector<Row> collector) throws Exception {
            Comparable comparable = null;
            Long l = null;
            ArrayList arrayList = new ArrayList();
            for (Tuple3<Comparable<?>, Long, Long[]> tuple3 : iterable) {
                comparable = (Comparable) tuple3.f0;
                l = (Long) tuple3.f1;
                if (arrayList.size() == this.maxItemNumber) {
                    int random = (int) (Math.random() * (this.maxItemNumber + 1));
                    if (random < this.maxItemNumber) {
                        arrayList.set(random, tuple3.f2);
                    }
                } else {
                    arrayList.add(tuple3.f2);
                }
            }
            float[] fArr = new float[arrayList.size()];
            int i = 0;
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                Long[] lArr = (Long[]) it.next();
                Arrays.sort(lArr);
                int i2 = i;
                i++;
                fArr[i2] = computeUserWeight(lArr.length);
            }
            HashMap hashMap = new HashMap();
            long[] jArr = new long[this.maxUserItems];
            for (int i3 = 0; i3 < arrayList.size(); i3++) {
                for (int i4 = i3 + 1; i4 < arrayList.size(); i4++) {
                    int countCommonItems = countCommonItems((Long[]) arrayList.get(i3), (Long[]) arrayList.get(i4), jArr);
                    if (countCommonItems != 0) {
                        float f = (fArr[i3] * fArr[i4]) / (this.alpha + countCommonItems);
                        for (int i5 = 0; i5 < countCommonItems; i5++) {
                            Long valueOf = Long.valueOf(jArr[i5]);
                            if (!valueOf.equals(l)) {
                                hashMap.put(valueOf, Float.valueOf(((Float) hashMap.getOrDefault(valueOf, Float.valueOf(0.0f))).floatValue() + f));
                            }
                        }
                    }
                }
            }
            ArrayList arrayList2 = new ArrayList();
            hashMap.forEach((l2, f2) -> {
                arrayList2.add(Tuple2.of(l2, f2));
            });
            arrayList2.sort(new Comparator<Tuple2<Long, Float>>() { // from class: com.alibaba.alink.operator.batch.recommendation.SwingTrainBatchOp.CalcSimilarity.1
                @Override // java.util.Comparator
                public int compare(Tuple2<Long, Float> tuple2, Tuple2<Long, Float> tuple22) {
                    return 0 - Float.compare(((Float) tuple2.f1).floatValue(), ((Float) tuple22.f1).floatValue());
                }
            });
            if (arrayList2.size() == 0) {
                return;
            }
            Long[] lArr2 = new Long[arrayList2.size()];
            Float[] fArr2 = new Float[arrayList2.size()];
            float floatValue = this.normalize ? ((Float) ((Tuple2) arrayList2.get(0)).f1).floatValue() : 1.0f;
            for (int i6 = 0; i6 < arrayList2.size(); i6++) {
                lArr2[i6] = (Long) ((Tuple2) arrayList2.get(i6)).f0;
                fArr2[i6] = Float.valueOf(((Float) ((Tuple2) arrayList2.get(i6)).f1).floatValue() / floatValue);
            }
            collector.collect(Row.of(new Object[]{comparable, lArr2, fArr2}));
        }

        private static int countCommonItems(Long[] lArr, Long[] lArr2, long[] jArr) {
            int i = 0;
            int i2 = 0;
            int i3 = 0;
            while (i < lArr.length && i2 < lArr2.length) {
                if (lArr[i].equals(lArr2[i2])) {
                    int i4 = i3;
                    i3++;
                    jArr[i4] = lArr[i].longValue();
                    i++;
                    i2++;
                } else if (lArr[i].longValue() < lArr2[i2].longValue()) {
                    i++;
                } else {
                    i2++;
                }
            }
            return i3;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/SwingTrainBatchOp$RowKeySelector.class */
    public static class RowKeySelector implements KeySelector<Row, Comparable<?>> {
        private static final long serialVersionUID = 7514280642434354647L;
        int index;

        public RowKeySelector(int i) {
            this.index = i;
        }

        public Comparable<?> getKey(Row row) {
            return (Comparable) row.getField(this.index);
        }
    }

    public SwingTrainBatchOp(Params params) {
        super(params);
    }

    public SwingTrainBatchOp() {
        this(new Params());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public SwingTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        String userCol = getUserCol();
        String itemCol = getItemCol();
        Integer maxUserItems = getMaxUserItems();
        Integer minUserItems = getMinUserItems();
        Integer maxItemNumber = getMaxItemNumber();
        boolean booleanValue = getResultNormalize().booleanValue();
        BatchOperator<?> select = checkAndGetFirst(batchOperatorArr).select(new String[]{userCol, itemCol});
        long longValue = getMLEnvironmentId().longValue();
        TypeInformation<?> findColType = TableUtil.findColType(select.getSchema(), itemCol);
        if (!findColType.equals(Types.STRING) && !findColType.equals(Types.INT) && !findColType.equals(Types.LONG)) {
            throw new AkIllegalDataException("not supported item type:" + findColType + ", should be int,long or string");
        }
        int i = 1;
        StringIndexerTrainBatchOp stringIndexerTrainBatchOp = new StringIndexerTrainBatchOp();
        if (findColType.equals(Types.STRING)) {
            ((StringIndexerTrainBatchOp) stringIndexerTrainBatchOp.setSelectedCol(itemCol).setMLEnvironmentId(Long.valueOf(longValue))).setStringOrderType("random").linkFrom(select);
            select = ((HugeStringIndexerPredictBatchOp) new HugeStringIndexerPredictBatchOp().setSelectedCols(itemCol).setOutputCols(ITEM_ID_COLNAME).setMLEnvironmentId(Long.valueOf(longValue))).linkFrom(stringIndexerTrainBatchOp, select);
            i = 2;
        }
        BatchOperator<?> fromTable = BatchOperator.fromTable(DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) select.getDataSet().groupBy(new RowKeySelector(0)).reduceGroup(new BuildSwingData(maxUserItems.intValue(), minUserItems.intValue(), i)).name("build_main_item_data").groupBy(new int[]{1}).reduceGroup(new CalcSimilarity(getAlpha().floatValue(), maxItemNumber.intValue(), maxUserItems.intValue(), getUserAlpha().floatValue(), getUserBeta().floatValue(), booleanValue)).name("compute_similarity"), new String[]{itemCol, "swing_items", "swing_scores"}, (TypeInformation<?>[]) new TypeInformation[]{findColType, Types.OBJECT_ARRAY(Types.LONG), Types.OBJECT_ARRAY(Types.FLOAT)}));
        if (findColType.equals(Types.STRING)) {
            fromTable = ((HugeIndexerStringPredictBatchOp) new HugeIndexerStringPredictBatchOp().setSelectedCols("swing_items").setMLEnvironmentId(Long.valueOf(longValue))).linkFrom(stringIndexerTrainBatchOp, fromTable);
        }
        setOutput((DataSet<Row>) fromTable.getDataSet().mapPartition(new BuildModelData(itemCol, getParams().set((ParamInfo<ParamInfo<String>>) SwingRecommKernel.ITEM_TYPE, (ParamInfo<String>) FlinkTypeConverter.getTypeString(findColType)))).name("build_model_data"), new SwingRecommModelConverter(findColType).getModelSchema());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ SwingTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
