package com.alibaba.alink.operator.common.classification;

import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.ComboModelMapper;
import com.alibaba.alink.common.mapper.Mapper;
import com.alibaba.alink.common.mapper.RichModelMapper;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.operator.common.io.types.JdbcTypeConverter;
import com.alibaba.alink.operator.common.linear.LinearModelMapper;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.common.tree.predictors.GbdtModelMapper;
import com.alibaba.alink.params.classification.OneVsRestPredictParams;
import com.alibaba.alink.params.shared.colname.HasOutputCol;
import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull;
import com.alibaba.alink.params.shared.colname.HasVectorCol;
import com.alibaba.alink.pipeline.classification.GbdtClassifier;
import com.alibaba.alink.pipeline.classification.LinearSvm;
import com.alibaba.alink.pipeline.classification.LogisticRegression;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

/* loaded from: input_file:com/alibaba/alink/operator/common/classification/OneVsRestModelMapper.class */
public class OneVsRestModelMapper extends ComboModelMapper {
    private static final long serialVersionUID = 7008077848896699027L;
    private List<Mapper> mapperList;
    private static final String ONE_VS_REST_RESULT_VECTOR_COL_NAME = "one_vs_rest_result_vector_internal_implement";
    private static final String ONE_VS_REST_PRED_RESULT_COL_NAME = "one_vs_rest_pred_result_internal_implement";
    private static final String ONE_VS_REST_PRED_DETAIL_COL_NAME = "one_vs_rest_pred_detail_internal_implement";

    /* loaded from: input_file:com/alibaba/alink/operator/common/classification/OneVsRestModelMapper$InitialResultMapper.class */
    private static class InitialResultMapper extends Mapper {
        private final int numClasses;

        public InitialResultMapper(TableSchema tableSchema, Params params) {
            super(tableSchema, params);
            this.numClasses = ((Integer) params.get(ModelParamName.NUM_CLASSES)).intValue();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.alibaba.alink.common.mapper.Mapper
        public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
            slicedResult.set(0, new DenseVector(this.numClasses));
        }

        @Override // com.alibaba.alink.common.mapper.Mapper
        protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, Params params) {
            return Tuple4.of(tableSchema.getFieldNames(), new String[]{(String) params.get(HasOutputCol.OUTPUT_COL)}, new TypeInformation[]{AlinkTypes.DENSE_VECTOR}, params.get(HasReservedColsDefaultAsNull.RESERVED_COLS));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/classification/OneVsRestModelMapper$SetPositiveResultMapper.class */
    private static class SetPositiveResultMapper extends Mapper {
        private final int classIndex;
        public static final ParamInfo<Integer> CLASS_INDEX = ParamInfoFactory.createParamInfo("classIndex", Integer.class).build();

        public SetPositiveResultMapper(TableSchema tableSchema, Params params) {
            super(tableSchema, params);
            this.classIndex = ((Integer) params.get(CLASS_INDEX)).intValue();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Type inference failed for: r1v1, types: [com.alibaba.alink.operator.common.classification.OneVsRestModelMapper$SetPositiveResultMapper$1] */
        @Override // com.alibaba.alink.common.mapper.Mapper
        public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
            Map map = (Map) JsonConverter.fromJson((String) slicedSelectedSample.get(0), new TypeReference<Map<String, String>>() { // from class: com.alibaba.alink.operator.common.classification.OneVsRestModelMapper.SetPositiveResultMapper.1
            }.getType());
            DenseVector denseVector = VectorUtil.getDenseVector(slicedSelectedSample.get(1));
            denseVector.set(this.classIndex, Double.parseDouble((String) map.get("1.0")));
            slicedResult.set(0, denseVector);
        }

        @Override // com.alibaba.alink.common.mapper.Mapper
        protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, Params params) {
            return Tuple4.of(new String[]{(String) params.get(OneVsRestPredictParams.PREDICTION_DETAIL_COL), (String) params.get(HasVectorCol.VECTOR_COL)}, new String[]{(String) params.get(HasOutputCol.OUTPUT_COL)}, new TypeInformation[]{AlinkTypes.DENSE_VECTOR}, params.get(HasReservedColsDefaultAsNull.RESERVED_COLS));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/classification/OneVsRestModelMapper$VoteMapper.class */
    private static class VoteMapper extends Mapper {
        private final boolean predDetail;
        private final List<Object> labels;

        public VoteMapper(TableSchema tableSchema, Params params) {
            super(tableSchema, params);
            this.predDetail = params.contains(OneVsRestPredictParams.PREDICTION_DETAIL_COL);
            this.labels = OneVsRestModelMapper.recoverLabel((String) params.get(ModelParamName.LABELS), (String) params.get(ModelParamName.LABEL_TYPE_NAME));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.alibaba.alink.common.mapper.Mapper
        public void map(Mapper.SlicedSelectedSample slicedSelectedSample, Mapper.SlicedResult slicedResult) throws Exception {
            double[] data = VectorUtil.getDenseVector(slicedSelectedSample.get(0)).getData();
            int i = -1;
            double d = -1.7976931348623157E308d;
            double d2 = 0.0d;
            for (int i2 = 0; i2 < data.length; i2++) {
                d2 += data[i2];
                if (d < data[i2]) {
                    i = i2;
                    d = data[i2];
                }
            }
            if (this.predDetail) {
                HashMap hashMap = new HashMap(this.labels.size());
                this.labels.forEach(obj -> {
                    hashMap.put(obj, Double.valueOf(Criteria.INVALID_GAIN));
                });
                for (int i3 = 0; i3 < data.length; i3++) {
                    hashMap.replace(this.labels.get(i3), Double.valueOf(data[i3] / d2));
                }
                slicedResult.set(1, JsonConverter.gson.toJson(hashMap));
            }
            slicedResult.set(0, this.labels.get(i));
        }

        @Override // com.alibaba.alink.common.mapper.Mapper
        protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, Params params) {
            String str = (String) params.get(HasVectorCol.VECTOR_COL);
            String str2 = (String) params.get(OneVsRestPredictParams.PREDICTION_COL);
            String[] strArr = (String[]) params.get(OneVsRestPredictParams.RESERVED_COLS);
            TypeInformation<?> flinkType = FlinkTypeConverter.getFlinkType((String) params.get(ModelParamName.LABEL_TYPE_NAME));
            return params.contains(OneVsRestPredictParams.PREDICTION_DETAIL_COL) ? Tuple4.of(new String[]{str}, new String[]{str2, (String) params.get(OneVsRestPredictParams.PREDICTION_DETAIL_COL)}, new TypeInformation[]{flinkType, Types.STRING}, strArr) : Tuple4.of(new String[]{str}, new String[]{str2}, new TypeInformation[]{flinkType}, strArr);
        }
    }

    public OneVsRestModelMapper(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        super(tableSchema, tableSchema2, params);
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    public void loadModel(List<Row> list) {
        Params extractMeta = extractMeta(list);
        int intValue = ((Integer) extractMeta.get(ModelParamName.NUM_CLASSES)).intValue();
        String str = (String) extractMeta.get(ModelParamName.LABELS);
        String str2 = (String) extractMeta.get(ModelParamName.LABEL_TYPE_NAME);
        String str3 = (String) extractMeta.get(ModelParamName.BIN_CLS_CLASS_NAME);
        String[] strArr = (String[]) extractMeta.get(ModelParamName.MODEL_COL_NAMES);
        Integer[] numArr = (Integer[]) extractMeta.get(ModelParamName.MODEL_COL_TYPES);
        TypeInformation[] typeInformationArr = new TypeInformation[numArr.length];
        for (int i = 0; i < numArr.length; i++) {
            typeInformationArr[i] = JdbcTypeConverter.getFlinkType(numArr[i].intValue());
        }
        this.mapperList = new ArrayList();
        String[] fieldNames = getDataSchema().getFieldNames();
        Params params = this.params.m1495clone().set((ParamInfo<ParamInfo<String[]>>) OneVsRestPredictParams.RESERVED_COLS, (ParamInfo<String[]>) ArrayUtils.add(fieldNames, ONE_VS_REST_RESULT_VECTOR_COL_NAME)).set((ParamInfo<ParamInfo<String>>) OneVsRestPredictParams.PREDICTION_COL, (ParamInfo<String>) ONE_VS_REST_PRED_RESULT_COL_NAME).set((ParamInfo<ParamInfo<String>>) OneVsRestPredictParams.PREDICTION_DETAIL_COL, (ParamInfo<String>) ONE_VS_REST_PRED_DETAIL_COL_NAME);
        InitialResultMapper initialResultMapper = new InitialResultMapper(getDataSchema(), this.params.m1495clone().set((ParamInfo<ParamInfo<String[]>>) OneVsRestPredictParams.RESERVED_COLS, (ParamInfo<String[]>) fieldNames).set((ParamInfo<ParamInfo<Integer>>) ModelParamName.NUM_CLASSES, (ParamInfo<Integer>) Integer.valueOf(intValue)).set((ParamInfo<ParamInfo<String>>) HasOutputCol.OUTPUT_COL, (ParamInfo<String>) ONE_VS_REST_RESULT_VECTOR_COL_NAME));
        this.mapperList.add(initialResultMapper);
        TableSchema tableSchema = null;
        for (int i2 = 0; i2 < intValue; i2++) {
            try {
                ArrayList arrayList = new ArrayList();
                for (Row row : list) {
                    if (row.getField(2) != null) {
                        if (i2 == ((Long) row.getField(2)).longValue()) {
                            Row row2 = new Row(row.getArity() - 4);
                            for (int i3 = 0; i3 < row2.getArity(); i3++) {
                                row2.setField(i3, row.getField(3 + i3));
                            }
                            arrayList.add(row2);
                        }
                    }
                }
                RichModelMapper createModelPredictor = createModelPredictor(str3, new TableSchema(strArr, typeInformationArr), initialResultMapper.getOutputSchema(), params, arrayList);
                SetPositiveResultMapper setPositiveResultMapper = new SetPositiveResultMapper(createModelPredictor.getOutputSchema(), this.params.m1495clone().merge(params).set((ParamInfo<ParamInfo<Integer>>) SetPositiveResultMapper.CLASS_INDEX, (ParamInfo<Integer>) Integer.valueOf(i2)).set((ParamInfo<ParamInfo<String>>) HasVectorCol.VECTOR_COL, (ParamInfo<String>) ONE_VS_REST_RESULT_VECTOR_COL_NAME).set((ParamInfo<ParamInfo<String>>) HasOutputCol.OUTPUT_COL, (ParamInfo<String>) ONE_VS_REST_RESULT_VECTOR_COL_NAME).set((ParamInfo<ParamInfo<String[]>>) OneVsRestPredictParams.RESERVED_COLS, (ParamInfo<String[]>) fieldNames));
                this.mapperList.add(createModelPredictor);
                this.mapperList.add(setPositiveResultMapper);
                tableSchema = setPositiveResultMapper.getOutputSchema();
            } catch (Exception e) {
                throw new AkUnclassifiedErrorException("Error. ", e);
            }
        }
        String[] strArr2 = (String[]) this.params.get(OneVsRestPredictParams.RESERVED_COLS);
        this.mapperList.add(new VoteMapper(tableSchema, this.params.m1495clone().set((ParamInfo<ParamInfo<String>>) ModelParamName.LABELS, (ParamInfo<String>) str).set((ParamInfo<ParamInfo<String>>) ModelParamName.LABEL_TYPE_NAME, (ParamInfo<String>) str2).set((ParamInfo<ParamInfo<String>>) HasVectorCol.VECTOR_COL, (ParamInfo<String>) ONE_VS_REST_RESULT_VECTOR_COL_NAME).set((ParamInfo<ParamInfo<String[]>>) OneVsRestPredictParams.RESERVED_COLS, (ParamInfo<String[]>) (null == strArr2 ? fieldNames : strArr2))));
    }

    @Override // com.alibaba.alink.common.mapper.ComboModelMapper
    public List<Mapper> getLoadedMapperList() {
        return this.mapperList;
    }

    @Override // com.alibaba.alink.common.mapper.ModelMapper
    protected Tuple4<String[], String[], TypeInformation<?>[], String[]> prepareIoSchema(TableSchema tableSchema, TableSchema tableSchema2, Params params) {
        String str = (String) params.get(OneVsRestPredictParams.PREDICTION_COL);
        String[] strArr = (String[]) params.get(OneVsRestPredictParams.RESERVED_COLS);
        TypeInformation typeInformation = tableSchema.getFieldTypes()[tableSchema.getFieldNames().length - 1];
        return params.contains(OneVsRestPredictParams.PREDICTION_DETAIL_COL) ? Tuple4.of(tableSchema2.getFieldNames(), new String[]{str, (String) params.get(OneVsRestPredictParams.PREDICTION_DETAIL_COL)}, new TypeInformation[]{typeInformation, Types.STRING}, strArr) : Tuple4.of(tableSchema2.getFieldNames(), new String[]{str}, new TypeInformation[]{typeInformation}, strArr);
    }

    private static void recoverLabelType(List<Object> list, String str) {
        if (str.equals(FlinkTypeConverter.getTypeString((TypeInformation<?>) Types.LONG))) {
            for (int i = 0; i < list.size(); i++) {
                list.set(i, Long.valueOf(((Double) list.get(i)).longValue()));
            }
            return;
        }
        if (str.equals(FlinkTypeConverter.getTypeString((TypeInformation<?>) Types.INT))) {
            for (int i2 = 0; i2 < list.size(); i2++) {
                list.set(i2, Integer.valueOf(((Double) list.get(i2)).intValue()));
            }
            return;
        }
        if (str.equals(FlinkTypeConverter.getTypeString((TypeInformation<?>) Types.FLOAT))) {
            for (int i3 = 0; i3 < list.size(); i3++) {
                list.set(i3, Float.valueOf(((Double) list.get(i3)).floatValue()));
            }
        }
    }

    private static RichModelMapper createModelPredictor(String str, TableSchema tableSchema, TableSchema tableSchema2, Params params, List<Row> list) {
        RichModelMapper linearModelMapper;
        if (str.equals(LogisticRegression.class.getCanonicalName()) || str.equals(LinearSvm.class.getCanonicalName())) {
            linearModelMapper = new LinearModelMapper(tableSchema, tableSchema2, params);
            linearModelMapper.loadModel(list);
        } else {
            if (!str.equals(GbdtClassifier.class.getCanonicalName())) {
                throw new UnsupportedOperationException("OneVsRest does not support classifier: " + str);
            }
            linearModelMapper = new GbdtModelMapper(tableSchema, tableSchema2, params);
            linearModelMapper.loadModel(list);
        }
        return linearModelMapper;
    }

    public Params extractMeta(List<Row> list) {
        Params params = null;
        Iterator<Row> it = list.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            Row next = it.next();
            if (next.getField(1) != null) {
                params = Params.fromJson((String) next.getField(1));
                break;
            }
        }
        return params;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static List<Object> recoverLabel(String str, String str2) {
        List<Object> list = (List) JsonConverter.gson.fromJson(str, ArrayList.class);
        recoverLabelType(list, str2);
        return list;
    }
}
