package com.alibaba.alink.operator.common.recommendation;

import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.NormalEquation;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.dataproc.MultiStringIndexerPredictBatchOp;
import com.alibaba.alink.operator.batch.dataproc.MultiStringIndexerTrainBatchOp;
import com.alibaba.alink.operator.batch.sql.FullOuterJoinBatchOp;
import com.alibaba.alink.operator.batch.sql.LeftOuterJoinBatchOp;
import com.alibaba.alink.operator.batch.sql.UnionBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.recommendation.AlsImplicitTrainParams;
import com.alibaba.alink.params.recommendation.AlsTrainParams;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Random;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichCoGroupFunction;
import org.apache.flink.api.common.functions.RichFilterFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/recommendation/HugeMfAlsImpl.class */
public class HugeMfAlsImpl {
    private static final Logger LOG = LoggerFactory.getLogger(HugeMfAlsImpl.class);

    /* loaded from: input_file:com/alibaba/alink/operator/common/recommendation/HugeMfAlsImpl$AlsTrain.class */
    public static class AlsTrain {
        private final int numFactors;
        private final int numIters;
        private final double lambda;
        private final boolean implicitPrefs;
        private final double alpha;
        private final int numMiniBatches;
        private final boolean nonnegative;

        /* loaded from: input_file:com/alibaba/alink/operator/common/recommendation/HugeMfAlsImpl$AlsTrain$DataProfile.class */
        public static class DataProfile implements Serializable {
            private static final long serialVersionUID = 1976492491732644585L;
            public long parallelism;
            public long numSamples;
            public long numUsers;
            public long numItems;
            public int numUserBatches;
            public int numItemBatches;

            void decideNumMiniBatches(int i, int i2, int i3) {
                this.numUserBatches = decideUserMiniBatches(this.numSamples, this.numItems, i, i2, i3);
                this.numItemBatches = decideUserMiniBatches(this.numSamples, this.numUsers, i, i2, i3);
            }

            static int decideUserMiniBatches(long j, long j2, int i, int i2, int i3) {
                long j3 = 1;
                if (j2 * i > 209715200) {
                    j3 = ((j * i) / (i2 * 209715200)) + 1;
                }
                return (int) Math.max(j3, i3);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:com/alibaba/alink/operator/common/recommendation/HugeMfAlsImpl$AlsTrain$Factors.class */
        public static class Factors implements Serializable {
            private static final long serialVersionUID = -616590158456104866L;
            public byte identity;
            public long nodeId;
            public float[] factors;

            private Factors() {
            }

            void getFactorsAsDoubleArray(double[] dArr) {
                for (int i = 0; i < this.factors.length; i++) {
                    dArr[i] = this.factors[i];
                }
            }

            void copyFactorsFromDoubleArray(double[] dArr) {
                if (this.factors == null) {
                    this.factors = new float[dArr.length];
                }
                for (int i = 0; i < dArr.length; i++) {
                    this.factors[i] = (float) dArr[i];
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:com/alibaba/alink/operator/common/recommendation/HugeMfAlsImpl$AlsTrain$Ratings.class */
        public static class Ratings implements Serializable {
            private static final long serialVersionUID = -5283706915605582930L;
            public byte identity;
            public long nodeId;
            public long[] neighbors;
            public float[] ratings;

            private Ratings() {
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:com/alibaba/alink/operator/common/recommendation/HugeMfAlsImpl$AlsTrain$UpdateFactorsFunc.class */
        public static class UpdateFactorsFunc extends RichCoGroupFunction<Tuple2<Integer, Ratings>, Tuple2<Integer, Factors>, Factors> {
            private static final long serialVersionUID = 4950896077295065254L;
            final int numFactors;
            final double lambda;
            final double alpha;
            final boolean explicit;
            final boolean nonnegative;
            private int numNodes;
            private long numEdges;
            private long numNeighbors;
            private transient double[] YtY;
            static final /* synthetic */ boolean $assertionsDisabled;

            UpdateFactorsFunc(boolean z, int i, double d, boolean z2) {
                this.numNodes = 0;
                this.numEdges = 0L;
                this.numNeighbors = 0L;
                this.YtY = null;
                this.explicit = z;
                this.numFactors = i;
                this.lambda = d;
                this.alpha = Criteria.INVALID_GAIN;
                this.nonnegative = z2;
            }

            UpdateFactorsFunc(boolean z, int i, double d, double d2, boolean z2) {
                this.numNodes = 0;
                this.numEdges = 0L;
                this.numNeighbors = 0L;
                this.YtY = null;
                this.explicit = z;
                this.numFactors = i;
                this.lambda = d;
                this.alpha = d2;
                this.nonnegative = z2;
            }

            public void open(Configuration configuration) {
                this.numNodes = 0;
                this.numEdges = 0L;
                this.numNeighbors = 0L;
                if (this.explicit) {
                    return;
                }
                this.YtY = (double[]) getRuntimeContext().getBroadcastVariable("YtY").get(0);
            }

            public void close() {
                HugeMfAlsImpl.LOG.info("Updated factors, num nodes {}, num edges {}, recv neighbors {}", new Object[]{Integer.valueOf(this.numNodes), Long.valueOf(this.numEdges), Long.valueOf(this.numNeighbors)});
            }

            public void coGroup(Iterable<Tuple2<Integer, Ratings>> iterable, Iterable<Tuple2<Integer, Factors>> iterable2, Collector<Factors> collector) {
                if (!$assertionsDisabled && (iterable == null || iterable2 == null)) {
                    throw new AssertionError();
                }
                ArrayList arrayList = new ArrayList();
                HashMap hashMap = new HashMap();
                for (Tuple2<Integer, Factors> tuple2 : iterable2) {
                    arrayList.add(tuple2);
                    hashMap.put(Long.valueOf(((Factors) tuple2.f1).nodeId), Integer.valueOf((int) this.numNeighbors));
                    this.numNeighbors++;
                }
                NormalEquation normalEquation = new NormalEquation(this.numFactors);
                DenseVector denseVector = new DenseVector(this.numFactors);
                DenseVector denseVector2 = new DenseVector(this.numFactors);
                for (Tuple2<Integer, Ratings> tuple22 : iterable) {
                    this.numNodes++;
                    this.numEdges += ((Ratings) tuple22.f1).neighbors.length;
                    normalEquation.reset();
                    if (this.explicit) {
                        long[] jArr = ((Ratings) tuple22.f1).neighbors;
                        float[] fArr = ((Ratings) tuple22.f1).ratings;
                        for (int i = 0; i < jArr.length; i++) {
                            ((Factors) ((Tuple2) arrayList.get(((Integer) hashMap.get(Long.valueOf(jArr[i]))).intValue())).f1).getFactorsAsDoubleArray(denseVector2.getData());
                            normalEquation.add(denseVector2, fArr[i], 1.0d);
                        }
                        normalEquation.regularize(jArr.length * this.lambda);
                        normalEquation.solve(denseVector, this.nonnegative);
                    } else {
                        normalEquation.merge(new DenseMatrix(this.numFactors, this.numFactors, this.YtY));
                        int i2 = 0;
                        long[] jArr2 = ((Ratings) tuple22.f1).neighbors;
                        float[] fArr2 = ((Ratings) tuple22.f1).ratings;
                        for (int i3 = 0; i3 < jArr2.length; i3++) {
                            Integer num = (Integer) hashMap.get(Long.valueOf(jArr2[i3]));
                            float f = fArr2[i3];
                            double d = 0.0d;
                            if (f > 0.0f) {
                                i2++;
                                d = this.alpha * f;
                            }
                            ((Factors) ((Tuple2) arrayList.get(num.intValue())).f1).getFactorsAsDoubleArray(denseVector2.getData());
                            normalEquation.add(denseVector2, ((double) f) > Criteria.INVALID_GAIN ? 1.0d + d : Criteria.INVALID_GAIN, d);
                        }
                        normalEquation.regularize(Math.max(i2, 1) * this.lambda);
                        normalEquation.solve(denseVector, this.nonnegative);
                    }
                    Factors factors = new Factors();
                    factors.identity = ((Ratings) tuple22.f1).identity;
                    factors.nodeId = ((Ratings) tuple22.f1).nodeId;
                    factors.copyFactorsFromDoubleArray(denseVector.getData());
                    collector.collect(factors);
                }
            }

            static {
                $assertionsDisabled = !HugeMfAlsImpl.class.desiredAssertionStatus();
            }
        }

        public AlsTrain(int i, int i2, double d, boolean z, double d2, int i3, boolean z2) {
            this.numFactors = i;
            this.numIters = i2;
            this.lambda = d;
            this.implicitPrefs = z;
            this.alpha = d2;
            this.numMiniBatches = i3;
            this.nonnegative = z2;
        }

        public DataSet<Tuple3<Byte, Long, float[]>> fit(DataSet<Tuple3<Long, Long, Float>> dataSet) {
            DataSet<Ratings> initGraph = initGraph(dataSet);
            IterativeDataSet iterate = initFactors(initGraph, this.numFactors).iterate(Integer.MAX_VALUE);
            Tuple2<DataSet<Factors>, DataSet<Integer>> updateFactors = updateFactors(iterate, initGraph, this.numMiniBatches, this.numFactors, this.nonnegative, this.numIters);
            return iterate.closeWith((DataSet) updateFactors.f0, (DataSet) updateFactors.f1).map(new MapFunction<Factors, Tuple3<Byte, Long, float[]>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.1
                private static final long serialVersionUID = -8820639290497738048L;

                public Tuple3<Byte, Long, float[]> map(Factors factors) {
                    return Tuple3.of(Byte.valueOf(factors.identity), Long.valueOf(factors.nodeId), factors.factors);
                }
            });
        }

        public DataSet<Tuple3<Byte, Long, float[]>> fit(DataSet<Row> dataSet, DataSet<Row> dataSet2, DataSet<Tuple3<Long, Long, Float>> dataSet3) {
            DataSet<Ratings> initGraph = initGraph(dataSet3);
            IterativeDataSet iterate = initFactors(dataSet, dataSet2, initGraph, this.numFactors).iterate(Integer.MAX_VALUE);
            Tuple2<DataSet<Factors>, DataSet<Integer>> updateFactors = updateFactors(iterate, initGraph, this.numMiniBatches, this.numFactors, this.nonnegative, this.numIters);
            return iterate.closeWith((DataSet) updateFactors.f0, (DataSet) updateFactors.f1).map(new MapFunction<Factors, Tuple3<Byte, Long, float[]>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.2
                private static final long serialVersionUID = -8820639290497738048L;

                public Tuple3<Byte, Long, float[]> map(Factors factors) {
                    return Tuple3.of(Byte.valueOf(factors.identity), Long.valueOf(factors.nodeId), factors.factors);
                }
            });
        }

        private DataSet<Ratings> initGraph(DataSet<Tuple3<Long, Long, Float>> dataSet) {
            return dataSet.flatMap(new FlatMapFunction<Tuple3<Long, Long, Float>, Tuple4<Long, Long, Float, Byte>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.4
                private static final long serialVersionUID = 3894371804771123007L;

                public void flatMap(Tuple3<Long, Long, Float> tuple3, Collector<Tuple4<Long, Long, Float, Byte>> collector) {
                    collector.collect(Tuple4.of(tuple3.f0, tuple3.f1, tuple3.f2, (byte) 0));
                    collector.collect(Tuple4.of(tuple3.f1, tuple3.f0, tuple3.f2, (byte) 1));
                }

                public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                    flatMap((Tuple3<Long, Long, Float>) obj, (Collector<Tuple4<Long, Long, Float, Byte>>) collector);
                }
            }).groupBy(new int[]{3, 0}).sortGroup(1, Order.ASCENDING).reduceGroup(new GroupReduceFunction<Tuple4<Long, Long, Float, Byte>, Ratings>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.3
                private static final long serialVersionUID = 3391161187867934671L;

                public void reduce(Iterable<Tuple4<Long, Long, Float, Byte>> iterable, Collector<Ratings> collector) {
                    byte b = -1;
                    long j = -1;
                    ArrayList arrayList = new ArrayList();
                    ArrayList arrayList2 = new ArrayList();
                    for (Tuple4<Long, Long, Float, Byte> tuple4 : iterable) {
                        b = ((Byte) tuple4.f3).byteValue();
                        j = ((Long) tuple4.f0).longValue();
                        arrayList.add(tuple4.f1);
                        arrayList2.add(tuple4.f2);
                    }
                    Ratings ratings = new Ratings();
                    ratings.nodeId = j;
                    ratings.identity = b;
                    ratings.neighbors = new long[arrayList.size()];
                    ratings.ratings = new float[arrayList.size()];
                    for (int i = 0; i < ratings.neighbors.length; i++) {
                        ratings.neighbors[i] = ((Long) arrayList.get(i)).longValue();
                        ratings.ratings[i] = ((Float) arrayList2.get(i)).floatValue();
                    }
                    collector.collect(ratings);
                }
            }).name("init_graph");
        }

        private DataSet<Factors> initFactors(DataSet<Ratings> dataSet, final int i) {
            return dataSet.map(new RichMapFunction<Ratings, Factors>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.5
                private static final long serialVersionUID = -6242580857177532093L;
                transient Random random;
                transient Factors reusedFactors;

                public void open(Configuration configuration) {
                    this.random = new Random(getRuntimeContext().getIndexOfThisSubtask());
                    this.reusedFactors = new Factors();
                    this.reusedFactors.factors = new float[i];
                }

                public Factors map(Ratings ratings) {
                    this.reusedFactors.identity = ratings.identity;
                    this.reusedFactors.nodeId = ratings.nodeId;
                    for (int i2 = 0; i2 < i; i2++) {
                        this.reusedFactors.factors[i2] = this.random.nextFloat();
                    }
                    return this.reusedFactors;
                }
            }).name("InitFactors");
        }

        private DataSet<Factors> initFactors(DataSet<Row> dataSet, DataSet<Row> dataSet2, DataSet<Ratings> dataSet3, final int i) {
            return dataSet3.map(new RichMapFunction<Ratings, Tuple2<String, Factors>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.6
                private static final long serialVersionUID = -6242580857177532093L;
                transient Random random;
                transient Factors reusedFactors;

                public void open(Configuration configuration) {
                    this.random = new Random(getRuntimeContext().getIndexOfThisSubtask());
                    this.reusedFactors = new Factors();
                    this.reusedFactors.factors = null;
                }

                public Tuple2<String, Factors> map(Ratings ratings) {
                    this.reusedFactors.identity = ratings.identity;
                    this.reusedFactors.nodeId = ratings.nodeId;
                    return Tuple2.of(((int) ratings.identity) + "_" + ratings.nodeId, this.reusedFactors);
                }
            }).name("InitFactors_hot_init").leftOuterJoin(dataSet.map(new RichMapFunction<Row, Tuple2<String, float[]>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.7
                private float[] reusedArray;

                public void open(Configuration configuration) {
                    this.reusedArray = new float[i];
                }

                public Tuple2<String, float[]> map(Row row) {
                    double[] data = VectorUtil.getDenseVector(row.getField(1)).getData();
                    for (int i2 = 0; i2 < i; i2++) {
                        this.reusedArray[i2] = (float) data[i2];
                    }
                    return Tuple2.of("0_" + row.getField(0), this.reusedArray);
                }
            }).union(dataSet2.map(new RichMapFunction<Row, Tuple2<String, float[]>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.8
                private float[] reusedArray;

                public void open(Configuration configuration) {
                    this.reusedArray = new float[i];
                }

                public Tuple2<String, float[]> map(Row row) {
                    double[] data = VectorUtil.getDenseVector(row.getField(1)).getData();
                    for (int i2 = 0; i2 < i; i2++) {
                        this.reusedArray[i2] = (float) data[i2];
                    }
                    return Tuple2.of("1_" + row.getField(0), this.reusedArray);
                }
            }))).where(new int[]{0}).equalTo(new int[]{0}).with(new RichJoinFunction<Tuple2<String, Factors>, Tuple2<String, float[]>, Factors>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.9
                private transient Random random;

                public void open(Configuration configuration) throws Exception {
                    super.open(configuration);
                    this.random = new Random(2020L);
                }

                public Factors join(Tuple2<String, Factors> tuple2, Tuple2<String, float[]> tuple22) {
                    if (tuple22 == null) {
                        ((Factors) tuple2.f1).factors = new float[i];
                        for (int i2 = 0; i2 < i; i2++) {
                            ((Factors) tuple2.f1).factors[i2] = this.random.nextFloat();
                        }
                    } else if (tuple2 == null) {
                        Factors factors = new Factors();
                        factors.factors = (float[]) tuple22.f1;
                        String[] split = ((String) tuple22.f0).split("_");
                        factors.identity = Byte.parseByte(split[0]);
                        factors.nodeId = Long.parseLong(split[1]);
                        tuple2 = Tuple2.of(tuple22.f0, factors);
                        ((Factors) tuple2.f1).factors = new float[i];
                        for (int i3 = 0; i3 < i; i3++) {
                            ((Factors) tuple2.f1).factors[i3] = this.random.nextFloat();
                        }
                    } else {
                        ((Factors) tuple2.f1).factors = (float[]) tuple22.f1;
                    }
                    return (Factors) tuple2.f1;
                }
            });
        }

        private static DataSet<DataProfile> generateDataProfile(DataSet<Ratings> dataSet, final int i, final int i2) {
            return dataSet.mapPartition(new MapPartitionFunction<Ratings, Tuple3<Long, Long, Long>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.12
                private static final long serialVersionUID = -3529850335007040435L;

                public void mapPartition(Iterable<Ratings> iterable, Collector<Tuple3<Long, Long, Long>> collector) {
                    long j = 0;
                    long j2 = 0;
                    long j3 = 0;
                    Iterator<Ratings> it = iterable.iterator();
                    while (it.hasNext()) {
                        if (it.next().identity == 0) {
                            j++;
                            j3 += r0.ratings.length;
                        } else {
                            j2++;
                        }
                    }
                    collector.collect(Tuple3.of(Long.valueOf(j), Long.valueOf(j2), Long.valueOf(j3)));
                }
            }).reduce(new ReduceFunction<Tuple3<Long, Long, Long>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.11
                private static final long serialVersionUID = 3849683380245684843L;

                public Tuple3<Long, Long, Long> reduce(Tuple3<Long, Long, Long> tuple3, Tuple3<Long, Long, Long> tuple32) {
                    tuple3.f0 = Long.valueOf(((Long) tuple3.f0).longValue() + ((Long) tuple32.f0).longValue());
                    tuple3.f1 = Long.valueOf(((Long) tuple3.f1).longValue() + ((Long) tuple32.f1).longValue());
                    tuple3.f2 = Long.valueOf(((Long) tuple3.f2).longValue() + ((Long) tuple32.f2).longValue());
                    return tuple3;
                }
            }).map(new RichMapFunction<Tuple3<Long, Long, Long>, DataProfile>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.10
                private static final long serialVersionUID = -2224348217053561771L;

                public DataProfile map(Tuple3<Long, Long, Long> tuple3) {
                    int maxNumberOfParallelSubtasks = getRuntimeContext().getMaxNumberOfParallelSubtasks();
                    DataProfile dataProfile = new DataProfile();
                    dataProfile.parallelism = maxNumberOfParallelSubtasks;
                    dataProfile.numUsers = ((Long) tuple3.f0).longValue();
                    dataProfile.numItems = ((Long) tuple3.f1).longValue();
                    dataProfile.numSamples = ((Long) tuple3.f2).longValue();
                    dataProfile.decideNumMiniBatches(i, maxNumberOfParallelSubtasks, i2);
                    return dataProfile;
                }
            }).name("data_profiling");
        }

        private Tuple2<DataSet<Factors>, DataSet<Integer>> updateFactors(DataSet<Factors> dataSet, DataSet<Ratings> dataSet2, int i, int i2, boolean z, final int i3) {
            FlatMapOperator flatMap = dataSet.flatMap(new FlatMapFunction<Factors, Factors>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.13
                private static final long serialVersionUID = -4512655206561622474L;

                public void flatMap(Factors factors, Collector<Factors> collector) {
                }

                public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                    flatMap((Factors) obj, (Collector<Factors>) collector);
                }
            });
            DataSet<DataProfile> generateDataProfile = generateDataProfile(dataSet2, i2, i);
            MapOperator map = dataSet2.filter(new RichFilterFunction<Ratings>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.15
                private static final long serialVersionUID = -6221088866110309923L;
                private transient DataProfile profile;
                private transient int alsStepNo;
                private transient int userOrItem;
                private transient int subStepNo;
                private transient int numSubsteps;

                public void open(Configuration configuration) {
                    if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
                        this.profile = (DataProfile) getRuntimeContext().getBroadcastVariable("profile").get(0);
                        HugeMfAlsImpl.LOG.info("Data profile {}", JsonConverter.toJson(this.profile));
                        this.subStepNo = -1;
                        this.userOrItem = 0;
                        this.alsStepNo = 0;
                        this.numSubsteps = this.profile.numUserBatches;
                    }
                    this.subStepNo++;
                    if (this.userOrItem == 0) {
                        if (this.subStepNo >= this.numSubsteps) {
                            this.userOrItem = 1;
                            this.numSubsteps = this.profile.numItemBatches;
                            this.subStepNo = 0;
                        }
                    } else if (this.userOrItem == 1 && this.subStepNo >= this.numSubsteps) {
                        this.userOrItem = 0;
                        this.numSubsteps = this.profile.numUserBatches;
                        this.subStepNo = 0;
                        this.alsStepNo++;
                    }
                    HugeMfAlsImpl.LOG.info("ALS step no {}, user or item {}, sub step no {}", new Object[]{Integer.valueOf(this.alsStepNo), Integer.valueOf(this.userOrItem), Integer.valueOf(this.subStepNo)});
                }

                public boolean filter(Ratings ratings) {
                    return this.alsStepNo < i3 && ratings.identity == this.userOrItem && Math.abs(ratings.nodeId) % ((long) this.numSubsteps) == ((long) this.subStepNo);
                }
            }).name("createMiniBatch").withBroadcastSet(flatMap, "empty").withBroadcastSet(generateDataProfile, "profile").map(new RichMapFunction<Ratings, Tuple2<Integer, Ratings>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.14
                private static final long serialVersionUID = 2482586233207883428L;
                transient int partitionId;

                public void open(Configuration configuration) {
                    this.partitionId = getRuntimeContext().getIndexOfThisSubtask();
                }

                public Tuple2<Integer, Ratings> map(Ratings ratings) {
                    return Tuple2.of(Integer.valueOf(this.partitionId), ratings);
                }
            });
            Operator name = map.flatMap(new FlatMapFunction<Tuple2<Integer, Ratings>, Tuple3<Integer, Byte, Long>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.16
                private static final long serialVersionUID = -3507408658975187358L;

                public void flatMap(Tuple2<Integer, Ratings> tuple2, Collector<Tuple3<Integer, Byte, Long>> collector) {
                    int i4 = 1 - ((Ratings) tuple2.f1).identity;
                    int intValue = ((Integer) tuple2.f0).intValue();
                    for (long j : ((Ratings) tuple2.f1).neighbors) {
                        collector.collect(Tuple3.of(Integer.valueOf(intValue), Byte.valueOf((byte) i4), Long.valueOf(j)));
                    }
                }

                public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                    flatMap((Tuple2<Integer, Ratings>) obj, (Collector<Tuple3<Integer, Byte, Long>>) collector);
                }
            }).name("GenerateRequest").coGroup(dataSet).where(new KeySelector<Tuple3<Integer, Byte, Long>, Tuple2<Byte, Long>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.19
                private static final long serialVersionUID = 5984785442398189198L;

                public Tuple2<Byte, Long> getKey(Tuple3<Integer, Byte, Long> tuple3) {
                    return Tuple2.of(tuple3.f1, tuple3.f2);
                }
            }).equalTo(new KeySelector<Factors, Tuple2<Byte, Long>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.18
                private static final long serialVersionUID = -7009936622357038623L;

                public Tuple2<Byte, Long> getKey(Factors factors) {
                    return Tuple2.of(Byte.valueOf(factors.identity), Long.valueOf(factors.nodeId));
                }
            }).with(new RichCoGroupFunction<Tuple3<Integer, Byte, Long>, Factors, Tuple2<Integer, Factors>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.17
                private static final long serialVersionUID = 7541748515432588189L;
                private transient int[] flag = null;
                private transient int[] partitionsIds = null;
                static final /* synthetic */ boolean $assertionsDisabled;

                public void open(Configuration configuration) {
                    int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                    this.flag = new int[numberOfParallelSubtasks];
                    this.partitionsIds = new int[numberOfParallelSubtasks];
                }

                public void close() {
                    this.flag = null;
                    this.partitionsIds = null;
                }

                public void coGroup(Iterable<Tuple3<Integer, Byte, Long>> iterable, Iterable<Factors> iterable2, Collector<Tuple2<Integer, Factors>> collector) {
                    if (iterable == null) {
                        return;
                    }
                    int i4 = 0;
                    byte b = -1;
                    long j = Long.MIN_VALUE;
                    int i5 = 0;
                    Arrays.fill(this.flag, 0);
                    for (Tuple3<Integer, Byte, Long> tuple3 : iterable) {
                        i4++;
                        b = ((Byte) tuple3.f1).byteValue();
                        j = ((Long) tuple3.f2).longValue();
                        int intValue = ((Integer) tuple3.f0).intValue();
                        if (this.flag[intValue] == 0) {
                            int i6 = i5;
                            i5++;
                            this.partitionsIds[i6] = intValue;
                            this.flag[intValue] = 1;
                        }
                    }
                    if (i4 == 0) {
                        return;
                    }
                    for (Factors factors : iterable2) {
                        if (!$assertionsDisabled && (factors.identity != b || factors.nodeId != j)) {
                            throw new AssertionError();
                        }
                        for (int i7 = 0; i7 < i5; i7++) {
                            collector.collect(Tuple2.of(Integer.valueOf(this.partitionsIds[i7]), factors));
                        }
                    }
                }

                static {
                    $assertionsDisabled = !HugeMfAlsImpl.class.desiredAssertionStatus();
                }
            }).name("GenerateResponse");
            return Tuple2.of(dataSet.coGroup(this.implicitPrefs ? map.coGroup(name).where(new int[]{0}).equalTo(new int[]{0}).withPartitioner(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.20
                private static final long serialVersionUID = 2820044599727883648L;

                public int partition(Integer num, int i4) {
                    return num.intValue() % i4;
                }
            }).with(new UpdateFactorsFunc(false, i2, this.lambda, this.alpha, z)).withBroadcastSet(computeYtY(dataSet, i2, i), "YtY").name("CalculateNewFactorsImplicit") : map.coGroup(name).where(new int[]{0}).equalTo(new int[]{0}).withPartitioner(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.21
                private static final long serialVersionUID = 1421529212117086604L;

                public int partition(Integer num, int i4) {
                    return num.intValue() % i4;
                }
            }).with(new UpdateFactorsFunc(true, i2, this.lambda, z)).name("CalculateNewFactorsExplicit")).where(new KeySelector<Factors, Tuple2<Byte, Long>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.24
                private static final long serialVersionUID = -2656531711247296641L;

                public Tuple2<Byte, Long> getKey(Factors factors) {
                    return Tuple2.of(Byte.valueOf(factors.identity), Long.valueOf(factors.nodeId));
                }
            }).equalTo(new KeySelector<Factors, Tuple2<Byte, Long>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.23
                private static final long serialVersionUID = -3261052949977562238L;

                public Tuple2<Byte, Long> getKey(Factors factors) {
                    return Tuple2.of(Byte.valueOf(factors.identity), Long.valueOf(factors.nodeId));
                }
            }).with(new RichCoGroupFunction<Factors, Factors, Factors>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.22
                private static final long serialVersionUID = -1806297671515688974L;
                static final /* synthetic */ boolean $assertionsDisabled;

                public void coGroup(Iterable<Factors> iterable, Iterable<Factors> iterable2, Collector<Factors> collector) {
                    if (!$assertionsDisabled && iterable == null) {
                        throw new AssertionError();
                    }
                    if (iterable2 != null) {
                        Iterator<Factors> it = iterable2.iterator();
                        if (it.hasNext()) {
                            Factors next = it.next();
                            for (Factors factors : iterable) {
                                if (!$assertionsDisabled && (factors.identity != next.identity || factors.nodeId != next.nodeId)) {
                                    throw new AssertionError();
                                }
                                collector.collect(next);
                            }
                            return;
                        }
                    }
                    Iterator<Factors> it2 = iterable.iterator();
                    while (it2.hasNext()) {
                        collector.collect(it2.next());
                    }
                }

                static {
                    $assertionsDisabled = !HugeMfAlsImpl.class.desiredAssertionStatus();
                }
            }).name("UpdateFactors"), generateDataProfile.flatMap(new RichFlatMapFunction<DataProfile, Integer>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.25
                private static final long serialVersionUID = 378571173957011355L;

                public void flatMap(DataProfile dataProfile, Collector<Integer> collector) {
                    if (getIterationRuntimeContext().getSuperstepNumber() < (dataProfile.numUserBatches + dataProfile.numItemBatches) * i3) {
                        collector.collect(0);
                    }
                }

                public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                    flatMap((DataProfile) obj, (Collector<Integer>) collector);
                }
            }).withBroadcastSet(flatMap, "empty").name("StopCriterion"));
        }

        private DataSet<double[]> computeYtY(DataSet<Factors> dataSet, final int i, final int i2) {
            return dataSet.mapPartition(new RichMapPartitionFunction<Factors, double[]>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.27
                private static final long serialVersionUID = 4337923793883497898L;

                public void mapPartition(Iterable<Factors> iterable, Collector<double[]> collector) {
                    int superstepNumber = 1 - (((getIterationRuntimeContext().getSuperstepNumber() - 1) / i2) % 2);
                    double[] dArr = new double[i * i];
                    Arrays.fill(dArr, Criteria.INVALID_GAIN);
                    for (Factors factors : iterable) {
                        if (factors.identity == superstepNumber) {
                            float[] fArr = factors.factors;
                            for (int i3 = 0; i3 < i; i3++) {
                                for (int i4 = 0; i4 < i; i4++) {
                                    int i5 = (i3 * i) + i4;
                                    dArr[i5] = dArr[i5] + (fArr[i3] * fArr[i4]);
                                }
                            }
                        }
                    }
                    collector.collect(dArr);
                }
            }).reduce(new ReduceFunction<double[]>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.AlsTrain.26
                private static final long serialVersionUID = 3534712378694892154L;

                public double[] reduce(double[] dArr, double[] dArr2) {
                    int i3 = i * i;
                    double[] dArr3 = new double[i3];
                    for (int i4 = 0; i4 < i3; i4++) {
                        dArr3[i4] = dArr[i4] + dArr2[i4];
                    }
                    return dArr3;
                }
            }).name("YtY");
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static Tuple2<BatchOperator<?>, BatchOperator<?>> factorize(BatchOperator<?> batchOperator, Params params, boolean z) {
        Long mLEnvironmentId = batchOperator.getMLEnvironmentId();
        String str = (String) params.get(AlsTrainParams.USER_COL);
        String str2 = (String) params.get(AlsTrainParams.ITEM_COL);
        String str3 = (String) params.get(AlsTrainParams.RATE_COL);
        double doubleValue = ((Double) params.get(AlsTrainParams.LAMBDA)).doubleValue();
        int intValue = ((Integer) params.get(AlsTrainParams.RANK)).intValue();
        int intValue2 = ((Integer) params.get(AlsTrainParams.NUM_ITER)).intValue();
        boolean booleanValue = ((Boolean) params.get(AlsTrainParams.NON_NEGATIVE)).booleanValue();
        double doubleValue2 = ((Double) params.get(AlsImplicitTrainParams.ALPHA)).doubleValue();
        int intValue3 = ((Integer) params.get(AlsTrainParams.NUM_BLOCKS)).intValue();
        final int findColIndexWithAssert = TableUtil.findColIndexWithAssert(batchOperator.getColNames(), str);
        final int findColIndexWithAssert2 = TableUtil.findColIndexWithAssert(batchOperator.getColNames(), str2);
        final int findColIndexWithAssert3 = TableUtil.findColIndexWithAssert(batchOperator.getColNames(), str3);
        boolean z2 = batchOperator.getColTypes()[findColIndexWithAssert].equals(Types.LONG) && batchOperator.getColTypes()[findColIndexWithAssert2].equals(Types.LONG);
        MultiStringIndexerPredictBatchOp multiStringIndexerPredictBatchOp = null;
        MultiStringIndexerPredictBatchOp multiStringIndexerPredictBatchOp2 = null;
        if (!z2) {
            BatchOperator<?> distinct = batchOperator.select("`" + str + "`").distinct();
            BatchOperator<?> distinct2 = batchOperator.select("`" + str2 + "`").distinct();
            MultiStringIndexerTrainBatchOp linkFrom = ((MultiStringIndexerTrainBatchOp) new MultiStringIndexerTrainBatchOp().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(str, str2).linkFrom(batchOperator);
            batchOperator = ((MultiStringIndexerPredictBatchOp) new MultiStringIndexerPredictBatchOp().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(str, str2).linkFrom(linkFrom, batchOperator);
            multiStringIndexerPredictBatchOp = ((MultiStringIndexerPredictBatchOp) new MultiStringIndexerPredictBatchOp().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(str).setOutputCols("__user_index").linkFrom(linkFrom, distinct);
            multiStringIndexerPredictBatchOp2 = ((MultiStringIndexerPredictBatchOp) new MultiStringIndexerPredictBatchOp().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(str2).setOutputCols("__item_index").linkFrom(linkFrom, distinct2);
        }
        DataSet<Tuple3<Byte, Long, float[]>> fit = new AlsTrain(intValue, intValue2, doubleValue, z, doubleValue2, intValue3, booleanValue).fit(batchOperator.getDataSet().map(new MapFunction<Row, Tuple3<Long, Long, Float>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.1
            private static final long serialVersionUID = 6671683813980584160L;

            public Tuple3<Long, Long, Float> map(Row row) {
                Object field = row.getField(findColIndexWithAssert);
                Object field2 = row.getField(findColIndexWithAssert2);
                Object field3 = row.getField(findColIndexWithAssert3);
                AkPreconditions.checkNotNull(field, new AkIllegalDataException("user is null"));
                AkPreconditions.checkNotNull(field2, new AkIllegalDataException("item is null"));
                AkPreconditions.checkNotNull(field3, new AkIllegalDataException("rating is null"));
                return new Tuple3<>(Long.valueOf(((Number) field).longValue()), Long.valueOf(((Number) field2).longValue()), Float.valueOf(((Number) field3).floatValue()));
            }
        }));
        BatchOperator<?> factors = getFactors(mLEnvironmentId, fit, str, (byte) 0);
        BatchOperator<?> factors2 = getFactors(mLEnvironmentId, fit, str2, (byte) 1);
        if (!z2) {
            factors = ((LeftOuterJoinBatchOp) new LeftOuterJoinBatchOp().setMLEnvironmentId(mLEnvironmentId)).setJoinPredicate(String.format("a.`%s`=b.__user_index", str)).setSelectClause(String.format("b.`%s`, a.`%s`", str, "factors")).linkFrom(factors, multiStringIndexerPredictBatchOp);
            factors2 = ((LeftOuterJoinBatchOp) new LeftOuterJoinBatchOp().setMLEnvironmentId(mLEnvironmentId)).setJoinPredicate(String.format("a.`%s`=b.__item_index", str2)).setSelectClause(String.format("b.`%s`, a.`%s`", str2, "factors")).linkFrom(factors2, multiStringIndexerPredictBatchOp2);
        }
        return Tuple2.of(factors, factors2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static Tuple4<BatchOperator<?>, BatchOperator<?>, BatchOperator<?>, BatchOperator<?>> factorize(BatchOperator<?> batchOperator, BatchOperator<?> batchOperator2, BatchOperator<?> batchOperator3, Params params, boolean z) {
        Long mLEnvironmentId = batchOperator3.getMLEnvironmentId();
        String str = (String) params.get(AlsTrainParams.USER_COL);
        String str2 = (String) params.get(AlsTrainParams.ITEM_COL);
        String str3 = (String) params.get(AlsTrainParams.RATE_COL);
        double doubleValue = ((Double) params.get(AlsTrainParams.LAMBDA)).doubleValue();
        int intValue = ((Integer) params.get(AlsTrainParams.RANK)).intValue();
        int intValue2 = ((Integer) params.get(AlsTrainParams.NUM_ITER)).intValue();
        boolean booleanValue = ((Boolean) params.get(AlsTrainParams.NON_NEGATIVE)).booleanValue();
        double doubleValue2 = ((Double) params.get(AlsImplicitTrainParams.ALPHA)).doubleValue();
        int intValue3 = ((Integer) params.get(AlsTrainParams.NUM_BLOCKS)).intValue();
        final int findColIndexWithAssert = TableUtil.findColIndexWithAssert(batchOperator3.getColNames(), str);
        final int findColIndexWithAssert2 = TableUtil.findColIndexWithAssert(batchOperator3.getColNames(), str2);
        final int findColIndexWithAssert3 = TableUtil.findColIndexWithAssert(batchOperator3.getColNames(), str3);
        boolean z2 = batchOperator3.getColTypes()[findColIndexWithAssert].equals(Types.LONG) && batchOperator3.getColTypes()[findColIndexWithAssert2].equals(Types.LONG);
        MultiStringIndexerPredictBatchOp multiStringIndexerPredictBatchOp = null;
        MultiStringIndexerPredictBatchOp multiStringIndexerPredictBatchOp2 = null;
        if (!z2) {
            BatchOperator<?> distinct = new UnionBatchOp().linkFrom(batchOperator.select("`" + str + "`"), batchOperator3.select("`" + str + "`")).distinct();
            BatchOperator<?> distinct2 = new UnionBatchOp().linkFrom(batchOperator2.select("`" + str2 + "`"), batchOperator3.select("`" + str2 + "`")).distinct();
            MultiStringIndexerTrainBatchOp linkFrom = ((MultiStringIndexerTrainBatchOp) new MultiStringIndexerTrainBatchOp().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(str).linkFrom(distinct);
            MultiStringIndexerTrainBatchOp linkFrom2 = ((MultiStringIndexerTrainBatchOp) new MultiStringIndexerTrainBatchOp().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(str2).linkFrom(distinct2);
            batchOperator3 = ((MultiStringIndexerPredictBatchOp) new MultiStringIndexerPredictBatchOp().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(str2).linkFrom(linkFrom2, ((MultiStringIndexerPredictBatchOp) new MultiStringIndexerPredictBatchOp().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(str).linkFrom(linkFrom, batchOperator3));
            batchOperator = ((MultiStringIndexerPredictBatchOp) new MultiStringIndexerPredictBatchOp().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(str).linkFrom(linkFrom, batchOperator);
            batchOperator2 = ((MultiStringIndexerPredictBatchOp) new MultiStringIndexerPredictBatchOp().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(str2).linkFrom(linkFrom2, batchOperator2);
            multiStringIndexerPredictBatchOp = ((MultiStringIndexerPredictBatchOp) new MultiStringIndexerPredictBatchOp().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(str).setOutputCols("__user_index").linkFrom(linkFrom, distinct);
            multiStringIndexerPredictBatchOp2 = ((MultiStringIndexerPredictBatchOp) new MultiStringIndexerPredictBatchOp().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(str2).setOutputCols("__item_index").linkFrom(linkFrom2, distinct2);
        }
        DataSet<Tuple3<Byte, Long, float[]>> fit = new AlsTrain(intValue, intValue2, doubleValue, z, doubleValue2, intValue3, booleanValue).fit(batchOperator.getDataSet(), batchOperator2.getDataSet(), batchOperator3.getDataSet().map(new MapFunction<Row, Tuple3<Long, Long, Float>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.2
            private static final long serialVersionUID = 6671683813980584160L;

            public Tuple3<Long, Long, Float> map(Row row) {
                Object field = row.getField(findColIndexWithAssert);
                Object field2 = row.getField(findColIndexWithAssert2);
                Object field3 = row.getField(findColIndexWithAssert3);
                AkPreconditions.checkNotNull(field, new AkIllegalDataException("user is null"));
                AkPreconditions.checkNotNull(field2, new AkIllegalDataException("item is null"));
                AkPreconditions.checkNotNull(field3, new AkIllegalDataException("rating is null"));
                return new Tuple3<>(Long.valueOf(((Number) field).longValue()), Long.valueOf(((Number) field2).longValue()), Float.valueOf(((Number) field3).floatValue()));
            }
        }));
        BatchOperator<?> factors = getFactors(mLEnvironmentId, fit, str, (byte) 0);
        BatchOperator<?> factors2 = getFactors(mLEnvironmentId, fit, str2, (byte) 1);
        if (!z2) {
            factors = ((LeftOuterJoinBatchOp) new LeftOuterJoinBatchOp().setMLEnvironmentId(mLEnvironmentId)).setJoinPredicate(String.format("a.`%s`=b.__user_index", str)).setSelectClause(String.format("b.`%s`, a.`%s`", str, "factors")).linkFrom(factors, multiStringIndexerPredictBatchOp);
            batchOperator = ((LeftOuterJoinBatchOp) new LeftOuterJoinBatchOp().setMLEnvironmentId(mLEnvironmentId)).setJoinPredicate(String.format("a.`%s`=b.__user_index", str)).setSelectClause(String.format("b.`%s`, a.`%s`", str, "factors")).linkFrom(batchOperator, multiStringIndexerPredictBatchOp);
            factors2 = ((LeftOuterJoinBatchOp) new LeftOuterJoinBatchOp().setMLEnvironmentId(mLEnvironmentId)).setJoinPredicate(String.format("a.`%s`=b.__item_index", str2)).setSelectClause(String.format("b.`%s`, a.`%s`", str2, "factors")).linkFrom(factors2, multiStringIndexerPredictBatchOp2);
            batchOperator2 = ((LeftOuterJoinBatchOp) new LeftOuterJoinBatchOp().setMLEnvironmentId(mLEnvironmentId)).setJoinPredicate(String.format("a.`%s`=b.__item_index", str2)).setSelectClause(String.format("b.`%s`, a.`%s`", str2, "factors")).linkFrom(batchOperator2, multiStringIndexerPredictBatchOp2);
        }
        return Tuple4.of(new FullOuterJoinBatchOp().setJoinPredicate("a." + str + "=b." + str).setSelectClause("case when a." + str + " is null then b." + str + " when b." + str + " is null then a." + str + " else b." + str + " end as " + str + ", case when a." + str + " is null then b.factors when b." + str + " is null then a.factors else b.factors end as factors").linkFrom(batchOperator, factors), new FullOuterJoinBatchOp().setJoinPredicate("a." + str2 + "=b." + str2).setSelectClause("case when a." + str2 + " is null then b." + str2 + " when b." + str2 + " is null then a." + str2 + " else b." + str2 + " end as " + str2 + ", case when a." + str2 + " is null then b.factors when b." + str2 + " is null then a.factors else b.factors end as factors").linkFrom(batchOperator2, factors2), factors, factors2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static BatchOperator<?> getFactors(Long l, DataSet<Tuple3<Byte, Long, float[]>> dataSet, String str, final byte b) {
        return (BatchOperator) BatchOperator.fromTable(DataSetConversionUtil.toTable(l, (DataSet<Row>) dataSet.filter(new FilterFunction<Tuple3<Byte, Long, float[]>>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.3
            private static final long serialVersionUID = -2198675502442522328L;

            public boolean filter(Tuple3<Byte, Long, float[]> tuple3) {
                return ((Byte) tuple3.f0).byteValue() == b;
            }
        }).map(new MapFunction<Tuple3<Byte, Long, float[]>, Row>() { // from class: com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl.4
            private static final long serialVersionUID = 3477932515673402769L;

            public Row map(Tuple3<Byte, Long, float[]> tuple3) {
                double[] dArr = new double[((float[]) tuple3.f2).length];
                for (int i = 0; i < dArr.length; i++) {
                    dArr[i] = ((float[]) tuple3.f2)[i];
                }
                return Row.of(new Object[]{tuple3.f1, VectorUtil.serialize(new DenseVector(dArr))});
            }
        }), new String[]{str, "factors"}, (TypeInformation<?>[]) new TypeInformation[]{Types.LONG, Types.STRING})).setMLEnvironmentId(l);
    }
}
