package com.alibaba.alink.operator.common.similarity.dataConverter;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.similarity.LocalitySensitiveHashApproxFunctions;
import com.alibaba.alink.operator.common.similarity.lsh.BaseLSH;
import com.alibaba.alink.operator.common.similarity.lsh.BucketRandomProjectionLSH;
import com.alibaba.alink.operator.common.similarity.lsh.MinHashLSH;
import com.alibaba.alink.operator.common.similarity.modeldata.LSHModelData;
import com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/similarity/dataConverter/LSHModelDataConverter.class */
public class LSHModelDataConverter extends NearestNeighborDataConverter<LSHModelData> {
    private static final long serialVersionUID = -6846015825612538416L;
    private static int ROW_SIZE = 2;
    private static int BUCKETS_INDEX = 0;
    private static int DATA_IDNEX = 1;

    public LSHModelDataConverter() {
        this.rowSize = ROW_SIZE;
    }

    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public TableSchema getModelDataSchema() {
        return new TableSchema(new String[]{"BUCKETS", "DATA"}, new TypeInformation[]{Types.STRING, Types.STRING});
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Type inference failed for: r1v15, types: [com.alibaba.alink.operator.common.similarity.dataConverter.LSHModelDataConverter$2] */
    /* JADX WARN: Type inference failed for: r1v9, types: [com.alibaba.alink.operator.common.similarity.dataConverter.LSHModelDataConverter$1] */
    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public LSHModelData loadModelData(List<Row> list) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Row row : list) {
            if (row.getField(BUCKETS_INDEX) != null) {
                Tuple2 tuple2 = (Tuple2) JsonConverter.fromJson((String) row.getField(BUCKETS_INDEX), new TypeReference<Tuple2<Integer, List<Object>>>() { // from class: com.alibaba.alink.operator.common.similarity.dataConverter.LSHModelDataConverter.1
                }.getType());
                hashMap.put(tuple2.f0, tuple2.f1);
            } else if (row.getField(DATA_IDNEX) != null) {
                Tuple2 tuple22 = (Tuple2) JsonConverter.fromJson((String) row.getField(DATA_IDNEX), new TypeReference<Tuple2<Object, String>>() { // from class: com.alibaba.alink.operator.common.similarity.dataConverter.LSHModelDataConverter.2
                }.getType());
                hashMap2.put(tuple22.f0, VectorUtil.getVector(tuple22.f1));
            }
        }
        return new LSHModelData(hashMap, hashMap2, ((VectorApproxNearestNeighborTrainParams.Metric) this.meta.get(VectorApproxNearestNeighborTrainParams.METRIC)).equals(VectorApproxNearestNeighborTrainParams.Metric.JACCARD) ? new MinHashLSH((int[][]) this.meta.get(MinHashLSH.RAND_COEFFICIENTS_A), (int[][]) this.meta.get(MinHashLSH.RAND_COEFFICIENTS_B)) : new BucketRandomProjectionLSH((DenseVector[][]) this.meta.get(BucketRandomProjectionLSH.RAND_VECTORS), (double[][]) this.meta.get(BucketRandomProjectionLSH.RAND_NUMBER), ((Double) this.meta.get(BucketRandomProjectionLSH.PROJECTION_WIDTH)).doubleValue()));
    }

    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public DataSet<Row> buildIndex(BatchOperator batchOperator, final Params params) {
        DataSet<BaseLSH> buildLSH = LocalitySensitiveHashApproxFunctions.buildLSH(batchOperator, params, (String) params.get(VectorApproxNearestNeighborTrainParams.SELECTED_COL));
        SingleInputUdfOperator withBroadcastSet = batchOperator.getDataSet().map(new RichMapFunction<Row, Tuple3<Object, Vector, int[]>>() { // from class: com.alibaba.alink.operator.common.similarity.dataConverter.LSHModelDataConverter.3
            private static final long serialVersionUID = 9119201008956936115L;

            public Tuple3<Object, Vector, int[]> map(Row row) throws Exception {
                BaseLSH baseLSH = (BaseLSH) getRuntimeContext().getBroadcastVariable("lsh").get(0);
                Vector vector = VectorUtil.getVector(row.getField(1));
                return Tuple3.of(row.getField(0), vector, baseLSH.hashFunction(vector));
            }
        }).withBroadcastSet(buildLSH, "lsh");
        return withBroadcastSet.map(new MapFunction<Tuple3<Object, Vector, int[]>, Row>() { // from class: com.alibaba.alink.operator.common.similarity.dataConverter.LSHModelDataConverter.6
            private static final long serialVersionUID = 7915820872982890995L;

            public Row map(Tuple3<Object, Vector, int[]> tuple3) throws Exception {
                Row row = new Row(LSHModelDataConverter.ROW_SIZE);
                row.setField(LSHModelDataConverter.DATA_IDNEX, JsonConverter.toJson(Tuple2.of(tuple3.f0, VectorUtil.serialize(tuple3.f1))));
                return row;
            }
        }).union(withBroadcastSet.flatMap(new FlatMapFunction<Tuple3<Object, Vector, int[]>, Tuple2<Object, Integer>>() { // from class: com.alibaba.alink.operator.common.similarity.dataConverter.LSHModelDataConverter.5
            private static final long serialVersionUID = 7401684044391240070L;

            public void flatMap(Tuple3<Object, Vector, int[]> tuple3, Collector<Tuple2<Object, Integer>> collector) throws Exception {
                for (int i : (int[]) tuple3.f2) {
                    collector.collect(Tuple2.of(tuple3.f0, Integer.valueOf(i)));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple3<Object, Vector, int[]>) obj, (Collector<Tuple2<Object, Integer>>) collector);
            }
        }).groupBy(new int[]{1}).reduceGroup(new GroupReduceFunction<Tuple2<Object, Integer>, Row>() { // from class: com.alibaba.alink.operator.common.similarity.dataConverter.LSHModelDataConverter.4
            private static final long serialVersionUID = -4976135470912551698L;

            public void reduce(Iterable<Tuple2<Object, Integer>> iterable, Collector<Row> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                Integer num = null;
                for (Tuple2<Object, Integer> tuple2 : iterable) {
                    arrayList.add(tuple2.f0);
                    if (null == num) {
                        num = (Integer) tuple2.f1;
                    }
                }
                Row row = new Row(LSHModelDataConverter.ROW_SIZE);
                row.setField(LSHModelDataConverter.BUCKETS_INDEX, JsonConverter.toJson(Tuple2.of(num, arrayList)));
                collector.collect(row);
            }
        })).mapPartition(new RichMapPartitionFunction<Row, Row>() { // from class: com.alibaba.alink.operator.common.similarity.dataConverter.LSHModelDataConverter.7
            private static final long serialVersionUID = 1398487522497229248L;

            public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) {
                Params params2 = null;
                BaseLSH baseLSH = (BaseLSH) getRuntimeContext().getBroadcastVariable("lsh").get(0);
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    params2 = params;
                    if (baseLSH instanceof BucketRandomProjectionLSH) {
                        BucketRandomProjectionLSH bucketRandomProjectionLSH = (BucketRandomProjectionLSH) baseLSH;
                        params2.set((ParamInfo<ParamInfo<DenseVector[][]>>) BucketRandomProjectionLSH.RAND_VECTORS, (ParamInfo<DenseVector[][]>) bucketRandomProjectionLSH.getRandVectors()).set((ParamInfo<ParamInfo<double[][]>>) BucketRandomProjectionLSH.RAND_NUMBER, (ParamInfo<double[][]>) bucketRandomProjectionLSH.getRandNumber()).set((ParamInfo<ParamInfo<Double>>) BucketRandomProjectionLSH.PROJECTION_WIDTH, (ParamInfo<Double>) Double.valueOf(bucketRandomProjectionLSH.getProjectionWidth()));
                    } else {
                        MinHashLSH minHashLSH = (MinHashLSH) baseLSH;
                        params2.set((ParamInfo<ParamInfo<int[][]>>) MinHashLSH.RAND_COEFFICIENTS_A, (ParamInfo<int[][]>) minHashLSH.getRandCoefficientsA()).set((ParamInfo<ParamInfo<int[][]>>) MinHashLSH.RAND_COEFFICIENTS_B, (ParamInfo<int[][]>) minHashLSH.getRandCoefficientsB());
                    }
                }
                new LSHModelDataConverter().save2(Tuple2.of(params2, iterable), collector);
            }
        }).withBroadcastSet(buildLSH, "lsh").name("build_model");
    }

    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public /* bridge */ /* synthetic */ LSHModelData loadModelData(List list) {
        return loadModelData((List<Row>) list);
    }
}
