From 569b83348f009a29fbb0dcf7a5a7253adf9af067 Mon Sep 17 00:00:00 2001 From: WANG YEFU Date: Thu, 19 Sep 2024 23:32:34 +0800 Subject: [PATCH] Add bfloat16 quantization support (#30) * Add quantization to float16/bfloat16 * make distance ops generic on float type * Add f16/bf16 to f32 conversion * Add normalization for bfloat16 * Add inner product distance for bfloat16 * Implement L2DistanceSquared for bf16 * Implement L2 for f32 * bf16 * Add benchmark for bf16 ops * Add bfloat16 vector --- .../vectorlite_py/test/vectorlite_test.py | 113 ++-- format.sh | 2 +- vcpkg | 2 +- vectorlite/CMakeLists.txt | 2 +- vectorlite/constraint.cpp | 44 +- vectorlite/constraint.h | 2 +- vectorlite/distance.h | 34 +- vectorlite/macros.h | 12 + vectorlite/ops/ops.cpp | 516 ++++++++++++++++-- vectorlite/ops/ops.h | 49 +- vectorlite/ops/ops_benchmark.cpp | 57 +- vectorlite/ops/ops_test.cpp | 186 +++++++ vectorlite/quantization.cpp | 19 + vectorlite/quantization.h | 10 + vectorlite/util.h | 81 +++ vectorlite/vector.cpp | 115 ---- vectorlite/vector.h | 134 ++++- vectorlite/vector_space.cpp | 43 +- vectorlite/vector_space.h | 1 + vectorlite/vector_space_test.cpp | 184 ++++--- vectorlite/vector_test.cpp | 16 +- vectorlite/vector_view.cpp | 44 -- vectorlite/vector_view.h | 71 ++- vectorlite/vector_view_test.cpp | 3 + vectorlite/virtual_table.cpp | 72 +-- vectorlite/virtual_table.h | 4 +- 26 files changed, 1365 insertions(+), 451 deletions(-) create mode 100644 vectorlite/quantization.cpp create mode 100644 vectorlite/quantization.h delete mode 100644 vectorlite/vector.cpp delete mode 100644 vectorlite/vector_view.cpp diff --git a/bindings/python/vectorlite_py/test/vectorlite_test.py b/bindings/python/vectorlite_py/test/vectorlite_test.py index 4007f1f..3df96c0 100644 --- a/bindings/python/vectorlite_py/test/vectorlite_test.py +++ b/bindings/python/vectorlite_py/test/vectorlite_test.py @@ -121,59 +121,60 @@ def remove_quote(s: str): file_path = os.path.join(tempdir, 'index.bin') file_paths = [f'\"{file_path}\"', f'\'{file_path}\''] - for index_file_path in file_paths: - assert not os.path.exists(remove_quote(index_file_path)) - - conn = get_connection() - cur = conn.cursor() - cur.execute(f'create virtual table my_table using vectorlite(my_embedding float32[{DIM}], hnsw(max_elements={NUM_ELEMENTS}), {index_file_path})') - - for i in range(NUM_ELEMENTS): - cur.execute('insert into my_table (rowid, my_embedding) values (?, ?)', (i, random_vectors[i].tobytes())) - - result = cur.execute('select rowid, distance from my_table where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() - assert len(result) == 10 - - conn.close() - # The index file should be created - index_file_size = os.path.getsize(remove_quote(index_file_path)) - assert os.path.exists(remove_quote(index_file_path)) and index_file_size > 0 - - # test if the index file could be loaded with the same parameters without inserting data again - conn = get_connection() - cur = conn.cursor() - cur.execute(f'create virtual table my_table using vectorlite(my_embedding float32[{DIM}], hnsw(max_elements={NUM_ELEMENTS}), {index_file_path})') - result = cur.execute('select rowid, distance from my_table where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() - assert len(result) == 10 - conn.close() - # The index file should be created - assert os.path.exists(remove_quote(index_file_path)) and os.path.getsize(remove_quote(index_file_path)) == index_file_size - - # test if the index file could be loaded with different hnsw parameters and distance type without inserting data again - # But hnsw parameters can't be changed even if different values are set, they will be owverwritten by the value from the index file - # todo: test whether hnsw parameters are overwritten after more functions are introduced to provide runtime stats. - conn = get_connection() - cur = conn.cursor() - cur.execute(f'create virtual table my_table2 using vectorlite(my_embedding float32[{DIM}] cosine, hnsw(max_elements={NUM_ELEMENTS},ef_construction=32,M=32), {index_file_path})') - result = cur.execute('select rowid, distance from my_table2 where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() - assert len(result) == 10 - - # test searching with ef_search = 30, which defaults to 10 - result = cur.execute('select rowid, distance from my_table2 where knn_search(my_embedding, knn_param(?, ?, ?))', (random_vectors[0].tobytes(), 10, 30)).fetchall() - assert len(result) == 10 - conn.close() - assert os.path.exists(remove_quote(index_file_path)) and os.path.getsize(remove_quote(index_file_path)) == index_file_size - - - # test if `drop table` deletes the index file - conn = get_connection() - cur = conn.cursor() - cur.execute(f'create virtual table my_table2 using vectorlite(my_embedding float32[{DIM}] cosine, hnsw(max_elements={NUM_ELEMENTS},ef_construction=64,M=32), {index_file_path})') - result = cur.execute('select rowid, distance from my_table2 where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() - assert len(result) == 10 - - cur.execute(f'drop table my_table2') - assert not os.path.exists(remove_quote(index_file_path)) - conn.close() - - + for vector_type in ['float32', 'bfloat16']: + for index_file_path in file_paths: + assert not os.path.exists(remove_quote(index_file_path)) + + conn = get_connection() + cur = conn.cursor() + cur.execute(f'create virtual table my_table using vectorlite(my_embedding {vector_type}[{DIM}], hnsw(max_elements={NUM_ELEMENTS}), {index_file_path})') + + for i in range(NUM_ELEMENTS): + cur.execute('insert into my_table (rowid, my_embedding) values (?, ?)', (i, random_vectors[i].tobytes())) + + result = cur.execute('select rowid, distance from my_table where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() + assert len(result) == 10 + + conn.close() + # The index file should be created + index_file_size = os.path.getsize(remove_quote(index_file_path)) + assert os.path.exists(remove_quote(index_file_path)) and index_file_size > 0 + + # test if the index file could be loaded with the same parameters without inserting data again + conn = get_connection() + cur = conn.cursor() + cur.execute(f'create virtual table my_table using vectorlite(my_embedding {vector_type}[{DIM}], hnsw(max_elements={NUM_ELEMENTS}), {index_file_path})') + result = cur.execute('select rowid, distance from my_table where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() + assert len(result) == 10 + conn.close() + # The index file should be created + assert os.path.exists(remove_quote(index_file_path)) and os.path.getsize(remove_quote(index_file_path)) == index_file_size + + # test if the index file could be loaded with different hnsw parameters and distance type without inserting data again + # But hnsw parameters can't be changed even if different values are set, they will be owverwritten by the value from the index file + # todo: test whether hnsw parameters are overwritten after more functions are introduced to provide runtime stats. + conn = get_connection() + cur = conn.cursor() + cur.execute(f'create virtual table my_table2 using vectorlite(my_embedding {vector_type}[{DIM}] cosine, hnsw(max_elements={NUM_ELEMENTS},ef_construction=32,M=32), {index_file_path})') + result = cur.execute('select rowid, distance from my_table2 where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() + assert len(result) == 10 + + # test searching with ef_search = 30, which defaults to 10 + result = cur.execute('select rowid, distance from my_table2 where knn_search(my_embedding, knn_param(?, ?, ?))', (random_vectors[0].tobytes(), 10, 30)).fetchall() + assert len(result) == 10 + conn.close() + assert os.path.exists(remove_quote(index_file_path)) and os.path.getsize(remove_quote(index_file_path)) == index_file_size + + + # test if `drop table` deletes the index file + conn = get_connection() + cur = conn.cursor() + cur.execute(f'create virtual table my_table2 using vectorlite(my_embedding {vector_type}[{DIM}] cosine, hnsw(max_elements={NUM_ELEMENTS},ef_construction=64,M=32), {index_file_path})') + result = cur.execute('select rowid, distance from my_table2 where knn_search(my_embedding, knn_param(?, ?))', (random_vectors[0].tobytes(), 10)).fetchall() + assert len(result) == 10 + + cur.execute(f'drop table my_table2') + assert not os.path.exists(remove_quote(index_file_path)) + conn.close() + + diff --git a/format.sh b/format.sh index 69966de..06830aa 100644 --- a/format.sh +++ b/format.sh @@ -1 +1 @@ -clang-format -style=file -i src/*.h src/*.cpp \ No newline at end of file +clang-format -style=file -i vectorlite/*.h vectorlite/*.cpp \ No newline at end of file diff --git a/vcpkg b/vcpkg index e590c2b..85392b1 160000 --- a/vcpkg +++ b/vcpkg @@ -1 +1 @@ -Subproject commit e590c2b30c08caf1dd8d612ec602a003f9784b7d +Subproject commit 85392b146fdbdb1ef68cb3bfd48e9ee9a6311064 diff --git a/vectorlite/CMakeLists.txt b/vectorlite/CMakeLists.txt index c9e47aa..9ebb666 100644 --- a/vectorlite/CMakeLists.txt +++ b/vectorlite/CMakeLists.txt @@ -10,7 +10,7 @@ message(STATUS "Compiling on ${CMAKE_SYSTEM_PROCESSOR}") add_subdirectory(ops) -add_library(vectorlite SHARED vectorlite.cpp virtual_table.cpp vector.cpp vector_view.cpp util.cpp vector_space.cpp index_options.cpp sqlite_functions.cpp constraint.cpp) +add_library(vectorlite SHARED vectorlite.cpp virtual_table.cpp util.cpp vector_space.cpp index_options.cpp sqlite_functions.cpp constraint.cpp quantization.cpp) # remove the lib prefix to make the shared library name consistent on all platforms. set_target_properties(vectorlite PROPERTIES PREFIX "") target_include_directories(vectorlite PUBLIC ${RAPIDJSON_INCLUDE_DIRS} ${HNSWLIB_INCLUDE_DIRS} ${PROJECT_BINARY_DIR}) diff --git a/vectorlite/constraint.cpp b/vectorlite/constraint.cpp index 0357938..6e7acde 100644 --- a/vectorlite/constraint.cpp +++ b/vectorlite/constraint.cpp @@ -13,6 +13,7 @@ #include "absl/strings/str_join.h" #include "hnswlib/hnswlib.h" #include "macros.h" +#include "quantization.h" #include "sqlite3ext.h" #include "util.h" #include "vector.h" @@ -195,20 +196,39 @@ absl::StatusOr QueryExecutor::Execute() const { index_.setEf(*knn_param->ef_search); } try { - if (!space_.normalize) { - return index_.searchKnnCloserFirst( - knn_param->query_vector.data().data(), knn_param->k, - rowid_filter.get()); + if (space_.vector_type == VectorType::Float32) { + if (!space_.normalize) { + return index_.searchKnnCloserFirst( + knn_param->query_vector.data().data(), knn_param->k, + rowid_filter.get()); + } + + VECTORLITE_ASSERT(space_.normalize); + // Copy the query vector and normalize it. + Vector normalized_vector = Vector::Normalize(knn_param->query_vector); + + auto result = index_.searchKnnCloserFirst( + normalized_vector.data().data(), knn_param->k, rowid_filter.get()); + return result; + } else if (space_.vector_type == VectorType::BFloat16) { + BF16Vector quantized_vector = Quantize(knn_param->query_vector); + + if (!space_.normalize) { + return index_.searchKnnCloserFirst(quantized_vector.data().data(), + knn_param->k, rowid_filter.get()); + } + + VECTORLITE_ASSERT(space_.normalize); + BF16Vector normalized_vector = quantized_vector.Normalize(); + + auto result = index_.searchKnnCloserFirst( + normalized_vector.data().data(), knn_param->k, rowid_filter.get()); + return result; + } else { + return absl::InternalError( + absl::StrFormat("Unknown vector type: %d", space_.vector_type)); } - VECTORLITE_ASSERT(space_.normalize); - // Copy the query vector and normalize it. - Vector normalized_vector = Vector::Normalize(knn_param->query_vector); - - auto result = index_.searchKnnCloserFirst( - normalized_vector.data().data(), knn_param->k, rowid_filter.get()); - return result; - } catch (const std::runtime_error& e) { return absl::InternalError(e.what()); } diff --git a/vectorlite/constraint.h b/vectorlite/constraint.h index 90af01c..014217f 100644 --- a/vectorlite/constraint.h +++ b/vectorlite/constraint.h @@ -12,8 +12,8 @@ #include "hnswlib/hnswlib.h" #include "macros.h" #include "sqlite3.h" -#include "vector_view.h" #include "vector_space.h" +#include "vector_view.h" namespace vectorlite { diff --git a/vectorlite/distance.h b/vectorlite/distance.h index d431e7b..4e68d79 100644 --- a/vectorlite/distance.h +++ b/vectorlite/distance.h @@ -1,6 +1,8 @@ #pragma once #include "hnswlib/hnswlib.h" +#include "hwy/base.h" +#include "macros.h" #include "ops/ops.h" // This file implements hnswlib::SpaceInterface using vectorlite @@ -9,12 +11,13 @@ // PC(i5-12600KF with AVX2 support) namespace vectorlite { -class InnerProductSpace : public hnswlib::SpaceInterface { +template +class GenericInnerProductSpace : public hnswlib::SpaceInterface { public: - explicit InnerProductSpace(size_t dim) - : dim_(dim), func_(InnerProductSpace::InnerProductDistanceFunc) {} + explicit GenericInnerProductSpace(size_t dim) + : dim_(dim), func_(GenericInnerProductSpace::InnerProductDistanceFunc) {} - size_t get_data_size() override { return dim_ * sizeof(float); } + size_t get_data_size() override { return dim_ * sizeof(T); } void* get_dist_func_param() override { return &dim_; } @@ -26,18 +29,22 @@ class InnerProductSpace : public hnswlib::SpaceInterface { static float InnerProductDistanceFunc(const void* v1, const void* v2, const void* dim) { - return ops::InnerProductDistance(static_cast(v1), - static_cast(v2), + return ops::InnerProductDistance(static_cast(v1), + static_cast(v2), *reinterpret_cast(dim)); } }; -class L2Space : public hnswlib::SpaceInterface { +using InnerProductSpace = GenericInnerProductSpace; +using InnerProductSpaceBF16 = GenericInnerProductSpace; + +template +class GenericL2Space : public hnswlib::SpaceInterface { public: - explicit L2Space(size_t dim) - : dim_(dim), func_(L2Space::L2DistanceSquaredFunc) {} + explicit GenericL2Space(size_t dim) + : dim_(dim), func_(GenericL2Space::L2DistanceSquaredFunc) {} - size_t get_data_size() override { return dim_ * sizeof(float); } + size_t get_data_size() override { return dim_ * sizeof(T); } void* get_dist_func_param() override { return &dim_; } @@ -49,10 +56,13 @@ class L2Space : public hnswlib::SpaceInterface { static float L2DistanceSquaredFunc(const void* v1, const void* v2, const void* dim) { - return ops::L2DistanceSquared(static_cast(v1), - static_cast(v2), + return ops::L2DistanceSquared(static_cast(v1), + static_cast(v2), *reinterpret_cast(dim)); } }; +using L2Space = GenericL2Space; +using L2SpaceBF16 = GenericL2Space; + } // namespace vectorlite \ No newline at end of file diff --git a/vectorlite/macros.h b/vectorlite/macros.h index 5559e80..81b1b67 100644 --- a/vectorlite/macros.h +++ b/vectorlite/macros.h @@ -1,5 +1,9 @@ #pragma once +#include + +#include "hwy/base.h" + #if defined(_WIN32) || defined(__WIN32__) #define VECTORLITE_EXPORT __declspec(dllexport) #else @@ -11,3 +15,11 @@ #include #define VECTORLITE_ASSERT(x) assert(x) #endif + +#define VECTORLITE_IF_FLOAT_SUPPORTED(T) \ + std::enable_if_t || \ + std::is_same_v>* = nullptr + +#define VECTORLITE_IF_FLOAT_SUPPORTED_FWD_DECL(T) \ + std::enable_if_t || \ + std::is_same_v>* diff --git a/vectorlite/ops/ops.cpp b/vectorlite/ops/ops.cpp index 436ad5b..db5381f 100644 --- a/vectorlite/ops/ops.cpp +++ b/vectorlite/ops/ops.cpp @@ -1,8 +1,11 @@ #include "ops.h" -#include - +#include #include +#include +#include + +#include "hwy/base.h" // >>>> for dynamic dispatch only, skip if you want static dispatch // For dynamic dispatch, specify the name of the current file (unfortunately @@ -17,7 +20,6 @@ // Must come after foreach_target.h to avoid redefinition errors. #include "hwy/contrib/algo/transform-inl.h" #include "hwy/contrib/dot/dot-inl.h" -#include "hwy/contrib/math/math-inl.h" #include "hwy/highway.h" #include "hwy/targets.h" @@ -32,8 +34,9 @@ namespace HWY_NAMESPACE { // Highway ops reside here; ADL does not find templates nor builtins. namespace hn = hwy::HWY_NAMESPACE; -static float SquaredSumVectorized(const float* v, size_t num_elements) { - const hn::ScalableTag d; +template > +static float SquaredSumVectorized(const D d, const T* v, size_t num_elements) { + static_assert(hwy::IsFloat(), "MulAdd requires float type"); using V = hn::Vec; const size_t N = hn::Lanes(d); HWY_DASSERT(num_elements >= N && num_elements % N == 0); @@ -74,9 +77,49 @@ static float SquaredSumVectorized(const float* v, size_t num_elements) { return hn::ReduceSum(d, sum0); } -static float InnerProductImplVectorized(const float* v1, const float* v2, +template +static float SquaredSumVectorized(const D d, const hwy::bfloat16_t* v, + size_t num_elements) { + const hn::Repartition df32; + + using V = decltype(Zero(df32)); + const size_t N = Lanes(d); + + size_t i = 0; + // See comment in the hwy::Dot::Compute() overload. Unroll 2x, but we need + // twice as many sums for ReorderWidenMulAccumulate. + V sum0 = Zero(df32); + V sum1 = Zero(df32); + V sum2 = Zero(df32); + V sum3 = Zero(df32); + + // Main loop: unrolled + for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop + const auto a0 = LoadU(d, v + i); + i += N; + sum0 = ReorderWidenMulAccumulate(df32, a0, a0, sum0, sum1); + const auto a1 = LoadU(d, v + i); + i += N; + sum2 = ReorderWidenMulAccumulate(df32, a1, a1, sum2, sum3); + } + + // Possibly one more iteration of whole vectors + if (i + N <= num_elements) { + const auto a0 = LoadU(d, v + i); + i += N; + sum0 = ReorderWidenMulAccumulate(df32, a0, a0, sum0, sum1); + } + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return ReduceSum(df32, sum0); +} + +template > +static float InnerProductImplVectorized(const D d, const T* v1, const T* v2, size_t num_elements) { - const hn::ScalableTag d; const size_t N = hn::Lanes(d); HWY_DASSERT(num_elements >= N && num_elements % N == 0); @@ -85,13 +128,13 @@ static float InnerProductImplVectorized(const float* v1, const float* v2, if (v1 != v2) { return hn::Dot::Compute(d, v1, v2, num_elements); } else { - return SquaredSumVectorized(v1, num_elements); + return SquaredSumVectorized(d, v1, num_elements); } } -static float InnerProductImpl(const float* v1, const float* v2, +template > +static float InnerProductImpl(const D d, const T* v1, const T* v2, size_t num_elements) { - const hn::ScalableTag d; const size_t N = hn::Lanes(d); const size_t leftover = num_elements % N; @@ -99,7 +142,7 @@ static float InnerProductImpl(const float* v1, const float* v2, float result = 0; if (num_elements >= N) { - result = InnerProductImplVectorized(v1, v2, num_elements - leftover); + result = InnerProductImplVectorized(d, v1, v2, num_elements - leftover); } if (leftover > 0) { @@ -108,12 +151,15 @@ static float InnerProductImpl(const float* v1, const float* v2, float sum1 = 0; size_t i = num_elements - leftover; for (; i + 2 <= num_elements; i += 2) { - sum0 += v1[i] * v2[i]; - sum1 += v1[i + 1] * v2[i + 1]; + sum0 += hwy::ConvertScalarTo(v1[i]) * + hwy::ConvertScalarTo(v2[i]); + sum1 += hwy::ConvertScalarTo(v1[i + 1]) * + hwy::ConvertScalarTo(v2[i + 1]); } if (i < num_elements) { - sum0 += v1[i] * v2[i]; + sum0 += hwy::ConvertScalarTo(v1[i]) * + hwy::ConvertScalarTo(v2[i]); } return result + sum0 + sum1; } else { @@ -121,10 +167,141 @@ static float InnerProductImpl(const float* v1, const float* v2, } } -static float L2DistanceSquaredImplVectorized(const float* HWY_RESTRICT v1, - const float* HWY_RESTRICT v2, +template +static float L2DistanceSquaredImplVectorized( + const D d, const hwy::bfloat16_t* HWY_RESTRICT v1, + const hwy::bfloat16_t* HWY_RESTRICT v2, size_t num_elements) { + const hn::Repartition df32; + + using V = decltype(Zero(df32)); + const size_t N = Lanes(d); + HWY_DASSERT(num_elements >= N && num_elements % N == 0); + + size_t i = 0; + + V sum0 = Zero(df32); + V sum1 = Zero(df32); + V sum2 = Zero(df32); + V sum3 = Zero(df32); + + // Main loop: unrolled + for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop + const auto a0 = LoadU(d, v1 + i); + const auto a0_lower = hn::PromoteLowerTo(df32, a0); + const auto a0_upper = hn::PromoteUpperTo(df32, a0); + const auto a1 = LoadU(d, v2 + i); + const auto a1_lower = hn::PromoteLowerTo(df32, a1); + const auto a1_upper = hn::PromoteUpperTo(df32, a1); + const auto diff_a_lower = hn::Sub(a0_lower, a1_lower); + const auto diff_a_upper = hn::Sub(a0_upper, a1_upper); + i += N; + sum0 = MulAdd(diff_a_lower, diff_a_lower, sum0); + sum1 = MulAdd(diff_a_upper, diff_a_upper, sum1); + + const auto b0 = LoadU(d, v1 + i); + const auto b0_lower = hn::PromoteLowerTo(df32, b0); + const auto b0_upper = hn::PromoteUpperTo(df32, b0); + const auto b1 = LoadU(d, v2 + i); + const auto b1_lower = hn::PromoteLowerTo(df32, b1); + const auto b1_upper = hn::PromoteUpperTo(df32, b1); + const auto diff_b_lower = hn::Sub(b0_lower, b1_lower); + const auto diff_b_upper = hn::Sub(b0_upper, b1_upper); + i += N; + sum2 = MulAdd(diff_b_lower, diff_b_lower, sum2); + sum3 = MulAdd(diff_b_upper, diff_b_upper, sum3); + } + + // Up to 1 iterations of whole vectors + for (; i + N <= num_elements; i += N) { + const auto a0 = LoadU(d, v1 + i); + const auto a0_lower = hn::PromoteLowerTo(df32, a0); + const auto a0_upper = hn::PromoteUpperTo(df32, a0); + const auto a1 = LoadU(d, v2 + i); + const auto a1_lower = hn::PromoteLowerTo(df32, a1); + const auto a1_upper = hn::PromoteUpperTo(df32, a1); + const auto diff_a_lower = hn::Sub(a0_lower, a1_lower); + const auto diff_a_upper = hn::Sub(a0_upper, a1_upper); + i += N; + sum0 = MulAdd(diff_a_lower, diff_a_lower, sum0); + sum1 = MulAdd(diff_a_upper, diff_a_upper, sum1); + } + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + + return hwy::ConvertScalarTo(hn::ReduceSum(df32, sum0)); +} + +template +static float L2DistanceSquaredImplVectorized( + const D df, const float* HWY_RESTRICT v1, + const hwy::bfloat16_t* HWY_RESTRICT v2, size_t num_elements) { + const hn::Repartition dbf; + using VBF = decltype(Zero(dbf)); + const hn::Half dbfh; + using VF = decltype(Zero(df)); + + const size_t NF = Lanes(df); + HWY_DASSERT(num_elements >= NF && num_elements % NF == 0); + + size_t i = 0; + + VF sum0 = Zero(df); + VF sum1 = Zero(df); + VF sum2 = Zero(df); + VF sum3 = Zero(df); + + // Main loop: unrolled + for (; i + 4 * NF <= num_elements; /* i += 4 * NF */) { + const VF a0 = LoadU(df, v1 + i); + const VBF b0 = LoadU(dbf, v2 + i); + i += NF; + const VF b0_lower = hn::PromoteLowerTo(df, b0); + const VF diff0 = hn::Sub(a0, b0_lower); + sum0 = MulAdd(diff0, diff0, sum0); + + const VF a1 = LoadU(df, v1 + i); + i += NF; + const VF b0_upper = hn::PromoteUpperTo(df, b0); + const VF diff1 = hn::Sub(a1, b0_upper); + sum1 = MulAdd(diff1, diff1, sum1); + + const VF a2 = LoadU(df, v1 + i); + const VBF b2 = LoadU(dbf, v2 + i); + i += NF; + const VF b2_lower = hn::PromoteLowerTo(df, b2); + const VF diff2 = hn::Sub(a2, b2_lower); + sum2 = MulAdd(diff2, diff2, sum2); + + const VF a3 = LoadU(df, v1 + i); + i += NF; + const VF b2_upper = hn::PromoteUpperTo(df, b2); + const VF diff3 = hn::Sub(a3, b2_upper); + sum3 = MulAdd(diff3, diff3, sum3); + } + + // Up to 3 iterations of whole vectors + for (; i + NF <= num_elements; i += NF) { + const VF a = LoadU(df, v1 + i); + const VF b = PromoteTo(df, LoadU(dbfh, v2 + i)); + const VF diff = Sub(a, b); + sum0 = MulAdd(diff, diff, sum0); + } + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + + return hwy::ConvertScalarTo(hn::ReduceSum(df, sum0)); +} + +template > +static float L2DistanceSquaredImplVectorized(const D d, + const T* HWY_RESTRICT v1, + const T* HWY_RESTRICT v2, size_t num_elements) { - const hn::ScalableTag d; + static_assert(hwy::IsFloat(), "MulAdd requires float type"); const size_t N = hn::Lanes(d); HWY_DASSERT(num_elements >= N && num_elements % N == 0); using V = hn::Vec; @@ -156,24 +333,25 @@ static float L2DistanceSquaredImplVectorized(const float* HWY_RESTRICT v1, const auto diff = hn::Sub(LoadU(d, v1 + i), LoadU(d, v2 + i)); sum0 = MulAdd(diff, diff, sum0); } -// Reduction tree: sum of all accumulators by pairs, then across lanes. + // Reduction tree: sum of all accumulators by pairs, then across lanes. sum0 = Add(sum0, sum1); sum2 = Add(sum2, sum3); sum0 = Add(sum0, sum2); - return hn::ReduceSum(d, sum0); + return hwy::ConvertScalarTo(hn::ReduceSum(d, sum0)); } -static float L2DistanceSquaredImpl(const float* HWY_RESTRICT v1, - const float* HWY_RESTRICT v2, +template , typename T2 = T1> +static float L2DistanceSquaredImpl(const D d, const T1* HWY_RESTRICT v1, + const T2* HWY_RESTRICT v2, size_t num_elements) { - const hn::ScalableTag d; const size_t N = hn::Lanes(d); const size_t leftover = num_elements % N; float result = 0; if (num_elements >= N) { - result = L2DistanceSquaredImplVectorized(v1, v2, num_elements - leftover); + result = + L2DistanceSquaredImplVectorized(d, v1, v2, num_elements - leftover); } if (leftover > 0) { @@ -182,14 +360,17 @@ static float L2DistanceSquaredImpl(const float* HWY_RESTRICT v1, float sum1 = 0; size_t i = num_elements - leftover; for (; i + 2 <= num_elements; i += 2) { - float diff0 = v1[i] - v2[i]; + float diff0 = hwy::ConvertScalarTo(v1[i]) - + hwy::ConvertScalarTo(v2[i]); sum0 += diff0 * diff0; - float diff1 = v1[i + 1] - v2[i + 1]; + float diff1 = hwy::ConvertScalarTo(v1[i + 1]) - + hwy::ConvertScalarTo(v2[i + 1]); sum1 += diff1 * diff1; } if (i < num_elements) { - float diff = v1[i] - v2[i]; + float diff = hwy::ConvertScalarTo(v1[i]) - + hwy::ConvertScalarTo(v2[i]); sum0 += diff * diff; } @@ -201,16 +382,175 @@ static float L2DistanceSquaredImpl(const float* HWY_RESTRICT v1, // A vectorized implementation following // https://github.com/nmslib/hnswlib/blob/v0.8.0/python_bindings/bindings.cpp#L241 -static void NormalizeImpl(float* HWY_RESTRICT inout, size_t num_elements) { - using D = hn::ScalableTag; - const D d; - const float squared_sum = InnerProductImpl(inout, inout, num_elements); - const float norm = 1.0f / (sqrtf(squared_sum) + 1e-30f); +template > +static void NormalizeImpl(const D d, T* HWY_RESTRICT inout, + size_t num_elements) { + const float squared_sum = InnerProductImpl(d, inout, inout, num_elements); + const float norm = + hwy::ConvertScalarTo(1.0f / (sqrtf(squared_sum) + 1e-30f)); hn::Transform(d, inout, num_elements, [norm](D d, hn::Vec v) HWY_ATTR { return hn::Mul(v, hn::Set(d, norm)); }); } +template +static void NormalizeImpl(const D d, hwy::bfloat16_t* HWY_RESTRICT inout, + size_t num_elements) { + const float squared_sum = InnerProductImpl(d, inout, inout, num_elements); + const float norm = + hwy::ConvertScalarTo(1.0f / (sqrtf(squared_sum) + 1e-30f)); + hn::Transform(d, inout, num_elements, [norm](D d, hn::Vec v) HWY_ATTR { + const hn::RepartitionToWide df32; + const auto norm_vector = hn::Set(df32, norm); + const auto lower = hn::Mul(hn::PromoteLowerTo(df32, v), norm_vector); + const auto upper = hn::Mul(hn::PromoteUpperTo(df32, v), norm_vector); + return hn::OrderedDemote2To(d, lower, upper); + }); +} + +template +static void QuantizeF32ToHalf(const float* HWY_RESTRICT in, + HalfFloat* HWY_RESTRICT out, size_t size) { + static_assert(sizeof(float) / sizeof(HalfFloat) == 2, + "HalfFloat must be 16-bit"); + const hn::ScalableTag df32; + // f16 here refers to the 16-bit floating point type, including float16_t and + // bfloat16_t + const hn::Repartition df16; + const size_t NF = hn::Lanes(df32); + using VF = hn::Vec; + using VBF = hn::Vec; + const hn::Half df16h; + constexpr bool is_bfloat16 = std::is_same::value; + + size_t i = 0; + if (size >= 2 * NF) { + for (; i <= size - 2 * NF; i += 2 * NF) { + const VF v0 = hn::LoadU(df32, in + i); + const VF v1 = hn::LoadU(df32, in + i + NF); + if constexpr (is_bfloat16) { + const VBF bf = hn::OrderedDemote2To(df16, v0, v1); + hn::StoreU(bf, df16, out + i); + } else { + static_assert(std::is_same::value, + "Unsupported HalfFloat type"); + // todo: use OrderedDemote2To once it's implemented for float16_t + const VBF bf = + hn::Combine(df16, hn::DemoteTo(df16h, v1), hn::DemoteTo(df16h, v0)); + hn::StoreU(bf, df16, out + i); + } + } + } + if (size - i >= NF) { + const VF v0 = hn::LoadU(df32, in + i); + const hn::Vec bfh = hn::DemoteTo(df16h, v0); + hn::StoreU(bfh, df16h, out + i); + i += NF; + } + + if (i != size) { + const size_t remaining = size - i; + const VF v0 = hn::LoadN(df32, in + i, remaining); + const hn::Vec bfh = hn::DemoteTo(df16h, v0); + hn::StoreN(bfh, df16h, out + i, remaining); + } +} + +template +static void HalfFloatToF32(const HalfFloat* HWY_RESTRICT in, + float* HWY_RESTRICT out, size_t size) { + static_assert(sizeof(float) / sizeof(HalfFloat) == 2, + "HalfFloat must be 16-bit"); + const hn::ScalableTag df32; + // f16 here refers to the 16-bit floating point type, including float16_t and + // bfloat16_t + const hn::Repartition df16; + const size_t NF = hn::Lanes(df32); + using VF = hn::Vec; + using VBF = hn::Vec; + const hn::Half df16h; + + size_t i = 0; + if (size >= NF) { + for (; i <= size - NF; i += NF) { + const auto v = hn::LoadU(df16h, in + i); + hn::StoreU(hn::PromoteTo(df32, v), df32, out + i); + } + } + + if (i != size) { + const size_t remaining = size - i; + const auto v = hn::LoadN(df16h, in + i, remaining); + hn::StoreN(hn::PromoteTo(df32, v), df32, out + i, remaining); + } +} + +static void QuantizeF32ToBF16Impl(const float* HWY_RESTRICT in, + hwy::bfloat16_t* HWY_RESTRICT out, + size_t size) { + QuantizeF32ToHalf(in, out, size); +} + +static void QuantizeF32ToF16Impl(const float* HWY_RESTRICT in, + hwy::float16_t* HWY_RESTRICT out, + size_t size) { + QuantizeF32ToHalf(in, out, size); +} + +static float InnerProductImplF32(const float* v1, const float* v2, + size_t num_elements) { + return InnerProductImpl(hn::ScalableTag(), v1, v2, num_elements); +} + +static float InnerProductImplBF16(const hwy::bfloat16_t* v1, + const hwy::bfloat16_t* v2, + size_t num_elements) { + return InnerProductImpl(hn::ScalableTag(), v1, v2, + num_elements); +} + +static float L2DistanceSquaredImplF32(const float* v1, const float* v2, + size_t num_elements) { + return L2DistanceSquaredImpl(hn::ScalableTag(), v1, v2, num_elements); +} + +static float L2DistanceSquaredImplBF16(const hwy::bfloat16_t* v1, + const hwy::bfloat16_t* v2, + size_t num_elements) { + return L2DistanceSquaredImpl(hn::ScalableTag(), v1, v2, + num_elements); +} + +static float L2DistanceSquaredImplF32BF16(const float* v1, + const hwy::bfloat16_t* v2, + size_t num_elements) { + return L2DistanceSquaredImpl(hn::ScalableTag(), v1, v2, num_elements); +} + +static void NormalizeImplF32(float* HWY_RESTRICT inout, size_t num_elements) { + return NormalizeImpl(hn::ScalableTag(), inout, num_elements); +} + +// static void NormalizeImplF16(hwy::float16_t* HWY_RESTRICT inout, size_t +// num_elements) { +// return NormalizeImpl(hn::Half>(), inout, +// num_elements); +// } + +static void NormalizeImplBF16(hwy::bfloat16_t* HWY_RESTRICT inout, + size_t num_elements) { + return NormalizeImpl(hn::ScalableTag(), inout, num_elements); +} + +static void F16ToF32Impl(const hwy::float16_t* HWY_RESTRICT in, + float* HWY_RESTRICT out, size_t num_elements) { + HalfFloatToF32(in, out, num_elements); +} +static void BF16ToF32Impl(const hwy::bfloat16_t* HWY_RESTRICT in, + float* HWY_RESTRICT out, size_t num_elements) { + HalfFloatToF32(in, out, num_elements); +} + } // namespace HWY_NAMESPACE HWY_AFTER_NAMESPACE(); @@ -225,13 +565,29 @@ namespace ops { // This macro declares a static array used for dynamic dispatch; it resides in // the same outer namespace that contains FloorLog2. -HWY_EXPORT(InnerProductImpl); -HWY_EXPORT(NormalizeImpl); -HWY_EXPORT(L2DistanceSquaredImpl); +HWY_EXPORT(InnerProductImplF32); +HWY_EXPORT(InnerProductImplBF16); +HWY_EXPORT(L2DistanceSquaredImplF32); +HWY_EXPORT(L2DistanceSquaredImplBF16); +HWY_EXPORT(L2DistanceSquaredImplF32BF16); +HWY_EXPORT(QuantizeF32ToF16Impl); +HWY_EXPORT(QuantizeF32ToBF16Impl); +HWY_EXPORT(F16ToF32Impl); +HWY_EXPORT(BF16ToF32Impl); + +HWY_EXPORT(NormalizeImplF32); +// HWY_EXPORT(NormalizeImplF16); +HWY_EXPORT(NormalizeImplBF16); HWY_DLLEXPORT float InnerProduct(const float* v1, const float* v2, size_t num_elements) { - return HWY_DYNAMIC_DISPATCH(InnerProductImpl)(v1, v2, num_elements); + return HWY_DYNAMIC_DISPATCH(InnerProductImplF32)(v1, v2, num_elements); +} + +HWY_DLLEXPORT float InnerProduct(const hwy::bfloat16_t* v1, + const hwy::bfloat16_t* v2, + size_t num_elements) { + return HWY_DYNAMIC_DISPATCH(InnerProductImplBF16)(v1, v2, num_elements); } HWY_DLLEXPORT float InnerProductDistance(const float* v1, const float* v2, @@ -239,8 +595,25 @@ HWY_DLLEXPORT float InnerProductDistance(const float* v1, const float* v2, return 1.0f - InnerProduct(v1, v2, num_elements); } +HWY_DLLEXPORT float InnerProductDistance(const hwy::bfloat16_t* v1, + const hwy::bfloat16_t* v2, + size_t num_elements) { + return 1.0f - InnerProduct(v1, v2, num_elements); +} + HWY_DLLEXPORT void Normalize(float* HWY_RESTRICT inout, size_t size) { - HWY_DYNAMIC_DISPATCH(NormalizeImpl)(inout, size); + HWY_DYNAMIC_DISPATCH(NormalizeImplF32)(inout, size); + return; +} + +// HWY_DLLEXPORT void Normalize(hwy::float16_t* HWY_RESTRICT inout, size_t size) +// { +// HWY_DYNAMIC_DISPATCH(NormalizeImplF16)(inout, size); +// return; +// } + +HWY_DLLEXPORT void Normalize(hwy::bfloat16_t* HWY_RESTRICT inout, size_t size) { + HWY_DYNAMIC_DISPATCH(NormalizeImplBF16)(inout, size); return; } @@ -250,7 +623,25 @@ HWY_DLLEXPORT float L2DistanceSquared(const float* v1, const float* v2, return 0.0f; } - return HWY_DYNAMIC_DISPATCH(L2DistanceSquaredImpl)(v1, v2, num_elements); + return HWY_DYNAMIC_DISPATCH(L2DistanceSquaredImplF32)(v1, v2, num_elements); +} + +HWY_DLLEXPORT float L2DistanceSquared(const hwy::bfloat16_t* v1, + const hwy::bfloat16_t* v2, + size_t num_elements) { + if (HWY_UNLIKELY(v1 == v2)) { + return 0.0f; + } + + return HWY_DYNAMIC_DISPATCH(L2DistanceSquaredImplBF16)(v1, v2, num_elements); +} + +// v1 and v2 MUST not be nullptr but **cannot** point to the same array. +HWY_DLLEXPORT float L2DistanceSquared(const float* v1, + const hwy::bfloat16_t* v2, + size_t num_elements) { + return HWY_DYNAMIC_DISPATCH(L2DistanceSquaredImplF32BF16)(v1, v2, + num_elements); } // Implementation follows @@ -269,13 +660,50 @@ HWY_DLLEXPORT void Normalize_Scalar(float* HWY_RESTRICT inout, size_t size) { return; } -// HWY_DLLEXPORT std::string_view DetectTarget() { -// uint64_t supported_targets = HWY_SUPPORTED_TARGETS; -// hwy::GetChosenTarget().Update(supported_targets); -// return hwy::TargetName(supported_targets); -// } +HWY_DLLEXPORT void Normalize_Scalar(hwy::bfloat16_t* HWY_RESTRICT inout, + size_t size) { + float norm = 0.0f; + for (int i = 0; i < size; i++) { + float data = hwy::F32FromBF16(inout[i]); + norm += data * data; + } + norm = 1.0f / (sqrtf(norm) + 1e-30f); + for (int i = 0; i < size; i++) { + inout[i] = hwy::BF16FromF32(hwy::F32FromBF16(inout[i]) * norm); + } + return; +} + +HWY_DLLEXPORT std::vector GetSupportedTargets() { + std::vector targets = hwy::SupportedAndGeneratedTargets(); + std::vector target_names(targets.size()); + std::transform(targets.cbegin(), targets.cend(), target_names.begin(), + [](int64_t target) { return hwy::TargetName(target); }); + return target_names; +} + +HWY_DLLEXPORT void QuantizeF32ToF16(const float* HWY_RESTRICT in, + hwy::float16_t* HWY_RESTRICT out, + size_t num_elements) { + HWY_DYNAMIC_DISPATCH(QuantizeF32ToF16Impl)(in, out, num_elements); +} + +HWY_DLLEXPORT void QuantizeF32ToBF16(const float* HWY_RESTRICT in, + hwy::bfloat16_t* HWY_RESTRICT out, + size_t num_elements) { + HWY_DYNAMIC_DISPATCH(QuantizeF32ToBF16Impl)(in, out, num_elements); +} + +HWY_DLLEXPORT void F16ToF32(const hwy::float16_t* HWY_RESTRICT in, + float* HWY_RESTRICT out, size_t num_elements) { + HWY_DYNAMIC_DISPATCH(F16ToF32Impl)(in, out, num_elements); +} +HWY_DLLEXPORT void BF16ToF32(const hwy::bfloat16_t* HWY_RESTRICT in, + float* HWY_RESTRICT out, size_t num_elements) { + HWY_DYNAMIC_DISPATCH(BF16ToF32Impl)(in, out, num_elements); +} -} // namespace distance +} // namespace ops } // namespace vectorlite -#endif // HWY_ONCE \ No newline at end of file +#endif // HWY_ONCE \ No newline at end of file diff --git a/vectorlite/ops/ops.h b/vectorlite/ops/ops.h index 5ea66d4..9ca76fe 100644 --- a/vectorlite/ops/ops.h +++ b/vectorlite/ops/ops.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include "hwy/base.h" @@ -16,28 +16,65 @@ namespace vectorlite { namespace ops { -using DistanceFunc = float (*)(const float*, const float*, size_t); - // v1 and v2 MUST not be nullptr but can point to the same array. HWY_DLLEXPORT float InnerProduct(const float* v1, const float* v2, size_t num_elements); +HWY_DLLEXPORT float InnerProduct(const hwy::bfloat16_t* v1, + const hwy::bfloat16_t* v2, + size_t num_elements); HWY_DLLEXPORT float InnerProductDistance(const float* v1, const float* v2, size_t num_elements); +HWY_DLLEXPORT float InnerProductDistance(const hwy::bfloat16_t* v1, + const hwy::bfloat16_t* v2, + size_t num_elements); // v1 and v2 MUST not be nullptr but can point to the same array. HWY_DLLEXPORT float L2DistanceSquared(const float* v1, const float* v2, size_t num_elements); +// v1 and v2 MUST not be nullptr but can point to the same array. +HWY_DLLEXPORT float L2DistanceSquared(const hwy::bfloat16_t* v1, + const hwy::bfloat16_t* v2, + size_t num_elements); + +// v1 and v2 MUST not be nullptr and MUST not point to the same array. +HWY_DLLEXPORT float L2DistanceSquared(const float* HWY_RESTRICT v1, + const hwy::bfloat16_t* HWY_RESTRICT v2, + size_t num_elements); + // Nornalize the input vector in place. HWY_DLLEXPORT void Normalize(float* HWY_RESTRICT inout, size_t num_elements); +// HWY_DLLEXPORT void Normalize(hwy::float16_t* HWY_RESTRICT inout, size_t +// num_elements); +HWY_DLLEXPORT void Normalize(hwy::bfloat16_t* HWY_RESTRICT inout, + size_t num_elements); // Normalize the input vector in place. Implemented using non-SIMD code for // testing and benchmarking purposes. HWY_DLLEXPORT void Normalize_Scalar(float* HWY_RESTRICT inout, size_t num_elements); -// Detect best available SIMD target to ensure future dynamic dispatch avoids -// the overhead of CPU detection. HWY_DLLEXPORT std::string_view DetectTarget(); +// Normalize the input vector in place. Implemented using non-SIMD code for +// testing and benchmarking purposes. +HWY_DLLEXPORT void Normalize_Scalar(hwy::bfloat16_t* HWY_RESTRICT inout, + size_t num_elements); + +// Get supported SIMD target name strings. +HWY_DLLEXPORT std::vector GetSuppportedTargets(); + +// in and out should not be nullptr and points to valid memory of required size. +HWY_DLLEXPORT void QuantizeF32ToF16(const float* HWY_RESTRICT in, + hwy::float16_t* HWY_RESTRICT out, + size_t num_elements); +HWY_DLLEXPORT void QuantizeF32ToBF16(const float* HWY_RESTRICT in, + hwy::bfloat16_t* HWY_RESTRICT out, + size_t num_elements); + +// Convert fp16/bf16 to fp32, useful for json serde +HWY_DLLEXPORT void F16ToF32(const hwy::float16_t* HWY_RESTRICT in, + float* HWY_RESTRICT out, size_t num_elements); +HWY_DLLEXPORT void BF16ToF32(const hwy::bfloat16_t* HWY_RESTRICT in, + float* HWY_RESTRICT out, size_t num_elements); -} // namespace distance +} // namespace ops } // namespace vectorlite diff --git a/vectorlite/ops/ops_benchmark.cpp b/vectorlite/ops/ops_benchmark.cpp index 5dfef5b..c012e29 100644 --- a/vectorlite/ops/ops_benchmark.cpp +++ b/vectorlite/ops/ops_benchmark.cpp @@ -1,8 +1,10 @@ +#include + #include #include "benchmark/benchmark.h" -#include "ops.h" #include "hnswlib/hnswlib.h" +#include "ops.h" static std::vector GenerateOneRandomVector(size_t dim) { std::random_device rd; @@ -31,6 +33,24 @@ static void BM_InnerProduct_Vectorlite(benchmark::State& state) { } } +static void BM_InnerProduct_Vectorlite_BF16(benchmark::State& state) { + size_t dim = state.range(0); + auto v1 = GenerateOneRandomVector(dim); + auto v2 = GenerateOneRandomVector(dim); + + std::vector v1_bf16(dim); + vectorlite::ops::QuantizeF32ToBF16(v1.data(), v1_bf16.data(), dim); + + std::vector v2_bf16(dim); + vectorlite::ops::QuantizeF32ToBF16(v2.data(), v2_bf16.data(), dim); + + for (auto _ : state) { + benchmark::DoNotOptimize( + vectorlite::ops::InnerProductDistance(v1.data(), v2.data(), dim)); + benchmark::ClobberMemory(); + } +} + static void BM_InnerProduct_Scalar(benchmark::State& state) { size_t dim = state.range(0); size_t self_product = state.range(1); @@ -71,6 +91,23 @@ static void BM_L2DistanceSquared_Vectorlite(benchmark::State& state) { } } +static void BM_L2DistanceSquared_Vectorlite_BF16(benchmark::State& state) { + size_t dim = state.range(0); + auto v1 = GenerateOneRandomVector(dim); + auto v2 = GenerateOneRandomVector(dim); + + std::vector v1_bf16(dim); + std::vector v2_bf16(dim); + vectorlite::ops::QuantizeF32ToBF16(v1.data(), v1_bf16.data(), dim); + vectorlite::ops::QuantizeF32ToBF16(v2.data(), v2_bf16.data(), dim); + + for (auto _ : state) { + benchmark::DoNotOptimize( + vectorlite::ops::L2DistanceSquared(v1_bf16.data(), v2_bf16.data(), dim)); + benchmark::ClobberMemory(); + } +} + static void BM_L2DistanceSquared_Scalar(benchmark::State& state) { size_t dim = state.range(0); auto v1 = GenerateOneRandomVector(dim); @@ -104,6 +141,17 @@ static void BM_Normalize_Vectorlite(benchmark::State& state) { } } +static void BM_Normalize_Vectorlite_BF16(benchmark::State& state) { + size_t dim = state.range(0); + auto v1 = GenerateOneRandomVector(dim); + std::vector v1_bf16(dim); + vectorlite::ops::QuantizeF32ToBF16(v1.data(), v1_bf16.data(), dim); + + for (auto _ : state) { + vectorlite::ops::Normalize(v1_bf16.data(), dim); + } +} + static void BM_Normalize_Scalar(benchmark::State& state) { size_t dim = state.range(0); auto v1 = GenerateOneRandomVector(dim); @@ -125,12 +173,19 @@ BENCHMARK(BM_InnerProduct_Vectorlite) ->ArgsProduct({ benchmark::CreateRange(128, 8 << 11, 2), {0, 1} // self product }); +BENCHMARK(BM_InnerProduct_Vectorlite_BF16) + ->RangeMultiplier(2) + ->Range(128, 8 << 11); BENCHMARK(BM_Normalize_Vectorlite)->RangeMultiplier(2)->Range(128, 8 << 11); +BENCHMARK(BM_Normalize_Vectorlite_BF16)->RangeMultiplier(2)->Range(128, 8 << 11); BENCHMARK(BM_Normalize_Scalar)->RangeMultiplier(2)->Range(128, 8 << 11); BENCHMARK(BM_L2DistanceSquared_Scalar)->RangeMultiplier(2)->Range(128, 8 << 11); BENCHMARK(BM_L2DistanceSquared_Vectorlite) ->RangeMultiplier(2) ->Range(128, 8 << 11); +BENCHMARK(BM_L2DistanceSquared_Vectorlite_BF16) + ->RangeMultiplier(2) + ->Range(128, 8 << 11); BENCHMARK(BM_L2DistanceSquared_HNSWLIB) ->RangeMultiplier(2) ->Range(128, 8 << 11); \ No newline at end of file diff --git a/vectorlite/ops/ops_test.cpp b/vectorlite/ops/ops_test.cpp index f0e2c3a..1be71fe 100644 --- a/vectorlite/ops/ops_test.cpp +++ b/vectorlite/ops/ops_test.cpp @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "hnswlib/hnswlib.h" +#include "hwy/base.h" static std::vector> GenerateRandomVectors(size_t num_vectors, size_t dim) { @@ -59,6 +60,35 @@ TEST(InnerProduct, ShouldWorkWithRandomVectors) { } } +TEST(InnerProduct_BF16, ShouldWorkWithRandomVectors) { + for (int dim = 1; dim <= 128; dim++) { + auto vectors = GenerateRandomVectors(10, dim); + for (int i = 0; i < vectors.size(); ++i) { + for (int j = 0; j < vectors.size(); ++j) { + auto v1 = vectors[i]; + auto v2 = vectors[j]; + auto size = dim; + std::vector v1_bf16(size); + vectorlite::ops::QuantizeF32ToBF16(v1.data(), v1_bf16.data(), size); + + std::vector v2_bf16(size); + vectorlite::ops::QuantizeF32ToBF16(v2.data(), v2_bf16.data(), size); + + float ip = vectorlite::ops::InnerProduct(v1_bf16.data(), v2_bf16.data(), size); + + float expected = 0.0f; + for (int k = 0; k < size; ++k) { + expected += hwy::F32FromBF16(hwy::BF16FromF32(v1[k])) * hwy::F32FromBF16(hwy::BF16FromF32(v2[k])); + } + // Note: floating point operations are not associative. SIMD version and + // scalar version traverse elements in different order. So the result + // should be different but close enough + EXPECT_NEAR(ip, expected, kEpsilon) << " dim = " << dim; + } + } + } +} + TEST(InnerProductDistance, ShouldReturnOneForEmptyVectors) { // Fixes C2466: cannot allocate an array of constant size 0 on MSVC float v1[] = {1}; @@ -120,6 +150,54 @@ TEST(L2DistanceSquared, ShouldWorkWithRandomVectors) { } } +TEST(L2DistanceSquared_BF16, ShouldWorkWithRandomVectors) { + for (size_t dim = 1; dim <= 128; dim++) { + auto vectors = GenerateRandomVectors(10, dim); + for (int i = 0; i < vectors.size(); ++i) { + for (int j = 0; j < vectors.size(); ++j) { + const auto& v1 = vectors[i]; + const auto& v2 = vectors[j]; + std::vector v1_bf16(dim); + vectorlite::ops::QuantizeF32ToBF16(v1.data(), v1_bf16.data(), dim); + + std::vector v2_bf16(dim); + vectorlite::ops::QuantizeF32ToBF16(v2.data(), v2_bf16.data(), dim); + + float result = vectorlite::ops::L2DistanceSquared(v1_bf16.data(), v2_bf16.data(), dim); + float expected = 0; + for (int k = 0; k < dim; ++k) { + float diff = hwy::F32FromBF16(v1_bf16[k]) - hwy::F32FromBF16(v2_bf16[k]); + expected += diff * diff; + } + EXPECT_NEAR(result, expected, 1e-2); + } + } + } +} + +TEST(L2DistanceSquared_F32_BF16, ShouldWorkWithRandomVectors) { + for (size_t dim = 1; dim <= 128; dim++) { + auto vectors = GenerateRandomVectors(2, dim); + for (int i = 0; i < vectors.size(); ++i) { + for (int j = 0; j < vectors.size(); ++j) { + const auto& v1 = vectors[i]; + const auto& v2 = vectors[j]; + + std::vector v2_bf16(dim); + vectorlite::ops::QuantizeF32ToBF16(v2.data(), v2_bf16.data(), dim); + + float result = vectorlite::ops::L2DistanceSquared(v1.data(), v2_bf16.data(), dim); + float expected = 0; + for (int k = 0; k < dim; ++k) { + float diff = v1[k] - hwy::F32FromBF16(v2_bf16[k]); + expected += diff * diff; + } + EXPECT_NEAR(result, expected, 1e-2) << " dim = " << dim; + } + } + } +} + TEST(Normalize, ShouldReturnCorrectResult) { for (int dim = 1; dim <= 1000; dim++) { auto vectors = GenerateRandomVectors(10, dim); @@ -136,3 +214,111 @@ TEST(Normalize, ShouldReturnCorrectResult) { } } } + +TEST(Normalize_F32ToBF16, ShouldReturnCorrectResult) { + for (int dim = 1; dim <= 1000; dim++) { + auto vectors = GenerateRandomVectors(1, dim); + for (int i = 0; i < vectors.size(); ++i) { + std::vector v = vectors[i]; + auto size = dim; + std::vector v_bf16(size); + vectorlite::ops::QuantizeF32ToBF16(v.data(), v_bf16.data(), size); + std::vector v_bf16_scalar = v_bf16; + + vectorlite::ops::Normalize(v_bf16.data(), size); + vectorlite::ops::Normalize_Scalar(v_bf16_scalar.data(), size); + + float sum = 0; + float sum_scalar = 0; + for (int j = 0; j < size; ++j) { + sum += hwy::F32FromBF16(v_bf16[j]) * hwy::F32FromBF16(v_bf16[j]); + sum_scalar += hwy::F32FromBF16(v_bf16_scalar[j]) * + hwy::F32FromBF16(v_bf16_scalar[j]); + EXPECT_NEAR(hwy::F32FromBF16(v_bf16[i]), + hwy::F32FromBF16(v_bf16_scalar[i]), 1e-3) + << " dim = " << dim; + } + + EXPECT_NEAR(sum, 1.0, 1e-2); + EXPECT_NEAR(sum, sum_scalar, 1e-2); + } + } +} + +TEST(QuantizeF32ToBF16, ShouldReturnCorrectResult) { + for (int dim = 0; dim <= 100; dim++) { + auto vectors = GenerateRandomVectors(10, dim); + for (int i = 0; i < vectors.size(); ++i) { + std::vector v = vectors[i]; + auto size = dim; + std::vector out(size); + vectorlite::ops::QuantizeF32ToBF16(v.data(), out.data(), size); + + for (int j = 0; j < size; ++j) { + float expected = hwy::F32FromBF16(hwy::BF16FromF32(v[j])); + EXPECT_NEAR(expected, hwy::F32FromBF16(out[j]), 1e-6) + << "v[" << j << "] = " << v[j] << " dim = " << dim; + } + } + } +} + +TEST(QuantizeF32ToF16, ShouldReturnCorrectResult) { + for (int dim = 0; dim <= 100; dim++) { + auto vectors = GenerateRandomVectors(10, dim); + for (int i = 0; i < vectors.size(); ++i) { + std::vector v = vectors[i]; + auto size = dim; + std::vector out(size); + vectorlite::ops::QuantizeF32ToF16(v.data(), out.data(), size); + + for (int j = 0; j < size; ++j) { + float expected = hwy::F32FromF16(hwy::F16FromF32(v[j])); + EXPECT_NEAR(expected, hwy::F32FromF16(out[j]), 1e-6) + << "v[" << j << "] = " << v[j] << " dim = " << dim; + } + } + } +} + +TEST(F16ToF32, ShouldReturnCorrectResult) { + for (int dim = 0; dim <= 100; dim++) { + auto vectors = GenerateRandomVectors(10, dim); + for (int i = 0; i < vectors.size(); ++i) { + std::vector v = vectors[i]; + auto size = dim; + std::vector f16(size); + for (int j = 0; j < size; ++j) { + f16[j] = hwy::F16FromF32(v[j]); + } + std::vector out(size); + vectorlite::ops::F16ToF32(f16.data(), out.data(), size); + + for (int j = 0; j < size; ++j) { + EXPECT_NEAR(hwy::F32FromF16(f16[j]), out[j], 1e-6) + << "v[" << j << "] = " << v[j] << " dim = " << dim; + } + } + } +} + +TEST(BF16ToF32, ShouldReturnCorrectResult) { + for (int dim = 0; dim <= 100; dim++) { + auto vectors = GenerateRandomVectors(10, dim); + for (int i = 0; i < vectors.size(); ++i) { + std::vector v = vectors[i]; + auto size = dim; + std::vector bf16(size); + for (int j = 0; j < size; ++j) { + bf16[j] = hwy::BF16FromF32(v[j]); + } + std::vector out(size); + vectorlite::ops::BF16ToF32(bf16.data(), out.data(), size); + + for (int j = 0; j < size; ++j) { + EXPECT_NEAR(hwy::F32FromBF16(bf16[j]), out[j], 1e-6) + << "v[" << j << "] = " << v[j] << " dim = " << dim; + } + } + } +} \ No newline at end of file diff --git a/vectorlite/quantization.cpp b/vectorlite/quantization.cpp new file mode 100644 index 0000000..e5eb7d1 --- /dev/null +++ b/vectorlite/quantization.cpp @@ -0,0 +1,19 @@ +#include "quantization.h" + +#include + +#include "hwy/base.h" +#include "ops/ops.h" +#include "vector.h" +#include "vector_view.h" + +namespace vectorlite { + +BF16Vector Quantize(VectorView v) { + std::vector quantized(v.dim()); + ops::QuantizeF32ToBF16(v.data().data(), quantized.data(), v.dim()); + + return BF16Vector(std::move(quantized)); +} + +} // namespace vectorlite \ No newline at end of file diff --git a/vectorlite/quantization.h b/vectorlite/quantization.h new file mode 100644 index 0000000..9c31ae4 --- /dev/null +++ b/vectorlite/quantization.h @@ -0,0 +1,10 @@ +#pragma once + +#include "vector.h" +#include "vector_view.h" + +namespace vectorlite { + +BF16Vector Quantize(VectorView v); + +} // namespace vectorlite \ No newline at end of file diff --git a/vectorlite/util.h b/vectorlite/util.h index e56c359..1e126a0 100644 --- a/vectorlite/util.h +++ b/vectorlite/util.h @@ -2,6 +2,7 @@ #include #include +#include #include #include "hnswlib/hnswlib.h" @@ -25,4 +26,84 @@ std::optional DetectSIMD(); bool IsRowidInIndex(const hnswlib::HierarchicalNSW& index, hnswlib::labeltype rowid); +// Below *Base classes are taken from +// https://github.com/abseil/abseil-cpp/blob/20240722.0/absl/status/internal/statusor_internal.h#L368 +// to allow implicitly deleted constructors and assignment +// operators in a Derived class. For example, `CopyCtorBase` will explicitly +// delete the copy constructor when T is not copy constructible and `Derived` +// class will inherit that behavior implicitly. +template ::value> +struct CopyCtorBase { + CopyCtorBase() = default; + CopyCtorBase(const CopyCtorBase&) = default; + CopyCtorBase(CopyCtorBase&&) = default; + CopyCtorBase& operator=(const CopyCtorBase&) = default; + CopyCtorBase& operator=(CopyCtorBase&&) = default; +}; + +template +struct CopyCtorBase { + CopyCtorBase() = default; + CopyCtorBase(const CopyCtorBase&) = delete; + CopyCtorBase(CopyCtorBase&&) = default; + CopyCtorBase& operator=(const CopyCtorBase&) = default; + CopyCtorBase& operator=(CopyCtorBase&&) = default; +}; + +template ::value> +struct MoveCtorBase { + MoveCtorBase() = default; + MoveCtorBase(const MoveCtorBase&) = default; + MoveCtorBase(MoveCtorBase&&) = default; + MoveCtorBase& operator=(const MoveCtorBase&) = default; + MoveCtorBase& operator=(MoveCtorBase&&) = default; +}; + +template +struct MoveCtorBase { + MoveCtorBase() = default; + MoveCtorBase(const MoveCtorBase&) = default; + MoveCtorBase(MoveCtorBase&&) = delete; + MoveCtorBase& operator=(const MoveCtorBase&) = default; + MoveCtorBase& operator=(MoveCtorBase&&) = default; +}; + +template ::value&& + std::is_copy_assignable::value> +struct CopyAssignBase { + CopyAssignBase() = default; + CopyAssignBase(const CopyAssignBase&) = default; + CopyAssignBase(CopyAssignBase&&) = default; + CopyAssignBase& operator=(const CopyAssignBase&) = default; + CopyAssignBase& operator=(CopyAssignBase&&) = default; +}; + +template +struct CopyAssignBase { + CopyAssignBase() = default; + CopyAssignBase(const CopyAssignBase&) = default; + CopyAssignBase(CopyAssignBase&&) = default; + CopyAssignBase& operator=(const CopyAssignBase&) = delete; + CopyAssignBase& operator=(CopyAssignBase&&) = default; +}; + +template ::value&& + std::is_move_assignable::value> +struct MoveAssignBase { + MoveAssignBase() = default; + MoveAssignBase(const MoveAssignBase&) = default; + MoveAssignBase(MoveAssignBase&&) = default; + MoveAssignBase& operator=(const MoveAssignBase&) = default; + MoveAssignBase& operator=(MoveAssignBase&&) = default; +}; + +template +struct MoveAssignBase { + MoveAssignBase() = default; + MoveAssignBase(const MoveAssignBase&) = default; + MoveAssignBase(MoveAssignBase&&) = default; + MoveAssignBase& operator=(const MoveAssignBase&) = default; + MoveAssignBase& operator=(MoveAssignBase&&) = delete; +}; + } // end namespace vectorlite diff --git a/vectorlite/vector.cpp b/vectorlite/vector.cpp deleted file mode 100644 index a2a4808..0000000 --- a/vectorlite/vector.cpp +++ /dev/null @@ -1,115 +0,0 @@ -#include "vector.h" - -#include -#include - -#include -#include -#include - -#include "hnswlib/hnswlib.h" -#include "hnswlib/space_l2.h" -#include "macros.h" -#include "rapidjson/document.h" -#include "rapidjson/error/en.h" -#include "rapidjson/stringbuffer.h" -#include "rapidjson/writer.h" -#include "vector_space.h" -#include "vector_view.h" -#include "ops/ops.h" - -namespace vectorlite { - -absl::StatusOr Vector::FromJSON(std::string_view json) { - rapidjson::Document doc; - doc.Parse(json.data(), json.size()); - auto err = doc.GetParseError(); - if (err != rapidjson::ParseErrorCode::kParseErrorNone) { - return absl::InvalidArgumentError(rapidjson::GetParseError_En(err)); - } - - Vector result; - - if (doc.IsArray()) { - for (auto& v : doc.GetArray()) { - if (v.IsNumber()) { - result.data_.push_back(v.GetFloat()); - } else { - return absl::InvalidArgumentError( - "JSON array contains non-numeric value."); - } - } - return result; - } - - return absl::InvalidArgumentError("Input JSON is not an array."); -} - -absl::StatusOr Vector::FromBlob(std::string_view blob) { - auto vector_view = VectorView::FromBlob(blob); - if (vector_view.ok()) { - return Vector(*vector_view); - } - return vector_view.status(); -} - -std::string Vector::ToJSON() const { - VectorView vector_view(*this); - - return vector_view.ToJSON(); -} - -absl::StatusOr Distance(VectorView v1, VectorView v2, - DistanceType distance_type) { - if (v1.dim() != v2.dim()) { - std::string err = - absl::StrFormat("Dimension mismatch: %d != %d", v1.dim(), v2.dim()); - return absl::InvalidArgumentError(err); - } - - ops::DistanceFunc distance_func = nullptr; - - switch (distance_type) { - case DistanceType::L2: - distance_func = ops::L2DistanceSquared; - break; - case DistanceType::InnerProduct: - distance_func = ops::InnerProductDistance; - break; - case DistanceType::Cosine: - distance_func = ops::InnerProductDistance; - break; - default: - return absl::InvalidArgumentError("Invalid distance type"); - } - - bool normalize = distance_type == DistanceType::Cosine; - - if (!normalize) { - return distance_func(v1.data().data(), v2.data().data(), v1.dim()); - } - - Vector lhs = Vector::Normalize(v1); - Vector rhs = Vector::Normalize(v2); - return distance_func(lhs.data().data(), rhs.data().data(), v1.dim()); -} - -std::string_view Vector::ToBlob() const { - VectorView vector_view(*this); - - return vector_view.ToBlob(); -} - -Vector Vector::Normalize() const { - VectorView vector_view(*this); - - return Vector::Normalize(vector_view); -} - -Vector Vector::Normalize(VectorView vector_view) { - Vector v(vector_view); - ops::Normalize(v.data_.data(), vector_view.dim()); - return v; -} - -} // namespace vectorlite \ No newline at end of file diff --git a/vectorlite/vector.h b/vectorlite/vector.h index d833335..fb78204 100644 --- a/vectorlite/vector.h +++ b/vectorlite/vector.h @@ -1,52 +1,148 @@ #pragma once +#include + #include #include #include #include "absl/status/statusor.h" #include "macros.h" +#include "ops/ops.h" +#include "rapidjson/document.h" +#include "rapidjson/error/en.h" +#include "rapidjson/stringbuffer.h" +#include "rapidjson/writer.h" +#include "util.h" #include "vector_space.h" #include "vector_view.h" namespace vectorlite { -class Vector { +template +class GenericVector : private CopyAssignBase, + private CopyCtorBase, + private MoveCtorBase, + private MoveAssignBase { public: - Vector() = default; - Vector(const Vector&) = default; - Vector(Vector&&) = default; + GenericVector() = default; + GenericVector(const GenericVector&) = default; + GenericVector(GenericVector&&) = default; - explicit Vector(std::vector&& data) : data_(std::move(data)) {} - explicit Vector(const std::vector& data) : data_(data) {} - explicit Vector(VectorView vector_view) + explicit GenericVector(std::vector&& data) : data_(std::move(data)) {} + explicit GenericVector(const std::vector& data) : data_(data) {} + explicit GenericVector(GenericVectorView vector_view) : data_(vector_view.data().begin(), vector_view.data().end()) {} - Vector& operator=(const Vector&) = default; - Vector& operator=(Vector&&) = default; + GenericVector& operator=(const GenericVector&) = default; + GenericVector& operator=(GenericVector&&) = default; + + static absl::StatusOr> FromJSON(std::string_view json) { + rapidjson::Document doc; + doc.Parse(json.data(), json.size()); + auto err = doc.GetParseError(); + if (err != rapidjson::ParseErrorCode::kParseErrorNone) { + return absl::InvalidArgumentError(rapidjson::GetParseError_En(err)); + } + + GenericVector result; + + if (doc.IsArray()) { + for (auto& v : doc.GetArray()) { + if (v.IsNumber()) { + result.data_.push_back(hwy::ConvertScalarTo(v.GetFloat())); + } else { + return absl::InvalidArgumentError( + "JSON array contains non-numeric value."); + } + } + return result; + } - static absl::StatusOr FromJSON(std::string_view json); + return absl::InvalidArgumentError("Input JSON is not an array."); + } - static absl::StatusOr FromBlob(std::string_view blob); + static absl::StatusOr> FromBlob(std::string_view blob) { + auto vector_view = GenericVectorView::FromBlob(blob); + if (vector_view.ok()) { + return GenericVector(*vector_view); + } + return vector_view.status(); + } - std::string ToJSON() const; + std::string ToJSON() const { + GenericVectorView vector_view(*this); - std::string_view ToBlob() const; + return vector_view.ToJSON(); + } - const std::vector& data() const { return data_; } + std::string_view ToBlob() const { + GenericVectorView vector_view(*this); + + return vector_view.ToBlob(); + }; + + const std::vector& data() const { return data_; } std::size_t dim() const { return data_.size(); } - Vector Normalize() const; + GenericVector Normalize() const { + GenericVectorView vector_view(*this); - static Vector Normalize(VectorView vector_view); + return GenericVector::Normalize(vector_view); + } + + static GenericVector Normalize(GenericVectorView vector_view) { + GenericVector v(vector_view); + ops::Normalize(v.data_.data(), vector_view.dim()); + return v; + } private: - std::vector data_; + std::vector data_; }; +template +using DistanceFunc = float (*)(const T*, const T*, size_t); + // Calculate the distance between two vectors. -absl::StatusOr Distance(VectorView v1, VectorView v2, - DistanceType space); +template +absl::StatusOr Distance(GenericVectorView v1, GenericVectorView v2, + DistanceType distance_type) { + if (v1.dim() != v2.dim()) { + std::string err = + absl::StrFormat("Dimension mismatch: %d != %d", v1.dim(), v2.dim()); + return absl::InvalidArgumentError(err); + } + + DistanceFunc distance_func = nullptr; + + switch (distance_type) { + case DistanceType::L2: + distance_func = ops::L2DistanceSquared; + break; + case DistanceType::InnerProduct: + distance_func = ops::InnerProductDistance; + break; + case DistanceType::Cosine: + distance_func = ops::InnerProductDistance; + break; + default: + return absl::InvalidArgumentError("Invalid distance type"); + } + + bool normalize = distance_type == DistanceType::Cosine; + + if (!normalize) { + return distance_func(v1.data().data(), v2.data().data(), v1.dim()); + } + + GenericVector lhs = GenericVector::Normalize(v1); + GenericVector rhs = GenericVector::Normalize(v2); + return distance_func(lhs.data().data(), rhs.data().data(), v1.dim()); +} + +using Vector = GenericVector; +using BF16Vector = GenericVector; } // namespace vectorlite diff --git a/vectorlite/vector_space.cpp b/vectorlite/vector_space.cpp index c9b1226..bd69e56 100644 --- a/vectorlite/vector_space.cpp +++ b/vectorlite/vector_space.cpp @@ -1,8 +1,12 @@ #include "vector_space.h" +#include + +#include #include #include +#include "absl/base/optimization.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" #include "distance.h" @@ -27,9 +31,42 @@ std::optional ParseVectorType(std::string_view vector_type) { if (vector_type == "float32") { return VectorType::Float32; } + + if (vector_type == "bfloat16") { + return VectorType::BFloat16; + } + return std::nullopt; } +static std::unique_ptr> CreateL2Space( + size_t dim, VectorType vector_type) { + switch (vector_type) { + case VectorType::Float32: + return std::make_unique(dim); + case VectorType::BFloat16: + return std::make_unique(dim); + default: + // This should never happen, but we include it for completeness + ABSL_UNREACHABLE(); + return nullptr; + } +} + +static std::unique_ptr> CreateInnerProductSpace( + size_t dim, VectorType vector_type) { + switch (vector_type) { + case VectorType::Float32: + return std::make_unique(dim); + case VectorType::BFloat16: + return std::make_unique(dim); + default: + // This should never happen, but we include it for completeness + ABSL_UNREACHABLE(); + return nullptr; + } +} + absl::StatusOr VectorSpace::Create(size_t dim, DistanceType distance_type, VectorType vector_type) { @@ -43,13 +80,13 @@ absl::StatusOr VectorSpace::Create(size_t dim, result.vector_type = vector_type; switch (distance_type) { case DistanceType::L2: - result.space = std::make_unique(dim); + result.space = CreateL2Space(dim, vector_type); break; case DistanceType::InnerProduct: - result.space = std::make_unique(dim); + result.space = CreateInnerProductSpace(dim, vector_type); break; case DistanceType::Cosine: - result.space = std::make_unique(dim); + result.space = CreateInnerProductSpace(dim, vector_type); break; default: std::string err_msg = diff --git a/vectorlite/vector_space.h b/vectorlite/vector_space.h index c8df5a7..cef25e4 100644 --- a/vectorlite/vector_space.h +++ b/vectorlite/vector_space.h @@ -19,6 +19,7 @@ std::optional ParseDistanceType(std::string_view distance_type); enum class VectorType { Float32, + BFloat16, }; std::optional ParseVectorType(std::string_view vector_type); diff --git a/vectorlite/vector_space_test.cpp b/vectorlite/vector_space_test.cpp index 065008d..c3f0a90 100644 --- a/vectorlite/vector_space_test.cpp +++ b/vectorlite/vector_space_test.cpp @@ -1,5 +1,6 @@ #include "vector_space.h" +#include "absl/strings/str_format.h" #include "gtest/gtest.h" TEST(ParseDistanceType, ShouldSupport_L2_InnerProduct_Cosine) { @@ -35,91 +36,114 @@ TEST(ParseVectorType, ShouldReturnNullOptForInvalidVectorType) { EXPECT_FALSE(uint8); } +TEST(ParseVectorType, ShouldSupportBFloat16) { + auto float16 = vectorlite::ParseVectorType("bfloat16"); + EXPECT_TRUE(float16); +} + TEST(CreateVectorSpace, ShouldWorkWithValidInput) { - auto l2 = vectorlite::CreateNamedVectorSpace(3, vectorlite::DistanceType::L2, - "my_vector", - vectorlite::VectorType::Float32); - ASSERT_TRUE(l2.ok()); - EXPECT_EQ(l2->distance_type, vectorlite::DistanceType::L2); - EXPECT_EQ(l2->normalize, false); - EXPECT_NE(l2->space, nullptr); - EXPECT_EQ(l2->dimension(), 3); - EXPECT_EQ(l2->vector_type, vectorlite::VectorType::Float32); - - auto ip = vectorlite::CreateNamedVectorSpace( - 4, vectorlite::DistanceType::InnerProduct, "my_vector", - vectorlite::VectorType::Float32); - ASSERT_TRUE(ip.ok()); - EXPECT_EQ(ip->distance_type, vectorlite::DistanceType::InnerProduct); - EXPECT_EQ(ip->normalize, false); - EXPECT_NE(ip->space, nullptr); - EXPECT_EQ(ip->dimension(), 4); - EXPECT_EQ(ip->vector_type, vectorlite::VectorType::Float32); - - auto cosine = vectorlite::CreateNamedVectorSpace( - 5, vectorlite::DistanceType::Cosine, "my_vector", - vectorlite::VectorType::Float32); - ASSERT_TRUE(cosine.ok()); - EXPECT_EQ(cosine->distance_type, vectorlite::DistanceType::Cosine); - EXPECT_EQ(cosine->normalize, true); - EXPECT_NE(cosine->space, nullptr); - EXPECT_EQ(cosine->dimension(), 5); - EXPECT_EQ(cosine->vector_type, vectorlite::VectorType::Float32); + for (auto vector_type : + {vectorlite::VectorType::Float32, vectorlite::VectorType::BFloat16}) { + auto l2 = vectorlite::CreateNamedVectorSpace( + 3, vectorlite::DistanceType::L2, "my_vector", vector_type); + ASSERT_TRUE(l2.ok()); + EXPECT_EQ(l2->distance_type, vectorlite::DistanceType::L2); + EXPECT_EQ(l2->normalize, false); + EXPECT_NE(l2->space, nullptr); + EXPECT_EQ(l2->dimension(), 3); + EXPECT_EQ(l2->vector_type, vector_type); + + auto ip = vectorlite::CreateNamedVectorSpace( + 4, vectorlite::DistanceType::InnerProduct, "my_vector", vector_type); + ASSERT_TRUE(ip.ok()); + EXPECT_EQ(ip->distance_type, vectorlite::DistanceType::InnerProduct); + EXPECT_EQ(ip->normalize, false); + EXPECT_NE(ip->space, nullptr); + EXPECT_EQ(ip->dimension(), 4); + EXPECT_EQ(ip->vector_type, vector_type); + + auto cosine = vectorlite::CreateNamedVectorSpace( + 5, vectorlite::DistanceType::Cosine, "my_vector", vector_type); + ASSERT_TRUE(cosine.ok()); + EXPECT_EQ(cosine->distance_type, vectorlite::DistanceType::Cosine); + EXPECT_EQ(cosine->normalize, true); + EXPECT_NE(cosine->space, nullptr); + EXPECT_EQ(cosine->dimension(), 5); + EXPECT_EQ(cosine->vector_type, vector_type); + } } TEST(CreateNamedVectorSpace, ShouldReturnErrorForDimOfZero) { - auto l2 = vectorlite::CreateNamedVectorSpace(0, vectorlite::DistanceType::L2, - "my_vector", - vectorlite::VectorType::Float32); - EXPECT_FALSE(l2.ok()); - - auto ip = vectorlite::CreateNamedVectorSpace( - 0, vectorlite::DistanceType::InnerProduct, "my_vector", - vectorlite::VectorType::Float32); - EXPECT_FALSE(ip.ok()); - - auto cosine = vectorlite::CreateNamedVectorSpace( - 0, vectorlite::DistanceType::Cosine, "my_vector", - vectorlite::VectorType::Float32); - EXPECT_FALSE(cosine.ok()); + for (auto vector_type : + {vectorlite::VectorType::Float32, vectorlite::VectorType::BFloat16}) { + auto l2 = vectorlite::CreateNamedVectorSpace( + 0, vectorlite::DistanceType::L2, "my_vector", vector_type); + EXPECT_FALSE(l2.ok()); + + auto ip = vectorlite::CreateNamedVectorSpace( + 0, vectorlite::DistanceType::InnerProduct, "my_vector", vector_type); + EXPECT_FALSE(ip.ok()); + + auto cosine = vectorlite::CreateNamedVectorSpace( + 0, vectorlite::DistanceType::Cosine, "my_vector", vector_type); + EXPECT_FALSE(cosine.ok()); + } +} + +static std::string VectorTypeToString(vectorlite::VectorType type) { + switch (type) { + case vectorlite::VectorType::Float32: + return "float32"; + case vectorlite::VectorType::BFloat16: + return "bfloat16"; + default: + return "unknown"; + } } TEST(NamedVectorSpace_FromString, ShouldWorkWithValidInput) { - // If distance type is not specifed, it should default to L2 - auto space = vectorlite::NamedVectorSpace::FromString("my_vec float32[3]"); - ASSERT_TRUE(space.ok()); - EXPECT_EQ(space->normalize, false); - EXPECT_NE(space->space, nullptr); - EXPECT_EQ(space->distance_type, vectorlite::DistanceType::L2); - EXPECT_EQ(3, space->dimension()); - EXPECT_EQ("my_vec", space->vector_name); - EXPECT_EQ(vectorlite::VectorType::Float32, space->vector_type); - - space = vectorlite::NamedVectorSpace::FromString("my_vec float32[3] l2"); - ASSERT_TRUE(space.ok()); - EXPECT_EQ(space->normalize, false); - EXPECT_NE(space->space, nullptr); - EXPECT_EQ(space->distance_type, vectorlite::DistanceType::L2); - EXPECT_EQ(3, space->dimension()); - EXPECT_EQ("my_vec", space->vector_name); - EXPECT_EQ(vectorlite::VectorType::Float32, space->vector_type); - - space = - vectorlite::NamedVectorSpace::FromString("my_vec float32[10086] cosine"); - ASSERT_TRUE(space.ok()); - EXPECT_EQ(space->normalize, true); - EXPECT_NE(space->space, nullptr); - EXPECT_EQ(space->distance_type, vectorlite::DistanceType::Cosine); - EXPECT_EQ(10086, space->dimension()); - EXPECT_EQ("my_vec", space->vector_name); - EXPECT_EQ(vectorlite::VectorType::Float32, space->vector_type); - - space = vectorlite::NamedVectorSpace::FromString("my_vec float32[42] ip"); - ASSERT_TRUE(space.ok()); - EXPECT_EQ(space->normalize, false); - EXPECT_NE(space->space, nullptr); - EXPECT_EQ(space->distance_type, vectorlite::DistanceType::InnerProduct); - EXPECT_EQ(42, space->dimension()); - EXPECT_EQ("my_vec", space->vector_name); - EXPECT_EQ(vectorlite::VectorType::Float32, space->vector_type); + for (auto vector_type : + {vectorlite::VectorType::Float32, vectorlite::VectorType::BFloat16}) { + // If distance type is not specifed, it should default to L2 + std::string vector_type_str = VectorTypeToString(vector_type); + auto space = vectorlite::NamedVectorSpace::FromString( + absl::StrFormat("my_vec %s[3]", vector_type_str)); + ASSERT_TRUE(space.ok()); + EXPECT_EQ(space->normalize, false); + EXPECT_NE(space->space, nullptr); + EXPECT_EQ(space->distance_type, vectorlite::DistanceType::L2); + EXPECT_EQ(3, space->dimension()); + EXPECT_EQ("my_vec", space->vector_name); + EXPECT_EQ(vector_type, space->vector_type); + + space = vectorlite::NamedVectorSpace::FromString( + absl::StrFormat("my_vec %s[3] l2", vector_type_str)); + ASSERT_TRUE(space.ok()); + EXPECT_EQ(space->normalize, false); + EXPECT_NE(space->space, nullptr); + EXPECT_EQ(space->distance_type, vectorlite::DistanceType::L2); + EXPECT_EQ(3, space->dimension()); + EXPECT_EQ("my_vec", space->vector_name); + EXPECT_EQ(vector_type, space->vector_type); + + space = vectorlite::NamedVectorSpace::FromString( + absl::StrFormat("my_vec %s[10086] cosine", vector_type_str)); + ASSERT_TRUE(space.ok()); + EXPECT_EQ(space->normalize, true); + EXPECT_NE(space->space, nullptr); + EXPECT_EQ(space->distance_type, vectorlite::DistanceType::Cosine); + EXPECT_EQ(10086, space->dimension()); + EXPECT_EQ("my_vec", space->vector_name); + EXPECT_EQ(vector_type, space->vector_type); + + space = vectorlite::NamedVectorSpace::FromString( + absl::StrFormat("my_vec %s[42] ip", vector_type_str)); + ASSERT_TRUE(space.ok()); + EXPECT_EQ(space->normalize, false); + EXPECT_NE(space->space, nullptr); + EXPECT_EQ(space->distance_type, vectorlite::DistanceType::InnerProduct); + EXPECT_EQ(42, space->dimension()); + EXPECT_EQ("my_vec", space->vector_name); + EXPECT_EQ(vector_type, space->vector_type); + } } diff --git a/vectorlite/vector_test.cpp b/vectorlite/vector_test.cpp index 2cdaf77..acd79a7 100644 --- a/vectorlite/vector_test.cpp +++ b/vectorlite/vector_test.cpp @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "vector_space.h" +#include "vector_view.h" TEST(VectorTest, FromJSON) { // Test valid JSON input @@ -69,15 +70,19 @@ TEST(VectorDistance, ShouldWork) { // Test valid input vectorlite::Vector v1({1.0, 2.0, 3.0}); vectorlite::Vector v2({4.0, 5.0, 6.0}); - auto distance = Distance(v1, v2, vectorlite::DistanceType::L2); + + vectorlite::VectorView v1_view(v1); + vectorlite::VectorView v2_view(v2); + + auto distance = Distance(v1_view, v2_view, vectorlite::DistanceType::L2); EXPECT_TRUE(distance.ok()); EXPECT_FLOAT_EQ(*distance, 27); - distance = Distance(v2, v1, vectorlite::DistanceType::InnerProduct); + distance = Distance(v2_view, v1_view, vectorlite::DistanceType::InnerProduct); EXPECT_TRUE(distance.ok()); EXPECT_FLOAT_EQ(*distance, -31); - distance = Distance(v1, v2, vectorlite::DistanceType::Cosine); + distance = Distance(v1_view, v2_view, vectorlite::DistanceType::Cosine); EXPECT_TRUE(distance.ok()); // On osx arm64, no vectoration is used and the following test fails. // EXPECT_FLOAT_EQ(*distance, 0.025368214); @@ -87,9 +92,12 @@ TEST(VectorDistance, ShouldWork) { // Test 0 dimension vectorlite::Vector v3; vectorlite::Vector v4; + + vectorlite::VectorView v3_view(v3); + vectorlite::VectorView v4_view(v4); for (auto space : {vectorlite::DistanceType::L2, vectorlite::DistanceType::InnerProduct}) { - distance = Distance(v3, v4, space); + distance = Distance(v3_view, v4_view, space); EXPECT_TRUE(distance.ok()); } } diff --git a/vectorlite/vector_view.cpp b/vectorlite/vector_view.cpp deleted file mode 100644 index adb20e8..0000000 --- a/vectorlite/vector_view.cpp +++ /dev/null @@ -1,44 +0,0 @@ -#include "vector_view.h" - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "rapidjson/document.h" -#include "rapidjson/error/en.h" -#include "rapidjson/stringbuffer.h" -#include "rapidjson/writer.h" -#include "vector.h" - -namespace vectorlite { - -VectorView::VectorView(const Vector& vector) : data_(vector.data()) {} - -absl::StatusOr VectorView::FromBlob(std::string_view blob) { - if (blob.size() % sizeof(float) != 0) { - return absl::InvalidArgumentError("Blob size is not a multiple of float"); - } - return VectorView(absl::MakeSpan(reinterpret_cast(blob.data()), - blob.size() / sizeof(float))); -} - -std::string VectorView::ToJSON() const { - rapidjson::Document doc; - doc.SetArray(); - - auto& allocator = doc.GetAllocator(); - for (float v : data_) { - doc.PushBack(v, allocator); - } - - rapidjson::StringBuffer buf; - rapidjson::Writer writer(buf); - doc.Accept(writer); - - return buf.GetString(); -} - -std::string_view VectorView::ToBlob() const { - return std::string_view(reinterpret_cast(data_.data()), - data_.size() * sizeof(float)); -} - -} // namespace vectorlite \ No newline at end of file diff --git a/vectorlite/vector_view.h b/vectorlite/vector_view.h index dfadfd6..61dfbf6 100644 --- a/vectorlite/vector_view.h +++ b/vectorlite/vector_view.h @@ -5,37 +5,76 @@ #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "hwy/base.h" +#include "macros.h" +#include "rapidjson/document.h" +#include "rapidjson/error/en.h" +#include "rapidjson/stringbuffer.h" +#include "rapidjson/writer.h" +#include "util.h" namespace vectorlite { -class Vector; +template +class GenericVector; -// VectorView is a read-only view of a vector, like std::string_view is to -// std::string. -class VectorView { +// GenericVectorView is a read-only view of a vector, like what std::string_view +// is to std::string. +template +class GenericVectorView : private CopyAssignBase, + private CopyCtorBase, + private MoveCtorBase, + private MoveAssignBase { public: - VectorView() = default; - VectorView(const VectorView&) = default; - VectorView(VectorView&&) = default; + GenericVectorView() = default; + GenericVectorView(const GenericVectorView&) = default; + GenericVectorView(GenericVectorView&&) = default; - VectorView(const Vector& vector); - explicit VectorView(absl::Span data) : data_(data) {} + GenericVectorView(const GenericVector& vector) + : data_(vector.data()) {} + explicit GenericVectorView(absl::Span data) : data_(data) {} - VectorView& operator=(const VectorView&) = default; - VectorView& operator=(VectorView&&) = default; + GenericVectorView& operator=(const GenericVectorView&) = default; + GenericVectorView& operator=(GenericVectorView&&) = default; - static absl::StatusOr FromBlob(std::string_view blob); + static absl::StatusOr> FromBlob(std::string_view blob) { + if (blob.size() % sizeof(T) != 0) { + return absl::InvalidArgumentError("Blob size is not a multiple of float"); + } + return GenericVectorView(absl::MakeSpan( + reinterpret_cast(blob.data()), blob.size() / sizeof(T))); + }; - std::string ToJSON() const; + std::string ToJSON() const { + rapidjson::Document doc; + doc.SetArray(); - std::string_view ToBlob() const; + auto& allocator = doc.GetAllocator(); + for (T v : data_) { + doc.PushBack(hwy::ConvertScalarTo(v), allocator); + } + + rapidjson::StringBuffer buf; + rapidjson::Writer writer(buf); + doc.Accept(writer); + + return buf.GetString(); + }; + + std::string_view ToBlob() const { + return std::string_view(reinterpret_cast(data_.data()), + data_.size() * sizeof(T)); + }; std::size_t dim() const { return data_.size(); } - absl::Span data() const { return data_; } + absl::Span data() const { return data_; } private: - absl::Span data_; + absl::Span data_; }; +using VectorView = GenericVectorView; +using BF16VectorView = GenericVectorView; + } // namespace vectorlite \ No newline at end of file diff --git a/vectorlite/vector_view_test.cpp b/vectorlite/vector_view_test.cpp index 4b46321..c8ec1ae 100644 --- a/vectorlite/vector_view_test.cpp +++ b/vectorlite/vector_view_test.cpp @@ -21,6 +21,9 @@ TEST(VectorViewTest, Reversible_ToBinary_FromBinary) { TEST(VectorViewTest, FromBinaryShouldFailWithInvalidInput) { auto v1 = vectorlite::VectorView::FromBlob(std::string_view("aaa")); EXPECT_FALSE(v1.ok()); + + auto v2 = vectorlite::BF16VectorView::FromBlob(std::string_view("aaa")); + EXPECT_FALSE(v2.ok()); } TEST(VectorViewTest, ToJSON) { diff --git a/vectorlite/virtual_table.cpp b/vectorlite/virtual_table.cpp index e8ad984..c8920bd 100644 --- a/vectorlite/virtual_table.cpp +++ b/vectorlite/virtual_table.cpp @@ -20,8 +20,10 @@ #include "hnswlib/hnswlib.h" #include "index_options.h" #include "macros.h" +#include "quantization.h" #include "sqlite3ext.h" #include "util.h" +#include "vector.h" #include "vector_space.h" #include "vector_view.h" @@ -599,7 +601,42 @@ constexpr bool IsRowidOutOfRange(sqlite3_int64 rowid) { std::numeric_limits::max()); } -// Only insert is supported for now +int VirtualTable::InsertOrUpdateVector(VectorView vector, Cursor::Rowid rowid) { + try { + if (space_.vector_type == vectorlite::VectorType::Float32) { + if (!space_.normalize) { + index_->addPoint(vector.data().data(), rowid, + index_->allow_replace_deleted_); + } else { + Vector normalized_vector = Vector::Normalize(vector); + index_->addPoint(normalized_vector.data().data(), rowid, + index_->allow_replace_deleted_); + } + } else if (space_.vector_type == vectorlite::VectorType::BFloat16) { + BF16Vector bf16_vector = Quantize(vector); + if (!space_.normalize) { + index_->addPoint(bf16_vector.data().data(), rowid, + index_->allow_replace_deleted_); + } else { + BF16Vector normalized_vector = bf16_vector.Normalize(); + index_->addPoint(normalized_vector.data().data(), rowid, + index_->allow_replace_deleted_); + } + + } else { + SetZErrMsg(&this->zErrMsg, "Unrecognized vector type %d", + space_.vector_type); + return SQLITE_ERROR; + } + + } catch (const std::runtime_error& e) { + SetZErrMsg(&this->zErrMsg, "Failed to insert row %lld due to: %s", rowid, + e.what()); + return SQLITE_ERROR; + } + return SQLITE_OK; +} + int VirtualTable::Update(sqlite3_vtab* pVTab, int argc, sqlite3_value** argv, sqlite_int64* pRowid) { VirtualTable* vtab = static_cast(pVTab); @@ -646,20 +683,7 @@ int VirtualTable::Update(sqlite3_vtab* pVTab, int argc, sqlite3_value** argv, return SQLITE_ERROR; } - try { - if (!vtab->space_.normalize) { - vtab->index_->addPoint(vector->data().data(), rowid, true); - } else { - Vector normalized_vector = Vector::Normalize(*vector); - vtab->index_->addPoint(normalized_vector.data().data(), rowid, true); - } - - } catch (const std::runtime_error& e) { - SetZErrMsg(&vtab->zErrMsg, "Failed to insert row %lld due to: %s", - rowid, e.what()); - return SQLITE_ERROR; - } - return SQLITE_OK; + return vtab->InsertOrUpdateVector(*vector, rowid); } else { SetZErrMsg(&vtab->zErrMsg, "Failed to perform insertion due to: %s", absl::StatusMessageAsCStr(vector.status())); @@ -729,23 +753,7 @@ int VirtualTable::Update(sqlite3_vtab* pVTab, int argc, sqlite3_value** argv, return SQLITE_ERROR; } - try { - if (!vtab->space_.normalize) { - vtab->index_->addPoint(vector->data().data(), rowid, - vtab->index_->allow_replace_deleted_); - } else { - Vector normalized_vector = Vector::Normalize(*vector); - vtab->index_->addPoint(normalized_vector.data().data(), rowid, - vtab->index_->allow_replace_deleted_); - } - - } catch (const std::runtime_error& e) { - SetZErrMsg(&vtab->zErrMsg, "Failed to update row %lld due to: %s", - rowid, e.what()); - return SQLITE_ERROR; - } - - return SQLITE_OK; + return vtab->InsertOrUpdateVector(*vector, rowid); } else { SetZErrMsg(&vtab->zErrMsg, "Failed to perform row %lld due to: %s", rowid, absl::StatusMessageAsCStr(vector.status())); diff --git a/vectorlite/virtual_table.h b/vectorlite/virtual_table.h index 021a509..7e4d624 100644 --- a/vectorlite/virtual_table.h +++ b/vectorlite/virtual_table.h @@ -17,9 +17,6 @@ namespace vectorlite { -// Note there shouldn't be any virtual functions in this class. -// Because VirtualTable* is expected to be static_cast-ed to sqlite3_vtab*. -// vptr could cause UB. class VirtualTable : public sqlite3_vtab { public: // No virtual function @@ -102,6 +99,7 @@ class VirtualTable : public sqlite3_vtab { private: absl::StatusOr GetVectorByRowid(int64_t rowid) const; + int InsertOrUpdateVector(VectorView vector, Cursor::Rowid rowid); NamedVectorSpace space_; std::unique_ptr> index_;