package com.alibaba.alink.operator.common.classification.tensorflow;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.linalg.tensor.FloatTensor;
import java.util.List;
import java.util.Map;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/alibaba/alink/operator/common/classification/tensorflow/PredictionExtractUtils.class */
public class PredictionExtractUtils {
    PredictionExtractUtils() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Object extractFromTensor(FloatTensor floatTensor, List<Object> list, Map<Object, Double> map, boolean z) {
        AkPreconditions.checkState(floatTensor.shape().length <= 1, "The prediction tensor must be rank-0 or rank-1");
        if (floatTensor.size() == 1) {
            double d = floatTensor.shape().length == 0 ? floatTensor.getFloat(new long[0]) : floatTensor.getFloat(0);
            if (z) {
                d = 1.0d / (1.0d + Math.exp(-d));
            }
            Object obj = list.get(0);
            Object obj2 = list.get(1);
            Object obj3 = d >= 0.5d ? obj2 : obj;
            map.put(obj2, Double.valueOf(d));
            map.put(obj, Double.valueOf(1.0d - d));
            return obj3;
        }
        int i = 0;
        if (z) {
            double[] dArr = new double[list.size()];
            double d2 = 0.0d;
            for (int i2 = 0; i2 < list.size(); i2++) {
                dArr[i2] = Math.exp(floatTensor.getFloat(i2));
                d2 += dArr[i2];
            }
            for (int i3 = 0; i3 < list.size(); i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] / d2;
                map.put(list.get(i3), Double.valueOf(dArr[i3]));
                if (dArr[i3] > dArr[i]) {
                    i = i3;
                }
            }
        } else {
            for (int i5 = 0; i5 < list.size(); i5++) {
                double d3 = floatTensor.getFloat(i5);
                map.put(list.get(i5), Double.valueOf(d3));
                if (d3 > floatTensor.getFloat(i)) {
                    i = i5;
                }
            }
        }
        return list.get(i);
    }
}
