chore: upgrade to lance 0.10.1 (#1034)

upgrade to lance 0.10.1 and update doc string to reflect dynamic
projection options
This commit is contained in:
Rob Meng
2024-02-28 11:06:46 -05:00
committed by Weston Pace
parent 0a8e258247
commit f3de3d990d
9 changed files with 97 additions and 22 deletions

View File

@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies]
lance = { "version" = "=0.9.18", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.18" }
lance-linalg = { "version" = "=0.9.18" }
lance-testing = { "version" = "=0.9.18" }
lance = { "version" = "=0.10.1", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.10.1" }
lance-linalg = { "version" = "=0.10.1" }
lance-testing = { "version" = "=0.10.1" }
# Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false }
arrow-array = "50.0"

View File

@@ -198,7 +198,7 @@ describe('LanceDB client', function () {
const table = await con.openTable('vectors')
const results = await table
.search([0.1, 0.1])
.select(['is_active'])
.select(['is_active', 'vector'])
.execute()
assert.equal(results.length, 2)
// vector and _distance are always returned

View File

@@ -184,6 +184,8 @@ impl From<ColumnAlteration> for LanceColumnAlteration {
path,
rename,
nullable,
// TODO: wire up this field
data_type: None,
}
}
}

View File

@@ -3,7 +3,7 @@ name = "lancedb"
version = "0.5.7"
dependencies = [
"deprecation",
"pylance==0.9.18",
"pylance==0.10.1",
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.27.0",

View File

@@ -16,7 +16,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union
import deprecation
import numpy as np
@@ -93,7 +93,7 @@ class Query(pydantic.BaseModel):
metric: str = "L2"
# which columns to return in the results
columns: Optional[List[str]] = None
columns: Optional[Union[List[str], Dict[str, str]]] = None
# optional query parameters for tuning the results,
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
@@ -321,20 +321,27 @@ class LanceQueryBuilder(ABC):
self._limit = limit
return self
def select(self, columns: list) -> LanceQueryBuilder:
def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder:
"""Set the columns to return.
Parameters
----------
columns: list
The columns to return.
columns: list of str, or dict of str to str default None
List of column names to be fetched.
Or a dictionary of column names to SQL expressions.
All columns are fetched if None or unspecified.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._columns = columns
if isinstance(columns, list):
self._columns = columns
elif isinstance(columns, dict):
self._columns = list(columns.items())
else:
raise ValueError("columns must be a list or a dictionary")
return self
def where(self, where: str, prefilter: bool = False) -> LanceQueryBuilder:
@@ -392,7 +399,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
>>> (table.search([0.4, 0.4])
... .metric("cosine")
... .where("b < 10")
... .select(["b"])
... .select(["b", "vector"])
... .limit(2)
... .to_pandas())
b vector _distance

View File

@@ -159,7 +159,7 @@ class Table(ABC):
Can query the table with [Table.search][lancedb.table.Table.search].
>>> table.search([0.4, 0.4]).select(["b"]).to_pandas()
>>> table.search([0.4, 0.4]).select(["b", "vector"]).to_pandas()
b vector _distance
0 4 [0.5, 1.3] 0.82
1 2 [1.1, 1.2] 1.13
@@ -435,7 +435,7 @@ class Table(ABC):
>>> query = [0.4, 1.4, 2.4]
>>> (table.search(query)
... .where("original_width > 1000", prefilter=True)
... .select(["caption", "original_width"])
... .select(["caption", "original_width", "vector"])
... .limit(2)
... .to_pandas())
caption original_width vector _distance
@@ -1264,7 +1264,7 @@ class LanceTable(Table):
>>> query = [0.4, 1.4, 2.4]
>>> (table.search(query)
... .where("original_width > 1000", prefilter=True)
... .select(["caption", "original_width"])
... .select(["caption", "original_width", "vector"])
... .limit(2)
... .to_pandas())
caption original_width vector _distance

View File

@@ -87,7 +87,7 @@ def test_query_builder(table):
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.limit(1)
.select(["id"])
.select(["id", "vector"])
.to_list()
)
assert rs[0]["id"] == 1

View File

@@ -604,6 +604,8 @@ impl JsTable {
path,
rename,
nullable,
// TODO: wire up this field
data_type: None,
})
})
.collect::<NeonResult<Vec<ColumnAlteration>>>()?;

View File

@@ -24,6 +24,13 @@ use crate::Error;
const DEFAULT_TOP_K: usize = 10;
#[derive(Debug, Clone)]
pub enum Select {
All,
Simple(Vec<String>),
Projection(Vec<(String, String)>),
}
/// A builder for nearest neighbor queries for LanceDB.
#[derive(Clone)]
pub struct Query {
@@ -44,7 +51,7 @@ pub struct Query {
/// Apply filter to the returned rows.
filter: Option<String>,
/// Select column projection.
select: Option<Vec<String>>,
select: Select,
/// Default is true. Set to false to enforce a brute force search.
use_index: bool,
@@ -70,7 +77,7 @@ impl Query {
metric_type: None,
use_index: true,
filter: None,
select: None,
select: Select::All,
prefilter: false,
}
}
@@ -115,7 +122,16 @@ impl Query {
scanner.use_index(self.use_index);
scanner.prefilter(self.prefilter);
self.select.as_ref().map(|p| scanner.project(p.as_slice()));
match &self.select {
Select::Simple(select) => {
scanner.project(select.as_slice())?;
}
Select::Projection(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
self.filter.as_ref().map(|f| scanner.filter(f));
self.refine_factor.map(|rf| scanner.refine(rf));
self.metric_type.map(|mt| scanner.distance_metric(mt));
@@ -206,7 +222,23 @@ impl Query {
///
/// Only select the specified columns. If not specified, all columns will be returned.
pub fn select(mut self, columns: &[impl AsRef<str>]) -> Self {
self.select = Some(columns.iter().map(|c| c.as_ref().to_string()).collect());
self.select = Select::Simple(columns.iter().map(|c| c.as_ref().to_string()).collect());
self
}
/// Return only the specified columns.
///
/// Only select the specified columns. If not specified, all columns will be returned.
pub fn select_with_projection(
mut self,
columns: &[(impl AsRef<str>, impl AsRef<str>)],
) -> Self {
self.select = Select::Projection(
columns
.iter()
.map(|(c, t)| (c.as_ref().to_string(), t.as_ref().to_string()))
.collect(),
);
self
}
@@ -226,7 +258,7 @@ mod tests {
RecordBatchReader,
};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use futures::StreamExt;
use futures::{StreamExt, TryStreamExt};
use lance::dataset::Dataset;
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
use tempfile::tempdir;
@@ -294,6 +326,38 @@ mod tests {
}
}
#[tokio::test]
async fn test_select_with_transform() {
let batches = make_non_empty_batches();
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
let ds = DatasetConsistencyWrapper::new_latest(ds, None);
let query = Query::new(ds)
.limit(10)
.select_with_projection(&[("id2", "id * 2"), ("id", "id")]);
let result = query.execute_stream().await;
let mut batches = result
.expect("should have result")
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches.len(), 1);
let batch = batches.pop().unwrap();
// id, and id2
assert_eq!(batch.num_columns(), 2);
let id: &Int32Array = batch.column_by_name("id").unwrap().as_primitive();
let id2: &Int32Array = batch.column_by_name("id2").unwrap().as_primitive();
id.iter().zip(id2.iter()).for_each(|(id, id2)| {
let id = id.unwrap();
let id2 = id2.unwrap();
assert_eq!(id * 2, id2);
});
}
#[tokio::test]
async fn test_execute_no_vector() {
// test that it's ok to not specify a query vector (just filter / limit)