package com.alibaba.alink.operator.common.classification.ann;

import com.alibaba.alink.common.exceptions.AkIllegalModelException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.optim.Lbfgs;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/classification/ann/FeedForwardTrainer.class */
public class FeedForwardTrainer implements Serializable {
    private static final long serialVersionUID = -1664569158355844836L;
    private final Topology topology;
    private final int inputSize;
    private final int outputSize;
    private final int blockSize;
    private final boolean onehotLabel;
    private final DenseVector initialWeights;

    public FeedForwardTrainer(Topology topology, int i, int i2, boolean z, int i3, DenseVector denseVector) {
        this.topology = topology;
        this.inputSize = i;
        this.outputSize = i2;
        this.onehotLabel = z;
        this.blockSize = i3;
        this.initialWeights = denseVector;
    }

    public DataSet<DenseVector> train(DataSet<Tuple2<Double, DenseVector>> dataSet, DataSet<DenseVector> dataSet2, Params params) {
        Topology topology = this.topology;
        int i = this.inputSize;
        int i2 = this.outputSize;
        boolean z = this.onehotLabel;
        ParamInfo build = ParamInfoFactory.createParamInfo("numSearchStep", Integer.class).setDescription("num search step").setRequired().build();
        DataSet<DenseVector> initModel = dataSet2 != null ? dataSet2 : initModel(dataSet, this.topology);
        DataSet<Tuple3<Double, Double, Vector>> stack = stack(dataSet, this.blockSize, i, i2, z);
        params.set((ParamInfo<ParamInfo>) build, (ParamInfo) 3);
        Lbfgs lbfgs = new Lbfgs(dataSet.getExecutionEnvironment().fromElements(new OptimObjFunc[]{new AnnObjFunc(topology, i, i2, z, params)}), stack, BatchOperator.getExecutionEnvironmentFromDataSets(dataSet).fromElements(new Integer[]{Integer.valueOf(i)}), params);
        lbfgs.initCoefWith(initModel);
        return lbfgs.optimize().map(new MapFunction<Tuple2<DenseVector, double[]>, DenseVector>() { // from class: com.alibaba.alink.operator.common.classification.ann.FeedForwardTrainer.1
            private static final long serialVersionUID = -6247802998516251320L;

            public DenseVector map(Tuple2<DenseVector, double[]> tuple2) {
                return (DenseVector) tuple2.f0;
            }
        });
    }

    private static DataSet<Tuple3<Double, Double, Vector>> stack(DataSet<Tuple2<Double, DenseVector>> dataSet, final int i, final int i2, final int i3, final boolean z) {
        return dataSet.mapPartition(new MapPartitionFunction<Tuple2<Double, DenseVector>, Tuple3<Double, Double, Vector>>() { // from class: com.alibaba.alink.operator.common.classification.ann.FeedForwardTrainer.2
            private static final long serialVersionUID = -4065550804759190453L;

            public void mapPartition(Iterable<Tuple2<Double, DenseVector>> iterable, Collector<Tuple3<Double, Double, Vector>> collector) {
                ArrayList arrayList = new ArrayList(i);
                for (int i4 = 0; i4 < i; i4++) {
                    arrayList.add(null);
                }
                int i5 = 0;
                Stacker stacker = new Stacker(i2, i3, z);
                Iterator<Tuple2<Double, DenseVector>> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.set(i5, it.next());
                    i5++;
                    if (i5 >= i) {
                        collector.collect(stacker.stack(arrayList, i5));
                        i5 = 0;
                    }
                }
                if (i5 > 0) {
                    collector.collect(stacker.stack(arrayList, i5));
                }
            }
        }).name("stack_data");
    }

    private DataSet<DenseVector> initModel(DataSet<?> dataSet, final Topology topology) {
        if (this.initialWeights == null) {
            return BatchOperator.getExecutionEnvironmentFromDataSets(dataSet).fromElements(new Integer[]{0}).map(new RichMapFunction<Integer, DenseVector>() { // from class: com.alibaba.alink.operator.common.classification.ann.FeedForwardTrainer.3
                private static final long serialVersionUID = 8668081633098768854L;
                final double initStdev = 0.05d;
                final long seed = 1;
                transient Random random;

                public void open(Configuration configuration) {
                    this.random = new Random(1L);
                }

                public DenseVector map(Integer num) {
                    DenseVector zeros = DenseVector.zeros(topology.getWeightSize());
                    for (int i = 0; i < zeros.size(); i++) {
                        zeros.set(i, this.random.nextGaussian() * 0.05d);
                    }
                    return zeros;
                }
            }).name("init_weights");
        }
        if (this.initialWeights.size() != topology.getWeightSize()) {
            throw new AkIllegalModelException("Invalid initial weights, size mismatch");
        }
        return BatchOperator.getExecutionEnvironmentFromDataSets(dataSet).fromElements(new DenseVector[]{this.initialWeights});
    }
}
