package com.alibaba.alink.operator.common.aps;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.TableSourceBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import java.io.Serializable;
import java.util.Map;
import java.util.Random;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/aps/ApsEnv.class */
public class ApsEnv<DT, MT> implements Serializable {
    public static final Logger LOG = LoggerFactory.getLogger(ApsEnv.class);
    private static final int RETRY_TIMES = 10;
    private static final long CHECKPOINT_LIFE_CYCLE = 28;
    private static final String CKPT_PREFIX = "alink_aps_tmp_ckpt";
    private static final long serialVersionUID = 8477870238598909628L;
    private final ApsSerializeData<DT> apsSerializeData;
    private final ApsSerializeModel<MT> apsSerializeModel;
    private final transient Long mlEnvId;
    private final transient ApsCheckpoint checkpoint;
    private transient boolean isBreakAll;

    /* loaded from: input_file:com/alibaba/alink/operator/common/aps/ApsEnv$PersistentHook.class */
    public interface PersistentHook<T> {
        default DataSet<T> hook(DataSet<T> dataSet) {
            return dataSet;
        }
    }

    @Deprecated
    public ApsEnv(ApsCheckpoint apsCheckpoint, ApsSerializeData<DT> apsSerializeData, ApsSerializeModel<MT> apsSerializeModel) {
        this(apsCheckpoint, apsSerializeData, apsSerializeModel, MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID);
    }

    public ApsEnv(ApsCheckpoint apsCheckpoint, ApsSerializeData<DT> apsSerializeData, ApsSerializeModel<MT> apsSerializeModel, Long l) {
        this.apsSerializeData = apsSerializeData;
        this.checkpoint = apsCheckpoint;
        this.apsSerializeModel = apsSerializeModel;
        this.mlEnvId = l;
    }

    public DataSet<Tuple2<Long, MT>> iterate(DataSet<Tuple2<Long, MT>> dataSet, DataSet<DT> dataSet2, ApsContext apsContext, boolean z, int i, int i2, Params params, ApsIterator<DT, MT> apsIterator) {
        return (DataSet) iterate(dataSet, dataSet2, apsContext, null, z, i, i2, params, apsIterator).f0;
    }

    public Tuple2<DataSet<Tuple2<Long, MT>>, BatchOperator[]> iterate(DataSet<Tuple2<Long, MT>> dataSet, DataSet<DT> dataSet2, ApsContext apsContext, BatchOperator[] batchOperatorArr, boolean z, int i, int i2, Params params, ApsIterator<DT, MT> apsIterator) {
        return iterate(dataSet, dataSet2, apsContext, batchOperatorArr, z, i, i2, params, apsIterator, null);
    }

    public Tuple2<DataSet<Tuple2<Long, MT>>, BatchOperator[]> iterate(DataSet<Tuple2<Long, MT>> dataSet, DataSet<DT> dataSet2, ApsContext apsContext, BatchOperator[] batchOperatorArr, boolean z, final int i, final int i2, Params params, ApsIterator<DT, MT> apsIterator, PersistentHook persistentHook) {
        DataSet<Tuple2<Long, DT>> zipWithIndex = DataSetUtils.zipWithIndex(dataSet2);
        ApsContext map = apsContext.map(new MapFunction<Params, Params>() { // from class: com.alibaba.alink.operator.common.aps.ApsEnv.1
            private static final long serialVersionUID = 1571322331049253987L;

            public Params map(Params params2) throws Exception {
                return params2.set((ParamInfo<ParamInfo<Integer>>) ApsContext.ALINK_APS_NUM_CHECK_POINT, (ParamInfo<Integer>) Integer.valueOf(i2)).set((ParamInfo<ParamInfo<Integer>>) ApsContext.ALINK_APS_NUM_ITER, (ParamInfo<Integer>) Integer.valueOf(i));
            }
        });
        if (z) {
            Tuple4<DataSet<Tuple2<Long, MT>>, DataSet<Tuple2<Long, DT>>, ApsContext, BatchOperator[]> persistentAll = persistentAll(dataSet, zipWithIndex, map, batchOperatorArr, "input");
            dataSet = (DataSet) persistentAll.f0;
            zipWithIndex = (DataSet) persistentAll.f1;
            map = (ApsContext) persistentAll.f2;
            batchOperatorArr = (BatchOperator[]) persistentAll.f3;
            if (persistentHook != null) {
                dataSet = persistentHook.hook(dataSet);
            }
        }
        for (int i3 = 0; i3 < i2; i3++) {
            LOG.info("ckptStart:{}", Integer.valueOf(i3));
            ApsContext put = map.m307clone().put(new Params().set((ParamInfo<ParamInfo<Integer>>) ApsContext.ALINK_APS_CUR_CHECK_POINT, (ParamInfo<Integer>) Integer.valueOf(i3)));
            IterativeDataSet<Tuple2<Long, MT>> iterate = dataSet.iterate(Integer.MAX_VALUE);
            put.updateLoopInfo(iterate);
            Tuple2<DataSet<Tuple2<Long, MT>>, ApsContext> train = apsIterator.train(iterate, getPartition(zipWithIndex, put), put, batchOperatorArr, params);
            dataSet = persistentModel(iterate.closeWith((DataSet) train.f0, ((ApsContext) train.f1).getCriterion()), "model_" + i3);
            if (persistentHook != null) {
                dataSet = persistentHook.hook(dataSet);
            }
            if (breakAll()) {
                break;
            }
        }
        return new Tuple2<>(dataSet, batchOperatorArr);
    }

    public boolean available() {
        return null != this.checkpoint;
    }

    public boolean breakAll() {
        return !available() || this.isBreakAll;
    }

    public Tuple4<DataSet<Tuple2<Long, MT>>, DataSet<Tuple2<Long, DT>>, ApsContext, BatchOperator[]> persistentAll(DataSet<Tuple2<Long, MT>> dataSet, DataSet<Tuple2<Long, DT>> dataSet2, ApsContext apsContext, BatchOperator[] batchOperatorArr, String str) {
        if (!available()) {
            return new Tuple4<>(dataSet, dataSet2, apsContext, batchOperatorArr);
        }
        TypeInformation<DT> typeAt = null == dataSet2 ? null : dataSet2.getType().getTypeAt(1);
        int length = null == batchOperatorArr ? 0 : batchOperatorArr.length;
        BatchOperator[] batchOperatorArr2 = new BatchOperator[length + 3];
        batchOperatorArr2[0] = this.apsSerializeModel.serializeModel(dataSet, this.mlEnvId);
        batchOperatorArr2[1] = null == dataSet2 ? null : this.apsSerializeData.serializeData(dataSet2, this.mlEnvId);
        batchOperatorArr2[2] = null == apsContext ? null : apsContext.serialize(this.mlEnvId);
        for (int i = 0; i < length; i++) {
            batchOperatorArr2[3 + i] = batchOperatorArr[i];
        }
        BatchOperator<?>[] persistentWithRetry = persistentWithRetry(str, batchOperatorArr2);
        DataSet<Tuple2<Long, MT>> deserilizeModel = null == persistentWithRetry[0] ? null : this.apsSerializeModel.deserilizeModel(persistentWithRetry[0], dataSet.getType());
        DataSet<Tuple2<Long, DT>> deserializeData = null == persistentWithRetry[1] ? null : this.apsSerializeData.deserializeData(persistentWithRetry[1], typeAt);
        ApsContext deserilize = null == persistentWithRetry[2] ? null : ApsContext.deserilize(persistentWithRetry[2]);
        BatchOperator[] batchOperatorArr3 = null;
        if (length > 0) {
            batchOperatorArr3 = new BatchOperator[length];
            for (int i2 = 0; i2 < length; i2++) {
                batchOperatorArr3[i2] = persistentWithRetry[3 + i2];
            }
        }
        return new Tuple4<>(deserilizeModel, deserializeData, deserilize, batchOperatorArr3);
    }

    private BatchOperator[] persistentWithRetry(String str, BatchOperator[] batchOperatorArr) {
        BatchOperator[] batchOperatorArr2 = null;
        for (int i = 0; i < 10; i++) {
            try {
                batchOperatorArr2 = persistent("alink_aps_tmp_ckpt_" + str + "_" + i, batchOperatorArr);
                break;
            } catch (Exception e) {
                e.printStackTrace();
                System.out.println("ckpt retry:  j: " + i);
                if (i == 9) {
                    throw new AkUnclassifiedErrorException("Error. ", e);
                }
            }
        }
        return batchOperatorArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private BatchOperator[] persistent(String str, BatchOperator[] batchOperatorArr) throws Exception {
        if (!available()) {
            return batchOperatorArr;
        }
        if (batchOperatorArr == null || batchOperatorArr.length == 0) {
            return null;
        }
        String[] strArr = new String[batchOperatorArr.length];
        for (int i = 0; i < batchOperatorArr.length; i++) {
            if (batchOperatorArr[i] != null) {
                strArr[i] = TableUtil.getTempTableName(str + "_" + i);
                this.checkpoint.write(batchOperatorArr[i], strArr[i], this.mlEnvId, new Params().set((ParamInfo<ParamInfo<Long>>) ApsContext.LIFECYCLE, (ParamInfo<Long>) Long.valueOf(CHECKPOINT_LIFE_CYCLE)));
            }
        }
        Map allAccumulatorResults = BatchOperator.getExecutionEnvironmentFromOps(batchOperatorArr).execute().getAllAccumulatorResults();
        LOG.info("{}:{}", ApsContext.alinkApsBreakAll, allAccumulatorResults);
        this.isBreakAll = allAccumulatorResults.containsKey(ApsContext.alinkApsBreakAll) && ((Integer) allAccumulatorResults.get(ApsContext.alinkApsBreakAll)).intValue() > 0;
        BatchOperator[] batchOperatorArr2 = new BatchOperator[batchOperatorArr.length];
        for (int i2 = 0; i2 < batchOperatorArr.length; i2++) {
            if (strArr[i2] != null) {
                batchOperatorArr2[i2] = this.checkpoint.read(strArr[i2], this.mlEnvId, new Params());
                batchOperatorArr2[i2] = (BatchOperator) new TableSourceBatchOp(DataSetConversionUtil.toTable(this.mlEnvId, (DataSet<Row>) batchOperatorArr2[i2].getDataSet().partitionCustom(new Partitioner<Integer>() { // from class: com.alibaba.alink.operator.common.aps.ApsEnv.2
                    private static final long serialVersionUID = -8803083408042377645L;

                    public int partition(Integer num, int i3) {
                        return num.intValue() % i3;
                    }
                }, new KeySelector<Row, Integer>() { // from class: com.alibaba.alink.operator.common.aps.ApsEnv.3
                    private static final long serialVersionUID = 7152274198229092623L;

                    public Integer getKey(Row row) throws Exception {
                        return Integer.valueOf(new Random().nextInt(Integer.MAX_VALUE));
                    }
                }), batchOperatorArr2[i2].getColNames(), batchOperatorArr2[i2].getColTypes())).setMLEnvironmentId(this.mlEnvId);
            }
        }
        return batchOperatorArr2;
    }

    public DataSet<Tuple2<Long, MT>> persistentModel(DataSet<Tuple2<Long, MT>> dataSet, String str) {
        return (DataSet) persistentAll(dataSet, null, null, null, str).f0;
    }

    private DataSet<DT> getPartition(DataSet<Tuple2<Long, DT>> dataSet, ApsContext apsContext) {
        return dataSet.flatMap(new RichFlatMapFunction<Tuple2<Long, DT>, DT>() { // from class: com.alibaba.alink.operator.common.aps.ApsEnv.4
            private static final long serialVersionUID = 343096718539266136L;
            int numMiniBatch;
            int curBlock;

            public void open(Configuration configuration) throws Exception {
                ApsEnv.LOG.info("{}:{}", Thread.currentThread().getName(), "open");
                Params params = (Params) getRuntimeContext().getBroadcastVariable("ApsContext").get(0);
                this.numMiniBatch = params.getInteger(ApsContext.alinkApsNumMiniBatch).intValue();
                this.curBlock = params.getInteger("alinkApsCurBlock").intValue();
            }

            public void close() throws Exception {
                ApsEnv.LOG.info("{}:{}", Thread.currentThread().getName(), "close");
            }

            public void flatMap(Tuple2<Long, DT> tuple2, Collector<DT> collector) throws Exception {
                if (((Long) tuple2.f0).longValue() % this.numMiniBatch == this.curBlock) {
                    collector.collect(tuple2.f1);
                }
            }
        }).withBroadcastSet(apsContext.getDataSet(), "ApsContext").returns(dataSet.getType().getTypeAt(1)).name("SelectBlockData");
    }
}
