diff --git a/docs/src/js/classes/Query.md b/docs/src/js/classes/Query.md index 7e467b85..c10c2f28 100644 --- a/docs/src/js/classes/Query.md +++ b/docs/src/js/classes/Query.md @@ -343,6 +343,29 @@ This is useful for pagination. *** +### outputSchema() + +```ts +outputSchema(): Promise> +``` + +Returns the schema of the output that will be returned by this query. + +This can be used to inspect the types and names of the columns that will be +returned by the query before executing it. + +#### Returns + +`Promise`<`Schema`<`any`>> + +An Arrow Schema describing the output columns. + +#### Inherited from + +`StandardQueryBase.outputSchema` + +*** + ### select() ```ts diff --git a/docs/src/js/classes/QueryBase.md b/docs/src/js/classes/QueryBase.md index b177b9e0..91aae97b 100644 --- a/docs/src/js/classes/QueryBase.md +++ b/docs/src/js/classes/QueryBase.md @@ -140,6 +140,25 @@ const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan(); *** +### outputSchema() + +```ts +outputSchema(): Promise> +``` + +Returns the schema of the output that will be returned by this query. + +This can be used to inspect the types and names of the columns that will be +returned by the query before executing it. + +#### Returns + +`Promise`<`Schema`<`any`>> + +An Arrow Schema describing the output columns. + +*** + ### select() ```ts diff --git a/docs/src/js/classes/TakeQuery.md b/docs/src/js/classes/TakeQuery.md index cda76fd5..4b1d168d 100644 --- a/docs/src/js/classes/TakeQuery.md +++ b/docs/src/js/classes/TakeQuery.md @@ -143,6 +143,29 @@ const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan(); *** +### outputSchema() + +```ts +outputSchema(): Promise> +``` + +Returns the schema of the output that will be returned by this query. + +This can be used to inspect the types and names of the columns that will be +returned by the query before executing it. + +#### Returns + +`Promise`<`Schema`<`any`>> + +An Arrow Schema describing the output columns. + +#### Inherited from + +[`QueryBase`](QueryBase.md).[`outputSchema`](QueryBase.md#outputschema) + +*** + ### select() ```ts diff --git a/docs/src/js/classes/VectorQuery.md b/docs/src/js/classes/VectorQuery.md index f935cd21..05554f91 100644 --- a/docs/src/js/classes/VectorQuery.md +++ b/docs/src/js/classes/VectorQuery.md @@ -498,6 +498,29 @@ This is useful for pagination. *** +### outputSchema() + +```ts +outputSchema(): Promise> +``` + +Returns the schema of the output that will be returned by this query. + +This can be used to inspect the types and names of the columns that will be +returned by the query before executing it. + +#### Returns + +`Promise`<`Schema`<`any`>> + +An Arrow Schema describing the output columns. + +#### Inherited from + +`StandardQueryBase.outputSchema` + +*** + ### postfilter() ```ts diff --git a/docs/src/js/interfaces/IvfRqOptions.md b/docs/src/js/interfaces/IvfRqOptions.md new file mode 100644 index 00000000..6b48cf65 --- /dev/null +++ b/docs/src/js/interfaces/IvfRqOptions.md @@ -0,0 +1,101 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / IvfRqOptions + +# Interface: IvfRqOptions + +## Properties + +### distanceType? + +```ts +optional distanceType: "l2" | "cosine" | "dot"; +``` + +Distance type to use to build the index. + +Default value is "l2". + +This is used when training the index to calculate the IVF partitions +(vectors are grouped in partitions with similar vectors according to this +distance type) and during quantization. + +The distance type used to train an index MUST match the distance type used +to search the index. Failure to do so will yield inaccurate results. + +The following distance types are available: + +"l2" - Euclidean distance. +"cosine" - Cosine distance. +"dot" - Dot product. + +*** + +### maxIterations? + +```ts +optional maxIterations: number; +``` + +Max iterations to train IVF kmeans. + +When training an IVF index we use kmeans to calculate the partitions. This parameter +controls how many iterations of kmeans to run. + +The default value is 50. + +*** + +### numBits? + +```ts +optional numBits: number; +``` + +Number of bits per dimension for residual quantization. + +This value controls how much each residual component is compressed. The more +bits, the more accurate the index will be but the slower search. Typical values +are small integers; the default is 1 bit per dimension. + +*** + +### numPartitions? + +```ts +optional numPartitions: number; +``` + +The number of IVF partitions to create. + +This value should generally scale with the number of rows in the dataset. +By default the number of partitions is the square root of the number of +rows. + +If this value is too large then the first part of the search (picking the +right partition) will be slow. If this value is too small then the second +part of the search (searching within a partition) will be slow. + +*** + +### sampleRate? + +```ts +optional sampleRate: number; +``` + +The number of vectors, per partition, to sample when training IVF kmeans. + +When an IVF index is trained, we need to calculate partitions. These are groups +of vectors that are similar to each other. To do this we use an algorithm called kmeans. + +Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a +random sample of the data. This parameter controls the size of the sample. The total +number of vectors used to train the index is `sample_rate * num_partitions`. + +Increasing this value might improve the quality of the index but in most cases the +default should be sufficient. + +The default value is 256. diff --git a/nodejs/__test__/query.test.ts b/nodejs/__test__/query.test.ts new file mode 100644 index 00000000..a3d85dc3 --- /dev/null +++ b/nodejs/__test__/query.test.ts @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import * as tmp from "tmp"; + +import { type Table, connect } from "../lancedb"; +import { + Field, + FixedSizeList, + Float32, + Int64, + Schema, + Utf8, + makeArrowTable, +} from "../lancedb/arrow"; +import { Index } from "../lancedb/indices"; + +describe("Query outputSchema", () => { + let tmpDir: tmp.DirResult; + let table: Table; + + beforeEach(async () => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + const db = await connect(tmpDir.name); + + // Create table with explicit schema to ensure proper types + const schema = new Schema([ + new Field("a", new Int64(), true), + new Field("text", new Utf8(), true), + new Field( + "vec", + new FixedSizeList(2, new Field("item", new Float32())), + true, + ), + ]); + + const data = makeArrowTable( + [ + { a: 1n, text: "foo", vec: [1, 2] }, + { a: 2n, text: "bar", vec: [3, 4] }, + { a: 3n, text: "baz", vec: [5, 6] }, + ], + { schema }, + ); + table = await db.createTable("test", data); + }); + + afterEach(() => { + tmpDir.removeCallback(); + }); + + it("should return schema for plain query", async () => { + const schema = await table.query().outputSchema(); + + expect(schema.fields.length).toBe(3); + expect(schema.fields.map((f) => f.name)).toEqual(["a", "text", "vec"]); + expect(schema.fields[0].type.toString()).toBe("Int64"); + expect(schema.fields[1].type.toString()).toBe("Utf8"); + }); + + it("should return schema with dynamic projection", async () => { + const schema = await table.query().select({ bl: "a * 2" }).outputSchema(); + + expect(schema.fields.length).toBe(1); + expect(schema.fields[0].name).toBe("bl"); + expect(schema.fields[0].type.toString()).toBe("Int64"); + }); + + it("should return schema for vector search with _distance column", async () => { + const schema = await table + .vectorSearch([1, 2]) + .select(["a"]) + .outputSchema(); + + expect(schema.fields.length).toBe(2); + expect(schema.fields.map((f) => f.name)).toEqual(["a", "_distance"]); + expect(schema.fields[0].type.toString()).toBe("Int64"); + expect(schema.fields[1].type.toString()).toBe("Float32"); + }); + + it("should return schema for FTS search", async () => { + await table.createIndex("text", { config: Index.fts() }); + + const schema = await table + .search("foo", "fts") + .select(["a"]) + .outputSchema(); + + // FTS search includes _score column in addition to selected columns + expect(schema.fields.length).toBe(2); + expect(schema.fields.map((f) => f.name)).toContain("a"); + expect(schema.fields.map((f) => f.name)).toContain("_score"); + const aField = schema.fields.find((f) => f.name === "a"); + expect(aField?.type.toString()).toBe("Int64"); + }); + + it("should return schema for take query", async () => { + const schema = await table.takeOffsets([0]).select(["text"]).outputSchema(); + + expect(schema.fields.length).toBe(1); + expect(schema.fields[0].name).toBe("text"); + expect(schema.fields[0].type.toString()).toBe("Utf8"); + }); + + it("should return full schema when no select is specified", async () => { + const schema = await table.query().outputSchema(); + + // Should return all columns + expect(schema.fields.length).toBe(3); + }); +}); diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index 2fbe48b8..19c87c70 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -326,6 +326,25 @@ export class QueryBase< return this.inner.analyzePlan(); } } + + /** + * Returns the schema of the output that will be returned by this query. + * + * This can be used to inspect the types and names of the columns that will be + * returned by the query before executing it. + * + * @returns An Arrow Schema describing the output columns. + */ + async outputSchema(): Promise { + let schemaBuffer: Buffer; + if (this.inner instanceof Promise) { + schemaBuffer = await this.inner.then((inner) => inner.outputSchema()); + } else { + schemaBuffer = await this.inner.outputSchema(); + } + const schema = tableFromIPC(schemaBuffer).schema; + return schema; + } } export class StandardQueryBase< diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 32b43482..e4a4dbfb 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -22,7 +22,7 @@ use crate::error::NapiErrorExt; use crate::iterator::RecordBatchIterator; use crate::rerankers::Reranker; use crate::rerankers::RerankerCallbacks; -use crate::util::parse_distance_type; +use crate::util::{parse_distance_type, schema_to_buffer}; #[napi] pub struct Query { @@ -88,6 +88,12 @@ impl Query { self.inner = self.inner.clone().with_row_id(); } + #[napi(catch_unwind)] + pub async fn output_schema(&self) -> napi::Result { + let schema = self.inner.output_schema().await.default_error()?; + schema_to_buffer(&schema) + } + #[napi(catch_unwind)] pub async fn execute( &self, @@ -273,6 +279,12 @@ impl VectorQuery { .rerank(Arc::new(Reranker::new(callbacks))); } + #[napi(catch_unwind)] + pub async fn output_schema(&self) -> napi::Result { + let schema = self.inner.output_schema().await.default_error()?; + schema_to_buffer(&schema) + } + #[napi(catch_unwind)] pub async fn execute( &self, @@ -346,6 +358,12 @@ impl TakeQuery { self.inner = self.inner.clone().with_row_id(); } + #[napi(catch_unwind)] + pub async fn output_schema(&self) -> napi::Result { + let schema = self.inner.output_schema().await.default_error()?; + schema_to_buffer(&schema) + } + #[napi(catch_unwind)] pub async fn execute( &self, diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index b1f037fe..56517e76 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -3,7 +3,6 @@ use std::collections::HashMap; -use arrow_ipc::writer::FileWriter; use lancedb::ipc::ipc_file_to_batches; use lancedb::table::{ AddDataMode, ColumnAlteration as LanceColumnAlteration, Duration, NewColumnTransform, @@ -16,6 +15,7 @@ use crate::error::NapiErrorExt; use crate::index::Index; use crate::merge::NativeMergeInsertBuilder; use crate::query::{Query, TakeQuery, VectorQuery}; +use crate::util::schema_to_buffer; #[napi] pub struct Table { @@ -64,14 +64,7 @@ impl Table { #[napi(catch_unwind)] pub async fn schema(&self) -> napi::Result { let schema = self.inner_ref()?.schema().await.default_error()?; - let mut writer = FileWriter::try_new(vec![], &schema) - .map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?; - writer - .finish() - .map_err(|e| napi::Error::from_reason(format!("Failed to finish IPC file: {}", e)))?; - Ok(Buffer::from(writer.into_inner().map_err(|e| { - napi::Error::from_reason(format!("Failed to get IPC file: {}", e)) - })?)) + schema_to_buffer(&schema) } #[napi(catch_unwind)] diff --git a/nodejs/src/util.rs b/nodejs/src/util.rs index a29a67f9..ba9cf40c 100644 --- a/nodejs/src/util.rs +++ b/nodejs/src/util.rs @@ -1,7 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors +use arrow_ipc::writer::FileWriter; +use arrow_schema::Schema; use lancedb::DistanceType; +use napi::bindgen_prelude::Buffer; pub fn parse_distance_type(distance_type: impl AsRef) -> napi::Result { match distance_type.as_ref().to_lowercase().as_str() { @@ -15,3 +18,15 @@ pub fn parse_distance_type(distance_type: impl AsRef) -> napi::Result napi::Result { + let mut writer = FileWriter::try_new(vec![], schema) + .map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?; + writer + .finish() + .map_err(|e| napi::Error::from_reason(format!("Failed to finish IPC file: {}", e)))?; + Ok(Buffer::from(writer.into_inner().map_err(|e| { + napi::Error::from_reason(format!("Failed to get IPC file: {}", e)) + })?)) +} diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 378c4a09..1b8eb0d8 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -123,6 +123,8 @@ class Table: @property def tags(self) -> Tags: ... def query(self) -> Query: ... + def take_offsets(self, offsets: list[int]) -> TakeQuery: ... + def take_row_ids(self, row_ids: list[int]) -> TakeQuery: ... def vector_search(self) -> VectorQuery: ... class Tags: @@ -165,6 +167,7 @@ class Query: def postfilter(self): ... def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ... def nearest_to_text(self, query: dict) -> FTSQuery: ... + async def output_schema(self) -> pa.Schema: ... async def execute( self, max_batch_length: Optional[int], timeout: Optional[timedelta] ) -> RecordBatchStream: ... @@ -172,6 +175,13 @@ class Query: async def analyze_plan(self) -> str: ... def to_query_request(self) -> PyQueryRequest: ... +class TakeQuery: + def select(self, columns: List[str]): ... + def with_row_id(self): ... + async def output_schema(self) -> pa.Schema: ... + async def execute(self) -> RecordBatchStream: ... + def to_query_request(self) -> PyQueryRequest: ... + class FTSQuery: def where(self, filter: str): ... def select(self, columns: List[str]): ... @@ -183,12 +193,14 @@ class FTSQuery: def get_query(self) -> str: ... def add_query_vector(self, query_vec: pa.Array) -> None: ... def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ... + async def output_schema(self) -> pa.Schema: ... async def execute( self, max_batch_length: Optional[int], timeout: Optional[timedelta] ) -> RecordBatchStream: ... def to_query_request(self) -> PyQueryRequest: ... class VectorQuery: + async def output_schema(self) -> pa.Schema: ... async def execute(self) -> RecordBatchStream: ... def where(self, filter: str): ... def select(self, columns: List[str]): ... diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 6985ed45..0b2b842b 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -1237,6 +1237,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._refine_factor = refine_factor return self + def output_schema(self) -> pa.Schema: + """ + Return the output schema for the query + + This does not execute the query. + """ + return self._table._output_schema(self.to_query_object()) + def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table: """ Execute the query and return the results as an @@ -1452,6 +1460,14 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): offset=self._offset, ) + def output_schema(self) -> pa.Schema: + """ + Return the output schema for the query + + This does not execute the query. + """ + return self._table._output_schema(self.to_query_object()) + def to_arrow(self, *, timeout: Optional[timedelta] = None) -> pa.Table: path, fs, exist = self._table._get_fts_index_path() if exist: @@ -1595,6 +1611,10 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder): offset=self._offset, ) + def output_schema(self) -> pa.Schema: + query = self.to_query_object() + return self._table._output_schema(query) + def to_batches( self, /, batch_size: Optional[int] = None, timeout: Optional[timedelta] = None ) -> pa.RecordBatchReader: @@ -2238,6 +2258,14 @@ class AsyncQueryBase(object): ) ) + async def output_schema(self) -> pa.Schema: + """ + Return the output schema for the query + + This does not execute the query. + """ + return await self._inner.output_schema() + async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table: """ Execute the query and collect the results into an Apache Arrow Table. @@ -3193,6 +3221,14 @@ class BaseQueryBuilder(object): self._inner.with_row_id() return self + def output_schema(self) -> pa.Schema: + """ + Return the output schema for the query + + This does not execute the query. + """ + return LOOP.run(self._inner.output_schema()) + def to_batches( self, *, diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 81d46d35..6c83f917 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -436,6 +436,9 @@ class RemoteTable(Table): def _analyze_plan(self, query: Query) -> str: return LOOP.run(self._table._analyze_plan(query)) + def _output_schema(self, query: Query) -> pa.Schema: + return LOOP.run(self._table._output_schema(query)) + def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: """Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder] that can be used to create a "merge insert" operation. diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 6749064d..4b595099 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1248,6 +1248,9 @@ class Table(ABC): @abstractmethod def _analyze_plan(self, query: Query) -> str: ... + @abstractmethod + def _output_schema(self, query: Query) -> pa.Schema: ... + @abstractmethod def _do_merge( self, @@ -2761,6 +2764,9 @@ class LanceTable(Table): def _analyze_plan(self, query: Query) -> str: return LOOP.run(self._table._analyze_plan(query)) + def _output_schema(self, query: Query) -> pa.Schema: + return LOOP.run(self._table._output_schema(query)) + def _do_merge( self, merge: LanceMergeInsertBuilder, @@ -3918,6 +3924,10 @@ class AsyncTable: async_query = self._sync_query_to_async(query) return await async_query.analyze_plan() + async def _output_schema(self, query: Query) -> pa.Schema: + async_query = self._sync_query_to_async(query) + return await async_query.output_schema() + async def _do_merge( self, merge: LanceMergeInsertBuilder, diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index b80984da..bd1905bf 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -1298,6 +1298,79 @@ async def test_query_serialization_async(table_async: AsyncTable): ) +def test_query_schema(tmp_path): + db = lancedb.connect(tmp_path) + tbl = db.create_table( + "test", + pa.table( + { + "a": [1, 2, 3], + "text": ["a", "b", "c"], + "vec": pa.array( + [[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2) + ), + } + ), + ) + + assert tbl.search(None).output_schema() == pa.schema( + { + "a": pa.int64(), + "text": pa.string(), + "vec": pa.list_(pa.float32(), list_size=2), + } + ) + assert tbl.search(None).select({"bl": "a * 2"}).output_schema() == pa.schema( + {"bl": pa.int64()} + ) + assert tbl.search([1, 2]).select(["a"]).output_schema() == pa.schema( + {"a": pa.int64(), "_distance": pa.float32()} + ) + assert tbl.search("blah").select(["a"]).output_schema() == pa.schema( + {"a": pa.int64()} + ) + assert tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema( + {"text": pa.string()} + ) + + +@pytest.mark.asyncio +async def test_query_schema_async(tmp_path): + db = await lancedb.connect_async(tmp_path) + tbl = await db.create_table( + "test", + pa.table( + { + "a": [1, 2, 3], + "text": ["a", "b", "c"], + "vec": pa.array( + [[1, 2], [3, 4], [5, 6]], pa.list_(pa.float32(), list_size=2) + ), + } + ), + ) + + assert await tbl.query().output_schema() == pa.schema( + { + "a": pa.int64(), + "text": pa.string(), + "vec": pa.list_(pa.float32(), list_size=2), + } + ) + assert await tbl.query().select({"bl": "a * 2"}).output_schema() == pa.schema( + {"bl": pa.int64()} + ) + assert await tbl.vector_search([1, 2]).select(["a"]).output_schema() == pa.schema( + {"a": pa.int64(), "_distance": pa.float32()} + ) + assert await (await tbl.search("blah")).select(["a"]).output_schema() == pa.schema( + {"a": pa.int64()} + ) + assert await tbl.take_offsets([0]).select(["text"]).output_schema() == pa.schema( + {"text": pa.string()} + ) + + def test_query_timeout(tmp_path): # Use local directory instead of memory:// to add a bit of latency to # operations so a timeout of zero will trigger exceptions. diff --git a/python/src/query.rs b/python/src/query.rs index aa285c01..aeb03182 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -9,6 +9,7 @@ use arrow::array::Array; use arrow::array::ArrayData; use arrow::pyarrow::FromPyArrow; use arrow::pyarrow::IntoPyArrow; +use arrow::pyarrow::ToPyArrow; use lancedb::index::scalar::{ BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur, Operator, PhraseQuery, @@ -30,6 +31,7 @@ use pyo3::IntoPyObject; use pyo3::PyAny; use pyo3::PyRef; use pyo3::PyResult; +use pyo3::Python; use pyo3::{exceptions::PyRuntimeError, FromPyObject}; use pyo3::{ exceptions::{PyNotImplementedError, PyValueError}, @@ -445,6 +447,15 @@ impl Query { }) } + #[pyo3(signature = ())] + pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let schema = inner.output_schema().await.infer_error()?; + Python::with_gil(|py| schema.to_pyarrow(py)) + }) + } + #[pyo3(signature = (max_batch_length=None, timeout=None))] pub fn execute( self_: PyRef<'_, Self>, @@ -515,6 +526,15 @@ impl TakeQuery { self.inner = self.inner.clone().with_row_id(); } + #[pyo3(signature = ())] + pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let schema = inner.output_schema().await.infer_error()?; + Python::with_gil(|py| schema.to_pyarrow(py)) + }) + } + #[pyo3(signature = (max_batch_length=None, timeout=None))] pub fn execute( self_: PyRef<'_, Self>, @@ -601,6 +621,15 @@ impl FTSQuery { self.inner = self.inner.clone().postfilter(); } + #[pyo3(signature = ())] + pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let schema = inner.output_schema().await.infer_error()?; + Python::with_gil(|py| schema.to_pyarrow(py)) + }) + } + #[pyo3(signature = (max_batch_length=None, timeout=None))] pub fn execute( self_: PyRef<'_, Self>, @@ -771,6 +800,15 @@ impl VectorQuery { self.inner = self.inner.clone().bypass_vector_index() } + #[pyo3(signature = ())] + pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + let schema = inner.output_schema().await.infer_error()?; + Python::with_gil(|py| schema.to_pyarrow(py)) + }) + } + #[pyo3(signature = (max_batch_length=None, timeout=None))] pub fn execute( self_: PyRef<'_, Self>, diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 9663379d..34151e3b 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -6,10 +6,10 @@ use std::{future::Future, time::Duration}; use arrow::compute::concat_batches; use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array}; -use arrow_schema::DataType; +use arrow_schema::{DataType, SchemaRef}; use datafusion_expr::Expr; use datafusion_physical_plan::ExecutionPlan; -use futures::{stream, try_join, FutureExt, TryStreamExt}; +use futures::{stream, try_join, FutureExt, TryFutureExt, TryStreamExt}; use half::f16; use lance::{ arrow::RecordBatchExt, @@ -582,16 +582,40 @@ pub trait ExecutableQuery { options: QueryExecutionOptions, ) -> impl Future> + Send; + /// Explain the plan for a query + /// + /// This will create a string representation of the plan that will be used to + /// execute the query. This will not execute the query. + /// + /// This function can be used to get an understanding of what work will be done by the query + /// and is useful for debugging query performance. fn explain_plan(&self, verbose: bool) -> impl Future> + Send; + /// Execute the query and display the runtime metrics + /// + /// This shows the same plan as [`ExecutableQuery::explain_plan`] but includes runtime metrics. + /// + /// This function will actually execute the query in order to get the runtime metrics. fn analyze_plan(&self) -> impl Future> + Send { self.analyze_plan_with_options(QueryExecutionOptions::default()) } + /// Execute the query and display the runtime metrics + /// + /// This is the same as [`ExecutableQuery::analyze_plan`] but allows for specifying the execution options. fn analyze_plan_with_options( &self, options: QueryExecutionOptions, ) -> impl Future> + Send; + + /// Return the output schema for data returned by the query without actually executing the query + /// + /// This can be useful when the selection for a query is built dynamically as it is not always + /// obvious what the output schema will be. + fn output_schema(&self) -> impl Future> + Send { + self.create_plan(QueryExecutionOptions::default()) + .and_then(|plan| std::future::ready(Ok(plan.schema()))) + } } /// A query filter that can be applied to a query @@ -1505,6 +1529,16 @@ mod tests { .query() .limit(10) .select(Select::dynamic(&[("id2", "id * 2"), ("id", "id")])); + + let schema = query.output_schema().await.unwrap(); + assert_eq!( + schema, + Arc::new(ArrowSchema::new(vec![ + ArrowField::new("id2", DataType::Int32, true), + ArrowField::new("id", DataType::Int32, true), + ])) + ); + let result = query.execute().await; let mut batches = result .expect("should have result")