package com.alibaba.alink.common.linalg.tensor;

import com.alibaba.alink.common.io.filesystem.copy.csv.CsvInputFormat;
import com.alibaba.alink.common.linalg.tensor.TensorUtil;
import com.alibaba.alink.common.viz.DataTypeDisplayInterface;
import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils;
import java.io.Serializable;
import java.lang.reflect.Array;
import org.apache.commons.lang3.ArrayUtils;
import org.tensorflow.ndarray.NdArray;

/* loaded from: input_file:com/alibaba/alink/common/linalg/tensor/Tensor.class */
public abstract class Tensor<DT> implements Serializable, DataTypeDisplayInterface {
    protected DataType type;
    protected NdArray<DT> data;

    public long[] shape() {
        return this.data.shape().asArray();
    }

    public long size() {
        return this.data.shape().size();
    }

    public DT getObject(long... jArr) {
        return (DT) this.data.getObject(jArr);
    }

    public Tensor<DT> setObject(DT dt, long... jArr) {
        this.data.setObject(dt, jArr);
        return this;
    }

    public String toString() {
        return toDisplaySummary() + CsvInputFormat.DEFAULT_LINE_DELIMITER + toShortDisplayData();
    }

    public abstract Tensor<DT> reshape(Shape shape);

    public Tensor<DT> flatten(int i, int i2) {
        long[] shape = shape();
        int length = shape.length;
        int wrapDim = (int) TensorUtil.wrapDim(i, length);
        int wrapDim2 = (int) TensorUtil.wrapDim(i2, length);
        long[] jArr = new long[length - (wrapDim2 - wrapDim)];
        if (wrapDim > 0) {
            System.arraycopy(shape, 0, jArr, 0, wrapDim);
        }
        long j = 1;
        for (int i3 = wrapDim; i3 <= wrapDim2; i3++) {
            j *= shape[i3];
        }
        jArr[wrapDim] = j;
        int i4 = wrapDim + 1;
        int i5 = wrapDim2 + 1;
        if (i5 < length) {
            System.arraycopy(shape, i5, jArr, i4, length - i5);
        }
        return reshape(new Shape(jArr));
    }

    public Tensor<DT> flatten() {
        return flatten(0, -1);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof Tensor)) {
            return false;
        }
        Tensor tensor = (Tensor) obj;
        if (this.type.equals(tensor.type)) {
            return this.data.equals(tensor.data);
        }
        return false;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v37, types: [com.alibaba.alink.common.linalg.tensor.Tensor] */
    public static <DT, T extends Tensor<DT>> T stack(T[] tArr, int i, T t) {
        if (tArr == null || tArr.length == 0) {
            return t;
        }
        T t2 = tArr[0];
        int wrapDim = (int) TensorUtil.wrapDim(i, t2.shape().length);
        long[] add = ArrayUtils.add(t2.shape(), wrapDim, tArr.length);
        if (t == null) {
            t = TensorUtil.of(add, t2.type);
        }
        long[] jArr = new long[wrapDim];
        int i2 = 1;
        for (int i3 = 0; i3 < wrapDim; i3++) {
            i2 = (int) (i2 * add[i3]);
        }
        TensorUtil.CoordInc coordInc = new TensorUtil.CoordInc(add, wrapDim, jArr);
        for (int i4 = 0; i4 < i2; i4++) {
            long[] add2 = ArrayUtils.add(jArr, 0L);
            for (int i5 = 0; i5 < tArr.length; i5++) {
                add2[wrapDim] = i5;
                t.getData().set(tArr[i5].getData().get(jArr), add2);
            }
            coordInc.inc();
        }
        return t;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v9 */
    /* JADX WARN: Type inference failed for: r8v0, types: [T extends com.alibaba.alink.common.linalg.tensor.Tensor<DT>[]] */
    /* JADX WARN: Type inference failed for: r8v1 */
    /* JADX WARN: Type inference failed for: r8v2 */
    public static <DT, T extends Tensor<DT>> T[] unstack(T t, int i, T[] tArr) {
        long[] shape = t.shape();
        int wrapDim = (int) TensorUtil.wrapDim(i, shape.length);
        long[] remove = ArrayUtils.remove(shape, wrapDim);
        if (tArr == 0) {
            tArr = (T[]) ((Tensor[]) Array.newInstance(t.getClass(), (int) shape[wrapDim]));
            for (int i2 = 0; i2 < shape[wrapDim]; i2++) {
                tArr[i2] = TensorUtil.of(remove, t.getType());
            }
        }
        long[] jArr = new long[wrapDim + 1];
        int i3 = 1;
        for (int i4 = 0; i4 < wrapDim; i4++) {
            i3 = (int) (i3 * shape[i4]);
        }
        TensorUtil.CoordInc coordInc = new TensorUtil.CoordInc(shape, wrapDim, jArr);
        for (int i5 = 0; i5 < i3; i5++) {
            long[] subarray = ArrayUtils.subarray(jArr, 0, wrapDim);
            for (int i6 = 0; i6 < tArr.length; i6++) {
                jArr[wrapDim] = i6;
                tArr[i6].getData().set(t.getData().get(jArr), subarray);
            }
            coordInc.inc();
        }
        return tArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v12, types: [long[]] */
    /* JADX WARN: Type inference failed for: r0v17, types: [java.lang.Object] */
    /* JADX WARN: Type inference failed for: r0v53, types: [com.alibaba.alink.common.linalg.tensor.Tensor] */
    /* JADX WARN: Type inference failed for: r1v17 */
    /* JADX WARN: Type inference failed for: r1v18, types: [long] */
    /* JADX WARN: Type inference failed for: r1v31 */
    /* JADX WARN: Type inference failed for: r1v32, types: [long] */
    public static <DT, T extends Tensor<DT>> T cat(T[] tArr, int i, T t) {
        if (tArr == null || tArr.length == 0) {
            return t;
        }
        T t2 = tArr[0];
        int wrapDim = (int) TensorUtil.wrapDim(i, t2.shape().length);
        ?? r0 = new long[tArr.length];
        int i2 = 0;
        for (int i3 = 0; i3 < tArr.length; i3++) {
            r0[i3] = tArr[i3].shape();
            i2 = (int) (i2 + r0[i3][wrapDim]);
        }
        long[] jArr = (long[]) r0[0].clone();
        jArr[wrapDim] = i2;
        if (t == null) {
            t = TensorUtil.of(jArr, t2.type);
        }
        long[] jArr2 = new long[wrapDim + 1];
        int i4 = 1;
        for (int i5 = 0; i5 < wrapDim; i5++) {
            i4 = (int) (i4 * jArr[i5]);
        }
        TensorUtil.CoordInc coordInc = new TensorUtil.CoordInc(jArr, wrapDim, jArr2);
        for (int i6 = 0; i6 < i4; i6++) {
            long[] jArr3 = (long[]) jArr2.clone();
            jArr3[wrapDim] = 0;
            for (int i7 = 0; i7 < tArr.length; i7++) {
                int i8 = 0;
                while (i8 < r0[i7][wrapDim]) {
                    jArr2[wrapDim] = i8;
                    t.getData().set(tArr[i7].getData().get(jArr2), jArr3);
                    i8++;
                    jArr3[wrapDim] = jArr3[wrapDim] + 1;
                }
            }
            coordInc.inc();
        }
        return t;
    }

    public static <DT, T extends Tensor<DT>> T permute(T t, long... jArr) {
        return t;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Tensor(NdArray<DT> ndArray, DataType dataType) {
        this.data = ndArray;
        this.type = dataType;
    }

    public DataType getType() {
        return this.type;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public NdArray<DT> getData() {
        return this.data;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public abstract void parseFromValueStrings(String[] strArr);

    /* JADX INFO: Access modifiers changed from: package-private */
    public abstract String[] getValueStrings();

    @Override // com.alibaba.alink.common.viz.DataTypeDisplayInterface
    public String toDisplaySummary() {
        return this.type.toString().substring(0, 1) + this.type.toString().substring(1).toLowerCase() + "Tensor(" + TensorUtil.toString(Shape.fromNdArrayShape(this.data.shape())) + ")";
    }

    @Override // com.alibaba.alink.common.viz.DataTypeDisplayInterface
    public String toDisplayData(int i) {
        return PrettyDisplayUtils.displayTensor(shape(), getValueStrings(), i);
    }

    @Override // com.alibaba.alink.common.viz.DataTypeDisplayInterface
    public String toShortDisplayData() {
        return toDisplayData(3);
    }
}
