package com.alibaba.alink.operator.common.tree;

import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.TableSourceBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp;
import com.alibaba.alink.operator.common.tree.TreeModelDataConverter;
import com.alibaba.alink.operator.common.tree.TreeModelInfoBatchOp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/* loaded from: input_file:com/alibaba/alink/operator/common/tree/TreeModelInfoBatchOp.class */
public abstract class TreeModelInfoBatchOp<S, T extends TreeModelInfoBatchOp<S, T>> extends ExtractModelInfoBatchOp<S, T> {
    private static final long serialVersionUID = 1735133462550836751L;

    public TreeModelInfoBatchOp() {
        this(null);
    }

    public TreeModelInfoBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp
    protected BatchOperator<?> processModel() {
        return combinedTreeModelFeatureImportance(this, (BatchOperator) new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) getDataSet().reduceGroup(new TreeModelDataConverter.FeatureImportanceReducer()), new String[]{(String) getParams().get(TreeModelDataConverter.IMPORTANCE_FIRST_COL), (String) getParams().get(TreeModelDataConverter.IMPORTANCE_SECOND_COL)}, (TypeInformation<?>[]) new TypeInformation[]{Types.STRING, Types.DOUBLE})).setMLEnvironmentId(getMLEnvironmentId()));
    }

    private static BatchOperator<?> combinedTreeModelFeatureImportance(BatchOperator<?> batchOperator, BatchOperator<?> batchOperator2) {
        return new TableSourceBatchOp(DataSetConversionUtil.toTable(batchOperator.getMLEnvironmentId(), (DataSet<Row>) batchOperator.getDataSet().reduceGroup(new RichGroupReduceFunction<Row, Row>() { // from class: com.alibaba.alink.operator.common.tree.TreeModelInfoBatchOp.2
            private static final long serialVersionUID = -1576541700351312745L;
            private transient String featureImportanceJson;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                this.featureImportanceJson = (String) getRuntimeContext().getBroadcastVariableWithInitializer("importanceJson", new BroadcastVariableInitializer<String, String>() { // from class: com.alibaba.alink.operator.common.tree.TreeModelInfoBatchOp.2.1
                    public String initializeBroadcastVariable(Iterable<String> iterable) {
                        return iterable.iterator().next();
                    }

                    /* renamed from: initializeBroadcastVariable, reason: collision with other method in class */
                    public /* bridge */ /* synthetic */ Object m625initializeBroadcastVariable(Iterable iterable) {
                        return initializeBroadcastVariable((Iterable<String>) iterable);
                    }
                });
            }

            public void reduce(Iterable<Row> iterable, Collector<Row> collector) throws Exception {
                ArrayList arrayList = new ArrayList();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                TreeModelDataConverter load = new TreeModelDataConverter().load(arrayList);
                load.meta.set((ParamInfo<ParamInfo<String>>) TreeModelInfo.FEATURE_IMPORTANCE, (ParamInfo<String>) this.featureImportanceJson);
                load.save(load, collector);
            }
        }).withBroadcastSet(batchOperator2.getDataSet().reduceGroup(new GroupReduceFunction<Row, String>() { // from class: com.alibaba.alink.operator.common.tree.TreeModelInfoBatchOp.1
            private static final long serialVersionUID = -1576541700351312745L;

            public void reduce(Iterable<Row> iterable, Collector<String> collector) throws Exception {
                HashMap hashMap = new HashMap();
                for (Row row : iterable) {
                    hashMap.put(String.valueOf(row.getField(0)), Double.valueOf(((Number) row.getField(1)).doubleValue()));
                }
                collector.collect(JsonConverter.toJson(hashMap));
            }
        }), "importanceJson"), batchOperator.getColNames(), batchOperator.getColTypes()));
    }
}
