package com.alibaba.alink.operator.batch.statistics;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.io.filesystem.copy.csv.CsvInputFormat;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Tensor;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant;
import com.alibaba.alink.operator.common.clustering.lda.LdaVariable;
import com.alibaba.alink.operator.common.nlp.WordCountUtil;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.statistics.SomJni;
import com.alibaba.alink.params.statistics.SomParams;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.ReduceOperator;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)})
@ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES})
@NameCn("Som")
@NameEn("Som")
/* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/SomBatchOp.class */
public final class SomBatchOp extends BatchOperator<SomBatchOp> implements SomParams<SomBatchOp> {
    public static final String[] COL_NAMES = {"meta", "xidx", "yidx", OptimVariable.weights, WordCountUtil.COUNT_COL_NAME};
    public static final TypeInformation[] COL_TYPES = {AlinkTypes.STRING, AlinkTypes.LONG, AlinkTypes.LONG, AlinkTypes.STRING, AlinkTypes.LONG};
    public static final boolean DO_PREDICTION = true;
    private static final long serialVersionUID = -6014481798410706652L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/SomBatchOp$SomModel.class */
    public static class SomModel {
        private int xdim;
        private int ydim;
        private int vdim;
        private float[] weights;
        private long[][] counts;
        private float[][] umatrix;
        private int[] bmu0 = new int[2];
        private SomJni somJni = new SomJni();

        public SomModel(int i, int i2, int i3) {
            this.xdim = i;
            this.ydim = i2;
            this.vdim = i3;
            this.weights = new float[i * i2 * i3];
        }

        private static float squaredDistance(float[] fArr, int i, float[] fArr2, int i2, int i3) {
            float f = 0.0f;
            for (int i4 = 0; i4 < i3; i4++) {
                f += (fArr[i + i4] - fArr2[i2 + i4]) * (fArr[i + i4] - fArr2[i2 + i4]);
            }
            return f;
        }

        public int getNeuronPos(int i, int i2) {
            return ((i2 * this.xdim) + i) * this.vdim;
        }

        public void setNeuron(int i, int i2, String str) {
            double[] data = Tensor.parse(str).getData();
            if (data.length != this.vdim) {
                throw new RuntimeException("invalid data length: " + data.length);
            }
            int neuronPos = getNeuronPos(i, i2);
            for (int i3 = 0; i3 < this.vdim; i3++) {
                this.weights[neuronPos + i3] = (float) data[i3];
            }
        }

        public void init(List<Tuple3<Long, Long, String>> list) {
            for (Tuple3<Long, Long, String> tuple3 : list) {
                int intValue = ((Long) tuple3.f0).intValue();
                int intValue2 = ((Long) tuple3.f1).intValue();
                double[] data = VectorUtil.parseDense((String) tuple3.f2).getData();
                if (data.length != this.vdim) {
                    throw new RuntimeException("Invalid data length: " + data.length);
                }
                int neuronPos = getNeuronPos(intValue, intValue2);
                for (int i = 0; i < this.vdim; i++) {
                    this.weights[neuronPos + i] = (float) data[i];
                }
            }
        }

        public void initCount() {
            this.counts = new long[this.xdim][this.ydim];
        }

        public void increaseCount(int[] iArr, long j) {
            long[] jArr = this.counts[iArr[0]];
            int i = iArr[1];
            jArr[i] = jArr[i] + j;
        }

        public long[][] getCounts() {
            return this.counts;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public float findBMU(float[] fArr, int[] iArr) {
            float findBmuJava = this.somJni.findBmuJava(this.weights, new float[this.xdim * this.ydim], fArr, this.bmu0, this.xdim, this.ydim, this.vdim);
            if (iArr != null) {
                iArr[0] = this.bmu0[0];
                iArr[1] = this.bmu0[1];
            }
            return findBmuJava;
        }

        public List<Tuple3<Long, Long, String>> getWeights() {
            ArrayList arrayList = new ArrayList(this.xdim * this.ydim);
            for (int i = 0; i < this.xdim; i++) {
                for (int i2 = 0; i2 < this.ydim; i2++) {
                    int neuronPos = getNeuronPos(i, i2);
                    StringBuilder sb = new StringBuilder();
                    for (int i3 = 0; i3 < this.vdim; i3++) {
                        if (i3 > 0) {
                            sb.append(",");
                        }
                        sb.append(this.weights[neuronPos + i3]);
                    }
                    arrayList.add(Tuple3.of(Long.valueOf(i), Long.valueOf(i2), sb.toString()));
                }
            }
            return arrayList;
        }

        public void createUMatrix() {
            this.umatrix = new float[this.xdim][this.ydim];
            for (int i = 0; i < this.xdim; i++) {
                for (int i2 = 0; i2 < this.ydim; i2++) {
                    float f = 0.0f;
                    int i3 = 0;
                    for (int i4 = -1; i4 <= 1; i4 += 2) {
                        for (int i5 = -1; i5 <= 1; i5 += 2) {
                            int i6 = i + i4;
                            int i7 = i2 + i5;
                            if (i6 >= 0 && i6 < this.xdim && i7 >= 0 && i7 < this.ydim) {
                                f += squaredDistance(this.weights, getNeuronPos(i, i2), this.weights, getNeuronPos(i6, i7), this.vdim);
                                i3++;
                            }
                        }
                    }
                    this.umatrix[i][i2] = (float) Math.sqrt(f / i3);
                }
            }
        }

        public float getUMatrixValue(int i, int i2) {
            return this.umatrix[i][i2];
        }

        public void writeToFile(String str, String str2) throws Exception {
            StringBuilder sb = new StringBuilder();
            StringBuilder sb2 = new StringBuilder();
            for (int i = 0; i < this.ydim; i++) {
                for (int i2 = 0; i2 < this.xdim; i2++) {
                    if (i2 > 0) {
                        sb.append(" ");
                        sb2.append(" ");
                    }
                    int neuronPos = getNeuronPos(i2, i);
                    sb.append(this.weights[neuronPos + 0]);
                    sb2.append(this.weights[neuronPos + 1]);
                }
                sb.append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
                sb2.append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
            }
            FileOutputStream fileOutputStream = new FileOutputStream(str);
            fileOutputStream.write(sb.toString().getBytes());
            fileOutputStream.close();
            FileOutputStream fileOutputStream2 = new FileOutputStream(str2);
            fileOutputStream2.write(sb2.toString().getBytes());
            fileOutputStream2.close();
        }

        public void writePointsToFile(String str, List<Double> list, List<Double> list2) throws Exception {
            int size = list.size();
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < size; i++) {
                sb.append(list.get(i)).append(" ").append(list2.get(i)).append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
            }
            FileOutputStream fileOutputStream = new FileOutputStream(str);
            fileOutputStream.write(sb.toString().getBytes());
            fileOutputStream.close();
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/SomBatchOp$SomSolver.class */
    public static class SomSolver {
        private int xdim;
        private int ydim;
        private int vdim;
        private double learnRate;
        private double sigma;
        private long currStepNo;
        private long maxStepNo;
        private float[] weights;
        private SomJni somJni = new SomJni();

        public SomSolver(int i, int i2, int i3, double d, double d2, long j) {
            this.currStepNo = 0L;
            this.xdim = i;
            this.ydim = i2;
            this.vdim = i3;
            this.learnRate = d;
            this.sigma = d2;
            this.maxStepNo = j;
            this.currStepNo = 0L;
            this.weights = new float[i * i2 * i3];
            if (j <= 0 || !AlinkGlobalConfiguration.isPrintProcessInfo()) {
                return;
            }
            System.out.println(String.format("xdim=%d,ydim=%d,vdim=%d,learnRate=%f,sigma=%f,maxStepNo=%d", Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3), Double.valueOf(d), Double.valueOf(d2), Long.valueOf(j)));
        }

        private static double decayFunction(double d, double d2, double d3) {
            return d / (1.0d + ((2.0d * d2) / d3));
        }

        public static int getNeuronPos(int i, int i2, int i3, int i4, int i5) {
            return ((i2 * i3) + i) * i5;
        }

        public void init(List<Tuple3<Long, Long, String>> list) {
            for (Tuple3<Long, Long, String> tuple3 : list) {
                int intValue = ((Long) tuple3.f0).intValue();
                int intValue2 = ((Long) tuple3.f1).intValue();
                double[] data = VectorUtil.getDenseVector((String) tuple3.f2).getData();
                if (data.length != this.vdim) {
                    throw new RuntimeException("Invalid data length: " + data.length);
                }
                int neuronPos = getNeuronPos(intValue, intValue2, this.xdim, this.ydim, this.vdim);
                for (int i = 0; i < this.vdim; i++) {
                    this.weights[neuronPos + i] = (float) data[i];
                }
            }
        }

        public void updateBatch(float[] fArr, int i) {
            this.somJni.updateBatchJava(this.weights, fArr, i, (float) decayFunction(this.learnRate, this.currStepNo, this.maxStepNo), (float) decayFunction(this.sigma, this.currStepNo, this.maxStepNo), this.xdim, this.ydim, this.vdim);
            this.currStepNo += i;
        }

        public List<Tuple3<Long, Long, String>> getWeights() {
            ArrayList arrayList = new ArrayList(this.xdim * this.ydim);
            for (int i = 0; i < this.xdim; i++) {
                for (int i2 = 0; i2 < this.ydim; i2++) {
                    int neuronPos = getNeuronPos(i, i2, this.xdim, this.ydim, this.vdim);
                    StringBuilder sb = new StringBuilder();
                    for (int i3 = 0; i3 < this.vdim; i3++) {
                        if (i3 > 0) {
                            sb.append(",");
                        }
                        sb.append(this.weights[neuronPos + i3]);
                    }
                    arrayList.add(Tuple3.of(Long.valueOf(i), Long.valueOf(i2), sb.toString()));
                }
            }
            return arrayList;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/statistics/SomBatchOp$SomTask.class */
    public static class SomTask extends RichMapPartitionFunction<Row, Tuple3<Long, Long, String>> {
        private static final long serialVersionUID = 6117856294526477050L;
        Params params;
        transient SomSolver solver = null;

        public SomTask(Params params) {
            this.params = params;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println("\n ** step " + getIterationRuntimeContext().getSuperstepNumber());
                System.out.println(new Date().toString() + ": start step ...");
            }
            if (this.solver == null) {
                int intValue = ((Integer) this.params.get(SomParams.NUM_ITERS)).intValue();
                int intValue2 = ((Integer) this.params.get(SomParams.XDIM)).intValue();
                int intValue3 = ((Integer) this.params.get(SomParams.YDIM)).intValue();
                int intValue4 = ((Integer) this.params.get(SomParams.VDIM)).intValue();
                double doubleValue = ((Double) this.params.get(SomParams.LEARN_RATE)).doubleValue();
                double doubleValue2 = ((Double) this.params.get(SomParams.SIGMA)).doubleValue();
                List<Tuple3<Long, Long, String>> broadcastVariable = getRuntimeContext().getBroadcastVariable(LdaVariable.initModel);
                if (broadcastVariable.size() != intValue2 * intValue3) {
                    throw new RuntimeException("unexpected");
                }
                this.solver = new SomSolver(intValue2, intValue3, intValue4, doubleValue, doubleValue2, intValue * ((Long) getRuntimeContext().getBroadcastVariable("numSamples").get(0)).longValue());
                this.solver.init(broadcastVariable);
            }
        }

        public void close() throws Exception {
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println(new Date().toString() + ": close step ...");
            }
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple3<Long, Long, String>> collector) throws Exception {
            int intValue = ((Integer) this.params.get(SomParams.VDIM)).intValue();
            float[] fArr = new float[65536 * intValue];
            int i = 0;
            Iterator<Row> it = iterable.iterator();
            while (it.hasNext()) {
                double[] data = VectorUtil.getDenseVector(it.next().getField(0)).getData();
                int i2 = i * intValue;
                for (int i3 = 0; i3 < intValue; i3++) {
                    fArr[i2 + i3] = (float) data[i3];
                }
                i++;
                if (i >= 65536) {
                    this.solver.updateBatch(fArr, i);
                    i = 0;
                }
            }
            if (i > 0) {
                this.solver.updateBatch(fArr, i);
            }
            Iterator<Tuple3<Long, Long, String>> it2 = this.solver.getWeights().iterator();
            while (it2.hasNext()) {
                collector.collect(it2.next());
            }
        }
    }

    public SomBatchOp() {
        this(new Params());
    }

    public SomBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public SomBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String vectorCol = getVectorCol();
        int intValue = getNumIters().intValue();
        final int intValue2 = getXdim().intValue();
        final int intValue3 = getYdim().intValue();
        final int intValue4 = getVdim().intValue();
        final String format = String.format("%d,%d,%d,r", Integer.valueOf(intValue2), Integer.valueOf(intValue3), Integer.valueOf(intValue4));
        getEvaluation().booleanValue();
        ReduceOperator reduce = checkAndGetFirst.getDataSet().mapPartition(new MapPartitionFunction<Row, Long>() { // from class: com.alibaba.alink.operator.batch.statistics.SomBatchOp.2
            private static final long serialVersionUID = -4852925590649190739L;

            public void mapPartition(Iterable<Row> iterable, Collector<Long> collector) throws Exception {
                long j = 0;
                for (Row row : iterable) {
                    j++;
                }
                collector.collect(Long.valueOf(j));
            }
        }).reduce(new ReduceFunction<Long>() { // from class: com.alibaba.alink.operator.batch.statistics.SomBatchOp.1
            private static final long serialVersionUID = -6343518193952236485L;

            public Long reduce(Long l, Long l2) throws Exception {
                return Long.valueOf(l.longValue() + l2.longValue());
            }
        });
        IterativeDataSet parallelism = checkAndGetFirst.select(vectorCol).getDataSet().mapPartition(new RichMapPartitionFunction<Row, Tuple3<Long, Long, String>>() { // from class: com.alibaba.alink.operator.batch.statistics.SomBatchOp.3
            private static final long serialVersionUID = -1154161394939821199L;
            List<Row> selectedRows;

            public void open(Configuration configuration) throws Exception {
                this.selectedRows = new ArrayList(intValue2 * intValue3);
                long longValue = ((Long) getRuntimeContext().getBroadcastVariable("numSamples").get(0)).longValue();
                if (longValue < intValue2 * intValue3) {
                    throw new RuntimeException("xdim * ydim > num training samples");
                }
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("Initializing model, num training samples: " + longValue);
                }
            }

            public void mapPartition(Iterable<Row> iterable, Collector<Tuple3<Long, Long, String>> collector) throws Exception {
                Random random = new Random();
                int i = 0;
                for (Row row : iterable) {
                    if (i < intValue2 * intValue3) {
                        this.selectedRows.add(row);
                    } else if (random.nextDouble() < ((double) (intValue2 * intValue3)) / ((double) (i + 1))) {
                        this.selectedRows.set(random.nextInt(intValue2 * intValue3), row);
                    }
                    i++;
                }
                int i2 = 0;
                for (int i3 = 0; i3 < intValue2; i3++) {
                    for (int i4 = 0; i4 < intValue3; i4++) {
                        Object field = this.selectedRows.get(i2).getField(0);
                        if (field instanceof DenseVector) {
                            field = VectorUtil.serialize(field);
                        } else if (field instanceof SparseVector) {
                            field = VectorUtil.serialize(field);
                        }
                        collector.collect(Tuple3.of(Long.valueOf(i3), Long.valueOf(i4), (String) field));
                        i2++;
                    }
                }
            }
        }).withBroadcastSet(reduce, "numSamples").setParallelism(1).name("init_model").iterate(intValue).setParallelism(1);
        DataSet closeWith = parallelism.closeWith(checkAndGetFirst.select(vectorCol).getDataSet().mapPartition(new SomTask(getParams())).withBroadcastSet(parallelism, LdaVariable.initModel).withBroadcastSet(reduce, "numSamples").setParallelism(1).name("som_train"));
        setOutput(checkAndGetFirst.select(vectorCol).getDataSet().mapPartition(new RichMapPartitionFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.statistics.SomBatchOp.4
            private static final long serialVersionUID = -7426117483145285343L;

            public void open(Configuration configuration) throws Exception {
                if (getRuntimeContext().getNumberOfParallelSubtasks() != 1) {
                    throw new RuntimeException("parallelism should be 1");
                }
            }

            public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
                List<Tuple3<Long, Long, String>> broadcastVariable = getRuntimeContext().getBroadcastVariable("somModel");
                if (broadcastVariable.size() != intValue2 * intValue3) {
                    throw new RuntimeException("unexpected");
                }
                SomModel somModel = new SomModel(intValue2, intValue3, intValue4);
                somModel.init(broadcastVariable);
                somModel.initCount();
                float[] fArr = new float[intValue4];
                int[] iArr = new int[2];
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    double[] data = VectorUtil.getDenseVector(it.next().getField(0)).getData();
                    for (int i = 0; i < intValue4; i++) {
                        fArr[i] = (float) data[i];
                    }
                    somModel.findBMU(fArr, iArr);
                    somModel.increaseCount(iArr, 1L);
                }
                long[][] counts = somModel.getCounts();
                for (Tuple3<Long, Long, String> tuple3 : somModel.getWeights()) {
                    collector.collect(Row.of(new Object[]{format, tuple3.f0, tuple3.f1, tuple3.f2, Long.valueOf(counts[((Long) tuple3.f0).intValue()][((Long) tuple3.f1).intValue()])}));
                }
            }
        }).withBroadcastSet(closeWith, "somModel").setParallelism(1).name(DbscanConstant.COUNT), COL_NAMES, COL_TYPES);
        setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) checkAndGetFirst.getDataSet().map(new RichMapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.statistics.SomBatchOp.5
            private static final long serialVersionUID = 5561628297750641436L;
            transient SomModel model;

            public void open(Configuration configuration) throws Exception {
                List<Tuple3<Long, Long, String>> broadcastVariable = getRuntimeContext().getBroadcastVariable("somModel");
                if (broadcastVariable.size() != intValue2 * intValue3) {
                    throw new RuntimeException("unexpected");
                }
                this.model = new SomModel(intValue2, intValue3, intValue4);
                this.model.init(broadcastVariable);
            }

            public Row map(Row row) throws Exception {
                float[] fArr = new float[intValue4];
                int[] iArr = new int[2];
                double[] data = VectorUtil.getDenseVector(row.getField(0)).getData();
                for (int i = 0; i < intValue4; i++) {
                    fArr[i] = (float) data[i];
                }
                this.model.findBMU(fArr, iArr);
                Row row2 = new Row(row.getArity() + 2);
                for (int i2 = 0; i2 < row.getArity(); i2++) {
                    row2.setField(i2, row.getField(i2));
                }
                for (int i3 = 0; i3 < 2; i3++) {
                    row2.setField(i3 + row.getArity(), Long.valueOf(iArr[i3]));
                }
                return row2;
            }
        }).withBroadcastSet(closeWith, "somModel"), (String[]) ArrayUtils.addAll(checkAndGetFirst.getColNames(), new String[]{"xidx", "yidx"}), (TypeInformation<?>[]) ArrayUtils.addAll(checkAndGetFirst.getColTypes(), new TypeInformation[]{AlinkTypes.LONG, AlinkTypes.LONG}))});
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ SomBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
