Skip to content

Commit

Permalink
Update to Spark 2.2.1, Neo4j 3.3.6, Allow bolt+routing API usage + se…
Browse files Browse the repository at this point in the history
…paration into read and write transactions
  • Loading branch information
jexp committed Jul 4, 2018
1 parent 5b90168 commit 8146368
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 59 deletions.
20 changes: 17 additions & 3 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>neo4j-contrib</groupId>
<packaging>jar</packaging>
<version>2.1.0-M4</version>
<version>2.2.1-M5</version>
<artifactId>neo4j-spark-connector</artifactId>
<name>neo4j-spark-connector</name>

Expand All @@ -30,9 +30,10 @@

<netty.version>4.1.8.Final</netty.version>
<neo4j.version>3.2.3</neo4j.version>
<driver.version>1.4.2</driver.version>
<spark.version>2.1.1</spark.version>
<driver.version>1.6.1</driver.version>
<spark.version>2.2.1</spark.version>
<graphframes.version>0.5.0-spark2.1-s_2.11</graphframes.version>
<bouncycastle.version>1.53</bouncycastle.version>
</properties>

<build>
Expand Down Expand Up @@ -197,6 +198,19 @@
<version>${netty.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk15on</artifactId>
<version>${bouncycastle.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcpkix-jdk15on</artifactId>
<version>${bouncycastle.version}</version>
<scope>test</scope>
</dependency>

</dependencies>

<repositories>
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/neo4j/spark/Neo4JavaSparkContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ public Dataset<Row> queryDF(final String query, final Map<String,Object> paramet
}

public Dataset<Row> queryDF(final String query, final Map<String,Object> parameters) {
return Neo4jDataFrame.apply(sqlContext, query,parameters);
return Neo4jDataFrame.apply(sqlContext, query,parameters, false);
}
}
72 changes: 40 additions & 32 deletions src/main/scala/org/neo4j/spark/Neo4j.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.graphframes.GraphFrame
import org.neo4j.driver.v1.{Driver, Session, StatementResult}
import org.neo4j.driver.v1._
import org.neo4j.spark.Neo4j.{LoadDsl, NameProp, PartitionsDsl, Pattern, QueriesDsl, Query, SaveDsl}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -378,7 +378,7 @@ object Executor {
i
}

def execute(config: Neo4jConfig, query: String, parameters: Map[String, Any]): CypherResult = {
def execute(config: Neo4jConfig, query: String, parameters: Map[String, Any], write: Boolean = false): CypherResult = {

def close(driver: Driver, session: Session) = {
try {
Expand All @@ -395,39 +395,47 @@ object Executor {
val session = driver.session()

try {
val result: StatementResult = session.run(query, toJava(parameters))
if (!result.hasNext) {
result.consume()
session.close()
driver.close()
return new CypherResult(new StructType(), Iterator.empty)
}
val peek = result.peek()
val keyCount = peek.size()
if (keyCount == 0) {
val res: CypherResult = new CypherResult(new StructType(), Array.fill[Array[Any]](rows(result))(EMPTY).toIterator)
result.consume()
close(driver,session)
return res
}
val keys = peek.keys().asScala
val fields = keys.map(k => (k, peek.get(k).`type`())).map(keyType => CypherTypes.field(keyType))
val schema = StructType(fields)

val it = result.asScala.map((record) => {
val row = new Array[Any](keyCount)
var i = 0
while (i < keyCount) {
row.update(i, record.get(i).asObject())
i = i + 1
}
val runner = new TransactionWork[CypherResult]() { override def execute(tx:Transaction) : CypherResult = {
val result: StatementResult = tx.run(query, toJava(parameters))
if (!result.hasNext) {
result.consume()
close(driver,session)
session.close()
driver.close()
return new CypherResult(new StructType(), Iterator.empty)
}
row
})
new CypherResult(schema, it)
val peek = result.peek()
val keyCount = peek.size()
if (keyCount == 0) {
val res: CypherResult = new CypherResult(new StructType(), Array.fill[Array[Any]](rows(result))(EMPTY).toIterator)
result.consume()
close(driver, session)
return res
}
val keys = peek.keys().asScala
val fields = keys.map(k => (k, peek.get(k).`type`())).map(keyType => CypherTypes.field(keyType))
val schema = StructType(fields)

val it = result.asScala.map((record) => {
val row = new Array[Any](keyCount)
var i = 0
while (i < keyCount) {
row.update(i, record.get(i).asObject())
i = i + 1
}
if (!result.hasNext) {
result.consume()
close(driver, session)
}
row
})
new CypherResult(schema, it)
}}

if (write)
session.writeTransaction(runner)
else
session.readTransaction(runner)

} finally {
close(driver,session)
}
Expand Down
55 changes: 35 additions & 20 deletions src/main/scala/org/neo4j/spark/Neo4jDataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import java.util

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.{Column, DataFrame, Row, SQLContext, types}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import org.neo4j.driver.internal.types.InternalTypeSystem
import org.neo4j.driver.v1._
import org.neo4j.driver.v1.summary.ResultSummary
import org.neo4j.driver.v1.types.{Type, TypeSystem}

import scala.collection.JavaConverters._
Expand All @@ -30,22 +31,29 @@ object Neo4jDataFrame {
"target" -> target._2.map( c => (c, r.getAs[AnyRef](c))).toMap.asJava,
"relationship" -> relationship._2.map( c => (c, r.getAs[AnyRef](c))).toMap.asJava)
.asJava).asJava
execute(config, mergeStatement, Map("rows" -> params).asJava)
execute(config, mergeStatement, Map("rows" -> params).asJava, write = true)
})
}

def execute(config : Neo4jConfig, query: String, parameters: java.util.Map[String, AnyRef]) = {
def execute(config : Neo4jConfig, query: String, parameters: java.util.Map[String, AnyRef], write: Boolean = false) : ResultSummary = {
val driver: Driver = config.driver()
val session = driver.session()
try {
session.run(query, parameters).consume()
val runner = new TransactionWork[ResultSummary]() { override def execute(tx:Transaction) : ResultSummary =
tx.run(query, parameters).consume()
}
if (write) {
session.writeTransaction(runner)
}
else
session.readTransaction(runner)
} finally {
if (session.isOpen) session.close()
driver.close()
}
}

def withDataType(sqlContext: SQLContext, query: String, parameters: Seq[(String, Any)], schema: (String, types.DataType)*) = {
def withDataType(sqlContext: SQLContext, query: String, parameters: Seq[(String, Any)], schema: (String, DataType)*) = {
val rowRdd = Neo4jRowRDD(sqlContext.sparkContext, query, parameters)
sqlContext.createDataFrame(rowRdd, CypherTypes.schemaFromDataType(schema))
}
Expand All @@ -55,18 +63,25 @@ object Neo4jDataFrame {
sqlContext.createDataFrame(rowRdd, CypherTypes.schemaFromNamedType(schema))
}

def apply(sqlContext: SQLContext, query: String, parameters: java.util.Map[String, AnyRef]) = {
def apply(sqlContext: SQLContext, query: String, parameters: java.util.Map[String, AnyRef], write: Boolean = false) : DataFrame = {
val config = Neo4jConfig(sqlContext.sparkContext.getConf)
val driver: Driver = config.driver()
val session = driver.session()
try {
val result = session.run(query, parameters)
if (!result.hasNext) throw new RuntimeException("Can't determine schema from empty result")
val peek: Record = result.peek()
val fields = peek.keys().asScala.map(k => (k, peek.get(k).`type`())).map(keyType => CypherTypes.field(keyType))
val schema = StructType(fields)
val rowRdd = new Neo4jResultRdd(sqlContext.sparkContext, result.asScala, peek.size(), session, driver)
sqlContext.createDataFrame(rowRdd, schema)
val runTransaction = new TransactionWork[DataFrame]() {
override def execute(tx:Transaction) : DataFrame = {
val result = session.run(query, parameters)
if (!result.hasNext) throw new RuntimeException("Can't determine schema from empty result")
val peek: Record = result.peek()
val fields = peek.keys().asScala.map(k => (k, peek.get(k).`type`())).map(keyType => CypherTypes.field(keyType))
val schema = StructType(fields)
val rowRdd = new Neo4jResultRdd(sqlContext.sparkContext, result.asScala, peek.size(), session, driver)
sqlContext.createDataFrame(rowRdd, schema)
}}
if (write)
session.writeTransaction(runTransaction)
else
session.readTransaction(runTransaction)
} finally {
if (session.isOpen) session.close()
driver.close()
Expand Down Expand Up @@ -106,11 +121,11 @@ object Neo4jDataFrame {
}

object CypherTypes {
val INTEGER = types.LongType
val FlOAT = types.DoubleType
val STRING = types.StringType
val BOOLEAN = types.BooleanType
val NULL = types.NullType
val INTEGER = DataTypes.LongType
val FlOAT = DataTypes.DoubleType
val STRING = DataTypes.StringType
val BOOLEAN = DataTypes.BooleanType
val NULL = DataTypes.NullType

def apply(typ:String) = typ.toUpperCase match {
case "LONG" => INTEGER
Expand Down Expand Up @@ -154,7 +169,7 @@ object CypherTypes {
StructField(field._1, CypherTypes(field._2), nullable = true) )
StructType(fields)
}
def schemaFromDataType(schemaInfo: Seq[(String, types.DataType)]) = {
def schemaFromDataType(schemaInfo: Seq[(String, DataType)]) = {
val fields = schemaInfo.map(field =>
StructField(field._1, field._2, nullable = true) )
StructType(fields)
Expand Down
9 changes: 6 additions & 3 deletions src/test/scala/org/neo4j/spark/Neo4jSparkTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,24 @@ class Neo4jSparkTest {
assertEquals(1000,graph.edges.count())

val top3: Array[(VertexId, Double)] = PageRank.run(graph,5).vertices.sortBy(v => v._2, ascending = false,5).take(3)
assertEquals(0.622D, top3(0)._2, 0.01)
// assertEquals(0.622D, top3(0)._2, 0.01) // Spark 2.1.2
assertEquals(1D, top3(0)._2, 0.01) // Spark 2.2.1
}
@Test def runSimplePatternRelQueryWithPartitionGraphFrame() {
val neo4j: Neo4j = Neo4j(sc).pattern(("Person","id"),("KNOWS",null), ("Person","id")).partitions(7).batch(200)
val graphFrame: GraphFrame = neo4j.loadGraphFrame
// graphFrame.edges.foreach(x => println(x))
assertEquals(100,graphFrame.vertices.count)
assertEquals(1000,graphFrame.edges.count)

val pageRankFrame: GraphFrame = graphFrame.pageRank.maxIter(5).run()
val pageRankFrame: GraphFrame = graphFrame.pageRank.maxIter(5).resetProbability(0.15).run()
val ranked: DataFrame = pageRankFrame.vertices
ranked.printSchema()
// sorting DF http://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.sql.Column
val top3: Array[Row] = ranked.orderBy(ranked.col("pagerank").desc).take(3)
top3.foreach(println)
assertEquals(0.622D, top3(0).getAs[Double]("pagerank"), 0.01)
// assertEquals(0.622D, top3(0).getAs[Double]("pagerank"), 0.01) // Spark 2.1.2
assertEquals(1D, top3(0).getAs[Double]("pagerank"), 0.01) // Spark 2.2.1
}


Expand Down

0 comments on commit 8146368

Please sign in to comment.