package com.alibaba.alink.operator.common.aps;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.comqueue.IterTaskObjKeeper;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichCoGroupFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/aps/ApsOpExpr.class */
public class ApsOpExpr {
    private static final Logger LOG = LoggerFactory.getLogger(ApsOpExpr.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/aps/ApsOpExpr$ModelPartitioner.class */
    public static class ModelPartitioner implements Partitioner<Long> {
        private static final long serialVersionUID = -6766513370607697864L;

        private ModelPartitioner() {
        }

        public int partition(Long l, int i) {
            return (int) (Math.abs(l.longValue()) % i);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/aps/ApsOpExpr$RequestIndexFunction.class */
    public interface RequestIndexFunction<DT> extends Function, Serializable {
        List<Long> requestIndex(DT dt);
    }

    private static <DT, MT> Tuple2<DataSet<Tuple3<Integer, Integer, DT>>, DataSet<Tuple3<Integer, Long, MT>>> pullBaseWithState(DataSet<DT> dataSet, DataSet<Tuple2<Long, MT>> dataSet2, final RequestIndexFunction<DT> requestIndexFunction, final long j) {
        TypeInformation type = dataSet.getType();
        TypeInformation typeAt = dataSet2.getType().getTypeAt(1);
        SingleInputUdfOperator returns = dataSet.flatMap(new RichFlatMapFunction<DT, Tuple3<Integer, Integer, DT>>() { // from class: com.alibaba.alink.operator.common.aps.ApsOpExpr.1
            private static final long serialVersionUID = -3104031715299687938L;
            transient int pid;
            transient int idx;

            public void open(Configuration configuration) throws Exception {
                this.pid = getRuntimeContext().getIndexOfThisSubtask();
                this.idx = 0;
            }

            public void flatMap(DT dt, Collector<Tuple3<Integer, Integer, DT>> collector) throws Exception {
                collector.collect(new Tuple3(Integer.valueOf(this.pid), Integer.valueOf(this.idx), dt));
                this.idx++;
            }
        }).name("ApsPartitionTrainData").returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.INT, type}));
        return new Tuple2<>(returns, returns.mapPartition(new RichMapPartitionFunction<Tuple3<Integer, Integer, DT>, Tuple2<Integer, Long>>() { // from class: com.alibaba.alink.operator.common.aps.ApsOpExpr.2
            private static final long serialVersionUID = -266021351887071153L;

            /* JADX WARN: Multi-variable type inference failed */
            public void mapPartition(Iterable<Tuple3<Integer, Integer, DT>> iterable, Collector<Tuple2<Integer, Long>> collector) throws Exception {
                HashSet hashSet = new HashSet();
                Integer num = null;
                for (Tuple3<Integer, Integer, DT> tuple3 : iterable) {
                    num = (Integer) tuple3.f0;
                    hashSet.addAll(RequestIndexFunction.this.requestIndex(tuple3.f2));
                }
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    collector.collect(Tuple2.of(num, (Long) it.next()));
                }
            }
        }).returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.LONG})).name("ApsRequestIndex").partitionCustom(new ModelPartitioner(), 1).mapPartition(new RichMapPartitionFunction<Tuple2<Integer, Long>, Tuple3<Integer, Long, MT>>() { // from class: com.alibaba.alink.operator.common.aps.ApsOpExpr.3
            private static final long serialVersionUID = 1064875426581432405L;

            public void open(Configuration configuration) throws Exception {
                ApsOpExpr.LOG.info("{}:{}", Thread.currentThread().getName(), "open");
            }

            public void close() throws Exception {
                ApsOpExpr.LOG.info("{}:{}", Thread.currentThread().getName(), "close");
            }

            public void mapPartition(Iterable<Tuple2<Integer, Long>> iterable, Collector<Tuple3<Integer, Long, MT>> collector) throws Exception {
                Tuple2 tuple2 = (Tuple2) IterTaskObjKeeper.get(j, getRuntimeContext().getIndexOfThisSubtask());
                AkPreconditions.checkArgument(tuple2 != null, "Can't get model from state.");
                List list = (List) tuple2.f0;
                Map map = (Map) tuple2.f1;
                for (Tuple2<Integer, Long> tuple22 : iterable) {
                    collector.collect(Tuple3.of(tuple22.f0, tuple22.f1, ((Tuple2) list.get(((Integer) map.get(tuple22.f1)).intValue())).f1));
                }
            }
        }).name("ApsGetPartitionFeatureValue").withBroadcastSet(dataSet2, OptimVariable.model).returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.LONG, typeAt})));
    }

    private static <DT, MT> Tuple2<DataSet<Tuple3<Integer, Integer, DT>>, DataSet<Tuple3<Integer, Long, MT>>> pullBase(DataSet<DT> dataSet, DataSet<Tuple2<Long, MT>> dataSet2, ApsFuncIndex4Pull<DT> apsFuncIndex4Pull, ApsContext apsContext) {
        TypeInformation type = dataSet.getType();
        if (!dataSet2.getType().isTupleType() || dataSet2.getType().getArity() != 2) {
            throw new AkUnclassifiedErrorException("Unsupported model type. type: " + dataSet2.getType().toString());
        }
        TypeInformation typeAt = dataSet2.getType().getTypeAt(1);
        SingleInputUdfOperator returns = dataSet.mapPartition(new RichMapPartitionFunction<DT, Tuple3<Integer, Integer, DT>>() { // from class: com.alibaba.alink.operator.common.aps.ApsOpExpr.4
            private static final long serialVersionUID = 4694589881648017911L;

            public void mapPartition(Iterable<DT> iterable, Collector<Tuple3<Integer, Integer, DT>> collector) throws Exception {
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                int i = 0;
                Iterator<DT> it = iterable.iterator();
                while (it.hasNext()) {
                    collector.collect(new Tuple3(Integer.valueOf(indexOfThisSubtask), Integer.valueOf(i), it.next()));
                    i++;
                }
            }
        }).name("ApsPartitionTrainData").returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.INT, type}));
        return new Tuple2<>(returns, returns.groupBy(new int[]{0}).withPartitioner(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.aps.ApsOpExpr.5
            private static final long serialVersionUID = 5245425039908176380L;

            public int partition(Integer num, int i) {
                return Math.abs(num.intValue()) % i;
            }
        }).sortGroup(1, Order.ASCENDING).reduceGroup(apsFuncIndex4Pull).withBroadcastSet(apsContext.getDataSet(), "RequestIndex").name("ApsRequestIndex").coGroup(dataSet2).where(new int[]{1}).equalTo(new int[]{0}).with(new RichCoGroupFunction<Tuple2<Integer, Long>, Tuple2<Long, MT>, Tuple3<Integer, Long, MT>>() { // from class: com.alibaba.alink.operator.common.aps.ApsOpExpr.6
            private static final long serialVersionUID = 7387404099507374176L;

            public void open(Configuration configuration) throws Exception {
                ApsOpExpr.LOG.info("{}:{}", Thread.currentThread().getName(), "open");
            }

            public void close() throws Exception {
                ApsOpExpr.LOG.info("{}:{}", Thread.currentThread().getName(), "close");
            }

            public void coGroup(Iterable<Tuple2<Integer, Long>> iterable, Iterable<Tuple2<Long, MT>> iterable2, Collector<Tuple3<Integer, Long, MT>> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                Iterator<Tuple2<Integer, Long>> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                if (arrayList.isEmpty()) {
                    return;
                }
                for (Tuple2<Long, MT> tuple2 : iterable2) {
                    Iterator it2 = arrayList.iterator();
                    while (it2.hasNext()) {
                        Tuple2 tuple22 = (Tuple2) it2.next();
                        collector.collect(new Tuple3(tuple22.f0, tuple22.f1, tuple2.f1));
                    }
                }
            }
        }).name("ApsGetPartitionFeatureValue").returns(new TupleTypeInfo(new TypeInformation[]{Types.INT, Types.LONG, typeAt})));
    }

    public static <DT, MT, OT> DataSet<OT> pullProc(DataSet<DT> dataSet, DataSet<Tuple2<Long, MT>> dataSet2, ApsFuncIndex4Pull<DT> apsFuncIndex4Pull, ApsContext apsContext, ApsFuncProc<DT, MT, OT> apsFuncProc, ApsContext apsContext2) {
        Tuple2 pullBase = pullBase(dataSet, dataSet2, apsFuncIndex4Pull, apsContext);
        return ((DataSet) pullBase.f1).coGroup((DataSet) pullBase.f0).where(new int[]{0}).equalTo(new int[]{0}).sortSecondGroup(1, Order.ASCENDING).withPartitioner(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.aps.ApsOpExpr.7
            private static final long serialVersionUID = 2862199187133765866L;

            public int partition(Integer num, int i) {
                return Math.abs(num.intValue()) % i;
            }
        }).with(apsFuncProc).withBroadcastSet(apsContext2.getDataSet(), "TrainSubset").name("ApsTrainSubset");
    }

    public static <DT, MT> DataSet<Tuple2<Long, MT>> pullTrainPushWithState(DataSet<DT> dataSet, DataSet<Tuple2<Long, MT>> dataSet2, ApsContext apsContext, RequestIndexFunction<DT> requestIndexFunction, ApsFuncTrain<DT, MT> apsFuncTrain, final ApsFuncUpdateModel<MT> apsFuncUpdateModel) {
        TypeInformation typeAt = dataSet2.getType().getTypeAt(1);
        final long newHandle = IterTaskObjKeeper.getNewHandle();
        Tuple2 pullBaseWithState = pullBaseWithState(dataSet, dataSet2.partitionCustom(new ModelPartitioner(), 0).mapPartition(new RichMapPartitionFunction<Tuple2<Long, MT>, Tuple2<Long, MT>>() { // from class: com.alibaba.alink.operator.common.aps.ApsOpExpr.8
            private static final long serialVersionUID = -4095984329300130615L;

            public void open(Configuration configuration) throws Exception {
                Params params = (Params) getRuntimeContext().getBroadcastVariable("curContext").get(0);
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("\n** " + params.toJson());
                    ApsOpExpr.LOG.info("init state:{}", Thread.currentThread().getName());
                }
            }

            public void mapPartition(Iterable<Tuple2<Long, MT>> iterable, Collector<Tuple2<Long, MT>> collector) throws Exception {
                int superstepNumber = getIterationRuntimeContext().getSuperstepNumber();
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                if (superstepNumber == 1) {
                    ArrayList arrayList = new ArrayList();
                    HashMap hashMap = new HashMap();
                    int i = 0;
                    for (Tuple2<Long, MT> tuple2 : iterable) {
                        arrayList.add(tuple2);
                        hashMap.put(tuple2.f0, Integer.valueOf(i));
                        i++;
                    }
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println("** # feature in model " + i);
                    }
                    IterTaskObjKeeper.put(newHandle, indexOfThisSubtask, Tuple2.of(arrayList, hashMap));
                }
            }
        }).name("InitState").withBroadcastSet(apsContext.getDataSet(), "curContext").returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, typeAt})), requestIndexFunction, newHandle);
        return ((DataSet) pullBaseWithState.f1).coGroup((DataSet) pullBaseWithState.f0).where(new int[]{0}).equalTo(new int[]{0}).sortSecondGroup(1, Order.ASCENDING).withPartitioner(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.aps.ApsOpExpr.9
            private static final long serialVersionUID = 3944668070162348246L;

            public int partition(Integer num, int i) {
                return Math.abs(num.intValue()) % i;
            }
        }).with(apsFuncTrain).returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, typeAt})).name("ApsTrainSubset").groupBy(new int[]{0}).withPartitioner(new ModelPartitioner()).reduceGroup(new RichGroupReduceFunction<Tuple2<Long, MT>, Tuple2<Long, MT>>() { // from class: com.alibaba.alink.operator.common.aps.ApsOpExpr.10
            private static final long serialVersionUID = 9134864393612660156L;
            transient List<Tuple2<Long, MT>> model;
            transient boolean[] collectFlag;
            transient Map<Long, Integer> fid2lid;
            transient boolean isLastStep;
            transient Collector<Tuple2<Long, MT>> collector;

            public void open(Configuration configuration) throws Exception {
                Tuple2 tuple2;
                ApsOpExpr.LOG.info("{}:{}", Thread.currentThread().getName(), "open");
                this.isLastStep = getRuntimeContext().getBroadcastVariable("stopCriterion").isEmpty();
                int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
                if (this.isLastStep) {
                    tuple2 = (Tuple2) IterTaskObjKeeper.remove(newHandle, indexOfThisSubtask);
                    this.collectFlag = new boolean[((List) tuple2.f0).size()];
                    Arrays.fill(this.collectFlag, false);
                } else {
                    tuple2 = (Tuple2) IterTaskObjKeeper.get(newHandle, indexOfThisSubtask);
                    if (tuple2 == null) {
                        throw new AkUnclassifiedErrorException("Fail to get model for task " + indexOfThisSubtask);
                    }
                }
                this.model = (List) tuple2.f0;
                this.fid2lid = (Map) tuple2.f1;
            }

            public void close() throws Exception {
                if (this.isLastStep) {
                    if (this.collector == null) {
                        return;
                    }
                    for (int i = 0; i < this.collectFlag.length; i++) {
                        if (!this.collectFlag[i]) {
                            this.collector.collect(this.model.get(i));
                        }
                    }
                }
                ApsOpExpr.LOG.info("{}:{}", Thread.currentThread().getName(), "close");
            }

            /* JADX WARN: Multi-variable type inference failed */
            public void reduce(Iterable<Tuple2<Long, MT>> iterable, Collector<Tuple2<Long, MT>> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                Long l = null;
                for (Tuple2<Long, MT> tuple2 : iterable) {
                    arrayList.add(tuple2.f1);
                    l = (Long) tuple2.f0;
                }
                if (l == null) {
                    return;
                }
                int intValue = this.fid2lid.get(l).intValue();
                Object update = apsFuncUpdateModel.update(this.model.get(intValue).f1, arrayList);
                if (!this.isLastStep) {
                    this.model.get(intValue).f1 = update;
                    return;
                }
                collector.collect(Tuple2.of(l, update));
                this.collector = collector;
                this.collectFlag[intValue] = true;
            }
        }).name("ApsUpdateModel").withBroadcastSet(apsContext.getCriterion(), "stopCriterion").returns(new TupleTypeInfo(new TypeInformation[]{Types.LONG, typeAt}));
    }
}
