package com.alibaba.alink.operator.stream.onlinelearning;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.Internal;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.RowUtil;
import com.alibaba.alink.operator.common.evaluation.BinaryClassMetrics;
import com.alibaba.alink.operator.common.evaluation.BinaryMetricsSummary;
import com.alibaba.alink.operator.common.evaluation.ClassificationEvaluationUtil;
import com.alibaba.alink.operator.common.evaluation.EvaluationUtil;
import com.alibaba.alink.operator.common.modelstream.ModelStreamUtils;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.onlinelearning.BinaryClassModelFilterStreamOp;
import com.alibaba.alink.params.evaluation.EvalBinaryClassParams;
import com.alibaba.alink.params.onlinelearning.BinaryClassModelFilterParams;
import com.alibaba.alink.params.shared.colname.HasLabelCol;
import com.alibaba.alink.params.shared.colname.HasPredictionCol;
import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.function.TriFunction;

@InputPorts(values = {@PortSpec(value = PortType.MODEL_STREAM, opType = PortSpec.OpType.SAME), @PortSpec(value = PortType.DATA, opType = PortSpec.OpType.SAME)})
@OutputPorts(values = {@PortSpec(PortType.MODEL_STREAM)})
@Internal
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES}), @ParamSelectColumnSpec(name = "labelCol")})
@NameCn("二分类模型过滤")
/* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/BinaryClassModelFilterStreamOp.class */
public class BinaryClassModelFilterStreamOp<T extends BinaryClassModelFilterStreamOp<T>> extends StreamOperator<T> {
    private final TriFunction<TableSchema, TableSchema, Params, ModelMapper> mapperBuilder;

    /* loaded from: input_file:com/alibaba/alink/operator/stream/onlinelearning/BinaryClassModelFilterStreamOp$FilterProcess.class */
    public static class FilterProcess extends RichCoFlatMapFunction<Row, Row, Row> {
        private final String positiveValue;
        private final double aucThreshold;
        private final double accuracyThreshold;
        private final double logLossThreshold;
        private final TypeInformation<?> labelType;
        private final ModelMapper mapper;
        private final int timestampColIndex;
        private final int countColIndex;
        private List<Row> model;
        private List<Row> oldModel;
        private transient BinaryMetricsSummary binaryMetricsSummary;
        private transient long[] bin0;
        private transient long[] bin1;
        private final List<Row> bufferRows = new ArrayList();
        private final Map<Timestamp, List<Row>> buffers = new HashMap(0);
        private Long modelBid = 0L;
        private Long evalID = 1L;
        private Object[] recordLabel = null;

        public FilterProcess(TableSchema tableSchema, TableSchema tableSchema2, Params params, TriFunction<TableSchema, TableSchema, Params, ModelMapper> triFunction, int i, int i2) {
            this.mapper = (ModelMapper) triFunction.apply(tableSchema, tableSchema2, params);
            this.labelType = tableSchema.getFieldTypes()[tableSchema.getFieldTypes().length - 1];
            this.positiveValue = (String) params.get(BinaryClassModelFilterParams.POS_LABEL_VAL_STR);
            this.aucThreshold = ((Double) params.get(BinaryClassModelFilterParams.AUC_THRESHOLD)).doubleValue();
            this.accuracyThreshold = ((Double) params.get(BinaryClassModelFilterParams.ACCURACY_THRESHOLD)).doubleValue();
            this.logLossThreshold = ((Double) params.get(BinaryClassModelFilterParams.LOG_LOSS)).doubleValue();
            this.timestampColIndex = i;
            this.countColIndex = i2;
        }

        public void flatMap2(Row row, Collector<Row> collector) throws Exception {
            if (this.modelBid.equals(0L)) {
                this.bufferRows.add(row);
            } else {
                evalRow(row, collector);
            }
        }

        private void evalRow(Row row, Collector<Row> collector) throws Exception {
            Row remove = RowUtil.remove(this.mapper.map(row), 1);
            if (this.bin0 == null) {
                this.bin0 = new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER];
                this.bin1 = new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER];
            }
            if (EvaluationUtil.checkRowFieldNotNull(row)) {
                TreeMap<Object, Double> extractLabelProbMap = EvaluationUtil.extractLabelProbMap(remove, this.labelType);
                if (null == this.recordLabel) {
                    this.recordLabel = (Object[]) ClassificationEvaluationUtil.buildLabelIndexLabelArray(new HashSet(extractLabelProbMap.keySet()), true, this.positiveValue, this.labelType, true).f1;
                    this.binaryMetricsSummary = new BinaryMetricsSummary(this.bin0, this.bin1, this.recordLabel, Criteria.INVALID_GAIN, 0L);
                }
                EvaluationUtil.updateBinaryMetricsSummary(EvaluationUtil.extractLabelProbMap(remove, this.labelType), remove.getField(0), this.binaryMetricsSummary);
            }
            if (this.modelBid.longValue() > this.evalID.longValue()) {
                if (null != this.recordLabel) {
                    BinaryClassMetrics metrics = this.binaryMetricsSummary.toMetrics();
                    double doubleValue = metrics.getAuc().doubleValue();
                    double accuracy = metrics.getAccuracy();
                    double doubleValue2 = metrics.getLogLoss().doubleValue();
                    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                        System.out.println("auc : " + doubleValue + "     accuracy : " + accuracy + "    logLoss : " + doubleValue2);
                    }
                    if (doubleValue >= this.aucThreshold && accuracy >= this.accuracyThreshold && doubleValue2 < this.logLossThreshold) {
                        Iterator<Row> it = this.oldModel.iterator();
                        while (it.hasNext()) {
                            collector.collect(it.next());
                        }
                    }
                }
                Long l = this.evalID;
                this.evalID = Long.valueOf(this.evalID.longValue() + 1);
                this.recordLabel = null;
                for (int i = 0; i < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER; i++) {
                    this.bin0[i] = 0;
                    this.bin1[i] = 0;
                }
            }
        }

        public void flatMap1(Row row, Collector<Row> collector) throws Exception {
            Timestamp timestamp = (Timestamp) row.getField(this.timestampColIndex);
            long longValue = ((Long) row.getField(this.countColIndex)).longValue();
            if (!this.buffers.containsKey(timestamp) || this.buffers.get(timestamp).size() != ((int) longValue) - 1) {
                if (this.buffers.containsKey(timestamp)) {
                    this.buffers.get(timestamp).add(row);
                    return;
                }
                ArrayList arrayList = new ArrayList(0);
                arrayList.add(row);
                this.buffers.put(timestamp, arrayList);
                return;
            }
            if (this.buffers.containsKey(timestamp)) {
                this.buffers.get(timestamp).add(row);
            } else {
                ArrayList arrayList2 = new ArrayList(0);
                arrayList2.add(row);
                this.buffers.put(timestamp, arrayList2);
            }
            this.oldModel = this.model;
            this.model = this.buffers.get(timestamp);
            ArrayList arrayList3 = new ArrayList();
            Iterator<Row> it = this.model.iterator();
            while (it.hasNext()) {
                arrayList3.add(ModelStreamUtils.genRowWithoutIdentifier(it.next(), this.timestampColIndex, this.countColIndex));
            }
            this.mapper.loadModel(arrayList3);
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println("load model : " + this.modelBid);
            }
            Long l = this.modelBid;
            this.modelBid = Long.valueOf(this.modelBid.longValue() + 1);
            if (this.bufferRows.size() != 0) {
                Iterator<Row> it2 = this.bufferRows.iterator();
                while (it2.hasNext()) {
                    evalRow(it2.next(), collector);
                }
                this.bufferRows.clear();
            }
            this.buffers.remove(timestamp);
        }

        public /* bridge */ /* synthetic */ void flatMap2(Object obj, Collector collector) throws Exception {
            flatMap2((Row) obj, (Collector<Row>) collector);
        }

        public /* bridge */ /* synthetic */ void flatMap1(Object obj, Collector collector) throws Exception {
            flatMap1((Row) obj, (Collector<Row>) collector);
        }
    }

    public BinaryClassModelFilterStreamOp(TriFunction<TableSchema, TableSchema, Params, ModelMapper> triFunction, Params params) {
        super(params);
        this.mapperBuilder = triFunction;
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public T linkFrom(StreamOperator<?>... streamOperatorArr) {
        checkOpSize(2, streamOperatorArr);
        getParams().set((ParamInfo<ParamInfo<String>>) HasPredictionCol.PREDICTION_COL, (ParamInfo<String>) "alink_inner_pred").set((ParamInfo<ParamInfo<String>>) EvalBinaryClassParams.PREDICTION_DETAIL_COL, (ParamInfo<String>) "alink_inner_detail").set((ParamInfo<ParamInfo<String[]>>) HasReservedColsDefaultAsNull.RESERVED_COLS, (ParamInfo<String[]>) new String[]{(String) getParams().get(HasLabelCol.LABEL_COL)});
        try {
            TableSchema schema = streamOperatorArr[0].getSchema();
            int findTimestampColIndexWithAssertAndHint = ModelStreamUtils.findTimestampColIndexWithAssertAndHint(schema);
            int findCountColIndexWithAssertAndHint = ModelStreamUtils.findCountColIndexWithAssertAndHint(schema);
            setOutput((DataStream<Row>) streamOperatorArr[0].getDataStream().connect(streamOperatorArr[1].getDataStream()).flatMap(new FilterProcess(ModelStreamUtils.getRawModelSchema(schema, findTimestampColIndexWithAssertAndHint, findCountColIndexWithAssertAndHint), streamOperatorArr[1].getSchema(), getParams(), this.mapperBuilder, findTimestampColIndexWithAssertAndHint, findCountColIndexWithAssertAndHint)).setParallelism(1), streamOperatorArr[0].getSchema());
            return this;
        } catch (Exception e) {
            e.printStackTrace();
            throw new AkUnclassifiedErrorException(e.toString());
        }
    }

    @Override // com.alibaba.alink.operator.stream.StreamOperator
    public /* bridge */ /* synthetic */ StreamOperator linkFrom(StreamOperator[] streamOperatorArr) {
        return linkFrom((StreamOperator<?>[]) streamOperatorArr);
    }
}
