Remove arch-specific stuff from HNSW extension (#4423)

This commit is contained in:
Sasha Krassovsky
2023-06-05 22:04:15 -08:00
committed by GitHub
parent 8e1b5e1224
commit ac11e7c32d
3 changed files with 70 additions and 57 deletions

View File

@@ -12,7 +12,7 @@ REGRESS_OPTS = --inputdir=test --load-extension=hnsw
# For auto-vectorization:
# - GCC (needs -ftree-vectorize OR -O3) - https://gcc.gnu.org/projects/tree-ssa/vectorization.html
PG_CFLAGS += -O3
PG_CXXFLAGS += -msse4 -mavx2 -O3 -std=c++11
PG_CXXFLAGS += -O3 -std=c++11
PG_LDFLAGS += -lstdc++
all: $(EXTENSION)--$(EXTVERSION).sql

View File

@@ -1,22 +1,11 @@
#include "hnswalg.h"
#if defined(__x86_64__)
#include <x86intrin.h>
#define USE_AVX
#if defined(__GNUC__)
#define PORTABLE_ALIGN32 __attribute__((aligned(32)))
#define PREFETCH(addr,hint) __builtin_prefetch(addr, 0, hint)
#else
#define PORTABLE_ALIGN32 __declspec(align(32))
#endif
#define PREFETCH(addr,hint) _mm_prefetch(addr, hint)
#else
#define PREFETCH(addr,hint)
#endif
HierarchicalNSW::HierarchicalNSW(size_t dim_, size_t maxelements_, size_t M_, size_t maxM_, size_t efConstruction_)
@@ -36,7 +25,9 @@ HierarchicalNSW::HierarchicalNSW(size_t dim_, size_t maxelements_, size_t M_, si
enterpoint_node = 0;
cur_element_count = 0;
dist_calc = 0;
#ifdef __x86_64__
use_avx2 = __builtin_cpu_supports("avx2");
#endif
}
std::priority_queue<std::pair<dist_t, idx_t>> HierarchicalNSW::searchBaseLayer(const coord_t *point, size_t ef)
@@ -66,12 +57,12 @@ std::priority_queue<std::pair<dist_t, idx_t>> HierarchicalNSW::searchBaseLayer(c
idx_t* data = get_linklist0(curNodeNum);
size_t size = *data++;
PREFETCH(getDataByInternalId(*data), _MM_HINT_T0);
PREFETCH(getDataByInternalId(*data), 0);
for (size_t j = 0; j < size; ++j) {
size_t tnum = *(data + j);
PREFETCH(getDataByInternalId(*(data + j + 1)), _MM_HINT_T0);
PREFETCH(getDataByInternalId(*(data + j + 1)), 0);
if (!(visited[tnum >> 5] & (1 << (tnum & 31)))) {
visited[tnum >> 5] |= 1 << (tnum & 31);
@@ -81,7 +72,7 @@ std::priority_queue<std::pair<dist_t, idx_t>> HierarchicalNSW::searchBaseLayer(c
if (topResults.top().first > dist || topResults.size() < ef) {
candidateSet.emplace(-dist, tnum);
PREFETCH(get_linklist0(candidateSet.top().second), _MM_HINT_T0);
PREFETCH(get_linklist0(candidateSet.top().second), 0);
topResults.emplace(dist, tnum);
if (topResults.size() > ef)
@@ -228,37 +219,59 @@ std::priority_queue<std::pair<dist_t, label_t>> HierarchicalNSW::searchKnn(const
return topResults;
};
dist_t HierarchicalNSW::fstdistfunc(const coord_t *x, const coord_t *y)
dist_t fstdistfunc_scalar(const coord_t *x, const coord_t *y, size_t n)
{
#if defined(__x86_64__)
float PORTABLE_ALIGN32 TmpRes[8];
size_t qty16 = dim >> 4;
const float *pEnd1 = x + (qty16 << 4);
#ifdef USE_AVX
__m256 diff, v1, v2;
__m256 sum = _mm256_set1_ps(0);
dist_t distance = 0.0;
while (x < pEnd1) {
v1 = _mm256_loadu_ps(x);
x += 8;
v2 = _mm256_loadu_ps(y);
y += 8;
diff = _mm256_sub_ps(v1, v2);
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
for (size_t i = 0; i < n; i++)
{
dist_t diff = x[i] - y[i];
distance += diff * diff;
}
return distance;
v1 = _mm256_loadu_ps(x);
x += 8;
v2 = _mm256_loadu_ps(y);
y += 8;
diff = _mm256_sub_ps(v1, v2);
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
}
}
_mm256_store_ps(TmpRes, sum);
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
#ifdef __x86_64__
#include <immintrin.h>
__attribute__((target("avx2")))
dist_t fstdistfunc_avx2(const coord_t *x, const coord_t *y, size_t n)
{
const size_t TmpResSz = sizeof(__m256) / sizeof(float);
float PORTABLE_ALIGN32 TmpRes[TmpResSz];
size_t qty16 = n / 16;
const float *pEnd1 = x + (qty16 * 16);
__m256 diff, v1, v2;
__m256 sum = _mm256_set1_ps(0);
while (x < pEnd1) {
v1 = _mm256_loadu_ps(x);
x += 8;
v2 = _mm256_loadu_ps(y);
y += 8;
diff = _mm256_sub_ps(v1, v2);
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
v1 = _mm256_loadu_ps(x);
x += 8;
v2 = _mm256_loadu_ps(y);
y += 8;
diff = _mm256_sub_ps(v1, v2);
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
}
_mm256_store_ps(TmpRes, sum);
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
return (res);
}
dist_t fstdistfunc_sse(const coord_t *x, const coord_t *y, size_t n)
{
const size_t TmpResSz = sizeof(__m128) / sizeof(float);
float PORTABLE_ALIGN32 TmpRes[TmpResSz];
size_t qty16 = n / 16;
const float *pEnd1 = x + (qty16 * 16);
return (res);
#else
__m128 diff, v1, v2;
__m128 sum = _mm_set1_ps(0);
@@ -293,21 +306,19 @@ dist_t HierarchicalNSW::fstdistfunc(const coord_t *x, const coord_t *y)
}
_mm_store_ps(TmpRes, sum);
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
return (res);
return res;
}
#endif
#else // portable implementation
dist_t distance = 0.0;
size_t n = dim;
dist_calc++;
dist_t HierarchicalNSW::fstdistfunc(const coord_t *x, const coord_t *y)
{
#ifndef __x86_64__
return fstdistfunc_scalar(x, y, dim);
#else
if(use_avx2)
return fstdistfunc_avx2(x, y, dim);
for (size_t i = 0; i < n; i++)
{
dist_t diff = x[i] - y[i];
distance += diff * diff;
}
return distance;
return fstdistfunc_sse(x, y, dim);
#endif
}

View File

@@ -22,8 +22,6 @@ struct HierarchicalNSW
idx_t enterpoint_node;
size_t dist_calc;
size_t dim;
size_t data_size;
size_t offset_data;
@@ -34,6 +32,10 @@ struct HierarchicalNSW
size_t size_links_level0;
size_t efConstruction;
#ifdef __x86_64__
bool use_avx2;
#endif
char data_level0_memory[0]; // varying size
public: