package com.alibaba.alink.operator.common.tensorflow;

import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.classification.tensorflow.TFTableModelClassificationModelData;
import com.alibaba.alink.operator.common.classification.tensorflow.TFTableModelClassificationModelDataConverter;
import com.alibaba.alink.operator.common.regression.tensorflow.TFTableModelRegressionModelData;
import com.alibaba.alink.operator.common.regression.tensorflow.TFTableModelRegressionModelDataConverter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/CommonUtils.class */
public class CommonUtils {
    public static final String TF_MODEL_BC_NAME = "TF_MODEL";
    public static final String PREPROCESS_PIPELINE_MODEL_BC_NAME = "PREPROCESS_PIPELINE_MODEL";
    public static final String SORTED_LABELS_BC_NAME = "SORTED_LABELS";

    /* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/CommonUtils$ConstructModelFlatMapFunction.class */
    public static class ConstructModelFlatMapFunction extends RichFlatMapFunction<Row, Row> {
        private final Params params;
        private final String[] featureCols;
        private final String tfOutputSignatureDef;
        private final TypeInformation<?> tfOutputSignatureType;
        private final String preprocessPipelineModelSchemaStr;
        private final boolean isOutputLogits;
        private List<Row> preprocessPipelineModelRows;
        private List<Row> tfModelRows;
        private List<Object> sortedLabels;

        public ConstructModelFlatMapFunction(Params params, String[] strArr, String str, TypeInformation<?> typeInformation, String str2) {
            this(params, strArr, str, typeInformation, str2, false);
        }

        public ConstructModelFlatMapFunction(Params params, String[] strArr, String str, TypeInformation<?> typeInformation, String str2, boolean z) {
            this.params = params;
            this.featureCols = strArr;
            this.tfOutputSignatureDef = str;
            this.tfOutputSignatureType = typeInformation;
            this.preprocessPipelineModelSchemaStr = str2;
            this.isOutputLogits = z;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            RuntimeContext runtimeContext = getRuntimeContext();
            this.sortedLabels = runtimeContext.hasBroadcastVariable(CommonUtils.SORTED_LABELS_BC_NAME) ? (List) runtimeContext.getBroadcastVariable(CommonUtils.SORTED_LABELS_BC_NAME).get(0) : Collections.emptyList();
            this.tfModelRows = runtimeContext.getBroadcastVariable(CommonUtils.TF_MODEL_BC_NAME);
            this.preprocessPipelineModelRows = runtimeContext.hasBroadcastVariable(CommonUtils.PREPROCESS_PIPELINE_MODEL_BC_NAME) ? runtimeContext.getBroadcastVariable(CommonUtils.PREPROCESS_PIPELINE_MODEL_BC_NAME) : Collections.emptyList();
        }

        public void flatMap(Row row, Collector<Row> collector) throws Exception {
            if (this.sortedLabels.size() > 0) {
                new TFTableModelClassificationModelDataConverter().save(new TFTableModelClassificationModelData(this.params, this.featureCols, this.tfModelRows, this.tfOutputSignatureDef, this.tfOutputSignatureType, this.preprocessPipelineModelSchemaStr, this.preprocessPipelineModelRows, this.sortedLabels, this.isOutputLogits), collector);
            } else {
                new TFTableModelRegressionModelDataConverter().save(new TFTableModelRegressionModelData(this.params, this.featureCols, this.tfModelRows, this.tfOutputSignatureDef, this.tfOutputSignatureType, this.preprocessPipelineModelSchemaStr, this.preprocessPipelineModelRows), collector);
            }
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Row) obj, (Collector<Row>) collector);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/CommonUtils$ConstructModelMapPartitionFunction.class */
    public static class ConstructModelMapPartitionFunction extends RichMapPartitionFunction<Row, Row> {
        private final Params params;
        private final String[] featureCols;
        private final String tfOutputSignatureDef;
        private final TypeInformation<?> tfOutputSignatureType;
        private final String preprocessPipelineModelSchemaStr;
        private final boolean isOutputLogits;
        private List<Row> preprocessPipelineModelRows;
        private List<Object> sortedLabels;

        public ConstructModelMapPartitionFunction(Params params, String[] strArr, String str, TypeInformation<?> typeInformation, String str2) {
            this(params, strArr, str, typeInformation, str2, false);
        }

        public ConstructModelMapPartitionFunction(Params params, String[] strArr, String str, TypeInformation<?> typeInformation, String str2, boolean z) {
            this.params = params;
            this.featureCols = strArr;
            this.tfOutputSignatureDef = str;
            this.tfOutputSignatureType = typeInformation;
            this.preprocessPipelineModelSchemaStr = str2;
            this.isOutputLogits = z;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            RuntimeContext runtimeContext = getRuntimeContext();
            this.sortedLabels = runtimeContext.hasBroadcastVariable(CommonUtils.SORTED_LABELS_BC_NAME) ? (List) runtimeContext.getBroadcastVariable(CommonUtils.SORTED_LABELS_BC_NAME).get(0) : Collections.emptyList();
            this.preprocessPipelineModelRows = runtimeContext.hasBroadcastVariable(CommonUtils.PREPROCESS_PIPELINE_MODEL_BC_NAME) ? runtimeContext.getBroadcastVariable(CommonUtils.PREPROCESS_PIPELINE_MODEL_BC_NAME) : Collections.emptyList();
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
            if (getRuntimeContext().getIndexOfThisSubtask() != 0) {
                return;
            }
            if (this.sortedLabels.size() > 0) {
                new TFTableModelClassificationModelDataConverter().save(new TFTableModelClassificationModelData(this.params, this.featureCols, iterable, this.tfOutputSignatureDef, this.tfOutputSignatureType, this.preprocessPipelineModelSchemaStr, this.preprocessPipelineModelRows, this.sortedLabels, this.isOutputLogits), collector);
            } else {
                new TFTableModelRegressionModelDataConverter().save(new TFTableModelRegressionModelData(this.params, this.featureCols, iterable, this.tfOutputSignatureDef, this.tfOutputSignatureType, this.preprocessPipelineModelSchemaStr, this.preprocessPipelineModelRows), collector);
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/CommonUtils$CountLabelsMapFunction.class */
    public static class CountLabelsMapFunction implements MapFunction<List<Object>, Row> {
        public Row map(List<Object> list) throws Exception {
            return Row.of(new Object[]{Integer.valueOf(list.size())});
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/CommonUtils$LabelToIndexMapper.class */
    public static class LabelToIndexMapper extends RichMapFunction<Row, Row> {
        private final int labelColId;
        private Map<Object, Float> labelIndexMap;

        public LabelToIndexMapper(int i) {
            this.labelColId = i;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            List list = (List) getRuntimeContext().getBroadcastVariable(CommonUtils.SORTED_LABELS_BC_NAME).get(0);
            this.labelIndexMap = new HashMap();
            for (int i = 0; i < list.size(); i++) {
                this.labelIndexMap.put(list.get(i), Float.valueOf((float) (i * 1.0d)));
            }
        }

        public Row map(Row row) throws Exception {
            Row copy = Row.copy(row);
            copy.setField(this.labelColId, this.labelIndexMap.get(copy.getField(this.labelColId)));
            return copy;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/tensorflow/CommonUtils$SortLabelsReduceGroupFunction.class */
    public static class SortLabelsReduceGroupFunction implements GroupReduceFunction<Row, List<Object>> {
        public void reduce(Iterable<Row> iterable, Collector<List<Object>> collector) throws Exception {
            TreeSet treeSet = new TreeSet();
            Iterator<Row> it = iterable.iterator();
            while (it.hasNext()) {
                treeSet.add(it.next().getField(0));
            }
            collector.collect(new ArrayList(treeSet));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static BatchOperator<?> mapLabelToIndex(BatchOperator<?> batchOperator, String str, DataSet<List<Object>> dataSet) {
        TableSchema schema = batchOperator.getSchema();
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(schema, str);
        SingleInputUdfOperator withBroadcastSet = batchOperator.getDataSet().map(new LabelToIndexMapper(findColIndexWithAssertAndHint)).withBroadcastSet(dataSet, SORTED_LABELS_BC_NAME);
        String[] fieldNames = schema.getFieldNames();
        TypeInformation[] fieldTypes = schema.getFieldTypes();
        fieldTypes[findColIndexWithAssertAndHint] = Types.FLOAT;
        return (BatchOperator) BatchOperator.fromTable(DataSetConversionUtil.toTable(batchOperator.getMLEnvironmentId(), (DataSet<Row>) withBroadcastSet, new TableSchema(fieldNames, fieldTypes))).setMLEnvironmentId(batchOperator.getMLEnvironmentId());
    }
}
