package com.alibaba.alink.operator.common.optim.local;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.common.optim.FeatureConstraint;
import com.alibaba.alink.operator.common.optim.activeSet.ConstraintObjFunc;
import com.alibaba.alink.operator.common.optim.activeSet.Sqp;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.feature.HasConstraint;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/local/ConstrainedLocalOptimizer.class */
public class ConstrainedLocalOptimizer {
    public static Tuple4<DenseVector, DenseVector, DenseMatrix, Double> optimizeWithHessian(List<Tuple3<Double, Double, Vector>> list, LinearModelType linearModelType, Params params) {
        int size = ((Vector) list.get(0).f2).size();
        ConstraintObjFunc constraintObjFunc = (ConstraintObjFunc) BaseConstrainedLinearModelTrainBatchOp.getObjFunction(linearModelType, null);
        extractConstraintsForFeatureAndBin(FeatureConstraint.fromJson(params.contains(HasConstraint.CONSTRAINT) ? (String) params.get(HasConstraint.CONSTRAINT) : ""), constraintObjFunc, null, size, true, null, null);
        return LocalSqp.sqpWithHessian(list, new DenseVector(Sqp.phaseOne(constraintObjFunc.equalityConstraint.getArrayCopy2D(), constraintObjFunc.equalityItem.getData(), constraintObjFunc.inequalityConstraint.getArrayCopy2D(), constraintObjFunc.inequalityItem.getData(), size)), constraintObjFunc, params);
    }

    public static void extractConstraintsForFeatureAndBin(FeatureConstraint featureConstraint, ConstraintObjFunc constraintObjFunc, String[] strArr, int i, boolean z, DenseVector denseVector, Map<String, Boolean> map) {
        Tuple4<double[][], double[], double[][], double[]> constraintsForFeatures;
        if (featureConstraint.getBinConstraintSize() != 0) {
            featureConstraint.setCountZero(denseVector);
            constraintsForFeatures = featureConstraint.getConstraintsForFeatureWithBin();
        } else if (strArr == null) {
            if (z) {
                i--;
            }
            constraintsForFeatures = featureConstraint.getConstraintsForFeatures(i);
        } else {
            HashMap<String, Integer> hashMap = new HashMap<>(strArr.length);
            for (int i2 = 0; i2 < strArr.length; i2++) {
                hashMap.put(strArr[i2], Integer.valueOf(i2));
            }
            constraintsForFeatures = featureConstraint.getConstraintsForFeatures(hashMap);
        }
        if (z) {
            addIntercept(constraintsForFeatures);
        }
        constraintObjFunc.inequalityConstraint = new DenseMatrix((double[][]) constraintsForFeatures.f0);
        constraintObjFunc.inequalityItem = new DenseVector((double[]) constraintsForFeatures.f1);
        constraintObjFunc.equalityConstraint = new DenseMatrix((double[][]) constraintsForFeatures.f2);
        constraintObjFunc.equalityItem = new DenseVector((double[]) constraintsForFeatures.f3);
    }

    private static void addIntercept(Tuple4<double[][], double[], double[][], double[]> tuple4) {
        tuple4.f0 = prefixMatrix((double[][]) tuple4.f0);
        tuple4.f2 = prefixMatrix((double[][]) tuple4.f2);
    }

    private static double[][] prefixMatrix(double[][] dArr) {
        int length = dArr.length;
        if (length == 0) {
            return dArr;
        }
        int length2 = dArr[0].length;
        for (int i = 0; i < length; i++) {
            dArr[i] = prefixRow(dArr[i], length2);
        }
        return dArr;
    }

    private static double[] prefixRow(double[] dArr, int i) {
        double[] dArr2 = new double[i + 1];
        if (i >= 0) {
            System.arraycopy(dArr, 0, dArr2, 1, i);
        }
        return dArr2;
    }

    @Deprecated
    public static void preProcess(List<Tuple3<Double, Double, Vector>> list, boolean z, boolean z2) {
        for (Tuple3<Double, Double, Vector> tuple3 : list) {
            if (z2) {
                if (z) {
                    tuple3.f2 = ((Vector) tuple3.f2).prefix(1.0d);
                }
            } else if (z) {
                tuple3.f2 = ((Vector) tuple3.f2).prefix(1.0d);
            }
        }
    }

    public List<Tuple3<Double, Double, Vector>> predict(List<Tuple3<Double, Double, Vector>> list, LinearModelType linearModelType, DenseVector denseVector, Params params) {
        preProcess(list, ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue(), ((Boolean) params.get(LinearTrainParams.STANDARDIZATION)).booleanValue());
        if (linearModelType.equals(LinearModelType.LinearReg)) {
            for (Tuple3<Double, Double, Vector> tuple3 : list) {
                tuple3.f0 = Double.valueOf(((Vector) tuple3.f2).dot(denseVector));
            }
        } else {
            for (Tuple3<Double, Double, Vector> tuple32 : list) {
                tuple32.f0 = Double.valueOf(((Vector) tuple32.f2).dot(denseVector) > Criteria.INVALID_GAIN ? 1.0d : Criteria.INVALID_GAIN);
            }
        }
        return list;
    }
}
