From d57bed90e57459e71f00205938e650f372896a92 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 15 Jan 2025 13:17:05 +0800 Subject: [PATCH] docs: add missing example code (#2025) --- python/python/tests/docs/test_multivector.py | 70 ++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 python/python/tests/docs/test_multivector.py diff --git a/python/python/tests/docs/test_multivector.py b/python/python/tests/docs/test_multivector.py new file mode 100644 index 00000000..d6b6c3a8 --- /dev/null +++ b/python/python/tests/docs/test_multivector.py @@ -0,0 +1,70 @@ +import shutil +import pytest + +# --8<-- [start:imports] +import lancedb +import numpy as np +import pyarrow as pa +# --8<-- [end:imports] + +shutil.rmtree("data/multivector_demo", ignore_errors=True) + + +def test_multivector(): + # --8<-- [start:sync_multivector] + db = lancedb.connect("data/multivector_demo") + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("vector", pa.list_(pa.list_(pa.float32(), 256))), + ] + ) + data = [ + { + "id": i, + "vector": np.random.random(size=(2, 256)).tolist(), + } + for i in range(1024) + ] + tbl = db.create_table("my_table", data=data, schema=schema) + + # query with single vector + query = np.random.random(256) + tbl.search(query).to_arrow() + + # query with multiple vectors + query = np.random.random(size=(2, 256)) + tbl.search(query).to_arrow() + + # --8<-- [end:sync_multivector] + db.drop_table("my_table") + + +@pytest.mark.asyncio +async def test_multivector_async(): + # --8<-- [start:async_multivector] + db = await lancedb.connect_async("data/multivector_demo") + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("vector", pa.list_(pa.list_(pa.float32(), 256))), + ] + ) + data = [ + { + "id": i, + "vector": np.random.random(size=(2, 256)).tolist(), + } + for i in range(1024) + ] + tbl = await db.create_table("my_table", data=data, schema=schema) + + # query with single vector + query = np.random.random(256) + await tbl.query().nearest_to(query).to_arrow() + + # query with multiple vectors + query = np.random.random(size=(2, 256)) + + # --8<-- [end:async_multivector] + await db.drop_table("my_table")