Skip to content

Commit

Permalink
load/dump ubodt (#21)
Browse files Browse the repository at this point in the history
* not ready

* not ready

* not ready

* fix

* test ubudto load/dump

* good

* fix test

---------

Co-authored-by: TANG ZHIXIONG <[email protected]>
  • Loading branch information
district10 and zhixiong-tang authored Mar 8, 2024
1 parent f7e5fc1 commit 7828111
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 13 deletions.
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@
# built documents.
#
# The short X.Y version.
version = '0.1.7'
version = '0.1.8'
# The full version, including alpha/beta/rc tags.
release = '0.1.7'
release = '0.1.8'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "scikit_build_core.build"

[project]
name = "networkx_graph"
version = "0.1.7"
version = "0.1.8"
url = "https://github.com/cubao/networkx-graph"
description = "Some customized graph algorithms"
readme = "README.md"
Expand Down
140 changes: 131 additions & 9 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,9 @@ struct DiGraph
std::vector<UbodtRecord> build_ubodt(double thresh, int pool_size = 1,
int nodes_thresh = 100) const
{
if (pool_size > 1 && nodes_.size() > nodes_thresh) {
return build_ubodt_parallel(thresh, pool_size);
}
auto records = std::vector<UbodtRecord>();
for (auto &kv : nodes_) {
auto rows = build_ubodt(kv.first, thresh);
Expand Down Expand Up @@ -919,7 +922,11 @@ struct DiGraph
return records;
}

// TODO, batching
std::vector<UbodtRecord> build_ubodt_parallel(double thresh,
int poolsize) const
{
return {};
}

void freeze() { freezed_ = true; }
void build() const {}
Expand Down Expand Up @@ -1736,6 +1743,10 @@ struct ShortestPathWithUbodt
ShortestPathWithUbodt(const DiGraph *graph,
const std::vector<UbodtRecord> &ubodt)
: graph(graph)
{
load_ubodt(ubodt);
}
void load_ubodt(const std::vector<UbodtRecord> &ubodt)
{
for (auto &r : ubodt) {
this->ubodt.emplace(std::make_pair(r.source_road, r.target_road),
Expand All @@ -1754,10 +1765,36 @@ struct ShortestPathWithUbodt
std::sort(items.begin(), items.end());
}
}
ShortestPathWithUbodt(const DiGraph *graph, double thresh)
: ShortestPathWithUbodt(graph, graph->build_ubodt(thresh))
ShortestPathWithUbodt(const DiGraph *graph, double thresh,
int pool_size = 1, int nodes_thresh = 100)
: ShortestPathWithUbodt(
graph, graph->build_ubodt(thresh, pool_size, nodes_thresh))
{
}
ShortestPathWithUbodt(const DiGraph *graph, const std::string &path)
{
load_ubodt(path);
}
void load_ubodt(const std::string &path)
{
return load_ubodt(Load_Ubodt(path));
}
std::vector<UbodtRecord> dump_ubodt() const
{
std::vector<UbodtRecord> rows;
rows.reserve(ubodt.size());
for (auto &pair : ubodt) {
rows.push_back(pair.second);
}
std::sort(rows.begin(), rows.end());
return rows;
}
bool dump_ubodt(const std::string &path) const
{
return Dump_Ubodt(dump_ubodt(), path);
}
size_t size() const { return ubodt.size(); }

std::vector<std::tuple<double, std::string>>
by_source(const std::string &source, std::optional<double> cutoff) const
{
Expand All @@ -1781,6 +1818,54 @@ struct ShortestPathWithUbodt
}
return path(*src_idx, *dst_idx);
}
std::optional<double> dist(const std::string &source,
const std::string &target) const
{
auto src_idx = graph->indexer().get_id(source);
if (!src_idx) {
return {};
}
auto dst_idx = graph->indexer().get_id(target);
if (!dst_idx) {
return {};
}
auto itr = ubodt.find({*src_idx, *dst_idx});
if (itr == ubodt.end()) {
return {};
}
return itr->second.cost;
}

static std::vector<UbodtRecord> Load_Ubodt(const std::string &path)
{
auto f = std::ifstream(path.c_str(), std::ios::binary | std::ios::ate);
if (!f.is_open()) {
return {};
}
const size_t N = static_cast<size_t>(f.tellg()) / sizeof(UbodtRecord);
std::vector<UbodtRecord> rows;
rows.reserve(N);
f.seekg(0);
UbodtRecord row;
while (f.read(reinterpret_cast<char *>(&row.source_road),
sizeof(UbodtRecord))) {
rows.push_back(row);
}
return rows;
}
static bool Dump_Ubodt(const std::vector<UbodtRecord> &ubodt,
const std::string &path)
{
auto f = std::ofstream(path.c_str(), std::ios::binary);
if (!f.is_open()) {
return false;
}
for (auto &row : ubodt) {
f.write(reinterpret_cast<const char *>(&row.source_road),
sizeof(row));
}
return true;
}

private:
std::optional<Path> path(int64_t source, int64_t target) const
Expand Down Expand Up @@ -2817,7 +2902,8 @@ PYBIND11_MODULE(_core, m)
.def("build_ubodt",
py::overload_cast<double, int, int>(&DiGraph::build_ubodt,
py::const_),
"thresh"_a, py::kw_only(), "pool_size"_a = 1,
"thresh"_a, py::kw_only(), //
"pool_size"_a = 1, //
"nodes_thresh"_a = 100)
.def("build_ubodt",
py::overload_cast<int64_t, double>(&DiGraph::build_ubodt,
Expand All @@ -2831,16 +2917,52 @@ PYBIND11_MODULE(_core, m)
py::dynamic_attr()) //
.def(py::init<const DiGraph *, const std::vector<UbodtRecord> &>(),
"graph"_a, "ubodt"_a)
.def(py::init<const DiGraph *, double>(), "graph"_a, "thresh"_a)
.def(py::init<const DiGraph *, double, int, int>(), //
"graph"_a, "thresh"_a, py::kw_only(), //
"pool_size"_a = 1, //
"nodes_thresh"_a = 100)
.def(py::init<const DiGraph *, const std::string &>(), //
"graph"_a, "path"_a)
//
.def(
"load_ubodt",
[](ShortestPathWithUbodt &self, const std::string &path) {
return self.load_ubodt(path);
},
"path"_a)
.def(
"load_ubodt",
[](ShortestPathWithUbodt &self,
const std::vector<UbodtRecord> &rows) {
return self.load_ubodt(rows);
},
"rows"_a)
.def(
"dump_ubodt",
[](const ShortestPathWithUbodt &self) { return self.dump_ubodt(); })
.def("dump_ubodt",
[](const ShortestPathWithUbodt &self, const std::string &path) {
return self.dump_ubodt(path);
})
.def("size", &ShortestPathWithUbodt::size)
//
.def("by_source", &ShortestPathWithUbodt::by_source, "source"_a,
"cutoff"_a = std::nullopt)
.def("by_target", &ShortestPathWithUbodt::by_target, "target"_a,
"cutoff"_a = std::nullopt)
.def_static("Load_Ubodt", &ShortestPathWithUbodt::Load_Ubodt, //
"path"_a)
.def_static("Dump_Ubodt", &ShortestPathWithUbodt::Dump_Ubodt, //
"ubodt"_a, "path"_a)
//
.def("by_source", &ShortestPathWithUbodt::by_source, //
"source"_a, "cutoff"_a = std::nullopt)
.def("by_target", &ShortestPathWithUbodt::by_target, //
"target"_a, "cutoff"_a = std::nullopt)
.def("path",
py::overload_cast<const std::string &, const std::string &>(
&ShortestPathWithUbodt::path, py::const_),
"source"_a, "target"_a)
.def("dist",
py::overload_cast<const std::string &, const std::string &>(
&ShortestPathWithUbodt::dist, py::const_),
"source"_a, "target"_a)
//
;

Expand Down
36 changes: 35 additions & 1 deletion tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import hashlib
import tempfile

import pytest

import networkx_graph as m
Expand All @@ -15,8 +18,19 @@
)


def calculate_md5(filename, block_size=4096):
hash_md5 = hashlib.md5()
try:
with open(filename, "rb") as f: # noqa: PTH123
for block in iter(lambda: f.read(block_size), b""):
hash_md5.update(block)
except OSError:
return None
return hash_md5.hexdigest()


def test_version():
assert m.__version__ == "0.1.7"
assert m.__version__ == "0.1.8"


def test_add():
Expand Down Expand Up @@ -1356,6 +1370,26 @@ def test_ubodt():
path = spath.path("w1", "w4")
path2 = Path.Build(G, path.nodes)
assert path.to_dict() == path2.to_dict()
assert path.dist == spath.dist("w1", "w4") == 10.0

rows2 = rows[5:] + rows[:5]
assert rows2 != rows
assert sorted(rows2) == rows
assert spath.dump_ubodt() == rows
assert spath.size() == len(rows) == 15

with tempfile.TemporaryDirectory() as dir:
ubodt_path = f"{dir}/ubodt.bin"
assert spath.dump_ubodt(ubodt_path)
md5 = calculate_md5(ubodt_path)
assert md5 == "f2c5dced545563b8f5fff3a6a52985f7"
spath2 = ShortestPathWithUbodt(G2, ubodt_path)
assert spath2.dump_ubodt() == rows
assert spath2.size() == 15
assert ShortestPathWithUbodt.Load_Ubodt(ubodt_path) == rows
ubodt_path2 = f"{dir}/ubodt2.bin"
assert ShortestPathWithUbodt.Dump_Ubodt(rows, ubodt_path2)
assert calculate_md5(ubodt_path2) == md5

path2 = Path.Build(G, path.nodes, start_offset=5.0, end_offset=17.0)
assert path2.dist == 32.0
Expand Down

0 comments on commit 7828111

Please sign in to comment.