package com.alibaba.alink.operator.common.classification;

import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.model.LabeledModelDataConverter;
import com.alibaba.alink.common.utils.AlinkSerializable;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.batch.classification.NaiveBayesModelInfo;
import com.alibaba.alink.params.shared.colname.HasFeatureColsDefaultAsNull;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/classification/NaiveBayesModelDataConverter.class */
public class NaiveBayesModelDataConverter extends LabeledModelDataConverter<NaiveBayesModelData, NaiveBayesModelData> {
    private ParamInfo<double[]> LABEL_WEIGHTS;
    private ParamInfo<boolean[]> IS_CATE;
    private ParamInfo<Integer> STRING_INDEXER_MODEL_SIZE;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/classification/NaiveBayesModelDataConverter$NaiveBayesProbInfo.class */
    public static class NaiveBayesProbInfo implements AlinkSerializable {
        public Number[][][] theta;
        public double[] pi;

        private NaiveBayesProbInfo() {
        }
    }

    public NaiveBayesModelDataConverter() {
        this.LABEL_WEIGHTS = ParamInfoFactory.createParamInfo("labelWeights", double[].class).setDescription("the label weights.").build();
        this.IS_CATE = ParamInfoFactory.createParamInfo("isCate", boolean[].class).setDescription("judge whether the feature columns are categorical or not").build();
        this.STRING_INDEXER_MODEL_SIZE = ParamInfoFactory.createParamInfo("stringIndexerModelSize", Integer.class).setDescription("stringIndexerModelSize").build();
    }

    public NaiveBayesModelDataConverter(TypeInformation typeInformation) {
        super(typeInformation);
        this.LABEL_WEIGHTS = ParamInfoFactory.createParamInfo("labelWeights", double[].class).setDescription("the label weights.").build();
        this.IS_CATE = ParamInfoFactory.createParamInfo("isCate", boolean[].class).setDescription("judge whether the feature columns are categorical or not").build();
        this.STRING_INDEXER_MODEL_SIZE = ParamInfoFactory.createParamInfo("stringIndexerModelSize", Integer.class).setDescription("stringIndexerModelSize").build();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.alibaba.alink.common.model.LabeledModelDataConverter
    public Tuple3<Params, Iterable<String>, Iterable<Object>> serializeModel(NaiveBayesModelData naiveBayesModelData) {
        Params params = new Params();
        params.set((ParamInfo<ParamInfo<String[]>>) HasFeatureColsDefaultAsNull.FEATURE_COLS, (ParamInfo<String[]>) naiveBayesModelData.featureNames);
        params.set((ParamInfo<ParamInfo<boolean[]>>) this.IS_CATE, (ParamInfo<boolean[]>) naiveBayesModelData.isCate);
        params.set((ParamInfo<ParamInfo<double[]>>) this.LABEL_WEIGHTS, (ParamInfo<double[]>) naiveBayesModelData.labelWeights);
        NaiveBayesProbInfo naiveBayesProbInfo = new NaiveBayesProbInfo();
        naiveBayesProbInfo.pi = naiveBayesModelData.piArray;
        naiveBayesProbInfo.theta = naiveBayesModelData.theta;
        ArrayList arrayList = new ArrayList();
        arrayList.add(JsonConverter.toJson(naiveBayesProbInfo));
        arrayList.add(JsonConverter.toJson(naiveBayesModelData.weightSum));
        arrayList.add(JsonConverter.toJson(naiveBayesModelData.featureInfo));
        if (naiveBayesModelData.stringIndexerModelSerialized != null) {
            for (Row row : naiveBayesModelData.stringIndexerModelSerialized) {
                Object[] objArr = new Object[row.getArity()];
                for (int i = 0; i < row.getArity(); i++) {
                    objArr[i] = row.getField(i);
                }
                arrayList.add(JsonConverter.toJson(objArr));
            }
        }
        params.set((ParamInfo<ParamInfo<Integer>>) this.STRING_INDEXER_MODEL_SIZE, (ParamInfo<Integer>) Integer.valueOf(arrayList.size()));
        return Tuple3.of(params, arrayList, Arrays.asList(naiveBayesModelData.label));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.common.model.LabeledModelDataConverter
    protected NaiveBayesModelData deserializeModel(Params params, Iterable<String> iterable, Iterable<Object> iterable2) {
        int intValue = ((Integer) params.get(this.STRING_INDEXER_MODEL_SIZE)).intValue();
        NaiveBayesModelData naiveBayesModelData = new NaiveBayesModelData();
        naiveBayesModelData.stringIndexerModelSerialized = new ArrayList(intValue);
        int i = 0;
        for (String str : iterable) {
            if (i == 0) {
                NaiveBayesProbInfo naiveBayesProbInfo = (NaiveBayesProbInfo) JsonConverter.fromJson(str, NaiveBayesProbInfo.class);
                naiveBayesModelData.piArray = naiveBayesProbInfo.pi;
                naiveBayesModelData.theta = naiveBayesProbInfo.theta;
            } else if (i == 1) {
                naiveBayesModelData.weightSum = (double[][]) JsonConverter.fromJson(str, double[][].class);
            } else if (i == 2) {
                naiveBayesModelData.featureInfo = (SparseVector[][]) JsonConverter.fromJson(str, SparseVector[][].class);
            } else {
                Object[] objArr = (Object[]) JsonConverter.fromJson(str, Object[].class);
                naiveBayesModelData.stringIndexerModelSerialized.add(Row.of(new Object[]{Long.valueOf(((Integer) objArr[0]).longValue()), objArr[1], objArr[2]}));
            }
            i++;
        }
        naiveBayesModelData.featureNames = (String[]) params.get(HasFeatureColsDefaultAsNull.FEATURE_COLS);
        naiveBayesModelData.isCate = (boolean[]) params.get(this.IS_CATE);
        naiveBayesModelData.labelWeights = (double[]) params.get(this.LABEL_WEIGHTS);
        naiveBayesModelData.label = Iterables.toArray(iterable2, Object.class);
        return naiveBayesModelData;
    }

    public NaiveBayesModelInfo loadModelInfo(List<Row> list) {
        NaiveBayesModelInfo naiveBayesModelInfo = (NaiveBayesModelInfo) JsonConverter.fromJson((String) list.get(0).getField(1), NaiveBayesModelInfo.class);
        naiveBayesModelInfo.stringIndexerModelSerialized = list.subList(1, list.size());
        return naiveBayesModelInfo;
    }

    @Override // com.alibaba.alink.common.model.LabeledModelDataConverter
    protected /* bridge */ /* synthetic */ NaiveBayesModelData deserializeModel(Params params, Iterable iterable, Iterable iterable2) {
        return deserializeModel(params, (Iterable<String>) iterable, (Iterable<Object>) iterable2);
    }
}
