package com.alibaba.alink.operator.batch.graph;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.GraphTransformUtils;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.graph.TreeDepthParams;
import java.util.Iterator;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.asm.degree.annotate.directed.VertexInDegree;
import org.apache.flink.graph.pregel.ComputeFunction;
import org.apache.flink.graph.pregel.MessageCombiner;
import org.apache.flink.graph.pregel.MessageIterator;
import org.apache.flink.graph.pregel.VertexCentricConfiguration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.Row;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH, desc = PortDesc.GRPAH_EDGES)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@NameCn("树深度")
@NameEn("Tree Depth")
/* loaded from: input_file:com/alibaba/alink/operator/batch/graph/TreeDepthBatchOp.class */
public class TreeDepthBatchOp extends BatchOperator<TreeDepthBatchOp> implements TreeDepthParams<TreeDepthBatchOp> {
    private static final long serialVersionUID = -6574485904046547006L;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/TreeDepthBatchOp$TreeDepth.class */
    public static class TreeDepth {
        public Integer maxIter;

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/TreeDepthBatchOp$TreeDepth$Execute.class */
        public static class Execute extends ComputeFunction<Long, Tuple3<Long, Long, Double>, Double, Tuple3<Long, Long, Double>> {
            private static final long serialVersionUID = -2503583975560433984L;

            public void compute(Vertex<Long, Tuple3<Long, Long, Double>> vertex, MessageIterator<Tuple3<Long, Long, Double>> messageIterator) throws Exception {
                if (((Long) ((Tuple3) vertex.f1).f1).equals(vertex.f0) && ((Long) ((Tuple3) vertex.f1).f1).equals(((Tuple3) vertex.f1).f0) && ((Double) ((Tuple3) vertex.f1).f2).doubleValue() != Criteria.INVALID_GAIN) {
                    boolean z = false;
                    for (Edge edge : getEdges()) {
                        if (z) {
                            throw new Exception("illegal input!!!");
                        }
                        Object tuple3 = new Tuple3(((Tuple3) vertex.f1).f0, edge.f1, Double.valueOf(-((Double) edge.f2).doubleValue()));
                        sendMessageTo(edge.getTarget(), tuple3);
                        setNewVertexValue(tuple3);
                        z = true;
                    }
                    return;
                }
                Iterator it = messageIterator.iterator();
                while (it.hasNext()) {
                    Tuple3 tuple32 = (Tuple3) it.next();
                    if (((Long) tuple32.f1).equals(vertex.f0)) {
                        sendMessageTo(tuple32.f0, vertex.f1);
                    } else if (((Double) tuple32.f2).doubleValue() >= Criteria.INVALID_GAIN) {
                        setNewVertexValue(new Tuple3(vertex.f0, tuple32.f1, Double.valueOf(((Double) tuple32.f2).doubleValue() - ((Double) ((Tuple3) vertex.f1).f2).doubleValue())));
                    } else {
                        Tuple3 tuple33 = new Tuple3(vertex.f0, tuple32.f1, Double.valueOf(((Double) tuple32.f2).doubleValue() + ((Double) ((Tuple3) vertex.f1).f2).doubleValue()));
                        setNewVertexValue(tuple33);
                        sendMessageTo(tuple33.f1, tuple33);
                    }
                }
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/TreeDepthBatchOp$TreeDepth$JudgeTupleIllegal.class */
        public static class JudgeTupleIllegal implements MapFunction<Vertex<Long, Tuple3<Long, Long, Double>>, Tuple3<Long, Long, Double>> {
            private static final long serialVersionUID = 858956933724773542L;

            public Tuple3<Long, Long, Double> map(Vertex<Long, Tuple3<Long, Long, Double>> vertex) throws Exception {
                if (((Double) ((Tuple3) vertex.f1).f2).doubleValue() < Criteria.INVALID_GAIN) {
                    throw new RuntimeException("illegal input!!!");
                }
                return (Tuple3) vertex.f1;
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/TreeDepthBatchOp$TreeDepth$MapVertexValue.class */
        public static class MapVertexValue implements MapFunction<Vertex<Long, LongValue>, Vertex<Long, Tuple3<Long, Long, Double>>> {
            private static final long serialVersionUID = 2154022863365357679L;

            public Vertex<Long, Tuple3<Long, Long, Double>> map(Vertex<Long, LongValue> vertex) {
                return ((LongValue) vertex.f1).getValue() == 0 ? new Vertex<>(vertex.f0, new Tuple3(vertex.f0, vertex.f0, Double.valueOf(Criteria.INVALID_GAIN))) : new Vertex<>(vertex.f0, new Tuple3(vertex.f0, vertex.f0, Double.valueOf(-1.0d)));
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/TreeDepthBatchOp$TreeDepth$ReverseEdge.class */
        public static class ReverseEdge implements MapFunction<Edge<Long, Double>, Edge<Long, Double>> {
            private static final long serialVersionUID = 7575794558756147475L;

            public Edge<Long, Double> map(Edge<Long, Double> edge) {
                if (((Double) edge.f2).doubleValue() <= Criteria.INVALID_GAIN) {
                    throw new RuntimeException("Edge " + edge + " is illegal. Edge weight must be positive!");
                }
                return new Edge<>(edge.f1, edge.f0, edge.f2);
            }
        }

        public TreeDepth(int i) {
            this.maxIter = Integer.valueOf(i);
        }

        private Graph<Long, Tuple3<Long, Long, Double>, Double> operation(Graph<Long, Double, Double> graph) {
            try {
                DataSet map = ((DataSet) graph.run(new VertexInDegree().setIncludeZeroDegreeVertices(true))).map(new MapVertexValue());
                DataSet map2 = graph.getEdges().map(new ReverseEdge());
                Graph fromDataSet = Graph.fromDataSet(map, map2, BatchOperator.getExecutionEnvironmentFromDataSets(map, map2));
                VertexCentricConfiguration vertexCentricConfiguration = new VertexCentricConfiguration();
                vertexCentricConfiguration.setName("tree depth iteration");
                return fromDataSet.runVertexCentricIteration(new Execute(), (MessageCombiner) null, this.maxIter.intValue(), vertexCentricConfiguration);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public DataSet<Tuple3<Long, Long, Double>> run(Graph<Long, Double, Double> graph) {
            return operation(graph).getVertices().map(new JudgeTupleIllegal());
        }
    }

    public TreeDepthBatchOp(Params params) {
        super(params);
    }

    public TreeDepthBatchOp() {
        super(new Params());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public TreeDepthBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        BatchOperator<?> checkAndGetFirst = checkAndGetFirst(batchOperatorArr);
        String edgeSourceCol = getEdgeSourceCol();
        String edgeTargetCol = getEdgeTargetCol();
        String edgeWeightCol = getEdgeWeightCol();
        boolean z = edgeWeightCol != null;
        String[] strArr = {"vertices", "root", "treeDepth"};
        Integer maxIter = getMaxIter();
        String[] strArr2 = z ? new String[]{edgeSourceCol, edgeTargetCol, edgeWeightCol} : new String[]{edgeSourceCol, edgeTargetCol};
        TypeInformation<?> typeInformation = checkAndGetFirst.getColTypes()[TableUtil.findColIndexWithAssertAndHint(checkAndGetFirst.getColNames(), edgeSourceCol)];
        DataSet<Row> input2json = GraphUtilsWithString.input2json(checkAndGetFirst, strArr2, 2, true);
        GraphUtilsWithString graphUtilsWithString = new GraphUtilsWithString(input2json, typeInformation);
        setOutput(graphUtilsWithString.long2outputTreeDepth(new TreeDepth(maxIter.intValue()).run(Graph.fromDataSet(graphUtilsWithString.inputType2longEdge(input2json, true), MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment()).mapVertices(new GraphTransformUtils.MapVerticesTreeDepth()))), strArr, new TypeInformation[]{typeInformation, typeInformation, Types.DOUBLE});
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ TreeDepthBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
