package com.alibaba.alink.operator.common.feature;

import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.operator.common.tree.Preprocessing;
import com.alibaba.alink.params.dataproc.HasHandleInvalid;
import com.alibaba.alink.params.feature.HasEncodeWithoutWoe;
import com.alibaba.alink.params.feature.QuantileDiscretizerPredictParams;
import com.alibaba.alink.params.shared.colname.HasOutputCol;
import com.alibaba.alink.params.shared.colname.HasSelectedCols;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/feature/QuantileDiscretizerModelMapper.class */
public class QuantileDiscretizerModelMapper extends ModelMapper implements Cloneable {
    private static final long serialVersionUID = 5400967430347827818L;
    private DiscreteMapperBuilder mapperBuilder;
    private final String[] inputPredictColNames;

    /* loaded from: input_file:com/alibaba/alink/operator/common/feature/QuantileDiscretizerModelMapper$DiscreteMapperBuilder.class */
    public static class DiscreteMapperBuilder implements Serializable, Cloneable {
        private static final long serialVersionUID = -1726998479492235578L;
        DiscreteParamsBuilder paramsBuilder;
        Map<Integer, Long> vectorSize = new HashMap();
        Map<Integer, Long> dropIndex = new HashMap();
        Integer assembledVectorSize;
        NumericQuantileDiscretizer[] discretizers;
        transient ThreadLocal<Long[]> predictIndices;

        public DiscreteMapperBuilder(Params params, TableSchema tableSchema) {
            this.paramsBuilder = new DiscreteParamsBuilder(params, tableSchema, (HasEncodeWithoutWoe.Encode) params.get(QuantileDiscretizerPredictParams.ENCODE));
            this.discretizers = new NumericQuantileDiscretizer[this.paramsBuilder.selectedCols.length];
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void open() {
            this.predictIndices = ThreadLocal.withInitial(() -> {
                return new Long[this.paramsBuilder.selectedCols.length];
            });
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void setAssembledVectorSize() {
            this.assembledVectorSize = Integer.valueOf(this.vectorSize.values().stream().mapToInt((v0) -> {
                return v0.intValue();
            }).sum());
            if (this.paramsBuilder.dropLast) {
                this.assembledVectorSize = Integer.valueOf(this.assembledVectorSize.intValue() - this.vectorSize.size());
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) {
            Long[] lArr = this.predictIndices.get();
            for (int i = 0; i < this.paramsBuilder.selectedCols.length; i++) {
                Object obj = slicedSelectedSample.get(i);
                int findIndex = this.discretizers[i].findIndex(obj);
                lArr[i] = Long.valueOf(findIndex);
                if (!this.discretizers[i].isValid(findIndex)) {
                    switch (this.paramsBuilder.handleInvalidStrategy) {
                        case KEEP:
                            break;
                        case SKIP:
                            lArr[i] = null;
                            break;
                        case ERROR:
                            throw new AkIllegalDataException("Unseen token: " + obj);
                        default:
                            throw new AkIllegalOperatorParameterException("Invalid handle invalid strategy.");
                    }
                }
            }
            Row resultRow = QuantileDiscretizerModelMapper.setResultRow(lArr, this.paramsBuilder.encode, this.dropIndex, this.vectorSize, this.paramsBuilder.dropLast, this.assembledVectorSize.intValue());
            for (int i2 = 0; i2 < resultRow.getArity(); i2++) {
                slicedResult.set(i2, resultRow.getField(i2));
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/feature/QuantileDiscretizerModelMapper$DiscreteParamsBuilder.class */
    public static class DiscreteParamsBuilder implements Serializable {
        private static final long serialVersionUID = 8218203038244120910L;
        public HasEncodeWithoutWoe.Encode encode;
        public HasHandleInvalid.HandleInvalid handleInvalidStrategy;
        public String[] selectedCols;
        public String[] resultCols;
        public TypeInformation<?>[] resultColTypes;
        public String[] reservedCols;
        public boolean dropLast;

        public DiscreteParamsBuilder(Params params, TableSchema tableSchema, HasEncodeWithoutWoe.Encode encode) {
            this.reservedCols = (String[]) params.get(QuantileDiscretizerPredictParams.RESERVED_COLS);
            this.handleInvalidStrategy = (HasHandleInvalid.HandleInvalid) params.get(QuantileDiscretizerPredictParams.HANDLE_INVALID);
            this.encode = encode;
            if (!params.contains(QuantileDiscretizerPredictParams.OUTPUT_COLS) && params.contains(HasOutputCol.OUTPUT_COL)) {
                params.set((ParamInfo<ParamInfo<String[]>>) QuantileDiscretizerPredictParams.OUTPUT_COLS, (ParamInfo<String[]>) new String[]{(String) params.get(HasOutputCol.OUTPUT_COL)});
            }
            if (params.contains(QuantileDiscretizerPredictParams.SELECTED_COLS)) {
                this.selectedCols = (String[]) params.get(QuantileDiscretizerPredictParams.SELECTED_COLS);
            } else {
                AkPreconditions.checkArgument(encode.equals(HasEncodeWithoutWoe.Encode.ASSEMBLED_VECTOR), "Not given selectedCols, encode must be ASSEMBLED_VECTOR!");
            }
            this.resultCols = (String[]) params.get(QuantileDiscretizerPredictParams.OUTPUT_COLS);
            switch (encode) {
                case INDEX:
                    if (null == this.resultCols) {
                        this.resultCols = this.selectedCols;
                    }
                    AkPreconditions.checkArgument(this.resultCols.length == this.selectedCols.length, "Input column name is not match output column name.");
                    this.resultColTypes = new TypeInformation[this.resultCols.length];
                    Arrays.fill(this.resultColTypes, Types.LONG);
                    break;
                case VECTOR:
                    if (null == this.resultCols) {
                        this.resultCols = this.selectedCols;
                    }
                    AkPreconditions.checkArgument(this.resultCols.length == this.selectedCols.length, "Input column name is not match output column name.");
                    this.resultColTypes = new TypeInformation[this.resultCols.length];
                    Arrays.fill(this.resultColTypes, AlinkTypes.SPARSE_VECTOR);
                    break;
                case ASSEMBLED_VECTOR:
                    String[] strArr = (String[]) params.get(QuantileDiscretizerPredictParams.OUTPUT_COLS);
                    AkPreconditions.checkArgument(null != strArr && strArr.length == 1, "When encode is ASSEMBLED_VECTOR, outputCols must be given and the length must be 1!");
                    this.resultColTypes = new TypeInformation[this.resultCols.length];
                    Arrays.fill(this.resultColTypes, AlinkTypes.SPARSE_VECTOR);
                    break;
                default:
                    throw new AkIllegalOperatorParameterException("Not support encode: " + encode.name());
            }
            this.dropLast = ((Boolean) params.get(QuantileDiscretizerPredictParams.DROP_LAST)).booleanValue();
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/feature/QuantileDiscretizerModelMapper$DoubleNumericQuantileDiscretizer.class */
    public static class DoubleNumericQuantileDiscretizer implements NumericQuantileDiscretizer {
        private static final long serialVersionUID = -1681225445245237307L;
        double[] bounds;
        boolean isLeftOpen;
        int[] boundIndex;
        int nullIndex;
        boolean zeroAsMissing;

        public DoubleNumericQuantileDiscretizer(double[] dArr, boolean z, int[] iArr, int i, boolean z2) {
            this.bounds = dArr;
            this.isLeftOpen = z;
            this.boundIndex = iArr;
            this.nullIndex = i;
            this.zeroAsMissing = z2;
        }

        @Override // com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelMapper.NumericQuantileDiscretizer
        public boolean isValid(int i) {
            return i != this.nullIndex;
        }

        @Override // com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelMapper.NumericQuantileDiscretizer
        public int findIndex(Object obj) {
            int i;
            if (obj == null) {
                return this.nullIndex;
            }
            double doubleValue = ((Number) obj).doubleValue();
            if (Preprocessing.isMissing(doubleValue, this.zeroAsMissing)) {
                return this.nullIndex;
            }
            int binarySearch = Arrays.binarySearch(this.bounds, doubleValue);
            if (this.isLeftOpen) {
                i = binarySearch >= 0 ? binarySearch - 1 : (-binarySearch) - 2;
            } else {
                i = binarySearch >= 0 ? binarySearch : (-binarySearch) - 2;
            }
            return this.boundIndex[i];
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/feature/QuantileDiscretizerModelMapper$LongQuantileDiscretizer.class */
    public static class LongQuantileDiscretizer implements NumericQuantileDiscretizer {
        private static final long serialVersionUID = 8869074090757935247L;
        long[] bounds;
        boolean isLeftOpen;
        int[] boundIndex;
        int nullIndex;
        boolean zeroAsMissing;

        public LongQuantileDiscretizer(long[] jArr, boolean z, int[] iArr, int i, boolean z2) {
            this.bounds = jArr;
            this.isLeftOpen = z;
            this.boundIndex = iArr;
            this.nullIndex = i;
            this.zeroAsMissing = z2;
        }

        @Override // com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelMapper.NumericQuantileDiscretizer
        public boolean isValid(int i) {
            return i != this.nullIndex;
        }

        @Override // com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelMapper.NumericQuantileDiscretizer
        public int findIndex(Object obj) {
            int i;
            if (obj == null) {
                return this.nullIndex;
            }
            long longValue = ((Number) obj).longValue();
            if (Preprocessing.isMissing(longValue, this.zeroAsMissing)) {
                return this.nullIndex;
            }
            int binarySearch = Arrays.binarySearch(this.bounds, longValue);
            if (this.isLeftOpen) {
                i = binarySearch >= 0 ? binarySearch - 1 : (-binarySearch) - 2;
            } else {
                i = binarySearch >= 0 ? binarySearch : (-binarySearch) - 2;
            }
            return this.boundIndex[i];
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/feature/QuantileDiscretizerModelMapper$NumericQuantileDiscretizer.class */
    public interface NumericQuantileDiscretizer extends Serializable {
        boolean isValid(int i);

        int findIndex(Object obj);
    }

    public QuantileDiscretizerModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
        if (params.contains(QuantileDiscretizerPredictParams.SELECTED_COLS)) {
            this.inputPredictColNames = (String[]) params.get(QuantileDiscretizerPredictParams.SELECTED_COLS);
        } else {
            this.inputPredictColNames = null;
        }
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        QuantileDiscretizerModelDataConverter quantileDiscretizerModelDataConverter = new QuantileDiscretizerModelDataConverter();
        quantileDiscretizerModelDataConverter.load(list);
        String[] strArr = (String[]) quantileDiscretizerModelDataConverter.meta.get(HasSelectedCols.SELECTED_COLS);
        if (null != this.inputPredictColNames) {
            HashSet hashSet = new HashSet();
            for (String str : strArr) {
                hashSet.add(str);
            }
            for (String str2 : this.inputPredictColNames) {
                if (!hashSet.contains(str2)) {
                    throw new AkIllegalArgumentException("Column '" + str2 + "' has not been precessed in QuantileDiscretizer model training.");
                }
            }
        }
        this.mapperBuilder = new DiscreteMapperBuilder(this.params, getDataSchema());
        for (int i = 0; i < this.mapperBuilder.paramsBuilder.selectedCols.length; i++) {
            ContinuousRanges continuousRanges = quantileDiscretizerModelDataConverter.data.get(this.mapperBuilder.paramsBuilder.selectedCols[i]);
            AkPreconditions.checkNotNull(continuousRanges, "%s not found in model", this.mapperBuilder.paramsBuilder.selectedCols[i]);
            long intervalNum = continuousRanges.getIntervalNum() - 1;
            long intervalNum2 = continuousRanges.getIntervalNum();
            switch (this.mapperBuilder.paramsBuilder.handleInvalidStrategy) {
                case KEEP:
                    this.mapperBuilder.vectorSize.put(Integer.valueOf(i), Long.valueOf(intervalNum2 + 1));
                    break;
                case SKIP:
                case ERROR:
                    this.mapperBuilder.vectorSize.put(Integer.valueOf(i), Long.valueOf(intervalNum + 1));
                    break;
                default:
                    throw new AkUnsupportedOperationException("Unsupported now.");
            }
            if (this.mapperBuilder.paramsBuilder.dropLast) {
                this.mapperBuilder.dropIndex.put(Integer.valueOf(i), Long.valueOf(intervalNum));
            }
            this.mapperBuilder.discretizers[i] = createQuantileDiscretizer(continuousRanges, quantileDiscretizerModelDataConverter.meta);
        }
        this.mapperBuilder.setAssembledVectorSize();
        this.mapperBuilder.open();
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        DiscreteMapperBuilder discreteMapperBuilder = new DiscreteMapperBuilder(params, getDataSchema());
        return Tuple4.of(discreteMapperBuilder.paramsBuilder.selectedCols, discreteMapperBuilder.paramsBuilder.resultCols, discreteMapperBuilder.paramsBuilder.resultColTypes, discreteMapperBuilder.paramsBuilder.reservedCols);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.mapper.Mapper
    public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
        this.mapperBuilder.map(slicedSelectedSample, slicedResult);
    }

    public static NumericQuantileDiscretizer createQuantileDiscretizer(ContinuousRanges continuousRanges, Params params) {
        int length = continuousRanges.splitsArray.length + 1;
        boolean booleanValue = continuousRanges.getLeftOpen().booleanValue();
        int intervalNum = continuousRanges.getIntervalNum();
        int[] array = IntStream.range(0, length + 2).toArray();
        array[length] = length - 1;
        if (continuousRanges.isFloat()) {
            double[] dArr = new double[length + 1];
            dArr[0] = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < length - 1; i++) {
                dArr[i + 1] = continuousRanges.splitsArray[i].doubleValue();
            }
            dArr[length] = Double.POSITIVE_INFINITY;
            return new DoubleNumericQuantileDiscretizer(dArr, booleanValue, array, intervalNum, ((Boolean) params.get(Preprocessing.ZERO_AS_MISSING)).booleanValue());
        }
        long[] jArr = new long[length + 1];
        jArr[0] = -9223372036854775807L;
        for (int i2 = 0; i2 < length - 1; i2++) {
            jArr[i2 + 1] = continuousRanges.splitsArray[i2].longValue();
        }
        jArr[length] = Long.MAX_VALUE;
        return new LongQuantileDiscretizer(jArr, booleanValue, array, intervalNum, ((Boolean) params.get(Preprocessing.ZERO_AS_MISSING)).booleanValue());
    }

    public static Row setResultRow(Long[] lArr, HasEncodeWithoutWoe.Encode encode, Map<Integer, Long> map, Map<Integer, Long> map2, boolean z, int i) {
        Row row;
        int[] iArr;
        switch (encode) {
            case INDEX:
            case VECTOR:
                row = new Row(lArr.length);
                iArr = new int[lArr.length];
                for (int i2 = 0; i2 < lArr.length; i2++) {
                    iArr[i2] = i2;
                }
                break;
            case ASSEMBLED_VECTOR:
                row = new Row(1);
                iArr = new int[]{0};
                break;
            default:
                throw new AkUnsupportedOperationException("Not support encode type!");
        }
        setResultRow(lArr, encode, map, map2, z, i, row, iArr);
        return row;
    }

    public static void setResultRow(Long[] lArr, HasEncodeWithoutWoe.Encode encode, Map<Integer, Long> map, Map<Integer, Long> map2, boolean z, int i, Row row, int[] iArr) {
        switch (encode) {
            case INDEX:
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    row.setField(iArr[i2], lArr[i2]);
                }
                return;
            case VECTOR:
                for (int i3 = 0; i3 < iArr.length; i3++) {
                    if (lArr[i3] == null) {
                        row.setField(iArr[i3], (Object) null);
                    } else {
                        Tuple2<Integer, Integer> vectorSizeAndIndex = getVectorSizeAndIndex(lArr[i3], map.get(Integer.valueOf(i3)), map2.get(Integer.valueOf(i3)), z);
                        row.setField(iArr[i3], null == vectorSizeAndIndex.f1 ? new SparseVector(((Integer) vectorSizeAndIndex.f0).intValue()) : new SparseVector(((Integer) vectorSizeAndIndex.f0).intValue(), new int[]{((Integer) vectorSizeAndIndex.f1).intValue()}, new double[]{1.0d}));
                    }
                }
                return;
            case ASSEMBLED_VECTOR:
                ArrayList arrayList = new ArrayList();
                int i4 = 0;
                for (int i5 = 0; i5 < lArr.length; i5++) {
                    if (null == lArr[i5]) {
                        row.setField(iArr[i5], (Object) null);
                    }
                    Tuple2<Integer, Integer> vectorSizeAndIndex2 = getVectorSizeAndIndex(lArr[i5], map.get(Integer.valueOf(i5)), map2.get(Integer.valueOf(i5)), z);
                    if (vectorSizeAndIndex2.f1 != null) {
                        arrayList.add(Integer.valueOf(i4 + ((Integer) vectorSizeAndIndex2.f1).intValue()));
                    }
                    i4 += ((Integer) vectorSizeAndIndex2.f0).intValue();
                }
                double[] dArr = new double[arrayList.size()];
                Arrays.fill(dArr, 1.0d);
                int[] iArr2 = new int[arrayList.size()];
                for (int i6 = 0; i6 < arrayList.size(); i6++) {
                    iArr2[i6] = ((Integer) arrayList.get(i6)).intValue();
                }
                row.setField(iArr[0], new SparseVector(i, iArr2, dArr));
                return;
            default:
                throw new AkUnsupportedOperationException("Not support encode type!");
        }
    }

    private static Tuple2<Integer, Integer> getVectorSizeAndIndex(Long l, Long l2, Long l3, boolean z) {
        if (!z) {
            return Tuple2.of(Integer.valueOf(l3.intValue()), Integer.valueOf(l.intValue()));
        }
        int intValue = l3.intValue() - 1;
        if (l.equals(l2)) {
            return Tuple2.of(Integer.valueOf(intValue), (Object) null);
        }
        return Tuple2.of(Integer.valueOf(intValue), Integer.valueOf(l.longValue() > l2.longValue() ? l.intValue() - 1 : l.intValue()));
    }
}
