Optimized popcnt with AVX2

Jul 11, 2024

How does one count the number of set bits in a register? Well, by using POPCNT for a 64 bit register according to intrinsics cheatsheet. This is however limited to 64 bits, so we do not have a straightforward way of counting the set bits in a YMM register.

A quick glance through the Intel manual lists a family of _mm256_popcnt_epi* instructions. However, these are shipped as a part of AVX512 (the CPU feature flags AVX512_BITALG and AVX512VPOPCNTDQ), and my CPU is too old to support them.

Literature

Obviously someone had thought of this exact problem before! Faster Population Counts using AVX2 instructions describes a set of approaches and their performance characteristics. For 256 bit vectors, the Mula function is measured to be the fastest. In my macro-benchmarks, manually invoking _popcnt64 on the four components of a 256 bit vector extracted using _mm256_extract_epi64 is a little slower than Mula. To give a sense of how close they are in performance, the former takes 0.38 ins/op and the latter takes 0.56 ins/op. This translates to a 6% drop in ops/second in this macro-benchmark.

AVX2 Mula

Having picked the AVX2 version of the Mula function, let's look into how it works. The following implementation is reproduced from the paper.

__m256i count(__m256i v) {
    __m256i lookup = _mm256_setr_epi8(0, 1, 1, 2, 1, 2, 2, 3,
                                      1, 2, 2, 3, 2, 3, 3, 4,
                                      0, 1, 1, 2, 1, 2, 2, 3,
                                      1, 2, 2, 3, 2, 3, 3, 4);
    __m256i low_mask = _mm256_set1_epi8(0x0f);
    __m256i lo = _mm256_and_si256(v, low_mask);
    __m256i hi = _mm256_and_si256(_mm256_srli_epi32(v, 4), low_mask);
    __m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo);
    __m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi);
    __m256i total = _mm256_add_epi8(popcnt1, popcnt2);

    return _mm256_sad_epu8(total, _mm256_setzero_si256());
}

How does this actually work? Let's translate it into a slightly more readable form.

Let's assign each bit an index.

v = [b0, b1, ... b255]

We then create a mask to split a byte into a pair of four bits.

low_mask = [0x0f, 0x0f, ... 0x0f]

Here, lo corresponds to 32 4-bit chunks corresponding to the lower half of a byte, and hi corresponds to the upper 4-bits.

lo = [0x0f & b0, 0x0f & b1, ..., 0x0f & b255]
hi = [0x0f & (b0b1b2b3 >> 4), 0x0f & (b4b5b6b7 >> 4), ...]

The use of _mm256_srli_epi32 here is a little suspicious, as I expected a version of _mm256_srli_epi8 instead to extract the upper half of a byte. This intrinsic however does not exist on the manual.

Let's prove that _mm256_srli_epi32 also works using our trusty z3. The bitvector support in z3 is great for our prototyping needs. We use the z3 python API here. Note that >> is arithmetic shift which uses the sign bits to fill, so we use the logical shift LShR to fill the left side with zeroes. [1]

from z3 import *
a = BitVec('a', 32)

# The approach used above
l = LShR(a, 4) & 0x0f0f0f0f

# Intuitive way to select the upper four bits of a byte
r = LShR(a & 0xf0, 4) | LShR(a & 0xf000, 4) | LShR(a & 0xf00000, 4) | LShR(a & 0xf0000000, 4)

s = Solver()
s.add(l != r)
s.check()
# prints unsat

That's enough proof for me.

Magic numbers

__m256i lookup = _mm256_setr_epi8(0, 1, 1, 2, 1, 2, 2, 3,
                                  1, 2, 2, 3, 2, 3, 3, 4,
                                  0, 1, 1, 2, 1, 2, 2, 3,
                                  1, 2, 2, 3, 2, 3, 3, 4);

Having split our input into upper and lower nibbles [2] , we call _mm256_shuffle_epi8 using the magic number sequence. This instruction takes an input vector and shuffles it based on a fixed sequence. An interesting observation is that this sequence allows repeating a specific 8 bit chunk more than once. In other words, if we need a vector where each element is the Nth element of another vector, we can use _mm256_shuffle_epi8 with N, N, ... as indices, although this is likely not the fastest way to do it.

__m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo);

So why does this work at all? The number sequence in lookup is actually the number of set bits in integers [0..15]. For every nibble, we look up the index corresponding to the value of this nibble in the magic number table. The _mm256_shuffle_epi8 above shuffles the lookup table based on the value of a nibble.

For instance, a hypothetical shuffle_epi8(lookup, [0b1110, 0b1100, 0b1000]) would return [lookup[0b1110], lookup[0b1100], lookup[0b1000]], which would then be [3, 2, 1]. Cool trick, right?

I tried deriving a linear equation out of lo so that I could use lo as the first argument in shuffle_epi8 but ended up with way more arithmetic operations.

Sum

Now that we have the popcounts of upper and lower nibbles in two different vectors, we could add them and then invoke _mm256_sad_epu8 with a zeroed second argument to horizontally add the popcounts of every 8 nibbles.

More generally, we now know how to do vectorized lookups of small tables.


  1. I don't know why the defaults were chosen this way. Seems counter-intuitive. ↩︎

  2. A group of four bits is a nibble! ↩︎