package com.alibaba.alink.common.linalg.tensor;

import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkParseErrorException;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.google.common.primitives.Longs;
import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.buffer.DataBuffer;

/* loaded from: input_file:com/alibaba/alink/common/linalg/tensor/TensorUtil.class */
public class TensorUtil {
    static final char ELEMENT_DELIMITER = ' ';
    static final String ELEMENT_DELIMITER_STR = " ";
    static final char HEADER_DELIMITER = '#';
    static final String HEADER_DELIMITER_STR = "#";
    static final char SHAPE_DELIMITER = ',';
    static final String SHAPE_DELIMITER_STR = ",";
    static final int NULL_STRING_LENGTH = -1;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.alibaba.alink.common.linalg.tensor.TensorUtil$1, reason: invalid class name */
    /* loaded from: input_file:com/alibaba/alink/common/linalg/tensor/TensorUtil$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$alibaba$alink$common$linalg$tensor$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$com$alibaba$alink$common$linalg$tensor$DataType[DataType.FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$alibaba$alink$common$linalg$tensor$DataType[DataType.DOUBLE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$alibaba$alink$common$linalg$tensor$DataType[DataType.INT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$alibaba$alink$common$linalg$tensor$DataType[DataType.LONG.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$alibaba$alink$common$linalg$tensor$DataType[DataType.BOOLEAN.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$com$alibaba$alink$common$linalg$tensor$DataType[DataType.BYTE.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$com$alibaba$alink$common$linalg$tensor$DataType[DataType.UBYTE.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$com$alibaba$alink$common$linalg$tensor$DataType[DataType.STRING.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/common/linalg/tensor/TensorUtil$CoordInc.class */
    static class CoordInc {
        private final long[] refs;
        private final long[] shapes;
        private final int start;
        private final int end;

        public CoordInc(long[] jArr, int i, long[] jArr2) {
            this(jArr, 0, i, jArr2);
        }

        public CoordInc(long[] jArr, int i, int i2, long[] jArr2) {
            this.refs = jArr2;
            this.shapes = jArr;
            this.start = i;
            this.end = i2;
        }

        public void inc() {
            int i = this.end - 1;
            int i2 = (this.end - this.start) - 1;
            while (i >= this.start) {
                if (this.refs[i2] < this.shapes[i] - 1) {
                    long[] jArr = this.refs;
                    int i3 = i2;
                    jArr[i3] = jArr[i3] + 1;
                    return;
                } else {
                    this.refs[i2] = 0;
                    i--;
                    i2--;
                }
            }
        }

        public void reset() {
            Arrays.fill(this.refs, 0, this.end - this.start, 0L);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/common/linalg/tensor/TensorUtil$DoCalcFunctions.class */
    interface DoCalcFunctions<DT, DTARRAY> {
        DTARRAY createArray(int i);

        void initial(DTARRAY dtarray);

        void calc(DTARRAY dtarray, NdArray<DT> ndArray, long[] jArr, int i);

        DataBuffer<DT> write(DTARRAY dtarray);

        default void post(DTARRAY dtarray, int i) {
        }
    }

    public static Tensor<?> getTensor(Object obj) {
        if (null == obj) {
            return null;
        }
        if (obj instanceof Tensor) {
            return (Tensor) obj;
        }
        if (obj instanceof Vector) {
            return fromDenseVector(VectorUtil.getDenseVector(obj));
        }
        if (!(obj instanceof String)) {
            if (obj instanceof Number) {
                return fromDenseVector(new DenseVector(new double[]{((Number) obj).doubleValue()}));
            }
            throw new AkIllegalArgumentException("Can not get the tensor from " + obj);
        }
        String str = (String) obj;
        if (isTensor(str)) {
            return parseTensor(str);
        }
        try {
            return fromDenseVector(VectorUtil.getDenseVector(str));
        } catch (Exception e) {
            return new StringTensor(str);
        }
    }

    public static Tensor<?> parseTensor(String str) {
        Tensor stringTensor;
        String[] splitPreserveAllTokens = StringUtils.splitPreserveAllTokens(str, HEADER_DELIMITER_STR, 3);
        if (splitPreserveAllTokens.length != 3) {
            throw new AkParseErrorException("Illegal tensor string: " + str);
        }
        DataType valueOf = DataType.valueOf(splitPreserveAllTokens[0]);
        Shape parseShapeStr = parseShapeStr(splitPreserveAllTokens[1]);
        String[] parseStringValueStr = DataType.STRING.equals(valueOf) ? parseStringValueStr(splitPreserveAllTokens[2], Math.toIntExact(parseShapeStr.size())) : splitPreserveAllTokens[2].split(" ");
        switch (AnonymousClass1.$SwitchMap$com$alibaba$alink$common$linalg$tensor$DataType[valueOf.ordinal()]) {
            case 1:
                stringTensor = new FloatTensor(parseShapeStr);
                break;
            case 2:
                stringTensor = new DoubleTensor(parseShapeStr);
                break;
            case 3:
                stringTensor = new IntTensor(parseShapeStr);
                break;
            case 4:
                stringTensor = new LongTensor(parseShapeStr);
                break;
            case 5:
                stringTensor = new BoolTensor(parseShapeStr);
                break;
            case TableUtil.DISPLAY_SIZE /* 6 */:
                stringTensor = new ByteTensor(parseShapeStr);
                break;
            case 7:
                stringTensor = new UByteTensor(parseShapeStr);
                break;
            case 8:
                stringTensor = new StringTensor(parseShapeStr);
                break;
            default:
                throw new AkUnsupportedOperationException("Data type is not supported: " + valueOf);
        }
        stringTensor.parseFromValueStrings(parseStringValueStr);
        return stringTensor;
    }

    private static String[] parseStringValueStr(String str, int i) {
        String[] split = StringUtils.split(str, HEADER_DELIMITER_STR, 2);
        int[] array = Arrays.stream(StringUtils.split(split[0], ' ')).mapToInt(Integer::parseInt).toArray();
        if (array.length != i) {
            throw new AkIllegalDataException("Illegal lengths section in tensor string: " + str);
        }
        String str2 = split.length > 1 ? split[1] : "";
        String[] strArr = new String[i];
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            if (-1 != array[i3]) {
                strArr[i3] = str2.substring(i2, i2 + array[i3]);
                i2 += array[i3] + 1;
            }
        }
        return strArr;
    }

    public static String toString(Tensor<?> tensor) {
        StringBuilder sb = new StringBuilder();
        sb.append(tensor.type.name());
        sb.append('#');
        sb.append(toString(Shape.fromNdArrayShape(tensor.data.shape())));
        sb.append('#');
        String[] valueStrings = tensor.getValueStrings();
        if (tensor instanceof StringTensor) {
            int length = valueStrings.length;
            for (int i = 0; i < length; i++) {
                String str = valueStrings[i];
                sb.append(null != str ? str.length() : -1).append(' ');
            }
            sb.append('#');
        }
        for (String str2 : valueStrings) {
            if (null != str2) {
                sb.append(str2).append(' ');
            }
        }
        return sb.toString();
    }

    public static String serialize(Object obj) {
        return toString((Tensor<?>) obj);
    }

    static Shape parseShapeStr(String str) {
        return str.isEmpty() ? new Shape(new long[0]) : new Shape(Arrays.stream(str.split(",")).mapToLong(Long::parseLong).toArray());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static String toString(Shape shape) {
        return Longs.join(",", shape.asArray());
    }

    private static Tensor<?> fromDenseVector(DenseVector denseVector) {
        if (denseVector == null) {
            return null;
        }
        return new DoubleTensor(denseVector.getData());
    }

    private static boolean isTensor(String str) {
        return str.contains(HEADER_DELIMITER_STR);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static <DT, T extends Tensor<DT>> T of(long[] jArr, DataType dataType) {
        switch (AnonymousClass1.$SwitchMap$com$alibaba$alink$common$linalg$tensor$DataType[dataType.ordinal()]) {
            case 1:
                return new FloatTensor(new Shape(jArr));
            case 2:
                return new DoubleTensor(new Shape(jArr));
            case 3:
                return new IntTensor(new Shape(jArr));
            case 4:
                return new LongTensor(new Shape(jArr));
            case 5:
                return new BoolTensor(new Shape(jArr));
            case TableUtil.DISPLAY_SIZE /* 6 */:
                return new ByteTensor(new Shape(jArr));
            case 7:
                return new UByteTensor(new Shape(jArr));
            case 8:
                return new StringTensor(new Shape(jArr));
            default:
                throw new AkUnsupportedOperationException("Failed to cast to tensor, unsupported DataType");
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static long wrapDim(long j, long j2) {
        if (j2 == 0) {
            if (j == 0) {
                return j;
            }
            throw new AkIllegalDataException("Dim is not 0 when nDims is 0.");
        }
        long j3 = -j2;
        long j4 = j2 - 1;
        if (j < j3 || j > j4) {
            throw new AkIllegalDataException(String.format("Dimension is outbound. Dim: %d, min: %d, max: %d.", Long.valueOf(j), Long.valueOf(j3), Long.valueOf(j4)));
        }
        return j < 0 ? j + j2 : j;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static <DT, DTARRAY, T extends Tensor<DT>> T doCalc(T t, int i, boolean z, DoCalcFunctions<DT, DTARRAY> doCalcFunctions) {
        long[] remove;
        long[] remove2;
        long[] shape = t.shape();
        int length = shape.length;
        int wrapDim = (int) wrapDim(i, length);
        int i2 = 1;
        for (int i3 = 0; i3 < wrapDim; i3++) {
            i2 = (int) (i2 * shape[i3]);
        }
        int i4 = 1;
        for (int i5 = wrapDim + 1; i5 < length; i5++) {
            i4 = (int) (i4 * shape[i5]);
        }
        long[] jArr = new long[wrapDim + 1];
        CoordInc coordInc = new CoordInc(shape, wrapDim, jArr);
        long[] jArr2 = new long[(length - wrapDim) - 1];
        CoordInc coordInc2 = new CoordInc(shape, wrapDim + 1, length, jArr2);
        if (z) {
            remove = new long[length];
            System.arraycopy(shape, 0, remove, 0, length);
            remove[wrapDim] = 1;
        } else {
            remove = ArrayUtils.remove(shape, wrapDim);
        }
        int i6 = (int) shape[wrapDim];
        T t2 = (T) of(remove, t.getType());
        DTARRAY createArray = doCalcFunctions.createArray(i4);
        int i7 = 0;
        while (i7 < i2) {
            doCalcFunctions.initial(createArray);
            for (int i8 = 0; i8 < i6; i8++) {
                jArr[wrapDim] = i8;
                NdArray<DT> ndArray = t.getData().get(jArr);
                coordInc2.reset();
                int i9 = 0;
                while (i9 < i4) {
                    doCalcFunctions.calc(createArray, ndArray, jArr2, i9);
                    i9++;
                    coordInc2.inc();
                }
            }
            doCalcFunctions.post(createArray, i6);
            if (z) {
                jArr[wrapDim] = 0;
                remove2 = jArr;
            } else {
                remove2 = ArrayUtils.remove(jArr, wrapDim);
            }
            t2.getData().get(remove2).write(doCalcFunctions.write(createArray));
            i7++;
            coordInc.inc();
        }
        return t2;
    }
}
