package com.alibaba.alink.operator.batch.graph;

import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.type.AlinkTypes;
import com.alibaba.alink.common.utils.AlinkSerializable;
import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.common.viz.AlinkViz;
import com.alibaba.alink.common.viz.VizDataWriterInterface;
import com.alibaba.alink.common.viz.VizOpChartData;
import com.alibaba.alink.common.viz.VizOpDataInfo;
import com.alibaba.alink.common.viz.VizOpMeta;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.batch.utils.DataSetUtil;
import com.alibaba.alink.operator.common.dataproc.FirstReducer;
import com.alibaba.alink.params.graph.VertexNeighborSearchParams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.flink.api.common.functions.CrossFunction;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.JoinOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.GraphAlgorithm;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.asm.translate.TranslateFunction;
import org.apache.flink.graph.pregel.ComputeFunction;
import org.apache.flink.graph.pregel.MessageCombiner;
import org.apache.flink.graph.pregel.MessageIterator;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH, desc = PortDesc.GRPAH_EDGES), @PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH, desc = PortDesc.GRAPH_VERTICES, isOptional = true)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "vertexIdCol", portIndices = {1}), @ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@NameCn("点邻居搜索")
@NameEn("Vertex Neighbor Search")
/* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp.class */
public final class VertexNeighborSearchBatchOp extends BatchOperator<VertexNeighborSearchBatchOp> implements VertexNeighborSearchParams<VertexNeighborSearchBatchOp>, AlinkViz<VertexNeighborSearchBatchOp> {
    private static final long serialVersionUID = 4341880061845091326L;
    static int MAX_NUM_EDGES_TO_ES = 1200;

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp$AllInOneGroupCombineFunction.class */
    public static class AllInOneGroupCombineFunction<T> implements GroupCombineFunction<T, List<T>> {
        private static final long serialVersionUID = -7055580437134926663L;

        public void combine(Iterable<T> iterable, Collector<List<T>> collector) {
            ArrayList arrayList = new ArrayList();
            Iterator<T> it = iterable.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            collector.collect(arrayList);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp$Graph2JsonCrossFunction.class */
    public static class Graph2JsonCrossFunction implements CrossFunction<List<Row>, List<Row>, String>, AlinkSerializable {
        private static final long serialVersionUID = -6978604589815110412L;
        private Set<String> selectedVertexIds;
        private String[] edgesColNames;
        private String edgeSourceColName;
        private String edgeTargetColName;
        private String[] verticesColNames;
        private String vertexIdColName;
        private boolean isUndirected;
        private int maxNumEdgesToEs;
        private Object[][] edges;
        private Object[][] vertices;

        Graph2JsonCrossFunction(Set<String> set, String[] strArr, String str, String str2, String[] strArr2, String str3, boolean z, int i) {
            this.selectedVertexIds = set;
            this.edgesColNames = strArr;
            this.edgeSourceColName = str;
            this.edgeTargetColName = str2;
            this.verticesColNames = strArr2;
            this.vertexIdColName = str3;
            this.isUndirected = z;
            this.maxNumEdgesToEs = i;
        }

        /* JADX WARN: Type inference failed for: r1v10, types: [java.lang.Object[], java.lang.Object[][]] */
        /* JADX WARN: Type inference failed for: r1v4, types: [java.lang.Object[], java.lang.Object[][]] */
        public String cross(List<Row> list, List<Row> list2) {
            ArrayList arrayList = new ArrayList();
            int i = 0;
            HashSet hashSet = new HashSet(list.size() * 2);
            for (Row row : list) {
                Object[] objArr = new Object[this.edgesColNames.length];
                for (int i2 = 0; i2 < this.edgesColNames.length; i2++) {
                    objArr[i2] = row.getField(i2);
                }
                arrayList.add(objArr);
                hashSet.add((String) objArr[0]);
                hashSet.add((String) objArr[1]);
                i++;
                if (i >= this.maxNumEdgesToEs) {
                    break;
                }
            }
            this.edges = new Object[arrayList.size()];
            arrayList.toArray(this.edges);
            ArrayList arrayList2 = new ArrayList();
            for (Row row2 : list2) {
                Object[] objArr2 = new Object[this.verticesColNames.length];
                for (int i3 = 0; i3 < this.verticesColNames.length; i3++) {
                    objArr2[i3] = row2.getField(i3);
                }
                if (hashSet.contains(objArr2[0])) {
                    arrayList2.add(objArr2);
                }
            }
            this.vertices = new Object[arrayList2.size()];
            arrayList2.toArray(this.vertices);
            return JsonConverter.gson.toJson(this);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp$OnlySecondJoinFunction.class */
    public static class OnlySecondJoinFunction<IN1, IN2> implements JoinFunction<IN1, IN2, IN2> {
        private static final long serialVersionUID = 5961146867726046621L;

        public IN2 join(IN1 in1, IN2 in2) {
            return in2;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp$Row2EdgeTuple.class */
    public static class Row2EdgeTuple implements MapFunction<Row, Tuple3<String, String, Long>> {
        private static final long serialVersionUID = -7430905996667254712L;
        private int sourceColId;
        private int targetColId;

        Row2EdgeTuple(int i, int i2) {
            this.sourceColId = i;
            this.targetColId = i2;
        }

        public Tuple3<String, String, Long> map(Row row) throws Exception {
            return Tuple3.of((String) row.getField(this.sourceColId), (String) row.getField(this.targetColId), 0L);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp$Row2VertexTuple.class */
    public static class Row2VertexTuple implements MapFunction<Row, Tuple2<String, Long>> {
        private static final long serialVersionUID = -1149958337899075070L;
        private int vertexColId;

        Row2VertexTuple(int i) {
            this.vertexColId = i;
        }

        public Tuple2<String, Long> map(Row row) throws Exception {
            return Tuple2.of((String) row.getField(this.vertexColId), 0L);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp$VertexNeighborSearch.class */
    public static class VertexNeighborSearch implements GraphAlgorithm<String, Long, Long, Graph<String, Long, Long>> {
        private int depth;
        private HashSet<String> sources;

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp$VertexNeighborSearch$FilterByValue.class */
        public static final class FilterByValue implements FilterFunction<Vertex<String, Long>> {
            private static final long serialVersionUID = -3443337881858305297L;
            private int thresh;

            FilterByValue(int i) {
                this.thresh = i;
            }

            public boolean filter(Vertex<String, Long> vertex) {
                return ((Long) vertex.getValue()).longValue() <= ((long) this.thresh);
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp$VertexNeighborSearch$MinimumDistanceCombiner.class */
        public static final class MinimumDistanceCombiner extends MessageCombiner<String, Long> {
            private static final long serialVersionUID = 1916706983491173310L;

            public void combineMessages(MessageIterator<Long> messageIterator) {
                long j = 4611686018427387903L;
                Iterator it = messageIterator.iterator();
                while (it.hasNext()) {
                    j = Math.min(j, ((Long) it.next()).longValue());
                }
                sendCombinedMessage(Long.valueOf(j));
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp$VertexNeighborSearch$SetLongMaxValue.class */
        public static final class SetLongMaxValue implements TranslateFunction<Long, Long> {
            private static final long serialVersionUID = 6439208445273249327L;

            public Long translate(Long l, Long l2) {
                return 4611686018427387903L;
            }
        }

        /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp$VertexNeighborSearch$VertexNeighborComputeFunction.class */
        public static final class VertexNeighborComputeFunction extends ComputeFunction<String, Long, Long, Long> {
            private static final long serialVersionUID = -4927352871373053814L;
            private HashSet<String> sources;

            VertexNeighborComputeFunction(HashSet<String> hashSet) {
                this.sources = hashSet;
            }

            public void compute(Vertex<String, Long> vertex, MessageIterator<Long> messageIterator) {
                long j = this.sources.contains(vertex.getId()) ? 0L : 4611686018427387903L;
                Iterator it = messageIterator.iterator();
                while (it.hasNext()) {
                    j = Math.min(j, ((Long) it.next()).longValue());
                }
                if (j < ((Long) vertex.getValue()).longValue()) {
                    setNewVertexValue(Long.valueOf(j));
                    Iterator it2 = getEdges().iterator();
                    while (it2.hasNext()) {
                        sendMessageTo(((Edge) it2.next()).getTarget(), Long.valueOf(j + 1));
                    }
                }
            }
        }

        public VertexNeighborSearch(HashSet<String> hashSet, int i) {
            this.sources = hashSet;
            this.depth = i;
        }

        public Graph<String, Long, Long> run(Graph<String, Long, Long> graph) throws Exception {
            Graph<String, Long, Long> filterOnVertices = graph.translateVertexValues(new SetLongMaxValue()).runVertexCentricIteration(new VertexNeighborComputeFunction(this.sources), new MinimumDistanceCombiner(), this.depth + 1).filterOnVertices(new FilterByValue(this.depth));
            filterOnVertices.getVertices();
            return filterOnVertices;
        }

        /* renamed from: run, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m266run(Graph graph) throws Exception {
            return run((Graph<String, Long, Long>) graph);
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp$VertexValueInitializer.class */
    public static class VertexValueInitializer implements MapFunction<String, Long> {
        private static final long serialVersionUID = -8771018283053295267L;

        public Long map(String str) throws Exception {
            return 0L;
        }
    }

    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp$VizWriterMapFunction.class */
    public static class VizWriterMapFunction implements MapFunction<String, String> {
        private static final long serialVersionUID = -3562810264038340464L;
        int dataId;
        VizDataWriterInterface writer;

        VizWriterMapFunction(int i, VizDataWriterInterface vizDataWriterInterface) {
            this.dataId = i;
            this.writer = vizDataWriterInterface;
        }

        public String map(String str) throws Exception {
            this.writer.writeBatchData(this.dataId, str, System.currentTimeMillis());
            return str;
        }
    }

    public VertexNeighborSearchBatchOp() {
        super(new Params());
    }

    public VertexNeighborSearchBatchOp(Params params) {
        super(params);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public VertexNeighborSearchBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        Graph<String, Long, Long> fromTupleDataSet;
        DataSet map;
        String[] strArr;
        checkMinOpSize(1, batchOperatorArr);
        VizDataWriterInterface vizDataWriter = getVizDataWriter();
        BatchOperator<?> batchOperator = batchOperatorArr[0];
        Boolean asUndirectedGraph = getAsUndirectedGraph();
        String vertexIdCol = getVertexIdCol();
        String edgeSourceCol = getEdgeSourceCol();
        String edgeTargetCol = getEdgeTargetCol();
        int findColIndexWithAssertAndHint = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), edgeSourceCol);
        int findColIndexWithAssertAndHint2 = TableUtil.findColIndexWithAssertAndHint(batchOperator.getColNames(), edgeTargetCol);
        int intValue = getDepth().intValue();
        HashSet hashSet = new HashSet(Arrays.asList(getSources()));
        MapOperator map2 = batchOperator.getDataSet().map(new Row2EdgeTuple(findColIndexWithAssertAndHint, findColIndexWithAssertAndHint2));
        BatchOperator<?> batchOperator2 = null;
        if (batchOperatorArr.length > 1) {
            batchOperator2 = batchOperatorArr[1];
            fromTupleDataSet = Graph.fromTupleDataSet(batchOperator2.getDataSet().map(new Row2VertexTuple(TableUtil.findColIndexWithAssertAndHint(batchOperator2.getColNames(), vertexIdCol))), map2, MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment());
        } else {
            fromTupleDataSet = Graph.fromTupleDataSet(map2, new VertexValueInitializer(), MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment());
        }
        if (asUndirectedGraph.booleanValue()) {
            fromTupleDataSet = fromTupleDataSet.getUndirected();
        }
        Graph<String, Long, Long> graph = fromTupleDataSet;
        try {
            graph = new VertexNeighborSearch(hashSet, intValue).run(fromTupleDataSet);
        } catch (Exception e) {
            e.printStackTrace();
        }
        JoinOperator.EquiJoin with = graph.getEdgesAsTuple3().joinWithHuge(batchOperator.getDataSet()).where(new int[]{0, 1}).equalTo(new int[]{0, 1}).with(new OnlySecondJoinFunction());
        setOutput((DataSet<Row>) with, batchOperator.getSchema());
        Table[] tableArr = new Table[1];
        setSideOutputTables(new Table[1]);
        if (batchOperator2 != null) {
            map = graph.getVerticesAsTuple2().join(batchOperator2.getDataSet()).where(new int[]{0}).equalTo(new int[]{0}).with(new OnlySecondJoinFunction());
            strArr = batchOperator2.getColNames();
            tableArr[0] = DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) map, batchOperator2.getSchema());
        } else {
            map = graph.getVerticesAsTuple2().project(new int[]{0}).map(new MapFunction<Tuple1<String>, Row>() { // from class: com.alibaba.alink.operator.batch.graph.VertexNeighborSearchBatchOp.1
                private static final long serialVersionUID = 599089156563158818L;

                public Row map(Tuple1<String> tuple1) {
                    return Row.of(new Object[]{tuple1.f0});
                }
            });
            strArr = new String[]{"name"};
            vertexIdCol = "name";
            tableArr[0] = DataSetConversionUtil.toTable(getMLEnvironmentId(), (DataSet<Row>) map, new String[]{"String"}, (TypeInformation<?>[]) new TypeInformation[]{AlinkTypes.STRING});
        }
        setSideOutputTables(tableArr);
        if (vizDataWriter != null) {
            DataSetUtil.linkDummySink(with.reduceGroup(new FirstReducer(MAX_NUM_EDGES_TO_ES)).combineGroup(new AllInOneGroupCombineFunction()).cross(map.combineGroup(new AllInOneGroupCombineFunction())).with(new Graph2JsonCrossFunction(hashSet, batchOperator.getColNames(), edgeSourceCol, edgeTargetCol, strArr, vertexIdCol, asUndirectedGraph.booleanValue(), MAX_NUM_EDGES_TO_ES)).map(new VizWriterMapFunction(0, vizDataWriter)));
            VizOpMeta vizOpMeta = new VizOpMeta();
            vizOpMeta.dataInfos = new VizOpDataInfo[1];
            vizOpMeta.dataInfos[0] = new VizOpDataInfo(0);
            vizOpMeta.cascades = new HashMap();
            vizOpMeta.cascades.put(JsonConverter.gson.toJson(new String[]{"图可视化"}), new VizOpChartData(0));
            vizOpMeta.setSchema(batchOperator.getSchema());
            vizOpMeta.params = getParams();
            vizOpMeta.isOutput = false;
            vizOpMeta.opName = "VertexNeighborSearchBatchOp";
            vizDataWriter.writeBatchMeta(vizOpMeta);
        }
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ VertexNeighborSearchBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
