Skip to content

Commit

Permalink
herk range
Browse files Browse the repository at this point in the history
  • Loading branch information
alfC committed Sep 15, 2024
1 parent 124860f commit 7fda805
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 11 deletions.
54 changes: 44 additions & 10 deletions include/boost/multi/adaptors/blas/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,6 @@ The main purpose of these functions is to manipulate arguments to BLAS interface
The functions in this category operate on one-dimensional arrays (vectors).
Here, we use `multi::array<T, 1>` as representative of a vector, but a one-dimensional subarray, such as a row or a column of a 2D array, can also be used as a vector.

### `auto multi::blas::swap(`_complex/real vector_`, `_complex/real vector_`) -> void`

Swaps the values of two vectors.
Vector extensions must match.

### `auto multi::blas::copy(`_complex/real vector_`) -> `_convertible to complex/real vector_

Copies the values of a vector to another.
Expand Down Expand Up @@ -179,6 +174,11 @@ The importance of this case is that it guarantees that no allocations are perfor
v4({0, 2}) = multi::blas::copy(v); // case 3: LHS is not resizable, assigns copies (resizing is not possible or necessary)
```

### `auto multi::blas::swap(`_complex/real vector_`, `_complex/real vector_`) -> void`

Swaps the values of two vectors.
Vector extensions must match.

Note that the utility of `multi::blas::copy` and `multi::blas::swap` is redundant with native features of the library (such as plain assignment, copy construction and swap), the only difference is that these operations will be performed using the BLAS operations elementwise.

## `auto multi::blas::nrm2(`_complex/real vector_`) -> `_convertible to real scalar_
Expand All @@ -202,13 +202,17 @@ double const n = multi::blas::nrm2(v[0]); // acting on a row view

Returns the sum of the absolute values of the elements of a vector (norm-1).

### `auto multi::blas::iamax(`_complex/real vector_`) -> `_index_type_

Index of the element with the largest absolute value (zero-based)

### `auto multi::blas::dot(`_complex/real vector_, _complex/real vector_`) -> `_convertible to complex/real scalar_

Returns the dot product of two vectors with complex or real elements (`T`).

```cpp
multi::array<double, 1> const v = {1.0, 2.0, 3.0};
multi::array<double, 1> const w = {4.0, 5.0, 6.0};
multi::array<double, 1> const v = {1.0, 2.0, 3.0};
multi::array<double, 1> const w = {4.0, 5.0, 6.0};
double const d = multi::blas::dot(v, w);
// auto const d = +multi::blas::dot(v, w);
```
Expand All @@ -229,20 +233,50 @@ It is important to note that the right hand side of the assignment can be a scal
In this case, the result is going to directly put at this location.

```cpp
multi::array<double, 1> z = {0.0, 0.0, 0.0};
z[1] = multi::blas::dot(v, w);
multi::array<double, 1> z = {0.0, 0.0, 0.0};
z[1] = multi::blas::dot(v, w);
```
This feature regarding scalar results is essential when operating on GPU memory since the whole operation can be performed on the device.
### `auto multi::blas::scal(`_complex/real scalar`, `_complex/real vector_`)`
Scales a vector.
### `auto multi::blas::axpy(`_complex/real scalar`, `_complex/real vector_`) -> `_convertible to complex/real_
Vector addition.
```cpp
multi::array<double, 1> const x = ...;
multi::array<double, 1> y = ...;
y += blas::axpy(2.0, x); // same as blas:::axpy(+2.0, x, y)
y -= blas::axpy(2.0, x); // same as blas:::axpy(-2.0, x, y)
```

## BLAS level 2

These functions operate on vectors and arrays.
Again, we use `multi::array<T, 1>` as representative of a vector, but a one-dimensional subarray, such as a row or a column of a 2D array, can also be used as a vector.
`multi::array<T, 2>` as representative of a matrices, but a two-dimensional subarray or larger of higher dimensional arrays can be used as long as one of the two interternal strides in 1.
This is limitation of BLAS, that only acts on certain layouts of 2D arrays.

### GEMV
### `auto multi::blas::gemv(`_complex/real scalar_ `,` _complex/real matrix_`) -> `_convertible to complex/real vector_

```cpp
multi::array<double, 2> const A({4, 3});
multi::array<double, 1> const x = {1.0, 2.0, 3.0};
multi::array<double, 1> const x = {1.0, 2.0, 3.0, 4.0};

y = blas::gemv(5.0, A, x); // y <- 5.0 A * x
```
The gemv expression can be used for addition and subtraction,
```
y += blas::gemv(1.0, A, x); // y <- + A * x + y
y -= blas::gemv(1.0, A, x); // y <- - A * x + y
```
### GEMM
Expand Down
3 changes: 2 additions & 1 deletion include/boost/multi/adaptors/blas/axpy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ class axpy_range {
auto begin() const -> iterator{ return {ctxt_, alpha_, x_begin_ }; }
auto end() const -> iterator{ return {ctxt_, alpha_, x_begin_ + count_}; }

auto size() const -> size_type{return end() - begin();}
auto size() const -> size_type { return end() - begin(); }
auto extensions() const { return extensions_t<1>{ {0, size()} }; }

template<class Other>
friend auto operator+=(Other&& other, axpy_range const& self) -> Other&& {
Expand Down
44 changes: 44 additions & 0 deletions include/boost/multi/adaptors/blas/herk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,50 @@ auto base_aux(A&& array)

using core::herk;

template<class ContextPtr, class Scalar, class ItA, class DecayType>
class herk_range {
ContextPtr ctxtp_;
Scalar s_;
ItA a_begin_;
ItA a_end_;

public:
herk_range(herk_range const&) = delete;
herk_range(herk_range&&) = delete;
auto operator=(herk_range const&) -> herk_range& = delete;
auto operator=(herk_range&&) -> herk_range& = delete;
~herk_range() = default;

herk_range(ContextPtr ctxtp, Scalar s, ItA a_first, ItA a_last) // NOLINT(bugprone-easily-swappable-parameters,readability-identifier-length) BLAS naming
: ctxtp_{ctxtp}
, s_{s}, a_begin_{std::move(a_first)}, a_end_{std::move(a_last)}
{}

// using iterator = herk_iterator<ContextPtr, Scalar, ItA>;
using decay_type = DecayType;
using size_type = typename decay_type::size_type;

// auto begin() const& -> iterator {return {ctxtp_, s_, a_begin_, b_begin_};}
// auto end() const& -> iterator {return {ctxtp_, s_, a_end_ , b_begin_};}
// friend auto begin(gemm_range const& self) {return self.begin();}
// friend auto end (gemm_range const& self) {return self.end ();}

// auto size() const -> size_type {return a_end_ - a_begin_;}

// auto extensions() const -> typename decay_type::extensions_type {return size()*(*b_begin_).extensions();}
// friend auto extensions(gemm_range const& self) {return self.extensions();}

// auto operator+() const -> decay_type {return *this;} // TODO(correaa) : investigate why return decay_type{*this} doesn't work
// template<class Arr>
// friend auto operator+=(Arr&& a, gemm_range const& self) -> Arr&& { // NOLINT(readability-identifier-length) BLAS naming
// blas::gemm_n(self.ctxtp_, self.s_, self.a_begin_, self.a_end_ - self.a_begin_, self.b_begin_, 1., a.begin());
// return std::forward<Arr>(a);
// }
// friend auto operator*(Scalar factor, gemm_range const& self) {
// return gemm_range{self.ctxtp_, factor*self.s_, self.a_begin_, self.a_end_, self.b_begin_};
// }
};

template<class AA, class BB, class A2D, class C2D, class = typename A2D::element_ptr,
std::enable_if_t<is_complex_array<C2D>{}, int> =0> // NOLINT(modernize-use-constraints) TODO(correaa) for C++20
auto herk(filling c_side, AA alpha, A2D const& a, BB beta, C2D&& c) -> C2D&& { // NOLINT(readability-function-cognitive-complexity,readability-identifier-length) 74, BLAS naming
Expand Down
9 changes: 9 additions & 0 deletions include/boost/multi/adaptors/blas/test/axpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ auto main() -> int { // NOLINT(readability-function-cognitive-complexity,bugpro
BOOST_TEST( arr[1][2] == 2.0*y[2] + carr[1][2] );
}

BOOST_AUTO_TEST_CASE(axpy_assignment) {
multi::array<double, 1> const xx = {1.0, 1.0, 1.0};
multi::array<double, 1> yy = {2.0, 2.0, 2.0};

yy += blas::axpy(3.0, xx);

BOOST_TEST( yy[0] == 5.0 );
}

BOOST_AUTO_TEST_CASE(multi_blas_axpy_complex_as_operator_minus_equal) {
multi::array<complex, 2> arr = {
{{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}},
Expand Down

0 comments on commit 7fda805

Please sign in to comment.