package com.alibaba.alink.operator.common.linear;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.model.LabeledModelDataConverter;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.params.shared.colname.HasLabelCol;
import com.alibaba.alink.params.shared.colname.HasVectorCol;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/linear/LinearModelDataConverter.class */
public class LinearModelDataConverter extends LabeledModelDataConverter<LinearModelData, LinearModelData> {

    /* loaded from: input_file:com/alibaba/alink/operator/common/linear/LinearModelDataConverter$ModelData.class */
    public static class ModelData implements Serializable {
        private static final long serialVersionUID = -1529656006252686121L;
        public String[] featureColNames = null;
        public String[] featureColTypes = null;
        public DenseVector coefVector = null;
        public DenseVector[] coefVectors = null;
        public double[] convergenceInfo = null;
    }

    public LinearModelDataConverter() {
        this(null);
    }

    public LinearModelDataConverter(TypeInformation typeInformation) {
        super(typeInformation);
    }

    @Override // com.alibaba.alink.common.model.LabeledModelDataConverter
    public Tuple3<Params, Iterable<String>, Iterable<Object>> serializeModel(LinearModelData linearModelData) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(JsonConverter.toJson(getModelData(linearModelData)));
        return Tuple3.of(getMetaInfo(linearModelData), arrayList, linearModelData.labelValues == null ? null : Arrays.asList(linearModelData.labelValues));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.common.model.LabeledModelDataConverter
    public LinearModelData deserializeModel(Params params, Iterable<String> iterable, Iterable<Object> iterable2) {
        LinearModelData linearModelData = new LinearModelData();
        if (params.contains(ModelParamName.LABEL_VALUES)) {
            linearModelData.labelValues = FeatureLabelUtil.recoverLabelType((Object[]) params.get(ModelParamName.LABEL_VALUES), this.labelType);
        }
        setMetaInfo(params, linearModelData);
        if (iterable2 != null) {
            ArrayList arrayList = new ArrayList();
            arrayList.getClass();
            iterable2.forEach(arrayList::add);
            linearModelData.labelValues = arrayList.toArray();
        }
        setModelData((ModelData) JsonConverter.fromJson(iterable.iterator().next(), ModelData.class), linearModelData);
        return linearModelData;
    }

    @Override // com.alibaba.alink.common.model.LabeledModelDataConverter, com.alibaba.alink.common.model.ModelDataConverter
    public LinearModelData load(List<Row> list) {
        if (list.get(0).getArity() != 4) {
            return (LinearModelData) super.load(list);
        }
        LinearModelData linearModelData = new LinearModelData();
        linearModelData.loadOldFromatModel(list);
        return linearModelData;
    }

    private Params getMetaInfo(LinearModelData linearModelData) {
        Params params = new Params();
        params.set((ParamInfo<ParamInfo<String>>) ModelParamName.MODEL_NAME, (ParamInfo<String>) linearModelData.modelName);
        params.set((ParamInfo<ParamInfo<Boolean>>) ModelParamName.HAS_INTERCEPT_ITEM, (ParamInfo<Boolean>) Boolean.valueOf(linearModelData.hasInterceptItem));
        params.set((ParamInfo<ParamInfo<LinearModelType>>) ModelParamName.LINEAR_MODEL_TYPE, (ParamInfo<LinearModelType>) linearModelData.linearModelType);
        if (linearModelData.vectorColName != null) {
            params.set((ParamInfo<ParamInfo<String>>) HasVectorCol.VECTOR_COL, (ParamInfo<String>) linearModelData.vectorColName);
        }
        params.set((ParamInfo<ParamInfo<Integer>>) ModelParamName.VECTOR_SIZE, (ParamInfo<Integer>) Integer.valueOf(linearModelData.vectorSize));
        params.set((ParamInfo<ParamInfo<String>>) HasLabelCol.LABEL_COL, (ParamInfo<String>) linearModelData.labelName);
        return params;
    }

    private void setMetaInfo(Params params, LinearModelData linearModelData) {
        linearModelData.modelName = (String) params.get(ModelParamName.MODEL_NAME);
        linearModelData.linearModelType = params.contains(ModelParamName.LINEAR_MODEL_TYPE) ? (LinearModelType) params.get(ModelParamName.LINEAR_MODEL_TYPE) : null;
        linearModelData.hasInterceptItem = params.contains(ModelParamName.HAS_INTERCEPT_ITEM) ? ((Boolean) params.get(ModelParamName.HAS_INTERCEPT_ITEM)).booleanValue() : true;
        linearModelData.vectorSize = params.contains(ModelParamName.VECTOR_SIZE) ? ((Integer) params.get(ModelParamName.VECTOR_SIZE)).intValue() : 0;
        linearModelData.vectorColName = params.contains(HasVectorCol.VECTOR_COL) ? (String) params.get(HasVectorCol.VECTOR_COL) : null;
        linearModelData.labelName = params.contains(HasLabelCol.LABEL_COL) ? (String) params.get(HasLabelCol.LABEL_COL) : null;
    }

    private ModelData getModelData(LinearModelData linearModelData) {
        ModelData modelData = new ModelData();
        modelData.featureColNames = linearModelData.featureNames;
        modelData.featureColTypes = linearModelData.featureTypes;
        modelData.coefVector = linearModelData.coefVector;
        modelData.coefVectors = linearModelData.coefVectors;
        return modelData;
    }

    private void setModelData(ModelData modelData, LinearModelData linearModelData) {
        linearModelData.featureNames = modelData.featureColNames;
        linearModelData.featureTypes = modelData.featureColTypes;
        linearModelData.coefVector = modelData.coefVector;
        if (linearModelData.modelName.equals("softmax")) {
            double[] data = modelData.coefVector.getData();
            int length = linearModelData.labelValues.length;
            int length2 = data.length / (length - 1);
            linearModelData.coefVectors = new DenseVector[length - 1];
            for (int i = 0; i < length - 1; i++) {
                linearModelData.coefVectors[i] = new DenseVector(length2);
                for (int i2 = 0; i2 < length2; i2++) {
                    linearModelData.coefVectors[i].set(i2, data[(i * length2) + i2]);
                }
            }
        }
    }

    @Override // com.alibaba.alink.common.model.LabeledModelDataConverter, com.alibaba.alink.common.model.ModelDataConverter
    public /* bridge */ /* synthetic */ Object load(List list) {
        return load((List<Row>) list);
    }

    @Override // com.alibaba.alink.common.model.LabeledModelDataConverter
    public /* bridge */ /* synthetic */ LinearModelData deserializeModel(Params params, Iterable iterable, Iterable iterable2) {
        return deserializeModel(params, (Iterable<String>) iterable, (Iterable<Object>) iterable2);
    }
}
