package com.alibaba.alink.operator.common.timeseries;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.Functional;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.timeseries.TimeSeriesPredictParams;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.sql.Timestamp;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.function.BiFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/alibaba/alink/operator/common/timeseries/TimeSeriesMapper.class */
public abstract class TimeSeriesMapper extends Mapper {
    private static final Logger LOG = LoggerFactory.getLogger(TimeSeriesMapper.class);
    private TimeSeries<TimeSeriesMapper> timeSeries;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/timeseries/TimeSeriesMapper$Predictor.class */
    public static class Predictor<M> implements BiFunction<MTable, Integer, Tuple2<MTable, String>>, Serializable {
        private final M mapper;
        private final Functional.SerializableQuadFunction<M, Timestamp[], double[], Integer, Tuple2<double[], String>> predictSingleVar;
        private final Functional.SerializableQuadFunction<M, Timestamp[], Vector[], Integer, Tuple2<Vector[], String>> predictMultiVar;

        public Predictor(M m, Functional.SerializableQuadFunction<M, Timestamp[], double[], Integer, Tuple2<double[], String>> serializableQuadFunction, Functional.SerializableQuadFunction<M, Timestamp[], Vector[], Integer, Tuple2<Vector[], String>> serializableQuadFunction2) {
            this.mapper = m;
            this.predictSingleVar = serializableQuadFunction;
            this.predictMultiVar = serializableQuadFunction2;
        }

        @Override // java.util.function.BiFunction
        public Tuple2<MTable, String> apply(MTable mTable, Integer num) {
            TableSchema schema = mTable.getSchema();
            String str = null;
            TypeInformation[] fieldTypes = schema.getFieldTypes();
            int i = 0;
            while (true) {
                if (i >= fieldTypes.length) {
                    break;
                }
                if (fieldTypes[i] == Types.SQL_TIMESTAMP) {
                    str = schema.getFieldNames()[i];
                    break;
                }
                i++;
            }
            int numRow = mTable.getNumRow();
            if (numRow < 2) {
                return Tuple2.of((Object) null, (Object) null);
            }
            int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(schema, str);
            mTable.orderBy(findColIndexWithAssertAndHint);
            Timestamp[] timestampArr = new Timestamp[numRow];
            for (int i2 = 0; i2 < numRow; i2++) {
                timestampArr[i2] = (Timestamp) mTable.getEntry(i2, findColIndexWithAssertAndHint);
            }
            Timestamp[] predictTimes = TimeSeriesMapper.getPredictTimes(timestampArr, num.intValue());
            String[] numericCols = TableUtil.getNumericCols(schema);
            if (numericCols.length == 1) {
                int findColIndex = TableUtil.findColIndex(schema, numericCols[0]);
                double[] dArr = new double[numRow];
                for (int i3 = 0; i3 < numRow; i3++) {
                    dArr[i3] = ((Number) mTable.getEntry(i3, findColIndex)).doubleValue();
                }
                try {
                    Tuple2<double[], String> apply = this.predictSingleVar.apply(this.mapper, timestampArr, dArr, num);
                    if (apply.f0 == null) {
                        return Tuple2.of((Object) null, (Object) null);
                    }
                    ArrayList arrayList = new ArrayList();
                    for (int i4 = 0; i4 < ((double[]) apply.f0).length; i4++) {
                        arrayList.add(Row.of(new Object[]{predictTimes[i4], Double.valueOf(((double[]) apply.f0)[i4])}));
                    }
                    return new Tuple2<>(new MTable(arrayList, new String[]{str, numericCols[0]}, (TypeInformation<?>[]) new TypeInformation[]{Types.SQL_TIMESTAMP, Types.DOUBLE}), apply.f1);
                } catch (Throwable th) {
                    TimeSeriesMapper.LOG.info("Exception caught: ", th);
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        th.printStackTrace();
                    }
                    return Tuple2.of((Object) null, th.getMessage());
                }
            }
            int i5 = -1;
            int i6 = 0;
            while (true) {
                if (i6 >= fieldTypes.length) {
                    break;
                }
                if (TableUtil.isVector(fieldTypes[i6])) {
                    i5 = i6;
                    break;
                }
                i6++;
            }
            if (i5 < 0) {
                return Tuple2.of((Object) null, (Object) null);
            }
            Vector[] vectorArr = new Vector[numRow];
            for (int i7 = 0; i7 < numRow; i7++) {
                vectorArr[i7] = VectorUtil.getVector(mTable.getEntry(i7, i5));
            }
            try {
                Tuple2<Vector[], String> apply2 = this.predictMultiVar.apply(this.mapper, timestampArr, vectorArr, num);
                if (apply2.f0 == null) {
                    return Tuple2.of((Object) null, (Object) null);
                }
                ArrayList arrayList2 = new ArrayList();
                for (int i8 = 0; i8 < ((Vector[]) apply2.f0).length; i8++) {
                    arrayList2.add(Row.of(new Object[]{predictTimes[i8], ((Vector[]) apply2.f0)[i8]}));
                }
                return new Tuple2<>(new MTable(arrayList2, new String[]{str, schema.getFieldNames()[i5]}, (TypeInformation<?>[]) new TypeInformation[]{Types.SQL_TIMESTAMP, AlinkTypes.VECTOR}), apply2.f1);
            } catch (Throwable th2) {
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    th2.printStackTrace();
                }
                return Tuple2.of((Object) null, th2.getMessage());
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/timeseries/TimeSeriesMapper$TimeSeries.class */
    static class TimeSeries<M> implements Serializable {
        private final int predictNum;
        private final boolean withPredDetail;
        private final Predictor<M> predictor;

        public TimeSeries(Params params, Predictor<M> predictor) {
            this.predictNum = ((Integer) params.get(TimeSeriesPredictParams.PREDICT_NUM)).intValue();
            this.withPredDetail = params.contains(TimeSeriesPredictParams.PREDICTION_DETAIL_COL);
            this.predictor = predictor;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public final void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
            Object obj = slicedSelectedSample.get(0);
            if (obj instanceof MTable) {
                Tuple2<MTable, String> apply = this.predictor.apply((MTable) obj, Integer.valueOf(this.predictNum));
                slicedResult.set(0, apply.f0);
                if (this.withPredDetail) {
                    slicedResult.set(1, apply.f1);
                    return;
                }
                return;
            }
            if (!(obj instanceof String)) {
                slicedResult.set(0, null);
                if (this.withPredDetail) {
                    slicedResult.set(1, "data is not MTable.");
                    return;
                }
                return;
            }
            Tuple2<MTable, String> apply2 = this.predictor.apply(MTable.fromJson((String) obj), Integer.valueOf(this.predictNum));
            slicedResult.set(0, apply2.f0);
            if (this.withPredDetail) {
                slicedResult.set(1, apply2.f1);
            }
        }
    }

    public TimeSeriesMapper(TableSchema tableSchema, Params params) {
        super(tableSchema, params);
    }

    @Override // com.alibaba.alink.common.mapper.Mapper
    public void open() {
        this.timeSeries = new TimeSeries<>(this.params, new Predictor(this, (v0, v1, v2, v3) -> {
            return v0.predictSingleVar(v1, v2, v3);
        }, (v0, v1, v2, v3) -> {
            return v0.predictMultiVar(v1, v2, v3);
        }));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public final void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        this.timeSeries.map(slicedSelectedSample, slicedResult);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract Tuple2<double[], String> predictSingleVar(Timestamp[] timestampArr, double[] dArr, int i);

    protected abstract Tuple2<Vector[], String> predictMultiVar(Timestamp[] timestampArr, Vector[] vectorArr, int i);

    @Override // com.alibaba.alink.common.mapper.Mapper
    protected final Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, Params params) {
        TypeInformation<?> findColType = TableUtil.findColType(tableSchema, (String) params.get(TimeSeriesPredictParams.VALUE_COL));
        if (findColType.equals(AlinkTypes.M_TABLE) || findColType == Types.STRING) {
            return prepareTimeSeriesIoSchema(params);
        }
        throw new AkIllegalDataException("valCol must be mtable or string.");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareTimeSeriesIoSchema(Params params) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add(params.get(TimeSeriesPredictParams.PREDICTION_COL));
        arrayList2.add(AlinkTypes.M_TABLE);
        if (params.contains(TimeSeriesPredictParams.PREDICTION_DETAIL_COL)) {
            arrayList.add(params.get(TimeSeriesPredictParams.PREDICTION_DETAIL_COL));
            arrayList2.add(Types.STRING);
        }
        return Tuple4.of(new String[]{(String) params.get(TimeSeriesPredictParams.VALUE_COL)}, arrayList.toArray(new String[0]), arrayList2.toArray(new TypeInformation[0]), params.get(TimeSeriesPredictParams.RESERVED_COLS));
    }

    public static Timestamp[] getPredictTimes(Timestamp[] timestampArr, int i) {
        int length = timestampArr.length;
        boolean z = true;
        long time = timestampArr[length - 1].getTime() - timestampArr[length - 2].getTime();
        int i2 = 0;
        while (true) {
            if (i2 >= length - 1) {
                break;
            }
            if (time != timestampArr[i2 + 1].getTime() - timestampArr[i2].getTime()) {
                z = false;
                break;
            }
            i2++;
        }
        Timestamp[] timestampArr2 = new Timestamp[i];
        if (z) {
            for (int i3 = 0; i3 < i; i3++) {
                timestampArr2[i3] = new Timestamp(timestampArr[length - 1].getTime() + (time * i3) + time);
            }
        } else {
            LocalDateTime localDateTime = timestampArr[length - 1].toLocalDateTime();
            LocalDateTime localDateTime2 = timestampArr[length - 2].toLocalDateTime();
            int year = localDateTime.getYear() - localDateTime2.getYear();
            int monthValue = localDateTime.getMonthValue() - localDateTime2.getMonthValue();
            int dayOfMonth = localDateTime.getDayOfMonth() - localDateTime2.getDayOfMonth();
            int hour = localDateTime.getHour() - localDateTime2.getHour();
            int minute = localDateTime.getMinute() - localDateTime2.getMinute();
            int second = localDateTime.getSecond() - localDateTime2.getSecond();
            int nano = localDateTime.getNano() - localDateTime2.getNano();
            for (int i4 = 0; i4 < i; i4++) {
                localDateTime = localDateTime.plusYears(year).plusMonths(monthValue).plusDays(dayOfMonth).plusHours(hour).plusMinutes(minute).plusSeconds(second).plusNanos(nano);
                timestampArr2[i4] = Timestamp.valueOf(localDateTime);
            }
        }
        return timestampArr2;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -2027432730:
                if (implMethodName.equals("predictSingleVar")) {
                    z = false;
                    break;
                }
                break;
            case -471764185:
                if (implMethodName.equals("predictMultiVar")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case VectorUtil.VectorSerialType.DENSE_VECTOR /* 0 */:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("com/alibaba/alink/common/utils/Functional$SerializableQuadFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/alibaba/alink/operator/common/timeseries/TimeSeriesMapper") && serializedLambda.getImplMethodSignature().equals("([Ljava/sql/Timestamp;[DI)Lorg/apache/flink/api/java/tuple/Tuple2;")) {
                    return (v0, v1, v2, v3) -> {
                        return v0.predictSingleVar(v1, v2, v3);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("com/alibaba/alink/common/utils/Functional$SerializableQuadFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/alibaba/alink/operator/common/timeseries/TimeSeriesMapper") && serializedLambda.getImplMethodSignature().equals("([Ljava/sql/Timestamp;[Lcom/alibaba/alink/common/linalg/Vector;I)Lorg/apache/flink/api/java/tuple/Tuple2;")) {
                    return (v0, v1, v2, v3) -> {
                        return v0.predictMultiVar(v1, v2, v3);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
