package com.alibaba.alink.common.linalg.tensor;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.tensor.TensorUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import java.util.Arrays;
import org.tensorflow.ndarray.DoubleNdArray;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;

/* loaded from: input_file:com/alibaba/alink/common/linalg/tensor/DoubleTensor.class */
public final class DoubleTensor extends NumericalTensor<Double> {
    public DoubleTensor(Shape shape) {
        this(NdArrays.ofDoubles(shape.toNdArrayShape()));
    }

    public DoubleTensor(double d) {
        this(NdArrays.scalarOf(d));
    }

    public DoubleTensor(double[] dArr) {
        this(StdArrays.ndCopyOf(dArr));
    }

    public DoubleTensor(double[][] dArr) {
        this(StdArrays.ndCopyOf(dArr));
    }

    public DoubleTensor(double[][][] dArr) {
        this(StdArrays.ndCopyOf(dArr));
    }

    public DoubleTensor(DoubleNdArray doubleNdArray) {
        super(doubleNdArray, DataType.DOUBLE);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // com.alibaba.alink.common.linalg.tensor.Tensor
    public void parseFromValueStrings(String[] strArr) {
        double[] dArr = new double[strArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = Double.parseDouble(strArr[i]);
        }
        this.data.write(DataBuffers.of(dArr));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // com.alibaba.alink.common.linalg.tensor.Tensor
    public String[] getValueStrings() {
        int intExact = Math.toIntExact(size());
        double[] dArr = new double[intExact];
        this.data.read(DataBuffers.of(dArr, false, false));
        String[] strArr = new String[intExact];
        for (int i = 0; i < intExact; i++) {
            strArr[i] = Double.toString(dArr[i]);
        }
        return strArr;
    }

    @Override // com.alibaba.alink.common.linalg.tensor.Tensor
    /* renamed from: reshape */
    public DoubleTensor reshape2(Shape shape) {
        AkPreconditions.checkArgument(shape.size() == size(), "Shape not matched.");
        DoubleDataBuffer ofDoubles = DataBuffers.ofDoubles(size());
        this.data.read(ofDoubles);
        return new DoubleTensor(NdArrays.wrap(shape.toNdArrayShape(), ofDoubles));
    }

    @Override // com.alibaba.alink.common.linalg.tensor.NumericalTensor
    public DoubleTensor min(int i, boolean z) {
        return (DoubleTensor) TensorUtil.doCalc(this, i, z, new TensorUtil.DoCalcFunctions<Double, double[]>() { // from class: com.alibaba.alink.common.linalg.tensor.DoubleTensor.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public double[] createArray(int i2) {
                return new double[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void initial(double[] dArr) {
                Arrays.fill(dArr, Double.MAX_VALUE);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void calc(double[] dArr, NdArray<Double> ndArray, long[] jArr, int i2) {
                dArr[i2] = Math.min(((Double) ndArray.getObject(jArr)).doubleValue(), dArr[i2]);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public DataBuffer<Double> write(double[] dArr) {
                return DataBuffers.of(dArr);
            }
        });
    }

    @Override // com.alibaba.alink.common.linalg.tensor.NumericalTensor
    public DoubleTensor max(int i, boolean z) {
        return (DoubleTensor) TensorUtil.doCalc(this, i, z, new TensorUtil.DoCalcFunctions<Double, double[]>() { // from class: com.alibaba.alink.common.linalg.tensor.DoubleTensor.2
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public double[] createArray(int i2) {
                return new double[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void initial(double[] dArr) {
                Arrays.fill(dArr, -1.7976931348623157E308d);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void calc(double[] dArr, NdArray<Double> ndArray, long[] jArr, int i2) {
                dArr[i2] = Math.max(((Double) ndArray.getObject(jArr)).doubleValue(), dArr[i2]);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public DataBuffer<Double> write(double[] dArr) {
                return DataBuffers.of(dArr);
            }
        });
    }

    @Override // com.alibaba.alink.common.linalg.tensor.NumericalTensor
    public DoubleTensor sum(int i, boolean z) {
        return (DoubleTensor) TensorUtil.doCalc(this, i, z, new TensorUtil.DoCalcFunctions<Double, double[]>() { // from class: com.alibaba.alink.common.linalg.tensor.DoubleTensor.3
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public double[] createArray(int i2) {
                return new double[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void initial(double[] dArr) {
                Arrays.fill(dArr, Criteria.INVALID_GAIN);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void calc(double[] dArr, NdArray<Double> ndArray, long[] jArr, int i2) {
                dArr[i2] = ((Double) ndArray.getObject(jArr)).doubleValue() + dArr[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public DataBuffer<Double> write(double[] dArr) {
                return DataBuffers.of(dArr);
            }
        });
    }

    @Override // com.alibaba.alink.common.linalg.tensor.NumericalTensor
    public DoubleTensor mean(int i, boolean z) {
        return (DoubleTensor) TensorUtil.doCalc(this, i, z, new TensorUtil.DoCalcFunctions<Double, double[]>() { // from class: com.alibaba.alink.common.linalg.tensor.DoubleTensor.4
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public double[] createArray(int i2) {
                return new double[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void initial(double[] dArr) {
                Arrays.fill(dArr, Criteria.INVALID_GAIN);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void calc(double[] dArr, NdArray<Double> ndArray, long[] jArr, int i2) {
                dArr[i2] = ((Double) ndArray.getObject(jArr)).doubleValue() + dArr[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public DataBuffer<Double> write(double[] dArr) {
                return DataBuffers.of(dArr);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void post(double[] dArr, int i2) {
                double d = 1.0d / i2;
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    int i4 = i3;
                    dArr[i4] = dArr[i4] * d;
                }
            }
        });
    }

    public double getDouble(long... jArr) {
        return this.data.getDouble(jArr);
    }

    public DoubleTensor setDouble(double d, long... jArr) {
        this.data.setDouble(d, jArr);
        return this;
    }

    public DenseVector toVector() {
        double[] dArr = new double[Math.toIntExact(size())];
        this.data.read(DataBuffers.of(dArr, false, false));
        return new DenseVector(dArr);
    }

    public static DoubleTensor of(Tensor<?> tensor) {
        if (!(tensor instanceof NumericalTensor)) {
            throw new AkUnsupportedOperationException(String.format("Only numerical tensor can ben cast to double tensor. tensor type: %s", tensor.getType()));
        }
        switch (tensor.getType()) {
            case DOUBLE:
                return (DoubleTensor) tensor;
            case FLOAT:
                FloatTensor floatTensor = (FloatTensor) tensor;
                DoubleTensor doubleTensor = new DoubleTensor(new Shape(floatTensor.shape()));
                floatTensor.getData().scalars().forEachIndexed((jArr, floatNdArray) -> {
                    doubleTensor.setDouble(floatNdArray.getFloat(new long[0]), jArr);
                });
                return doubleTensor;
            case INT:
                IntTensor intTensor = (IntTensor) tensor;
                DoubleTensor doubleTensor2 = new DoubleTensor(new Shape(intTensor.shape()));
                intTensor.getData().scalars().forEachIndexed((jArr2, intNdArray) -> {
                    doubleTensor2.setDouble(intNdArray.getInt(new long[0]), jArr2);
                });
                return doubleTensor2;
            case LONG:
                LongTensor longTensor = (LongTensor) tensor;
                DoubleTensor doubleTensor3 = new DoubleTensor(new Shape(longTensor.shape()));
                longTensor.getData().scalars().forEachIndexed((jArr3, longNdArray) -> {
                    doubleTensor3.setDouble(longNdArray.getLong(new long[0]), jArr3);
                });
                return doubleTensor3;
            default:
                throw new AkUnsupportedOperationException(String.format("Only numerical tensor can ben cast to double tensor. tensor type: %s", tensor.getType()));
        }
    }
}
