Skip to content

Commit

Permalink
Improved bit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 19, 2024
1 parent 362ced0 commit e82d96d
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 24 deletions.
30 changes: 22 additions & 8 deletions src/diesel_ext/bit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,25 @@ mod tests {
diesel::sql_query("CREATE EXTENSION IF NOT EXISTS vector").execute(&mut conn)?;
diesel::sql_query("DROP TABLE IF EXISTS diesel_bit_items").execute(&mut conn)?;
diesel::sql_query(
"CREATE TABLE diesel_bit_items (id serial PRIMARY KEY, embedding bit(3))",
"CREATE TABLE diesel_bit_items (id serial PRIMARY KEY, embedding bit(10))",
)
.execute(&mut conn)?;

let new_items = vec![
NewItem {
embedding: Some(Bit::new(&[false, false, false])),
embedding: Some(Bit::new(&[
false, false, false, false, false, false, false, false, false, true,
])),
},
NewItem {
embedding: Some(Bit::new(&[true, false, true])),
embedding: Some(Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
])),
},
NewItem {
embedding: Some(Bit::new(&[true, true, true])),
embedding: Some(Bit::new(&[
true, true, true, false, false, false, false, false, false, true,
])),
},
NewItem { embedding: None },
];
Expand All @@ -92,20 +98,26 @@ mod tests {
assert_eq!(4, all.len());

let neighbors = items::table
.order(items::embedding.hamming_distance(Bit::new(&[true, false, true])))
.order(items::embedding.hamming_distance(Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
])))
.limit(5)
.load::<Item>(&mut conn)?;
assert_eq!(
vec![2, 3, 1, 4],
neighbors.iter().map(|v| v.id).collect::<Vec<i32>>()
);
assert_eq!(
Some(Bit::new(&[true, false, true])),
Some(Bit::new(&[
true, false, true, false, false, false, false, false, false, true
])),
neighbors.first().unwrap().embedding
);

let neighbors = items::table
.order(items::embedding.jaccard_distance(Bit::new(&[true, false, true])))
.order(items::embedding.jaccard_distance(Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
])))
.limit(5)
.load::<Item>(&mut conn)?;
assert_eq!(
Expand All @@ -114,7 +126,9 @@ mod tests {
);

let distances = items::table
.select(items::embedding.hamming_distance(Bit::new(&[true, false, true])))
.select(items::embedding.hamming_distance(Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
])))
.order(items::id)
.load::<Option<f64>>(&mut conn)?;
assert_eq!(vec![Some(2.0), Some(0.0), Some(1.0), None], distances);
Expand Down
28 changes: 19 additions & 9 deletions src/postgres_ext/bit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,32 @@ mod tests {
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
client.execute("DROP TABLE IF EXISTS postgres_bit_items", &[])?;
client.execute(
"CREATE TABLE postgres_bit_items (id bigserial PRIMARY KEY, embedding bit(3))",
"CREATE TABLE postgres_bit_items (id bigserial PRIMARY KEY, embedding bit(10))",
&[],
)?;

let vec = Bit::new(&[true, false, true]);
let vec2 = Bit::new(&[false, true, false]);
let vec = Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
]);
let vec2 = Bit::new(&[
false, true, false, false, false, false, false, false, false, true,
]);
client.execute(
"INSERT INTO postgres_bit_items (embedding) VALUES ($1), ($2), (NULL)",
&[&vec, &vec2],
)?;

let query_vec = Bit::new(&[true, false, true]);
let query_vec = Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
]);
let row = client.query_one(
"SELECT embedding FROM postgres_bit_items ORDER BY embedding <~> $1 LIMIT 1",
&[&query_vec],
)?;
let res_vec: Bit = row.get(0);
assert_eq!(vec, res_vec);
assert_eq!(3, res_vec.len());
assert_eq!(&[0b10100000], res_vec.as_bytes());
assert_eq!(10, res_vec.len());
assert_eq!(&[0b10100000, 0b01000000], res_vec.as_bytes());

let null_row = client.query_one(
"SELECT embedding FROM postgres_bit_items WHERE embedding IS NULL LIMIT 1",
Expand All @@ -87,15 +93,19 @@ mod tests {
&[],
)?;
let text_res: String = text_row.get(0);
assert_eq!("101", text_res);
assert_eq!("1010000001", text_res);

// copy
let bit_type = Type::BIT;
let writer = client
.copy_in("COPY postgres_bit_items (embedding) FROM STDIN WITH (FORMAT BINARY)")?;
let mut writer = BinaryCopyInWriter::new(writer, &[bit_type]);
writer.write(&[&Bit::new(&[true, false, true])])?;
writer.write(&[&Bit::new(&[false, true, false])])?;
writer.write(&[&Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
])])?;
writer.write(&[&Bit::new(&[
false, true, false, false, false, false, false, false, false, true,
])])?;
writer.finish()?;

Ok(())
Expand Down
20 changes: 13 additions & 7 deletions src/sqlx_ext/bit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,33 @@ mod tests {
sqlx::query("DROP TABLE IF EXISTS sqlx_bit_items")
.execute(&pool)
.await?;
sqlx::query("CREATE TABLE sqlx_bit_items (id bigserial PRIMARY KEY, embedding bit(3))")
sqlx::query("CREATE TABLE sqlx_bit_items (id bigserial PRIMARY KEY, embedding bit(10))")
.execute(&pool)
.await?;

let vec = Bit::new(&[true, false, true]);
let vec2 = Bit::new(&[false, true, false]);
let vec = Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
]);
let vec2 = Bit::new(&[
false, true, false, false, false, false, false, false, false, true,
]);
sqlx::query("INSERT INTO sqlx_bit_items (embedding) VALUES ($1), ($2), (NULL)")
.bind(&vec)
.bind(&vec2)
.execute(&pool)
.await?;

let query_vec = Bit::new(&[true, false, true]);
let query_vec = Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
]);
let row =
sqlx::query("SELECT embedding FROM sqlx_bit_items ORDER BY embedding <~> $1 LIMIT 1")
.bind(query_vec)
.fetch_one(&pool)
.await?;
let res_vec: Bit = row.try_get("embedding").unwrap();
assert_eq!(vec, res_vec);
assert_eq!(&[0b10100000], res_vec.as_bytes());
assert_eq!(&[0b10100000, 0b01000000], res_vec.as_bytes());

let null_row =
sqlx::query("SELECT embedding FROM sqlx_bit_items WHERE embedding IS NULL LIMIT 1")
Expand All @@ -92,9 +98,9 @@ mod tests {
.fetch_one(&pool)
.await?;
let text_res: String = text_row.try_get("embedding").unwrap();
assert_eq!("101", text_res);
assert_eq!("1010000001", text_res);

sqlx::query("ALTER TABLE sqlx_bit_items ADD COLUMN factors bit(3)[]")
sqlx::query("ALTER TABLE sqlx_bit_items ADD COLUMN factors bit(10)[]")
.execute(&pool)
.await?;

Expand Down

0 comments on commit e82d96d

Please sign in to comment.