package com.alibaba.alink.operator.common.dataproc;

import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.tensor.DataType;
import com.alibaba.alink.common.linalg.tensor.DoubleTensor;
import com.alibaba.alink.common.linalg.tensor.FloatTensor;
import com.alibaba.alink.common.linalg.tensor.Shape;
import com.alibaba.alink.common.linalg.tensor.StringTensor;
import com.alibaba.alink.common.linalg.tensor.Tensor;
import com.alibaba.alink.common.linalg.tensor.TensorUtil;
import com.alibaba.alink.common.mapper.SISOMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.dataproc.ToTensorParams;
import com.alibaba.alink.params.shared.HasHandleInvalid;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/ToTensorMapper.class */
public class ToTensorMapper extends SISOMapper {
    private final long[] shape;
    private final DataType targetDataType;
    private final HasHandleInvalid.HandleInvalidMethod handleInvalidMethod;

    public ToTensorMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
        Long[] lArr = (Long[]) params.get(ToTensorParams.TENSOR_SHAPE);
        this.shape = lArr == null ? null : ArrayUtils.toPrimitive(lArr);
        if (params.contains(ToTensorParams.TENSOR_DATA_TYPE)) {
            this.targetDataType = (DataType) params.get(ToTensorParams.TENSOR_DATA_TYPE);
        } else {
            this.targetDataType = null;
        }
        this.handleInvalidMethod = (HasHandleInvalid.HandleInvalidMethod) params.get(ToTensorParams.HANDLE_INVALID);
    }

    @Override // com.alibaba.alink.common.mapper.SISOMapper
    protected Object mapColumn(Object obj) {
        if (null == obj) {
            return null;
        }
        Tensor<?> tensor = null;
        if (DataType.STRING.equals(this.targetDataType) && (obj instanceof String)) {
            tensor = new StringTensor((String) obj);
        } else {
            try {
                tensor = TensorUtil.getTensor(obj);
            } catch (Exception e) {
                switch (this.handleInvalidMethod) {
                    case ERROR:
                        throw e;
                    case SKIP:
                        break;
                    default:
                        throw new AkUnsupportedOperationException("Not support exception. ");
                }
            }
        }
        if (tensor == null) {
            return null;
        }
        if (this.targetDataType == null) {
            return this.shape == null ? tensor : tensor.reshape2(new Shape(this.shape));
        }
        switch (this.targetDataType) {
            case DOUBLE:
                tensor = DoubleTensor.of(tensor);
                break;
            case FLOAT:
                tensor = FloatTensor.of(tensor);
                break;
        }
        if (!tensor.getType().equals(this.targetDataType)) {
            switch (this.handleInvalidMethod) {
                case ERROR:
                    throw new AkIllegalOperatorParameterException(String.format("Could not convert tensor %s to tensor type %s", tensor, this.targetDataType));
                case SKIP:
                    tensor = null;
                    break;
                default:
                    throw new AkUnsupportedOperationException("Not support exception. ");
            }
        }
        if (tensor == null) {
            return null;
        }
        return this.shape == null ? tensor : tensor.reshape2(new Shape(this.shape));
    }

    @Override // com.alibaba.alink.common.mapper.SISOMapper
    protected TypeInformation<?> initOutputColType() {
        DataType dataType = null;
        if (this.params.contains(ToTensorParams.TENSOR_DATA_TYPE)) {
            dataType = (DataType) this.params.get(ToTensorParams.TENSOR_DATA_TYPE);
        }
        if (dataType == null) {
            return AlinkTypes.TENSOR;
        }
        switch (AnonymousClass1.$SwitchMap$com$alibaba$alink$common$linalg$tensor$DataType[dataType.ordinal()]) {
            case 1:
                return AlinkTypes.DOUBLE_TENSOR;
            case 2:
                return AlinkTypes.FLOAT_TENSOR;
            case 3:
                return AlinkTypes.STRING_TENSOR;
            case 4:
                return AlinkTypes.INT_TENSOR;
            case 5:
                return AlinkTypes.LONG_TENSOR;
            case TableUtil.DISPLAY_SIZE /* 6 */:
                return AlinkTypes.BOOL_TENSOR;
            case 7:
                return AlinkTypes.UBYTE_TENSOR;
            case 8:
                return AlinkTypes.BYTE_TENSOR;
            default:
                throw new AkUnsupportedOperationException("Unsupported tensor data type: " + dataType);
        }
    }
}
