package com.alibaba.alink.operator.common.dataproc.vector;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.MISOMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/vector/VectorInteractionMapper.class */
public class VectorInteractionMapper extends MISOMapper {
    private static final long serialVersionUID = 5122592154123233560L;

    public VectorInteractionMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
    }

    @Override // com.alibaba.alink.common.mapper.MISOMapper
    protected TypeInformation<?> initOutputColType() {
        return AlinkTypes.VECTOR;
    }

    @Override // com.alibaba.alink.common.mapper.MISOMapper
    protected Object map(Object[] objArr) {
        if (objArr.length != 2) {
            throw new AkIllegalDataException("VectorInteraction only support two input columns.");
        }
        if (objArr[0] == null || objArr[1] == null) {
            return null;
        }
        Vector vector = VectorUtil.getVector(objArr[0]);
        Vector vector2 = VectorUtil.getVector(objArr[1]);
        if (!(vector instanceof SparseVector)) {
            if (vector2 instanceof SparseVector) {
                throw new AkIllegalDataException("Make sure the two input vectors are both dense or sparse.");
            }
            double[] data = ((DenseVector) vector).getData();
            double[] data2 = ((DenseVector) vector2).getData();
            DenseVector denseVector = new DenseVector(data.length * data2.length);
            double[] data3 = denseVector.getData();
            for (int i = 0; i < data.length; i++) {
                int length = i * data2.length;
                for (int i2 = 0; i2 < data2.length; i2++) {
                    data3[length + i2] = data[i] * data2[i2];
                }
            }
            return denseVector;
        }
        if (vector2 instanceof DenseVector) {
            throw new AkIllegalDataException("Make sure the two input vectors are both dense or sparse.");
        }
        SparseVector sparseVector = (SparseVector) vector;
        int size = sparseVector.size();
        int[] indices = sparseVector.getIndices();
        double[] values = sparseVector.getValues();
        SparseVector sparseVector2 = (SparseVector) vector2;
        int size2 = sparseVector2.size();
        int[] indices2 = sparseVector2.getIndices();
        double[] values2 = sparseVector2.getValues();
        double[] dArr = new double[indices2.length * indices.length];
        int[] iArr = new int[indices2.length * indices.length];
        for (int i3 = 0; i3 < indices.length; i3++) {
            int length2 = i3 * indices2.length;
            for (int i4 = 0; i4 < indices2.length; i4++) {
                int i5 = length2 + i4;
                iArr[i5] = (size * indices2[i4]) + indices[i3];
                dArr[i5] = values[i3] * values2[i4];
            }
        }
        return new SparseVector(size * size2, iArr, dArr);
    }
}
