package com.alibaba.alink.operator.common.dataproc.vector;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.MISOMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.params.dataproc.vector.VectorAssemblerParams;
import com.alibaba.alink.params.shared.HasHandleInvalid;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/vector/VectorAssemblerMapper.class */
public class VectorAssemblerMapper extends MISOMapper {
    private static final double RATIO = 1.5d;
    private static final long serialVersionUID = -8419340084734506661L;
    private final HasHandleInvalid.HandleInvalidMethod handleInvalid;

    public VectorAssemblerMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
        this.handleInvalid = (HasHandleInvalid.HandleInvalidMethod) params.get(VectorAssemblerParams.HANDLE_INVALID);
    }

    @Override // com.alibaba.alink.common.mapper.MISOMapper
    protected TypeInformation<?> initOutputColType() {
        return AlinkTypes.VECTOR;
    }

    @Override // com.alibaba.alink.common.mapper.MISOMapper
    protected Object map(Object[] objArr) {
        return assembler(objArr, this.handleInvalid);
    }

    public static Object assembler(Object[] objArr) {
        return assembler(objArr, HasHandleInvalid.HandleInvalidMethod.ERROR);
    }

    private static Object assembler(Object[] objArr, HasHandleInvalid.HandleInvalidMethod handleInvalidMethod) {
        if (null == objArr) {
            return null;
        }
        int i = 0;
        int length = objArr.length;
        for (Object obj : objArr) {
            if (obj instanceof DenseVector) {
                length += ((DenseVector) obj).size();
            } else if (obj instanceof SparseVector) {
                length += ((SparseVector) obj).getIndices().length;
            }
        }
        HashMap hashMap = new HashMap(length);
        for (Object obj2 : objArr) {
            if (null == obj2) {
                switch (handleInvalidMethod) {
                    case ERROR:
                        throw new AkIllegalDataException("null value is found in vector assembler inputs.");
                    case SKIP:
                        return null;
                }
            }
            if (obj2 instanceof Number) {
                int i2 = i;
                i++;
                hashMap.put(Integer.valueOf(i2), Double.valueOf(((Number) obj2).doubleValue()));
            } else if (obj2 instanceof String) {
                i = appendVector(VectorUtil.getVector(obj2), hashMap, i);
            } else {
                if (!(obj2 instanceof Vector)) {
                    throw new AkUnsupportedOperationException("only support number, string and vector, other types will cause exception");
                }
                i = appendVector((Vector) obj2, hashMap, i);
            }
        }
        Vector sparseVector = new SparseVector(i, hashMap);
        if (hashMap.size() * RATIO > i) {
            sparseVector = ((SparseVector) sparseVector).toDenseVector();
        }
        return sparseVector;
    }

    private static int appendVector(Vector vector, Map<Integer, Double> map, int i) {
        if (vector instanceof SparseVector) {
            SparseVector sparseVector = (SparseVector) vector;
            if (sparseVector.size() <= 0) {
                throw new AkIllegalDataException("The append sparse vector must have size.");
            }
            int[] indices = sparseVector.getIndices();
            double[] values = sparseVector.getValues();
            for (int i2 = 0; i2 < indices.length; i2++) {
                map.put(Integer.valueOf(i + indices[i2]), Double.valueOf(values[i2]));
            }
            i += sparseVector.size();
        } else if (vector instanceof DenseVector) {
            DenseVector denseVector = (DenseVector) vector;
            for (int i3 = 0; i3 < denseVector.size(); i3++) {
                int i4 = i;
                i++;
                map.put(Integer.valueOf(i4), Double.valueOf(denseVector.get(i3)));
            }
        }
        return i;
    }
}
