package com.alibaba.alink.operator.common.fm;

import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp;
import com.alibaba.alink.operator.common.fm.FmTrainBatchOp;
import com.alibaba.alink.operator.common.optim.FmOptimizer;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.recommendation.FmTrainParams;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@Internal
/* loaded from: input_file:com/alibaba/alink/operator/common/fm/FmTrainBatchOp.class */
public class FmTrainBatchOp<T extends FmTrainBatchOp<T>> extends BaseFmTrainBatchOp<T> {
    private static final long serialVersionUID = -3985394692845121356L;

    /* loaded from: input_file:com/alibaba/alink/operator/common/fm/FmTrainBatchOp$GenerateModelRows.class */
    public static class GenerateModelRows extends RichFlatMapFunction<Tuple2<BaseFmTrainBatchOp.FmDataFormat, double[]>, Row> {
        private static final long serialVersionUID = -380930181466110905L;
        private final Params params;
        private final int[] dim;
        private final TypeInformation<?> labelType;
        private Object[] labelValues;
        private final boolean isRegProc;
        private int vecSize;

        public GenerateModelRows(Params params, int[] iArr, TypeInformation<?> typeInformation, boolean z) {
            this.params = params;
            this.labelType = typeInformation;
            this.dim = iArr;
            this.isRegProc = z;
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.labelValues = (Object[]) getRuntimeContext().getBroadcastVariable(BaseFmTrainBatchOp.LABEL_VALUES).get(0);
            this.vecSize = ((Integer) getRuntimeContext().getBroadcastVariable(BaseFmTrainBatchOp.VEC_SIZE).get(0)).intValue();
        }

        public void flatMap(Tuple2<BaseFmTrainBatchOp.FmDataFormat, double[]> tuple2, Collector<Row> collector) throws Exception {
            FmModelData fmModelData = new FmModelData();
            fmModelData.fmModel = (BaseFmTrainBatchOp.FmDataFormat) tuple2.f0;
            fmModelData.vectorColName = (String) this.params.get(FmTrainParams.VECTOR_COL);
            fmModelData.featureColNames = (String[]) this.params.get(FmTrainParams.FEATURE_COLS);
            fmModelData.dim = this.dim;
            fmModelData.regular = new double[]{((Double) this.params.get(FmTrainParams.LAMBDA_0)).doubleValue(), ((Double) this.params.get(FmTrainParams.LAMBDA_1)).doubleValue(), ((Double) this.params.get(FmTrainParams.LAMBDA_2)).doubleValue()};
            fmModelData.labelColName = (String) this.params.get(FmTrainParams.LABEL_COL);
            fmModelData.task = (BaseFmTrainBatchOp.Task) this.params.get(ModelParamName.TASK);
            if (this.isRegProc) {
                fmModelData.labelValues = new Object[]{Double.valueOf(Criteria.INVALID_GAIN)};
            } else {
                fmModelData.labelValues = this.labelValues;
            }
            fmModelData.vectorSize = this.vecSize;
            new FmModelDataConverter(this.labelType).save2(fmModelData, collector);
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Tuple2<BaseFmTrainBatchOp.FmDataFormat, double[]>) obj, (Collector<Row>) collector);
        }
    }

    public FmTrainBatchOp(Params params, BaseFmTrainBatchOp.Task task) {
        super(params.set((ParamInfo<ParamInfo<BaseFmTrainBatchOp.Task>>) ModelParamName.TASK, (ParamInfo<BaseFmTrainBatchOp.Task>) task));
    }

    @Override // com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp
    protected DataSet<Tuple2<BaseFmTrainBatchOp.FmDataFormat, double[]>> optimize(DataSet<Tuple3<Double, Double, Vector>> dataSet, DataSet<Integer> dataSet2, Params params, final int[] iArr) {
        final double doubleValue = ((Double) params.get(FmTrainParams.INIT_STDEV)).doubleValue();
        DataSet<BaseFmTrainBatchOp.FmDataFormat> map = dataSet2.map(new RichMapFunction<Integer, BaseFmTrainBatchOp.FmDataFormat>() { // from class: com.alibaba.alink.operator.common.fm.FmTrainBatchOp.1
            private static final long serialVersionUID = 76796953320215874L;

            public BaseFmTrainBatchOp.FmDataFormat map(Integer num) {
                return new BaseFmTrainBatchOp.FmDataFormat(num.intValue(), iArr, doubleValue);
            }
        });
        FmOptimizer fmOptimizer = new FmOptimizer(dataSet, params);
        fmOptimizer.setWithInitFactors(map);
        return fmOptimizer.optimize();
    }

    @Override // com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp
    protected DataSet<Row> transformModel(DataSet<Tuple2<BaseFmTrainBatchOp.FmDataFormat, double[]>> dataSet, DataSet<Object[]> dataSet2, DataSet<Integer> dataSet3, Params params, int[] iArr, boolean z, TypeInformation<?> typeInformation) {
        return dataSet.flatMap(new GenerateModelRows(params, iArr, typeInformation, z)).withBroadcastSet(dataSet2, BaseFmTrainBatchOp.LABEL_VALUES).withBroadcastSet(dataSet3, BaseFmTrainBatchOp.VEC_SIZE);
    }
}
