From f3de3d990d9ba46a72a636cca24f3fa2854473fe Mon Sep 17 00:00:00 2001 From: Rob Meng Date: Wed, 28 Feb 2024 11:06:46 -0500 Subject: [PATCH] chore: upgrade to lance 0.10.1 (#1034) upgrade to lance 0.10.1 and update doc string to reflect dynamic projection options --- Cargo.toml | 8 ++-- node/src/test/test.ts | 2 +- nodejs/src/table.rs | 2 + python/pyproject.toml | 2 +- python/python/lancedb/query.py | 21 ++++++--- python/python/lancedb/table.py | 6 +-- python/python/tests/test_query.py | 2 +- rust/ffi/node/src/table.rs | 2 + rust/lancedb/src/query.rs | 74 ++++++++++++++++++++++++++++--- 9 files changed, 97 insertions(+), 22 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8671d22d..e3eb2f92 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"] categories = ["database-implementations"] [workspace.dependencies] -lance = { "version" = "=0.9.18", "features" = ["dynamodb"] } -lance-index = { "version" = "=0.9.18" } -lance-linalg = { "version" = "=0.9.18" } -lance-testing = { "version" = "=0.9.18" } +lance = { "version" = "=0.10.1", "features" = ["dynamodb"] } +lance-index = { "version" = "=0.10.1" } +lance-linalg = { "version" = "=0.10.1" } +lance-testing = { "version" = "=0.10.1" } # Note that this one does not include pyarrow arrow = { version = "50.0", optional = false } arrow-array = "50.0" diff --git a/node/src/test/test.ts b/node/src/test/test.ts index 3f8e5161..fce5adba 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -198,7 +198,7 @@ describe('LanceDB client', function () { const table = await con.openTable('vectors') const results = await table .search([0.1, 0.1]) - .select(['is_active']) + .select(['is_active', 'vector']) .execute() assert.equal(results.length, 2) // vector and _distance are always returned diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index 74d73cbd..29e03f3a 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -184,6 +184,8 @@ impl From for LanceColumnAlteration { path, rename, nullable, + // TODO: wire up this field + data_type: None, } } } diff --git a/python/pyproject.toml b/python/pyproject.toml index 3cbe0a69..999a6104 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -3,7 +3,7 @@ name = "lancedb" version = "0.5.7" dependencies = [ "deprecation", - "pylance==0.9.18", + "pylance==0.10.1", "ratelimiter~=1.0", "retry>=0.9.2", "tqdm>=4.27.0", diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 44809da9..0064ba7f 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -16,7 +16,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union import deprecation import numpy as np @@ -93,7 +93,7 @@ class Query(pydantic.BaseModel): metric: str = "L2" # which columns to return in the results - columns: Optional[List[str]] = None + columns: Optional[Union[List[str], Dict[str, str]]] = None # optional query parameters for tuning the results, # e.g. `{"nprobes": "10", "refine_factor": "10"}` @@ -321,20 +321,27 @@ class LanceQueryBuilder(ABC): self._limit = limit return self - def select(self, columns: list) -> LanceQueryBuilder: + def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder: """Set the columns to return. Parameters ---------- - columns: list - The columns to return. + columns: list of str, or dict of str to str default None + List of column names to be fetched. + Or a dictionary of column names to SQL expressions. + All columns are fetched if None or unspecified. Returns ------- LanceQueryBuilder The LanceQueryBuilder object. """ - self._columns = columns + if isinstance(columns, list): + self._columns = columns + elif isinstance(columns, dict): + self._columns = list(columns.items()) + else: + raise ValueError("columns must be a list or a dictionary") return self def where(self, where: str, prefilter: bool = False) -> LanceQueryBuilder: @@ -392,7 +399,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): >>> (table.search([0.4, 0.4]) ... .metric("cosine") ... .where("b < 10") - ... .select(["b"]) + ... .select(["b", "vector"]) ... .limit(2) ... .to_pandas()) b vector _distance diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index f604cd9c..3fd32a1f 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -159,7 +159,7 @@ class Table(ABC): Can query the table with [Table.search][lancedb.table.Table.search]. - >>> table.search([0.4, 0.4]).select(["b"]).to_pandas() + >>> table.search([0.4, 0.4]).select(["b", "vector"]).to_pandas() b vector _distance 0 4 [0.5, 1.3] 0.82 1 2 [1.1, 1.2] 1.13 @@ -435,7 +435,7 @@ class Table(ABC): >>> query = [0.4, 1.4, 2.4] >>> (table.search(query) ... .where("original_width > 1000", prefilter=True) - ... .select(["caption", "original_width"]) + ... .select(["caption", "original_width", "vector"]) ... .limit(2) ... .to_pandas()) caption original_width vector _distance @@ -1264,7 +1264,7 @@ class LanceTable(Table): >>> query = [0.4, 1.4, 2.4] >>> (table.search(query) ... .where("original_width > 1000", prefilter=True) - ... .select(["caption", "original_width"]) + ... .select(["caption", "original_width", "vector"]) ... .limit(2) ... .to_pandas()) caption original_width vector _distance diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 422f3a23..ed88f9c7 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -87,7 +87,7 @@ def test_query_builder(table): rs = ( LanceVectorQueryBuilder(table, [0, 0], "vector") .limit(1) - .select(["id"]) + .select(["id", "vector"]) .to_list() ) assert rs[0]["id"] == 1 diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index 5aa38727..c687f849 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -604,6 +604,8 @@ impl JsTable { path, rename, nullable, + // TODO: wire up this field + data_type: None, }) }) .collect::>>()?; diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index d52a026f..75e5499b 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -24,6 +24,13 @@ use crate::Error; const DEFAULT_TOP_K: usize = 10; +#[derive(Debug, Clone)] +pub enum Select { + All, + Simple(Vec), + Projection(Vec<(String, String)>), +} + /// A builder for nearest neighbor queries for LanceDB. #[derive(Clone)] pub struct Query { @@ -44,7 +51,7 @@ pub struct Query { /// Apply filter to the returned rows. filter: Option, /// Select column projection. - select: Option>, + select: Select, /// Default is true. Set to false to enforce a brute force search. use_index: bool, @@ -70,7 +77,7 @@ impl Query { metric_type: None, use_index: true, filter: None, - select: None, + select: Select::All, prefilter: false, } } @@ -115,7 +122,16 @@ impl Query { scanner.use_index(self.use_index); scanner.prefilter(self.prefilter); - self.select.as_ref().map(|p| scanner.project(p.as_slice())); + match &self.select { + Select::Simple(select) => { + scanner.project(select.as_slice())?; + } + Select::Projection(select_with_transform) => { + scanner.project_with_transform(select_with_transform.as_slice())?; + } + Select::All => { /* Do nothing */ } + } + self.filter.as_ref().map(|f| scanner.filter(f)); self.refine_factor.map(|rf| scanner.refine(rf)); self.metric_type.map(|mt| scanner.distance_metric(mt)); @@ -206,7 +222,23 @@ impl Query { /// /// Only select the specified columns. If not specified, all columns will be returned. pub fn select(mut self, columns: &[impl AsRef]) -> Self { - self.select = Some(columns.iter().map(|c| c.as_ref().to_string()).collect()); + self.select = Select::Simple(columns.iter().map(|c| c.as_ref().to_string()).collect()); + self + } + + /// Return only the specified columns. + /// + /// Only select the specified columns. If not specified, all columns will be returned. + pub fn select_with_projection( + mut self, + columns: &[(impl AsRef, impl AsRef)], + ) -> Self { + self.select = Select::Projection( + columns + .iter() + .map(|(c, t)| (c.as_ref().to_string(), t.as_ref().to_string())) + .collect(), + ); self } @@ -226,7 +258,7 @@ mod tests { RecordBatchReader, }; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; - use futures::StreamExt; + use futures::{StreamExt, TryStreamExt}; use lance::dataset::Dataset; use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector}; use tempfile::tempdir; @@ -294,6 +326,38 @@ mod tests { } } + #[tokio::test] + async fn test_select_with_transform() { + let batches = make_non_empty_batches(); + let ds = Dataset::write(batches, "memory://foo", None).await.unwrap(); + + let ds = DatasetConsistencyWrapper::new_latest(ds, None); + + let query = Query::new(ds) + .limit(10) + .select_with_projection(&[("id2", "id * 2"), ("id", "id")]); + let result = query.execute_stream().await; + let mut batches = result + .expect("should have result") + .try_collect::>() + .await + .unwrap(); + assert_eq!(batches.len(), 1); + let batch = batches.pop().unwrap(); + + // id, and id2 + assert_eq!(batch.num_columns(), 2); + + let id: &Int32Array = batch.column_by_name("id").unwrap().as_primitive(); + let id2: &Int32Array = batch.column_by_name("id2").unwrap().as_primitive(); + + id.iter().zip(id2.iter()).for_each(|(id, id2)| { + let id = id.unwrap(); + let id2 = id2.unwrap(); + assert_eq!(id * 2, id2); + }); + } + #[tokio::test] async fn test_execute_no_vector() { // test that it's ok to not specify a query vector (just filter / limit)