mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 14:49:57 +00:00
feat: add to_query_object method (#2239)
This PR adds a `to_query_object` method to the various query builders (except not hybrid queries yet). This makes it possible to inspect the query that is built. In addition this PR does some normalization between the sync and async query paths. A few custom defaults were removed in favor of None (with the default getting set once, in rust). Also, the synchronous to_batches method will now actually stream results Also, the remote API now defaults to prefiltering
This commit is contained in:
@@ -26,10 +26,12 @@ from lancedb.query import (
|
||||
AsyncVectorQuery,
|
||||
LanceVectorQueryBuilder,
|
||||
Query,
|
||||
FullTextSearchQuery,
|
||||
)
|
||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
from utils import exception_output
|
||||
from importlib.util import find_spec
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -392,12 +394,28 @@ def test_query_builder_batches(table):
|
||||
for item in rs:
|
||||
rs_list.append(item)
|
||||
assert isinstance(item, pa.RecordBatch)
|
||||
assert len(rs_list) == 1
|
||||
assert len(rs_list[0]["id"]) == 2
|
||||
assert len(rs_list) == 2
|
||||
assert len(rs_list[0]["id"]) == 1
|
||||
assert all(rs_list[0].to_pandas()["vector"][0] == [1.0, 2.0])
|
||||
assert rs_list[0].to_pandas()["id"][0] == 1
|
||||
assert all(rs_list[0].to_pandas()["vector"][1] == [3.0, 4.0])
|
||||
assert rs_list[0].to_pandas()["id"][1] == 2
|
||||
assert all(rs_list[1].to_pandas()["vector"][0] == [3.0, 4.0])
|
||||
assert rs_list[1].to_pandas()["id"][0] == 2
|
||||
|
||||
rs = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.limit(2)
|
||||
.select(["id", "vector"])
|
||||
.to_batches(2)
|
||||
)
|
||||
rs_list = []
|
||||
for item in rs:
|
||||
rs_list.append(item)
|
||||
assert isinstance(item, pa.RecordBatch)
|
||||
assert len(rs_list) == 1
|
||||
assert len(rs_list[0]["id"]) == 2
|
||||
rs_list = rs_list[0].to_pandas()
|
||||
assert rs_list["id"][0] == 1
|
||||
assert rs_list["id"][1] == 2
|
||||
|
||||
|
||||
def test_dynamic_projection(table):
|
||||
@@ -488,12 +506,9 @@ def test_query_builder_with_different_vector_column():
|
||||
Query(
|
||||
vector=query,
|
||||
filter="b < 10",
|
||||
prefilter=True,
|
||||
k=2,
|
||||
metric="cosine",
|
||||
limit=2,
|
||||
distance_type="cosine",
|
||||
columns=["b"],
|
||||
nprobes=20,
|
||||
refine_factor=None,
|
||||
vector_column="foo_vector",
|
||||
),
|
||||
None,
|
||||
@@ -595,6 +610,10 @@ async def test_query_async(table_async: AsyncTable):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_query_reranked_async(table_async: AsyncTable):
|
||||
# CrossEncoderReranker requires torch
|
||||
if find_spec("torch") is None:
|
||||
pytest.skip("torch not installed")
|
||||
|
||||
# FTS with rerank
|
||||
await table_async.create_index("text", config=FTS(with_position=False))
|
||||
await check_query(
|
||||
@@ -823,3 +842,223 @@ async def test_query_search_specified(mem_db_async: AsyncConnection):
|
||||
assert "No embedding functions are registered for any columns" in exception_output(
|
||||
e
|
||||
)
|
||||
|
||||
|
||||
# Helper method used in the following tests. Looks at the simple python object `q` and
|
||||
# checks that the properties match the expected values in kwargs.
|
||||
def check_set_props(q, **kwargs):
|
||||
for k in dict(q):
|
||||
if not k.startswith("_"):
|
||||
if k in kwargs:
|
||||
assert kwargs[k] == getattr(q, k), (
|
||||
f"{k} should be {kwargs[k]} but is {getattr(q, k)}"
|
||||
)
|
||||
else:
|
||||
assert getattr(q, k) is None, f"{k} should be None"
|
||||
|
||||
|
||||
def test_query_serialization_sync(table: lancedb.table.Table):
|
||||
# Simple queries
|
||||
q = table.search().where("id = 1").limit(500).offset(10).to_query_object()
|
||||
check_set_props(q, limit=500, offset=10, filter="id = 1")
|
||||
|
||||
q = table.search().select(["id", "vector"]).to_query_object()
|
||||
check_set_props(q, columns=["id", "vector"])
|
||||
|
||||
q = table.search().with_row_id(True).to_query_object()
|
||||
check_set_props(q, with_row_id=True)
|
||||
|
||||
# Vector queries
|
||||
q = table.search([5.0, 6.0]).limit(10).to_query_object()
|
||||
check_set_props(q, limit=10, vector_column="vector", vector=[5.0, 6.0])
|
||||
|
||||
q = table.search([5.0, 6.0]).to_query_object()
|
||||
check_set_props(q, vector_column="vector", vector=[5.0, 6.0])
|
||||
|
||||
q = (
|
||||
table.search([5.0, 6.0])
|
||||
.limit(10)
|
||||
.where("id = 1", prefilter=False)
|
||||
.to_query_object()
|
||||
)
|
||||
check_set_props(
|
||||
q,
|
||||
limit=10,
|
||||
vector_column="vector",
|
||||
filter="id = 1",
|
||||
postfilter=True,
|
||||
vector=[5.0, 6.0],
|
||||
)
|
||||
|
||||
q = table.search([5.0, 6.0]).nprobes(10).refine_factor(5).to_query_object()
|
||||
check_set_props(
|
||||
q, vector_column="vector", vector=[5.0, 6.0], nprobes=10, refine_factor=5
|
||||
)
|
||||
|
||||
q = table.search([5.0, 6.0]).distance_range(0.0, 1.0).to_query_object()
|
||||
check_set_props(
|
||||
q, vector_column="vector", vector=[5.0, 6.0], lower_bound=0.0, upper_bound=1.0
|
||||
)
|
||||
|
||||
q = table.search([5.0, 6.0]).distance_type("cosine").to_query_object()
|
||||
check_set_props(
|
||||
q, distance_type="cosine", vector_column="vector", vector=[5.0, 6.0]
|
||||
)
|
||||
|
||||
q = table.search([5.0, 6.0]).ef(7).to_query_object()
|
||||
check_set_props(q, ef=7, vector_column="vector", vector=[5.0, 6.0])
|
||||
|
||||
q = table.search([5.0, 6.0]).bypass_vector_index().to_query_object()
|
||||
check_set_props(
|
||||
q, bypass_vector_index=True, vector_column="vector", vector=[5.0, 6.0]
|
||||
)
|
||||
|
||||
# FTS queries
|
||||
q = table.search("foo").limit(10).to_query_object()
|
||||
check_set_props(
|
||||
q, limit=10, full_text_query=FullTextSearchQuery(columns=[], query="foo")
|
||||
)
|
||||
|
||||
q = table.search("foo", query_type="fts").to_query_object()
|
||||
check_set_props(q, full_text_query=FullTextSearchQuery(columns=[], query="foo"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_serialization_async(table_async: AsyncTable):
|
||||
# Simple queries
|
||||
q = table_async.query().where("id = 1").limit(500).offset(10).to_query_object()
|
||||
check_set_props(q, limit=500, offset=10, filter="id = 1", with_row_id=False)
|
||||
|
||||
q = table_async.query().select(["id", "vector"]).to_query_object()
|
||||
check_set_props(q, columns=["id", "vector"], with_row_id=False)
|
||||
|
||||
q = table_async.query().with_row_id().to_query_object()
|
||||
check_set_props(q, with_row_id=True)
|
||||
|
||||
sample_vector = [pa.array([5.0, 6.0], type=pa.float32())]
|
||||
|
||||
# Vector queries
|
||||
q = (await table_async.search([5.0, 6.0])).limit(10).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
limit=10,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
)
|
||||
|
||||
q = (await table_async.search([5.0, 6.0])).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (
|
||||
(await table_async.search([5.0, 6.0]))
|
||||
.limit(10)
|
||||
.where("id = 1")
|
||||
.postfilter()
|
||||
.to_query_object()
|
||||
)
|
||||
check_set_props(
|
||||
q,
|
||||
limit=10,
|
||||
filter="id = 1",
|
||||
postfilter=True,
|
||||
vector=sample_vector,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
)
|
||||
|
||||
q = (
|
||||
(await table_async.search([5.0, 6.0]))
|
||||
.nprobes(10)
|
||||
.refine_factor(5)
|
||||
.to_query_object()
|
||||
)
|
||||
check_set_props(
|
||||
q,
|
||||
vector=sample_vector,
|
||||
nprobes=10,
|
||||
refine_factor=5,
|
||||
postfilter=False,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (
|
||||
(await table_async.search([5.0, 6.0]))
|
||||
.distance_range(0.0, 1.0)
|
||||
.to_query_object()
|
||||
)
|
||||
check_set_props(
|
||||
q,
|
||||
vector=sample_vector,
|
||||
lower_bound=0.0,
|
||||
upper_bound=1.0,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (await table_async.search([5.0, 6.0])).distance_type("cosine").to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
distance_type="cosine",
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (await table_async.search([5.0, 6.0])).ef(7).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
ef=7,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
bypass_vector_index=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
q = (await table_async.search([5.0, 6.0])).bypass_vector_index().to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
bypass_vector_index=True,
|
||||
vector=sample_vector,
|
||||
postfilter=False,
|
||||
nprobes=20,
|
||||
with_row_id=False,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# FTS queries
|
||||
q = (await table_async.search("foo")).limit(10).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
limit=10,
|
||||
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
|
||||
with_row_id=False,
|
||||
)
|
||||
|
||||
q = (await table_async.search("foo", query_type="fts")).to_query_object()
|
||||
check_set_props(
|
||||
q,
|
||||
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
|
||||
with_row_id=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user