From 648327e90c45cee70bfe81e0e4ba8e55110408f9 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 15 Jan 2025 01:00:57 +0800 Subject: [PATCH] docs: show how to pack bits for binary vector (#2020) Signed-off-by: BubbleCal --- .../python/tests/docs/test_binary_vector.py | 73 ++++++++++++++----- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/python/python/tests/docs/test_binary_vector.py b/python/python/tests/docs/test_binary_vector.py index 0bc8030d..ec8da5be 100644 --- a/python/python/tests/docs/test_binary_vector.py +++ b/python/python/tests/docs/test_binary_vector.py @@ -3,6 +3,7 @@ import shutil # --8<-- [start:imports] import lancedb import numpy as np +import pyarrow as pa import pytest # --8<-- [end:imports] @@ -12,16 +13,32 @@ shutil.rmtree("data/binary_lancedb", ignore_errors=True) def test_binary_vector(): # --8<-- [start:sync_binary_vector] db = lancedb.connect("data/binary_lancedb") - data = [ - { - "id": i, - "vector": np.random.randint(0, 256, size=16), - } - for i in range(1024) - ] - tbl = db.create_table("my_binary_vectors", data=data) - query = np.random.randint(0, 256, size=16) - tbl.search(query).metric("hamming").to_arrow() + schema = pa.schema( + [ + pa.field("id", pa.int64()), + # for dim=256, lance stores every 8 bits in a byte + # so the vector field should be a list of 256 / 8 = 32 bytes + pa.field("vector", pa.list_(pa.uint8(), 32)), + ] + ) + tbl = db.create_table("my_binary_vectors", schema=schema) + + data = [] + for i in range(1024): + vector = np.random.randint(0, 2, size=256) + # pack the binary vector into bytes to save space + packed_vector = np.packbits(vector) + data.append( + { + "id": i, + "vector": packed_vector, + } + ) + tbl.add(data) + + query = np.random.randint(0, 2, size=256) + packed_query = np.packbits(query) + tbl.search(packed_query).metric("hamming").to_arrow() # --8<-- [end:sync_binary_vector] db.drop_table("my_binary_vectors") @@ -30,15 +47,31 @@ def test_binary_vector(): async def test_binary_vector_async(): # --8<-- [start:async_binary_vector] db = await lancedb.connect_async("data/binary_lancedb") - data = [ - { - "id": i, - "vector": np.random.randint(0, 256, size=16), - } - for i in range(1024) - ] - tbl = await db.create_table("my_binary_vectors", data=data) - query = np.random.randint(0, 256, size=16) - await tbl.query().nearest_to(query).distance_type("hamming").to_arrow() + schema = pa.schema( + [ + pa.field("id", pa.int64()), + # for dim=256, lance stores every 8 bits in a byte + # so the vector field should be a list of 256 / 8 = 32 bytes + pa.field("vector", pa.list_(pa.uint8(), 32)), + ] + ) + tbl = await db.create_table("my_binary_vectors", schema=schema) + + data = [] + for i in range(1024): + vector = np.random.randint(0, 2, size=256) + # pack the binary vector into bytes to save space + packed_vector = np.packbits(vector) + data.append( + { + "id": i, + "vector": packed_vector, + } + ) + await tbl.add(data) + + query = np.random.randint(0, 2, size=256) + packed_query = np.packbits(query) + await tbl.query().nearest_to(packed_query).distance_type("hamming").to_arrow() # --8<-- [end:async_binary_vector] await db.drop_table("my_binary_vectors")