package com.alibaba.alink.operator.local.feature;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.SelectedColsWithFirstInputSpec;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.feature.OneHotModelDataConverter;
import com.alibaba.alink.operator.common.feature.OneHotModelInfo;
import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.operator.local.lazy.WithModelInfoLocalOp;
import com.alibaba.alink.params.dataproc.HasSelectedColTypes;
import com.alibaba.alink.params.feature.HasEnableElse;
import com.alibaba.alink.params.feature.OneHotTrainParams;
import com.alibaba.alink.params.shared.colname.HasSelectedCols;
import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH)})
@OutputPorts(values = {@PortSpec(PortType.MODEL)})
@SelectedColsWithFirstInputSpec
@NameCn("独热编码训练")
@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.OneHotEncoder")
/* loaded from: input_file:com/alibaba/alink/operator/local/feature/OneHotTrainLocalOp.class */
public final class OneHotTrainLocalOp extends LocalOperator<OneHotTrainLocalOp> implements OneHotTrainParams<OneHotTrainLocalOp>, WithModelInfoLocalOp<OneHotModelInfo, OneHotTrainLocalOp, OneHotModelInfoLocalOp> {
    public OneHotTrainLocalOp() {
        super(null);
    }

    public OneHotTrainLocalOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.local.LocalOperator
    public OneHotTrainLocalOp linkFrom(LocalOperator<?>... localOperatorArr) {
        int[] array;
        LocalOperator<?> checkAndGetFirst = checkAndGetFirst(localOperatorArr);
        String[] selectedCols = getSelectedCols();
        String[] strArr = new String[selectedCols.length];
        for (int i = 0; i < selectedCols.length; i++) {
            strArr[i] = FlinkTypeConverter.getTypeString(TableUtil.findColTypeWithAssertAndHint(checkAndGetFirst.getSchema(), selectedCols[i]));
        }
        if (getParams().contains(OneHotTrainParams.DISCRETE_THRESHOLDS_ARRAY)) {
            array = Arrays.stream(getDiscreteThresholdsArray()).mapToInt((v0) -> {
                return v0.intValue();
            }).toArray();
        } else {
            array = new int[selectedCols.length];
            Arrays.fill(array, getDiscreteThresholds().intValue());
        }
        boolean isEnableElse = isEnableElse(array);
        List<Tuple3<Integer, String, Long>> indexedToken = indexedToken(checkAndGetFirst.select(selectedCols).getOutputTable().getRows(), true, array);
        RowCollector rowCollector = new RowCollector();
        new OneHotModelDataConverter().save2(Tuple2.of(new Params().set((ParamInfo<ParamInfo<String[]>>) HasSelectedCols.SELECTED_COLS, (ParamInfo<String[]>) selectedCols).set((ParamInfo<ParamInfo<String[]>>) HasSelectedColTypes.SELECTED_COL_TYPES, (ParamInfo<String[]>) strArr).set((ParamInfo<ParamInfo<Boolean>>) HasEnableElse.ENABLE_ELSE, (ParamInfo<Boolean>) Boolean.valueOf(isEnableElse)), indexedToken), (Collector<Row>) rowCollector);
        long[] jArr = new long[selectedCols.length];
        for (Tuple3<Integer, String, Long> tuple3 : indexedToken) {
            jArr[((Integer) tuple3.f0).intValue()] = Math.max(jArr[((Integer) tuple3.f0).intValue()], ((Long) tuple3.f2).longValue());
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < selectedCols.length; i2++) {
            arrayList.add(Row.of(new Object[]{selectedCols[i2], Long.valueOf(jArr[i2] + 1)}));
        }
        setOutputTable(new MTable(rowCollector.getRows(), new OneHotModelDataConverter().getModelSchema()));
        setSideOutputTables(new MTable[]{new MTable(arrayList, new String[]{"selectedCol", "distinctTokenNumber"}, (TypeInformation<?>[]) new TypeInformation[]{Types.STRING, Types.LONG})});
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.local.lazy.WithModelInfoLocalOp
    public OneHotModelInfoLocalOp getModelInfoLocalOp() {
        return new OneHotModelInfoLocalOp(getParams()).linkFrom(this);
    }

    private static boolean isEnableElse(int[] iArr) {
        for (int i : iArr) {
            if (i > 0) {
                return true;
            }
        }
        return false;
    }

    private static List<Tuple3<Integer, String, Long>> indexedToken(List<Row> list, boolean z, int[] iArr) {
        int arity = list.get(0).getArity();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < arity; i++) {
            long j = 0;
            HashMap hashMap = new HashMap();
            Iterator<Row> it = list.iterator();
            while (it.hasNext()) {
                Object field = it.next().getField(i);
                if (field != null) {
                    String valueOf = String.valueOf(field);
                    hashMap.put(valueOf, Long.valueOf(1 + ((Long) hashMap.getOrDefault(valueOf, 0L)).longValue()));
                } else if (!z) {
                    j++;
                }
            }
            long j2 = 0;
            if (j != 0 && j >= iArr[i]) {
                arrayList.add(Tuple3.of(Integer.valueOf(i), (Object) null, 0L));
                j2 = 0 + 1;
            }
            for (Map.Entry entry : hashMap.entrySet()) {
                if (((Long) entry.getValue()).longValue() >= iArr[i]) {
                    arrayList.add(Tuple3.of(Integer.valueOf(i), entry.getKey(), Long.valueOf(j2)));
                    j2++;
                }
            }
        }
        return arrayList;
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public /* bridge */ /* synthetic */ OneHotTrainLocalOp linkFrom(LocalOperator[] localOperatorArr) {
        return linkFrom((LocalOperator<?>[]) localOperatorArr);
    }
}
