package com.alibaba.alink.operator.batch.recommendation;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.FmClassifierTrainBatchOp;
import com.alibaba.alink.operator.batch.regression.FmRegressorTrainBatchOp;
import com.alibaba.alink.operator.batch.sql.LeftOuterJoinBatchOp;
import com.alibaba.alink.operator.common.utils.PackBatchOperatorUtil;
import com.alibaba.alink.params.recommendation.FmRecommTrainParams;
import com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler;
import com.alibaba.alink.pipeline.feature.OneHotEncoder;
import java.util.Arrays;
import org.apache.commons.lang.ArrayUtils;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.functions.ScalarFunction;

@InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(value = PortType.DATA, isOptional = true), @PortSpec(value = PortType.DATA, isOptional = true)})
@OutputPorts(values = {@PortSpec(PortType.MODEL), @PortSpec(value = PortType.DATA, desc = PortDesc.USER_FACTOR), @PortSpec(value = PortType.DATA, desc = PortDesc.ITEM_FACTOR), @PortSpec(value = PortType.DATA, desc = PortDesc.APPEND_USER_FACTOR, isOptional = true), @PortSpec(value = PortType.DATA, desc = PortDesc.APPEND_ITEM_FACTOR, isOptional = true)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "userCol"), @ParamSelectColumnSpec(name = "itemCol"), @ParamSelectColumnSpec(name = "rateCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "userFeatureCols", portIndices = {1}, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "userCategoricalFeatureCols", portIndices = {1}), @ParamSelectColumnSpec(name = "itemFeatureCols", portIndices = {2}, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}), @ParamSelectColumnSpec(name = "itemCategoricalFeatureCols", portIndices = {2})})
@NameCn("FM推荐训练")
@NameEn("Fm Recommend Training")
/* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/FmRecommTrainBatchOp.class */
public final class FmRecommTrainBatchOp extends BatchOperator<FmRecommTrainBatchOp> implements FmRecommTrainParams<FmRecommTrainBatchOp> {
    private static final long serialVersionUID = 4783783487314711500L;
    boolean implicitFeedback;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/FmRecommTrainBatchOp$CheckNotNull.class */
    public static class CheckNotNull extends ScalarFunction {
        private static final long serialVersionUID = -8064346886560345966L;

        public String eval(String str) {
            if (str == null) {
                throw new AkIllegalDataException("feature vector is null, perhaps some user/item feature is missing.");
            }
            return str;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/recommendation/FmRecommTrainBatchOp$ConvertVec.class */
    public static class ConvertVec extends ScalarFunction {
        private static final long serialVersionUID = -8905679791356243034L;

        public String eval(Vector vector) {
            return VectorUtil.serialize(vector);
        }
    }

    public FmRecommTrainBatchOp() {
        this(new Params());
    }

    public FmRecommTrainBatchOp(Params params) {
        super(params);
        this.implicitFeedback = false;
    }

    private static String[] subtract(String[] strArr, String[] strArr2) {
        String[] strArr3 = new String[strArr.length];
        int i = 0;
        for (String str : strArr) {
            if (TableUtil.findColIndex(strArr2, str) < 0) {
                int i2 = i;
                i++;
                strArr3[i2] = str;
            }
        }
        return (String[]) Arrays.copyOf(strArr3, i);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static BatchOperator<?> createFeatureVectors(BatchOperator<?> batchOperator, String str, String[] strArr, String[] strArr2) {
        TableUtil.assertSelectedColExist(strArr, strArr2);
        String[] subtract = subtract(strArr, strArr2);
        Long mLEnvironmentId = batchOperator.getMLEnvironmentId();
        if (strArr2.length > 0) {
            batchOperator = ((OneHotEncoder) new OneHotEncoder().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(strArr2).setOutputCols("__fm_features__").setDropLast(false).fit(batchOperator).transform(batchOperator);
            subtract = (String[]) ArrayUtils.add(subtract, "__fm_features__");
        }
        return ((VectorAssembler) new VectorAssembler().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols(subtract).setOutputCol("__fm_features__").setReservedCols(str).transform(batchOperator).udf("__fm_features__", "__fm_features__", new ConvertVec());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public FmRecommTrainBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> batchOperator = batchOperatorArr[0];
        Long mLEnvironmentId = batchOperator.getMLEnvironmentId();
        BatchOperator<?> batchOperator2 = batchOperatorArr.length >= 2 ? batchOperatorArr[1] : null;
        BatchOperator<?> batchOperator3 = batchOperatorArr.length >= 3 ? batchOperatorArr[2] : null;
        Params params = getParams();
        String str = (String) params.get(USER_COL);
        String str2 = (String) params.get(ITEM_COL);
        String str3 = (String) params.get(RATE_COL);
        String[] strArr = (String[]) params.get(USER_FEATURE_COLS);
        String[] strArr2 = (String[]) params.get(ITEM_FEATURE_COLS);
        String[] strArr3 = (String[]) params.get(USER_CATEGORICAL_FEATURE_COLS);
        String[] strArr4 = (String[]) params.get(ITEM_CATEGORICAL_FEATURE_COLS);
        if (batchOperator2 == null) {
            batchOperator2 = batchOperator.select("`" + str + "`").distinct();
            strArr = new String[]{str};
            strArr3 = new String[]{str};
        } else {
            AkPreconditions.checkArgument(TableUtil.findColTypeWithAssert(batchOperator2.getSchema(), str).equals(TableUtil.findColTypeWithAssert(batchOperator.getSchema(), str)), (ExceptionWithErrorCode) new AkIllegalDataException("user column type mismatch"));
        }
        if (batchOperator3 == null) {
            batchOperator3 = batchOperator.select("`" + str2 + "`").distinct();
            strArr2 = new String[]{str2};
            strArr4 = new String[]{str2};
        } else {
            AkPreconditions.checkArgument(TableUtil.findColTypeWithAssert(batchOperator3.getSchema(), str2).equals(TableUtil.findColTypeWithAssert(batchOperator.getSchema(), str2)), (ExceptionWithErrorCode) new AkIllegalDataException("item column type mismatch"));
        }
        BatchOperator<?> select = batchOperator.select(new String[]{str, str2});
        BatchOperator<?> createFeatureVectors = createFeatureVectors(batchOperator2, str, strArr, strArr3);
        BatchOperator<?> createFeatureVectors2 = createFeatureVectors(batchOperator3, str2, strArr2, strArr4);
        BatchOperator<?> transform = ((VectorAssembler) new VectorAssembler().setMLEnvironmentId(mLEnvironmentId)).setSelectedCols("__user_features__", "__item_features__").setOutputCol("__alink_features__").setReservedCols(str3).transform(((LeftOuterJoinBatchOp) new LeftOuterJoinBatchOp().setMLEnvironmentId(mLEnvironmentId)).setJoinPredicate("a.`" + str2 + "`=b.`" + str2 + "`").setSelectClause("a.*, b.__fm_features__ as __item_features__").linkFrom(((LeftOuterJoinBatchOp) new LeftOuterJoinBatchOp().setMLEnvironmentId(mLEnvironmentId)).setJoinPredicate("a.`" + str + "`=b.`" + str + "`").setSelectClause("a.*, b.__fm_features__ as __user_features__").linkFrom(batchOperator, createFeatureVectors), createFeatureVectors2).udf("__user_features__", "__user_features__", new CheckNotNull()).udf("__item_features__", "__item_features__", new CheckNotNull()));
        BatchOperator batchOperator4 = !this.implicitFeedback ? (BatchOperator) new FmRegressorTrainBatchOp(params).setLabelCol((String) params.get(RATE_COL)).setVectorCol("__alink_features__").setMLEnvironmentId(mLEnvironmentId) : (BatchOperator) new FmClassifierTrainBatchOp(params).setLabelCol((String) params.get(RATE_COL)).setVectorCol("__alink_features__").setMLEnvironmentId(mLEnvironmentId);
        batchOperator4.linkFrom(transform);
        setOutputTable(PackBatchOperatorUtil.packBatchOps(new BatchOperator[]{batchOperator4, createFeatureVectors, createFeatureVectors2, select}).getOutputTable());
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ FmRecommTrainBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
