package com.alibaba.alink.operator.batch.regression;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.common.optim.activeSet.ConstraintVariable;
import com.alibaba.alink.operator.common.regression.GlmModelData;
import com.alibaba.alink.operator.common.regression.GlmModelDataConverter;
import com.alibaba.alink.operator.common.regression.glm.FamilyLink;
import com.alibaba.alink.operator.common.regression.glm.GlmModelInfo;
import com.alibaba.alink.operator.common.regression.glm.GlmUtil;
import com.alibaba.alink.params.regression.GlmTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL), @PortSpec(PortType.DATA), @PortSpec(PortType.DATA)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "labelCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})})
@NameCn("广义线性回归训练")
@NameEn("GLM Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.GeneralizedLinearRegression")
/* loaded from: input_file:com/alibaba/alink/operator/batch/regression/GlmTrainBatchOp.class */
public final class GlmTrainBatchOp extends BatchOperator<GlmTrainBatchOp> implements GlmTrainParams<GlmTrainBatchOp>, WithModelInfoBatchOp<GlmModelInfo, GlmTrainBatchOp, GlmModelInfoBatchOp> {
    private static final long serialVersionUID = -4589724139230288132L;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/regression/GlmTrainBatchOp$BuildModel.class */
    public static class BuildModel extends RichMapPartitionFunction<GlmUtil.WeightedLeastSquaresModel, Row> {
        private static final long serialVersionUID = -1832850372942112225L;
        private String[] featureColNames;
        private String offsetColName;
        private String weightColName;
        private String labelColName;
        private GlmTrainParams.Family familyName;
        private double variancePower;
        private GlmTrainParams.Link linkName;
        private double linkPower;
        private boolean fitIntercept;
        private int numIter;
        private double epsilon;
        private GlmUtil.GlmModelSummary summary;

        public BuildModel(String[] strArr, String str, String str2, String str3, GlmTrainParams.Family family, double d, GlmTrainParams.Link link, double d2, Boolean bool, int i, double d3) {
            this.featureColNames = strArr;
            this.offsetColName = str;
            this.weightColName = str2;
            this.labelColName = str3;
            this.familyName = family;
            this.variancePower = d;
            this.linkName = link;
            this.linkPower = d2;
            this.fitIntercept = bool.booleanValue();
            this.numIter = i;
            this.epsilon = d3;
        }

        public void open(Configuration configuration) throws Exception {
            this.summary = (GlmUtil.GlmModelSummary) getRuntimeContext().getBroadcastVariable("summary").get(0);
        }

        public void mapPartition(Iterable<GlmUtil.WeightedLeastSquaresModel> iterable, Collector<Row> collector) throws Exception {
            GlmUtil.WeightedLeastSquaresModel next = iterable.iterator().next();
            GlmModelDataConverter glmModelDataConverter = new GlmModelDataConverter();
            GlmModelData glmModelData = new GlmModelData();
            glmModelData.featureColNames = this.featureColNames;
            glmModelData.offsetColName = this.offsetColName;
            glmModelData.weightColName = this.weightColName;
            glmModelData.labelColName = this.labelColName;
            glmModelData.familyName = this.familyName;
            glmModelData.variancePower = this.variancePower;
            glmModelData.linkName = this.linkName;
            glmModelData.linkPower = this.linkPower;
            glmModelData.coefficients = next.coefficients;
            glmModelData.intercept = next.intercept;
            glmModelData.diagInvAtWA = next.diagInvAtWA;
            glmModelData.fitIntercept = this.fitIntercept;
            glmModelData.numIter = this.numIter;
            glmModelData.epsilon = this.epsilon;
            glmModelData.modelSummary = this.summary;
            glmModelDataConverter.save(glmModelData, collector);
        }
    }

    public GlmTrainBatchOp() {
    }

    public GlmTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public GlmTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String[] featureCols = getFeatureCols();
        String labelCol = getLabelCol();
        String weightCol = getWeightCol();
        String offsetCol = getOffsetCol();
        GlmTrainParams.Family family = getFamily();
        GlmTrainParams.Link link = getLink();
        double doubleValue = getVariancePower().doubleValue();
        double doubleValue2 = getLinkPower().doubleValue();
        int intValue = getMaxIter().intValue();
        double doubleValue3 = getEpsilon().doubleValue();
        boolean booleanValue = getFitIntercept().booleanValue();
        double doubleValue4 = getRegParam().doubleValue();
        int length = featureCols.length;
        FamilyLink familyLink = new FamilyLink(family, doubleValue, link, doubleValue2);
        DataSet<Row> preProc = GlmUtil.preProc(checkAndGetFirst, featureCols, offsetCol, weightCol, labelCol);
        DataSet<GlmUtil.WeightedLeastSquaresModel> train = GlmUtil.train(preProc, length, familyLink, doubleValue4, booleanValue, intValue, doubleValue3);
        DataSet<Row> residual = GlmUtil.residual(train, preProc, length, familyLink);
        DataSet<GlmUtil.GlmModelSummary> aggSummary = GlmUtil.aggSummary(residual, train, length, familyLink, doubleValue4, intValue, doubleValue3, booleanValue);
        setOutput((DataSet<Row>) train.mapPartition(new BuildModel(featureCols, offsetCol, weightCol, labelCol, family, doubleValue, link, doubleValue2, Boolean.valueOf(booleanValue), intValue, doubleValue3)).setParallelism(1).withBroadcastSet(aggSummary, "summary"), new GlmModelDataConverter().getModelSchema());
        String[] strArr = new String[length + 4 + 4];
        TypeInformation[] typeInformationArr = new TypeInformation[length + 4 + 4];
        for (int i = 0; i < length; i++) {
            strArr[i] = featureCols[i];
            typeInformationArr[i] = Types.DOUBLE;
        }
        strArr[length] = "label";
        typeInformationArr[length] = Types.DOUBLE;
        strArr[length + 1] = ConstraintVariable.weight;
        typeInformationArr[length + 1] = Types.DOUBLE;
        strArr[length + 2] = "offset";
        typeInformationArr[length + 2] = Types.DOUBLE;
        strArr[length + 3] = "pred";
        typeInformationArr[length + 3] = Types.DOUBLE;
        strArr[length + 4] = "residualdevianceResiduals";
        typeInformationArr[length + 4] = Types.DOUBLE;
        strArr[length + 5] = "pearsonResiduals";
        typeInformationArr[length + 5] = Types.DOUBLE;
        strArr[length + 6] = "workingResiduals";
        typeInformationArr[length + 6] = Types.DOUBLE;
        strArr[length + 7] = "responseResiduals";
        typeInformationArr[length + 7] = Types.DOUBLE;
        setSideOutputTables(new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), residual, strArr, (TypeInformation<?>[]) typeInformationArr), DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) aggSummary.map(new MapFunction<GlmUtil.GlmModelSummary, Row>() { // from class: com.alibaba.alink.operator.batch.regression.GlmTrainBatchOp.1
            public Row map(GlmUtil.GlmModelSummary glmModelSummary) {
                return Row.of(new Object[]{JsonConverter.toJson(glmModelSummary)});
            }
        }), new String[]{"summary"}, (TypeInformation<?>[]) new TypeInformation[]{Types.STRING})});
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp
    public GlmModelInfoBatchOp getModelInfoBatchOp() {
        return new GlmModelInfoBatchOp(getParams()).linkFrom(this);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ GlmTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
