Skip to content

Commit 219b9c2

Browse files
authored
add similarity algo and test (#31)
* add similarity algo * modify comment
1 parent 67d9396 commit 219b9c2

File tree

6 files changed

+183
-1
lines changed

6 files changed

+183
-1
lines changed

nebula-algorithm/src/main/resources/application.conf

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,5 +173,10 @@
173173
embSeparate: ",",
174174
modelPath: "hdfs://127.0.0.1:9000/model"
175175
}
176+
177+
# JaccardAlgo parameter
178+
jaccard:{
179+
tol: 1.0
180+
}
176181
}
177182
}

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import com.vesoft.nebula.algorithm.config.{
1414
CoefficientConfig,
1515
Configs,
1616
HanpConfig,
17+
JaccardConfig,
1718
KCoreConfig,
1819
LPAConfig,
1920
LouvainConfig,
@@ -31,6 +32,7 @@ import com.vesoft.nebula.algorithm.lib.{
3132
DegreeStaticAlgo,
3233
GraphTriangleCountAlgo,
3334
HanpAlgo,
35+
JaccardAlgo,
3436
KCoreAlgo,
3537
LabelPropagationAlgo,
3638
LouvainAlgo,
@@ -191,6 +193,10 @@ object Main {
191193
val bfsConfig = BfsConfig.getBfsConfig(configs)
192194
BfsAlgo(spark, dataSet, bfsConfig)
193195
}
196+
case "jaccard" => {
197+
val jaccardConfig = JaccardConfig.getJaccardConfig(configs)
198+
JaccardAlgo(spark, dataSet, jaccardConfig)
199+
}
194200
case _ => throw new UnknownParameterException("unknown executeAlgo name.")
195201
}
196202
}

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,20 @@ object Node2vecConfig {
266266
}
267267
}
268268

269+
/**
270+
* Jaccard
271+
*/
272+
case class JaccardConfig(tol: Double)
273+
274+
object JaccardConfig {
275+
var tol: Double = _
276+
def getJaccardConfig(configs: Configs): JaccardConfig = {
277+
val jaccardConfig = configs.algorithmConfig.map
278+
tol = jaccardConfig("algorithm.jaccard.tol").toDouble
279+
JaccardConfig(tol)
280+
}
281+
}
282+
269283
case class AlgoConfig(configs: Configs)
270284

271285
object AlgoConfig {
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/* Copyright (c) 2021 vesoft inc. All rights reserved.
2+
*
3+
* This source code is licensed under Apache 2.0 License.
4+
*/
5+
6+
package com.vesoft.nebula.algorithm.lib
7+
8+
import com.vesoft.nebula.algorithm.config.JaccardConfig
9+
import org.apache.log4j.Logger
10+
import org.apache.spark.ml.feature.{
11+
CountVectorizer,
12+
CountVectorizerModel,
13+
MinHashLSH,
14+
MinHashLSHModel
15+
}
16+
import org.apache.spark.ml.linalg.SparseVector
17+
import org.apache.spark.rdd.RDD
18+
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
19+
import org.apache.spark.sql.functions.col
20+
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
21+
22+
object JaccardAlgo {
23+
private val LOGGER = Logger.getLogger(this.getClass)
24+
25+
val ALGORITHM: String = "Jaccard"
26+
27+
/**
28+
* run the Jaccard algorithm for nebula graph
29+
*/
30+
def apply(spark: SparkSession, dataset: Dataset[Row], jaccardConfig: JaccardConfig): DataFrame = {
31+
32+
val jaccardResult: RDD[Row] = execute(spark, dataset, jaccardConfig.tol)
33+
34+
val schema = StructType(
35+
List(
36+
StructField("srcId", StringType, nullable = true),
37+
StructField("dstId", StringType, nullable = true),
38+
StructField("similarity", DoubleType, nullable = true)
39+
))
40+
val algoResult = spark.sqlContext.createDataFrame(jaccardResult, schema)
41+
algoResult
42+
}
43+
44+
def execute(spark: SparkSession, dataset: Dataset[Row], tol: Double): RDD[Row] = {
45+
// compute the node's 1-degree neighbor set
46+
import spark.implicits._
47+
val edges = dataset
48+
.map(row => {
49+
(row.get(0).toString, row.get(1).toString)
50+
})
51+
.rdd
52+
53+
// get in-degree neighbors
54+
val inputNodeVector: RDD[(String, List[String])] = edges
55+
.map(_.swap)
56+
.combineByKey((v: String) => List(v),
57+
(c: List[String], v: String) => v :: c,
58+
(c1: List[String], c2: List[String]) => c1 ::: c2)
59+
.repartition(100)
60+
61+
// get out-degree neighbors
62+
val outputNodeVector: RDD[(String, List[String])] = edges
63+
.combineByKey(
64+
(v: String) => List(v),
65+
(c: List[String], v: String) => v :: c,
66+
(c1: List[String], c2: List[String]) => c1 ::: c2
67+
)
68+
.repartition(100)
69+
70+
// combine the neighbors
71+
val nodeVector: RDD[(String, List[String])] = inputNodeVector
72+
.fullOuterJoin(outputNodeVector)
73+
.map(row => {
74+
val inNeighbors: Option[List[String]] = row._2._1
75+
val outNeighbors: Option[List[String]] = row._2._2
76+
val neighbors = if (inNeighbors.isEmpty && outNeighbors.isEmpty) {
77+
(row._1, List())
78+
} else if (inNeighbors.isEmpty && outNeighbors.isDefined) {
79+
(row._1, outNeighbors.get)
80+
} else if (inNeighbors.isDefined && outNeighbors.isEmpty) {
81+
(row._1, inNeighbors.get)
82+
} else {
83+
(row._1, (inNeighbors.get ::: outNeighbors.get).distinct)
84+
}
85+
neighbors
86+
})
87+
88+
// Preprocess the input data, process it into a 0-1 vector in the form of bag of word
89+
val inputNodeVectorDF = spark.createDataFrame(nodeVector).toDF("node", "neighbors")
90+
val cvModel: CountVectorizerModel =
91+
new CountVectorizer()
92+
.setInputCol("neighbors")
93+
.setOutputCol("features")
94+
.setBinary(true)
95+
.fit(inputNodeVectorDF)
96+
97+
val inputNodeVectorDFSparse: DataFrame =
98+
cvModel.transform(inputNodeVectorDF).select("node", "features")
99+
100+
val nodeVectorDFSparseFilter = spark
101+
.createDataFrame(
102+
inputNodeVectorDFSparse.rdd
103+
.map(row => (row.getAs[String]("node"), row.getAs[SparseVector]("features")))
104+
.map(x => (x._1, x._2, x._2.numNonzeros))
105+
.filter(x => x._3 >= 1)
106+
.map(x => (x._1, x._2)))
107+
.toDF("node", "features")
108+
109+
// call ml's minhashLSH to compute the Jaccard
110+
val mh = new MinHashLSH().setNumHashTables(100).setInputCol("features").setOutputCol("hashes")
111+
val model: MinHashLSHModel = mh.fit(nodeVectorDFSparseFilter)
112+
val nodeDistance: DataFrame = model
113+
.approxSimilarityJoin(nodeVectorDFSparseFilter,
114+
nodeVectorDFSparseFilter,
115+
tol,
116+
"JaccardDistance")
117+
.select(col("datasetA.node").alias("node1"),
118+
col("datasetB.node").alias("node2"),
119+
col("JaccardDistance"))
120+
121+
val nodeOverlapRatio = nodeDistance.rdd
122+
.map(x => {
123+
val node1 = x.getString(0)
124+
val node2 = x.getString(1)
125+
val overlapRatio = 1 - x.getDouble(2)
126+
if (node1 < node2) ((node1, node2), overlapRatio) else ((node2, node1), overlapRatio)
127+
})
128+
.filter(x => x._1._1 != x._1._2)
129+
.map(row => {
130+
Row(row._1._1, row._1._2, row._2)
131+
})
132+
133+
nodeOverlapRatio.distinct()
134+
}
135+
}

nebula-algorithm/src/test/resources/edge.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ src,dst,weight
1414
4,1,1.0
1515
4,2,5.0
1616
4,3,1.0
17-
4,4,5.0
17+
4,4,5.0
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2021 vesoft inc. All rights reserved.
2+
*
3+
* This source code is licensed under Apache 2.0 License.
4+
*/
5+
6+
package com.vesoft.nebula.algorithm.lib
7+
8+
import com.vesoft.nebula.algorithm.config.{JaccardConfig, KCoreConfig}
9+
import org.apache.spark.sql.SparkSession
10+
import org.junit.Test
11+
12+
class JaccardAlgoSuite {
13+
@Test
14+
def kcoreSuite(): Unit = {
15+
val spark = SparkSession.builder().master("local").getOrCreate()
16+
val data = spark.read.option("header", true).csv("src/test/resources/edge.csv")
17+
val jaccardConfig = new JaccardConfig(0.01)
18+
val jaccardResult = JaccardAlgo.apply(spark, data, jaccardConfig)
19+
jaccardResult.show()
20+
assert(jaccardResult.count() == 6)
21+
}
22+
}

0 commit comments

Comments
 (0)