Skip to content

Commit

Permalink
DBSCAN notebook Twitter exmaple (#638)
Browse files Browse the repository at this point in the history
* Twitter DBSCAN exmaple

Signed-off-by: Hongzhe Cheng <[email protected]>

* Parquet Save

---------

Signed-off-by: Hongzhe Cheng <[email protected]>
  • Loading branch information
Er1cCheng committed May 3, 2024
1 parent 96de8d2 commit 9f9b4e6
Showing 1 changed file with 105 additions and 1 deletion.
106 changes: 105 additions & 1 deletion notebooks/dbscan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,110 @@
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Twitter Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download Data and Store to Parquet"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Full dataset\n",
"# !curl --output twitter.h5.h5 https://b2share.eudat.eu/api/files/189c8eaf-d596-462b-8a07-93b5922c4a9f/twitter.h5.h5\n",
"\n",
"# Partial small dataset\n",
"!curl --output twitterSmall.h5.h5 https://b2share.eudat.eu/api/files/189c8eaf-d596-462b-8a07-93b5922c4a9f/twitterSmall.h5.h5\n",
"\n",
"import h5py\n",
"import pyarrow\n",
"import pyarrow.parquet as pq\n",
"\n",
"with h5py.File('twitterSmall.h5.h5', 'r') as f: \n",
" data = f[\"DBSCAN\"][:]\n",
"\n",
"df=pd.DataFrame(data, columns=['f1', 'f2'])\n",
"arrow_table = pyarrow.Table.from_pandas(df)\n",
"\n",
"# REMEMBER to change the dbfs path to your designated space\n",
"# Or to local like \"./twitter.parquet\"\n",
"dbfs_path = \"/dbfs/temp/twitter.parquet\"\n",
"pq.write_table(arrow_table, dbfs_path)\n",
"\n",
"df = spark.read.parquet(dbfs_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run DBSCAN over Twitter Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"start_time = time.time()\n",
"\n",
"eps = 0.1\n",
"gpu_dbscan = DBSCAN(eps=eps, min_samples=40, metric=\"euclidean\")\n",
"gpu_dbscan.setFeaturesCols([\"f1\", \"f2\"])\n",
"gpu_model = gpu_dbscan.fit(df)\n",
"gpu_model.setPredictionCol(\"prediction\")\n",
"transformed = gpu_model.transform(df)\n",
"transformed.show()\n",
"\n",
"end_time = time.time()\n",
"elapsed_time = (end_time - start_time)\n",
"\n",
"print(\"Time\", elapsed_time)\n",
"\n",
"dbscan_np = transformed.toPandas().to_numpy()\n",
"\n",
"n_cluster = max(dbscan_np[:,2])\n",
"clusters = [[[],[]] for i in range(int(n_cluster) + 1)]\n",
"\n",
"for p in dbscan_np:\n",
" if int(p[2]) == -1:\n",
" continue\n",
"\n",
" clusters[int(p[2])][0].append(p[0])\n",
" clusters[int(p[2])][1].append(p[1])\n",
"\n",
"clusters = sorted(clusters, key=lambda x: len(x[0]), reverse=True)\n",
"print(\"Number of clusters: \", len(clusters))\n",
"\n",
"for i, c in enumerate(clusters):\n",
" plt.scatter(c[0], c[1], s=0.5, label=f\"cluster {i}\")\n",
" \n",
"plt.xlabel('X')\n",
"plt.ylabel('Y')\n",
"plt.title(f'Twitter API Geo Clusters with DBSCAN eps={eps}')\n",
"plt.show()\n",
"# plt.savefig('plot.png', dpi=1200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down Expand Up @@ -509,7 +613,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.9.19"
},
"vscode": {
"interpreter": {
Expand Down

0 comments on commit 9f9b4e6

Please sign in to comment.