package com.alibaba.alink.operator.local.utils;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithFirstInputSpec;
import com.alibaba.alink.common.exceptions.AkIllegalDataException;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode;
import com.alibaba.alink.common.mapper.FlatMapper;
import com.alibaba.alink.common.utils.RowCollector;
import com.alibaba.alink.operator.local.AlinkLocalSession;
import com.alibaba.alink.operator.local.LocalOperator;
import com.alibaba.alink.operator.local.utils.FlatMapLocalOp;
import com.alibaba.alink.params.shared.HasNumThreads;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiFunction;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InputPorts(values = {@PortSpec(PortType.DATA)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@Internal
@ReservedColsWithFirstInputSpec
/* loaded from: input_file:com/alibaba/alink/operator/local/utils/FlatMapLocalOp.class */
public class FlatMapLocalOp<T extends FlatMapLocalOp<T>> extends LocalOperator<T> {
    private static final Logger LOG = LoggerFactory.getLogger(FlatMapLocalOp.class);
    protected final BiFunction<TableSchema, Params, FlatMapper> mapperBuilder;

    public FlatMapLocalOp(BiFunction<TableSchema, Params, FlatMapper> biFunction, Params params) {
        super(params);
        this.mapperBuilder = biFunction;
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public T linkFrom(LocalOperator<?>... localOperatorArr) {
        LocalOperator<?> checkAndGetFirst = checkAndGetFirst(localOperatorArr);
        try {
            FlatMapper apply = this.mapperBuilder.apply(checkAndGetFirst.getSchema(), getParams());
            apply.open();
            setOutputTable(new MTable(execFlatMapper(checkAndGetFirst, apply, getParams()), apply.getOutputSchema()));
            apply.close();
            return this;
        } catch (ExceptionWithErrorCode e) {
            throw e;
        } catch (Exception e2) {
            throw new AkUnclassifiedErrorException("Error. ", e2);
        }
    }

    protected static List<Row> execFlatMapper(LocalOperator<?> localOperator, FlatMapper flatMapper, Params params) {
        int defaultNumThreads = LocalOperator.getDefaultNumThreads();
        if (params.contains(HasNumThreads.NUM_THREADS)) {
            defaultNumThreads = ((Integer) params.get(HasNumThreads.NUM_THREADS)).intValue();
        }
        AlinkLocalSession.TaskRunner taskRunner = new AlinkLocalSession.TaskRunner();
        List[] listArr = new List[defaultNumThreads];
        MTable outputTable = localOperator.getOutputTable();
        int numRow = outputTable.getNumRow();
        for (int i = 0; i < defaultNumThreads; i++) {
            int startPos = (int) AlinkLocalSession.DISTRIBUTOR.startPos(i, defaultNumThreads, numRow);
            int localRowCnt = (int) AlinkLocalSession.DISTRIBUTOR.localRowCnt(i, defaultNumThreads, numRow);
            int i2 = i;
            if (localRowCnt > 0) {
                taskRunner.submit(() -> {
                    RowCollector rowCollector = new RowCollector();
                    for (int i3 = startPos; i3 < startPos + localRowCnt; i3++) {
                        try {
                            flatMapper.flatMap(outputTable.getRow(i3), rowCollector);
                        } catch (Exception e) {
                            LOG.error("Execute mapper error.", e);
                            throw new AkIllegalDataException("FlatMap error on the data : " + outputTable.getRow(i3).toString());
                        }
                    }
                    listArr[i2] = rowCollector.getRows();
                });
            }
        }
        taskRunner.join();
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < defaultNumThreads; i3++) {
            if (listArr[i3] != null) {
                arrayList.addAll(listArr[i3]);
            }
        }
        return arrayList;
    }

    @Override // com.alibaba.alink.operator.local.LocalOperator
    public /* bridge */ /* synthetic */ LocalOperator linkFrom(LocalOperator[] localOperatorArr) {
        return linkFrom((LocalOperator<?>[]) localOperatorArr);
    }
}
