feat: add to_query_object method (#2239)

This PR adds a `to_query_object` method to the various query builders
(except not hybrid queries yet). This makes it possible to inspect the
query that is built.

In addition this PR does some normalization between the sync and async
query paths. A few custom defaults were removed in favor of None (with
the default getting set once, in rust).

Also, the synchronous to_batches method will now actually stream results

Also, the remote API now defaults to prefiltering
This commit is contained in:
Weston Pace
2025-03-21 13:01:51 -07:00
committed by GitHub
parent b2a38ac366
commit 9403254442
8 changed files with 867 additions and 177 deletions

View File

@@ -26,10 +26,12 @@ from lancedb.query import (
AsyncVectorQuery,
LanceVectorQueryBuilder,
Query,
FullTextSearchQuery,
)
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
from lancedb.table import AsyncTable, LanceTable
from utils import exception_output
from importlib.util import find_spec
@pytest.fixture(scope="module")
@@ -392,12 +394,28 @@ def test_query_builder_batches(table):
for item in rs:
rs_list.append(item)
assert isinstance(item, pa.RecordBatch)
assert len(rs_list) == 1
assert len(rs_list[0]["id"]) == 2
assert len(rs_list) == 2
assert len(rs_list[0]["id"]) == 1
assert all(rs_list[0].to_pandas()["vector"][0] == [1.0, 2.0])
assert rs_list[0].to_pandas()["id"][0] == 1
assert all(rs_list[0].to_pandas()["vector"][1] == [3.0, 4.0])
assert rs_list[0].to_pandas()["id"][1] == 2
assert all(rs_list[1].to_pandas()["vector"][0] == [3.0, 4.0])
assert rs_list[1].to_pandas()["id"][0] == 2
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.limit(2)
.select(["id", "vector"])
.to_batches(2)
)
rs_list = []
for item in rs:
rs_list.append(item)
assert isinstance(item, pa.RecordBatch)
assert len(rs_list) == 1
assert len(rs_list[0]["id"]) == 2
rs_list = rs_list[0].to_pandas()
assert rs_list["id"][0] == 1
assert rs_list["id"][1] == 2
def test_dynamic_projection(table):
@@ -488,12 +506,9 @@ def test_query_builder_with_different_vector_column():
Query(
vector=query,
filter="b < 10",
prefilter=True,
k=2,
metric="cosine",
limit=2,
distance_type="cosine",
columns=["b"],
nprobes=20,
refine_factor=None,
vector_column="foo_vector",
),
None,
@@ -595,6 +610,10 @@ async def test_query_async(table_async: AsyncTable):
@pytest.mark.asyncio
@pytest.mark.slow
async def test_query_reranked_async(table_async: AsyncTable):
# CrossEncoderReranker requires torch
if find_spec("torch") is None:
pytest.skip("torch not installed")
# FTS with rerank
await table_async.create_index("text", config=FTS(with_position=False))
await check_query(
@@ -823,3 +842,223 @@ async def test_query_search_specified(mem_db_async: AsyncConnection):
assert "No embedding functions are registered for any columns" in exception_output(
e
)
# Helper method used in the following tests. Looks at the simple python object `q` and
# checks that the properties match the expected values in kwargs.
def check_set_props(q, **kwargs):
for k in dict(q):
if not k.startswith("_"):
if k in kwargs:
assert kwargs[k] == getattr(q, k), (
f"{k} should be {kwargs[k]} but is {getattr(q, k)}"
)
else:
assert getattr(q, k) is None, f"{k} should be None"
def test_query_serialization_sync(table: lancedb.table.Table):
# Simple queries
q = table.search().where("id = 1").limit(500).offset(10).to_query_object()
check_set_props(q, limit=500, offset=10, filter="id = 1")
q = table.search().select(["id", "vector"]).to_query_object()
check_set_props(q, columns=["id", "vector"])
q = table.search().with_row_id(True).to_query_object()
check_set_props(q, with_row_id=True)
# Vector queries
q = table.search([5.0, 6.0]).limit(10).to_query_object()
check_set_props(q, limit=10, vector_column="vector", vector=[5.0, 6.0])
q = table.search([5.0, 6.0]).to_query_object()
check_set_props(q, vector_column="vector", vector=[5.0, 6.0])
q = (
table.search([5.0, 6.0])
.limit(10)
.where("id = 1", prefilter=False)
.to_query_object()
)
check_set_props(
q,
limit=10,
vector_column="vector",
filter="id = 1",
postfilter=True,
vector=[5.0, 6.0],
)
q = table.search([5.0, 6.0]).nprobes(10).refine_factor(5).to_query_object()
check_set_props(
q, vector_column="vector", vector=[5.0, 6.0], nprobes=10, refine_factor=5
)
q = table.search([5.0, 6.0]).distance_range(0.0, 1.0).to_query_object()
check_set_props(
q, vector_column="vector", vector=[5.0, 6.0], lower_bound=0.0, upper_bound=1.0
)
q = table.search([5.0, 6.0]).distance_type("cosine").to_query_object()
check_set_props(
q, distance_type="cosine", vector_column="vector", vector=[5.0, 6.0]
)
q = table.search([5.0, 6.0]).ef(7).to_query_object()
check_set_props(q, ef=7, vector_column="vector", vector=[5.0, 6.0])
q = table.search([5.0, 6.0]).bypass_vector_index().to_query_object()
check_set_props(
q, bypass_vector_index=True, vector_column="vector", vector=[5.0, 6.0]
)
# FTS queries
q = table.search("foo").limit(10).to_query_object()
check_set_props(
q, limit=10, full_text_query=FullTextSearchQuery(columns=[], query="foo")
)
q = table.search("foo", query_type="fts").to_query_object()
check_set_props(q, full_text_query=FullTextSearchQuery(columns=[], query="foo"))
@pytest.mark.asyncio
async def test_query_serialization_async(table_async: AsyncTable):
# Simple queries
q = table_async.query().where("id = 1").limit(500).offset(10).to_query_object()
check_set_props(q, limit=500, offset=10, filter="id = 1", with_row_id=False)
q = table_async.query().select(["id", "vector"]).to_query_object()
check_set_props(q, columns=["id", "vector"], with_row_id=False)
q = table_async.query().with_row_id().to_query_object()
check_set_props(q, with_row_id=True)
sample_vector = [pa.array([5.0, 6.0], type=pa.float32())]
# Vector queries
q = (await table_async.search([5.0, 6.0])).limit(10).to_query_object()
check_set_props(
q,
limit=10,
vector=sample_vector,
postfilter=False,
nprobes=20,
with_row_id=False,
bypass_vector_index=False,
)
q = (await table_async.search([5.0, 6.0])).to_query_object()
check_set_props(
q,
vector=sample_vector,
postfilter=False,
nprobes=20,
with_row_id=False,
bypass_vector_index=False,
limit=10,
)
q = (
(await table_async.search([5.0, 6.0]))
.limit(10)
.where("id = 1")
.postfilter()
.to_query_object()
)
check_set_props(
q,
limit=10,
filter="id = 1",
postfilter=True,
vector=sample_vector,
nprobes=20,
with_row_id=False,
bypass_vector_index=False,
)
q = (
(await table_async.search([5.0, 6.0]))
.nprobes(10)
.refine_factor(5)
.to_query_object()
)
check_set_props(
q,
vector=sample_vector,
nprobes=10,
refine_factor=5,
postfilter=False,
with_row_id=False,
bypass_vector_index=False,
limit=10,
)
q = (
(await table_async.search([5.0, 6.0]))
.distance_range(0.0, 1.0)
.to_query_object()
)
check_set_props(
q,
vector=sample_vector,
lower_bound=0.0,
upper_bound=1.0,
postfilter=False,
nprobes=20,
with_row_id=False,
bypass_vector_index=False,
limit=10,
)
q = (await table_async.search([5.0, 6.0])).distance_type("cosine").to_query_object()
check_set_props(
q,
distance_type="cosine",
vector=sample_vector,
postfilter=False,
nprobes=20,
with_row_id=False,
bypass_vector_index=False,
limit=10,
)
q = (await table_async.search([5.0, 6.0])).ef(7).to_query_object()
check_set_props(
q,
ef=7,
vector=sample_vector,
postfilter=False,
nprobes=20,
with_row_id=False,
bypass_vector_index=False,
limit=10,
)
q = (await table_async.search([5.0, 6.0])).bypass_vector_index().to_query_object()
check_set_props(
q,
bypass_vector_index=True,
vector=sample_vector,
postfilter=False,
nprobes=20,
with_row_id=False,
limit=10,
)
# FTS queries
q = (await table_async.search("foo")).limit(10).to_query_object()
check_set_props(
q,
limit=10,
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
with_row_id=False,
)
q = (await table_async.search("foo", query_type="fts")).to_query_object()
check_set_props(
q,
full_text_query=FullTextSearchQuery(columns=[], query="foo"),
with_row_id=False,
)