package com.alibaba.alink.operator.common.aps;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp;
import com.alibaba.alink.operator.batch.source.TableSourceBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.accumulators.IntCounter;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/aps/ApsContext.class */
public class ApsContext {
    private static final String alinkApsStepNum = "alinkApsStepNum";
    protected DataSet<Params> context;
    private static final int MAX_VARCHAR_SIZE = 30000;
    public static final Logger LOG = LoggerFactory.getLogger(ApsContext.class);
    public static final String alinkApsBreakAll = "alinkApsBreakAll";
    public static ParamInfo<Boolean> ALINK_APS_BREAK_ALL = ParamInfoFactory.createParamInfo(alinkApsBreakAll, Boolean.class).setDescription(alinkApsBreakAll).setRequired().build();
    public static final String alinkApsNumMiniBatch = "alinkApsNumMiniBatch";
    public static ParamInfo<Integer> ALINK_APS_NUM_MINI_BATCH = ParamInfoFactory.createParamInfo(alinkApsNumMiniBatch, Integer.class).setDescription(alinkApsNumMiniBatch).setRequired().build();
    public static final String alinkApsNumCheckpoint = "alinkApsNumCheckpoint";
    public static ParamInfo<Integer> ALINK_APS_NUM_CHECK_POINT = ParamInfoFactory.createParamInfo(alinkApsNumCheckpoint, Integer.class).setDescription(alinkApsNumCheckpoint).setRequired().build();
    public static final String alinkApsNumIter = "alinkApsNumIter";
    public static ParamInfo<Integer> ALINK_APS_NUM_ITER = ParamInfoFactory.createParamInfo(alinkApsNumIter, Integer.class).setDescription(alinkApsNumIter).setRequired().build();
    public static ParamInfo<Long[]> SEEDS = ParamInfoFactory.createParamInfo("seeds", Long[].class).setDescription("seeds").setRequired().build();
    private static final String alinkApsHasNextBlock = "alinkApsHasNextBlock";
    public static ParamInfo<Boolean> ALINK_APS_HAS_NEXT_BLOCK = ParamInfoFactory.createParamInfo(alinkApsHasNextBlock, Boolean.class).setDescription(alinkApsHasNextBlock).setRequired().build();
    static final String alinkApsCurBlock = "alinkApsCurBlock";
    public static ParamInfo<Integer> ALINK_APS_CUR_BLOCK = ParamInfoFactory.createParamInfo(alinkApsCurBlock, Integer.class).setDescription(alinkApsCurBlock).setRequired().build();
    public static final String alinkApsCurCheckpoint = "alinkApsCurCheckpoint";
    public static ParamInfo<Integer> ALINK_APS_CUR_CHECK_POINT = ParamInfoFactory.createParamInfo(alinkApsCurCheckpoint, Integer.class).setDescription(alinkApsCurCheckpoint).setRequired().build();
    public static ParamInfo<Long> LIFECYCLE = ParamInfoFactory.createParamInfo("lifecycle", Long.class).setDescription("lifecycle").setRequired().build();
    private static final String SCHEMA_PREFIX1 = "a1";
    private static final String SCHEMA_PREFIX2 = "a2";
    private static final TableSchema DEFAULT_MODEL_SCHEMA = new TableSchema(new String[]{"alinkmodelid", "alinkmodelinfo", SCHEMA_PREFIX1, SCHEMA_PREFIX2}, new TypeInformation[]{Types.LONG, Types.STRING, Types.STRING, Types.STRING});

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/aps/ApsContext$AppendParam2Context.class */
    public static class AppendParam2Context implements MapFunction<Params, Params> {
        private static final long serialVersionUID = -4020594804812453044L;
        private Params params;

        public AppendParam2Context(Params params) {
            this.params = params;
        }

        public Params map(Params params) throws Exception {
            return params.merge(this.params);
        }
    }

    @Deprecated
    public ApsContext() {
        this(MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID.longValue());
    }

    /* JADX WARN: Multi-variable type inference failed */
    public ApsContext(long j) {
        this.context = ((NumSeqSourceBatchOp) new NumSeqSourceBatchOp().setFrom(1L).setTo(1L).setMLEnvironmentId(Long.valueOf(j))).getDataSet().map(new MapFunction<Row, Params>() { // from class: com.alibaba.alink.operator.common.aps.ApsContext.1
            private static final long serialVersionUID = -1788057684839131006L;

            public Params map(Row row) throws Exception {
                return new Params();
            }
        });
    }

    public ApsContext(DataSet<Params> dataSet) {
        AkPreconditions.checkNotNull(dataSet);
        this.context = dataSet.map(new MapFunction<Params, Params>() { // from class: com.alibaba.alink.operator.common.aps.ApsContext.2
            private static final long serialVersionUID = 3148182956495365060L;

            public Params map(Params params) throws Exception {
                return params.m1495clone();
            }
        });
    }

    private static DataSet<Row> seriContext(DataSet<Params> dataSet) {
        return dataSet.flatMap(new FlatMapFunction<Params, Row>() { // from class: com.alibaba.alink.operator.common.aps.ApsContext.3
            private static final long serialVersionUID = 426050062920251145L;

            public void flatMap(Params params, Collector<Row> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                ApsContext.appendStringToModel(params.toJson(), arrayList);
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    collector.collect((Row) it.next());
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Params) obj, (Collector<Row>) collector);
            }
        });
    }

    public static ApsContext deserilize(BatchOperator<?> batchOperator) {
        return new ApsContext((DataSet<Params>) batchOperator.getDataSet().reduceGroup(new GroupReduceFunction<Row, Params>() { // from class: com.alibaba.alink.operator.common.aps.ApsContext.4
            private static final long serialVersionUID = 9192617145467685004L;

            public void reduce(Iterable<Row> iterable, Collector<Params> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                collector.collect(Params.fromJson(ApsContext.extractStringFromModel(arrayList)));
            }
        }));
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ApsContext m307clone() {
        return new ApsContext(this.context);
    }

    public ApsContext map(MapFunction<Params, Params> mapFunction) {
        this.context = this.context.map(mapFunction);
        return this;
    }

    public DataSet<Params> getDataSet() {
        return this.context;
    }

    public ApsContext put(Params params) {
        this.context = this.context.map(new AppendParam2Context(params));
        return this;
    }

    public ApsContext put(DataSet<Params> dataSet) {
        this.context = this.context.map(new RichMapFunction<Params, Params>() { // from class: com.alibaba.alink.operator.common.aps.ApsContext.5
            private static final long serialVersionUID = -1634936932541926298L;
            Params info;

            public void open(Configuration configuration) throws Exception {
                this.info = (Params) getRuntimeContext().getBroadcastVariable("info").get(0);
            }

            public Params map(Params params) throws Exception {
                return params.merge(this.info);
            }
        }).withBroadcastSet(dataSet, "info");
        return this;
    }

    public ApsContext put(ApsContext apsContext) {
        return put(apsContext.context);
    }

    public <T> ApsContext put(final String str, DataSet<T> dataSet) {
        this.context = this.context.map(new RichMapFunction<Params, Params>() { // from class: com.alibaba.alink.operator.common.aps.ApsContext.6
            private static final long serialVersionUID = 1639194499741682391L;
            T data;

            /* JADX WARN: Type inference failed for: r1v3, types: [T, java.lang.Object] */
            public void open(Configuration configuration) throws Exception {
                this.data = getRuntimeContext().getBroadcastVariable("dataSet").get(0);
            }

            public Params map(Params params) throws Exception {
                return params.set(str, this.data);
            }
        }).withBroadcastSet(dataSet, "dataSet");
        return this;
    }

    @Deprecated
    public BatchOperator serialize() {
        return serialize(MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public BatchOperator<?> serialize(Long l) {
        return (BatchOperator) new TableSourceBatchOp(DataSetConversionUtil.toTable(l, seriContext(this.context), DEFAULT_MODEL_SCHEMA.getFieldNames(), (TypeInformation<?>[]) DEFAULT_MODEL_SCHEMA.getFieldTypes())).setMLEnvironmentId(l);
    }

    public <MT> ApsContext updateLoopInfo(IterativeDataSet<Tuple2<Long, MT>> iterativeDataSet) {
        put(alinkApsStepNum, iterativeDataSet.mapPartition(new RichMapPartitionFunction<Tuple2<Long, MT>, Integer>() { // from class: com.alibaba.alink.operator.common.aps.ApsContext.7
            private static final long serialVersionUID = 4816930852791283240L;

            public void mapPartition(Iterable<Tuple2<Long, MT>> iterable, Collector<Integer> collector) throws Exception {
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                    collector.collect(Integer.valueOf(getIterationRuntimeContext().getSuperstepNumber()));
                }
            }
        }).returns(Types.INT));
        this.context = this.context.map(new RichMapFunction<Params, Params>() { // from class: com.alibaba.alink.operator.common.aps.ApsContext.8
            private static final long serialVersionUID = -7796590264875170687L;
            IntCounter counter = new IntCounter();

            public void open(Configuration configuration) throws Exception {
                getRuntimeContext().addAccumulator(ApsContext.alinkApsBreakAll, this.counter);
            }

            public Params map(Params params) throws Exception {
                int startPos;
                int localRowCnt;
                int intValue = params.getInteger(ApsContext.alinkApsNumCheckpoint).intValue();
                int intValue2 = params.getIntegerOrDefault(ApsContext.alinkApsNumIter, 1).intValue();
                int intValue3 = params.getInteger(ApsContext.alinkApsNumMiniBatch).intValue();
                if (intValue <= 0) {
                    startPos = 0;
                    localRowCnt = intValue3 * intValue2;
                } else {
                    int intValue4 = params.getInteger(ApsContext.alinkApsCurCheckpoint).intValue();
                    DefaultDistributedInfo defaultDistributedInfo = new DefaultDistributedInfo();
                    startPos = (int) defaultDistributedInfo.startPos(intValue4, intValue, intValue3 * intValue2);
                    localRowCnt = ((int) defaultDistributedInfo.localRowCnt(intValue4, intValue, intValue3 * intValue2)) + startPos;
                }
                int intValue5 = (params.getInteger(ApsContext.alinkApsStepNum).intValue() - 1) + startPos;
                ApsContext.LOG.info("taskId:{}, stepNum:{}", Integer.valueOf(getRuntimeContext().getIndexOfThisSubtask()), Integer.valueOf(intValue5));
                int i = intValue5 >= localRowCnt ? -1 : intValue5 >= intValue3 * intValue2 ? -1 : intValue5 % intValue3;
                boolean z = true;
                if (intValue5 >= localRowCnt - 1) {
                    z = false;
                } else if (intValue5 >= (intValue3 * intValue2) - 1) {
                    z = false;
                }
                if (localRowCnt >= intValue3 * intValue2) {
                    this.counter.add(1);
                }
                params.set((ParamInfo<ParamInfo<Boolean>>) ApsContext.ALINK_APS_HAS_NEXT_BLOCK, (ParamInfo<Boolean>) Boolean.valueOf(z));
                params.set((ParamInfo<ParamInfo<Integer>>) ApsContext.ALINK_APS_CUR_BLOCK, (ParamInfo<Integer>) Integer.valueOf(i));
                return params;
            }
        });
        return this;
    }

    public DataSet<Params> getCriterion() {
        return this.context.flatMap(new RichFlatMapFunction<Params, Params>() { // from class: com.alibaba.alink.operator.common.aps.ApsContext.9
            private static final long serialVersionUID = -156438604518301112L;
            IntCounter counter = new IntCounter();

            public void open(Configuration configuration) throws Exception {
                getRuntimeContext().addAccumulator(ApsContext.alinkApsBreakAll, this.counter);
            }

            public void flatMap(Params params, Collector<Params> collector) throws Exception {
                if (params.getBoolOrDefault(ApsContext.alinkApsBreakAll, false).booleanValue()) {
                    this.counter.add(1);
                } else if (((Boolean) params.get(ApsContext.ALINK_APS_HAS_NEXT_BLOCK)).booleanValue()) {
                    collector.collect(params);
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Params) obj, (Collector<Params>) collector);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int appendStringToModel(String str, List<Row> list) {
        return appendStringToModel(str, list, 1);
    }

    private static int appendStringToModel(String str, List<Row> list, int i) {
        return appendStringToModel(str, list, i, MAX_VARCHAR_SIZE);
    }

    private static int appendStringToModel(String str, List<Row> list, int i, int i2) {
        if (null == str || str.length() == 0) {
            return i;
        }
        int length = str.length();
        int i3 = 0;
        int i4 = i;
        while (i3 < length) {
            String substring = str.substring(i3, Math.min(i3 + i2, length));
            Row row = new Row(4);
            row.setField(0, Long.valueOf(i4));
            row.setField(1, substring);
            list.add(row);
            i4++;
            i3 += i2;
        }
        return i4;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static String extractStringFromModel(List<Row> list) {
        return extractStringFromModel(list, 1);
    }

    private static String extractStringFromModel(List<Row> list, int i) {
        int size = list.size();
        String[] strArr = new String[size];
        for (Row row : list) {
            int intValue = ((Long) row.getField(0)).intValue();
            if (intValue >= i) {
                strArr[intValue - i] = (String) row.getField(1);
            }
        }
        StringBuilder sb = new StringBuilder();
        for (int i2 = 0; i2 < size && null != strArr[i2]; i2++) {
            sb.append(strArr[i2]);
        }
        return sb.toString();
    }
}
