package com.alibaba.alink.operator.common.similarity.dataConverter;

import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.clustering.KMeansTrainBatchOp;
import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper;
import com.alibaba.alink.operator.common.distance.EuclideanDistance;
import com.alibaba.alink.operator.common.distance.FastDistanceVectorData;
import com.alibaba.alink.operator.common.similarity.KDTree;
import com.alibaba.alink.operator.common.similarity.modeldata.KDTreeModelData;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/similarity/dataConverter/KDTreeModelDataConverter.class */
public class KDTreeModelDataConverter extends NearestNeighborDataConverter<KDTreeModelData> {
    private static final long serialVersionUID = 7886707008132061008L;
    private static int ROW_SIZE = 4;
    private static int TASKID_INDEX = 0;
    private static int DATA_ID_INDEX = 1;
    private static int DATA_IDNEX = 2;
    private static int ROOT_IDDEX = 3;
    private static ParamInfo<Integer> VECTOR_SIZE = ParamInfoFactory.createParamInfo(KMeansTrainBatchOp.VECTOR_SIZE, Integer.class).setRequired().build();

    public KDTreeModelDataConverter() {
        this.rowSize = ROW_SIZE;
    }

    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public TableSchema getModelDataSchema() {
        return new TableSchema(new String[]{"TASKID", "DATAID", "DATA", "ROOT"}, new TypeInformation[]{Types.LONG, Types.LONG, Types.STRING, Types.STRING});
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Type inference failed for: r1v28, types: [com.alibaba.alink.operator.common.similarity.dataConverter.KDTreeModelDataConverter$1] */
    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public KDTreeModelData loadModelData(List<Row> list) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Row row : list) {
            if (row.getField(TASKID_INDEX) != null) {
                long longValue = ((Long) row.getField(TASKID_INDEX)).longValue();
                TreeMap treeMap = (TreeMap) hashMap.get(Long.valueOf(longValue));
                if (null == treeMap) {
                    treeMap = new TreeMap();
                }
                if (row.getField(DATA_IDNEX) != null) {
                    treeMap.put(Integer.valueOf(((Number) row.getField(DATA_ID_INDEX)).intValue()), FastDistanceVectorData.fromString((String) row.getField(DATA_IDNEX)));
                    hashMap.put(Long.valueOf(longValue), treeMap);
                } else if (row.getField(ROOT_IDDEX) != null) {
                    hashMap2.put(Long.valueOf(longValue), (KDTree.TreeNode) JsonConverter.fromJson((String) row.getField(ROOT_IDDEX), new TypeReference<KDTree.TreeNode>() { // from class: com.alibaba.alink.operator.common.similarity.dataConverter.KDTreeModelDataConverter.1
                    }.getType()));
                }
            }
        }
        ArrayList arrayList = new ArrayList();
        int intValue = ((Integer) this.meta.get(VECTOR_SIZE)).intValue();
        EuclideanDistance euclideanDistance = new EuclideanDistance();
        for (Map.Entry entry : hashMap.entrySet()) {
            KDTree.TreeNode treeNode = (KDTree.TreeNode) hashMap2.get(Long.valueOf(((Long) entry.getKey()).longValue()));
            KDTree kDTree = new KDTree((FastDistanceVectorData[]) ((TreeMap) entry.getValue()).values().toArray(new FastDistanceVectorData[0]), intValue, euclideanDistance);
            kDTree.setRoot(treeNode);
            arrayList.add(kDTree);
        }
        return new KDTreeModelData(arrayList);
    }

    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public DataSet<Row> buildIndex(BatchOperator batchOperator, final Params params) {
        AkPreconditions.checkArgument(((VectorApproxNearestNeighborTrainParams.Metric) params.get(VectorApproxNearestNeighborTrainParams.METRIC)).equals(VectorApproxNearestNeighborTrainParams.Metric.EUCLIDEAN), "KDTree solver only supports Euclidean distance!");
        final EuclideanDistance euclideanDistance = new EuclideanDistance();
        Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> summaryHelper = StatisticsHelper.summaryHelper(batchOperator, null, (String) params.get(VectorApproxNearestNeighborTrainParams.SELECTED_COL));
        return batchOperator.getDataSet().rebalance().mapPartition(new RichMapPartitionFunction<Row, Row>() { // from class: com.alibaba.alink.operator.common.similarity.dataConverter.KDTreeModelDataConverter.3
            private static final long serialVersionUID = 6654757741959479783L;

            public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
                int vectorSize = ((BaseVectorSummary) getRuntimeContext().getBroadcastVariable(KMeansTrainBatchOp.VECTOR_SIZE).get(0)).vectorSize();
                ArrayList arrayList = new ArrayList();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    FastDistanceVectorData prepareVectorData = euclideanDistance.prepareVectorData(it.next(), 1, 0);
                    arrayList.add(prepareVectorData);
                    vectorSize = prepareVectorData.getVector().size();
                }
                if (arrayList.size() > 0) {
                    FastDistanceVectorData[] fastDistanceVectorDataArr = (FastDistanceVectorData[]) arrayList.toArray(new FastDistanceVectorData[0]);
                    KDTree kDTree = new KDTree(fastDistanceVectorDataArr, vectorSize, euclideanDistance);
                    kDTree.buildTree();
                    int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                    Row row = new Row(KDTreeModelDataConverter.ROW_SIZE);
                    row.setField(KDTreeModelDataConverter.TASKID_INDEX, Long.valueOf(indexOfThisSubtask));
                    for (int i = 0; i < fastDistanceVectorDataArr.length; i++) {
                        row.setField(KDTreeModelDataConverter.DATA_ID_INDEX, Long.valueOf(i));
                        row.setField(KDTreeModelDataConverter.DATA_IDNEX, fastDistanceVectorDataArr[i].toString());
                        collector.collect(row);
                    }
                    row.setField(KDTreeModelDataConverter.DATA_ID_INDEX, (Object) null);
                    row.setField(KDTreeModelDataConverter.DATA_IDNEX, (Object) null);
                    row.setField(KDTreeModelDataConverter.ROOT_IDDEX, JsonConverter.toJson(kDTree.getRoot()));
                    collector.collect(row);
                }
            }
        }).withBroadcastSet((DataSet) summaryHelper.f1, KMeansTrainBatchOp.VECTOR_SIZE).mapPartition(new RichMapPartitionFunction<Row, Row>() { // from class: com.alibaba.alink.operator.common.similarity.dataConverter.KDTreeModelDataConverter.2
            private static final long serialVersionUID = 6849403933586157611L;

            public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
                Params params2 = null;
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    params2 = params;
                    params2.set((ParamInfo<ParamInfo>) KDTreeModelDataConverter.VECTOR_SIZE, (ParamInfo) Integer.valueOf(((BaseVectorSummary) getRuntimeContext().getBroadcastVariable(KMeansTrainBatchOp.VECTOR_SIZE).get(0)).vectorSize()));
                }
                new KDTreeModelDataConverter().save2(Tuple2.of(params2, iterable), collector);
            }
        }).withBroadcastSet((DataSet) summaryHelper.f1, KMeansTrainBatchOp.VECTOR_SIZE);
    }

    @Override // com.alibaba.alink.operator.common.similarity.dataConverter.NearestNeighborDataConverter
    public /* bridge */ /* synthetic */ KDTreeModelData loadModelData(List list) {
        return loadModelData((List<Row>) list);
    }
}
