package com.alibaba.alink.operator.common.linear;

import com.alibaba.alink.common.io.filesystem.copy.csv.CsvInputFormat;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/linear/LinearModelTrainInfo.class */
public final class LinearModelTrainInfo implements Serializable {
    private static final long serialVersionUID = 7999201781768270042L;
    private String[] convInfo;
    private Params meta;
    private String[] colNames;
    private double[] weight;
    private double[] importance;
    private static final Comparator compare = (tuple2, tuple22) -> {
        return ((Double) tuple22.f1).compareTo((Double) tuple2.f1);
    };

    public LinearModelTrainInfo(List<Row> list) {
        DecimalFormat decimalFormat = new DecimalFormat("#0.00000000");
        for (Row row : list) {
            if (((Long) row.getField(0)).longValue() == 0) {
                this.meta = (Params) JsonConverter.fromJson((String) row.getField(1), Params.class);
            } else if (((Long) row.getField(0)).longValue() == 1) {
                this.colNames = (String[]) JsonConverter.fromJson((String) row.getField(1), String[].class);
            } else if (((Long) row.getField(0)).longValue() == 2) {
                this.weight = (double[]) JsonConverter.fromJson((String) row.getField(1), double[].class);
            } else if (((Long) row.getField(0)).longValue() == 3) {
                this.importance = (double[]) JsonConverter.fromJson((String) row.getField(1), double[].class);
            } else if (((Long) row.getField(0)).longValue() == 4) {
                double[] dArr = (double[]) JsonConverter.fromJson((String) row.getField(1), double[].class);
                int length = dArr.length / 3;
                this.convInfo = new String[length];
                for (int i = 0; i < length; i++) {
                    this.convInfo[i] = "step:" + i + " loss:" + decimalFormat.format(dArr[3 * i]) + " gradNorm:" + decimalFormat.format(dArr[(3 * i) + 1]) + " learnRate:" + decimalFormat.format(dArr[(3 * i) + 2]);
                }
            }
        }
    }

    public String[] getConvInfo() {
        return this.convInfo;
    }

    public Params getMeta() {
        return this.meta;
    }

    public String[] getColNames() {
        return this.colNames;
    }

    public double[] getWeight() {
        return this.weight;
    }

    public double[] getImportance() {
        return this.importance;
    }

    private List<Tuple2<String, Double>> getWeightList() {
        ArrayList arrayList = new ArrayList();
        if (this.weight.length == this.importance.length) {
            for (int i = 0; i < this.weight.length; i++) {
                arrayList.add(Tuple2.of(this.colNames[i], Double.valueOf(this.weight[i])));
            }
        } else {
            for (int i2 = 0; i2 < this.importance.length; i2++) {
                arrayList.add(Tuple2.of(this.colNames[i2], Double.valueOf(this.weight[i2 + 1])));
            }
        }
        arrayList.sort(compare);
        return arrayList;
    }

    private List<Tuple2<String, Double>> getImportanceList() {
        ArrayList arrayList = new ArrayList();
        if (this.weight.length == this.importance.length) {
            for (int i = 0; i < this.weight.length; i++) {
                arrayList.add(Tuple2.of(this.colNames[i], Double.valueOf(this.importance[i])));
            }
        } else {
            for (int i2 = 0; i2 < this.importance.length; i2++) {
                arrayList.add(Tuple2.of(this.colNames[i2], Double.valueOf(this.importance[i2])));
            }
        }
        arrayList.sort(compare);
        return arrayList;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(PrettyDisplayUtils.displayHeadline("train meta info", '-'));
        HashMap hashMap = new HashMap();
        hashMap.put("model name", this.meta.get(ModelParamName.MODEL_NAME));
        hashMap.put("num feature", ((Integer) this.meta.get(ModelParamName.VECTOR_SIZE)).toString());
        sb.append(PrettyDisplayUtils.displayMap(hashMap, 2, false)).append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
        if (!((String) this.meta.get(ModelParamName.MODEL_NAME)).equals("softmax")) {
            sb.append(PrettyDisplayUtils.displayHeadline("train importance info", '-'));
            DecimalFormat decimalFormat = new DecimalFormat("#0.00000000");
            List<Tuple2<String, Double>> weightList = getWeightList();
            List<Tuple2<String, Double>> importanceList = getImportanceList();
            if (importanceList.size() < 6) {
                Object[][] objArr = new Object[importanceList.size()][4];
                for (int i = 0; i < importanceList.size(); i++) {
                    objArr[i][0] = importanceList.get(i).f0;
                    objArr[i][1] = decimalFormat.format(importanceList.get(i).f1);
                    objArr[i][2] = weightList.get(i).f0;
                    objArr[i][3] = decimalFormat.format(weightList.get(i).f1);
                }
                sb.append(PrettyDisplayUtils.displayTable(objArr, importanceList.size(), 4, null, new String[]{"colName", "importanceValue", "colName", "weightValue"}, null, importanceList.size(), 4));
            } else {
                Object[][] objArr2 = new Object[7][4];
                for (int i2 = 0; i2 < 3; i2++) {
                    objArr2[i2][0] = importanceList.get(i2).f0;
                    objArr2[i2][1] = decimalFormat.format(importanceList.get(i2).f1);
                    objArr2[i2][2] = weightList.get(i2).f0;
                    objArr2[i2][3] = decimalFormat.format(weightList.get(i2).f1);
                }
                for (int i3 = 0; i3 < 4; i3++) {
                    objArr2[3][i3] = "... ...";
                }
                for (int i4 = 3; i4 > 0; i4--) {
                    int size = importanceList.size() - i4;
                    objArr2[7 - i4][0] = importanceList.get(size).f0;
                    objArr2[7 - i4][1] = decimalFormat.format(importanceList.get(size).f1);
                    objArr2[7 - i4][2] = weightList.get(size).f0;
                    objArr2[7 - i4][3] = decimalFormat.format(weightList.get(size).f1);
                }
                sb.append(PrettyDisplayUtils.displayTable(objArr2, 7, 4, null, new String[]{"colName", "importanceValue", "colName", "weightValue"}, null, 7, 4));
            }
        }
        sb.append(PrettyDisplayUtils.displayHeadline("train convergence info", '-'));
        if (this.convInfo.length < 20) {
            for (String str : this.convInfo) {
                sb.append(str).append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
            }
        } else {
            for (int i5 = 0; i5 < 10; i5++) {
                sb.append(this.convInfo[i5]).append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
            }
            sb.append("... ... ... ...\n");
            for (int length = this.convInfo.length - 10; length < this.convInfo.length; length++) {
                sb.append(this.convInfo[length]).append(CsvInputFormat.DEFAULT_LINE_DELIMITER);
            }
        }
        return sb.toString();
    }
}
