package com.alibaba.alink.operator.batch.clustering;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.clustering.LocalKMeans;
import com.alibaba.alink.operator.common.clustering.common.Sample;
import com.alibaba.alink.operator.common.distance.ContinuousDistance;
import com.alibaba.alink.operator.common.distance.FastDistance;
import com.alibaba.alink.params.clustering.GroupKMeansParams;
import com.alibaba.alink.params.shared.clustering.HasClusteringDistanceType;
import com.alibaba.alink.params.shared.colname.HasPredictionCol;
import java.util.ArrayList;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.api.Types;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "groupCols", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "idCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@NameCn("分组EM")
@NameEn("Group EM")
/* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupEmBatchOp.class */
public final class GroupEmBatchOp extends BatchOperator<GroupEmBatchOp> implements GroupKMeansParams<GroupEmBatchOp> {
    private static final long serialVersionUID = 2403292854593151120L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupEmBatchOp$Clustering.class */
    public static class Clustering implements GroupReduceFunction<Sample, Sample> {
        private static final long serialVersionUID = -6401148777324895859L;
        private int k;
        private double epsilon;
        private ContinuousDistance distance;
        private int maxIter;
        private int dim;

        public Clustering(int i, double d, int i2, int i3, ContinuousDistance continuousDistance) {
            this.epsilon = d;
            this.k = i;
            this.distance = continuousDistance;
            this.maxIter = i2;
            this.dim = i3;
        }

        public void reduce(Iterable<Sample> iterable, Collector<Sample> collector) throws Exception {
            LocalKMeans.clustering(iterable, collector, this.k, this.epsilon, this.maxIter, this.distance);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupEmBatchOp$MapToRow.class */
    public static class MapToRow implements MapFunction<Sample, Row> {
        private static final long serialVersionUID = 5045205789035382392L;
        private int rowArity;
        private int groupColNamesSize;

        public MapToRow(int i, int i2) {
            this.rowArity = i;
            this.groupColNamesSize = i2;
        }

        public Row map(Sample sample) throws Exception {
            Row row = new Row(this.rowArity);
            DenseVector vector = sample.getVector();
            for (int i = 0; i < this.groupColNamesSize; i++) {
                row.setField(i, sample.getGroupColNames()[i]);
            }
            row.setField(this.groupColNamesSize, sample.getSampleId());
            row.setField(this.groupColNamesSize + 1, Long.valueOf(sample.getClusterId()));
            for (int i2 = 0; i2 < vector.size(); i2++) {
                row.setField(i2 + this.groupColNamesSize + 2, Double.valueOf(vector.get(i2)));
            }
            return row;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupEmBatchOp$SelectGroup.class */
    public class SelectGroup implements KeySelector<Sample, String> {
        private static final long serialVersionUID = 4582197026301874450L;

        public SelectGroup() {
        }

        public String getKey(Sample sample) {
            return sample.getGroupColNamesString();
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/clustering/GroupEmBatchOp$mapToDataSample.class */
    public static class mapToDataSample implements MapFunction<Row, Sample> {
        private static final long serialVersionUID = -1252574650762802849L;
        private int dim;
        private int groupColNamesSize;

        public mapToDataSample(int i, int i2) {
            this.dim = i;
            this.groupColNamesSize = i2;
        }

        public Sample map(Row row) throws Exception {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < this.groupColNamesSize; i++) {
                if (null == row.getField(i)) {
                    throw new RuntimeException("There is NULL value in group col!");
                }
                arrayList.add((String) row.getField(i));
            }
            if (null == row.getField(this.groupColNamesSize)) {
                throw new RuntimeException("There is NULL value in id col!");
            }
            String str = (String) row.getField(this.groupColNamesSize);
            double[] dArr = new double[this.dim];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                if (null == row.getField(i2 + this.groupColNamesSize + 1)) {
                    throw new RuntimeException("There is NULL value in value col!");
                }
                dArr[i2] = ((Double) row.getField(i2 + this.groupColNamesSize + 1)).doubleValue();
            }
            return new Sample(str, new DenseVector(dArr), -1L, (String[]) arrayList.toArray(new String[this.groupColNamesSize]));
        }
    }

    public GroupEmBatchOp() {
        super(null);
    }

    public GroupEmBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public GroupEmBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        if (!getParams().contains(HasPredictionCol.PREDICTION_COL)) {
            setPredictionCol("cluster_id");
        }
        String[] numericCols = (!getParams().contains(FEATURE_COLS) || getFeatureCols() == null || getFeatureCols().length <= 0) ? TableUtil.getNumericCols(checkAndGetFirst.getSchema()) : getFeatureCols();
        int intValue = getK().intValue();
        double doubleValue = getEpsilon().doubleValue();
        int intValue2 = getMaxIter().intValue();
        HasClusteringDistanceType.DistanceType distanceType = getDistanceType();
        String[] groupCols = getGroupCols();
        String idCol = getIdCol();
        String predictionCol = getPredictionCol();
        FastDistance fastDistance = distanceType.getFastDistance();
        for (String str : groupCols) {
            if (TableUtil.findColIndex(numericCols, str) >= 0) {
                throw new RuntimeException("groupColNames should NOT be included in featureColNames!");
            }
        }
        if (null == idCol || "".equals(idCol)) {
            throw new RuntimeException("idCol column should be set!");
        }
        if (TableUtil.findColIndex(numericCols, idCol) >= 0) {
            throw new RuntimeException("idCol column should NOT be included in featureColNames !");
        }
        if (TableUtil.findColIndex(groupCols, idCol) >= 0) {
            throw new RuntimeException("idCol column should NOT be included in groupColNames !");
        }
        StringBuilder sb = new StringBuilder();
        for (String str2 : groupCols) {
            sb.append("cast(`").append(str2).append("` as VARCHAR) as `").append(str2).append("`, ");
        }
        sb.append("cast(`").append(idCol).append("` as VARCHAR) as `").append(idCol).append("`, ");
        for (int i = 0; i < numericCols.length; i++) {
            if (i > 0) {
                sb.append(", ");
            }
            sb.append("cast(`").append(numericCols[i]).append("` as double) as `").append(numericCols[i]).append("`");
        }
        int length = numericCols.length;
        ArrayList arrayList = new ArrayList();
        for (String str3 : groupCols) {
            arrayList.add(str3);
        }
        arrayList.add(idCol);
        arrayList.add(predictionCol);
        for (String str4 : numericCols) {
            arrayList.add(str4);
        }
        ArrayList arrayList2 = new ArrayList();
        for (String str5 : groupCols) {
            arrayList2.add(Types.STRING());
        }
        arrayList2.add(Types.STRING());
        arrayList2.add(Types.LONG());
        for (String str6 : numericCols) {
            arrayList2.add(Types.DOUBLE());
        }
        try {
            setOutput((DataSet<Row>) checkAndGetFirst.select(sb.toString()).getDataSet().map(new mapToDataSample(length, groupCols.length)).groupBy(new SelectGroup()).reduceGroup(new Clustering(intValue, doubleValue, intValue2, length, fastDistance)).map(new MapToRow(arrayList.size(), groupCols.length)), new TableSchema((String[]) arrayList.toArray(new String[arrayList.size()]), (TypeInformation[]) arrayList2.toArray(new TypeInformation[arrayList2.size()])));
            return this;
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ GroupEmBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
