feat: add output_schema method to queries (#2717)

This is a helper utility I need for some of my data loader work. It
makes it easy to see the output schema even when a `select` has been
applied.
This commit is contained in:
Weston Pace
2025-10-14 05:13:28 -07:00
committed by GitHub
parent 03eab0f091
commit 8f8e06a2da
17 changed files with 563 additions and 12 deletions

View File

@@ -123,6 +123,8 @@ class Table:
@property
def tags(self) -> Tags: ...
def query(self) -> Query: ...
def take_offsets(self, offsets: list[int]) -> TakeQuery: ...
def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ...
def vector_search(self) -> VectorQuery: ...
class Tags:
@@ -165,6 +167,7 @@ class Query:
def postfilter(self): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
def nearest_to_text(self, query: dict) -> FTSQuery: ...
async def output_schema(self) -> pa.Schema: ...
async def execute(
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
) -> RecordBatchStream: ...
@@ -172,6 +175,13 @@ class Query:
async def analyze_plan(self) -> str: ...
def to_query_request(self) -> PyQueryRequest: ...
class TakeQuery:
def select(self, columns: List[str]): ...
def with_row_id(self): ...
async def output_schema(self) -> pa.Schema: ...
async def execute(self) -> RecordBatchStream: ...
def to_query_request(self) -> PyQueryRequest: ...
class FTSQuery:
def where(self, filter: str): ...
def select(self, columns: List[str]): ...
@@ -183,12 +193,14 @@ class FTSQuery:
def get_query(self) -> str: ...
def add_query_vector(self, query_vec: pa.Array) -> None: ...
def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ...
async def output_schema(self) -> pa.Schema: ...
async def execute(
self, max_batch_length: Optional[int], timeout: Optional[timedelta]
) -> RecordBatchStream: ...
def to_query_request(self) -> PyQueryRequest: ...
class VectorQuery:
async def output_schema(self) -> pa.Schema: ...
async def execute(self) -> RecordBatchStream: ...
def where(self, filter: str): ...
def select(self, columns: List[str]): ...

View File

@@ -1237,6 +1237,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = refine_factor
return self
def output_schema(self) -> pa.Schema:
"""
Return the output schema for the query
This does not execute the query.
"""
return self._table._output_schema(self.to_query_object())
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
"""
Execute the query and return the results as an
@@ -1452,6 +1460,14 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
offset=self._offset,
)
def output_schema(self) -> pa.Schema:
"""
Return the output schema for the query
This does not execute the query.
"""
return self._table._output_schema(self.to_query_object())
def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table:
path, fs, exist = self._table._get_fts_index_path()
if exist:
@@ -1595,6 +1611,10 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
offset=self._offset,
)
def output_schema(self) -> pa.Schema:
query = self.to_query_object()
return self._table._output_schema(query)
def to_batches(
self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None
) -> pa.RecordBatchReader:
@@ -2238,6 +2258,14 @@ class AsyncQueryBase(object):
)
)
async def output_schema(self) -> pa.Schema:
"""
Return the output schema for the query
This does not execute the query.
"""
return await self._inner.output_schema()
async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
"""
Execute the query and collect the results into an Apache Arrow Table.
@@ -3193,6 +3221,14 @@ class BaseQueryBuilder(object):
self._inner.with_row_id()
return self
def output_schema(self) -> pa.Schema:
"""
Return the output schema for the query
This does not execute the query.
"""
return LOOP.run(self._inner.output_schema())
def to_batches(
self,
*,

View File

@@ -436,6 +436,9 @@ class RemoteTable(Table):
def _analyze_plan(self, query: Query) -> str:
return LOOP.run(self._table._analyze_plan(query))
def _output_schema(self, query: Query) -> pa.Schema:
return LOOP.run(self._table._output_schema(query))
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
that can be used to create a "merge insert" operation.

View File

@@ -1248,6 +1248,9 @@ class Table(ABC):
@abstractmethod
def _analyze_plan(self, query: Query) -> str: ...
@abstractmethod
def _output_schema(self, query: Query) -> pa.Schema: ...
@abstractmethod
def _do_merge(
self,
@@ -2761,6 +2764,9 @@ class LanceTable(Table):
def _analyze_plan(self, query: Query) -> str:
return LOOP.run(self._table._analyze_plan(query))
def _output_schema(self, query: Query) -> pa.Schema:
return LOOP.run(self._table._output_schema(query))
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
@@ -3918,6 +3924,10 @@ class AsyncTable:
async_query = self._sync_query_to_async(query)
return await async_query.analyze_plan()
async def _output_schema(self, query: Query) -> pa.Schema:
async_query = self._sync_query_to_async(query)
return await async_query.output_schema()
async def _do_merge(
self,
merge: LanceMergeInsertBuilder,

View File

@@ -1298,6 +1298,79 @@ async def test_query_serialization_async(table_async: AsyncTable):
)
def test_query_schema(tmp_path):
db = lancedb.connect(tmp_path)
tbl = db.create_table(
"test",
pa.table(
{
"a": [1, 2, 3],
"text": ["a", "b", "c"],
"vec": pa.array(
[[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2)
),
}
),
)
assert tbl.search(None).output_schema() == pa.schema(
{
"a": pa.int64(),
"text": pa.string(),
"vec": pa.list_(pa.float32(), list_size=2),
}
)
assert tbl.search(None).select({"bl": "a * 2"}).output_schema() == pa.schema(
{"bl": pa.int64()}
)
assert tbl.search([1, 2]).select(["a"]).output_schema() == pa.schema(
{"a": pa.int64(), "_distance": pa.float32()}
)
assert tbl.search("blah").select(["a"]).output_schema() == pa.schema(
{"a": pa.int64()}
)
assert tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema(
{"text": pa.string()}
)
@pytest.mark.asyncio
async def test_query_schema_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
tbl = await db.create_table(
"test",
pa.table(
{
"a": [1, 2, 3],
"text": ["a", "b", "c"],
"vec": pa.array(
[[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2)
),
}
),
)
assert await tbl.query().output_schema() == pa.schema(
{
"a": pa.int64(),
"text": pa.string(),
"vec": pa.list_(pa.float32(), list_size=2),
}
)
assert await tbl.query().select({"bl": "a * 2"}).output_schema() == pa.schema(
{"bl": pa.int64()}
)
assert await tbl.vector_search([1, 2]).select(["a"]).output_schema() == pa.schema(
{"a": pa.int64(), "_distance": pa.float32()}
)
assert await (await tbl.search("blah")).select(["a"]).output_schema() == pa.schema(
{"a": pa.int64()}
)
assert await tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema(
{"text": pa.string()}
)
def test_query_timeout(tmp_path):
# Use local directory instead of memory:// to add a bit of latency to
# operations so a timeout of zero will trigger exceptions.

View File

@@ -9,6 +9,7 @@ use arrow::array::Array;
use arrow::array::ArrayData;
use arrow::pyarrow::FromPyArrow;
use arrow::pyarrow::IntoPyArrow;
use arrow::pyarrow::ToPyArrow;
use lancedb::index::scalar::{
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
Operator, PhraseQuery,
@@ -30,6 +31,7 @@ use pyo3::IntoPyObject;
use pyo3::PyAny;
use pyo3::PyRef;
use pyo3::PyResult;
use pyo3::Python;
use pyo3::{exceptions::PyRuntimeError, FromPyObject};
use pyo3::{
exceptions::{PyNotImplementedError, PyValueError},
@@ -445,6 +447,15 @@ impl Query {
})
}
#[pyo3(signature = ())]
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,
@@ -515,6 +526,15 @@ impl TakeQuery {
self.inner = self.inner.clone().with_row_id();
}
#[pyo3(signature = ())]
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,
@@ -601,6 +621,15 @@ impl FTSQuery {
self.inner = self.inner.clone().postfilter();
}
#[pyo3(signature = ())]
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,
@@ -771,6 +800,15 @@ impl VectorQuery {
self.inner = self.inner.clone().bypass_vector_index()
}
#[pyo3(signature = ())]
pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let schema = inner.output_schema().await.infer_error()?;
Python::with_gil(|py| schema.to_pyarrow(py))
})
}
#[pyo3(signature = (max_batch_length=None, timeout=None))]
pub fn execute(
self_: PyRef<'_, Self>,