package com.alibaba.alink.operator.common.dataproc.vector;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.SISOMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.dataproc.vector.VectorPolynomialExpandParams;
import org.apache.commons.math3.util.ArithmeticUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

/* loaded from: input_file:com/alibaba/alink/operator/common/dataproc/vector/PolynomialExpansionMapper.class */
public class PolynomialExpansionMapper extends SISOMapper {
    private static final long serialVersionUID = -706089902874084729L;
    private final int degree;
    private static final int CONSTANT = 61;

    public PolynomialExpansionMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
        this.degree = ((Integer) this.params.get(VectorPolynomialExpandParams.DEGREE)).intValue();
    }

    static int getPolySize(int i, int i2) {
        if (i == 0) {
            return 1;
        }
        if (i == 1 || i2 == 1) {
            return i + i2;
        }
        if (i2 > i) {
            return getPolySize(i2, i);
        }
        long j = 1;
        int i3 = i + 1;
        if (i + i2 < CONSTANT) {
            for (int i4 = 1; i4 <= i2; i4++) {
                j = (j * i3) / i4;
                i3++;
            }
        } else {
            for (int i5 = 1; i5 <= i2; i5++) {
                int gcd = ArithmeticUtils.gcd(i3, i5);
                j = ArithmeticUtils.mulAndCheck(j / (i5 / gcd), i3 / gcd);
                i3++;
            }
        }
        if (j > 2147483647L) {
            throw new AkIllegalDataException("The expended polynomial size is too large.");
        }
        return (int) j;
    }

    @Override // com.alibaba.alink.common.mapper.SISOMapper
    protected TypeInformation initOutputColType() {
        return AlinkTypes.VECTOR;
    }

    @Override // com.alibaba.alink.common.mapper.SISOMapper
    protected Object mapColumn(Object obj) {
        Vector vector = VectorUtil.getVector(obj);
        if (null == vector) {
            return null;
        }
        return vector instanceof SparseVector ? sparsePE((SparseVector) vector, this.degree) : densePE((DenseVector) vector, this.degree);
    }

    private DenseVector densePE(DenseVector denseVector, int i) {
        int size = denseVector.size();
        double[] dArr = new double[getPolySize(size, i) - 1];
        expandDense(denseVector.getData(), size - 1, i, 1.0d, dArr, -1);
        return new DenseVector(dArr);
    }

    private SparseVector sparsePE(SparseVector sparseVector, int i) {
        int[] indices = sparseVector.getIndices();
        double[] values = sparseVector.getValues();
        int size = sparseVector.size();
        int length = sparseVector.getValues().length;
        int polySize = getPolySize(length, i);
        Tuple2<Integer, int[]> of = Tuple2.of(0, new int[polySize - 1]);
        Tuple2<Integer, double[]> of2 = Tuple2.of(0, new double[polySize - 1]);
        expandSparse(indices, values, length - 1, size - 1, i, 1.0d, of, of2, -1);
        return new SparseVector(getPolySize(size, i) - 1, (int[]) of.f1, (double[]) of2.f1);
    }

    private int expandDense(double[] dArr, int i, int i2, double d, double[] dArr2, int i3) {
        if (!Double.valueOf(d).equals(Double.valueOf(Criteria.INVALID_GAIN))) {
            if (i2 != 0 && i >= 0) {
                double d2 = dArr[i];
                int i4 = i - 1;
                int i5 = 0;
                int i6 = i3;
                for (double d3 = d; i5 <= i2 && Math.abs(d3) > Criteria.INVALID_GAIN; d3 *= d2) {
                    i6 = expandDense(dArr, i4, i2 - i5, d3, dArr2, i6);
                    i5++;
                }
            } else if (i3 >= 0) {
                dArr2[i3] = d;
            }
        }
        return i3 + getPolySize(i + 1, i2);
    }

    private int expandSparse(int[] iArr, double[] dArr, int i, int i2, int i3, double d, Tuple2<Integer, int[]> tuple2, Tuple2<Integer, double[]> tuple22, int i4) {
        if (!Double.valueOf(d).equals(Double.valueOf(Criteria.INVALID_GAIN))) {
            if (i3 != 0 && i >= 0) {
                double d2 = dArr[i];
                int i5 = i - 1;
                int i6 = iArr[i] - 1;
                int i7 = i4;
                int i8 = 0;
                for (double d3 = d; i8 <= i3 && Math.abs(d3) > Criteria.INVALID_GAIN; d3 *= d2) {
                    i7 = expandSparse(iArr, dArr, i5, i6, i3 - i8, d3, tuple2, tuple22, i7);
                    i8++;
                }
            } else if (i4 >= 0) {
                ((int[]) tuple2.f1)[((Integer) tuple2.f0).intValue()] = i4;
                ((double[]) tuple22.f1)[((Integer) tuple22.f0).intValue()] = d;
                tuple2.f0 = Integer.valueOf(((Integer) tuple2.f0).intValue() + 1);
                tuple22.f0 = Integer.valueOf(((Integer) tuple22.f0).intValue() + 1);
            }
        }
        return i4 + getPolySize(i2 + 1, i3);
    }
}
