package com.alibaba.alink.operator.batch.clustering;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.clustering.dbscan.Dbscan;
import com.alibaba.alink.operator.common.clustering.kmodes.KModesModel;
import com.alibaba.alink.operator.common.clustering.kmodes.KModesModelData;
import com.alibaba.alink.operator.common.distance.OneZeroDistance;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.clustering.KModesTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@ParamSelectColumnSpec(name = "featureCols", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})
@NameCn("Kmodes训练")
@NameEn("KModes Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.clustering.KModes")
/* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/KModesTrainBatchOp.class */
public final class KModesTrainBatchOp extends BatchOperator<KModesTrainBatchOp> implements KModesTrainParams<KModesTrainBatchOp> {
    private static final long serialVersionUID = 7392501340000162512L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/KModesTrainBatchOp$DataPartition.class */
    public static class DataPartition implements MapPartitionFunction<Tuple2<Long, String[]>, Tuple3<Long, Double, Map<String, Integer>[]>> {
        private static final long serialVersionUID = 4053491536690724820L;
        private int k;
        private int dim;

        public DataPartition(int i, int i2) {
            this.k = i;
            this.dim = i2;
        }

        public void mapPartition(Iterable<Tuple2<Long, String[]>> iterable, Collector<Tuple3<Long, Double, Map<String, Integer>[]>> collector) throws Exception {
            HashMap[][] hashMapArr = new HashMap[this.k][this.dim];
            for (int i = 0; i < this.k; i++) {
                for (int i2 = 0; i2 < this.dim; i2++) {
                    hashMapArr[i][i2] = new HashMap(32);
                }
            }
            double[] dArr = new double[this.k];
            Arrays.fill(dArr, Criteria.INVALID_GAIN);
            for (Tuple2<Long, String[]> tuple2 : iterable) {
                int intValue = ((Long) tuple2.f0).intValue();
                dArr[intValue] = dArr[intValue] + 1.0d;
                for (int i3 = 0; i3 < this.dim; i3++) {
                    if (hashMapArr[intValue][i3].containsKey(((String[]) tuple2.f1)[i3])) {
                        hashMapArr[intValue][i3].put(((String[]) tuple2.f1)[i3], Integer.valueOf(((Integer) hashMapArr[intValue][i3].get(((String[]) tuple2.f1)[i3])).intValue() + 1));
                    } else {
                        hashMapArr[intValue][i3].put(((String[]) tuple2.f1)[i3], 1);
                    }
                }
            }
            for (int i4 = 0; i4 < hashMapArr.length; i4++) {
                collector.collect(new Tuple3(Long.valueOf(i4), Double.valueOf(dArr[i4]), hashMapArr[i4]));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/KModesTrainBatchOp$DataReduce.class */
    public static class DataReduce implements ReduceFunction<Tuple3<Long, Double, Map<String, Integer>[]>> {
        private static final long serialVersionUID = -7472289261425686956L;
        private int dim;

        public DataReduce(int i) {
            this.dim = i;
        }

        public Tuple3<Long, Double, Map<String, Integer>[]> reduce(Tuple3<Long, Double, Map<String, Integer>[]> tuple3, Tuple3<Long, Double, Map<String, Integer>[]> tuple32) {
            return new Tuple3<>(tuple3.f0, Double.valueOf(((Double) tuple3.f1).doubleValue() + ((Double) tuple32.f1).doubleValue()), KModesTrainBatchOp.unionMaps((Map[]) tuple3.f2, (Map[]) tuple32.f2, this.dim));
        }
    }

    public KModesTrainBatchOp() {
        super(new Params());
    }

    public KModesTrainBatchOp(Params params) {
        super(params);
    }

    private static Map<String, Integer> unionMaps(Map<String, Integer> map, Map<String, Integer> map2) {
        HashMap hashMap = new HashMap();
        hashMap.putAll(map);
        for (String str : map2.keySet()) {
            if (hashMap.containsKey(str)) {
                hashMap.put(str, Integer.valueOf(((Integer) hashMap.get(str)).intValue() + map2.get(str).intValue()));
            } else {
                hashMap.put(str, map2.get(str));
            }
        }
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    public static Map<String, Integer>[] unionMaps(Map<String, Integer>[] mapArr, Map<String, Integer>[] mapArr2, int i) {
        HashMap[] hashMapArr = new HashMap[i];
        for (int i2 = 0; i2 < i; i2++) {
            hashMapArr[i2] = unionMaps(mapArr[i2], mapArr2[i2]);
        }
        return hashMapArr;
    }

    private static String getKOfMaxV(Map<String, Integer> map) {
        Integer valueOf = Integer.valueOf(Dbscan.NOISE);
        String str = null;
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue().intValue() > valueOf.intValue()) {
                valueOf = entry.getValue();
                str = entry.getKey();
            }
        }
        return str;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static String[] getKOfMaxV(Map<String, Integer>[] mapArr) {
        String[] strArr = new String[mapArr.length];
        for (int i = 0; i < mapArr.length; i++) {
            strArr[i] = getKOfMaxV(mapArr[i]);
        }
        return strArr;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public KModesTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        final String[] fieldNames = (!getParams().contains(FEATURE_COLS) || getFeatureCols() == null || getFeatureCols().length <= 0) ? checkAndGetFirst.getSchema().getFieldNames() : getFeatureCols();
        int intValue = getNumIter().intValue();
        int intValue2 = getK().intValue();
        if (fieldNames == null || fieldNames.length == 0) {
            throw new RuntimeException("featureColNames should be set !");
        }
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < fieldNames.length; i++) {
            if (i > 0) {
                sb.append(", ");
            }
            sb.append("cast(`").append(fieldNames[i]).append("` as VARCHAR) as `").append(fieldNames[i]).append("`");
        }
        MapOperator map = checkAndGetFirst.select(sb.toString()).getDataSet().map(new MapFunction<Row, String[]>() { // from class: com.alibaba.alink.operator.batch.clustering.KModesTrainBatchOp.1
            private static final long serialVersionUID = 8380190916941374707L;

            public String[] map(Row row) throws Exception {
                String[] strArr = new String[row.getArity()];
                for (int i2 = 0; i2 < strArr.length; i2++) {
                    strArr[i2] = (String) row.getField(i2);
                }
                return strArr;
            }
        });
        IterativeDataSet iterate = DataSetUtils.zipWithIndex(DataSetUtils.sampleWithSize(map, false, intValue2)).map(new MapFunction<Tuple2<Long, String[]>, Tuple3<Long, Double, String[]>>() { // from class: com.alibaba.alink.operator.batch.clustering.KModesTrainBatchOp.2
            private static final long serialVersionUID = -6852532761276146862L;

            public Tuple3<Long, Double, String[]> map(Tuple2<Long, String[]> tuple2) throws Exception {
                return new Tuple3<>(tuple2.f0, Double.valueOf(Criteria.INVALID_GAIN), tuple2.f1);
            }
        }).withForwardedFields(new String[]{"f0->f0;f1->f2"}).iterate(intValue);
        setOutput((DataSet<Row>) iterate.closeWith(updateCentroid(assignClusterId(map, iterate), intValue2, fieldNames.length)).mapPartition(new MapPartitionFunction<Tuple3<Long, Double, String[]>, Row>() { // from class: com.alibaba.alink.operator.batch.clustering.KModesTrainBatchOp.3
            private static final long serialVersionUID = -3961032097333930998L;

            public void mapPartition(Iterable<Tuple3<Long, Double, String[]>> iterable, Collector<Row> collector) throws Exception {
                KModesModelData kModesModelData = new KModesModelData();
                kModesModelData.centroids = new ArrayList();
                Iterator<Tuple3<Long, Double, String[]>> it = iterable.iterator();
                while (it.hasNext()) {
                    kModesModelData.centroids.add(it.next());
                }
                kModesModelData.featureColNames = fieldNames;
                new KModesModel().save(kModesModelData, collector);
            }
        }).setParallelism(1), new KModesModel().getModelSchema());
        return this;
    }

    private DataSet<Tuple2<Long, String[]>> assignClusterId(DataSet<String[]> dataSet, DataSet<Tuple3<Long, Double, String[]>> dataSet2) {
        return dataSet.map(new RichMapFunction<String[], Tuple2<Long, String[]>>() { // from class: com.alibaba.alink.operator.batch.clustering.KModesTrainBatchOp.1FindClusterOp
            private static final long serialVersionUID = 6305282153314372806L;
            List<Tuple3<Long, Double, String[]>> centroids;
            OneZeroDistance distance;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
                this.distance = new OneZeroDistance();
            }

            public Tuple2<Long, String[]> map(String[] strArr) throws Exception {
                return new Tuple2<>(Long.valueOf(KModesModel.findCluster(this.centroids, strArr, this.distance)), strArr);
            }
        }).withBroadcastSet(dataSet2, "centroids");
    }

    private DataSet<Tuple3<Long, Double, String[]>> updateCentroid(DataSet<Tuple2<Long, String[]>> dataSet, int i, int i2) {
        return dataSet.mapPartition(new DataPartition(i, i2)).groupBy(new int[]{0}).reduce(new DataReduce(i2)).map(new MapFunction<Tuple3<Long, Double, Map<String, Integer>[]>, Tuple3<Long, Double, String[]>>() { // from class: com.alibaba.alink.operator.batch.clustering.KModesTrainBatchOp.4
            private static final long serialVersionUID = -6833217715929845251L;

            public Tuple3<Long, Double, String[]> map(Tuple3<Long, Double, Map<String, Integer>[]> tuple3) {
                return new Tuple3<>(tuple3.f0, tuple3.f1, KModesTrainBatchOp.getKOfMaxV((Map<String, Integer>[]) tuple3.f2));
            }
        }).withForwardedFields(new String[]{"f0;f1"});
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ KModesTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
