package com.alibaba.alink.operator.common.dataproc;

import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.tensor.DoubleTensor;
import com.alibaba.alink.common.linalg.tensor.NumericalTensor;
import com.alibaba.alink.common.linalg.tensor.Tensor;
import com.alibaba.alink.common.linalg.tensor.TensorUtil;
import com.alibaba.alink.common.mapper.SISOMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.params.dataproc.TensorToVectorParams;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/TensorToVectorMapper.class */
public class TensorToVectorMapper extends SISOMapper {
    private final TensorToVectorParams.ConvertMethod method;

    public TensorToVectorMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
        this.method = (TensorToVectorParams.ConvertMethod) params.get(TensorToVectorParams.CONVERT_METHOD);
    }

    @Override // com.alibaba.alink.common.mapper.SISOMapper
    protected Object mapColumn(Object obj) {
        if (null == obj) {
            return null;
        }
        Tensor<?> tensor = TensorUtil.getTensor(obj);
        if (!(tensor instanceof NumericalTensor)) {
            throw new IllegalStateException(String.format("Only numerical tensor could be converted to vector. Tensor type: %s", tensor.getClass().getName()));
        }
        switch (this.method) {
            case FLATTEN:
                return DoubleTensor.of(tensor.flatten(0, -1)).toVector();
            case SUM:
                return DoubleTensor.of(((NumericalTensor) tensor).sum(0, false)).toVector();
            case MEAN:
                return DoubleTensor.of(((NumericalTensor) tensor).mean(0, false)).toVector();
            case MAX:
                return DoubleTensor.of(((NumericalTensor) tensor).max(0, false)).toVector();
            case MIN:
                return DoubleTensor.of(((NumericalTensor) tensor).min(0, false)).toVector();
            default:
                throw new AkUnsupportedOperationException("Not support exception. ");
        }
    }

    @Override // com.alibaba.alink.common.mapper.SISOMapper
    protected TypeInformation initOutputColType() {
        return AlinkTypes.DENSE_VECTOR;
    }
}
