package com.alibaba.alink.operator.batch.classification;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.FeatureColsVectorColMutexRule;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp;
import com.alibaba.alink.operator.batch.utils.WithTrainInfo;
import com.alibaba.alink.operator.common.dataproc.SortUtils;
import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp;
import com.alibaba.alink.operator.common.linear.LinearModelData;
import com.alibaba.alink.operator.common.linear.LinearModelDataConverter;
import com.alibaba.alink.operator.common.linear.LinearModelTrainInfo;
import com.alibaba.alink.operator.common.linear.SoftmaxModelInfo;
import com.alibaba.alink.operator.common.linear.SoftmaxObjFunc;
import com.alibaba.alink.operator.common.optim.Lbfgs;
import com.alibaba.alink.operator.common.optim.Optimizer;
import com.alibaba.alink.operator.common.optim.OptimizerFactory;
import com.alibaba.alink.operator.common.optim.Owlqn;
import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc;
import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.classification.SoftmaxTrainParams;
import com.alibaba.alink.params.shared.colname.HasFeatureCols;
import com.alibaba.alink.params.shared.colname.HasVectorCol;
import com.alibaba.alink.params.shared.linear.HasL1;
import com.alibaba.alink.params.shared.linear.LinearTrainParams;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@OutputPorts(values = {@PortSpec(PortType.MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.MODEL_INFO)})
@InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(value = PortType.MODEL, isOptional = true)})
@FeatureColsVectorColMutexRule
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES}), @ParamSelectColumnSpec(name = "labelCol"), @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})})
@NameCn("Softmax训练")
@NameEn("Softmax Training")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.Softmax")
/* loaded from: input_file:com/alibaba/alink/operator/batch/classification/SoftmaxTrainBatchOp.class */
public final class SoftmaxTrainBatchOp extends BatchOperator<SoftmaxTrainBatchOp> implements SoftmaxTrainParams<SoftmaxTrainBatchOp>, WithTrainInfo<LinearModelTrainInfo, SoftmaxTrainBatchOp>, WithModelInfoBatchOp<SoftmaxModelInfo, SoftmaxTrainBatchOp, SoftmaxModelInfoBatchOp> {
    private static final long serialVersionUID = 2291776467437931890L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/classification/SoftmaxTrainBatchOp$BuildModelFromCoefs.class */
    public static class BuildModelFromCoefs extends AbstractRichFunction implements MapPartitionFunction<Tuple2<DenseVector, double[]>, Row> {
        private static final long serialVersionUID = -5211654314835044657L;
        private final String[] featureNames;
        private Params meta;
        private int labelSize;
        private final TypeInformation<?> labelType;
        private final boolean standardization;
        private DenseVector[] meanVar;

        private BuildModelFromCoefs(TypeInformation<?> typeInformation, String[] strArr, boolean z) {
            this.featureNames = strArr;
            this.labelType = typeInformation;
            this.standardization = z;
        }

        public void open(Configuration configuration) throws Exception {
            this.meta = (Params) getRuntimeContext().getBroadcastVariable("meta").get(0);
            this.meanVar = (DenseVector[]) getRuntimeContext().getBroadcastVariable("meanVar").get(0);
            this.labelSize = ((Integer) this.meta.get(ModelParamName.NUM_CLASSES)).intValue();
        }

        public void mapPartition(Iterable<Tuple2<DenseVector, double[]>> iterable, Collector<Row> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            boolean booleanValue = ((Boolean) this.meta.get(ModelParamName.HAS_INTERCEPT_ITEM)).booleanValue();
            for (Tuple2<DenseVector, double[]> tuple2 : iterable) {
                this.meta.set((ParamInfo<ParamInfo<Integer>>) ModelParamName.VECTOR_SIZE, (ParamInfo<Integer>) Integer.valueOf((((DenseVector) tuple2.f0).size() / (this.labelSize - 1)) - (booleanValue ? 1 : 0)));
                this.meta.set((ParamInfo<ParamInfo<double[]>>) ModelParamName.LOSS_CURVE, (ParamInfo<double[]>) tuple2.f1);
                if (this.standardization) {
                    if (booleanValue) {
                        int size = this.meanVar[0].size();
                        for (int i = 0; i < this.labelSize - 1; i++) {
                            double d = 0.0d;
                            for (int i2 = 1; i2 < size; i2++) {
                                int i3 = (i * size) + i2;
                                d += (((DenseVector) tuple2.f0).get(i3) * this.meanVar[0].get(i2)) / this.meanVar[1].get(i2);
                                ((DenseVector) tuple2.f0).set(i3, ((DenseVector) tuple2.f0).get(i3) / this.meanVar[1].get(i2));
                            }
                            ((DenseVector) tuple2.f0).set(i * size, ((DenseVector) tuple2.f0).get(i * size) - d);
                        }
                    } else {
                        for (int i4 = 0; i4 < ((DenseVector) tuple2.f0).size(); i4++) {
                            ((DenseVector) tuple2.f0).set(i4, ((DenseVector) tuple2.f0).get(i4) / this.meanVar[1].get(i4 % this.meanVar[1].size()));
                        }
                    }
                }
                arrayList.add(tuple2.f0);
            }
            new LinearModelDataConverter(this.labelType).save(new LinearModelData(this.labelType, this.meta, this.featureNames, (DenseVector) arrayList.get(0)), collector);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/classification/SoftmaxTrainBatchOp$CreateMeta.class */
    public static class CreateMeta implements MapPartitionFunction<Row, Params> {
        private static final long serialVersionUID = 8430372703655142394L;
        private final String modelName;
        private final boolean hasInterceptItem;
        private final String vectorColName;
        private final String labelColName;

        private CreateMeta(String str, boolean z, String str2, String str3) {
            this.modelName = str;
            this.hasInterceptItem = z;
            this.vectorColName = str2;
            this.labelColName = str3;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Params> collector) throws Exception {
            ArrayList<Row> arrayList = new ArrayList();
            Iterator<Row> it = iterable.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            String[] strArr = new String[arrayList.size()];
            for (Row row : arrayList) {
                strArr[((Long) row.getField(1)).intValue()] = row.getField(0).toString();
            }
            Params params = new Params();
            params.set((ParamInfo<ParamInfo<String>>) ModelParamName.MODEL_NAME, (ParamInfo<String>) this.modelName);
            params.set((ParamInfo<ParamInfo<Object[]>>) ModelParamName.LABEL_VALUES, (ParamInfo<Object[]>) strArr);
            params.set((ParamInfo<ParamInfo<String>>) ModelParamName.LABEL_COL_NAME, (ParamInfo<String>) this.labelColName);
            params.set((ParamInfo<ParamInfo<Boolean>>) ModelParamName.HAS_INTERCEPT_ITEM, (ParamInfo<Boolean>) Boolean.valueOf(this.hasInterceptItem));
            params.set((ParamInfo<ParamInfo<String>>) ModelParamName.VECTOR_COL_NAME, (ParamInfo<String>) this.vectorColName);
            params.set((ParamInfo<ParamInfo<Integer>>) ModelParamName.NUM_CLASSES, (ParamInfo<Integer>) Integer.valueOf(arrayList.size()));
            collector.collect(params);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/classification/SoftmaxTrainBatchOp$PreProcess.class */
    public static class PreProcess extends AbstractRichFunction implements MapPartitionFunction<Tuple3<Double, Object, Vector>, Tuple3<Double, Double, Vector>> {
        private static final long serialVersionUID = -5610968130256583178L;
        private final boolean hasInterceptItem;
        private final boolean standardization;
        private final HashMap<String, Double> labelMap = new HashMap<>();
        private DenseVector[] meanVar;

        public PreProcess(boolean z, boolean z2) {
            this.hasInterceptItem = z;
            this.standardization = z2;
        }

        public void open(Configuration configuration) throws Exception {
            for (Row row : getRuntimeContext().getBroadcastVariable("labelIDs")) {
                this.labelMap.put(row.getField(0).toString(), Double.valueOf(((Long) row.getField(1)).doubleValue()));
            }
            this.meanVar = (DenseVector[]) getRuntimeContext().getBroadcastVariable("meanVar").get(0);
        }

        public void mapPartition(Iterable<Tuple3<Double, Object, Vector>> iterable, Collector<Tuple3<Double, Double, Vector>> collector) throws Exception {
            for (Tuple3<Double, Object, Vector> tuple3 : iterable) {
                Double d = (Double) tuple3.f0;
                Vector vector = (Vector) tuple3.f2;
                Double d2 = this.labelMap.get(tuple3.f1.toString());
                if (((Double) tuple3.f0).doubleValue() > Criteria.INVALID_GAIN) {
                    if (vector instanceof DenseVector) {
                        if (this.standardization) {
                            if (this.hasInterceptItem) {
                                for (int i = 0; i < vector.size(); i++) {
                                    vector.set(i, (vector.get(i) - this.meanVar[0].get(i)) / this.meanVar[1].get(i));
                                }
                            } else {
                                for (int i2 = 0; i2 < vector.size(); i2++) {
                                    vector.set(i2, vector.get(i2) / this.meanVar[1].get(i2));
                                }
                            }
                        }
                    } else if (this.standardization) {
                        int[] indices = ((SparseVector) vector).getIndices();
                        double[] values = ((SparseVector) vector).getValues();
                        for (int i3 = 0; i3 < indices.length; i3++) {
                            values[i3] = values[i3] / this.meanVar[1].get(indices[i3]);
                        }
                    }
                    collector.collect(Tuple3.of(d, d2, vector));
                }
            }
        }
    }

    public SoftmaxTrainBatchOp() {
        this(null);
    }

    public SoftmaxTrainBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public SoftmaxTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> batchOperator;
        BatchOperator<?> batchOperator2 = null;
        if (batchOperatorArr.length == 1) {
            batchOperator = checkAndGetFirst(batchOperatorArr);
        } else {
            batchOperator = batchOperatorArr[0];
            batchOperator2 = batchOperatorArr[1];
        }
        String str = "softmax";
        final boolean booleanValue = getWithIntercept().booleanValue();
        boolean booleanValue2 = getStandardization().booleanValue();
        String[] featureCols = getFeatureCols();
        String labelCol = getLabelCol();
        String vectorCol = getVectorCol();
        if (getParams().contains(HasFeatureCols.FEATURE_COLS) && getParams().contains(HasVectorCol.VECTOR_COL)) {
            throw new AkIllegalArgumentException("featureCols and vectorCol cannot be set at the same time.");
        }
        TableSchema schema = batchOperator.getSchema();
        if (null == featureCols && null == vectorCol) {
            featureCols = TableUtil.getNumericCols(batchOperator.getSchema(), new String[]{labelCol});
            getParams().set((ParamInfo<ParamInfo<String[]>>) SoftmaxTrainParams.FEATURE_COLS, (ParamInfo<String[]>) featureCols);
        }
        final TypeInformation<?> findColTypeWithAssertAndHint = TableUtil.findColTypeWithAssertAndHint(schema, labelCol);
        DataSet<Tuple3<Double, Object, Vector>> transform = BaseLinearModelTrainBatchOp.transform(batchOperator, getParams(), false, booleanValue2);
        DataSet<Tuple3<DenseVector[], Object[], Integer[]>> utilInfo = BaseLinearModelTrainBatchOp.getUtilInfo(transform, booleanValue2, false);
        FlatMapOperator flatMap = utilInfo.flatMap(new FlatMapFunction<Tuple3<DenseVector[], Object[], Integer[]>, Row>() { // from class: com.alibaba.alink.operator.batch.classification.SoftmaxTrainBatchOp.1
            private static final long serialVersionUID = 6773656778135257500L;

            public void flatMap(Tuple3<DenseVector[], Object[], Integer[]> tuple3, Collector<Row> collector) {
                ArrayList arrayList = new ArrayList();
                for (Object obj : (Object[]) tuple3.f1) {
                    arrayList.add(Row.of(new Object[]{obj}));
                }
                arrayList.sort(new SortUtils.RowComparator(0));
                if (arrayList.size() > ((Integer[]) tuple3.f2)[1].intValue() * 0.5d && arrayList.size() > 1000) {
                    throw new AkIllegalDataException("label num is : " + arrayList.size() + ", sample num is : " + ((Integer[]) tuple3.f2)[1] + ", please check your label column.");
                }
                for (int i = 0; i < arrayList.size(); i++) {
                    Row row = new Row(2);
                    row.setField(0, ((Row) arrayList.get(i)).getField(0));
                    row.setField(1, Long.valueOf(i));
                    collector.collect(row);
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple3<DenseVector[], Object[], Integer[]>) obj, (Collector<Row>) collector);
            }
        });
        MapOperator map = utilInfo.map(new MapFunction<Tuple3<DenseVector[], Object[], Integer[]>, DenseVector[]>() { // from class: com.alibaba.alink.operator.batch.classification.SoftmaxTrainBatchOp.2
            private static final long serialVersionUID = 2633660310293456071L;

            public DenseVector[] map(Tuple3<DenseVector[], Object[], Integer[]> tuple3) {
                return (DenseVector[]) tuple3.f0;
            }
        });
        MapOperator map2 = utilInfo.map(new MapFunction<Tuple3<DenseVector[], Object[], Integer[]>, Integer>() { // from class: com.alibaba.alink.operator.batch.classification.SoftmaxTrainBatchOp.3
            private static final long serialVersionUID = -8902907232968104891L;

            public Integer map(Tuple3<DenseVector[], Object[], Integer[]> tuple3) {
                return ((Integer[]) tuple3.f2)[0];
            }
        });
        DataSet<Tuple2<DenseVector, double[]>> optimize = optimize(getParams(), map2, transform.mapPartition(new PreProcess(booleanValue, booleanValue2)).withBroadcastSet(flatMap, "labelIDs").withBroadcastSet(map, "meanVar"), batchOperator2 == null ? null : batchOperator2.getDataSet().reduceGroup(new RichGroupReduceFunction<Row, DenseVector>() { // from class: com.alibaba.alink.operator.batch.classification.SoftmaxTrainBatchOp.4
            /* JADX WARN: Removed duplicated region for block: B:16:0x00a6  */
            /* JADX WARN: Removed duplicated region for block: B:19:0x00b0  */
            /*
                Code decompiled incorrectly, please refer to instructions dump.
                To view partially-correct add '--show-bad-code' argument
            */
            public void reduce(java.lang.Iterable<org.apache.flink.types.Row> r8, org.apache.flink.util.Collector<com.alibaba.alink.common.linalg.DenseVector> r9) {
                /*
                    Method dump skipped, instructions count: 411
                    To view this dump add '--comments-level debug' option
                */
                throw new UnsupportedOperationException("Method not decompiled: com.alibaba.alink.operator.batch.classification.SoftmaxTrainBatchOp.AnonymousClass4.reduce(java.lang.Iterable, org.apache.flink.util.Collector):void");
            }
        }).withBroadcastSet(map2, "featSize").withBroadcastSet(map, "meanVar"), booleanValue, flatMap);
        SingleInputUdfOperator withBroadcastSet = optimize.mapPartition(new BuildModelFromCoefs(findColTypeWithAssertAndHint, featureCols, booleanValue2)).withBroadcastSet(flatMap.mapPartition(new CreateMeta(str, booleanValue, vectorCol, labelCol)).setParallelism(1), "meta").setParallelism(1).withBroadcastSet(map, "meanVar");
        setOutput((DataSet<Row>) withBroadcastSet, new LinearModelDataConverter(findColTypeWithAssertAndHint).getModelSchema());
        setSideOutputTables(getSideTablesOfCoefficient(withBroadcastSet, optimize.project(new int[]{1})));
        return this;
    }

    private DataSet<Tuple2<DenseVector, double[]>> optimize(Params params, DataSet<Integer> dataSet, DataSet<Tuple3<Double, Double, Vector>> dataSet2, DataSet<DenseVector> dataSet3, boolean z, DataSet<Row> dataSet4) {
        SingleInputUdfOperator withBroadcastSet;
        final double doubleValue = getL1().doubleValue();
        final double doubleValue2 = getL2().doubleValue();
        String[] strArr = (String[]) params.get(SoftmaxTrainParams.FEATURE_COLS);
        String str = (String) params.get(SoftmaxTrainParams.VECTOR_COL);
        GroupReduceOperator reduceGroup = dataSet4.reduceGroup(new GroupReduceFunction<Row, Integer>() { // from class: com.alibaba.alink.operator.batch.classification.SoftmaxTrainBatchOp.5
            private static final long serialVersionUID = -8665284351311032858L;

            public void reduce(Iterable<Row> iterable, Collector<Integer> collector) {
                int i = 0;
                for (Row row : iterable) {
                    i++;
                }
                collector.collect(Integer.valueOf(i));
            }
        });
        if (str == null || str.length() == 0) {
            ExecutionEnvironment executionEnvironment = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment();
            Integer[] numArr = new Integer[1];
            numArr[0] = Integer.valueOf(strArr.length + (z ? 1 : 0));
            withBroadcastSet = executionEnvironment.fromElements(numArr).map(new RichMapFunction<Integer, Integer>() { // from class: com.alibaba.alink.operator.batch.classification.SoftmaxTrainBatchOp.7
                private static final long serialVersionUID = 3133807849008897754L;
                private int k1;

                public void open(Configuration configuration) throws Exception {
                    super.open(configuration);
                    this.k1 = ((Integer) getRuntimeContext().getBroadcastVariable("numClass").get(0)).intValue() - 1;
                }

                public Integer map(Integer num) {
                    return Integer.valueOf(this.k1 * num.intValue());
                }
            }).withBroadcastSet(reduceGroup, "numClass");
        } else {
            withBroadcastSet = dataSet.map(new RichMapFunction<Integer, Integer>() { // from class: com.alibaba.alink.operator.batch.classification.SoftmaxTrainBatchOp.6
                private static final long serialVersionUID = 3041217732252202526L;
                private int k1;

                public void open(Configuration configuration) throws Exception {
                    super.open(configuration);
                    this.k1 = ((Integer) getRuntimeContext().getBroadcastVariable("numClass").get(0)).intValue() - 1;
                }

                public Integer map(Integer num) {
                    return Integer.valueOf(this.k1 * num.intValue());
                }
            }).withBroadcastSet(reduceGroup, "numClass");
        }
        GroupReduceOperator reduceGroup2 = reduceGroup.reduceGroup(new GroupReduceFunction<Integer, OptimObjFunc>() { // from class: com.alibaba.alink.operator.batch.classification.SoftmaxTrainBatchOp.8
            private static final long serialVersionUID = -4647154716237314079L;

            public void reduce(Iterable<Integer> iterable, Collector<OptimObjFunc> collector) {
                int i = 0;
                Iterator<Integer> it = iterable.iterator();
                while (it.hasNext()) {
                    i = it.next().intValue();
                }
                collector.collect(new SoftmaxObjFunc(new Params().set((ParamInfo<ParamInfo<Double>>) SoftmaxTrainParams.L_1, (ParamInfo<Double>) Double.valueOf(doubleValue)).set((ParamInfo<ParamInfo<Double>>) SoftmaxTrainParams.L_2, (ParamInfo<Double>) Double.valueOf(doubleValue2)).set((ParamInfo<ParamInfo<Integer>>) ModelParamName.NUM_CLASSES, (ParamInfo<Integer>) Integer.valueOf(i))));
            }
        });
        Optimizer create = params.contains(LinearTrainParams.OPTIM_METHOD) ? OptimizerFactory.create(reduceGroup2, dataSet2, withBroadcastSet, params, (LinearTrainParams.OptimMethod) params.get(LinearTrainParams.OPTIM_METHOD)) : ((Double) params.get(HasL1.L_1)).doubleValue() > Criteria.INVALID_GAIN ? new Owlqn(reduceGroup2, dataSet2, withBroadcastSet, params) : new Lbfgs(reduceGroup2, dataSet2, withBroadcastSet, params);
        create.initCoefWith(dataSet3);
        return create.optimize();
    }

    private Table[] getSideTablesOfCoefficient(DataSet<Row> dataSet, DataSet<Tuple1<double[]>> dataSet2) {
        Operator parallelism = dataSet.mapPartition(new MapPartitionFunction<Row, LinearModelData>() { // from class: com.alibaba.alink.operator.batch.classification.SoftmaxTrainBatchOp.9
            private static final long serialVersionUID = 2063366042018382802L;

            public void mapPartition(Iterable<Row> iterable, Collector<LinearModelData> collector) {
                ArrayList arrayList = new ArrayList();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                collector.collect(new LinearModelDataConverter().load((List<Row>) arrayList));
            }
        }).setParallelism(1);
        return new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) parallelism.mapPartition(new RichMapPartitionFunction<LinearModelData, Row>() { // from class: com.alibaba.alink.operator.batch.classification.SoftmaxTrainBatchOp.10
            private static final long serialVersionUID = 8785824618242390100L;

            public void mapPartition(Iterable<LinearModelData> iterable, Collector<Row> collector) {
                LinearModelData next = iterable.iterator().next();
                double[] dArr = (double[]) ((Tuple1) getRuntimeContext().getBroadcastVariable("cinfo").get(0)).f0;
                collector.collect(Row.of(new Object[]{0L, JsonConverter.toJson(next.getMetaInfo())}));
                collector.collect(Row.of(new Object[]{4L, JsonConverter.toJson(dArr)}));
            }
        }).setParallelism(1).withBroadcastSet(parallelism, OptimVariable.model).withBroadcastSet(dataSet2, "cinfo"), new TableSchema(new String[]{"title", "info"}, new TypeInformation[]{Types.LONG, Types.STRING}))};
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp
    public SoftmaxModelInfoBatchOp getModelInfoBatchOp() {
        return new SoftmaxModelInfoBatchOp(getParams()).linkFrom(this);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.utils.WithTrainInfo
    public LinearModelTrainInfo createTrainInfo(List<Row> list) {
        return new LinearModelTrainInfo(list);
    }

    @Override // com.alibaba.alink.operator.batch.utils.WithTrainInfo
    public BatchOperator<?> getSideOutputTrainInfo() {
        return getSideOutput(0);
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ SoftmaxTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }

    @Override // com.alibaba.alink.operator.batch.utils.WithTrainInfo
    public /* bridge */ /* synthetic */ LinearModelTrainInfo createTrainInfo(List list) {
        return createTrainInfo((List<Row>) list);
    }
}
