mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 05:22:56 +00:00
Remove arch-specific stuff from HNSW extension (#4423)
This commit is contained in:
@@ -12,7 +12,7 @@ REGRESS_OPTS = --inputdir=test --load-extension=hnsw
|
||||
# For auto-vectorization:
|
||||
# - GCC (needs -ftree-vectorize OR -O3) - https://gcc.gnu.org/projects/tree-ssa/vectorization.html
|
||||
PG_CFLAGS += -O3
|
||||
PG_CXXFLAGS += -msse4 -mavx2 -O3 -std=c++11
|
||||
PG_CXXFLAGS += -O3 -std=c++11
|
||||
PG_LDFLAGS += -lstdc++
|
||||
|
||||
all: $(EXTENSION)--$(EXTVERSION).sql
|
||||
|
||||
@@ -1,22 +1,11 @@
|
||||
#include "hnswalg.h"
|
||||
|
||||
|
||||
#if defined(__x86_64__)
|
||||
|
||||
#include <x86intrin.h>
|
||||
#define USE_AVX
|
||||
#if defined(__GNUC__)
|
||||
#define PORTABLE_ALIGN32 __attribute__((aligned(32)))
|
||||
#define PREFETCH(addr,hint) __builtin_prefetch(addr, 0, hint)
|
||||
#else
|
||||
#define PORTABLE_ALIGN32 __declspec(align(32))
|
||||
#endif
|
||||
|
||||
#define PREFETCH(addr,hint) _mm_prefetch(addr, hint)
|
||||
|
||||
#else
|
||||
|
||||
#define PREFETCH(addr,hint)
|
||||
|
||||
#endif
|
||||
|
||||
HierarchicalNSW::HierarchicalNSW(size_t dim_, size_t maxelements_, size_t M_, size_t maxM_, size_t efConstruction_)
|
||||
@@ -36,7 +25,9 @@ HierarchicalNSW::HierarchicalNSW(size_t dim_, size_t maxelements_, size_t M_, si
|
||||
|
||||
enterpoint_node = 0;
|
||||
cur_element_count = 0;
|
||||
dist_calc = 0;
|
||||
#ifdef __x86_64__
|
||||
use_avx2 = __builtin_cpu_supports("avx2");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::priority_queue<std::pair<dist_t, idx_t>> HierarchicalNSW::searchBaseLayer(const coord_t *point, size_t ef)
|
||||
@@ -66,12 +57,12 @@ std::priority_queue<std::pair<dist_t, idx_t>> HierarchicalNSW::searchBaseLayer(c
|
||||
idx_t* data = get_linklist0(curNodeNum);
|
||||
size_t size = *data++;
|
||||
|
||||
PREFETCH(getDataByInternalId(*data), _MM_HINT_T0);
|
||||
PREFETCH(getDataByInternalId(*data), 0);
|
||||
|
||||
for (size_t j = 0; j < size; ++j) {
|
||||
size_t tnum = *(data + j);
|
||||
|
||||
PREFETCH(getDataByInternalId(*(data + j + 1)), _MM_HINT_T0);
|
||||
PREFETCH(getDataByInternalId(*(data + j + 1)), 0);
|
||||
|
||||
if (!(visited[tnum >> 5] & (1 << (tnum & 31)))) {
|
||||
visited[tnum >> 5] |= 1 << (tnum & 31);
|
||||
@@ -81,7 +72,7 @@ std::priority_queue<std::pair<dist_t, idx_t>> HierarchicalNSW::searchBaseLayer(c
|
||||
if (topResults.top().first > dist || topResults.size() < ef) {
|
||||
candidateSet.emplace(-dist, tnum);
|
||||
|
||||
PREFETCH(get_linklist0(candidateSet.top().second), _MM_HINT_T0);
|
||||
PREFETCH(get_linklist0(candidateSet.top().second), 0);
|
||||
topResults.emplace(dist, tnum);
|
||||
|
||||
if (topResults.size() > ef)
|
||||
@@ -228,37 +219,59 @@ std::priority_queue<std::pair<dist_t, label_t>> HierarchicalNSW::searchKnn(const
|
||||
return topResults;
|
||||
};
|
||||
|
||||
dist_t HierarchicalNSW::fstdistfunc(const coord_t *x, const coord_t *y)
|
||||
dist_t fstdistfunc_scalar(const coord_t *x, const coord_t *y, size_t n)
|
||||
{
|
||||
#if defined(__x86_64__)
|
||||
float PORTABLE_ALIGN32 TmpRes[8];
|
||||
size_t qty16 = dim >> 4;
|
||||
const float *pEnd1 = x + (qty16 << 4);
|
||||
#ifdef USE_AVX
|
||||
__m256 diff, v1, v2;
|
||||
__m256 sum = _mm256_set1_ps(0);
|
||||
dist_t distance = 0.0;
|
||||
|
||||
while (x < pEnd1) {
|
||||
v1 = _mm256_loadu_ps(x);
|
||||
x += 8;
|
||||
v2 = _mm256_loadu_ps(y);
|
||||
y += 8;
|
||||
diff = _mm256_sub_ps(v1, v2);
|
||||
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
|
||||
for (size_t i = 0; i < n; i++)
|
||||
{
|
||||
dist_t diff = x[i] - y[i];
|
||||
distance += diff * diff;
|
||||
}
|
||||
return distance;
|
||||
|
||||
v1 = _mm256_loadu_ps(x);
|
||||
x += 8;
|
||||
v2 = _mm256_loadu_ps(y);
|
||||
y += 8;
|
||||
diff = _mm256_sub_ps(v1, v2);
|
||||
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
|
||||
}
|
||||
}
|
||||
|
||||
_mm256_store_ps(TmpRes, sum);
|
||||
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
|
||||
#ifdef __x86_64__
|
||||
#include <immintrin.h>
|
||||
|
||||
__attribute__((target("avx2")))
|
||||
dist_t fstdistfunc_avx2(const coord_t *x, const coord_t *y, size_t n)
|
||||
{
|
||||
const size_t TmpResSz = sizeof(__m256) / sizeof(float);
|
||||
float PORTABLE_ALIGN32 TmpRes[TmpResSz];
|
||||
size_t qty16 = n / 16;
|
||||
const float *pEnd1 = x + (qty16 * 16);
|
||||
__m256 diff, v1, v2;
|
||||
__m256 sum = _mm256_set1_ps(0);
|
||||
|
||||
while (x < pEnd1) {
|
||||
v1 = _mm256_loadu_ps(x);
|
||||
x += 8;
|
||||
v2 = _mm256_loadu_ps(y);
|
||||
y += 8;
|
||||
diff = _mm256_sub_ps(v1, v2);
|
||||
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
|
||||
|
||||
v1 = _mm256_loadu_ps(x);
|
||||
x += 8;
|
||||
v2 = _mm256_loadu_ps(y);
|
||||
y += 8;
|
||||
diff = _mm256_sub_ps(v1, v2);
|
||||
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
|
||||
}
|
||||
_mm256_store_ps(TmpRes, sum);
|
||||
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
|
||||
return (res);
|
||||
}
|
||||
|
||||
dist_t fstdistfunc_sse(const coord_t *x, const coord_t *y, size_t n)
|
||||
{
|
||||
const size_t TmpResSz = sizeof(__m128) / sizeof(float);
|
||||
float PORTABLE_ALIGN32 TmpRes[TmpResSz];
|
||||
size_t qty16 = n / 16;
|
||||
const float *pEnd1 = x + (qty16 * 16);
|
||||
|
||||
return (res);
|
||||
#else
|
||||
__m128 diff, v1, v2;
|
||||
__m128 sum = _mm_set1_ps(0);
|
||||
|
||||
@@ -293,21 +306,19 @@ dist_t HierarchicalNSW::fstdistfunc(const coord_t *x, const coord_t *y)
|
||||
}
|
||||
_mm_store_ps(TmpRes, sum);
|
||||
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
|
||||
|
||||
return (res);
|
||||
return res;
|
||||
}
|
||||
#endif
|
||||
#else // portable implementation
|
||||
dist_t distance = 0.0;
|
||||
size_t n = dim;
|
||||
|
||||
dist_calc++;
|
||||
dist_t HierarchicalNSW::fstdistfunc(const coord_t *x, const coord_t *y)
|
||||
{
|
||||
#ifndef __x86_64__
|
||||
return fstdistfunc_scalar(x, y, dim);
|
||||
#else
|
||||
if(use_avx2)
|
||||
return fstdistfunc_avx2(x, y, dim);
|
||||
|
||||
for (size_t i = 0; i < n; i++)
|
||||
{
|
||||
dist_t diff = x[i] - y[i];
|
||||
distance += diff * diff;
|
||||
}
|
||||
return distance;
|
||||
return fstdistfunc_sse(x, y, dim);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -22,8 +22,6 @@ struct HierarchicalNSW
|
||||
|
||||
idx_t enterpoint_node;
|
||||
|
||||
size_t dist_calc;
|
||||
|
||||
size_t dim;
|
||||
size_t data_size;
|
||||
size_t offset_data;
|
||||
@@ -34,6 +32,10 @@ struct HierarchicalNSW
|
||||
size_t size_links_level0;
|
||||
size_t efConstruction;
|
||||
|
||||
#ifdef __x86_64__
|
||||
bool use_avx2;
|
||||
#endif
|
||||
|
||||
char data_level0_memory[0]; // varying size
|
||||
|
||||
public:
|
||||
|
||||
Reference in New Issue
Block a user