package com.alibaba.alink.operator.local.classification;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.exceptions.AkIllegalArgumentException;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp;
import com.alibaba.alink.operator.common.fm.FmModelData;
import com.alibaba.alink.operator.common.fm.FmModelDataConverter;
import com.alibaba.alink.operator.common.optim.LocalFmOptimizer;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.operator.local.classification.FmTrainLocalOp;
import com.alibaba.alink.params.recommendation.FmTrainParams;
import com.alibaba.alink.params.shared.colname.HasFeatureCols;
import com.alibaba.alink.params.shared.colname.HasVectorCol;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@Internal
/* loaded from: input_file:com/alibaba/alink/operator/local/classification/FmTrainLocalOp.class */
public class FmTrainLocalOp<T extends FmTrainLocalOp<T>> extends LocalOperator<T> {
    private static final long serialVersionUID = -3985394692858121356L;

    public FmTrainLocalOp(Params params, BaseFmTrainBatchOp.Task task) {
        super(params.set((ParamInfo<ParamInfo<BaseFmTrainBatchOp.Task>>) ModelParamName.TASK, (ParamInfo<BaseFmTrainBatchOp.Task>) task));
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public T linkFrom(LocalOperator<?>... localOperatorArr) {
        LocalOperator<?> checkAndGetFirst = checkAndGetFirst(localOperatorArr);
        Params params = getParams();
        if (params.contains(HasFeatureCols.FEATURE_COLS) && params.contains(HasVectorCol.VECTOR_COL)) {
            throw new AkIllegalArgumentException("featureCols and vectorCol cannot be set at the same time.");
        }
        int[] iArr = new int[3];
        iArr[0] = ((Boolean) params.get(FmTrainParams.WITH_INTERCEPT)).booleanValue() ? 1 : 0;
        iArr[1] = ((Boolean) params.get(FmTrainParams.WITH_LINEAR_ITEM)).booleanValue() ? 1 : 0;
        iArr[2] = ((Integer) params.get(FmTrainParams.NUM_FACTOR)).intValue();
        boolean equals = ((BaseFmTrainBatchOp.Task) params.get(ModelParamName.TASK)).equals(BaseFmTrainBatchOp.Task.REGRESSION);
        TypeInformation<?> typeInformation = equals ? Types.DOUBLE : checkAndGetFirst.getColTypes()[TableUtil.findColIndex(checkAndGetFirst.getColNames(), (String) params.get(FmTrainParams.LABEL_COL))];
        List<Tuple3<Double, Object, Vector>> transform = transform(checkAndGetFirst, params, equals);
        Tuple2<Object[], Integer> labelsAndFeatureSize = getLabelsAndFeatureSize(transform, equals);
        setOutputTable(new MTable(transformModel(optimize(transferLabel(transform, equals, (Object[]) labelsAndFeatureSize.f0), ((Integer) labelsAndFeatureSize.f1).intValue(), params, iArr), (Object[]) labelsAndFeatureSize.f0, (Integer) labelsAndFeatureSize.f1, params, iArr, equals, typeInformation), new FmModelDataConverter(typeInformation).getModelSchema()));
        return this;
    }

    protected Tuple2<BaseFmTrainBatchOp.FmDataFormat, double[]> optimize(List<Tuple3<Double, Double, Vector>> list, int i, Params params, int[] iArr) {
        BaseFmTrainBatchOp.FmDataFormat fmDataFormat = new BaseFmTrainBatchOp.FmDataFormat(i, iArr, ((Double) getParams().get(FmTrainParams.INIT_STDEV)).doubleValue());
        LocalFmOptimizer localFmOptimizer = new LocalFmOptimizer(list, params);
        localFmOptimizer.setWithInitFactors(fmDataFormat);
        return localFmOptimizer.optimize();
    }

    private List<Row> transformModel(Tuple2<BaseFmTrainBatchOp.FmDataFormat, double[]> tuple2, Object[] objArr, Integer num, Params params, int[] iArr, boolean z, TypeInformation<?> typeInformation) {
        FmModelData fmModelData = new FmModelData();
        fmModelData.fmModel = (BaseFmTrainBatchOp.FmDataFormat) tuple2.f0;
        fmModelData.vectorColName = (String) params.get(FmTrainParams.VECTOR_COL);
        fmModelData.featureColNames = (String[]) params.get(FmTrainParams.FEATURE_COLS);
        fmModelData.dim = iArr;
        fmModelData.regular = new double[]{((Double) params.get(FmTrainParams.LAMBDA_0)).doubleValue(), ((Double) params.get(FmTrainParams.LAMBDA_1)).doubleValue(), ((Double) params.get(FmTrainParams.LAMBDA_2)).doubleValue()};
        fmModelData.labelColName = (String) params.get(FmTrainParams.LABEL_COL);
        fmModelData.task = (BaseFmTrainBatchOp.Task) params.get(ModelParamName.TASK);
        fmModelData.vectorSize = num.intValue();
        if (z) {
            fmModelData.labelValues = new Object[]{Double.valueOf(Criteria.INVALID_GAIN)};
        } else {
            fmModelData.labelValues = objArr;
        }
        RowCollector rowCollector = new RowCollector();
        new FmModelDataConverter(typeInformation).save2(fmModelData, (Collector<Row>) rowCollector);
        return rowCollector.getRows();
    }

    private List<Tuple3<Double, Double, Vector>> transferLabel(List<Tuple3<Double, Object, Vector>> list, boolean z, Object[] objArr) {
        ArrayList arrayList = new ArrayList(list.size());
        for (Tuple3<Double, Object, Vector> tuple3 : list) {
            if (((Double) tuple3.f0).doubleValue() > Criteria.INVALID_GAIN) {
                arrayList.add(Tuple3.of(tuple3.f0, Double.valueOf(z ? Double.parseDouble(tuple3.f1.toString()) : tuple3.f1.equals(objArr[0]) ? 1.0d : Criteria.INVALID_GAIN), tuple3.f2));
            }
        }
        return arrayList;
    }

    private Tuple2<Object[], Integer> getLabelsAndFeatureSize(List<Tuple3<Double, Object, Vector>> list, boolean z) {
        int i = -1;
        HashSet hashSet = new HashSet();
        for (Tuple3<Double, Object, Vector> tuple3 : list) {
            if (((Double) tuple3.f0).doubleValue() < Criteria.INVALID_GAIN) {
                Tuple2 tuple2 = (Tuple2) tuple3.f1;
                Collections.addAll(hashSet, (Object[]) tuple2.f1);
                i = Math.max(i, ((Integer) tuple2.f0).intValue());
            }
        }
        return Tuple2.of(z ? hashSet.toArray() : orderLabels(hashSet), Integer.valueOf(i));
    }

    private Object[] orderLabels(Iterable<Object> iterable) {
        ArrayList arrayList = new ArrayList();
        Iterator<Object> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        Object[] array = arrayList.toArray(new Object[0]);
        AkPreconditions.checkState(array.length == 2, (ExceptionWithErrorCode) new AkIllegalDataException("labels count should be 2 in 2 classification algo."));
        if (!(array[0] instanceof Number)) {
            String obj = array[0].toString();
            String obj2 = array[1].toString();
            if (array[1].toString().equals(obj2.compareTo(obj) > 0 ? obj2 : obj)) {
                Object obj3 = array[0];
                array[0] = array[1];
                array[1] = obj3;
            }
        } else if (((Number) array[0]).doubleValue() + ((Number) array[1]).doubleValue() == 1.0d && ((Number) array[0]).doubleValue() == Criteria.INVALID_GAIN) {
            Object obj4 = array[0];
            array[0] = array[1];
            array[1] = obj4;
        }
        return array;
    }

    private List<Tuple3<Double, Object, Vector>> transform(LocalOperator<?> localOperator, Params params, boolean z) {
        String[] strArr = (String[]) params.get(FmTrainParams.FEATURE_COLS);
        String str = (String) params.get(FmTrainParams.LABEL_COL);
        String str2 = (String) params.get(FmTrainParams.WEIGHT_COL);
        String str3 = (String) params.get(FmTrainParams.VECTOR_COL);
        TableSchema schema = localOperator.getSchema();
        if (null == strArr && null == str3) {
            strArr = TableUtil.getNumericCols(schema, new String[]{str});
            params.set((ParamInfo<ParamInfo<String[]>>) FmTrainParams.FEATURE_COLS, (ParamInfo<String[]>) strArr);
        }
        int[] iArr = null;
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(schema.getFieldNames(), str);
        if (strArr != null) {
            iArr = new int[strArr.length];
            for (int i = 0; i < strArr.length; i++) {
                iArr[i] = TableUtil.findColIndexWithAssertAndHint(localOperator.getColNames(), strArr[i]);
            }
        }
        return preprocess(localOperator.getOutputTable(), z, str2 != null ? TableUtil.findColIndexWithAssertAndHint(localOperator.getColNames(), str2) : -1, str3 != null ? TableUtil.findColIndexWithAssertAndHint(localOperator.getColNames(), str3) : -1, iArr, findColIndexWithAssertAndHint);
    }

    private List<Tuple3<Double, Object, Vector>> preprocess(MTable mTable, boolean z, int i, int i2, int[] iArr, int i3) {
        Vector vector;
        HashSet hashSet = new HashSet();
        ArrayList arrayList = new ArrayList();
        int length = iArr != null ? iArr.length : -1;
        for (Row row : mTable.getRows()) {
            Double valueOf = Double.valueOf(i == -1 ? 1.0d : ((Number) row.getField(i)).doubleValue());
            Object field = row.getField(i3);
            if (z) {
                hashSet.add(Double.valueOf(Criteria.INVALID_GAIN));
            } else {
                hashSet.add(field);
            }
            if (iArr != null) {
                vector = new DenseVector(iArr.length);
                for (int i4 = 0; i4 < iArr.length; i4++) {
                    vector.set(i4, ((Number) row.getField(iArr[i4])).doubleValue());
                }
            } else {
                vector = VectorUtil.getVector(row.getField(i2));
                if (vector instanceof SparseVector) {
                    for (int i5 : ((SparseVector) vector).getIndices()) {
                        length = vector.size() > 0 ? vector.size() : Math.max(length, i5 + 1);
                    }
                } else {
                    length = ((DenseVector) vector).getData().length;
                }
            }
            arrayList.add(Tuple3.of(valueOf, field, vector));
        }
        arrayList.add(Tuple3.of(Double.valueOf(-1.0d), Tuple2.of(Integer.valueOf(length), hashSet.toArray()), new DenseVector(0)));
        return arrayList;
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public /* bridge */ /* synthetic */ LocalOperator linkFrom(LocalOperator[] localOperatorArr) {
        return linkFrom((LocalOperator<?>[]) localOperatorArr);
    }
}
