package com.alibaba.alink.operator.batch.graph;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.NameEn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpecs;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.dataproc.HugeIndexerStringPredictBatchOp;
import com.alibaba.alink.operator.batch.dataproc.HugeStringIndexerPredictBatchOp;
import com.alibaba.alink.operator.batch.dataproc.StringIndexerTrainBatchOp;
import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil;
import com.alibaba.alink.operator.common.tree.Criteria;
import com.alibaba.alink.params.graph.CommonNeighborsTrainParams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

@InputPorts(values = {@PortSpec(value = PortType.DATA, opType = PortSpec.OpType.BATCH, desc = PortDesc.GRPAH_EDGES)})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpecs({@ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR}), @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = {VectorUtil.VectorSerialType.DENSE_VECTOR})})
@NameCn("共同邻居计算")
@NameEn("Common Neighbors")
/* loaded from: input_file:com/alibaba/alink/operator/batch/graph/CommonNeighborsBatchOp.class */
public class CommonNeighborsBatchOp extends BatchOperator<CommonNeighborsBatchOp> implements CommonNeighborsTrainParams<CommonNeighborsBatchOp> {
    private static final long serialVersionUID = -9221019571132151284L;
    private static final String COMMON_NEIGHBOR_ID_COL = "alink_common_neighbors_col";
    private static final String ADAMIC_OUTPUT_COL = "adamic_score";
    private static final String JACCARDS_OUTPUT_COL = "jaccards_score";
    private static final String NEIGHBORS_OUTPUT_COL = "neighbors";
    private static final String CN_OUTPUT_COL = "cn_count";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/CommonNeighborsBatchOp$AddReverseEdge.class */
    public static class AddReverseEdge implements MapPartitionFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> {
        private static final long serialVersionUID = 8536124777027322457L;

        AddReverseEdge() {
        }

        public void mapPartition(Iterable<Tuple2<Long, Long>> iterable, Collector<Tuple2<Long, Long>> collector) throws Exception {
            iterable.forEach(tuple2 -> {
                collector.collect(tuple2);
                collector.collect(Tuple2.of(tuple2.f1, tuple2.f0));
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/CommonNeighborsBatchOp$BuildNeighborData.class */
    public static class BuildNeighborData implements GroupReduceFunction<Tuple2<Long, Long>, Tuple3<Long, Long, Long[]>> {
        private static final long serialVersionUID = -987657040866491039L;

        BuildNeighborData() {
        }

        public void reduce(Iterable<Tuple2<Long, Long>> iterable, Collector<Tuple3<Long, Long, Long[]>> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            Long l = null;
            for (Tuple2<Long, Long> tuple2 : iterable) {
                l = (Long) tuple2.f0;
                arrayList.add(tuple2.f1);
            }
            Long[] lArr = new Long[arrayList.size()];
            int i = 0;
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                lArr[i2] = (Long) it.next();
            }
            for (Long l2 : lArr) {
                collector.collect(Tuple3.of(l2, l, lArr));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/CommonNeighborsBatchOp$ComputeCommonNeighbors.class */
    public static class ComputeCommonNeighbors implements GroupReduceFunction<Tuple3<Long, Long, Long[]>, Tuple4<Long, Long, Long[], Double>> {
        private static final long serialVersionUID = -4665094384893739366L;

        ComputeCommonNeighbors() {
        }

        public void reduce(Iterable<Tuple3<Long, Long, Long[]>> iterable, Collector<Tuple4<Long, Long, Long[], Double>> collector) throws Exception {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (Tuple3<Long, Long, Long[]> tuple3 : iterable) {
                arrayList.add(tuple3.f1);
                HashSet hashSet = new HashSet();
                hashSet.addAll(Arrays.asList((Object[]) tuple3.f2));
                arrayList2.add(hashSet);
            }
            for (int i = 0; i < arrayList2.size(); i++) {
                for (int i2 = i + 1; i2 < arrayList2.size(); i2++) {
                    HashSet hashSet2 = (HashSet) ((HashSet) arrayList2.get(i)).clone();
                    hashSet2.retainAll((Collection) arrayList2.get(i2));
                    Long[] lArr = new Long[hashSet2.size()];
                    int i3 = 0;
                    Iterator it = hashSet2.iterator();
                    while (it.hasNext()) {
                        int i4 = i3;
                        i3++;
                        lArr[i4] = (Long) it.next();
                    }
                    Double valueOf = Double.valueOf((lArr.length * 1.0d) / ((((HashSet) arrayList2.get(i)).size() + ((HashSet) arrayList2.get(i2)).size()) - lArr.length));
                    collector.collect(Tuple4.of(arrayList.get(i), arrayList.get(i2), lArr, valueOf));
                    collector.collect(Tuple4.of(arrayList.get(i2), arrayList.get(i), lArr, valueOf));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/alibaba/alink/operator/batch/graph/CommonNeighborsBatchOp$LongTupleData.class */
    public static class LongTupleData implements MapPartitionFunction<Row, Tuple2<Long, Long>> {
        private static final long serialVersionUID = -2659982156687136255L;

        LongTupleData() {
        }

        public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Long, Long>> collector) throws Exception {
            iterable.forEach(row -> {
                collector.collect(Tuple2.of((Long) row.getField(0), (Long) row.getField(1)));
            });
        }
    }

    public CommonNeighborsBatchOp(Params params) {
        super(params);
    }

    public CommonNeighborsBatchOp() {
        this(new Params());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public CommonNeighborsBatchOp linkFrom(BatchOperator<?>... batchOperatorArr) {
        checkOpSize(1, batchOperatorArr);
        String edgeSourceCol = getEdgeSourceCol();
        String edgeTargetCol = getEdgeTargetCol();
        String[] strArr = {edgeSourceCol, edgeTargetCol};
        Long mLEnvironmentId = getMLEnvironmentId();
        boolean booleanValue = getNeedTransformID().booleanValue();
        BatchOperator<?> select = batchOperatorArr[0].select(strArr);
        StringIndexerTrainBatchOp stringIndexerTrainBatchOp = new StringIndexerTrainBatchOp();
        if (booleanValue) {
            stringIndexerTrainBatchOp.setSelectedCol(edgeSourceCol).setSelectedCols(edgeTargetCol).setStringOrderType("random").linkFrom(select);
            select = new HugeStringIndexerPredictBatchOp().setSelectedCols(strArr).linkFrom(stringIndexerTrainBatchOp, select);
        }
        DataSet mapPartition = select.getDataSet().mapPartition(new LongTupleData());
        if (!getIsBipartiteGraph().booleanValue()) {
            mapPartition = mapPartition.mapPartition(new AddReverseEdge());
        }
        Operator name = mapPartition.groupBy(new int[]{1}).reduceGroup(new GroupReduceFunction<Tuple2<Long, Long>, Tuple2<Long, Double>>() { // from class: com.alibaba.alink.operator.batch.graph.CommonNeighborsBatchOp.1
            public void reduce(Iterable<Tuple2<Long, Long>> iterable, Collector<Tuple2<Long, Double>> collector) throws Exception {
                int i = 0;
                Long l = null;
                Iterator<Tuple2<Long, Long>> it = iterable.iterator();
                while (it.hasNext()) {
                    i++;
                    l = (Long) it.next().f1;
                }
                Double valueOf = Double.valueOf(Criteria.INVALID_GAIN);
                if (i > 1) {
                    valueOf = Double.valueOf(1.0d / Math.log(i));
                }
                collector.collect(Tuple2.of(l, valueOf));
            }
        }).name("compute_adamic_weight");
        Operator name2 = mapPartition.groupBy(new int[]{0}).reduceGroup(new BuildNeighborData()).name("build_neighbor_data").groupBy(new int[]{0}).reduceGroup(new ComputeCommonNeighbors()).name("compute_neighbor_data").groupBy(new int[]{0, 1}).reduceGroup(new GroupReduceFunction<Tuple4<Long, Long, Long[], Double>, Tuple4<Long, Long, Long[], Double>>() { // from class: com.alibaba.alink.operator.batch.graph.CommonNeighborsBatchOp.2
            public void reduce(Iterable<Tuple4<Long, Long, Long[], Double>> iterable, Collector<Tuple4<Long, Long, Long[], Double>> collector) throws Exception {
                collector.collect(iterable.iterator().next());
            }
        }).name("filter_multi_data");
        Operator name3 = name2.join(name2.flatMap(new FlatMapFunction<Tuple4<Long, Long, Long[], Double>, Tuple3<Long, Long, Long>>() { // from class: com.alibaba.alink.operator.batch.graph.CommonNeighborsBatchOp.5
            public void flatMap(Tuple4<Long, Long, Long[], Double> tuple4, Collector<Tuple3<Long, Long, Long>> collector) throws Exception {
                for (Long l : (Long[]) tuple4.f2) {
                    collector.collect(Tuple3.of(tuple4.f0, tuple4.f1, l));
                }
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((Tuple4<Long, Long, Long[], Double>) obj, (Collector<Tuple3<Long, Long, Long>>) collector);
            }
        }).leftOuterJoin(name).where(new int[]{2}).equalTo(new int[]{0}).with(new JoinFunction<Tuple3<Long, Long, Long>, Tuple2<Long, Double>, Tuple3<Long, Long, Double>>() { // from class: com.alibaba.alink.operator.batch.graph.CommonNeighborsBatchOp.4
            public Tuple3<Long, Long, Double> join(Tuple3<Long, Long, Long> tuple3, Tuple2<Long, Double> tuple2) throws Exception {
                return Tuple3.of(tuple3.f0, tuple3.f1, tuple2.f1);
            }
        }).groupBy(new int[]{0, 1}).reduceGroup(new GroupReduceFunction<Tuple3<Long, Long, Double>, Tuple3<Long, Long, Double>>() { // from class: com.alibaba.alink.operator.batch.graph.CommonNeighborsBatchOp.3
            public void reduce(Iterable<Tuple3<Long, Long, Double>> iterable, Collector<Tuple3<Long, Long, Double>> collector) throws Exception {
                Long l = null;
                Long l2 = null;
                Double valueOf = Double.valueOf(Criteria.INVALID_GAIN);
                for (Tuple3<Long, Long, Double> tuple3 : iterable) {
                    l = (Long) tuple3.f0;
                    l2 = (Long) tuple3.f1;
                    valueOf = Double.valueOf(valueOf.doubleValue() + ((Double) tuple3.f2).doubleValue());
                }
                collector.collect(Tuple3.of(l, l2, valueOf));
            }
        }).name("merge_adamic_adar_score")).where(new int[]{0, 1}).equalTo(new int[]{0, 1}).projectFirst(new int[]{0, 1, 2, 3}).projectSecond(new int[]{2}).map(new MapFunction<Tuple5<Long, Long, Long[], Double, Double>, Row>() { // from class: com.alibaba.alink.operator.batch.graph.CommonNeighborsBatchOp.6
            public Row map(Tuple5<Long, Long, Long[], Double, Double> tuple5) throws Exception {
                Row row = new Row(7);
                row.setField(0, tuple5.f0);
                row.setField(1, tuple5.f1);
                row.setField(2, tuple5.f2);
                row.setField(3, Long.valueOf(((Long[]) tuple5.f2).length));
                Long[] lArr = (Long[]) tuple5.f2;
                StringBuffer stringBuffer = new StringBuffer();
                for (Long l : lArr) {
                    stringBuffer.append(",").append(l);
                }
                row.setField(4, stringBuffer.toString().substring(1));
                row.setField(5, tuple5.f3);
                row.setField(6, tuple5.f4);
                return row;
            }
        }).name("refine_final_result");
        String[] strArr2 = {edgeSourceCol, edgeTargetCol, NEIGHBORS_OUTPUT_COL, CN_OUTPUT_COL, JACCARDS_OUTPUT_COL, ADAMIC_OUTPUT_COL};
        TypeInformation<?>[] typeInformationArr = {Types.LONG, Types.LONG, Types.STRING, Types.LONG, Types.DOUBLE, Types.DOUBLE};
        BatchOperator<?> fromTable = BatchOperator.fromTable(DataSetConversionUtil.toTable(mLEnvironmentId, (DataSet<Row>) name3, new String[]{edgeSourceCol, edgeTargetCol, NEIGHBORS_OUTPUT_COL, CN_OUTPUT_COL, COMMON_NEIGHBOR_ID_COL, JACCARDS_OUTPUT_COL, ADAMIC_OUTPUT_COL}, (TypeInformation<?>[]) new TypeInformation[]{Types.LONG, Types.LONG, Types.OBJECT_ARRAY(Types.LONG), Types.LONG, Types.STRING, Types.DOUBLE, Types.DOUBLE}));
        if (booleanValue) {
            setOutput(new HugeIndexerStringPredictBatchOp().setSelectedCols(edgeSourceCol, edgeTargetCol, NEIGHBORS_OUTPUT_COL).linkFrom(stringIndexerTrainBatchOp, fromTable).select(strArr2).getDataSet(), strArr2, new TypeInformation[]{Types.STRING, Types.STRING, Types.STRING, Types.LONG, Types.DOUBLE, Types.DOUBLE});
        } else {
            setOutput(fromTable.select(new String[]{edgeSourceCol, edgeTargetCol, COMMON_NEIGHBOR_ID_COL, CN_OUTPUT_COL, JACCARDS_OUTPUT_COL, ADAMIC_OUTPUT_COL}).getDataSet(), strArr2, typeInformationArr);
        }
        return this;
    }

    @Override // com.alibaba.alink.operator.batch.BatchOperator
    public /* bridge */ /* synthetic */ CommonNeighborsBatchOp linkFrom(BatchOperator[] batchOperatorArr) {
        return linkFrom((BatchOperator<?>[]) batchOperatorArr);
    }
}
