package com.alibaba.alink.common.linalg.tensor;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.tensor.TensorUtil;
import java.util.Arrays;
import org.tensorflow.ndarray.LongNdArray;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.LongDataBuffer;

/* loaded from: input_file:com/alibaba/alink/common/linalg/tensor/LongTensor.class */
public final class LongTensor extends NumericalTensor<Long> {
    public LongTensor(Shape shape) {
        this(NdArrays.ofLongs(shape.toNdArrayShape()));
    }

    public LongTensor(long j) {
        this(NdArrays.scalarOf(j));
    }

    public LongTensor(long[] jArr) {
        this(StdArrays.ndCopyOf(jArr));
    }

    public LongTensor(long[][] jArr) {
        this(StdArrays.ndCopyOf(jArr));
    }

    public LongTensor(long[][][] jArr) {
        this(StdArrays.ndCopyOf(jArr));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public LongTensor(LongNdArray longNdArray) {
        super(longNdArray, DataType.LONG);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // com.alibaba.alink.common.linalg.tensor.Tensor
    public void parseFromValueStrings(String[] strArr) {
        long[] jArr = new long[strArr.length];
        for (int i = 0; i < jArr.length; i++) {
            jArr[i] = Long.parseLong(strArr[i]);
        }
        this.data.write(DataBuffers.of(jArr));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // com.alibaba.alink.common.linalg.tensor.Tensor
    public String[] getValueStrings() {
        int intExact = Math.toIntExact(size());
        long[] jArr = new long[intExact];
        this.data.read(DataBuffers.of(jArr, false, false));
        String[] strArr = new String[intExact];
        for (int i = 0; i < intExact; i++) {
            strArr[i] = Long.toString(jArr[i]);
        }
        return strArr;
    }

    public long getLong(long... jArr) {
        return this.data.getLong(jArr);
    }

    public LongTensor setLong(long j, long... jArr) {
        this.data.setLong(j, jArr);
        return this;
    }

    @Override // com.alibaba.alink.common.linalg.tensor.Tensor
    /* renamed from: reshape */
    public LongTensor reshape2(Shape shape) {
        AkPreconditions.checkArgument(shape.size() == size(), "Shape not matched.");
        LongDataBuffer ofLongs = DataBuffers.ofLongs(size());
        this.data.read(ofLongs);
        return new LongTensor(NdArrays.wrap(shape.toNdArrayShape(), ofLongs));
    }

    @Override // com.alibaba.alink.common.linalg.tensor.NumericalTensor
    public LongTensor min(int i, boolean z) {
        return (LongTensor) TensorUtil.doCalc(this, i, z, new TensorUtil.DoCalcFunctions<Long, long[]>() { // from class: com.alibaba.alink.common.linalg.tensor.LongTensor.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public long[] createArray(int i2) {
                return new long[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void initial(long[] jArr) {
                Arrays.fill(jArr, Long.MAX_VALUE);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void calc(long[] jArr, NdArray<Long> ndArray, long[] jArr2, int i2) {
                jArr[i2] = Math.min(((Long) ndArray.getObject(jArr2)).longValue(), jArr[i2]);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public DataBuffer<Long> write(long[] jArr) {
                return DataBuffers.of(jArr);
            }
        });
    }

    @Override // com.alibaba.alink.common.linalg.tensor.NumericalTensor
    public LongTensor max(int i, boolean z) {
        return (LongTensor) TensorUtil.doCalc(this, i, z, new TensorUtil.DoCalcFunctions<Long, long[]>() { // from class: com.alibaba.alink.common.linalg.tensor.LongTensor.2
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public long[] createArray(int i2) {
                return new long[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void initial(long[] jArr) {
                Arrays.fill(jArr, Long.MIN_VALUE);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void calc(long[] jArr, NdArray<Long> ndArray, long[] jArr2, int i2) {
                jArr[i2] = Math.max(((Long) ndArray.getObject(jArr2)).longValue(), jArr[i2]);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public DataBuffer<Long> write(long[] jArr) {
                return DataBuffers.of(jArr);
            }
        });
    }

    @Override // com.alibaba.alink.common.linalg.tensor.NumericalTensor
    public LongTensor sum(int i, boolean z) {
        return (LongTensor) TensorUtil.doCalc(this, i, z, new TensorUtil.DoCalcFunctions<Long, long[]>() { // from class: com.alibaba.alink.common.linalg.tensor.LongTensor.3
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public long[] createArray(int i2) {
                return new long[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void initial(long[] jArr) {
                Arrays.fill(jArr, 0L);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void calc(long[] jArr, NdArray<Long> ndArray, long[] jArr2, int i2) {
                jArr[i2] = ((Long) ndArray.getObject(jArr2)).longValue() + jArr[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public DataBuffer<Long> write(long[] jArr) {
                return DataBuffers.of(jArr);
            }
        });
    }

    @Override // com.alibaba.alink.common.linalg.tensor.NumericalTensor
    public LongTensor mean(int i, boolean z) {
        throw new AkUnsupportedOperationException("Not support exception. ");
    }

    public static LongTensor of(Tensor<?> tensor) {
        if (!(tensor instanceof IntTensor)) {
            if (tensor instanceof LongTensor) {
                return (LongTensor) tensor;
            }
            throw new AkUnsupportedOperationException(String.format("Failed to cast to long tensor. tensor type: %s", tensor.getType()));
        }
        IntTensor intTensor = (IntTensor) tensor;
        LongTensor longTensor = new LongTensor(new Shape(intTensor.shape()));
        intTensor.getData().scalars().forEachIndexed((jArr, intNdArray) -> {
            longTensor.setLong(intNdArray.getInt(new long[0]), jArr);
        });
        return longTensor;
    }
}
