package com.alibaba.alink.operator.common.fm;

import com.alibaba.alink.common.annotation.FeatureColsVectorColMutexRule;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.recommendation.FmTrainParams;
import com.alibaba.alink.params.shared.colname.HasFeatureCols;
import com.alibaba.alink.params.shared.colname.HasVectorCol;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(PortType.MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.MODEL_INFO)})
@FeatureColsVectorColMutexRule
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "featureCols", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES}), @ParamSelectColumnSpec(name = "labelCol"), @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES})})
/* loaded from: input_file:com/alibaba/alink/operator/common/fm/BaseFmTrainBatchOp.class */
public abstract class BaseFmTrainBatchOp<T extends BaseFmTrainBatchOp<T>> extends BatchOperator<T> {
    public static final String LABEL_VALUES = "labelValues";
    public static final String VEC_SIZE = "vecSize";
    private static final long serialVersionUID = -5308557491809175331L;
    protected TypeInformation<?> labelType;

    /* loaded from: input_file:com/alibaba/alink/operator/common/fm/BaseFmTrainBatchOp$FmDataFormat.class */
    public static class FmDataFormat implements Serializable {
        private static final long serialVersionUID = 192926704450234984L;
        public double[][] factors;
        public double bias;
        public int[] dim;

        public FmDataFormat() {
        }

        public FmDataFormat(int i, int[] iArr, double d) {
            this.dim = iArr;
            this.factors = new double[i][iArr[2] + iArr[1]];
            reset(d);
        }

        public FmDataFormat(int i, int i2, int[] iArr, double d) {
            this.dim = iArr;
            this.factors = new double[i * i2][iArr[2] + iArr[1]];
            reset(d);
        }

        public void reset(double d) {
            Random random = new Random(2020L);
            for (int i = 0; i < this.factors.length; i++) {
                for (int i2 = 0; i2 < this.factors[0].length; i2++) {
                    this.factors[i][i2] = random.nextGaussian() * d;
                }
            }
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/fm/BaseFmTrainBatchOp$LogitLoss.class */
    public static class LogitLoss implements LossFunction {
        private static final long serialVersionUID = -166213844104644622L;

        @Override // com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.LossFunction
        public double l(double d, double d2) {
            double sigmoid = sigmoid(d2);
            if (d < 0.5d) {
                return -Math.log(1.0d - sigmoid);
            }
            if (d >= 0.5d) {
                return -Math.log(sigmoid);
            }
            throw new AkIllegalDataException("Invalid label: " + d);
        }

        @Override // com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.LossFunction
        public double dldy(double d, double d2) {
            return sigmoid(d2) - d;
        }

        private double sigmoid(double d) {
            if (d < -37.0d) {
                return Criteria.INVALID_GAIN;
            }
            if (d > 34.0d) {
                return 1.0d;
            }
            return 1.0d / (1.0d + Math.exp(-d));
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/fm/BaseFmTrainBatchOp$LossFunction.class */
    public interface LossFunction extends Serializable {
        double l(double d, double d2);

        double dldy(double d, double d2);
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/fm/BaseFmTrainBatchOp$SquareLoss.class */
    public static class SquareLoss implements LossFunction {
        private static final long serialVersionUID = -3903508209287601504L;
        private final double maxTarget;
        private final double minTarget;

        public SquareLoss(double d, double d2) {
            this.maxTarget = d;
            this.minTarget = d2;
        }

        @Override // com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.LossFunction
        public double l(double d, double d2) {
            return (d - d2) * (d - d2);
        }

        @Override // com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.LossFunction
        public double dldy(double d, double d2) {
            return 2.0d * (Math.max(Math.min(d2, this.maxTarget), this.minTarget) - d);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/common/fm/BaseFmTrainBatchOp$Task.class */
    public enum Task {
        REGRESSION,
        BINARY_CLASSIFICATION
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/common/fm/BaseFmTrainBatchOp$Transform.class */
    public static class Transform extends RichMapPartitionFunction<Row, Tuple3<Double, Object, Vector>> {
        private static final long serialVersionUID = 5935792357245627952L;
        private final int vecIdx;
        private final int labelIdx;
        private final int weightIdx;
        private final boolean isRegProc;
        private final int[] featureIndices;

        public Transform(boolean z, int i, int i2, int[] iArr, int i3) {
            this.vecIdx = i2;
            this.labelIdx = i3;
            this.weightIdx = i;
            this.isRegProc = z;
            this.featureIndices = iArr;
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple3<Double, Object, Vector>> collector) throws Exception {
            Vector vector;
            HashSet hashSet = new HashSet();
            int length = this.featureIndices != null ? this.featureIndices.length : -1;
            for (Row row : iterable) {
                Double valueOf = Double.valueOf(this.weightIdx == -1 ? 1.0d : ((Number) row.getField(this.weightIdx)).doubleValue());
                Object field = row.getField(this.labelIdx);
                if (this.isRegProc) {
                    hashSet.add(Double.valueOf(Criteria.INVALID_GAIN));
                } else {
                    hashSet.add(field);
                }
                if (this.featureIndices != null) {
                    vector = new DenseVector(this.featureIndices.length);
                    for (int i = 0; i < this.featureIndices.length; i++) {
                        vector.set(i, ((Number) row.getField(this.featureIndices[i])).doubleValue());
                    }
                } else {
                    vector = VectorUtil.getVector(row.getField(this.vecIdx));
                    if (vector instanceof SparseVector) {
                        for (int i2 : ((SparseVector) vector).getIndices()) {
                            length = vector.size() > 0 ? vector.size() : Math.max(length, i2 + 1);
                        }
                    } else {
                        length = ((DenseVector) vector).getData().length;
                    }
                }
                collector.collect(Tuple3.of(valueOf, field, vector));
            }
            collector.collect(Tuple3.of(Double.valueOf(-1.0d), Tuple2.of(Integer.valueOf(length), hashSet.toArray()), new DenseVector(0)));
        }
    }

    public BaseFmTrainBatchOp(Params params) {
        super(params);
    }

    public BaseFmTrainBatchOp() {
        super(new Params());
    }

    protected abstract DataSet<Tuple2<FmDataFormat, double[]>> optimize(DataSet<Tuple3<Double, Double, Vector>> dataSet, DataSet<Integer> dataSet2, Params params, int[] iArr);

    protected abstract DataSet<Row> transformModel(DataSet<Tuple2<FmDataFormat, double[]>> dataSet, DataSet<Object[]> dataSet2, DataSet<Integer> dataSet3, Params params, int[] iArr, boolean z, TypeInformation<?> typeInformation);

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public T linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        Params params = getParams();
        if (params.contains(HasFeatureCols.FEATURE_COLS) && params.contains(HasVectorCol.VECTOR_COL)) {
            throw new AkIllegalArgumentException("featureCols and vectorCol cannot be set at the same time.");
        }
        int[] iArr = new int[3];
        iArr[0] = ((Boolean) params.get(FmTrainParams.WITH_INTERCEPT)).booleanValue() ? 1 : 0;
        iArr[1] = ((Boolean) params.get(FmTrainParams.WITH_LINEAR_ITEM)).booleanValue() ? 1 : 0;
        iArr[2] = ((Integer) params.get(FmTrainParams.NUM_FACTOR)).intValue();
        boolean equals = ((Task) params.get(ModelParamName.TASK)).equals(Task.REGRESSION);
        this.labelType = equals ? Types.DOUBLE : checkAndGetFirst.getColTypes()[TableUtil.findColIndex(checkAndGetFirst.getColNames(), (String) params.get(FmTrainParams.LABEL_COL))];
        DataSet<Tuple3<Double, Object, Vector>> transform = transform(checkAndGetFirst, params, equals);
        DataSet<Tuple2<Object[], Integer>> utilInfo = getUtilInfo(transform, equals);
        MapOperator map = utilInfo.map(new MapFunction<Tuple2<Object[], Integer>, Integer>() { // from class: com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.1
            private static final long serialVersionUID = 1099531852518545431L;

            public Integer map(Tuple2<Object[], Integer> tuple2) {
                return (Integer) tuple2.f1;
            }
        });
        FlatMapOperator flatMap = utilInfo.flatMap(new FlatMapFunction<Tuple2<Object[], Integer>, Object[]>() { // from class: com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.2
            private static final long serialVersionUID = -4407775357759305675L;

            public void flatMap(Tuple2<Object[], Integer> tuple2, Collector<Object[]> collector) {
                collector.collect(tuple2.f0);
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple2<Object[], Integer>) obj, (Collector<Object[]>) collector);
            }
        });
        DataSet<Tuple2<FmDataFormat, double[]>> optimize = optimize(transferLabel(transform, equals, flatMap), map, params, iArr);
        DataSet<Row> transformModel = transformModel(optimize, flatMap, map, params, iArr, equals, this.labelType);
        setOutput(transformModel, new FmModelDataConverter(this.labelType).getModelSchema());
        setSideOutputTables(getSideTablesOfCoefficient(transformModel, optimize.project(new int[]{1}), this.labelType));
        return this;
    }

    private static DataSet<Tuple3<Double, Double, Vector>> transferLabel(DataSet<Tuple3<Double, Object, Vector>> dataSet, final boolean z, DataSet<Object[]> dataSet2) {
        return dataSet.mapPartition(new RichMapPartitionFunction<Tuple3<Double, Object, Vector>, Tuple3<Double, Double, Vector>>() { // from class: com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.3
            private static final long serialVersionUID = 1609901151679856341L;
            private Object[] labelValues = null;

            public void open(Configuration configuration) {
                this.labelValues = (Object[]) getRuntimeContext().getBroadcastVariable(BaseFmTrainBatchOp.LABEL_VALUES).get(0);
            }

            public void mapPartition(Iterable<Tuple3<Double, Object, Vector>> iterable, Collector<Tuple3<Double, Double, Vector>> collector) {
                for (Tuple3<Double, Object, Vector> tuple3 : iterable) {
                    if (((Double) tuple3.f0).doubleValue() > Criteria.INVALID_GAIN) {
                        collector.collect(Tuple3.of(tuple3.f0, Double.valueOf(z ? Double.parseDouble(tuple3.f1.toString()) : tuple3.f1.equals(this.labelValues[0]) ? 1.0d : Criteria.INVALID_GAIN), tuple3.f2));
                    }
                }
            }
        }).withBroadcastSet(dataSet2, LABEL_VALUES);
    }

    private static DataSet<Tuple2<Object[], Integer>> getUtilInfo(DataSet<Tuple3<Double, Object, Vector>> dataSet, final boolean z) {
        return dataSet.filter(new FilterFunction<Tuple3<Double, Object, Vector>>() { // from class: com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.5
            private static final long serialVersionUID = 4954837288144406855L;

            public boolean filter(Tuple3<Double, Object, Vector> tuple3) {
                return ((Double) tuple3.f0).doubleValue() < Criteria.INVALID_GAIN;
            }
        }).reduceGroup(new GroupReduceFunction<Tuple3<Double, Object, Vector>, Tuple2<Object[], Integer>>() { // from class: com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.4
            private static final long serialVersionUID = 3520762756658301627L;

            public void reduce(Iterable<Tuple3<Double, Object, Vector>> iterable, Collector<Tuple2<Object[], Integer>> collector) {
                int i = -1;
                HashSet hashSet = new HashSet();
                Iterator<Tuple3<Double, Object, Vector>> it = iterable.iterator();
                while (it.hasNext()) {
                    Tuple2 tuple2 = (Tuple2) it.next().f1;
                    Collections.addAll(hashSet, (Object[]) tuple2.f1);
                    i = Math.max(i, ((Integer) tuple2.f0).intValue());
                }
                collector.collect(Tuple2.of(z ? hashSet.toArray() : BaseFmTrainBatchOp.orderLabels(hashSet), Integer.valueOf(i)));
            }
        });
    }

    protected static Object[] orderLabels(Iterable<Object> iterable) {
        ArrayList arrayList = new ArrayList();
        Iterator<Object> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        Object[] array = arrayList.toArray(new Object[0]);
        AkPreconditions.checkState(array.length == 2, (ExceptionWithErrorCode) new AkIllegalDataException("labels count should be 2 in 2 classification algo."));
        if (!(array[0] instanceof Number)) {
            String obj = array[0].toString();
            String obj2 = array[1].toString();
            if (array[1].toString().equals(obj2.compareTo(obj) > 0 ? obj2 : obj)) {
                Object obj3 = array[0];
                array[0] = array[1];
                array[1] = obj3;
            }
        } else if (((Number) array[0]).doubleValue() + ((Number) array[1]).doubleValue() == 1.0d && ((Number) array[0]).doubleValue() == Criteria.INVALID_GAIN) {
            Object obj4 = array[0];
            array[0] = array[1];
            array[1] = obj4;
        }
        return array;
    }

    private DataSet<Tuple3<Double, Object, Vector>> transform(BatchOperator<?> batchOperator, Params params, boolean z) {
        String[] strArr = (String[]) params.get(FmTrainParams.FEATURE_COLS);
        String str = (String) params.get(FmTrainParams.LABEL_COL);
        String str2 = (String) params.get(FmTrainParams.WEIGHT_COL);
        String str3 = (String) params.get(FmTrainParams.VECTOR_COL);
        TableSchema schema = batchOperator.getSchema();
        if (null == strArr && null == str3) {
            strArr = TableUtil.getNumericCols(schema, new String[]{str});
            params.set((ParamInfo<ParamInfo<String[]>>) FmTrainParams.FEATURE_COLS, (ParamInfo<String[]>) strArr);
        }
        int[] iArr = null;
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(schema.getFieldNames(), str);
        if (strArr != null) {
            iArr = new int[strArr.length];
            for (int i = 0; i < strArr.length; i++) {
                iArr[i] = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), strArr[i]);
            }
        }
        return batchOperator.getDataSet().mapPartition(new Transform(z, str2 != null ? TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str2) : -1, str3 != null ? TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), str3) : -1, iArr, findColIndexWithAssertAndHint));
    }

    private Table[] getSideTablesOfCoefficient(DataSet<Row> dataSet, DataSet<Tuple1<double[]>> dataSet2, final TypeInformation<?> typeInformation) {
        return new Table[]{DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) dataSet.mapPartition(new MapPartitionFunction<Row, FmModelData>() { // from class: com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.6
            private static final long serialVersionUID = 2063366042018382802L;

            public void mapPartition(Iterable<Row> iterable, Collector<FmModelData> collector) {
                ArrayList arrayList = new ArrayList();
                Iterator<Row> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                collector.collect(new FmModelDataConverter(typeInformation).load((List<Row>) arrayList));
            }
        }).setParallelism(1).mapPartition(new RichMapPartitionFunction<FmModelData, Row>() { // from class: com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.7
            private static final long serialVersionUID = 8785824618242390100L;

            public void mapPartition(Iterable<FmModelData> iterable, Collector<Row> collector) {
                FmModelData next = iterable.iterator().next();
                double[] dArr = (double[]) ((Tuple1) getRuntimeContext().getBroadcastVariable("cinfo").get(0)).f0;
                Params params = new Params();
                params.set((ParamInfo<ParamInfo<Integer>>) ModelParamName.VECTOR_SIZE, (ParamInfo<Integer>) Integer.valueOf(next.vectorSize));
                params.set((ParamInfo<ParamInfo<Object[]>>) ModelParamName.LABEL_VALUES, (ParamInfo<Object[]>) next.labelValues);
                params.set((ParamInfo<ParamInfo<Boolean>>) FmTrainParams.WITH_LINEAR_ITEM, (ParamInfo<Boolean>) Boolean.valueOf(next.dim[1] == 1));
                params.set((ParamInfo<ParamInfo<Boolean>>) FmTrainParams.WITH_INTERCEPT, (ParamInfo<Boolean>) Boolean.valueOf(next.dim[0] == 1));
                params.set((ParamInfo<ParamInfo<Integer>>) FmTrainParams.NUM_FACTOR, (ParamInfo<Integer>) Integer.valueOf(next.dim[2]));
                collector.collect(Row.of(new Object[]{0, JsonConverter.toJson(params)}));
                collector.collect(Row.of(new Object[]{1, JsonConverter.toJson(dArr)}));
            }
        }).setParallelism(1).withBroadcastSet(dataSet2, "cinfo"), new TableSchema(new String[]{"title", "info"}, new TypeInformation[]{Types.INT, Types.STRING}))};
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ BatchOperator linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
