package com.alibaba.alink.operator.common.evaluation;

import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/evaluation/RankingMetrics.class */
public class RankingMetrics extends BaseSimpleMultiLabelMetrics<RankingMetrics> {
    private static final long serialVersionUID = 3357268529296753541L;
    static final ParamInfo<Double> HIT_RATE = ParamInfoFactory.createParamInfo("hitRate", Double.class).setDescription("hit rate").setRequired().build();
    static final ParamInfo<Double> AVE_RECIPRO_HIT_RANK = ParamInfoFactory.createParamInfo("averageReciprocalHitRank", Double.class).setDescription("Average Reciprocal Hit Rank").setRequired().build();
    static final ParamInfo<double[]> PRECISION_ARRAY = ParamInfoFactory.createParamInfo("precisionArray", double[].class).setDescription("precision list, PRECISION: TP / (TP + FP)").setRequired().build();
    static final ParamInfo<double[]> RECALL_ARRAY = ParamInfoFactory.createParamInfo("RecallArray", double[].class).setDescription("recall list, recall == TPR").setRequired().build();
    static final ParamInfo<Double> MAP = ParamInfoFactory.createParamInfo("map", Double.class).setDescription("map").setRequired().build();
    static final ParamInfo<double[]> NDCG_ARRAY = ParamInfoFactory.createParamInfo("ndcgArray", double[].class).setDescription("ndcg").setRequired().build();

    public RankingMetrics(Row row) {
        super(row);
    }

    public RankingMetrics(Params params) {
        super(params);
    }

    public double getNdcg(int i) {
        double[] dArr = (double[]) getParams().get(NDCG_ARRAY);
        return i - 1 >= dArr.length ? dArr[dArr.length - 1] : dArr[i - 1];
    }

    public double getMap() {
        return ((Double) get(MAP)).doubleValue();
    }

    public double getPrecisionAtK(int i) {
        double[] dArr = (double[]) getParams().get(PRECISION_ARRAY);
        return i - 1 >= dArr.length ? (dArr[dArr.length - 1] * dArr.length) / i : dArr[i - 1];
    }

    public double getHitRate() {
        return ((Double) get(HIT_RATE)).doubleValue();
    }

    public double getArHr() {
        return ((Double) get(AVE_RECIPRO_HIT_RANK)).doubleValue();
    }

    public double getRecallAtK(int i) {
        double[] dArr = (double[]) getParams().get(RECALL_ARRAY);
        return i - 1 >= dArr.length ? dArr[dArr.length - 1] : dArr[i - 1];
    }
}
