package com.alibaba.alink.operator.common.recommendation;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.linalg.BLAS;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.operator.common.recommendation.RecommUtils;
import com.alibaba.alink.operator.common.utils.PackBatchOperatorUtil;
import com.alibaba.alink.params.recommendation.BaseItemsPerUserRecommParams;
import com.alibaba.alink.params.recommendation.BaseRateRecommParams;
import com.alibaba.alink.params.recommendation.BaseUsersPerItemRecommParams;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/recommendation/AlsRecommKernel.class */
public class AlsRecommKernel extends RecommKernel {
    private static final long serialVersionUID = 2716744280007281817L;
    protected transient Map<Object, DenseVector> userFactors;
    protected transient Map<Object, DenseVector> itemFactors;
    protected transient Map<Object, Set<Object>> historyUserItems;
    protected transient Map<Object, Set<Object>> historyItemUsers;
    private final Integer topK;
    private boolean excludeKnown;

    public AlsRecommKernel(TableSchema tableSchema, TableSchema tableSchema2, Params params, RecommType recommType) {
        super(tableSchema, tableSchema2, params, recommType);
        this.excludeKnown = false;
        this.userColName = (String) getParamDefaultAsNull(params, BaseRateRecommParams.USER_COL);
        this.itemColName = (String) getParamDefaultAsNull(params, BaseRateRecommParams.ITEM_COL);
        this.topK = (Integer) getParamDefaultAsNull(params, BaseItemsPerUserRecommParams.K);
        if (recommType == RecommType.ITEMS_PER_USER) {
            AkPreconditions.checkArgument(this.topK != null, (ExceptionWithErrorCode) new AkIllegalArgumentException("Missing param topK"));
            this.excludeKnown = ((Boolean) params.get(BaseItemsPerUserRecommParams.EXCLUDE_KNOWN)).booleanValue();
        } else if (recommType == RecommType.USERS_PER_ITEM) {
            AkPreconditions.checkArgument(this.topK != null, (ExceptionWithErrorCode) new AkIllegalArgumentException("Missing param topK"));
            this.excludeKnown = ((Boolean) params.get(BaseUsersPerItemRecommParams.EXCLUDE_KNOWN)).booleanValue();
        } else if (recommType == RecommType.SIMILAR_USERS) {
            AkPreconditions.checkArgument(this.topK != null, (ExceptionWithErrorCode) new AkIllegalArgumentException("Missing param topK"));
        } else if (recommType == RecommType.SIMILAR_ITEMS) {
            AkPreconditions.checkArgument(this.topK != null, (ExceptionWithErrorCode) new AkIllegalArgumentException("Missing param topK"));
        }
    }

    private static <T> T getParamDefaultAsNull(Params params, ParamInfo<T> paramInfo) {
        if (params.contains(paramInfo)) {
            return (T) params.get(paramInfo);
        }
        if (paramInfo.hasDefaultValue()) {
            return paramInfo.getDefaultValue();
        }
        return null;
    }

    @Override // com.alibaba.alink.operator.common.recommendation.RecommKernel
    public void loadModel(List<Row> list) {
        for (Row row : list) {
            if (((Number) row.getField(0)).intValue() == -1) {
                Tuple2 tuple2 = (Tuple2) JsonConverter.fromJson((String) row.getField(1), Tuple2.class);
                this.userColName = (String) ((List) ((List) tuple2.f0).get(2)).get(0);
                this.itemColName = (String) ((List) ((List) tuple2.f0).get(2)).get(1);
                int intValue = ((Integer) ((List) ((List) tuple2.f1).get(2)).get(0)).intValue();
                int intValue2 = ((Integer) ((List) ((List) tuple2.f1).get(2)).get(1)).intValue();
                if (this.recommType == RecommType.ITEMS_PER_USER) {
                    this.recommObjType = getModelSchema().getFieldTypes()[intValue2];
                } else if (this.recommType == RecommType.USERS_PER_ITEM) {
                    this.recommObjType = getModelSchema().getFieldTypes()[intValue];
                } else if (this.recommType == RecommType.SIMILAR_USERS) {
                    this.recommObjType = getModelSchema().getFieldTypes()[intValue];
                } else if (this.recommType == RecommType.SIMILAR_ITEMS) {
                    this.recommObjType = getModelSchema().getFieldTypes()[intValue2];
                }
            }
        }
        List<Row> unpackRows = PackBatchOperatorUtil.unpackRows(list, 0);
        List<Row> unpackRows2 = PackBatchOperatorUtil.unpackRows(list, 1);
        this.userFactors = new HashMap();
        this.itemFactors = new HashMap();
        unpackRows.forEach(row2 -> {
            this.userFactors.put(row2.getField(0), VectorUtil.getDenseVector(row2.getField(1)));
        });
        unpackRows2.forEach(row3 -> {
            this.itemFactors.put(row3.getField(0), VectorUtil.getDenseVector(row3.getField(1)));
        });
        if (this.excludeKnown) {
            List<Row> unpackRows3 = PackBatchOperatorUtil.unpackRows(list, 2);
            this.historyItemUsers = new HashMap();
            this.historyUserItems = new HashMap();
            for (Row row4 : unpackRows3) {
                Object field = row4.getField(0);
                Object field2 = row4.getField(1);
                if (this.historyUserItems.containsKey(field)) {
                    this.historyUserItems.get(field).add(field2);
                } else {
                    HashSet hashSet = new HashSet();
                    hashSet.add(field2);
                    this.historyUserItems.put(field, hashSet);
                }
                if (this.historyItemUsers.containsKey(field2)) {
                    this.historyItemUsers.get(field2).add(field);
                } else {
                    HashSet hashSet2 = new HashSet();
                    hashSet2.add(field);
                    this.historyItemUsers.put(field2, hashSet2);
                }
            }
        }
    }

    @Override // com.alibaba.alink.operator.common.recommendation.RecommKernel
    public Double rate(Object[] objArr) {
        return predictRating(objArr);
    }

    @Override // com.alibaba.alink.operator.common.recommendation.RecommKernel
    public MTable recommendItemsPerUser(Object obj) {
        DenseVector denseVector = this.userFactors.get(obj);
        Set<Object> set = null;
        if (this.excludeKnown) {
            set = this.historyUserItems.get(obj);
        }
        return recommend(this.itemColName, denseVector, set, this.itemFactors, KObjectUtil.RATING_NAME);
    }

    @Override // com.alibaba.alink.operator.common.recommendation.RecommKernel
    public MTable recommendUsersPerItem(Object obj) {
        DenseVector denseVector = this.itemFactors.get(obj);
        Set<Object> set = null;
        if (this.excludeKnown) {
            set = this.historyItemUsers.get(obj);
        }
        return recommend(this.userColName, denseVector, set, this.userFactors, KObjectUtil.RATING_NAME);
    }

    @Override // com.alibaba.alink.operator.common.recommendation.RecommKernel
    public MTable recommendSimilarItems(Object obj) {
        DenseVector denseVector = this.itemFactors.get(obj);
        HashSet hashSet = new HashSet();
        hashSet.add(obj);
        return recommend(this.itemColName, denseVector, hashSet, this.itemFactors, KObjectUtil.SCORE_NAME);
    }

    @Override // com.alibaba.alink.operator.common.recommendation.RecommKernel
    public MTable recommendSimilarUsers(Object obj) {
        DenseVector denseVector = this.userFactors.get(obj);
        HashSet hashSet = new HashSet();
        hashSet.add(obj);
        return recommend(this.userColName, denseVector, hashSet, this.userFactors, KObjectUtil.SCORE_NAME);
    }

    private Double predictRating(Object[] objArr) {
        Object obj = objArr[0];
        Object obj2 = objArr[1];
        DenseVector denseVector = this.userFactors.get(obj);
        DenseVector denseVector2 = this.itemFactors.get(obj2);
        if (denseVector == null || denseVector2 == null) {
            return null;
        }
        return Double.valueOf(BLAS.dot(denseVector, denseVector2));
    }

    private MTable recommend(String str, DenseVector denseVector, Set<Object> set, Map<Object, DenseVector> map, String str2) {
        RecommUtils.RecommPriorityQueue recommPriorityQueue = new RecommUtils.RecommPriorityQueue(this.topK.intValue());
        if (denseVector != null) {
            map.forEach((obj, denseVector2) -> {
                if (set == null || !set.contains(obj)) {
                    recommPriorityQueue.addOrReplace(obj, BLAS.dot(denseVector, denseVector2));
                }
            });
        }
        return new MTable(recommPriorityQueue.getOrderedRows(), str + " " + FlinkTypeConverter.getTypeString(this.recommObjType) + "," + str2 + " DOUBLE");
    }

    @Override // com.alibaba.alink.operator.common.recommendation.RecommKernel
    public RecommKernel createNew() {
        return new AlsRecommKernel(getModelSchema(), getDataSchema(), this.params.m1495clone(), this.recommType);
    }
}
