package com.alibaba.alink.operator.common.optim.subfunc;

import com.alibaba.alink.common.comqueue.ComContext;
import com.alibaba.alink.common.comqueue.CompleteResultFunction;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.model.ModelParamName;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/optim/subfunc/OutputModel.class */
public class OutputModel extends CompleteResultFunction {
    private static final long serialVersionUID = 7674917765793850275L;

    @Override // com.alibaba.alink.common.comqueue.CompleteResultFunction
    public List<Row> calc(ComContext comContext) {
        if (comContext.getTaskId() != 0) {
            return null;
        }
        Tuple2 tuple2 = (Tuple2) comContext.getObj(OptimVariable.minCoef);
        double[] dArr = (double[]) comContext.getObj(OptimVariable.convergenceInfo);
        int length = dArr.length;
        int i = 0;
        while (true) {
            if (i >= dArr.length) {
                break;
            }
            if (Double.isInfinite(dArr[i])) {
                length = i;
                break;
            }
            i++;
        }
        double[] dArr2 = new double[length];
        System.arraycopy(dArr, 0, dArr2, 0, length);
        Params params = new Params();
        for (int i2 = 0; i2 < ((DenseVector) tuple2.f0).size(); i2++) {
            if (Double.isNaN(((DenseVector) tuple2.f0).get(i2)) || Double.isInfinite(((DenseVector) tuple2.f0).get(i2))) {
                throw new AkIllegalDataException("Optimization result has NAN or infinite value, please check your input data and train parameters.");
            }
        }
        params.set((ParamInfo<ParamInfo<DenseVector>>) ModelParamName.COEF, (ParamInfo<DenseVector>) tuple2.f0);
        params.set((ParamInfo<ParamInfo<double[]>>) ModelParamName.LOSS_CURVE, (ParamInfo<double[]>) dArr2);
        ArrayList arrayList = new ArrayList(1);
        arrayList.add(Row.of(new Object[]{params.toJson()}));
        return arrayList;
    }
}
