package com.alibaba.alink.operator.common.linear;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.linalg.BLAS;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.operator.common.linear.LabelTypeEnum;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/linear/FeatureLabelUtil.class */
public class FeatureLabelUtil {
    public static Vector getVectorFeature(Object obj, boolean z, Integer num) {
        Vector vector = VectorUtil.getVector(obj);
        return z ? vector.prefix(1.0d) : vector;
    }

    public static Vector getTableFeature(Row row, boolean z, int i, int[] iArr) {
        DenseVector denseVector;
        if (z) {
            denseVector = new DenseVector(i + 1);
            denseVector.set(0, 1.0d);
            for (int i2 = 0; i2 < i; i2++) {
                if (row.getField(iArr[i2]) instanceof Number) {
                    denseVector.set(i2 + 1, ((Number) row.getField(iArr[i2])).doubleValue());
                }
            }
        } else {
            denseVector = new DenseVector(i);
            for (int i3 = 0; i3 < i; i3++) {
                if (row.getField(iArr[i3]) instanceof Number) {
                    denseVector.set(i3, ((Number) row.getField(iArr[i3])).doubleValue());
                }
            }
        }
        return denseVector;
    }

    public static Vector getFeatureVector(Row row, boolean z, int i, int[] iArr, int i2, Integer num) {
        Vector denseVector;
        if (i2 != -1) {
            Vector vector = VectorUtil.getVector(row.getField(i2));
            if (vector instanceof SparseVector) {
                SparseVector sparseVector = (SparseVector) vector;
                if (null != num && sparseVector.size() > 0) {
                    sparseVector.setSize(num.intValue());
                }
                denseVector = z ? sparseVector.prefix(1.0d) : sparseVector;
            } else {
                Vector vector2 = (DenseVector) vector;
                denseVector = z ? vector2.prefix(1.0d) : vector2;
            }
        } else if (z) {
            denseVector = new DenseVector(i + 1);
            denseVector.set(0, 1.0d);
            for (int i3 = 0; i3 < i; i3++) {
                if (row.getField(iArr[i3]) instanceof Number) {
                    denseVector.set(i3 + 1, ((Number) row.getField(iArr[i3])).doubleValue());
                }
            }
        } else {
            denseVector = new DenseVector(i);
            for (int i4 = 0; i4 < i; i4++) {
                if (row.getField(iArr[i4]) instanceof Number) {
                    denseVector.set(i4, ((Number) row.getField(iArr[i4])).doubleValue());
                }
            }
        }
        return denseVector;
    }

    public static double getLabelValue(Row row, boolean z, int i, String str) {
        return z ? ((Number) row.getField(i)).doubleValue() : row.getField(i).toString().equals(str) ? 1.0d : -1.0d;
    }

    public static Object[] recoverLabelType(Object[] objArr, TypeInformation typeInformation) {
        if (objArr == null) {
            return null;
        }
        for (int i = 0; i < objArr.length; i++) {
            Object obj = objArr[i];
            if (obj != null) {
                if (obj instanceof String) {
                    try {
                        objArr[i] = LabelTypeEnum.StringTypeEnum.valueOf(typeInformation.toString().toUpperCase()).getOperation().apply((String) obj);
                    } catch (Exception e) {
                        throw new AkIllegalDataException("Unknown label type: " + typeInformation);
                    }
                } else if (obj instanceof Double) {
                    objArr[i] = LabelTypeEnum.DoubleTypeEnum.valueOf(typeInformation.toString().toUpperCase()).getOperation().apply((Double) obj);
                }
            }
        }
        return objArr;
    }

    public static double dot(Vector vector, DenseVector denseVector) {
        if (vector instanceof DenseVector) {
            if (vector.size() == denseVector.size()) {
                return BLAS.dot((DenseVector) vector, denseVector);
            }
            double d = 0.0d;
            int min = Math.min(vector.size(), denseVector.size());
            for (int i = 0; i < min; i++) {
                d += vector.get(i) * denseVector.get(i);
            }
            return d;
        }
        double[] values = ((SparseVector) vector).getValues();
        int[] indices = ((SparseVector) vector).getIndices();
        double d2 = 0.0d;
        for (int i2 = 0; i2 < indices.length; i2++) {
            if (indices[i2] < denseVector.size()) {
                d2 += values[i2] * denseVector.get(indices[i2]);
            }
        }
        return d2;
    }

    public static double getWeightValue(Row row, int i) {
        if (i >= 0) {
            return ((Number) row.getField(i)).doubleValue();
        }
        return 1.0d;
    }
}
