package com.alibaba.alink.common.linalg.tensor;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.tensor.TensorUtil;
import java.util.Arrays;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;

/* loaded from: input_file:com/alibaba/alink/common/linalg/tensor/FloatTensor.class */
public final class FloatTensor extends NumericalTensor<Float> {
    public FloatTensor(Shape shape) {
        this(NdArrays.ofFloats(shape.toNdArrayShape()));
    }

    public FloatTensor(float f) {
        this(NdArrays.scalarOf(f));
    }

    public FloatTensor(float[] fArr) {
        this(StdArrays.ndCopyOf(fArr));
    }

    public FloatTensor(float[][] fArr) {
        this(StdArrays.ndCopyOf(fArr));
    }

    public FloatTensor(float[][][] fArr) {
        this(StdArrays.ndCopyOf(fArr));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public FloatTensor(FloatNdArray floatNdArray) {
        super(floatNdArray, DataType.FLOAT);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // com.alibaba.alink.common.linalg.tensor.Tensor
    public void parseFromValueStrings(String[] strArr) {
        float[] fArr = new float[strArr.length];
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = Float.parseFloat(strArr[i]);
        }
        this.data.write(DataBuffers.of(fArr));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // com.alibaba.alink.common.linalg.tensor.Tensor
    public String[] getValueStrings() {
        int intExact = Math.toIntExact(size());
        float[] fArr = new float[intExact];
        this.data.read(DataBuffers.of(fArr, false, false));
        String[] strArr = new String[intExact];
        for (int i = 0; i < intExact; i++) {
            strArr[i] = Float.toString(fArr[i]);
        }
        return strArr;
    }

    public float getFloat(long... jArr) {
        return this.data.getFloat(jArr);
    }

    public FloatTensor setFloat(float f, long... jArr) {
        this.data.setFloat(f, jArr);
        return this;
    }

    public FloatTensor scale(float f) {
        getData().scalars().forEachIndexed((jArr, floatNdArray) -> {
            floatNdArray.setFloat(floatNdArray.getFloat(new long[0]) * f, new long[0]);
        });
        return this;
    }

    @Override // com.alibaba.alink.common.linalg.tensor.Tensor
    /* renamed from: reshape */
    public FloatTensor reshape2(Shape shape) {
        AkPreconditions.checkArgument(shape.size() == size(), "Shape not matched.");
        FloatDataBuffer ofFloats = DataBuffers.ofFloats(size());
        this.data.read(ofFloats);
        return new FloatTensor(NdArrays.wrap(shape.toNdArrayShape(), ofFloats));
    }

    @Override // com.alibaba.alink.common.linalg.tensor.NumericalTensor
    public FloatTensor min(int i, boolean z) {
        return (FloatTensor) TensorUtil.doCalc(this, i, z, new TensorUtil.DoCalcFunctions<Float, float[]>() { // from class: com.alibaba.alink.common.linalg.tensor.FloatTensor.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public float[] createArray(int i2) {
                return new float[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void initial(float[] fArr) {
                Arrays.fill(fArr, Float.MAX_VALUE);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void calc(float[] fArr, NdArray<Float> ndArray, long[] jArr, int i2) {
                fArr[i2] = Math.min(((Float) ndArray.getObject(jArr)).floatValue(), fArr[i2]);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public DataBuffer<Float> write(float[] fArr) {
                return DataBuffers.of(fArr);
            }
        });
    }

    @Override // com.alibaba.alink.common.linalg.tensor.NumericalTensor
    public FloatTensor max(int i, boolean z) {
        return (FloatTensor) TensorUtil.doCalc(this, i, z, new TensorUtil.DoCalcFunctions<Float, float[]>() { // from class: com.alibaba.alink.common.linalg.tensor.FloatTensor.2
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public float[] createArray(int i2) {
                return new float[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void initial(float[] fArr) {
                Arrays.fill(fArr, -3.4028235E38f);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void calc(float[] fArr, NdArray<Float> ndArray, long[] jArr, int i2) {
                fArr[i2] = Math.max(((Float) ndArray.getObject(jArr)).floatValue(), fArr[i2]);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public DataBuffer<Float> write(float[] fArr) {
                return DataBuffers.of(fArr);
            }
        });
    }

    @Override // com.alibaba.alink.common.linalg.tensor.NumericalTensor
    public FloatTensor sum(int i, boolean z) {
        return (FloatTensor) TensorUtil.doCalc(this, i, z, new TensorUtil.DoCalcFunctions<Float, float[]>() { // from class: com.alibaba.alink.common.linalg.tensor.FloatTensor.3
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public float[] createArray(int i2) {
                return new float[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void initial(float[] fArr) {
                Arrays.fill(fArr, 0.0f);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void calc(float[] fArr, NdArray<Float> ndArray, long[] jArr, int i2) {
                fArr[i2] = ((Float) ndArray.getObject(jArr)).floatValue() + fArr[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public DataBuffer<Float> write(float[] fArr) {
                return DataBuffers.of(fArr);
            }
        });
    }

    @Override // com.alibaba.alink.common.linalg.tensor.NumericalTensor
    public FloatTensor mean(int i, boolean z) {
        return (FloatTensor) TensorUtil.doCalc(this, i, z, new TensorUtil.DoCalcFunctions<Float, float[]>() { // from class: com.alibaba.alink.common.linalg.tensor.FloatTensor.4
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public float[] createArray(int i2) {
                return new float[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void initial(float[] fArr) {
                Arrays.fill(fArr, 0.0f);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void calc(float[] fArr, NdArray<Float> ndArray, long[] jArr, int i2) {
                fArr[i2] = ((Float) ndArray.getObject(jArr)).floatValue() + fArr[i2];
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public DataBuffer<Float> write(float[] fArr) {
                return DataBuffers.of(fArr);
            }

            @Override // com.alibaba.alink.common.linalg.tensor.TensorUtil.DoCalcFunctions
            public void post(float[] fArr, int i2) {
                float f = 1.0f / i2;
                for (int i3 = 0; i3 < fArr.length; i3++) {
                    int i4 = i3;
                    fArr[i4] = fArr[i4] * f;
                }
            }
        });
    }

    public static FloatTensor of(Tensor<?> tensor) {
        if (!(tensor instanceof DoubleTensor)) {
            if (tensor instanceof FloatTensor) {
                return (FloatTensor) tensor;
            }
            throw new AkUnsupportedOperationException(String.format("Failed to cast to float tensor. tensor type: %s", tensor.getType()));
        }
        DoubleTensor doubleTensor = (DoubleTensor) tensor;
        FloatTensor floatTensor = new FloatTensor(new Shape(doubleTensor.shape()));
        doubleTensor.getData().scalars().forEachIndexed((jArr, doubleNdArray) -> {
            floatTensor.setFloat((float) doubleNdArray.getDouble(new long[0]), jArr);
        });
        return floatTensor;
    }
}
