package com.alibaba.alink.operator.common.outlier;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.outlier.OcsvmModelData;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.outlier.HaskernelType;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/outlier/OcsvmModelDetector.class */
public class OcsvmModelDetector extends ModelOutlierDetector {
    private static final long serialVersionUID = 6504098446269455446L;
    private int[] featureIdx;
    private int vectorIndex;
    private OcsvmModelData modelData;
    private DenseVector localX;
    private double gamma;
    private double coef0;
    private int degree;
    private HaskernelType.KernelType kernelType;

    public OcsvmModelDetector(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        this.vectorIndex = -1;
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        this.modelData = new OcsvmModelDataConverter().load(list);
        this.gamma = this.modelData.gamma;
        this.coef0 = this.modelData.coef0;
        this.degree = this.modelData.degree;
        this.kernelType = this.modelData.kernelType;
        if (this.modelData.featureColNames != null) {
            this.featureIdx = TableUtil.findColIndicesWithAssertAndHint(getSelectedCols(), this.modelData.featureColNames);
            this.localX = new DenseVector(this.featureIdx.length);
        }
        String str = this.modelData.vectorCol;
        if (str == null || str.isEmpty()) {
            return;
        }
        this.vectorIndex = TableUtil.findColIndexWithAssertAndHint(getSelectedCols(), str);
    }

    @Override // com.alibaba.alink.operator.common.outlier.ModelOutlierDetector
    protected Tuple3<Boolean, Double, Map<String, String>> detect(Mapper.SlicedSelectedSample slicedSelectedSample) {
        double d = 0.0d;
        for (OcsvmModelData.SvmModelData svmModelData : this.modelData.models) {
            d -= predictSingle(slicedSelectedSample, svmModelData);
        }
        boolean z = d >= Criteria.INVALID_GAIN;
        HashMap hashMap = new HashMap();
        hashMap.put(OutlierDetector.OUTLIER_SCORE_KEY, String.valueOf(d));
        return Tuple3.of(Boolean.valueOf(z), Double.valueOf(d), hashMap);
    }

    public double predictSingle(Mapper.SlicedSelectedSample slicedSelectedSample, OcsvmModelData.SvmModelData svmModelData) {
        if (this.vectorIndex != -1) {
            return OcsvmKernel.svmPredict(svmModelData, VectorUtil.getVector(slicedSelectedSample.get(this.vectorIndex)), this.kernelType, this.gamma, this.coef0, this.degree);
        }
        for (int i = 0; i < this.featureIdx.length; i++) {
            this.localX.set(i, ((Number) slicedSelectedSample.get(this.featureIdx[i])).doubleValue());
        }
        return OcsvmKernel.svmPredict(svmModelData, this.localX, this.kernelType, this.gamma, this.coef0, this.degree);
    }
}
