package com.alibaba.alink.operator.batch.graph;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.graph.memory.MemoryComputeFunction;
import com.alibaba.alink.operator.batch.graph.memory.MemoryVertexCentricIteration;
import com.alibaba.alink.params.graph.PageRankParams;
import java.util.Iterator;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH, desc = PortDesc.GRPAH_EDGES)})
@OutputPorts(values = {@PortSpec(PortType.DATA)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@NameCn("PageRank算法")
@NameEn("PageRank")
/* loaded from: input_file:com/alibaba/alink/operator/batch/graph/PageRankBatchOp.class */
public class PageRankBatchOp extends BatchOperator<PageRankBatchOp> implements PageRankParams<PageRankBatchOp> {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/PageRankBatchOp$PageRankComputeFunction.class */
    public static class PageRankComputeFunction extends MemoryComputeFunction {
        private final double dampingFactor;
        private final double epsilon;

        public PageRankComputeFunction(double d, double d2) {
            this.dampingFactor = d;
            this.epsilon = d2;
        }

        @Override // com.alibaba.alink.operator.batch.graph.memory.MemoryComputeFunction
        public void gatherMessage(long j, double d) {
            incCurVertexValue(j, d * this.dampingFactor);
        }

        @Override // com.alibaba.alink.operator.batch.graph.memory.MemoryComputeFunction
        public void sendMessage(Iterator<Tuple2<Long, Double>> it, long j) {
            if (getSuperStep() == 1) {
                if (!it.hasNext()) {
                    sendMessageTo(-1L, 1.0d / getGraphContext().numVertex);
                    return;
                }
                double lastStepVertexValue = getLastStepVertexValue(j);
                while (it.hasNext()) {
                    Tuple2<Long, Double> next = it.next();
                    sendMessageTo(((Long) next.f0).longValue(), lastStepVertexValue * ((Double) next.f1).doubleValue());
                }
                return;
            }
            double curVertexValue = getCurVertexValue(j);
            double lastStepVertexValue2 = curVertexValue - getLastStepVertexValue(j);
            if (Math.abs(lastStepVertexValue2) / curVertexValue > this.epsilon) {
                if (!it.hasNext()) {
                    sendMessageTo(-1L, lastStepVertexValue2 / getGraphContext().numVertex);
                    return;
                }
                while (it.hasNext()) {
                    Tuple2<Long, Double> next2 = it.next();
                    sendMessageTo(((Long) next2.f0).longValue(), lastStepVertexValue2 * ((Double) next2.f1).doubleValue());
                }
            }
        }

        @Override // com.alibaba.alink.operator.batch.graph.memory.MemoryComputeFunction
        public void initVerticesValues() {
            setAllVertexValues(1.0d);
        }

        @Override // com.alibaba.alink.operator.batch.graph.memory.MemoryComputeFunction
        public void initEdgesValues() {
            normalizeEdgeValuesByVertex();
        }
    }

    public PageRankBatchOp(Params params) {
        super(params);
    }

    public PageRankBatchOp() {
        this(new Params());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public PageRankBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String[] strArr = {"vertex", "label"};
        String[] colNames = checkAndGetFirst.getColNames();
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(colNames, getEdgeSourceCol());
        int findColIndexWithAssertAndHint2 = TableUtil.findColIndexWithAssertAndHint(colNames, getEdgeTargetCol());
        String edgeWeightCol = getEdgeWeightCol();
        String[] strArr2 = edgeWeightCol != null ? new String[]{getEdgeSourceCol(), getEdgeTargetCol(), edgeWeightCol} : new String[]{getEdgeSourceCol(), getEdgeTargetCol()};
        TypeInformation<?> typeInformation = checkAndGetFirst.getColTypes()[findColIndexWithAssertAndHint];
        Preconditions.checkState(typeInformation == checkAndGetFirst.getColTypes()[findColIndexWithAssertAndHint2], "The source and target should be the same type.");
        DataSet<Row> runAndGetVertices = MemoryVertexCentricIteration.runAndGetVertices(checkAndGetFirst.select(strArr2).getDataSet(), typeInformation, edgeWeightCol != null, false, getMLEnvironmentId().longValue(), getMaxIter().intValue(), new PageRankComputeFunction(getDampingFactor().doubleValue(), getEpsilon().doubleValue()));
        setOutput(runAndGetVertices.map(new RichMapFunction<Row, Row>() { // from class: com.alibaba.alink.operator.batch.graph.PageRankBatchOp.3
            double sum;

            public void open(Configuration configuration) throws Exception {
                super.open(configuration);
                this.sum = ((Double) getRuntimeContext().getBroadcastVariable("sumPageRank").get(0)).doubleValue();
            }

            public Row map(Row row) throws Exception {
                return Row.of(new Object[]{row.getField(0), Double.valueOf(((Double) row.getField(1)).doubleValue() / this.sum)});
            }
        }).withBroadcastSet(runAndGetVertices.map(new MapFunction<Row, Double>() { // from class: com.alibaba.alink.operator.batch.graph.PageRankBatchOp.2
            public Double map(Row row) throws Exception {
                return (Double) row.getField(1);
            }
        }).reduce(new ReduceFunction<Double>() { // from class: com.alibaba.alink.operator.batch.graph.PageRankBatchOp.1
            public Double reduce(Double d, Double d2) throws Exception {
                return Double.valueOf(d.doubleValue() + d2.doubleValue());
            }
        }), "sumPageRank"), strArr, new TypeInformation[]{typeInformation, BasicTypeInfo.DOUBLE_TYPE_INFO});
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ PageRankBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
