package com.alibaba.alink.common.pyrunner.fn.conversion;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.linalg.tensor.DataType;
import com.alibaba.alink.common.linalg.tensor.Tensor;
import com.alibaba.alink.common.linalg.tensor.TensorInternalUtils;
import com.alibaba.alink.common.pyrunner.fn.JavaObjectWrapper;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.tensorflow.ndarray.Shape;

/* loaded from: input_file:com/alibaba/alink/common/pyrunner/fn/conversion/TensorWrapper.class */
public class TensorWrapper implements JavaObjectWrapper<Tensor<?>> {
    private static final Map<DataType, String> TO_NUMPY_DTYPE = new HashMap();
    private static final Map<String, DataType> FROM_NUMPY_DTYPE = new HashMap();
    private Tensor<?> tensor;
    private byte[] bytes;
    private String dtypeStr;
    private long[] shape;

    private TensorWrapper() {
    }

    public static TensorWrapper fromJava(Tensor<?> tensor) {
        TensorWrapper tensorWrapper = new TensorWrapper();
        tensorWrapper.tensor = tensor;
        tensorWrapper.dtypeStr = TO_NUMPY_DTYPE.get(tensor.getType());
        tensorWrapper.shape = tensor.shape();
        tensorWrapper.bytes = TensorInternalUtils.tensorToBytes(tensor);
        return tensorWrapper;
    }

    public static TensorWrapper fromPy(byte[] bArr, String str, List<Integer> list) {
        return fromPy(bArr, str, list.stream().mapToLong(num -> {
            return num.intValue();
        }).toArray());
    }

    public static TensorWrapper fromPy(byte[] bArr, String str, long[] jArr) {
        TensorWrapper tensorWrapper = new TensorWrapper();
        tensorWrapper.bytes = bArr;
        tensorWrapper.dtypeStr = str;
        tensorWrapper.shape = jArr;
        AkPreconditions.checkArgument(FROM_NUMPY_DTYPE.containsKey(str), (ExceptionWithErrorCode) new AkUnsupportedOperationException(String.format("Numpy array of type %s is not supported.", str)));
        tensorWrapper.tensor = TensorInternalUtils.bytesToTensor(bArr, FROM_NUMPY_DTYPE.get(str), Shape.of(jArr));
        return tensorWrapper;
    }

    public String toString() {
        return "TensorWrapper{tensor=" + this.tensor + ", bytes=" + Arrays.toString(this.bytes) + ", dtypeStr='" + this.dtypeStr + "', shape=" + Arrays.toString(this.shape) + '}';
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.common.pyrunner.fn.JavaObjectWrapper
    public Tensor<?> getJavaObject() {
        return this.tensor;
    }

    public byte[] getBytes() {
        return this.bytes;
    }

    public long[] getShape() {
        return this.shape;
    }

    public String getDtypeStr() {
        return this.dtypeStr;
    }

    static {
        TO_NUMPY_DTYPE.put(DataType.FLOAT, "<f4");
        TO_NUMPY_DTYPE.put(DataType.DOUBLE, "<f8");
        TO_NUMPY_DTYPE.put(DataType.INT, "<i4");
        TO_NUMPY_DTYPE.put(DataType.LONG, "<i8");
        TO_NUMPY_DTYPE.put(DataType.BOOLEAN, "<?");
        TO_NUMPY_DTYPE.put(DataType.BYTE, "<b");
        TO_NUMPY_DTYPE.put(DataType.UBYTE, "<B");
        TO_NUMPY_DTYPE.put(DataType.STRING, "<U");
        TO_NUMPY_DTYPE.forEach((dataType, str) -> {
            FROM_NUMPY_DTYPE.put(str, dataType);
        });
        FROM_NUMPY_DTYPE.put("|b1", DataType.BOOLEAN);
        FROM_NUMPY_DTYPE.put("|i1", DataType.BYTE);
        FROM_NUMPY_DTYPE.put("|u1", DataType.UBYTE);
    }
}
