package com.alibaba.alink.operator.common.timeseries;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.io.plugin.ResourcePluginFactory;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.pyrunner.PyMIMOCalcHandle;
import com.alibaba.alink.common.pyrunner.PyMIMOCalcRunner;
import com.alibaba.alink.common.pyrunner.bridge.BasePythonBridge;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.CloseableThreadLocal;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.params.dl.HasPythonEnv;
import com.alibaba.alink.params.timeseries.ProphetParams;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import java.lang.invoke.SerializedLambda;
import java.math.BigDecimal;
import java.sql.Timestamp;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/timeseries/ProphetMapper.class */
public class ProphetMapper extends TimeSeriesSingleMapper {
    private static final Logger LOG = LoggerFactory.getLogger(ProphetMapper.class);
    private transient CloseableThreadLocal<PyMIMOCalcRunner<PyMIMOCalcHandle>> runner;
    private final int predictNum;
    private final ResourcePluginFactory factory;

    public ProphetMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
        this.predictNum = ((Integer) params.get(ProphetParams.PREDICT_NUM)).intValue();
        this.factory = new ResourcePluginFactory();
    }

    @Override // com.alibaba.alink.operator.common.timeseries.TimeSeriesMapper, com.alibaba.alink.common.mapper.Mapper
    public void open() {
        super.open();
        this.runner = new CloseableThreadLocal<>(this::createPythonRunner, this::destroyPythonRunner);
    }

    private PyMIMOCalcRunner<PyMIMOCalcHandle> createPythonRunner() {
        HashMap hashMap = new HashMap();
        hashMap.put(BasePythonBridge.PY_TURN_ON_LOGGING_KEY, String.valueOf(AlinkGlobalConfiguration.isPrintProcessInfo()));
        if (this.params.contains(HasPythonEnv.PYTHON_ENV)) {
            hashMap.put("py_virtual_env", this.params.get(HasPythonEnv.PYTHON_ENV));
        }
        hashMap.getClass();
        PyMIMOCalcRunner<PyMIMOCalcHandle> pyMIMOCalcRunner = new PyMIMOCalcRunner<>("algo.prophet.PyProphetCalc2", (v1, v2) -> {
            return r3.getOrDefault(v1, v2);
        }, this.factory);
        pyMIMOCalcRunner.open();
        return pyMIMOCalcRunner;
    }

    private void destroyPythonRunner(PyMIMOCalcRunner<PyMIMOCalcHandle> pyMIMOCalcRunner) {
        pyMIMOCalcRunner.close();
    }

    @Override // com.alibaba.alink.common.mapper.Mapper
    public void close() {
        this.runner.close();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.operator.common.timeseries.TimeSeriesMapper
    public Tuple2<double[], String> predictSingleVar(Timestamp[] timestampArr, double[] dArr, int i) {
        LOG.info("Entering predictSingleVar");
        if (dArr.length <= 2) {
            LOG.info("historyVals.length <= 2");
            return Tuple2.of((Object) null, (Object) null);
        }
        HashMap hashMap = new HashMap();
        hashMap.put("periods", String.valueOf(this.predictNum));
        hashMap.put("freq", getFreq(timestampArr));
        hashMap.put("uncertainty_samples", String.valueOf(this.params.get(ProphetParams.UNCERTAINTY_SAMPLES)));
        hashMap.put("init_model", this.params.get(ProphetParams.STAN_INIT));
        if (this.params.contains(ProphetParams.HOLIDAYS)) {
            hashMap.put("holidays", this.params.get(ProphetParams.HOLIDAYS));
        }
        if (this.params.contains(ProphetParams.CAP)) {
            hashMap.put("cap", String.valueOf(this.params.get(ProphetParams.CAP)));
        }
        if (this.params.contains(ProphetParams.FLOOR)) {
            hashMap.put("floor", String.valueOf(this.params.get(ProphetParams.FLOOR)));
        }
        if (this.params.contains(ProphetParams.CHANGE_POINTS)) {
            hashMap.put("changepoints", this.params.get(ProphetParams.CHANGE_POINTS));
        }
        hashMap.put("growth", String.valueOf(this.params.get(ProphetParams.GROWTH)).toLowerCase());
        hashMap.put("holidays_prior_scale", String.valueOf(this.params.get(ProphetParams.HOLIDAYS_PRIOR_SCALE)));
        hashMap.put("n_change_point", String.valueOf(this.params.get(ProphetParams.N_CHANGE_POINT)));
        hashMap.put("change_point_range", String.valueOf(this.params.get(ProphetParams.CHANGE_POINT_RANGE)));
        hashMap.put("changepoint_prior_scale", String.valueOf(this.params.get(ProphetParams.CHANGE_POINT_PRIOR_SCALE)));
        hashMap.put("interval_width", String.valueOf(this.params.get(ProphetParams.INTERVAL_WIDTH)));
        hashMap.put("seasonality_mode", String.valueOf(this.params.get(ProphetParams.SEASONALITY_MODE)).toLowerCase());
        hashMap.put("seasonality_prior_scale", String.valueOf(this.params.get(ProphetParams.SEASONALITY_PRIOR_SCALE)));
        hashMap.put("mcmc_samples", String.valueOf(this.params.get(ProphetParams.MCMC_SAMPLES)));
        hashMap.put("yearly_seasonality", this.params.get(ProphetParams.YEARLY_SEASONALITY));
        hashMap.put("weekly_seasonality", this.params.get(ProphetParams.WEEKLY_SEASONALITY));
        hashMap.put("daily_seasonality", this.params.get(ProphetParams.DAILY_SEASONALITY));
        hashMap.put("include_history", String.valueOf(this.params.get(ProphetParams.INCLUDE_HISTORY)).toLowerCase());
        ArrayList arrayList = new ArrayList();
        SimpleDateFormat simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
        for (int i2 = 0; i2 < timestampArr.length; i2++) {
            arrayList.add(Row.of(new Object[]{simpleDateFormat.format(Long.valueOf(timestampArr[i2].getTime())), Double.valueOf(dArr[i2])}));
        }
        Tuple3<String, String, double[]> warmStartProphet = warmStartProphet(this.runner, hashMap, arrayList, null);
        LOG.info("Leaving predictSingleVar");
        return Tuple2.of(warmStartProphet.f2, warmStartProphet.f1);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static String getFreq(Timestamp[] timestampArr) {
        int length = timestampArr.length;
        long time = timestampArr[length - 1].getTime() - timestampArr[length - 2].getTime();
        if (time <= 0) {
            throw new AkIllegalDataException("history times must be acs, and not equal.");
        }
        return time + "L";
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Tuple3<String, String, double[]> warmStartProphet(CloseableThreadLocal<PyMIMOCalcRunner<PyMIMOCalcHandle>> closeableThreadLocal, Map<String, String> map, List<Row> list, String str) {
        List<Row> calc;
        LOG.info("Entering warmStartProphet");
        if (str != null) {
            LOG.info("initModel != null");
            calc = closeableThreadLocal.get().calc(map, list, null);
            LOG.info("after call calc");
        } else {
            LOG.info("initModel == null");
            ArrayList arrayList = new ArrayList();
            arrayList.add(Row.of(new Object[]{str}));
            calc = closeableThreadLocal.get().calc(map, list, arrayList);
            LOG.info("after call calc");
        }
        String str2 = (String) calc.get(0).getField(0);
        JSONObject jSONObject = (JSONObject) JSON.parse((String) calc.get(0).getField(1));
        String[] strArr = (String[]) jSONObject.keySet().toArray(new String[0]);
        int length = strArr.length;
        TypeInformation[] typeInformationArr = new TypeInformation[length];
        for (int i = 0; i < length; i++) {
            typeInformationArr[i] = AlinkTypes.DOUBLE;
        }
        Row[] rowArr = null;
        int i2 = -1;
        for (int i3 = 0; i3 < length; i3++) {
            JSONObject jSONObject2 = (JSONObject) jSONObject.get(strArr[i3]);
            if (i2 < 0) {
                i2 = jSONObject2.size();
                rowArr = new Row[i2];
                for (int i4 = 0; i4 < i2; i4++) {
                    rowArr[i4] = new Row(length);
                }
            }
            for (String str3 : jSONObject2.keySet()) {
                int parseInt = Integer.parseInt(str3);
                if (jSONObject2.get(str3) instanceof BigDecimal) {
                    rowArr[parseInt].setField(i3, Double.valueOf(((BigDecimal) jSONObject2.get(str3)).doubleValue()));
                } else {
                    long longValue = ((Long) jSONObject2.get(str3)).longValue();
                    if (strArr[i3].equals("ds")) {
                        typeInformationArr[i3] = AlinkTypes.SQL_TIMESTAMP;
                        rowArr[parseInt].setField(i3, new Timestamp(longValue - 28800000));
                    } else {
                        typeInformationArr[i3] = Types.LONG;
                        rowArr[parseInt].setField(i3, Long.valueOf(longValue));
                    }
                }
            }
        }
        MTable mTable = new MTable(rowArr, strArr, (TypeInformation<?>[]) typeInformationArr);
        double[] dArr = (double[]) JsonConverter.fromJson((String) calc.get(0).getField(2), double[].class);
        LOG.info("Leaving warmStartProphet");
        return Tuple3.of(str2, JsonConverter.toJson(mTable), dArr);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1252785192:
                if (implMethodName.equals("getOrDefault")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case VectorUtil.VectorSerialType.DENSE_VECTOR /* 0 */:
                if (serializedLambda.getImplMethodKind() == 9 && serializedLambda.getFunctionalInterfaceClass().equals("com/alibaba/alink/common/utils/Functional$SerializableBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("java/util/Map") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;")) {
                    Map map = (Map) serializedLambda.getCapturedArg(0);
                    return (v1, v2) -> {
                        return r0.getOrDefault(v1, v2);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
