package com.alibaba.alink.operator.batch.recommendation;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.feature.OneHotPredictBatchOp;
import com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp;
import com.alibaba.alink.operator.batch.similarity.VectorNearestNeighborPredictBatchOp;
import com.alibaba.alink.operator.batch.similarity.VectorNearestNeighborTrainBatchOp;
import com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData;
import com.alibaba.alink.operator.common.feature.OneHotModelDataConverter;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.operator.common.recommendation.ItemCfModelInfo;
import com.alibaba.alink.operator.common.recommendation.ItemCfRecommModelDataConverter;
import com.alibaba.alink.operator.common.similarity.NearestNeighborsMapper;
import com.alibaba.alink.params.feature.HasEncodeWithoutWoe;
import com.alibaba.alink.params.recommendation.HasSimilarityType;
import com.alibaba.alink.params.recommendation.ItemCfRecommTrainParams;
import com.alibaba.alink.params.shared.clustering.HasFastMetric;
import java.util.List;
import java.util.TreeMap;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.FilterOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "userCol"), @ParamSelectColumnSpec(name = "itemCol"), @ParamSelectColumnSpec(name = "rateCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})})
@NameCn("ItemCf训练")
@NameEn("ItemCf Training")
/* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/ItemCfTrainBatchOp.class */
public class ItemCfTrainBatchOp extends BatchOperator<ItemCfTrainBatchOp> implements ItemCfRecommTrainParams<ItemCfTrainBatchOp>, WithModelInfoBatchOp<ItemCfModelInfo, ItemCfTrainBatchOp, ItemCfModelInfoBatchOp> {
    private static final long serialVersionUID = -5873113492724718667L;
    private static final String USER_NUM = "userNum";
    private static final String[] COL_NAMES = {"itemId", "itemVector"};

    /* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/ItemCfTrainBatchOp$ItemSimilarityVectorGenerator.class */
    public static class ItemSimilarityVectorGenerator extends RichMapPartitionFunction<Row, Row> {
        private static final long serialVersionUID = 4250780052412233802L;
        private final String itemCol;
        private long itemNum;

        public ItemSimilarityVectorGenerator(String str) {
            this.itemCol = str;
        }

        public void open(Configuration configuration) {
            for (Row row : getRuntimeContext().getBroadcastVariable(ItemCfTrainBatchOp.USER_NUM)) {
                if (row.getField(0).equals(this.itemCol)) {
                    this.itemNum = ((Long) row.getField(1)).longValue();
                    return;
                }
            }
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) {
            for (Row row : iterable) {
                TreeMap treeMap = new TreeMap();
                Object field = row.getField(0);
                Tuple2<List<Object>, List<Object>> extractKObject = NearestNeighborsMapper.extractKObject((String) row.getField(1), Long.class);
                for (int i = 0; i < ((List) extractKObject.f0).size(); i++) {
                    long longValue = ((Long) ((List) extractKObject.f0).get(i)).longValue();
                    double doubleValue = 1.0d - ((Double) ((List) extractKObject.f1).get(i)).doubleValue();
                    if (!field.equals(Long.valueOf(longValue))) {
                        treeMap.put(Integer.valueOf((int) longValue), Double.valueOf(doubleValue));
                    }
                }
                collector.collect(Row.of(new Object[]{null, field, new SparseVector((int) this.itemNum, treeMap)}));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/ItemCfTrainBatchOp$ItemVectorGenerator.class */
    public static class ItemVectorGenerator extends RichGroupReduceFunction<Row, Row> {
        private static final long serialVersionUID = 1783010539701599910L;
        private final String rateCol;
        private final String userCol;
        private long userNum;

        public ItemVectorGenerator(String str, String str2) {
            this.rateCol = str;
            this.userCol = str2;
        }

        public void open(Configuration configuration) {
            for (Row row : getRuntimeContext().getBroadcastVariable(ItemCfTrainBatchOp.USER_NUM)) {
                if (row.getField(0).equals(this.userCol)) {
                    this.userNum = ((Long) row.getField(1)).longValue();
                    return;
                }
            }
        }

        public void reduce(Iterable<Row> iterable, Collector<Row> collector) {
            TreeMap treeMap = new TreeMap();
            Object obj = null;
            for (Row row : iterable) {
                if (null == obj) {
                    AkPreconditions.checkNotNull(row.getField(1), new AkIllegalDataException("Item column is null!"));
                    obj = row.getField(1);
                }
                AkPreconditions.checkNotNull(row.getField(0), new AkIllegalDataException("User column is null!"));
                long longValue = ((Long) row.getField(0)).longValue();
                double d = 1.0d;
                if (null != this.rateCol) {
                    d = ((Number) row.getField(2)).doubleValue();
                }
                treeMap.put(Integer.valueOf((int) longValue), Double.valueOf(d));
            }
            collector.collect(Row.of(new Object[]{obj, new SparseVector((int) this.userNum, treeMap)}));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/ItemCfTrainBatchOp$UserItemVectorGenerator.class */
    public static class UserItemVectorGenerator extends RichGroupReduceFunction<Row, Row> {
        private static final long serialVersionUID = 4250780052412233802L;
        private final String rateCol;
        private final String itemCol;
        private long itemNum;

        public UserItemVectorGenerator(String str, String str2) {
            this.rateCol = str;
            this.itemCol = str2;
        }

        public void open(Configuration configuration) {
            for (Row row : getRuntimeContext().getBroadcastVariable(ItemCfTrainBatchOp.USER_NUM)) {
                if (row.getField(0).equals(this.itemCol)) {
                    this.itemNum = ((Long) row.getField(1)).longValue();
                    return;
                }
            }
        }

        public void reduce(Iterable<Row> iterable, Collector<Row> collector) {
            TreeMap treeMap = new TreeMap();
            Object obj = null;
            for (Row row : iterable) {
                if (null == obj) {
                    obj = row.getField(0);
                }
                AkPreconditions.checkNotNull(row.getField(0), new AkIllegalDataException("User column is null!"));
                long longValue = ((Long) row.getField(1)).longValue();
                double d = 1.0d;
                if (null != this.rateCol) {
                    d = ((Number) row.getField(2)).doubleValue();
                }
                treeMap.put(Integer.valueOf((int) longValue), Double.valueOf(d));
            }
            collector.collect(Row.of(new Object[]{obj, null, new SparseVector((int) this.itemNum, treeMap)}));
        }
    }

    public ItemCfTrainBatchOp() {
        super(new Params());
    }

    public ItemCfTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public ItemCfTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        final String userCol = getUserCol();
        final String itemCol = getItemCol();
        final String rateCol = getRateCol();
        final TypeInformation<?> findColTypeWithAssertAndHint = TableUtil.findColTypeWithAssertAndHint(checkAndGetFirst.getSchema(), userCol);
        final String typeString = FlinkTypeConverter.getTypeString(TableUtil.findColTypeWithAssertAndHint(checkAndGetFirst.getSchema(), itemCol));
        if (null == rateCol) {
            AkPreconditions.checkArgument(getSimilarityType().equals(HasSimilarityType.SimilarityType.JACCARD), (ExceptionWithErrorCode) new AkIllegalArgumentException("When rateCol is not given, only Jaccard calc is supported!"));
        }
        BatchOperator<?> select = checkAndGetFirst.select(null == rateCol ? new String[]{userCol, itemCol} : new String[]{userCol, itemCol, rateCol});
        OneHotTrainBatchOp linkFrom = new OneHotTrainBatchOp().setSelectedCols(userCol, itemCol).linkFrom(select);
        OneHotPredictBatchOp linkFrom2 = new OneHotPredictBatchOp().setSelectedCols(userCol, itemCol).setOutputCols("userEncode", itemCol).setEncode(HasEncodeWithoutWoe.Encode.INDEX).linkFrom(linkFrom, select);
        Operator name = linkFrom2.select(null == rateCol ? new String[]{"userEncode", itemCol} : new String[]{"userEncode", itemCol, rateCol}).getDataSet().groupBy(new int[]{1}).reduceGroup(new ItemVectorGenerator(rateCol, userCol)).withBroadcastSet(linkFrom.getSideOutput(0).getDataSet(), USER_NUM).name("GenerateItemVector");
        Operator name2 = linkFrom2.select(null == rateCol ? new String[]{userCol, itemCol} : new String[]{userCol, itemCol, rateCol}).getDataSet().groupBy(new int[]{0}).reduceGroup(new UserItemVectorGenerator(rateCol, itemCol)).withBroadcastSet(linkFrom.getSideOutput(0).getDataSet(), USER_NUM).name("GetUserItems");
        DataSetWrapperBatchOp dataSetWrapperBatchOp = new DataSetWrapperBatchOp(name, COL_NAMES, new TypeInformation[]{Types.LONG, AlinkTypes.SPARSE_VECTOR});
        Operator name3 = new VectorNearestNeighborPredictBatchOp().setSelectedCol(COL_NAMES[1]).setReservedCols(COL_NAMES[0]).setTopN(Integer.valueOf(getMaxNeighborNumber().intValue() + 1)).setRadius(Double.valueOf(1.0d - getSimilarityThreshold().doubleValue())).linkFrom(new VectorNearestNeighborTrainBatchOp().setIdCol(COL_NAMES[0]).setSelectedCol(COL_NAMES[1]).setMetric(HasFastMetric.Metric.valueOf(getSimilarityType().name())).linkFrom(dataSetWrapperBatchOp), dataSetWrapperBatchOp).select(new String[]{COL_NAMES[0], COL_NAMES[1]}).getDataSet().mapPartition(new ItemSimilarityVectorGenerator(itemCol)).withBroadcastSet(linkFrom.getSideOutput(0).getDataSet(), USER_NUM).name("CalcItemSimilarity");
        FilterOperator filter = linkFrom.getDataSet().filter(new FilterFunction<Row>() { // from class: com.alibaba.alink.operator.batch.recommendation.ItemCfTrainBatchOp.1
            private static final long serialVersionUID = 7406134775433418651L;

            public boolean filter(Row row) {
                return !row.getField(0).equals(0L);
            }
        });
        final Params params = getParams();
        setOutput((DataSet<Row>) name2.union(name3).mapPartition(new RichMapPartitionFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.recommendation.ItemCfTrainBatchOp.2
            private static final long serialVersionUID = 3779020277896699637L;

            public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) {
                Params params2 = null;
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    MultiStringIndexerModelData multiStringIndexerModelData = new OneHotModelDataConverter().load(getRuntimeContext().getBroadcastVariable("ITEM_MAP")).modelData;
                    String[] strArr = new String[(int) multiStringIndexerModelData.getNumberOfTokensOfColumn(itemCol)];
                    for (int i = 0; i < strArr.length; i++) {
                        strArr[i] = multiStringIndexerModelData.getToken(itemCol, Long.valueOf(i));
                    }
                    params2 = params.set((ParamInfo<ParamInfo<String>>) ItemCfRecommTrainParams.RATE_COL, (ParamInfo<String>) rateCol).set((ParamInfo<ParamInfo<String[]>>) ItemCfRecommModelDataConverter.ITEMS, (ParamInfo<String[]>) strArr).set((ParamInfo<ParamInfo<String>>) ItemCfRecommModelDataConverter.ITEM_TYPE, (ParamInfo<String>) typeString).set((ParamInfo<ParamInfo<String>>) ItemCfRecommModelDataConverter.USER_TYPE, (ParamInfo<String>) FlinkTypeConverter.getTypeString((TypeInformation<?>) findColTypeWithAssertAndHint));
                }
                new ItemCfRecommModelDataConverter(userCol, findColTypeWithAssertAndHint, itemCol).save2(Tuple2.of(params2, iterable), collector);
            }
        }).withBroadcastSet(filter, "ITEM_MAP").name("build_model"), new ItemCfRecommModelDataConverter(userCol, findColTypeWithAssertAndHint, itemCol).getModelSchema());
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp
    public ItemCfModelInfoBatchOp getModelInfoBatchOp() {
        return new ItemCfModelInfoBatchOp(getParams()).linkFrom(this);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ ItemCfTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
