package com.alibaba.alink.operator.stream.onlinelearning;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.mapper.MapperChain;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.evaluation.BinaryClassMetrics;
import com.alibaba.alink.operator.common.evaluation.BinaryMetricsSummary;
import com.alibaba.alink.operator.common.evaluation.ClassificationEvaluationUtil;
import com.alibaba.alink.operator.common.evaluation.EvaluationUtil;
import com.alibaba.alink.operator.common.modelstream.ModelStreamUtils;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.params.evaluation.EvalBinaryClassParams;
import com.alibaba.alink.params.onlinelearning.BinaryClassModelFilterParams;
import com.alibaba.alink.params.shared.colname.HasPredictionDetailCol;
import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull;
import com.alibaba.alink.pipeline.ModelExporterUtils;
import com.alibaba.alink.pipeline.PipelineStageBase;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.TreeMap;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(value = PortType.MODEL_STREAM, opType = PortSpec.OpType.SAME), @PortSpec(value = PortType.DATA, opType = PortSpec.OpType.SAME)})
@OutputPorts(values = {@PortSpec(PortType.MODEL_STREAM)})
@Internal
@ParamSelectColumnSpec(name = "labelCol")
@NameCn("Pipeline二分类模型过滤")
/* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/BinaryClassPipelineModelFilterStreamOp.class */
public class BinaryClassPipelineModelFilterStreamOp extends StreamOperator<BinaryClassPipelineModelFilterStreamOp> implements BinaryClassModelFilterParams<BinaryClassPipelineModelFilterStreamOp> {

    /* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/BinaryClassPipelineModelFilterStreamOp$FilterProcess.class */
    public static class FilterProcess extends RichCoFlatMapFunction<Row, Row, Row> {
        private final String positiveValue;
        private final double aucThreshold;
        private final double accuracyThreshold;
        private final double logLossThreshold;
        private final TypeInformation<?> labelType;
        private final String modelSchemaStr;
        private final String dataSchemaStr;
        private final String labelCol;
        private String predDetailCol;
        private int indexLabelCol;
        private int indexPredDetailCol;
        private final int timestampColIndex;
        private final int countColIndex;
        private MapperChain mapperChain;
        private transient BinaryMetricsSummary binaryMetricsSummary;
        private transient long[] bin0;
        private transient long[] bin1;
        private final int evalBatchSize;
        private final Map<Timestamp, List<Row>> buffers = new HashMap(0);
        private Long modelBid = 0L;
        private Object[] recordLabel = null;
        Queue<Row> queue = new LinkedList();

        public FilterProcess(TableSchema tableSchema, TableSchema tableSchema2, Params params, int i, int i2, TypeInformation<?> typeInformation) {
            this.modelSchemaStr = TableUtil.schema2SchemaStr(tableSchema);
            this.dataSchemaStr = TableUtil.schema2SchemaStr(tableSchema2);
            this.labelCol = (String) params.get(EvalBinaryClassParams.LABEL_COL);
            this.positiveValue = (String) params.get(BinaryClassModelFilterParams.POS_LABEL_VAL_STR);
            this.aucThreshold = ((Double) params.get(BinaryClassModelFilterParams.AUC_THRESHOLD)).doubleValue();
            this.accuracyThreshold = ((Double) params.get(BinaryClassModelFilterParams.ACCURACY_THRESHOLD)).doubleValue();
            this.logLossThreshold = ((Double) params.get(BinaryClassModelFilterParams.LOG_LOSS)).doubleValue();
            this.evalBatchSize = ((Integer) params.get(BinaryClassModelFilterParams.NUM_EVAL_SAMPLES)).intValue();
            this.timestampColIndex = i;
            this.countColIndex = i2;
            this.labelType = typeInformation;
        }

        public void flatMap2(Row row, Collector<Row> collector) throws Exception {
            this.queue.add(row);
            if (this.queue.size() > this.evalBatchSize) {
                this.queue.remove();
            }
        }

        private void evalRow(Row row) throws Exception {
            if (null == this.mapperChain) {
                return;
            }
            Row map = this.mapperChain.map(row);
            Row of = Row.of(new Object[]{map.getField(this.indexLabelCol), map.getField(this.indexPredDetailCol)});
            if (this.bin0 == null) {
                this.bin0 = new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER];
                this.bin1 = new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER];
            }
            if (EvaluationUtil.checkRowFieldNotNull(row)) {
                TreeMap<Object, Double> extractLabelProbMap = EvaluationUtil.extractLabelProbMap(of, this.labelType);
                if (null == this.recordLabel) {
                    this.recordLabel = (Object[]) ClassificationEvaluationUtil.buildLabelIndexLabelArray(new HashSet(extractLabelProbMap.keySet()), true, this.positiveValue, this.labelType, true).f1;
                    this.binaryMetricsSummary = new BinaryMetricsSummary(this.bin0, this.bin1, this.recordLabel, Criteria.INVALID_GAIN, 0L);
                }
                EvaluationUtil.updateBinaryMetricsSummary(EvaluationUtil.extractLabelProbMap(of, this.labelType), of.getField(0), this.binaryMetricsSummary);
            }
        }

        public void flatMap1(Row row, Collector<Row> collector) throws Exception {
            Timestamp timestamp = (Timestamp) row.getField(this.timestampColIndex);
            long longValue = ((Long) row.getField(this.countColIndex)).longValue();
            if (!this.buffers.containsKey(timestamp) || this.buffers.get(timestamp).size() != ((int) longValue) - 1) {
                if (this.buffers.containsKey(timestamp)) {
                    this.buffers.get(timestamp).add(row);
                    return;
                }
                ArrayList arrayList = new ArrayList(0);
                arrayList.add(row);
                this.buffers.put(timestamp, arrayList);
                return;
            }
            if (this.buffers.containsKey(timestamp)) {
                this.buffers.get(timestamp).add(row);
            } else {
                ArrayList arrayList2 = new ArrayList(0);
                arrayList2.add(row);
                this.buffers.put(timestamp, arrayList2);
            }
            List<Row> list = this.buffers.get(timestamp);
            ArrayList arrayList3 = new ArrayList();
            Iterator<Row> it = list.iterator();
            while (it.hasNext()) {
                arrayList3.add(ModelStreamUtils.genRowWithoutIdentifier(it.next(), this.timestampColIndex, this.countColIndex));
            }
            TableSchema loadPipelineModel = loadPipelineModel(arrayList3);
            this.indexLabelCol = TableUtil.findColIndex(loadPipelineModel, this.labelCol);
            this.indexPredDetailCol = TableUtil.findColIndex(loadPipelineModel, this.predDetailCol);
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println("load model : " + this.modelBid);
            }
            Long l = this.modelBid;
            this.modelBid = Long.valueOf(this.modelBid.longValue() + 1);
            if (this.queue.size() != 0) {
                Iterator<Row> it2 = this.queue.iterator();
                while (it2.hasNext()) {
                    evalRow(it2.next());
                }
            }
            if (null != this.recordLabel) {
                BinaryClassMetrics metrics = this.binaryMetricsSummary.toMetrics();
                double doubleValue = metrics.getAuc().doubleValue();
                double accuracy = metrics.getAccuracy();
                double doubleValue2 = metrics.getLogLoss().doubleValue();
                if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                    System.out.println("auc : " + doubleValue + "     accuracy : " + accuracy + "    logLoss : " + doubleValue2);
                }
                if (doubleValue >= this.aucThreshold && accuracy >= this.accuracyThreshold && doubleValue2 < this.logLossThreshold) {
                    Iterator<Row> it3 = list.iterator();
                    while (it3.hasNext()) {
                        collector.collect(it3.next());
                    }
                }
            }
            this.recordLabel = null;
            for (int i = 0; i < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER; i++) {
                this.bin0[i] = 0;
                this.bin1[i] = 0;
            }
            this.buffers.remove(timestamp);
        }

        private TableSchema loadPipelineModel(List<Row> list) {
            List<Tuple3<PipelineStageBase<?>, TableSchema, List<Row>>> loadStagesFromPipelineModel = ModelExporterUtils.loadStagesFromPipelineModel(list, TableUtil.schemaStr2Schema(this.modelSchemaStr));
            Params params = ((PipelineStageBase) loadStagesFromPipelineModel.get(loadStagesFromPipelineModel.size() - 1).f0).getParams();
            params.set((ParamInfo<ParamInfo<String[]>>) HasReservedColsDefaultAsNull.RESERVED_COLS, (ParamInfo<String[]>) new String[]{this.labelCol});
            this.predDetailCol = (String) params.get(HasPredictionDetailCol.PREDICTION_DETAIL_COL);
            this.mapperChain = ModelExporterUtils.loadMapperListFromStages(loadStagesFromPipelineModel, TableUtil.schemaStr2Schema(this.dataSchemaStr));
            this.mapperChain.open();
            return this.mapperChain.getMappers()[this.mapperChain.getMappers().length - 1].getOutputSchema();
        }

        public /* bridge */ /* synthetic */ void flatMap2(Object obj, Collector collector) throws Exception {
            flatMap2((Row) obj, (Collector<Row>) collector);
        }

        public /* bridge */ /* synthetic */ void flatMap1(Object obj, Collector collector) throws Exception {
            flatMap1((Row) obj, (Collector<Row>) collector);
        }
    }

    public BinaryClassPipelineModelFilterStreamOp() {
    }

    public BinaryClassPipelineModelFilterStreamOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public BinaryClassPipelineModelFilterStreamOp linkFrom(StreamOperator<?>... streamOperatorArr) {
        checkOpSize(2, streamOperatorArr);
        try {
            TableSchema schema = streamOperatorArr[0].getSchema();
            int findTimestampColIndexWithAssertAndHint = ModelStreamUtils.findTimestampColIndexWithAssertAndHint(schema);
            int findCountColIndexWithAssertAndHint = ModelStreamUtils.findCountColIndexWithAssertAndHint(schema);
            setOutput((DataStream<Row>) streamOperatorArr[0].getDataStream().connect(streamOperatorArr[1].getDataStream()).flatMap(new FilterProcess(ModelStreamUtils.getRawModelSchema(schema, findTimestampColIndexWithAssertAndHint, findCountColIndexWithAssertAndHint), streamOperatorArr[1].getSchema(), getParams(), findTimestampColIndexWithAssertAndHint, findCountColIndexWithAssertAndHint, streamOperatorArr[1].getSchema().getFieldTypes()[TableUtil.findColIndex(streamOperatorArr[1].getSchema(), getLabelCol())])).setParallelism(1), streamOperatorArr[0].getSchema());
            return this;
        } catch (Exception e) {
            e.printStackTrace();
            throw new AkUnclassifiedErrorException(e.toString());
        }
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public /* bridge */ /* synthetic */ BinaryClassPipelineModelFilterStreamOp linkFrom(StreamOperator[] streamOperatorArr) {
        return linkFrom((StreamOperator<?>[]) streamOperatorArr);
    }
}
