package com.alibaba.alink.operator.common.similarity;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper;
import com.alibaba.alink.operator.common.similarity.lsh.BaseLSH;
import com.alibaba.alink.operator.common.similarity.lsh.BucketRandomProjectionLSH;
import com.alibaba.alink.operator.common.similarity.lsh.MinHashLSH;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import com.alibaba.alink.params.shared.HasMLEnvironmentId;
import com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams;
import java.util.ArrayList;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/similarity/LocalitySensitiveHashApproxFunctions.class */
public class LocalitySensitiveHashApproxFunctions {
    public static BaseLSH buildLSH(Params params, int i) {
        VectorApproxNearestNeighborTrainParams.Metric metric = (VectorApproxNearestNeighborTrainParams.Metric) params.get(VectorApproxNearestNeighborTrainParams.METRIC);
        long longValue = ((Long) params.get(VectorApproxNearestNeighborTrainParams.SEED)).longValue();
        int intValue = ((Integer) params.get(VectorApproxNearestNeighborTrainParams.NUM_PROJECTIONS_PER_TABLE)).intValue();
        int intValue2 = ((Integer) params.get(VectorApproxNearestNeighborTrainParams.NUM_HASH_TABLES)).intValue();
        switch (metric) {
            case JACCARD:
                return new MinHashLSH(longValue, intValue, intValue2);
            case EUCLIDEAN:
                return new BucketRandomProjectionLSH(longValue, i, intValue, intValue2, ((Double) params.get(VectorApproxNearestNeighborTrainParams.PROJECTION_WIDTH)).doubleValue());
            default:
                throw new AkUnsupportedOperationException("Metric not supported: " + metric);
        }
    }

    public static DataSet<BaseLSH> buildLSH(BatchOperator batchOperator, final Params params, String str) {
        DataSource mapPartition;
        VectorApproxNearestNeighborTrainParams.Metric metric = (VectorApproxNearestNeighborTrainParams.Metric) params.get(VectorApproxNearestNeighborTrainParams.METRIC);
        switch (metric) {
            case JACCARD:
                mapPartition = MLEnvironmentFactory.get((Long) params.get(HasMLEnvironmentId.ML_ENVIRONMENT_ID)).getExecutionEnvironment().fromElements(new BaseLSH[]{new MinHashLSH(((Long) params.get(VectorApproxNearestNeighborTrainParams.SEED)).longValue(), ((Integer) params.get(VectorApproxNearestNeighborTrainParams.NUM_PROJECTIONS_PER_TABLE)).intValue(), ((Integer) params.get(VectorApproxNearestNeighborTrainParams.NUM_HASH_TABLES)).intValue())});
                break;
            case EUCLIDEAN:
                mapPartition = ((DataSet) StatisticsHelper.summaryHelper(batchOperator, null, str).f1).mapPartition(new MapPartitionFunction<BaseVectorSummary, BaseLSH>() { // from class: com.alibaba.alink.operator.common.similarity.LocalitySensitiveHashApproxFunctions.1
                    private static final long serialVersionUID = -3698577489884292933L;

                    public void mapPartition(Iterable<BaseVectorSummary> iterable, Collector<BaseLSH> collector) {
                        ArrayList arrayList = new ArrayList();
                        arrayList.getClass();
                        iterable.forEach((v1) -> {
                            r1.add(v1);
                        });
                        collector.collect(new BucketRandomProjectionLSH(((Long) Params.this.get(VectorApproxNearestNeighborTrainParams.SEED)).longValue(), ((BaseVectorSummary) arrayList.get(0)).vectorSize(), ((Integer) Params.this.get(VectorApproxNearestNeighborTrainParams.NUM_PROJECTIONS_PER_TABLE)).intValue(), ((Integer) Params.this.get(VectorApproxNearestNeighborTrainParams.NUM_HASH_TABLES)).intValue(), ((Double) Params.this.get(VectorApproxNearestNeighborTrainParams.PROJECTION_WIDTH)).doubleValue()));
                    }
                });
                break;
            default:
                throw new IllegalArgumentException("Not support " + metric);
        }
        return mapPartition;
    }
}
