mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
chore: upgrade to lance 0.10.1 (#1034)
upgrade to lance 0.10.1 and update doc string to reflect dynamic projection options
This commit is contained in:
@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.9.18", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.9.18" }
|
||||
lance-linalg = { "version" = "=0.9.18" }
|
||||
lance-testing = { "version" = "=0.9.18" }
|
||||
lance = { "version" = "=0.10.1", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.10.1" }
|
||||
lance-linalg = { "version" = "=0.10.1" }
|
||||
lance-testing = { "version" = "=0.10.1" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "50.0", optional = false }
|
||||
arrow-array = "50.0"
|
||||
|
||||
@@ -198,7 +198,7 @@ describe('LanceDB client', function () {
|
||||
const table = await con.openTable('vectors')
|
||||
const results = await table
|
||||
.search([0.1, 0.1])
|
||||
.select(['is_active'])
|
||||
.select(['is_active', 'vector'])
|
||||
.execute()
|
||||
assert.equal(results.length, 2)
|
||||
// vector and _distance are always returned
|
||||
|
||||
@@ -184,6 +184,8 @@ impl From<ColumnAlteration> for LanceColumnAlteration {
|
||||
path,
|
||||
rename,
|
||||
nullable,
|
||||
// TODO: wire up this field
|
||||
data_type: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ name = "lancedb"
|
||||
version = "0.5.7"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.9.18",
|
||||
"pylance==0.10.1",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.27.0",
|
||||
|
||||
@@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
import deprecation
|
||||
import numpy as np
|
||||
@@ -93,7 +93,7 @@ class Query(pydantic.BaseModel):
|
||||
metric: str = "L2"
|
||||
|
||||
# which columns to return in the results
|
||||
columns: Optional[List[str]] = None
|
||||
columns: Optional[Union[List[str], Dict[str, str]]] = None
|
||||
|
||||
# optional query parameters for tuning the results,
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
@@ -321,20 +321,27 @@ class LanceQueryBuilder(ABC):
|
||||
self._limit = limit
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceQueryBuilder:
|
||||
def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder:
|
||||
"""Set the columns to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list
|
||||
The columns to return.
|
||||
columns: list of str, or dict of str to str default None
|
||||
List of column names to be fetched.
|
||||
Or a dictionary of column names to SQL expressions.
|
||||
All columns are fetched if None or unspecified.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._columns = columns
|
||||
if isinstance(columns, list):
|
||||
self._columns = columns
|
||||
elif isinstance(columns, dict):
|
||||
self._columns = list(columns.items())
|
||||
else:
|
||||
raise ValueError("columns must be a list or a dictionary")
|
||||
return self
|
||||
|
||||
def where(self, where: str, prefilter: bool = False) -> LanceQueryBuilder:
|
||||
@@ -392,7 +399,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
>>> (table.search([0.4, 0.4])
|
||||
... .metric("cosine")
|
||||
... .where("b < 10")
|
||||
... .select(["b"])
|
||||
... .select(["b", "vector"])
|
||||
... .limit(2)
|
||||
... .to_pandas())
|
||||
b vector _distance
|
||||
|
||||
@@ -159,7 +159,7 @@ class Table(ABC):
|
||||
|
||||
Can query the table with [Table.search][lancedb.table.Table.search].
|
||||
|
||||
>>> table.search([0.4, 0.4]).select(["b"]).to_pandas()
|
||||
>>> table.search([0.4, 0.4]).select(["b", "vector"]).to_pandas()
|
||||
b vector _distance
|
||||
0 4 [0.5, 1.3] 0.82
|
||||
1 2 [1.1, 1.2] 1.13
|
||||
@@ -435,7 +435,7 @@ class Table(ABC):
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query)
|
||||
... .where("original_width > 1000", prefilter=True)
|
||||
... .select(["caption", "original_width"])
|
||||
... .select(["caption", "original_width", "vector"])
|
||||
... .limit(2)
|
||||
... .to_pandas())
|
||||
caption original_width vector _distance
|
||||
@@ -1264,7 +1264,7 @@ class LanceTable(Table):
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query)
|
||||
... .where("original_width > 1000", prefilter=True)
|
||||
... .select(["caption", "original_width"])
|
||||
... .select(["caption", "original_width", "vector"])
|
||||
... .limit(2)
|
||||
... .to_pandas())
|
||||
caption original_width vector _distance
|
||||
|
||||
@@ -87,7 +87,7 @@ def test_query_builder(table):
|
||||
rs = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.limit(1)
|
||||
.select(["id"])
|
||||
.select(["id", "vector"])
|
||||
.to_list()
|
||||
)
|
||||
assert rs[0]["id"] == 1
|
||||
|
||||
@@ -604,6 +604,8 @@ impl JsTable {
|
||||
path,
|
||||
rename,
|
||||
nullable,
|
||||
// TODO: wire up this field
|
||||
data_type: None,
|
||||
})
|
||||
})
|
||||
.collect::<NeonResult<Vec<ColumnAlteration>>>()?;
|
||||
|
||||
@@ -24,6 +24,13 @@ use crate::Error;
|
||||
|
||||
const DEFAULT_TOP_K: usize = 10;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Select {
|
||||
All,
|
||||
Simple(Vec<String>),
|
||||
Projection(Vec<(String, String)>),
|
||||
}
|
||||
|
||||
/// A builder for nearest neighbor queries for LanceDB.
|
||||
#[derive(Clone)]
|
||||
pub struct Query {
|
||||
@@ -44,7 +51,7 @@ pub struct Query {
|
||||
/// Apply filter to the returned rows.
|
||||
filter: Option<String>,
|
||||
/// Select column projection.
|
||||
select: Option<Vec<String>>,
|
||||
select: Select,
|
||||
|
||||
/// Default is true. Set to false to enforce a brute force search.
|
||||
use_index: bool,
|
||||
@@ -70,7 +77,7 @@ impl Query {
|
||||
metric_type: None,
|
||||
use_index: true,
|
||||
filter: None,
|
||||
select: None,
|
||||
select: Select::All,
|
||||
prefilter: false,
|
||||
}
|
||||
}
|
||||
@@ -115,7 +122,16 @@ impl Query {
|
||||
scanner.use_index(self.use_index);
|
||||
scanner.prefilter(self.prefilter);
|
||||
|
||||
self.select.as_ref().map(|p| scanner.project(p.as_slice()));
|
||||
match &self.select {
|
||||
Select::Simple(select) => {
|
||||
scanner.project(select.as_slice())?;
|
||||
}
|
||||
Select::Projection(select_with_transform) => {
|
||||
scanner.project_with_transform(select_with_transform.as_slice())?;
|
||||
}
|
||||
Select::All => { /* Do nothing */ }
|
||||
}
|
||||
|
||||
self.filter.as_ref().map(|f| scanner.filter(f));
|
||||
self.refine_factor.map(|rf| scanner.refine(rf));
|
||||
self.metric_type.map(|mt| scanner.distance_metric(mt));
|
||||
@@ -206,7 +222,23 @@ impl Query {
|
||||
///
|
||||
/// Only select the specified columns. If not specified, all columns will be returned.
|
||||
pub fn select(mut self, columns: &[impl AsRef<str>]) -> Self {
|
||||
self.select = Some(columns.iter().map(|c| c.as_ref().to_string()).collect());
|
||||
self.select = Select::Simple(columns.iter().map(|c| c.as_ref().to_string()).collect());
|
||||
self
|
||||
}
|
||||
|
||||
/// Return only the specified columns.
|
||||
///
|
||||
/// Only select the specified columns. If not specified, all columns will be returned.
|
||||
pub fn select_with_projection(
|
||||
mut self,
|
||||
columns: &[(impl AsRef<str>, impl AsRef<str>)],
|
||||
) -> Self {
|
||||
self.select = Select::Projection(
|
||||
columns
|
||||
.iter()
|
||||
.map(|(c, t)| (c.as_ref().to_string(), t.as_ref().to_string()))
|
||||
.collect(),
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -226,7 +258,7 @@ mod tests {
|
||||
RecordBatchReader,
|
||||
};
|
||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||
use futures::StreamExt;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::dataset::Dataset;
|
||||
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
|
||||
use tempfile::tempdir;
|
||||
@@ -294,6 +326,38 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_select_with_transform() {
|
||||
let batches = make_non_empty_batches();
|
||||
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
|
||||
|
||||
let ds = DatasetConsistencyWrapper::new_latest(ds, None);
|
||||
|
||||
let query = Query::new(ds)
|
||||
.limit(10)
|
||||
.select_with_projection(&[("id2", "id * 2"), ("id", "id")]);
|
||||
let result = query.execute_stream().await;
|
||||
let mut batches = result
|
||||
.expect("should have result")
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(batches.len(), 1);
|
||||
let batch = batches.pop().unwrap();
|
||||
|
||||
// id, and id2
|
||||
assert_eq!(batch.num_columns(), 2);
|
||||
|
||||
let id: &Int32Array = batch.column_by_name("id").unwrap().as_primitive();
|
||||
let id2: &Int32Array = batch.column_by_name("id2").unwrap().as_primitive();
|
||||
|
||||
id.iter().zip(id2.iter()).for_each(|(id, id2)| {
|
||||
let id = id.unwrap();
|
||||
let id2 = id2.unwrap();
|
||||
assert_eq!(id * 2, id2);
|
||||
});
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_no_vector() {
|
||||
// test that it's ok to not specify a query vector (just filter / limit)
|
||||
|
||||
Reference in New Issue
Block a user