Skip to content

Commit

Permalink
Merge pull request #131 from rhpvorderman/ssse3dispatch
Browse files Browse the repository at this point in the history
Use dispatch rather than rely on separate compile commands for SSSE3 bam parser
  • Loading branch information
marcelm committed Apr 22, 2024
2 parents 36c8fe0 + 4fbcc23 commit 94e09d7
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 54 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ jobs:
matrix:
os: [ubuntu-latest]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
compile_flags: [""]
include:
- os: macos-latest
python-version: "3.10"
Expand All @@ -61,7 +60,6 @@ jobs:
python-version: "3.10"
- os: ubuntu-latest
python-version: "3.10"
compile_flags: "-mssse3"
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -72,8 +70,6 @@ jobs:
run: python -m pip install tox
- name: Test
run: tox -e py
env:
CFLAGS: ${{ matrix.compile_flags }}
- name: Upload coverage report
uses: codecov/codecov-action@v3

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ CFLAGS = "-g0 -DNDEBUG"
CFLAGS = "-g0 -DNDEBUG"

[tool.cibuildwheel.linux.environment]
CFLAGS = "-g0 -DNDEBUG -mssse3"
CFLAGS = "-g0 -DNDEBUG"

[tool.cibuildwheel]
test-requires = "pytest"
Expand Down
129 changes: 80 additions & 49 deletions src/dnaio/bam.h
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
#include <stdint.h>
#include <stddef.h>
#include <string.h>
#include <assert.h>
// Macros also used in htslib, very useful.
#if defined __GNUC__
#define GCC_AT_LEAST(major, minor) \
(__GNUC__ > (major) || (__GNUC__ == (major) && __GNUC_MINOR__ >= (minor)))
#else
# define GCC_AT_LEAST(major, minor) 0
#endif

#ifdef __SSE2__
#include "emmintrin.h"
#if defined(__clang__) && defined(__has_attribute)
#define CLANG_COMPILER_HAS(attribute) __has_attribute(attribute)
#else
#define CLANG_COMPILER_HAS(attribute) 0
#endif

#ifdef __SSSE3__
#include "tmmintrin.h"
#define COMPILER_HAS_TARGET (GCC_AT_LEAST(4, 8) || CLANG_COMPILER_HAS(__target__))
#define COMPILER_HAS_OPTIMIZE (GCC_AT_LEAST(4,4) || CLANG_COMPILER_HAS(optimize))

#if defined(__x86_64__) || defined(_M_X64)
#define BUILD_IS_X86_64 1
#include "immintrin.h"
#else
#define BUILD_IS_X86_64 0
#endif

static void
decode_bam_sequence(uint8_t *dest, const uint8_t *encoded_sequence, size_t length)
{
/* Reuse a trick from sam_internal.h in htslib. Have a table to lookup
two characters simultaneously.*/
#include <stdint.h>
#include <string.h>
#include <stddef.h>

static void
decode_bam_sequence_default(uint8_t *dest, const uint8_t *encoded_sequence, size_t length) {
static const char code2base[512] =
"===A=C=M=G=R=S=V=T=W=Y=H=K=D=B=N"
"A=AAACAMAGARASAVATAWAYAHAKADABAN"
Expand All @@ -34,10 +46,26 @@ decode_bam_sequence(uint8_t *dest, const uint8_t *encoded_sequence, size_t lengt
"B=BABCBMBGBRBSBVBTBWBYBHBKBDBBBN"
"N=NANCNMNGNRNSNVNTNWNYNHNKNDNBNN";
static const uint8_t *nuc_lookup = (uint8_t *)"=ACMGRSVTWYHKDBN";
size_t length_2 = length / 2;
for (size_t i=0; i < length_2; i++) {
memcpy(dest + i*2, code2base + ((size_t)encoded_sequence[i] * 2), 2);
}
if (length & 1) {
uint8_t encoded = encoded_sequence[length_2] >> 4;
dest[(length - 1)] = nuc_lookup[encoded];
}
}

#if COMPILER_HAS_TARGET && BUILD_IS_X86_64
__attribute__((__target__("ssse3")))
static void
decode_bam_sequence_ssse3(uint8_t *dest, const uint8_t *encoded_sequence, size_t length)
{

static const uint8_t *nuc_lookup = (uint8_t *)"=ACMGRSVTWYHKDBN";
const uint8_t *dest_end_ptr = dest + length;
uint8_t *dest_cursor = dest;
const uint8_t *encoded_cursor = encoded_sequence;
#ifdef __SSSE3__
const uint8_t *dest_vec_end_ptr = dest_end_ptr - (2 * sizeof(__m128i));
__m128i first_upper_shuffle = _mm_setr_epi8(
0, 0xff, 1, 0xff, 2, 0xff, 3, 0xff, 4, 0xff, 5, 0xff, 6, 0xff, 7, 0xff);
Expand Down Expand Up @@ -84,44 +112,47 @@ decode_bam_sequence(uint8_t *dest, const uint8_t *encoded_sequence, size_t lengt
encoded_cursor += sizeof(__m128i);
dest_cursor += 2 * sizeof(__m128i);
}
#endif
/* Do two at the time until it gets to the last even address. */
const uint8_t *dest_end_ptr_twoatatime = dest + (length & (~1ULL));
while (dest_cursor < dest_end_ptr_twoatatime) {
/* According to htslib, size_t cast helps the optimizer.
Code confirmed to indeed run faster. */
memcpy(dest_cursor, code2base + ((size_t)*encoded_cursor * 2), 2);
dest_cursor += 2;
encoded_cursor += 1;
decode_bam_sequence_default(dest_cursor, encoded_cursor, dest_end_ptr - dest_cursor);
}

static void (*decode_bam_sequence)(
uint8_t *dest, const uint8_t *encoded_sequence, size_t length);

/* Simple dispatcher function, updates the function pointer after testing the
CPU capabilities. After this, the dispatcher function is not needed anymore. */
static void decode_bam_sequence_dispatch(
uint8_t *dest, const uint8_t *encoded_sequence, size_t length) {
if (__builtin_cpu_supports("ssse3")) {
decode_bam_sequence = decode_bam_sequence_ssse3;
}
assert((dest_end_ptr - dest_cursor) < 2);
if (dest_cursor != dest_end_ptr) {
/* There is a single encoded nuc left */
uint8_t encoded_nucs = *encoded_cursor;
uint8_t upper_nuc_index = encoded_nucs >> 4;
dest_cursor[0] = nuc_lookup[upper_nuc_index];
else {
decode_bam_sequence = decode_bam_sequence_default;
}
decode_bam_sequence(dest, encoded_sequence, length);
}

static void (*decode_bam_sequence)(
uint8_t *dest, const uint8_t *encoded_sequence, size_t length
) = decode_bam_sequence_dispatch;

#else
static inline void decode_bam_sequence(uint8_t *dest, const uint8_t *encoded_sequence, size_t length)
{
decode_bam_sequence_default(dest, encoded_sequence, length);
}
#endif

static void
decode_bam_qualities(uint8_t *dest, const uint8_t *encoded_qualities, size_t length)
// Code is simple enough to be auto vectorized.
#if COMPILER_HAS_OPTIMIZE
__attribute__((optimize("O3")))
#endif
static void
decode_bam_qualities(
uint8_t *restrict dest,
const uint8_t *restrict encoded_qualities,
size_t length)
{
const uint8_t *end_ptr = encoded_qualities + length;
const uint8_t *cursor = encoded_qualities;
uint8_t *dest_cursor = dest;
#ifdef __SSE2__
const uint8_t *vec_end_ptr = end_ptr - sizeof(__m128i);
while (cursor < vec_end_ptr) {
__m128i quals = _mm_loadu_si128((__m128i *)cursor);
__m128i phreds = _mm_add_epi8(quals, _mm_set1_epi8(33));
_mm_storeu_si128((__m128i *)dest_cursor, phreds);
cursor += sizeof(__m128i);
dest_cursor += sizeof(__m128i);
}
#endif
while (cursor < end_ptr) {
*dest_cursor = *cursor + 33;
cursor += 1;
dest_cursor += 1;
for (size_t i=0; i<length; i++) {
dest[i] = encoded_qualities[i] + 33;
}
}
}

0 comments on commit 94e09d7

Please sign in to comment.