package com.alibaba.alink.operator.common.dataproc.vector;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.SISOMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.params.dataproc.vector.VectorElementwiseProductParams;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/vector/VectorElementwiseProductMapper.class */
public class VectorElementwiseProductMapper extends SISOMapper {
    private static final long serialVersionUID = -8030477987774641696L;
    private final Vector scalingVector;

    public VectorElementwiseProductMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
        this.scalingVector = VectorUtil.getVector(this.params.get(VectorElementwiseProductParams.SCALING_VECTOR));
    }

    @Override // com.alibaba.alink.common.mapper.SISOMapper
    protected TypeInformation initOutputColType() {
        return AlinkTypes.VECTOR;
    }

    @Override // com.alibaba.alink.common.mapper.SISOMapper
    protected Object mapColumn(Object obj) {
        if (null == obj) {
            return null;
        }
        Vector vector = VectorUtil.getVector(obj);
        if (vector instanceof DenseVector) {
            double[] data = ((DenseVector) vector).getData();
            for (int i = 0; i < data.length; i++) {
                data[i] = data[i] * this.scalingVector.get(i);
            }
        } else {
            SparseVector sparseVector = (SparseVector) vector;
            double[] values = sparseVector.getValues();
            int[] indices = sparseVector.getIndices();
            for (int i2 = 0; i2 < values.length; i2++) {
                int i3 = i2;
                values[i3] = values[i3] * this.scalingVector.get(indices[i2]);
            }
        }
        return vector;
    }
}
