diff --git a/Dockerfile.compute-node b/Dockerfile.compute-node index de8a904c02..8446ef9fa0 100644 --- a/Dockerfile.compute-node +++ b/Dockerfile.compute-node @@ -556,6 +556,10 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) \ make -j $(getconf _NPROCESSORS_ONLN) \ PG_CONFIG=/usr/local/pgsql/bin/pg_config \ -C pgxn/neon_utils \ + -s install && \ + make -j $(getconf _NPROCESSORS_ONLN) \ + PG_CONFIG=/usr/local/pgsql/bin/pg_config \ + -C pgxn/hnsw \ -s install ######################################################################################### diff --git a/Makefile b/Makefile index 9d78c5d0fc..ae979b8b4c 100644 --- a/Makefile +++ b/Makefile @@ -138,6 +138,11 @@ neon-pg-ext-%: postgres-% $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \ -C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \ -f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile install + +@echo "Compiling hnsw $*" + mkdir -p $(POSTGRES_INSTALL_DIR)/build/hnsw-$* + $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \ + -C $(POSTGRES_INSTALL_DIR)/build/hnsw-$* \ + -f $(ROOT_PROJECT_DIR)/pgxn/hnsw/Makefile install .PHONY: neon-pg-ext-clean-% neon-pg-ext-clean-%: @@ -153,6 +158,9 @@ neon-pg-ext-clean-%: $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \ -C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \ -f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile clean + $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \ + -C $(POSTGRES_INSTALL_DIR)/build/hnsw-$* \ + -f $(ROOT_PROJECT_DIR)/pgxn/hnsw/Makefile clean .PHONY: neon-pg-ext neon-pg-ext: \ diff --git a/pgxn/hnsw/Makefile b/pgxn/hnsw/Makefile new file mode 100644 index 0000000000..9bdd87430c --- /dev/null +++ b/pgxn/hnsw/Makefile @@ -0,0 +1,26 @@ +EXTENSION = hnsw +EXTVERSION = 0.1.0 + +MODULE_big = hnsw +DATA = $(wildcard *--*.sql) +OBJS = hnsw.o hnswalg.o + +TESTS = $(wildcard test/sql/*.sql) +REGRESS = $(patsubst test/sql/%.sql,%,$(TESTS)) +REGRESS_OPTS = --inputdir=test --load-extension=hnsw + +# For auto-vectorization: +# - GCC (needs -ftree-vectorize OR -O3) - https://gcc.gnu.org/projects/tree-ssa/vectorization.html +PG_CFLAGS += -O3 +PG_CPPFLAGS += -msse4.1 -O3 -march=native -ftree-vectorize -ftree-vectorizer-verbose=0 +PG_LDFLAGS += -lstdc++ + +all: $(EXTENSION)--$(EXTVERSION).sql + +PG_CONFIG ?= pg_config +PGXS := $(shell $(PG_CONFIG) --pgxs) +include $(PGXS) + +dist: + mkdir -p dist + git archive --format zip --prefix=$(EXTENSION)-$(EXTVERSION)/ --output dist/$(EXTENSION)-$(EXTVERSION).zip master diff --git a/pgxn/hnsw/README.md b/pgxn/hnsw/README.md new file mode 100644 index 0000000000..bc9c8d571c --- /dev/null +++ b/pgxn/hnsw/README.md @@ -0,0 +1,25 @@ +# Revisiting the Inverted Indices for Billion-Scale Approximate Nearest Neighbors + +This ANN extension of Postgres is based +on [ivf-hnsw](https://github.com/dbaranchuk/ivf-hnsw.git) implementation of [HNSW](https://www.pinecone.io/learn/hnsw), +the code for the current state-of-the-art billion-scale nearest neighbor search system presented in the paper: + +[Revisiting the Inverted Indices for Billion-Scale Approximate Nearest Neighbors](http://openaccess.thecvf.com/content_ECCV_2018/html/Dmitry_Baranchuk_Revisiting_the_Inverted_ECCV_2018_paper.html), +
+Dmitry Baranchuk, Artem Babenko, Yury Malkov + +# Postgres extension + +HNSW index is hold in memory (built on demand) and it's maxial size is limited +by `maxelements` index parameter. Another required parameter is nubmer of dimensions (if it is not specified in column type). +Optional parameter `ef` specifies number of neighbors which are considered during index construction and search (corresponds `efConstruction` and `efSearch` parameters +described in the article). + +# Example of usage: + +``` +create extension hnsw; +create table embeddings(id integer primary key, payload real[]); +create index on embeddings using hnsw(payload) with (maxelements=1000000, dims=100, m=32); +select id from embeddings order by payload <-> array[1.0, 2.0,...] limit 100; +``` \ No newline at end of file diff --git a/pgxn/hnsw/hnsw--0.1.0.sql b/pgxn/hnsw/hnsw--0.1.0.sql new file mode 100644 index 0000000000..ebf424326d --- /dev/null +++ b/pgxn/hnsw/hnsw--0.1.0.sql @@ -0,0 +1,29 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "CREATE EXTENSION hnsw" to load this file. \quit + +-- functions + +CREATE FUNCTION l2_distance(real[], real[]) RETURNS real + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- operators + +CREATE OPERATOR <-> ( + LEFTARG = real[], RIGHTARG = real[], PROCEDURE = l2_distance, + COMMUTATOR = '<->' +); + +-- access method + +CREATE FUNCTION hnsw_handler(internal) RETURNS index_am_handler + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnsw_handler; + +COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method'; + +-- opclasses + +CREATE OPERATOR CLASS knn_ops + DEFAULT FOR TYPE real[] USING hnsw AS + OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops; diff --git a/pgxn/hnsw/hnsw.c b/pgxn/hnsw/hnsw.c new file mode 100644 index 0000000000..434f4986f8 --- /dev/null +++ b/pgxn/hnsw/hnsw.c @@ -0,0 +1,551 @@ +#include "postgres.h" + +#include "access/amapi.h" +#include "access/generic_xlog.h" +#include "access/relation.h" +#include "access/reloptions.h" +#include "access/tableam.h" +#include "catalog/index.h" +#include "commands/vacuum.h" +#include "nodes/execnodes.h" +#include "storage/bufmgr.h" +#include "utils/guc.h" +#include "utils/selfuncs.h" + +#include +#include + +#include "hnsw.h" + +PG_MODULE_MAGIC; + +typedef struct { + int32 vl_len_; /* varlena header (do not touch directly!) */ + int dims; + int maxelements; + int efConstruction; + int efSearch; + int M; +} HnswOptions; + +static relopt_kind hnsw_relopt_kind; + +typedef struct { + HierarchicalNSW* hnsw; + size_t curr; + size_t n_results; + ItemPointer results; +} HnswScanOpaqueData; + +typedef HnswScanOpaqueData* HnswScanOpaque; + +typedef struct { + Oid relid; + uint32 status; + HierarchicalNSW* hnsw; +} HnswHashEntry; + + +#define SH_PREFIX hnsw_index +#define SH_ELEMENT_TYPE HnswHashEntry +#define SH_KEY_TYPE Oid +#define SH_KEY relid +#define SH_STORE_HASH +#define SH_GET_HASH(tb, a) ((a)->relid) +#define SH_HASH_KEY(tb, key) (key) +#define SH_EQUAL(tb, a, b) ((a) == (b)) +#define SH_SCOPE static inline +#define SH_DEFINE +#define SH_DECLARE +#include "lib/simplehash.h" + +#define INDEX_HASH_SIZE 11 + +#define DEFAULT_EF_SEARCH 64 + +PGDLLEXPORT void _PG_init(void); + +static hnsw_index_hash *hnsw_indexes; + +/* + * Initialize index options and variables + */ +void +_PG_init(void) +{ + hnsw_relopt_kind = add_reloption_kind(); + add_int_reloption(hnsw_relopt_kind, "dims", "Number of dimensions", + 0, 0, INT_MAX, AccessExclusiveLock); + add_int_reloption(hnsw_relopt_kind, "maxelements", "Maximal number of elements", + 0, 0, INT_MAX, AccessExclusiveLock); + add_int_reloption(hnsw_relopt_kind, "m", "Number of neighbors of each vertex", + 100, 0, INT_MAX, AccessExclusiveLock); + add_int_reloption(hnsw_relopt_kind, "efconstruction", "Number of inspected neighbors during index construction", + 16, 1, INT_MAX, AccessExclusiveLock); + add_int_reloption(hnsw_relopt_kind, "efsearch", "Number of inspected neighbors during index search", + 64, 1, INT_MAX, AccessExclusiveLock); + hnsw_indexes = hnsw_index_create(TopMemoryContext, INDEX_HASH_SIZE, NULL); +} + + +static void +hnsw_build_callback(Relation index, ItemPointer tid, Datum *values, + bool *isnull, bool tupleIsAlive, void *state) +{ + HierarchicalNSW* hnsw = (HierarchicalNSW*) state; + ArrayType* array; + int n_items; + label_t label = 0; + + /* Skip nulls */ + if (isnull[0]) + return; + + array = DatumGetArrayTypeP(values[0]); + n_items = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array)); + if (n_items != hnsw_dimensions(hnsw)) + { + elog(ERROR, "Wrong number of dimensions: %d instead of %d expected", + n_items, hnsw_dimensions(hnsw)); + } + + memcpy(&label, tid, sizeof(*tid)); + hnsw_add_point(hnsw, (coord_t*)ARR_DATA_PTR(array), label); +} + +static void +hnsw_populate(HierarchicalNSW* hnsw, Relation indexRel, Relation heapRel) +{ + IndexInfo* indexInfo = BuildIndexInfo(indexRel); + Assert(indexInfo->ii_NumIndexAttrs == 1); + table_index_build_scan(heapRel, indexRel, indexInfo, + true, true, hnsw_build_callback, (void *) hnsw, NULL); +} + +static HierarchicalNSW* +hnsw_get_index(Relation indexRel, Relation heapRel) +{ + HierarchicalNSW* hnsw; + Oid indexoid = RelationGetRelid(indexRel); + HnswHashEntry* entry = hnsw_index_lookup(hnsw_indexes, indexoid); + if (entry == NULL) + { + size_t dims, maxelements; + size_t M; + size_t maxM; + size_t size_links_level0; + size_t size_data_per_element; + size_t data_size; + dsm_handle handle = indexoid << 1; /* make it even */ + void* impl_private = NULL; + void* mapped_address = NULL; + Size mapped_size = 0; + Size shmem_size; + bool exists = true; + bool found; + HnswOptions *opts = (HnswOptions *) indexRel->rd_options; + if (opts == NULL || opts->maxelements == 0 || opts->dims == 0) { + elog(ERROR, "HNSW index requires 'maxelements' and 'dims' to be specified"); + } + dims = opts->dims; + maxelements = opts->maxelements; + M = opts->M; + maxM = M * 2; + data_size = dims * sizeof(coord_t); + size_links_level0 = (maxM + 1) * sizeof(idx_t); + size_data_per_element = size_links_level0 + data_size + sizeof(label_t); + shmem_size = hnsw_sizeof() + maxelements * size_data_per_element; + + /* first try to attach to existed index */ + if (!dsm_impl_op(DSM_OP_ATTACH, handle, 0, &impl_private, + &mapped_address, &mapped_size, DEBUG1)) + { + /* index doesn't exists: try to create it */ + if (!dsm_impl_op(DSM_OP_CREATE, handle, shmem_size, &impl_private, + &mapped_address, &mapped_size, DEBUG1)) + { + /* We can do it under shared lock, so some other backend may + * try to initialize index. If create is failed because index already + * created by somebody else, then try to attach to it once again + */ + if (!dsm_impl_op(DSM_OP_ATTACH, handle, 0, &impl_private, + &mapped_address, &mapped_size, ERROR)) + { + return NULL; + } + } + else + { + exists = false; + } + } + Assert(mapped_size == shmem_size); + hnsw = (HierarchicalNSW*)mapped_address; + + if (!exists) + { + hnsw_init(hnsw, dims, maxelements, M, maxM, opts->efConstruction); + hnsw_populate(hnsw, indexRel, heapRel); + } + entry = hnsw_index_insert(hnsw_indexes, indexoid, &found); + Assert(!found); + entry->hnsw = hnsw; + } + else + { + hnsw = entry->hnsw; + } + return hnsw; +} + +/* + * Start or restart an index scan + */ +static IndexScanDesc +hnsw_beginscan(Relation index, int nkeys, int norderbys) +{ + IndexScanDesc scan = RelationGetIndexScan(index, nkeys, norderbys); + HnswScanOpaque so = (HnswScanOpaque) palloc(sizeof(HnswScanOpaqueData)); + Relation heap = relation_open(index->rd_index->indrelid, NoLock); + so->hnsw = hnsw_get_index(index, heap); + relation_close(heap, NoLock); + so->curr = 0; + so->n_results = 0; + so->results = NULL; + scan->opaque = so; + return scan; +} + +/* + * Start or restart an index scan + */ +static void +hnsw_rescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + if (so->results) + { + pfree(so->results); + so->results = NULL; + } + so->curr = 0; + if (orderbys && scan->numberOfOrderBys > 0) + memmove(scan->orderByData, orderbys, scan->numberOfOrderBys * sizeof(ScanKeyData)); +} + +/* + * Fetch the next tuple in the given scan + */ +static bool +hnsw_gettuple(IndexScanDesc scan, ScanDirection dir) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + + /* + * Index can be used to scan backward, but Postgres doesn't support + * backward scan on operators + */ + Assert(ScanDirectionIsForward(dir)); + + if (so->curr == 0) + { + Datum value; + ArrayType* array; + int n_items; + size_t n_results; + label_t* results; + HnswOptions *opts = (HnswOptions *) scan->indexRelation->rd_options; + size_t efSearch = opts ? opts->efSearch : DEFAULT_EF_SEARCH; + + /* Safety check */ + if (scan->orderByData == NULL) + elog(ERROR, "cannot scan HNSW index without order"); + + /* No items will match if null */ + if (scan->orderByData->sk_flags & SK_ISNULL) + return false; + + value = scan->orderByData->sk_argument; + array = DatumGetArrayTypeP(value); + n_items = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array)); + if (n_items != hnsw_dimensions(so->hnsw)) + { + elog(ERROR, "Wrong number of dimensions: %d instead of %d expected", + n_items, hnsw_dimensions(so->hnsw)); + } + + if (!hnsw_search(so->hnsw, (coord_t*)ARR_DATA_PTR(array), efSearch, &n_results, &results)) + elog(ERROR, "HNSW index search failed"); + so->results = (ItemPointer)palloc(n_results*sizeof(ItemPointerData)); + so->n_results = n_results; + for (size_t i = 0; i < n_results; i++) + { + memcpy(&so->results[i], &results[i], sizeof(so->results[i])); + } + free(results); + } + if (so->curr >= so->n_results) + { + return false; + } + else + { + scan->xs_heaptid = so->results[so->curr++]; + scan->xs_recheckorderby = false; + return true; + } +} + +/* + * End a scan and release resources + */ +static void +hnsw_endscan(IndexScanDesc scan) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + if (so->results) + pfree(so->results); + pfree(so); + scan->opaque = NULL; +} + + +/* + * Estimate the cost of an index scan + */ +static void +hnsw_costestimate(PlannerInfo *root, IndexPath *path, double loop_count, + Cost *indexStartupCost, Cost *indexTotalCost, + Selectivity *indexSelectivity, double *indexCorrelation + ,double *indexPages +) +{ + GenericCosts costs; + + /* Never use index without order */ + if (path->indexorderbys == NULL) + { + *indexStartupCost = DBL_MAX; + *indexTotalCost = DBL_MAX; + *indexSelectivity = 0; + *indexCorrelation = 0; + *indexPages = 0; + return; + } + + MemSet(&costs, 0, sizeof(costs)); + + genericcostestimate(root, path, loop_count, &costs); + + /* Startup cost and total cost are same */ + *indexStartupCost = costs.indexTotalCost; + *indexTotalCost = costs.indexTotalCost; + *indexSelectivity = costs.indexSelectivity; + *indexCorrelation = costs.indexCorrelation; + *indexPages = costs.numIndexPages; +} + +/* + * Parse and validate the reloptions + */ +static bytea * +hnsw_options(Datum reloptions, bool validate) +{ + static const relopt_parse_elt tab[] = { + {"dims", RELOPT_TYPE_INT, offsetof(HnswOptions, dims)}, + {"maxelements", RELOPT_TYPE_INT, offsetof(HnswOptions, maxelements)}, + {"efconstruction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)}, + {"efsearch", RELOPT_TYPE_INT, offsetof(HnswOptions, efSearch)}, + {"m", RELOPT_TYPE_INT, offsetof(HnswOptions, M)} + }; + + return (bytea *) build_reloptions(reloptions, validate, + hnsw_relopt_kind, + sizeof(HnswOptions), + tab, lengthof(tab)); +} + +/* + * Validate catalog entries for the specified operator class + */ +static bool +hnsw_validate(Oid opclassoid) +{ + return true; +} + +/* + * Build the index for a logged table + */ +static IndexBuildResult * +hnsw_build(Relation heap, Relation index, IndexInfo *indexInfo) +{ + HierarchicalNSW* hnsw = hnsw_get_index(index, heap); + IndexBuildResult* result = (IndexBuildResult *) palloc(sizeof(IndexBuildResult)); + result->heap_tuples = result->index_tuples = hnsw_count(hnsw); + + return result; +} + +/* + * Insert a tuple into the index + */ +static bool +hnsw_insert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, + Relation heap, IndexUniqueCheck checkUnique, + bool indexUnchanged, + IndexInfo *indexInfo) +{ + HierarchicalNSW* hnsw = hnsw_get_index(index, heap); + Datum value; + ArrayType* array; + int n_items; + label_t label = 0; + + /* Skip nulls */ + if (isnull[0]) + return false; + + /* Detoast value */ + value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + array = DatumGetArrayTypeP(value); + n_items = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array)); + if (n_items != hnsw_dimensions(hnsw)) + { + elog(ERROR, "Wrong number of dimensions: %d instead of %d expected", + n_items, hnsw_dimensions(hnsw)); + } + memcpy(&label, heap_tid, sizeof(*heap_tid)); + if (!hnsw_add_point(hnsw, (coord_t*)ARR_DATA_PTR(array), label)) + elog(ERROR, "HNSW index insert failed"); + return true; +} + +/* + * Build the index for an unlogged table + */ +static void +hnsw_buildempty(Relation index) +{ + /* index will be constructed on dema nd when accessed */ +} + +/* + * Clean up after a VACUUM operation + */ +static IndexBulkDeleteResult * +hnsw_vacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats) +{ + Relation rel = info->index; + + if (stats == NULL) + return NULL; + + stats->num_pages = RelationGetNumberOfBlocks(rel); + + return stats; +} + +/* + * Bulk delete tuples from the index + */ +static IndexBulkDeleteResult * +hnsw_bulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, + IndexBulkDeleteCallback callback, void *callback_state) +{ + if (stats == NULL) + stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult)); + return stats; +} + +/* + * Define index handler + * + * See https://www.postgresql.org/docs/current/index-api.html + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(hnsw_handler); +Datum +hnsw_handler(PG_FUNCTION_ARGS) +{ + IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); + + amroutine->amstrategies = 0; + amroutine->amsupport = 0; + amroutine->amoptsprocnum = 0; + amroutine->amcanorder = false; + amroutine->amcanorderbyop = true; + amroutine->amcanbackward = false; /* can change direction mid-scan */ + amroutine->amcanunique = false; + amroutine->amcanmulticol = false; + amroutine->amoptionalkey = true; + amroutine->amsearcharray = false; + amroutine->amsearchnulls = false; + amroutine->amstorage = false; + amroutine->amclusterable = false; + amroutine->ampredlocks = false; + amroutine->amcanparallel = false; + amroutine->amcaninclude = false; + amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */ + amroutine->amparallelvacuumoptions = VACUUM_OPTION_PARALLEL_BULKDEL; + amroutine->amkeytype = InvalidOid; + + /* Interface functions */ + amroutine->ambuild = hnsw_build; + amroutine->ambuildempty = hnsw_buildempty; + amroutine->aminsert = hnsw_insert; + amroutine->ambulkdelete = hnsw_bulkdelete; + amroutine->amvacuumcleanup = hnsw_vacuumcleanup; + amroutine->amcanreturn = NULL; /* tuple not included in heapsort */ + amroutine->amcostestimate = hnsw_costestimate; + amroutine->amoptions = hnsw_options; + amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */ + amroutine->ambuildphasename = NULL; + amroutine->amvalidate = hnsw_validate; + amroutine->amadjustmembers = NULL; + amroutine->ambeginscan = hnsw_beginscan; + amroutine->amrescan = hnsw_rescan; + amroutine->amgettuple = hnsw_gettuple; + amroutine->amgetbitmap = NULL; + amroutine->amendscan = hnsw_endscan; + amroutine->ammarkpos = NULL; + amroutine->amrestrpos = NULL; + + /* Interface functions to support parallel index scans */ + amroutine->amestimateparallelscan = NULL; + amroutine->aminitparallelscan = NULL; + amroutine->amparallelrescan = NULL; + + PG_RETURN_POINTER(amroutine); +} + +/* + * Get the L2 distance between vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(l2_distance); +Datum +l2_distance(PG_FUNCTION_ARGS) +{ + ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); + ArrayType *b = PG_GETARG_ARRAYTYPE_P(1); + int a_dim = ArrayGetNItems(ARR_NDIM(a), ARR_DIMS(a)); + int b_dim = ArrayGetNItems(ARR_NDIM(b), ARR_DIMS(b)); + dist_t distance = 0.0; + dist_t diff; + coord_t *ax = (coord_t*)ARR_DATA_PTR(a); + coord_t *bx = (coord_t*)ARR_DATA_PTR(b); + + if (a_dim != b_dim) + { + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("different array dimensions %d and %d", a_dim, b_dim))); + } + + for (int i = 0; i < a_dim; i++) + { + diff = ax[i] - bx[i]; + distance += diff * diff; + } + + PG_RETURN_FLOAT4((dist_t)sqrt(distance)); +} diff --git a/pgxn/hnsw/hnsw.control b/pgxn/hnsw/hnsw.control new file mode 100644 index 0000000000..b292b96026 --- /dev/null +++ b/pgxn/hnsw/hnsw.control @@ -0,0 +1,5 @@ +comment = 'hNsw index' +default_version = '0.1.0' +module_pathname = '$libdir/hnsw' +relocatable = true +trusted = true diff --git a/pgxn/hnsw/hnsw.h b/pgxn/hnsw/hnsw.h new file mode 100644 index 0000000000..d4065ab8fe --- /dev/null +++ b/pgxn/hnsw/hnsw.h @@ -0,0 +1,15 @@ +#pragma once + +typedef float coord_t; +typedef float dist_t; +typedef uint32_t idx_t; +typedef uint64_t label_t; + +typedef struct HierarchicalNSW HierarchicalNSW; + +bool hnsw_search(HierarchicalNSW* hnsw, const coord_t *point, size_t efSearch, size_t* n_results, label_t** results); +bool hnsw_add_point(HierarchicalNSW* hnsw, const coord_t *point, label_t label); +void hnsw_init(HierarchicalNSW* hnsw, size_t dim, size_t maxelements, size_t M, size_t maxM, size_t efConstruction); +int hnsw_dimensions(HierarchicalNSW* hnsw); +size_t hnsw_count(HierarchicalNSW* hnsw); +size_t hnsw_sizeof(void); diff --git a/pgxn/hnsw/hnswalg.cpp b/pgxn/hnsw/hnswalg.cpp new file mode 100644 index 0000000000..226dcbd53d --- /dev/null +++ b/pgxn/hnsw/hnswalg.cpp @@ -0,0 +1,368 @@ +#include "hnswalg.h" + + +#if defined(__x86_64__) + +#include +#define USE_AVX +#if defined(__GNUC__) +#define PORTABLE_ALIGN32 __attribute__((aligned(32))) +#else +#define PORTABLE_ALIGN32 __declspec(align(32)) +#endif + +#define PREFETCH(addr,hint) _mm_prefetch(addr, hint) + +#else + +#define PREFETCH(addr,hint) + +#endif + +HierarchicalNSW::HierarchicalNSW(size_t dim_, size_t maxelements_, size_t M_, size_t maxM_, size_t efConstruction_) +{ + dim = dim_; + data_size = dim * sizeof(coord_t); + + efConstruction = efConstruction_; + + maxelements = maxelements_; + M = M_; + maxM = maxM_; + size_links_level0 = (maxM + 1) * sizeof(idx_t); + size_data_per_element = size_links_level0 + data_size + sizeof(label_t); + offset_data = size_links_level0; + offset_label = offset_data + data_size; + + enterpoint_node = 0; + cur_element_count = 0; + dist_calc = 0; +} + +std::priority_queue> HierarchicalNSW::searchBaseLayer(const coord_t *point, size_t ef) +{ + std::vector visited; + visited.resize((cur_element_count + 31) >> 5); + + std::priority_queue> topResults; + std::priority_queue> candidateSet; + + dist_t dist = fstdistfunc(point, getDataByInternalId(enterpoint_node)); + + topResults.emplace(dist, enterpoint_node); + candidateSet.emplace(-dist, enterpoint_node); + visited[enterpoint_node >> 5] = 1 << (enterpoint_node & 31); + dist_t lowerBound = dist; + + while (!candidateSet.empty()) + { + std::pair curr_el_pair = candidateSet.top(); + if (-curr_el_pair.first > lowerBound) + break; + + candidateSet.pop(); + idx_t curNodeNum = curr_el_pair.second; + + idx_t* data = get_linklist0(curNodeNum); + size_t size = *data++; + + PREFETCH(getDataByInternalId(*data), _MM_HINT_T0); + + for (size_t j = 0; j < size; ++j) { + size_t tnum = *(data + j); + + PREFETCH(getDataByInternalId(*(data + j + 1)), _MM_HINT_T0); + + if (!(visited[tnum >> 5] & (1 << (tnum & 31)))) { + visited[tnum >> 5] |= 1 << (tnum & 31); + + dist = fstdistfunc(point, getDataByInternalId(tnum)); + + if (topResults.top().first > dist || topResults.size() < ef) { + candidateSet.emplace(-dist, tnum); + + PREFETCH(get_linklist0(candidateSet.top().second), _MM_HINT_T0); + topResults.emplace(dist, tnum); + + if (topResults.size() > ef) + topResults.pop(); + + lowerBound = topResults.top().first; + } + } + } + } + return topResults; +} + + +void HierarchicalNSW::getNeighborsByHeuristic(std::priority_queue> &topResults, size_t NN) +{ + if (topResults.size() < NN) + return; + + std::priority_queue> resultSet; + std::vector> returnlist; + + while (topResults.size() > 0) { + resultSet.emplace(-topResults.top().first, topResults.top().second); + topResults.pop(); + } + + while (resultSet.size()) { + if (returnlist.size() >= NN) + break; + std::pair curen = resultSet.top(); + dist_t dist_to_query = -curen.first; + resultSet.pop(); + bool good = true; + for (std::pair curen2 : returnlist) { + dist_t curdist = fstdistfunc(getDataByInternalId(curen2.second), + getDataByInternalId(curen.second)); + if (curdist < dist_to_query) { + good = false; + break; + } + } + if (good) returnlist.push_back(curen); + } + for (std::pair elem : returnlist) + topResults.emplace(-elem.first, elem.second); +} + +void HierarchicalNSW::mutuallyConnectNewElement(const coord_t *point, idx_t cur_c, + std::priority_queue> topResults) +{ + getNeighborsByHeuristic(topResults, M); + + std::vector res; + res.reserve(M); + while (topResults.size() > 0) { + res.push_back(topResults.top().second); + topResults.pop(); + } + { + idx_t* data = get_linklist0(cur_c); + if (*data) + throw std::runtime_error("Should be blank"); + + *data++ = res.size(); + + for (size_t idx = 0; idx < res.size(); idx++) { + if (data[idx]) + throw std::runtime_error("Should be blank"); + data[idx] = res[idx]; + } + } + for (size_t idx = 0; idx < res.size(); idx++) { + if (res[idx] == cur_c) + throw std::runtime_error("Connection to the same element"); + + size_t resMmax = maxM; + idx_t *ll_other = get_linklist0(res[idx]); + idx_t sz_link_list_other = *ll_other; + + if (sz_link_list_other > resMmax || sz_link_list_other < 0) + throw std::runtime_error("Bad sz_link_list_other"); + + if (sz_link_list_other < resMmax) { + idx_t *data = ll_other + 1; + data[sz_link_list_other] = cur_c; + *ll_other = sz_link_list_other + 1; + } else { + // finding the "weakest" element to replace it with the new one + idx_t *data = ll_other + 1; + dist_t d_max = fstdistfunc(getDataByInternalId(cur_c), getDataByInternalId(res[idx])); + // Heuristic: + std::priority_queue> candidates; + candidates.emplace(d_max, cur_c); + + for (size_t j = 0; j < sz_link_list_other; j++) + candidates.emplace(fstdistfunc(getDataByInternalId(data[j]), getDataByInternalId(res[idx])), data[j]); + + getNeighborsByHeuristic(candidates, resMmax); + + size_t indx = 0; + while (!candidates.empty()) { + data[indx] = candidates.top().second; + candidates.pop(); + indx++; + } + *ll_other = indx; + } + } +} + +void HierarchicalNSW::addPoint(const coord_t *point, label_t label) +{ + if (cur_element_count >= maxelements) { + throw std::runtime_error("The number of elements exceeds the specified limit"); + } + idx_t cur_c = cur_element_count++; + memset((char *) get_linklist0(cur_c), 0, size_data_per_element); + memcpy(getDataByInternalId(cur_c), point, data_size); + memcpy(getExternalLabel(cur_c), &label, sizeof label); + + // Do nothing for the first element + if (cur_c != 0) { + std::priority_queue > topResults = searchBaseLayer(point, efConstruction); + mutuallyConnectNewElement(point, cur_c, topResults); + } +}; + +std::priority_queue> HierarchicalNSW::searchKnn(const coord_t *query, size_t k) +{ + std::priority_queue> topResults; + auto topCandidates = searchBaseLayer(query, k); + while (topCandidates.size() > k) { + topCandidates.pop(); + } + while (!topCandidates.empty()) { + std::pair rez = topCandidates.top(); + label_t label; + memcpy(&label, getExternalLabel(rez.second), sizeof(label)); + topResults.push(std::pair(rez.first, label)); + topCandidates.pop(); + } + + return topResults; +}; + +dist_t HierarchicalNSW::fstdistfunc(const coord_t *x, const coord_t *y) +{ +#if defined(__x86_64__) + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = dim >> 4; + const float *pEnd1 = x + (qty16 << 4); +#ifdef USE_AVX + __m256 diff, v1, v2; + __m256 sum = _mm256_set1_ps(0); + + while (x < pEnd1) { + v1 = _mm256_loadu_ps(x); + x += 8; + v2 = _mm256_loadu_ps(y); + y += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + + v1 = _mm256_loadu_ps(x); + x += 8; + v2 = _mm256_loadu_ps(y); + y += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + } + + _mm256_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + + return (res); +#else + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (x < pEnd1) { + v1 = _mm_loadu_ps(x); + x += 4; + v2 = _mm_loadu_ps(y); + y += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(x); + x += 4; + v2 = _mm_loadu_ps(y); + y += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(x); + x += 4; + v2 = _mm_loadu_ps(y); + y += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(x); + x += 4; + v2 = _mm_loadu_ps(y); + y += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + } + _mm_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return (res); +#endif +#else // portable implementation + dist_t distance = 0.0; + size_t n = dim; + + dist_calc++; + + for (size_t i = 0; i < n; i++) + { + dist_t diff = x[i] - y[i]; + distance += diff * diff; + } + return distance; +#endif +} + +bool hnsw_search(HierarchicalNSW* hnsw, const coord_t *point, size_t efSearch, size_t* n_results, label_t** results) +{ + try + { + auto result = hnsw->searchKnn(point, efSearch); + size_t nResults = result.size(); + *results = (label_t*)malloc(nResults*sizeof(label_t)); + for (size_t i = nResults; i-- != 0;) + { + (*results)[i] = result.top().second; + result.pop(); + } + *n_results = nResults; + return true; + } + catch (std::exception& x) + { + return false; + } +} + +bool hnsw_add_point(HierarchicalNSW* hnsw, const coord_t *point, label_t label) +{ + try + { + hnsw->addPoint(point, label); + return true; + } + catch (std::exception& x) + { + fprintf(stderr, "Catch %s\n", x.what()); + return false; + } +} + +void hnsw_init(HierarchicalNSW* hnsw, size_t dims, size_t maxelements, size_t M, size_t maxM, size_t efConstruction) +{ + new ((void*)hnsw) HierarchicalNSW(dims, maxelements, M, maxM, efConstruction); +} + + +int hnsw_dimensions(HierarchicalNSW* hnsw) +{ + return (int)hnsw->dim; +} + +size_t hnsw_count(HierarchicalNSW* hnsw) +{ + return hnsw->cur_element_count; +} + +size_t hnsw_sizeof(void) +{ + return sizeof(HierarchicalNSW); +} diff --git a/pgxn/hnsw/hnswalg.h b/pgxn/hnsw/hnswalg.h new file mode 100644 index 0000000000..b845ad2743 --- /dev/null +++ b/pgxn/hnsw/hnswalg.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +extern "C" { +#include "hnsw.h" +} + +struct HierarchicalNSW +{ + size_t maxelements; + size_t cur_element_count; + + idx_t enterpoint_node; + + size_t dist_calc; + + size_t dim; + size_t data_size; + size_t offset_data; + size_t offset_label; + size_t size_data_per_element; + size_t M; + size_t maxM; + size_t size_links_level0; + size_t efConstruction; + + char data_level0_memory[0]; // varying size + + public: + HierarchicalNSW(size_t dim, size_t maxelements, size_t M, size_t maxM, size_t efConstruction); + ~HierarchicalNSW(); + + + inline coord_t *getDataByInternalId(idx_t internal_id) const { + return (coord_t *)&data_level0_memory[internal_id * size_data_per_element + offset_data]; + } + + inline idx_t *get_linklist0(idx_t internal_id) const { + return (idx_t*)&data_level0_memory[internal_id * size_data_per_element]; + } + + inline label_t *getExternalLabel(idx_t internal_id) const { + return (label_t *)&data_level0_memory[internal_id * size_data_per_element + offset_label]; + } + + std::priority_queue> searchBaseLayer(const coord_t *x, size_t ef); + + void getNeighborsByHeuristic(std::priority_queue> &topResults, size_t NN); + + void mutuallyConnectNewElement(const coord_t *x, idx_t id, std::priority_queue> topResults); + + void addPoint(const coord_t *point, label_t label); + + std::priority_queue> searchKnn(const coord_t *query_data, size_t k); + + dist_t fstdistfunc(const coord_t *x, const coord_t *y); +}; diff --git a/pgxn/hnsw/test/expected/knn.out b/pgxn/hnsw/test/expected/knn.out new file mode 100644 index 0000000000..a1cee4525e --- /dev/null +++ b/pgxn/hnsw/test/expected/knn.out @@ -0,0 +1,28 @@ +SET enable_seqscan = off; +CREATE TABLE t (val real[]); +INSERT INTO t (val) VALUES ('{0,0,0}'), ('{1,2,3}'), ('{1,1,1}'), (NULL); +CREATE INDEX ON t USING hnsw (val) WITH (maxelements = 10, dims=3, m=3); +INSERT INTO t (val) VALUES (array[1,2,4]); +explain SELECT * FROM t ORDER BY val <-> array[3,3,3]; + QUERY PLAN +-------------------------------------------------------------------- + Index Scan using t_val_idx on t (cost=4.02..8.06 rows=3 width=36) + Order By: (val <-> '{3,3,3}'::real[]) +(2 rows) + +SELECT * FROM t ORDER BY val <-> array[3,3,3]; + val +--------- + {1,2,3} + {1,2,4} + {1,1,1} + {0,0,0} +(4 rows) + +SELECT COUNT(*) FROM t; + count +------- + 5 +(1 row) + +DROP TABLE t; diff --git a/pgxn/hnsw/test/sql/knn.sql b/pgxn/hnsw/test/sql/knn.sql new file mode 100644 index 0000000000..0635bda4a2 --- /dev/null +++ b/pgxn/hnsw/test/sql/knn.sql @@ -0,0 +1,13 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val real[]); +INSERT INTO t (val) VALUES ('{0,0,0}'), ('{1,2,3}'), ('{1,1,1}'), (NULL); +CREATE INDEX ON t USING hnsw (val) WITH (maxelements = 10, dims=3, m=3); + +INSERT INTO t (val) VALUES (array[1,2,4]); + +explain SELECT * FROM t ORDER BY val <-> array[3,3,3]; +SELECT * FROM t ORDER BY val <-> array[3,3,3]; +SELECT COUNT(*) FROM t; + +DROP TABLE t;