diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 1c3fb4ab..8ed9ec8b 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -745,3 +745,27 @@ describe("table.search", () => { expect(results[0].text).toBe(data[1].text); }); }); + +describe("when calling explainPlan", () => { + let tmpDir: tmp.DirResult; + let table: Table; + let queryVec: number[]; + beforeEach(async () => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + const con = await connect(tmpDir.name); + table = await con.createTable("vectors", [{ id: 1, vector: [0.1, 0.2] }]); + }); + + afterEach(() => { + tmpDir.removeCallback(); + }); + + it("retrieves query plan", async () => { + queryVec = Array(2) + .fill(1) + .map(() => Math.random()); + const plan = await table.query().nearestTo(queryVec).explainPlan(true); + + expect(plan).toMatch("KNN"); + }); +}); diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index 97829cda..48cc7693 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -226,6 +226,24 @@ export class QueryBase< const tbl = await this.toArrow(options); return tbl.toArray(); } + + /** + * Generates an explanation of the query execution plan. + * + * @example + * import * as lancedb from "@lancedb/lancedb" + * const db = await lancedb.connect("./.lancedb"); + * const table = await db.createTable("my_table", [ + * { vector: [1.1, 0.9], id: "1" }, + * ]); + * const plan = await table.query().nearestTo([0.5, 0.2]).explainPlan(); + * + * @param verbose - If true, provides a more detailed explanation. Defaults to false. + * @returns A Promise that resolves to a string containing the query execution plan explanation. + */ + async explainPlan(verbose = false): Promise { + return await this.inner.explainPlan(verbose); + } } /** diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 012cd125..68b6511d 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -80,6 +80,13 @@ impl Query { })?; Ok(RecordBatchIterator::new(inner_stream)) } + + #[napi] + pub async fn explain_plan(&self, verbose: bool) -> napi::Result { + self.inner.explain_plan(verbose).await.map_err(|e| { + napi::Error::from_reason(format!("Failed to retrieve the query plan: {}", e)) + }) + } } #[napi] @@ -154,4 +161,11 @@ impl VectorQuery { })?; Ok(RecordBatchIterator::new(inner_stream)) } + + #[napi] + pub async fn explain_plan(&self, verbose: bool) -> napi::Result { + self.inner.explain_plan(verbose).await.map_err(|e| { + napi::Error::from_reason(format!("Failed to retrieve the query plan: {}", e)) + }) + } } diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 35201de9..c50810a0 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -417,6 +417,38 @@ class LanceQueryBuilder(ABC): self._with_row_id = with_row_id return self + def explain_plan(self, verbose: Optional[bool] = False) -> str: + """Return the execution plan for this query. + + Examples + -------- + >>> import lancedb + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", [{"vector": [99, 99]}]) + >>> query = [100, 100] + >>> plan = table.search(query).explain_plan(True) + >>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + Projection: fields=[vector, _distance] + KNNFlat: k=10 metric=l2 + LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false + + Parameters + ---------- + verbose : bool, default False + Use a verbose output format. + + Returns + ------- + plan : str + """ # noqa: E501 + ds = self._table.to_lance() + return ds.scanner( + nearest={ + "column": self._vector_column, + "q": self._query, + }, + ).explain_plan(verbose) + class LanceVectorQueryBuilder(LanceQueryBuilder): """ @@ -1166,6 +1198,35 @@ class AsyncQueryBase(object): """ return (await self.to_arrow()).to_pandas() + async def explain_plan(self, verbose: Optional[bool] = False): + """Return the execution plan for this query. + + Examples + -------- + >>> import asyncio + >>> from lancedb import connect_async + >>> async def doctest_example(): + ... conn = await connect_async("./.lancedb") + ... table = await conn.create_table("my_table", [{"vector": [99, 99]}]) + ... query = [100, 100] + ... plan = await table.query().nearest_to([1, 2]).explain_plan(True) + ... print(plan) + >>> asyncio.run(doctest_example()) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + Projection: fields=[vector, _distance] + KNNFlat: k=10 metric=l2 + LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false + + Parameters + ---------- + verbose : bool, default False + Use a verbose output format. + + Returns + ------- + plan : str + """ # noqa: E501 + return await self._inner.explain_plan(verbose) + class AsyncQuery(AsyncQueryBase): def __init__(self, inner: LanceQuery): diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 6f047bd3..89c5530e 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -333,3 +333,15 @@ async def test_query_to_pandas_async(table_async: AsyncTable): df = await table_async.query().where("id < 0").to_pandas() assert df.shape == (0, 4) + + +def test_explain_plan(table): + q = LanceVectorQueryBuilder(table, [0, 0], "vector") + plan = q.explain_plan(verbose=True) + assert "KNN" in plan + + +@pytest.mark.asyncio +async def test_explain_plan_async(table_async: AsyncTable): + plan = await table_async.query().nearest_to(pa.array([1, 2])).explain_plan(True) + assert "KNN" in plan diff --git a/python/src/query.rs b/python/src/query.rs index fadf2aeb..1cdedc66 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -19,6 +19,7 @@ use lancedb::query::QueryExecutionOptions; use lancedb::query::{ ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery, }; +use pyo3::exceptions::PyRuntimeError; use pyo3::pyclass; use pyo3::pymethods; use pyo3::PyAny; @@ -73,6 +74,16 @@ impl Query { Ok(RecordBatchStream::new(inner_stream)) }) } + + fn explain_plan(self_: PyRef<'_, Self>, verbose: bool) -> PyResult<&PyAny> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + inner + .explain_plan(verbose) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + }) + } } #[pyclass] @@ -131,4 +142,14 @@ impl VectorQuery { Ok(RecordBatchStream::new(inner_stream)) }) } + + fn explain_plan(self_: PyRef<'_, Self>, verbose: bool) -> PyResult<&PyAny> { + let inner = self_.inner.clone(); + future_into_py(self_.py(), async move { + inner + .explain_plan(verbose) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + }) + } } diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index cc915556..87a829d4 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -465,6 +465,8 @@ pub trait ExecutableQuery { &self, options: QueryExecutionOptions, ) -> impl Future> + Send; + + fn explain_plan(&self, verbose: bool) -> impl Future> + Send; } /// A builder for LanceDB queries. @@ -572,6 +574,12 @@ impl ExecutableQuery for Query { self.parent.clone().plain_query(self, options).await?, )) } + + async fn explain_plan(&self, verbose: bool) -> Result { + self.parent + .explain_plan(&self.clone().into_vector(), verbose) + .await + } } /// A builder for vector searches @@ -752,6 +760,10 @@ impl ExecutableQuery for VectorQuery { )?), )) } + + async fn explain_plan(&self, verbose: bool) -> Result { + self.base.parent.explain_plan(self, verbose).await + } } impl HasQuery for VectorQuery { diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 1b7c6d20..5add6e6a 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1,10 +1,12 @@ use std::sync::Arc; +use crate::table::dataset::DatasetReadGuard; use arrow_array::RecordBatchReader; use arrow_schema::SchemaRef; use async_trait::async_trait; use datafusion_physical_plan::ExecutionPlan; -use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform}; +use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; +use lance::dataset::{ColumnAlteration, NewColumnTransform}; use crate::{ connection::NoData, @@ -74,6 +76,14 @@ impl TableInternal for RemoteTable { ) -> Result<()> { todo!() } + async fn build_plan( + &self, + _ds_ref: &DatasetReadGuard, + _query: &VectorQuery, + _options: Option, + ) -> Result { + todo!() + } async fn create_plan( &self, _query: &VectorQuery, @@ -81,6 +91,9 @@ impl TableInternal for RemoteTable { ) -> Result> { unimplemented!() } + async fn explain_plan(&self, _query: &VectorQuery, _verbose: bool) -> Result { + todo!() + } async fn plain_query( &self, _query: &Query, diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 18c4e592..2e40cb47 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -65,7 +65,7 @@ use crate::query::{ }; use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam}; -use self::dataset::DatasetConsistencyWrapper; +use self::dataset::{DatasetConsistencyWrapper, DatasetReadGuard}; use self::merge::MergeInsertBuilder; pub(crate) mod dataset; @@ -369,6 +369,12 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn async fn schema(&self) -> Result; /// Count the number of rows in this table. async fn count_rows(&self, filter: Option) -> Result; + async fn build_plan( + &self, + ds_ref: &DatasetReadGuard, + query: &VectorQuery, + options: Option, + ) -> Result; async fn create_plan( &self, query: &VectorQuery, @@ -379,6 +385,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn query: &Query, options: QueryExecutionOptions, ) -> Result; + async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result; async fn add( &self, add: AddDataBuilder, @@ -1667,12 +1674,12 @@ impl TableInternal for NativeTable { Ok(()) } - async fn create_plan( + async fn build_plan( &self, + ds_ref: &DatasetReadGuard, query: &VectorQuery, - options: QueryExecutionOptions, - ) -> Result> { - let ds_ref = self.dataset.get().await?; + options: Option, + ) -> Result { let mut scanner: Scanner = ds_ref.scan(); if let Some(query_vector) = query.query_vector.as_ref() { @@ -1684,9 +1691,11 @@ impl TableInternal for NativeTable { let arrow_schema = Schema::from(ds_ref.schema()); default_vector_column(&arrow_schema, Some(query_vector.len() as i32))? }; + let field = ds_ref.schema().field(&column).ok_or(Error::Schema { message: format!("Column {} not found in dataset schema", column), })?; + if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() { if !f.data_type().is_floating() { return Err(Error::InvalidInput { @@ -1698,16 +1707,17 @@ impl TableInternal for NativeTable { } if dim != query_vector.len() as i32 { return Err(Error::InvalidInput { - message: format!( - "The dimension of the query vector does not match with the dimension of the vector column '{}': \ - query dim={}, expected vector dim={}", - column, - query_vector.len(), - dim, - ), - }); + message: format!( + "The dimension of the query vector does not match with the dimension of the vector column '{}': \ + query dim={}, expected vector dim={}", + column, + query_vector.len(), + dim, + ), + }); } } + let query_vector = query_vector.as_primitive::(); scanner.nearest( &column, @@ -1718,10 +1728,26 @@ impl TableInternal for NativeTable { // If there is no vector query, it's ok to not have a limit scanner.limit(query.base.limit.map(|limit| limit as i64), None)?; } + scanner.nprobs(query.nprobes); scanner.use_index(query.use_index); scanner.prefilter(query.prefilter); - scanner.batch_size(options.max_batch_length as usize); + + if let Some(opts) = options { + scanner.batch_size(opts.max_batch_length as usize); + } + + Ok(scanner) + } + + async fn create_plan( + &self, + query: &VectorQuery, + options: QueryExecutionOptions, + ) -> Result> { + let ds_ref = self.dataset.get().await?; + + let mut scanner = self.build_plan(&ds_ref, query, Some(options)).await?; match &query.base.select { Select::Columns(select) => { @@ -1756,6 +1782,16 @@ impl TableInternal for NativeTable { .await } + async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result { + let ds_ref = self.dataset.get().await?; + + let scanner = self.build_plan(&ds_ref, query, None).await?; + + let plan = scanner.explain_plan(verbose).await?; + + Ok(plan) + } + async fn merge_insert( &self, params: MergeInsertBuilder,