package com.alibaba.alink.operator.common.outlier;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.operator.common.outlier.OcsvmModelData;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.outlier.HaskernelType;
import com.alibaba.alink.params.outlier.OcsvmDetectorParams;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/outlier/OcsvmDetector.class */
public class OcsvmDetector extends OutlierDetector {

    /* loaded from: input_file:com/alibaba/alink/operator/common/outlier/OcsvmDetector$OcsvmPredict.class */
    public static final class OcsvmPredict implements OcsvmDetectorParams<OcsvmPredict> {
        private final Params params;
        private transient OcsvmModelData ocsvmModel;
        private String vectorColName;
        private double gamma;
        private double coef0;
        private int degree;
        private HaskernelType.KernelType kernelType;

        public OcsvmPredict() {
            this(new Params());
        }

        public OcsvmPredict(Params params) {
            this.params = params == null ? new Params() : params;
        }

        public void loadModel(OcsvmModelData ocsvmModelData) {
            this.ocsvmModel = ocsvmModelData;
            this.vectorColName = ocsvmModelData.vectorCol;
            this.gamma = ocsvmModelData.gamma;
            this.kernelType = ocsvmModelData.kernelType;
            this.degree = ocsvmModelData.degree;
            this.coef0 = ocsvmModelData.coef0;
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v23, types: [com.alibaba.alink.common.linalg.Vector] */
        public double predict(Row row) {
            DenseVector denseVector;
            if (this.vectorColName != null) {
                denseVector = VectorUtil.getVector(row.getField(0));
            } else {
                denseVector = new DenseVector(row.getArity());
                for (int i = 0; i < row.getArity(); i++) {
                    denseVector.set(i, ((Number) row.getField(i)).doubleValue());
                }
            }
            double d = 0.0d;
            for (OcsvmModelData.SvmModelData svmModelData : this.ocsvmModel.models) {
                d -= OcsvmKernel.svmPredict(svmModelData, denseVector, this.kernelType, this.gamma, this.coef0, this.degree);
            }
            return d;
        }

        @Override // org.apache.flink.ml.api.misc.param.WithParams
        public Params getParams() {
            return this.params;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/outlier/OcsvmDetector$OcsvmTrain.class */
    public static final class OcsvmTrain implements OcsvmDetectorParams<OcsvmTrain> {
        private final Params params;

        public OcsvmTrain() {
            this.params = new Params();
        }

        public OcsvmTrain(Params params) {
            this.params = params == null ? new Params() : params;
        }

        public OcsvmModelData train(MTable mTable) {
            Vector[] vectors = OutlierUtil.getVectors(mTable, this.params);
            if (((Double) this.params.get(OcsvmDetectorParams.GAMMA)).doubleValue() < 1.0E-18d) {
                this.params.set((ParamInfo<ParamInfo<Double>>) OcsvmDetectorParams.GAMMA, (ParamInfo<Double>) Double.valueOf(1.0d / vectors[0].size()));
            }
            OcsvmModelData ocsvmModelData = new OcsvmModelData();
            ocsvmModelData.featureColNames = getFeatureCols();
            ocsvmModelData.kernelType = getKernelType();
            ocsvmModelData.coef0 = getCoef0().doubleValue();
            ocsvmModelData.degree = getDegree().intValue();
            ocsvmModelData.gamma = ((Double) this.params.get(OcsvmDetectorParams.GAMMA)).doubleValue();
            ocsvmModelData.vectorCol = getVectorCol();
            ocsvmModelData.models = new OcsvmModelData.SvmModelData[]{OcsvmKernel.svmTrain(vectors, this.params)};
            return ocsvmModelData;
        }

        @Override // org.apache.flink.ml.api.misc.param.WithParams
        public Params getParams() {
            return this.params;
        }
    }

    public OcsvmDetector(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
    }

    @Override // com.alibaba.alink.operator.common.outlier.OutlierDetector
    protected Tuple3<Boolean, Double, Map<String, String>>[] detect(MTable mTable, boolean z) {
        String str = (String) this.params.get(OcsvmDetectorParams.VECTOR_COL);
        MTable select = str != null ? mTable.select(str) : mTable.select((String[]) this.params.get(OcsvmDetectorParams.FEATURE_COLS));
        OcsvmTrain ocsvmTrain = new OcsvmTrain(this.params);
        OcsvmPredict ocsvmPredict = new OcsvmPredict(this.params);
        ocsvmPredict.loadModel(ocsvmTrain.train(select));
        int numRow = z ? select.getNumRow() - 1 : 0;
        Tuple3<Boolean, Double, Map<String, String>>[] tuple3Arr = new Tuple3[select.getNumRow() - numRow];
        for (int i = numRow; i < select.getNumRow(); i++) {
            double predict = ocsvmPredict.predict(select.getRow(i));
            tuple3Arr[i - numRow] = Tuple3.of(Boolean.valueOf(predict > Criteria.INVALID_GAIN), Double.valueOf(predict), (Object) null);
        }
        return tuple3Arr;
    }
}
