package com.alibaba.alink.operator.batch.regression;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamCond;
import com.alibaba.alink.common.annotation.ParamMutexRule;
import com.alibaba.alink.common.annotation.ParamMutexRules;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.regression.IsotonicRegressionConverter;
import com.alibaba.alink.operator.common.regression.IsotonicRegressionModelData;
import com.alibaba.alink.operator.common.regression.isotonicReg.LinkedData;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.regression.IsotonicRegTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import com.google.common.collect.Lists;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@InputPorts(values = {@PortSpec(PortType.DATA)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "labelCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "featureCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES})})
@NameCn("保序回归训练")
@ParamMutexRules({@ParamMutexRule(name = "vectorCol", type = ParamMutexRule.ActionType.DISABLE, cond = @ParamCond(name = "featureCol", type = ParamCond.CondType.WHEN_NOT_NULL)), @ParamMutexRule(name = "featureCol", type = ParamMutexRule.ActionType.DISABLE, cond = @ParamCond(name = "vectorCol", type = ParamCond.CondType.WHEN_NOT_NULL))})
@NameEn("Isotonic Regression Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.IsotonicRegression")
/* loaded from: input_file:com/alibaba/alink/operator/batch/regression/IsotonicRegTrainBatchOp.class */
public final class IsotonicRegTrainBatchOp extends BatchOperator<IsotonicRegTrainBatchOp> implements IsotonicRegTrainParams<IsotonicRegTrainBatchOp> {
    private static final long serialVersionUID = -1681187909098085588L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/regression/IsotonicRegTrainBatchOp$BuildModel.class */
    public static class BuildModel implements MapPartitionFunction<byte[], Row> {
        private static final long serialVersionUID = -8409030153684329440L;
        private final boolean isotonic;
        private final String featureColName;
        private final String vectorColName;
        private final int index;

        BuildModel(boolean z, String str, String str2, int i) {
            this.isotonic = z;
            this.featureColName = str;
            this.vectorColName = str2;
            this.index = i;
        }

        public void mapPartition(Iterable<byte[]> iterable, Collector<Row> collector) {
            LinkedData linkedData = new LinkedData(IsotonicRegTrainBatchOp.updateLinkedData(new LinkedData(IsotonicRegTrainBatchOp.summarizeModelData(iterable))));
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            while (linkedData.hasNext()) {
                Tuple4<Float, Double, Double, Float> data = linkedData.getData();
                float floatValue = ((Float) data.f0).floatValue();
                double doubleValue = ((Double) data.f1).doubleValue();
                double doubleValue2 = ((Double) data.f2).doubleValue();
                float floatValue2 = ((Float) data.f3).floatValue();
                double d = this.isotonic ? floatValue / floatValue2 : (-floatValue) / floatValue2;
                if (doubleValue == doubleValue2) {
                    arrayList.add(Double.valueOf(doubleValue));
                    arrayList2.add(Double.valueOf(d));
                } else {
                    arrayList.add(Double.valueOf(doubleValue));
                    arrayList2.add(Double.valueOf(d));
                    arrayList.add(Double.valueOf(doubleValue2));
                    arrayList2.add(Double.valueOf(d));
                }
                linkedData.advance();
            }
            IsotonicRegressionModelData isotonicRegressionModelData = new IsotonicRegressionModelData();
            isotonicRegressionModelData.boundaries = (Double[]) arrayList.toArray(new Double[0]);
            isotonicRegressionModelData.values = (Double[]) arrayList2.toArray(new Double[0]);
            isotonicRegressionModelData.meta.set((ParamInfo<ParamInfo<String>>) IsotonicRegTrainParams.FEATURE_COL, (ParamInfo<String>) this.featureColName);
            isotonicRegressionModelData.meta.set((ParamInfo<ParamInfo<String>>) IsotonicRegTrainParams.VECTOR_COL, (ParamInfo<String>) this.vectorColName);
            isotonicRegressionModelData.meta.set((ParamInfo<ParamInfo<Integer>>) IsotonicRegTrainParams.FEATURE_INDEX, (ParamInfo<Integer>) Integer.valueOf(this.index));
            new IsotonicRegressionConverter().save(isotonicRegressionModelData, collector);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/regression/IsotonicRegTrainBatchOp$PoolAdjacentViolators.class */
    public static class PoolAdjacentViolators implements MapPartitionFunction<Tuple3<Double, Double, Double>, byte[]> {
        private static final long serialVersionUID = -212769047923494155L;

        public void mapPartition(Iterable<Tuple3<Double, Double, Double>> iterable, Collector<byte[]> collector) {
            if (null == iterable) {
                return;
            }
            byte[] updateLinkedData = IsotonicRegTrainBatchOp.updateLinkedData(IsotonicRegTrainBatchOp.initLinkedData(iterable));
            if (updateLinkedData.length > 0) {
                collector.collect(updateLinkedData);
            }
        }
    }

    public IsotonicRegTrainBatchOp() {
        super(new Params());
    }

    public IsotonicRegTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public IsotonicRegTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        String[] strArr;
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String labelCol = getLabelCol();
        String featureCol = getFeatureCol();
        final String weightCol = getWeightCol();
        final String vectorCol = getVectorCol();
        final boolean booleanValue = getIsotonic().booleanValue();
        final int intValue = getFeatureIndex().intValue();
        if (null == vectorCol && null != featureCol) {
            strArr = weightCol == null ? new String[]{labelCol, featureCol} : new String[]{labelCol, featureCol, weightCol};
        } else {
            if (null != featureCol || null == vectorCol) {
                if (null != featureCol) {
                    throw new AkIllegalOperatorParameterException("featureCols and vectorCol cannot be set at the same time.");
                }
                throw new AkIllegalOperatorParameterException("Either featureColName or vectorColName is required!");
            }
            strArr = weightCol == null ? new String[]{labelCol, vectorCol} : new String[]{labelCol, vectorCol, weightCol};
        }
        setOutput((DataSet<Row>) checkAndGetFirst.select(strArr).getDataSet().map(new MapFunction<Row, Tuple3<Double, Double, Double>>() { // from class: com.alibaba.alink.operator.batch.regression.IsotonicRegTrainBatchOp.1
            private static final long serialVersionUID = 9034460902571223806L;

            public Tuple3<Double, Double, Double> map(Row row) {
                double doubleValue = ((Number) row.getField(0)).doubleValue();
                double d = booleanValue ? doubleValue : -doubleValue;
                double doubleValue2 = null == vectorCol ? ((Number) row.getField(1)).doubleValue() : VectorUtil.getVector(row.getField(1)).get(intValue);
                double doubleValue3 = null == weightCol ? 1.0d : ((Number) row.getField(2)).doubleValue();
                if (doubleValue3 < Criteria.INVALID_GAIN) {
                    throw new AkIllegalDataException("Weights must be non-negative!");
                }
                return Tuple3.of(Double.valueOf(d), Double.valueOf(doubleValue2), Double.valueOf(doubleValue3));
            }
        }).filter(new FilterFunction<Tuple3<Double, Double, Double>>() { // from class: com.alibaba.alink.operator.batch.regression.IsotonicRegTrainBatchOp.2
            private static final long serialVersionUID = 234777408534250527L;

            public boolean filter(Tuple3<Double, Double, Double> tuple3) {
                return ((Double) tuple3.f2).doubleValue() > Criteria.INVALID_GAIN;
            }
        }).rebalance().partitionByRange(new int[]{1}).mapPartition(new PoolAdjacentViolators()).mapPartition(new BuildModel(booleanValue, featureCol, vectorCol, intValue)).setParallelism(1), new IsotonicRegressionConverter().getModelSchema());
        return this;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static LinkedData initLinkedData(Iterable<Tuple3<Double, Double, Double>> iterable) {
        ArrayList newArrayList = Lists.newArrayList(iterable);
        newArrayList.sort(Comparator.comparing(tuple3 -> {
            return (Double) tuple3.f1;
        }));
        return new LinkedData(newArrayList);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static byte[] updateLinkedData(LinkedData linkedData) {
        if (null == linkedData || null == linkedData.getByteArray() || linkedData.getByteArray().length == 0) {
            return new byte[0];
        }
        Tuple4<Float, Double, Double, Float> data = linkedData.getData();
        float floatValue = ((Float) data.f0).floatValue();
        double doubleValue = ((Double) data.f1).doubleValue();
        float floatValue2 = ((Float) data.f3).floatValue();
        while (true) {
            float f = floatValue2;
            if (!linkedData.hasNext()) {
                return Arrays.copyOfRange(linkedData.getByteArray(), 0, linkedData.compact() * 24);
            }
            linkedData.advance();
            Tuple4<Float, Double, Double, Float> data2 = linkedData.getData();
            float floatValue3 = ((Float) data2.f0).floatValue();
            double doubleValue2 = ((Double) data2.f1).doubleValue();
            double doubleValue3 = ((Double) data2.f2).doubleValue();
            float floatValue4 = ((Float) data2.f3).floatValue();
            if (floatValue / f >= floatValue3 / floatValue4) {
                linkedData.removeCurrentAndRetreat();
                linkedData.putData(floatValue3 + floatValue, doubleValue, doubleValue3, f + floatValue4);
                if (linkedData.hasPrevious()) {
                    linkedData.retreat();
                }
                Tuple4<Float, Double, Double, Float> data3 = linkedData.getData();
                floatValue = ((Float) data3.f0).floatValue();
                doubleValue = ((Double) data3.f1).doubleValue();
                floatValue2 = ((Float) data3.f3).floatValue();
            } else {
                floatValue = floatValue3;
                doubleValue = doubleValue2;
                floatValue2 = floatValue4;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static byte[] summarizeModelData(Iterable<byte[]> iterable) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (byte[] bArr : iterable) {
            i += bArr.length;
            arrayList.add(ByteBuffer.wrap(bArr));
        }
        arrayList.sort(Comparator.comparingDouble(byteBuffer -> {
            return byteBuffer.getDouble(4);
        }));
        ByteBuffer allocate = ByteBuffer.allocate(i);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            allocate.put((ByteBuffer) it.next());
        }
        return allocate.array();
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ IsotonicRegTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
