package com.alibaba.alink.operator.batch.regression;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.FeatureColsVectorColMutexRule;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalModelException;
import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.batch.utils.WithTrainInfo;
import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp;
import com.alibaba.alink.operator.common.linear.LinearModelData;
import com.alibaba.alink.operator.common.linear.LinearModelDataConverter;
import com.alibaba.alink.operator.common.linear.LinearModelTrainInfo;
import com.alibaba.alink.operator.common.linear.LinearModelType;
import com.alibaba.alink.operator.common.linear.LinearRegressorModelInfo;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.regression.AftRegTrainParams;
import com.alibaba.alink.params.shared.colname.HasFeatureCols;
import com.alibaba.alink.params.shared.colname.HasVectorCol;
import com.alibaba.alink.params.shared.linear.HasWithIntercept;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@OutputPorts(values = {@PortSpec(PortType.MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.MODEL_INFO), @PortSpec(value = PortType.DATA, desc = PortDesc.FEATURE_IMPORTANCE), @PortSpec(value = PortType.DATA, desc = PortDesc.MODEL_WEIGHT)})
@InputPorts(values = {@PortSpec(PortType.DATA)})
@FeatureColsVectorColMutexRule
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES}), @ParamSelectColumnSpec(name = "labelCol"), @ParamSelectColumnSpec(name = "censorCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})})
@NameCn("生存回归训练")
@NameEn("Aft Survival Regression Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.AftSurvivalRegression")
/* loaded from: input_file:com/alibaba/alink/operator/batch/regression/AftSurvivalRegTrainBatchOp.class */
public class AftSurvivalRegTrainBatchOp extends BatchOperator<AftSurvivalRegTrainBatchOp> implements AftRegTrainParams<AftSurvivalRegTrainBatchOp>, WithTrainInfo<LinearModelTrainInfo, AftSurvivalRegTrainBatchOp>, WithModelInfoBatchOp<LinearRegressorModelInfo, AftSurvivalRegTrainBatchOp, AftSurvivalRegModelInfoBatchOp> {
    private static final long serialVersionUID = -7789949822832208166L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/regression/AftSurvivalRegTrainBatchOp$FormatCoef.class */
    public static class FormatCoef extends AbstractRichFunction implements MapFunction<Tuple2<DenseVector, double[]>, Tuple2<DenseVector, double[]>> {
        private static final long serialVersionUID = -5330623238434515619L;
        private double[] std;

        public void open(Configuration configuration) throws Exception {
            this.std = (double[]) getRuntimeContext().getBroadcastVariable("std").get(0);
        }

        public Tuple2<DenseVector, double[]> map(Tuple2<DenseVector, double[]> tuple2) throws Exception {
            DenseVector mo136clone = ((DenseVector) tuple2.f0).mo136clone();
            double[] data = mo136clone.getData();
            double[] data2 = ((DenseVector) tuple2.f0).getData();
            int length = data2.length - 1;
            for (int i = 0; i < length; i++) {
                if (this.std[i] > Criteria.INVALID_GAIN) {
                    data[i] = data2[i] / this.std[i];
                } else {
                    data[i] = 0.0d;
                }
            }
            return Tuple2.of(mo136clone, tuple2.f1);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/regression/AftSurvivalRegTrainBatchOp$FormatLabeledVector.class */
    public static class FormatLabeledVector extends AbstractRichFunction implements MapPartitionFunction<Tuple3<Double, Object, Vector>, Tuple3<Double, Double, Vector>> {
        private static final long serialVersionUID = -1207608955281033320L;
        private double[] std;

        public void open(Configuration configuration) throws Exception {
            this.std = (double[]) getRuntimeContext().getBroadcastVariable("std").get(0);
        }

        public void mapPartition(Iterable<Tuple3<Double, Object, Vector>> iterable, Collector<Tuple3<Double, Double, Vector>> collector) throws Exception {
            for (Tuple3<Double, Object, Vector> tuple3 : iterable) {
                Double d = (Double) tuple3.getField(0);
                if (Math.abs(d.doubleValue() + 1.0d) >= 1.0E-4d && Math.abs(d.doubleValue() + 2.0d) >= 1.0E-4d) {
                    Vector vector = (Vector) tuple3.getField(2);
                    double doubleValue = ((Double) tuple3.f1).doubleValue();
                    if (doubleValue <= Criteria.INVALID_GAIN) {
                        throw new AkIllegalArgumentException("Survival Time must be greater than 0!");
                    }
                    if (!d.equals(Double.valueOf(Criteria.INVALID_GAIN)) && !d.equals(Double.valueOf(1.0d))) {
                        throw new AkIllegalArgumentException("Censor must be 1.0 or 0.0!");
                    }
                    collector.collect(new Tuple3(d, Double.valueOf(Math.log(doubleValue)), vector instanceof SparseVector ? AftSurvivalRegTrainBatchOp.svStd(vector, this.std) : AftSurvivalRegTrainBatchOp.dvStd(vector, this.std)));
                }
            }
        }
    }

    public AftSurvivalRegTrainBatchOp() {
        this(null);
    }

    public AftSurvivalRegTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public AftSurvivalRegTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> batchOperator;
        BatchOperator<?> batchOperator2 = null;
        if (batchOperatorArr.length == 1) {
            batchOperator = checkAndGetFirst(batchOperatorArr);
        } else {
            batchOperator = batchOperatorArr[0];
            batchOperator2 = batchOperatorArr[1];
        }
        String[] featureCols = getFeatureCols();
        TypeInformation typeInformation = Types.DOUBLE;
        DataSource fromElements = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment().fromElements(new Object[]{new Object()});
        final Params params = getParams();
        if (params.contains(HasFeatureCols.FEATURE_COLS) && params.contains(HasVectorCol.VECTOR_COL)) {
            throw new AkIllegalOperatorParameterException("featureCols and vectorCol cannot be set at the same time.");
        }
        params.set((ParamInfo<ParamInfo<String>>) LinearTrainParams.WEIGHT_COL, (ParamInfo<String>) getCensorCol());
        DataSet<Tuple3<Double, Object, Vector>> transform = BaseLinearModelTrainBatchOp.transform(batchOperator, params, true, true);
        DataSet<Tuple3<DenseVector[], Object[], Integer[]>> utilInfo = BaseLinearModelTrainBatchOp.getUtilInfo(transform, true, true);
        MapOperator map = utilInfo.map(new MapFunction<Tuple3<DenseVector[], Object[], Integer[]>, double[]>() { // from class: com.alibaba.alink.operator.batch.regression.AftSurvivalRegTrainBatchOp.1
            private static final long serialVersionUID = -7070926092286155032L;

            public double[] map(Tuple3<DenseVector[], Object[], Integer[]> tuple3) {
                return ((DenseVector[]) tuple3.f0)[1].getData();
            }
        });
        MapOperator map2 = utilInfo.map(new MapFunction<Tuple3<DenseVector[], Object[], Integer[]>, Integer>() { // from class: com.alibaba.alink.operator.batch.regression.AftSurvivalRegTrainBatchOp.2
            private static final long serialVersionUID = 5463028282798602155L;

            public Integer map(Tuple3<DenseVector[], Object[], Integer[]> tuple3) {
                return ((Integer[]) tuple3.f2)[0];
            }
        });
        SingleInputUdfOperator withBroadcastSet = transform.mapPartition(new FormatLabeledVector()).withBroadcastSet(map, "std");
        Operator parallelism = fromElements.map(new MapFunction<Object, Object[]>() { // from class: com.alibaba.alink.operator.batch.regression.AftSurvivalRegTrainBatchOp.3
            private static final long serialVersionUID = -1563051729748477019L;

            /* renamed from: map, reason: merged with bridge method [inline-methods] */
            public Object[] m284map(Object obj) {
                return new Object[0];
            }
        }).mapPartition(new BaseLinearModelTrainBatchOp.CreateMeta("AFTSurvivalRegTrainBatchOp", LinearModelType.AFT, getParams())).setParallelism(1);
        DataSet<Tuple2<DenseVector, double[]>> optimize = BaseLinearModelTrainBatchOp.optimize(getParams(), map2, withBroadcastSet, batchOperator2 == null ? null : batchOperator2.getDataSet().reduceGroup(new GroupReduceFunction<Row, DenseVector>() { // from class: com.alibaba.alink.operator.batch.regression.AftSurvivalRegTrainBatchOp.4
            static final /* synthetic */ boolean $assertionsDisabled;

            public void reduce(Iterable<Row> iterable, Collector<DenseVector> collector) {
                ArrayList arrayList = new ArrayList(0);
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                LinearModelData load = new LinearModelDataConverter().load((List<Row>) arrayList);
                try {
                    if (!$assertionsDisabled && load.hasInterceptItem != ((Boolean) params.get(HasWithIntercept.WITH_INTERCEPT)).booleanValue()) {
                        throw new AssertionError();
                    }
                    collector.collect(load.coefVector);
                } catch (Exception e) {
                    throw new AkIllegalModelException("initial model is not compatible with data and parameter setting.");
                }
            }

            static {
                $assertionsDisabled = !AftSurvivalRegTrainBatchOp.class.desiredAssertionStatus();
            }
        }), LinearModelType.AFT, MLEnvironmentFactory.get(getMLEnvironmentId()));
        Operator parallelism2 = optimize.map(new FormatCoef()).withBroadcastSet(map, "std").mapPartition(new BaseLinearModelTrainBatchOp.BuildModelFromCoefs(typeInformation, featureCols, false, false, null)).withBroadcastSet(parallelism, "meta").setParallelism(1);
        setOutput((DataSet<Row>) parallelism2, new LinearModelDataConverter(typeInformation).getModelSchema());
        setSideOutputTables(BaseLinearModelTrainBatchOp.getSideTablesOfCoefficient(optimize.project(new int[]{1}), parallelism2, transform, utilInfo.map(new MapFunction<Tuple3<DenseVector[], Object[], Integer[]>, Integer>() { // from class: com.alibaba.alink.operator.batch.regression.AftSurvivalRegTrainBatchOp.5
            private static final long serialVersionUID = 2773811388068064638L;

            public Integer map(Tuple3<DenseVector[], Object[], Integer[]> tuple3) {
                return ((Integer[]) tuple3.f2)[0];
            }
        }), (String[]) params.get(LinearTrainParams.FEATURE_COLS), ((Boolean) params.get(LinearTrainParams.WITH_INTERCEPT)).booleanValue(), getMLEnvironmentId().longValue()));
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp
    public AftSurvivalRegModelInfoBatchOp getModelInfoBatchOp() {
        return new AftSurvivalRegModelInfoBatchOp(getParams()).linkFrom(this);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Vector svStd(Vector vector, double[] dArr) {
        SparseVector sparseVector = (SparseVector) vector;
        int[] indices = sparseVector.getIndices();
        double[] values = sparseVector.getValues();
        int length = indices.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            values[i2] = values[i2] / dArr[indices[i]];
        }
        return sparseVector;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Vector dvStd(Vector vector, double[] dArr) {
        DenseVector denseVector = (DenseVector) vector;
        double[] data = denseVector.getData();
        int length = data.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            data[i2] = data[i2] / dArr[i];
        }
        return denseVector;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithTrainInfo
    public LinearModelTrainInfo createTrainInfo(List<Row> list) {
        return new LinearModelTrainInfo(list);
    }

    @Override // com.alibaba.alink.operator.batch.utils.WithTrainInfo
    public BatchOperator<?> getSideOutputTrainInfo() {
        return getSideOutput(0);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ AftSurvivalRegTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }

    @Override // com.alibaba.alink.operator.batch.utils.WithTrainInfo
    public /* bridge */ /* synthetic */ LinearModelTrainInfo createTrainInfo(List list) {
        return createTrainInfo((List<Row>) list);
    }
}
