Compare commits

..

2 Commits

Author SHA1 Message Date
rmeng
295be1ab1c relax half req 2024-05-02 13:46:31 -04:00
rmeng
91f980ec5d chore: upgrade to lance 0.10.18 2024-05-02 13:42:18 -04:00
35 changed files with 129 additions and 1226 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.19
current_version = 0.4.18
commit = True
message = Bump version: {current_version} → {new_version}
tag = True

2
.gitignore vendored
View File

@@ -6,7 +6,7 @@
venv
.vscode
.zed
rust/target
rust/Cargo.lock

View File

@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies]
lance = { "version" = "=0.10.16", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.10.16" }
lance-linalg = { "version" = "=0.10.16" }
lance-testing = { "version" = "=0.10.16" }
lance = { "version" = "=0.10.18", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.10.18" }
lance-linalg = { "version" = "=0.10.18" }
lance-testing = { "version" = "=0.10.18" }
# Note that this one does not include pyarrow
arrow = { version = "51.0", optional = false }
arrow-array = "51.0"
@@ -29,7 +29,7 @@ arrow-arith = "51.0"
arrow-cast = "51.0"
async-trait = "0"
chrono = "0.4.35"
half = { "version" = "=2.3.1", default-features = false, features = [
half = { "version" = "2.4.1", default-features = false, features = [
"num-traits",
] }
futures = "0"

View File

@@ -20,7 +20,7 @@
<hr />
LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrieval, filtering and management of embeddings.
LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrevial, filtering and management of embeddings.
The key features of LanceDB include:
@@ -36,7 +36,7 @@ The key features of LanceDB include:
* GPU support in building vector index(*).
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/docs/integrations/vectorstores/lancedb/), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lanecdb.html), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
LanceDB's core is written in Rust 🦀 and is built using <a href="https://github.com/lancedb/lance">Lance</a>, an open-source columnar format designed for performant ML workloads.

View File

@@ -299,14 +299,6 @@ LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you m
This can also be done with the ``AWS_ENDPOINT`` and ``AWS_DEFAULT_REGION`` environment variables.
!!! tip "Local servers"
For local development, the server often has a `http` endpoint rather than a
secure `https` endpoint. In this case, you must also set the `ALLOW_HTTP`
environment variable to `true` to allow non-TLS connections, or pass the
storage option `allow_http` as `true`. If you do not do this, you will get
an error like `URL scheme is not allowed`.
#### S3 Express
LanceDB supports [S3 Express One Zone](https://aws.amazon.com/s3/storage-classes/express-one-zone/) endpoints, but requires additional configuration. Also, S3 Express endpoints only support connecting from an EC2 instance within the same region.

View File

@@ -36,7 +36,7 @@
}
],
"source": [
"!pip install --quiet openai datasets\n",
"!pip install --quiet openai datasets \n",
"!pip install --quiet -U lancedb"
]
},
@@ -213,7 +213,7 @@
"if \"OPENAI_API_KEY\" not in os.environ:\n",
" # OR set the key here as a variable\n",
" os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n",
"\n",
" \n",
"client = OpenAI()\n",
"assert len(client.models.list().data) > 0"
]
@@ -234,12 +234,9 @@
"metadata": {},
"outputs": [],
"source": [
"def embed_func(c):\n",
"def embed_func(c): \n",
" rs = client.embeddings.create(input=c, model=\"text-embedding-ada-002\")\n",
" return [\n",
" data.embedding\n",
" for data in rs.data\n",
" ]"
" return [rs.data[0].embedding]"
]
},
{
@@ -517,7 +514,7 @@
" prompt_start +\n",
" \"\\n\\n---\\n\\n\".join(context.text) +\n",
" prompt_end\n",
" )\n",
" ) \n",
" return prompt"
]
},

74
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.4.19",
"version": "0.4.18",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.4.19",
"version": "0.4.18",
"cpu": [
"x64",
"arm64"
@@ -52,11 +52,11 @@
"uuid": "^9.0.0"
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.4.19",
"@lancedb/vectordb-darwin-x64": "0.4.19",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.19",
"@lancedb/vectordb-linux-x64-gnu": "0.4.19",
"@lancedb/vectordb-win32-x64-msvc": "0.4.19"
"@lancedb/vectordb-darwin-arm64": "0.4.18",
"@lancedb/vectordb-darwin-x64": "0.4.18",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.18",
"@lancedb/vectordb-linux-x64-gnu": "0.4.18",
"@lancedb/vectordb-win32-x64-msvc": "0.4.18"
},
"peerDependencies": {
"@apache-arrow/ts": "^14.0.2",
@@ -333,6 +333,66 @@
"@jridgewell/sourcemap-codec": "^1.4.10"
}
},
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.4.18",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.18.tgz",
"integrity": "sha512-CzJbkBKz30U0ocFFhRLV3ZPRZh3MtAkOmFr76jxRWeXLPM/JcLvhGOAnW9h/XdTONidHOfHNZnUtrjeWDMCyig==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.4.18",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.18.tgz",
"integrity": "sha512-wyqpfdDBE5g+8SLN/6E/r37smt5i+4H3MVNZ2GZfvcMAd4xIZTwGAf5Mfx8j15t3mvKMiBEZPTvYQfFde2bQmA==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.4.18",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.18.tgz",
"integrity": "sha512-OfCjTrwfdmT9Qh5r92AUj7Wosvl8mSYADS6rp+ofNoht9nq1UqtlyrCot1RhuF5w6UW1aLCJXycXY0qHh1WUPw==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.4.18",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.18.tgz",
"integrity": "sha512-uJiCminsZQ6oUvVYEElgN+/Lqd9646cJUWCbfiFSnt10PCj/kFBXWjKEuCxfG/A0bp6DTm5mU5RYWDfY9v3T0Q==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.4.18",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.18.tgz",
"integrity": "sha512-T9//HlvtNDHEfbIjz0ExLQFbqKFHRdaKf3jpFECt5oJdU0VCXe5460DIMvp6w/SDf24pb1UvOjZnmPLvX4yPNg==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"win32"
]
},
"node_modules/@neon-rs/cli": {
"version": "0.0.160",
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",

View File

@@ -1,6 +1,6 @@
{
"name": "vectordb",
"version": "0.4.19",
"version": "0.4.18",
"description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js",
"types": "dist/index.d.ts",
@@ -88,10 +88,10 @@
}
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.4.19",
"@lancedb/vectordb-darwin-x64": "0.4.19",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.19",
"@lancedb/vectordb-linux-x64-gnu": "0.4.19",
"@lancedb/vectordb-win32-x64-msvc": "0.4.19"
"@lancedb/vectordb-darwin-arm64": "0.4.18",
"@lancedb/vectordb-darwin-x64": "0.4.18",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.18",
"@lancedb/vectordb-linux-x64-gnu": "0.4.18",
"@lancedb/vectordb-win32-x64-msvc": "0.4.18"
}
}

View File

@@ -51,7 +51,7 @@ describe('LanceDB Mirrored Store Integration test', function () {
const dir = tmpdir()
console.log(dir)
const conn = await lancedb.connect({ uri: `s3://lancedb-integtest?mirroredStore=${dir}`, storageOptions: { allowHttp: 'true' } })
const conn = await lancedb.connect(`s3://lancedb-integtest?mirroredStore=${dir}`)
const data = Array(200).fill({ vector: Array(128).fill(1.0), id: 0 })
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 1 }))
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 2 }))

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.4.19",
"version": "0.4.18",
"os": [
"darwin"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.4.19",
"version": "0.4.18",
"os": [
"darwin"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.4.19",
"version": "0.4.18",
"os": [
"linux"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.4.19",
"version": "0.4.18",
"os": [
"linux"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb",
"version": "0.4.19",
"version": "0.4.18",
"main": "./dist/index.js",
"types": "./dist/index.d.ts",
"napi": {

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.6.12
current_version = 0.6.11
commit = True
message = [python] Bump version: {current_version} → {new_version}
tag = True

View File

@@ -1,6 +1,6 @@
[project]
name = "lancedb"
version = "0.6.12"
version = "0.6.11"
dependencies = [
"deprecation",
"pylance==0.10.12",

View File

@@ -107,9 +107,6 @@ def connect(
request_thread_pool=request_thread_pool,
**kwargs,
)
if kwargs:
raise ValueError(f"Unknown keyword arguments: {kwargs}")
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)

View File

@@ -255,13 +255,7 @@ def retry_with_exponential_backoff(
)
delay *= exponential_base * (1 + jitter * random.random())
logging.warning(
"Error occurred: %s \n Retrying in %s seconds (retry %s of %s) \n",
e,
delay,
num_retries,
max_retries,
)
logging.info("Retrying in %s seconds...", delay)
time.sleep(delay)
return wrapper

View File

@@ -30,7 +30,6 @@ from typing import (
import deprecation
import numpy as np
import pyarrow as pa
import pyarrow.fs as pa_fs
import pydantic
from . import __version__
@@ -38,7 +37,7 @@ from .arrow import AsyncRecordBatchReader
from .common import VEC
from .rerankers.base import Reranker
from .rerankers.linear_combination import LinearCombinationReranker
from .util import fs_from_uri, safe_import_pandas
from .util import safe_import_pandas
if TYPE_CHECKING:
import PIL
@@ -666,14 +665,6 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
# get the index path
index_path = self._table._get_fts_index_path()
# Check that we are on local filesystem
fs, _path = fs_from_uri(index_path)
if not isinstance(fs, pa_fs.LocalFileSystem):
raise NotImplementedError(
"Full-text search is only supported on the local filesystem"
)
# check if the index exist
if not Path(index_path).exists():
raise FileNotFoundError(

View File

@@ -1209,11 +1209,6 @@ class LanceTable(Table):
raise ValueError("Index already exists. Use replace=True to overwrite.")
fs.delete_dir(path)
if not isinstance(fs, pa_fs.LocalFileSystem):
raise NotImplementedError(
"Full-text search is only supported on the local filesystem"
)
index = create_index(
self._get_fts_index_path(),
field_names,

View File

@@ -213,7 +213,7 @@ def test_syntax(table):
# https://github.com/lancedb/lancedb/issues/769
table.create_fts_index("text")
with pytest.raises(ValueError, match="Syntax Error"):
table.search("they could have been dogs OR").limit(10).to_list()
table.search("they could have been dogs OR cats").limit(10).to_list()
# these should work

View File

@@ -35,16 +35,21 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
match &self {
Ok(_) => Ok(self.unwrap()),
Err(err) => match err {
LanceError::InvalidInput { .. }
| LanceError::InvalidTableName { .. }
| LanceError::TableNotFound { .. }
| LanceError::Schema { .. } => self.value_error(),
LanceError::InvalidInput { .. } => self.value_error(),
LanceError::InvalidTableName { .. } => self.value_error(),
LanceError::TableNotFound { .. } => self.value_error(),
LanceError::Schema { .. } => self.value_error(),
LanceError::CreateDir { .. } => self.os_error(),
LanceError::TableAlreadyExists { .. } => self.runtime_error(),
LanceError::ObjectStore { .. } => Err(PyIOError::new_err(err.to_string())),
LanceError::Lance { .. } => self.runtime_error(),
LanceError::Runtime { .. } => self.runtime_error(),
LanceError::Http { .. } => self.runtime_error(),
LanceError::Arrow { .. } => self.runtime_error(),
LanceError::NotSupported { .. } => {
Err(PyNotImplementedError::new_err(err.to_string()))
}
_ => self.runtime_error(),
LanceError::Other { .. } => self.runtime_error(),
},
}
}

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-node"
version = "0.4.19"
version = "0.4.18"
description = "Serverless, low-latency vector database for AI applications"
license.workspace = true
edition.workspace = true

View File

@@ -59,7 +59,7 @@ fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
for handle in storage_options_js {
let obj = handle.downcast::<JsArray, _>(&mut cx).unwrap();
let key = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
let value = obj.get::<JsString, _, _>(&mut cx, 1)?.value(&mut cx);
let value = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
storage_options.push((key, value));
}

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.4.19"
version = "0.4.18"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -40,8 +40,6 @@ serde = { version = "^1" }
serde_json = { version = "1" }
# For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
polars-arrow = { version = ">=0.37", optional = true }
polars = { version = ">=0.37", optional = true}
[dev-dependencies]
tempfile = "3.5.0"
@@ -58,4 +56,3 @@ default = []
remote = ["dep:reqwest"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
polars = ["dep:polars-arrow", "dep:polars"]

View File

@@ -14,12 +14,10 @@
use std::{pin::Pin, sync::Arc};
pub use arrow_array;
pub use arrow_schema;
use futures::{Stream, StreamExt};
#[cfg(feature = "polars")]
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
use crate::error::Result;
/// An iterator of batches that also has a schema
@@ -116,183 +114,8 @@ pub trait IntoArrow {
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>>;
}
pub type BoxedRecordBatchReader = Box<dyn arrow_array::RecordBatchReader + Send>;
impl<T: arrow_array::RecordBatchReader + Send + 'static> IntoArrow for T {
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
Ok(Box::new(self))
}
}
impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> SimpleRecordBatchStream<S> {
pub fn new(stream: S, schema: Arc<arrow_schema::Schema>) -> Self {
Self { schema, stream }
}
}
#[cfg(feature = "polars")]
/// An iterator of record batches formed from a Polars DataFrame.
pub struct PolarsDataFrameRecordBatchReader {
chunks: std::vec::IntoIter<ArrowChunk>,
arrow_schema: Arc<arrow_schema::Schema>,
}
#[cfg(feature = "polars")]
impl PolarsDataFrameRecordBatchReader {
/// Creates a new `PolarsDataFrameRecordBatchReader` from a given Polars DataFrame.
/// If the input dataframe does not have aligned chunks, this function undergoes
/// the costly operation of reallocating each series as a single contigous chunk.
pub fn new(mut df: DataFrame) -> Result<Self> {
df.align_chunks();
let arrow_schema =
polars_arrow_convertors::convert_polars_df_schema_to_arrow_rb_schema(df.schema())?;
Ok(Self {
chunks: df
.iter_chunks(polars_arrow_convertors::POLARS_ARROW_FLAVOR)
.collect::<Vec<ArrowChunk>>()
.into_iter(),
arrow_schema,
})
}
}
#[cfg(feature = "polars")]
impl Iterator for PolarsDataFrameRecordBatchReader {
type Item = std::result::Result<arrow_array::RecordBatch, arrow_schema::ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
self.chunks.next().map(|chunk| {
let columns: std::result::Result<Vec<arrow_array::ArrayRef>, arrow_schema::ArrowError> =
chunk
.into_arrays()
.into_iter()
.zip(self.arrow_schema.fields.iter())
.map(|(polars_array, arrow_field)| {
polars_arrow_convertors::convert_polars_arrow_array_to_arrow_rs_array(
polars_array,
arrow_field.data_type().clone(),
)
})
.collect();
arrow_array::RecordBatch::try_new(self.arrow_schema.clone(), columns?)
})
}
}
#[cfg(feature = "polars")]
impl arrow_array::RecordBatchReader for PolarsDataFrameRecordBatchReader {
fn schema(&self) -> Arc<arrow_schema::Schema> {
self.arrow_schema.clone()
}
}
/// A trait for converting the result of a LanceDB query into a Polars DataFrame with aligned
/// chunks. The resulting Polars DataFrame will have aligned chunks, but the series's
/// chunks are not guaranteed to be contiguous.
#[cfg(feature = "polars")]
pub trait IntoPolars {
fn into_polars(self) -> impl std::future::Future<Output = Result<DataFrame>> + Send;
}
#[cfg(feature = "polars")]
impl IntoPolars for SendableRecordBatchStream {
async fn into_polars(mut self) -> Result<DataFrame> {
let polars_schema =
polars_arrow_convertors::convert_arrow_rb_schema_to_polars_df_schema(&self.schema())?;
let mut acc_df: DataFrame = DataFrame::from(&polars_schema);
while let Some(record_batch) = self.next().await {
let new_df = polars_arrow_convertors::convert_arrow_rb_to_polars_df(
&record_batch?,
&polars_schema,
)?;
acc_df = acc_df.vstack(&new_df)?;
}
Ok(acc_df)
}
}
#[cfg(all(test, feature = "polars"))]
mod tests {
use super::SendableRecordBatchStream;
use crate::arrow::{
IntoArrow, IntoPolars, PolarsDataFrameRecordBatchReader, SimpleRecordBatchStream,
};
use polars::prelude::{DataFrame, NamedFrom, Series};
fn get_record_batch_reader_from_polars() -> Box<dyn arrow_array::RecordBatchReader + Send> {
let mut string_series = Series::new("string", &["ab"]);
let mut int_series = Series::new("int", &[1]);
let mut float_series = Series::new("float", &[1.0]);
let df1 = DataFrame::new(vec![string_series, int_series, float_series]).unwrap();
string_series = Series::new("string", &["bc"]);
int_series = Series::new("int", &[2]);
float_series = Series::new("float", &[2.0]);
let df2 = DataFrame::new(vec![string_series, int_series, float_series]).unwrap();
PolarsDataFrameRecordBatchReader::new(df1.vstack(&df2).unwrap())
.unwrap()
.into_arrow()
.unwrap()
}
#[test]
fn from_polars_to_arrow() {
let record_batch_reader = get_record_batch_reader_from_polars();
let schema = record_batch_reader.schema();
// Test schema conversion
assert_eq!(
schema
.fields
.iter()
.map(|field| (field.name().as_str(), field.data_type()))
.collect::<Vec<_>>(),
vec![
("string", &arrow_schema::DataType::LargeUtf8),
("int", &arrow_schema::DataType::Int32),
("float", &arrow_schema::DataType::Float64)
]
);
let record_batches: Vec<arrow_array::RecordBatch> =
record_batch_reader.map(|result| result.unwrap()).collect();
assert_eq!(record_batches.len(), 2);
assert_eq!(schema, record_batches[0].schema());
assert_eq!(record_batches[0].schema(), record_batches[1].schema());
// Test number of rows
assert_eq!(record_batches[0].num_rows(), 1);
assert_eq!(record_batches[1].num_rows(), 1);
}
#[tokio::test]
async fn from_arrow_to_polars() {
let record_batch_reader = get_record_batch_reader_from_polars();
let schema = record_batch_reader.schema();
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
schema: schema.clone(),
stream: futures::stream::iter(
record_batch_reader
.into_iter()
.map(|r| r.map_err(Into::into)),
),
});
let df = stream.into_polars().await.unwrap();
// Test number of chunks and rows
assert_eq!(df.n_chunks(), 2);
assert_eq!(df.height(), 2);
// Test schema conversion
assert_eq!(
df.schema()
.into_iter()
.map(|(name, datatype)| (name.to_string(), datatype))
.collect::<Vec<_>>(),
vec![
("string".to_string(), polars::prelude::DataType::String),
("int".to_owned(), polars::prelude::DataType::Int32),
("float".to_owned(), polars::prelude::DataType::Float64)
]
);
}
}

View File

@@ -27,12 +27,9 @@ use object_store::{aws::AwsCredential, local::LocalFileSystem};
use snafu::prelude::*;
use crate::arrow::IntoArrow;
use crate::embeddings::{
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
};
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
use crate::io::object_store::MirroringObjectStoreWrapper;
use crate::table::{NativeTable, TableDefinition, WriteOptions};
use crate::table::{NativeTable, WriteOptions};
use crate::utils::validate_table_name;
use crate::Table;
@@ -136,10 +133,9 @@ pub struct CreateTableBuilder<const HAS_DATA: bool, T: IntoArrow> {
parent: Arc<dyn ConnectionInternal>,
pub(crate) name: String,
pub(crate) data: Option<T>,
pub(crate) schema: Option<SchemaRef>,
pub(crate) mode: CreateTableMode,
pub(crate) write_options: WriteOptions,
pub(crate) table_definition: Option<TableDefinition>,
pub(crate) embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
}
// Builder methods that only apply when we have initial data
@@ -149,10 +145,9 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
parent,
name,
data: Some(data),
schema: None,
mode: CreateTableMode::default(),
write_options: WriteOptions::default(),
table_definition: None,
embeddings: Vec::new(),
}
}
@@ -180,43 +175,24 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
parent: self.parent,
name: self.name,
data: None,
table_definition: self.table_definition,
schema: self.schema,
mode: self.mode,
write_options: self.write_options,
embeddings: self.embeddings,
};
Ok((data, builder))
}
pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result<Self> {
// Early verification of the embedding name
let embedding_func = self
.parent
.embedding_registry()
.get(&definition.embedding_name)
.ok_or_else(|| Error::EmbeddingFunctionNotFound {
name: definition.embedding_name.to_string(),
reason: "No embedding function found in the connection's embedding_registry"
.to_string(),
})?;
self.embeddings.push((definition, embedding_func));
Ok(self)
}
}
// Builder methods that only apply when we do not have initial data
impl CreateTableBuilder<false, NoData> {
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
let table_definition = TableDefinition::new_from_schema(schema);
Self {
parent,
name,
data: None,
table_definition: Some(table_definition),
schema: Some(schema),
mode: CreateTableMode::default(),
write_options: WriteOptions::default(),
embeddings: Vec::new(),
}
}
@@ -374,7 +350,6 @@ impl OpenTableBuilder {
pub(crate) trait ConnectionInternal:
Send + Sync + std::fmt::Debug + std::fmt::Display + 'static
{
fn embedding_registry(&self) -> &dyn EmbeddingRegistry;
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>>;
async fn do_create_table(
&self,
@@ -391,7 +366,7 @@ pub(crate) trait ConnectionInternal:
) -> Result<Table> {
let batches = Box::new(RecordBatchIterator::new(
vec![],
options.table_definition.clone().unwrap().schema.clone(),
options.schema.as_ref().unwrap().clone(),
));
self.do_create_table(options, batches).await
}
@@ -478,13 +453,6 @@ impl Connection {
pub async fn drop_db(&self) -> Result<()> {
self.internal.drop_db().await
}
/// Get the in-memory embedding registry.
/// It's important to note that the embedding registry is not persisted across connections.
/// So if a table contains embeddings, you will need to make sure that you are using a connection that has the same embedding functions registered
pub fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
self.internal.embedding_registry()
}
}
#[derive(Debug)]
@@ -518,7 +486,6 @@ pub struct ConnectBuilder {
/// consistency only applies to read operations. Write operations are
/// always consistent.
read_consistency_interval: Option<std::time::Duration>,
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
}
impl ConnectBuilder {
@@ -531,7 +498,6 @@ impl ConnectBuilder {
host_override: None,
read_consistency_interval: None,
storage_options: HashMap::new(),
embedding_registry: None,
}
}
@@ -550,12 +516,6 @@ impl ConnectBuilder {
self
}
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
self.embedding_registry = Some(registry);
self
}
/// [`AwsCredential`] to use when connecting to S3.
#[deprecated(note = "Pass through storage_options instead")]
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
@@ -682,7 +642,6 @@ struct Database {
// Storage options to be inherited by tables created from this connection
storage_options: HashMap<String, String>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
impl std::fmt::Display for Database {
@@ -716,12 +675,7 @@ impl Database {
// TODO: pass params regardless of OS
match parse_res {
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
Self::open_path(
uri,
options.read_consistency_interval,
options.embedding_registry.clone(),
)
.await
Self::open_path(uri, options.read_consistency_interval).await
}
Ok(mut url) => {
// iter thru the query params and extract the commit store param
@@ -791,10 +745,6 @@ impl Database {
None => None,
};
let embedding_registry = options
.embedding_registry
.clone()
.unwrap_or_else(|| Arc::new(MemoryRegistry::new()));
Ok(Self {
uri: table_base_uri,
query_string,
@@ -803,33 +753,20 @@ impl Database {
store_wrapper: write_store_wrapper,
read_consistency_interval: options.read_consistency_interval,
storage_options,
embedding_registry,
})
}
Err(_) => {
Self::open_path(
uri,
options.read_consistency_interval,
options.embedding_registry.clone(),
)
.await
}
Err(_) => Self::open_path(uri, options.read_consistency_interval).await,
}
}
async fn open_path(
path: &str,
read_consistency_interval: Option<std::time::Duration>,
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
) -> Result<Self> {
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
if object_store.is_local() {
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
}
let embedding_registry =
embedding_registry.unwrap_or_else(|| Arc::new(MemoryRegistry::new()));
Ok(Self {
uri: path.to_string(),
query_string: None,
@@ -838,7 +775,6 @@ impl Database {
store_wrapper: None,
read_consistency_interval,
storage_options: HashMap::new(),
embedding_registry,
})
}
@@ -879,9 +815,6 @@ impl Database {
#[async_trait::async_trait]
impl ConnectionInternal for Database {
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
self.embedding_registry.as_ref()
}
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>> {
let mut f = self
.object_store
@@ -918,7 +851,7 @@ impl ConnectionInternal for Database {
data: Box<dyn RecordBatchReader + Send>,
) -> Result<Table> {
let table_uri = self.table_uri(&options.name)?;
let embedding_registry = self.embedding_registry.clone();
// Inherit storage options from the connection
let storage_options = options
.write_options
@@ -933,11 +866,6 @@ impl ConnectionInternal for Database {
storage_options.insert(key.clone(), value.clone());
}
}
let data = if options.embeddings.is_empty() {
data
} else {
Box::new(WithEmbeddings::new(data, options.embeddings))
};
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
if matches!(&options.mode, CreateTableMode::Overwrite) {
@@ -954,10 +882,7 @@ impl ConnectionInternal for Database {
)
.await
{
Ok(table) => Ok(Table::new_with_embedding_registry(
Arc::new(table),
embedding_registry,
)),
Ok(table) => Ok(Table::new(Arc::new(table))),
Err(Error::TableAlreadyExists { name }) => match options.mode {
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
CreateTableMode::ExistOk(callback) => {

View File

@@ -1,307 +0,0 @@
// Copyright 2024 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use lance::arrow::RecordBatchExt;
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
sync::{Arc, RwLock},
};
use arrow_array::{Array, RecordBatch, RecordBatchReader};
use arrow_schema::{DataType, Field, SchemaBuilder};
// use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::{
error::Result,
table::{ColumnDefinition, ColumnKind, TableDefinition},
Error,
};
/// Trait for embedding functions
///
/// An embedding function is a function that is applied to a column of input data
/// to produce an "embedding" of that input. This embedding is then stored in the
/// database alongside (or instead of) the original input.
///
/// An "embedding" is often a lower-dimensional representation of the input data.
/// For example, sentence-transformers can be used to embed sentences into a 768-dimensional
/// vector space. This is useful for tasks like similarity search, where we want to find
/// similar sentences to a query sentence.
///
/// To use an embedding function you must first register it with the `EmbeddingsRegistry`.
/// Then you can define it on a column in the table schema. That embedding will then be used
/// to embed the data in that column.
pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync {
fn name(&self) -> &str;
/// The type of the input data
fn source_type(&self) -> Result<Cow<DataType>>;
/// The type of the output data
/// This should **always** match the output of the `embed` function
fn dest_type(&self) -> Result<Cow<DataType>>;
/// Embed the input
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
}
/// Defines an embedding from input data into a lower-dimensional space
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct EmbeddingDefinition {
/// The name of the column in the input data
pub source_column: String,
/// The name of the embedding column, if not specified
/// it will be the source column with `_embedding` appended
pub dest_column: Option<String>,
/// The name of the embedding function to apply
pub embedding_name: String,
}
impl EmbeddingDefinition {
pub fn new<S: Into<String>>(source_column: S, embedding_name: S, dest: Option<S>) -> Self {
Self {
source_column: source_column.into(),
dest_column: dest.map(|d| d.into()),
embedding_name: embedding_name.into(),
}
}
}
/// A registry of embedding
pub trait EmbeddingRegistry: Send + Sync + std::fmt::Debug {
/// Return the names of all registered embedding functions
fn functions(&self) -> HashSet<String>;
/// Register a new [`EmbeddingFunction
/// Returns an error if the function can not be registered
fn register(&self, name: &str, function: Arc<dyn EmbeddingFunction>) -> Result<()>;
/// Get an embedding function by name
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>>;
}
/// A [`EmbeddingRegistry`] that uses in-memory [`HashMap`]s
#[derive(Debug, Default, Clone)]
pub struct MemoryRegistry {
functions: Arc<RwLock<HashMap<String, Arc<dyn EmbeddingFunction>>>>,
}
impl EmbeddingRegistry for MemoryRegistry {
fn functions(&self) -> HashSet<String> {
self.functions.read().unwrap().keys().cloned().collect()
}
fn register(&self, name: &str, function: Arc<dyn EmbeddingFunction>) -> Result<()> {
self.functions
.write()
.unwrap()
.insert(name.to_string(), function);
Ok(())
}
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
self.functions.read().unwrap().get(name).cloned()
}
}
impl MemoryRegistry {
/// Create a new `MemoryRegistry`
pub fn new() -> Self {
Self::default()
}
}
/// A record batch reader that has embeddings applied to it
/// This is a wrapper around another record batch reader that applies an embedding function
/// when reading from the record batch
pub struct WithEmbeddings<R: RecordBatchReader> {
inner: R,
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
}
/// A record batch that might have embeddings applied to it.
pub enum MaybeEmbedded<R: RecordBatchReader> {
/// The record batch reader has embeddings applied to it
Yes(WithEmbeddings<R>),
/// The record batch reader does not have embeddings applied to it
/// The inner record batch reader is returned as-is
No(R),
}
impl<R: RecordBatchReader> MaybeEmbedded<R> {
/// Create a new RecordBatchReader with embeddings applied to it if the table definition
/// specifies an embedding column and the registry contains an embedding function with that name
/// Otherwise, this is a no-op and the inner RecordBatchReader is returned.
pub fn try_new(
inner: R,
table_definition: TableDefinition,
registry: Option<Arc<dyn EmbeddingRegistry>>,
) -> Result<Self> {
if let Some(registry) = registry {
let mut embeddings = Vec::with_capacity(table_definition.column_definitions.len());
for cd in table_definition.column_definitions.iter() {
if let ColumnKind::Embedding(embedding_def) = &cd.kind {
match registry.get(&embedding_def.embedding_name) {
Some(func) => {
embeddings.push((embedding_def.clone(), func));
}
None => {
return Err(Error::EmbeddingFunctionNotFound {
name: embedding_def.embedding_name.to_string(),
reason: format!(
"Table was defined with an embedding column `{}` but no embedding function was found with that name within the registry.",
embedding_def.embedding_name
),
});
}
}
}
}
if !embeddings.is_empty() {
return Ok(Self::Yes(WithEmbeddings { inner, embeddings }));
}
};
// No embeddings to apply
Ok(Self::No(inner))
}
}
impl<R: RecordBatchReader> WithEmbeddings<R> {
pub fn new(
inner: R,
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
) -> Self {
Self { inner, embeddings }
}
}
impl<R: RecordBatchReader> WithEmbeddings<R> {
fn dest_fields(&self) -> Result<Vec<Field>> {
let schema = self.inner.schema();
self.embeddings
.iter()
.map(|(ed, func)| {
let src_field = schema.field_with_name(&ed.source_column).unwrap();
let field_name = ed
.dest_column
.clone()
.unwrap_or_else(|| format!("{}_embedding", &ed.source_column));
Ok(Field::new(
field_name,
func.dest_type()?.into_owned(),
src_field.is_nullable(),
))
})
.collect()
}
fn column_defs(&self) -> Vec<ColumnDefinition> {
let base_schema = self.inner.schema();
base_schema
.fields()
.iter()
.map(|_| ColumnDefinition {
kind: ColumnKind::Physical,
})
.chain(self.embeddings.iter().map(|(ed, _)| ColumnDefinition {
kind: ColumnKind::Embedding(ed.clone()),
}))
.collect::<Vec<_>>()
}
pub fn table_definition(&self) -> Result<TableDefinition> {
let base_schema = self.inner.schema();
let output_fields = self.dest_fields()?;
let column_definitions = self.column_defs();
let mut sb: SchemaBuilder = base_schema.as_ref().into();
sb.extend(output_fields);
let schema = Arc::new(sb.finish());
Ok(TableDefinition {
schema,
column_definitions,
})
}
}
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
type Item = std::result::Result<RecordBatch, arrow_schema::ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Yes(inner) => inner.next(),
Self::No(inner) => inner.next(),
}
}
}
impl<R: RecordBatchReader> RecordBatchReader for MaybeEmbedded<R> {
fn schema(&self) -> Arc<arrow_schema::Schema> {
match self {
Self::Yes(inner) => inner.schema(),
Self::No(inner) => inner.schema(),
}
}
}
impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
type Item = std::result::Result<RecordBatch, arrow_schema::ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
let batch = self.inner.next()?;
match batch {
Ok(mut batch) => {
// todo: parallelize this
for (fld, func) in self.embeddings.iter() {
let src_column = batch.column_by_name(&fld.source_column).unwrap();
let embedding = match func.embed(src_column.clone()) {
Ok(embedding) => embedding,
Err(e) => {
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
"Error computing embedding: {}",
e
))))
}
};
let dst_field_name = fld
.dest_column
.clone()
.unwrap_or_else(|| format!("{}_embedding", &fld.source_column));
let dst_field = Field::new(
dst_field_name,
embedding.data_type().clone(),
embedding.nulls().is_some(),
);
match batch.try_with_column(dst_field.clone(), embedding) {
Ok(b) => batch = b,
Err(e) => return Some(Err(e)),
};
}
Some(Ok(batch))
}
Err(e) => Some(Err(e)),
}
}
}
impl<R: RecordBatchReader> RecordBatchReader for WithEmbeddings<R> {
fn schema(&self) -> Arc<arrow_schema::Schema> {
self.table_definition()
.expect("table definition should be infallible at this point")
.into_rich_schema()
}
}

View File

@@ -26,9 +26,6 @@ pub enum Error {
InvalidInput { message: String },
#[snafu(display("Table '{name}' was not found"))]
TableNotFound { name: String },
#[snafu(display("Embedding function '{name}' was not found. : {reason}"))]
EmbeddingFunctionNotFound { name: String, reason: String },
#[snafu(display("Table '{name}' already exists"))]
TableAlreadyExists { name: String },
#[snafu(display("Unable to created lance dataset at {path}: {source}"))]
@@ -115,13 +112,3 @@ impl From<url::ParseError> for Error {
}
}
}
#[cfg(feature = "polars")]
impl From<polars::prelude::PolarsError> for Error {
fn from(source: polars::prelude::PolarsError) -> Self {
Self::Other {
message: "Error in Polars DataFrame integration.".to_string(),
source: Some(Box::new(source)),
}
}
}

View File

@@ -194,13 +194,10 @@
pub mod arrow;
pub mod connection;
pub mod data;
pub mod embeddings;
pub mod error;
pub mod index;
pub mod io;
pub mod ipc;
#[cfg(feature = "polars")]
mod polars_arrow_convertors;
pub mod query;
#[cfg(feature = "remote")]
pub(crate) mod remote;

View File

@@ -1,123 +0,0 @@
/// Polars and LanceDB both use Arrow for their in memory-representation, but use
/// different Rust Arrow implementations. LanceDB uses the arrow-rs crate and
/// Polars uses the polars-arrow crate.
///
/// This crate defines zero-copy conversions (of the underlying buffers)
/// between polars-arrow and arrow-rs using the C FFI.
///
/// The polars-arrow does implement conversions to and from arrow-rs, but
/// requires a feature flagged dependency on arrow-rs. The version of arrow-rs
/// depended on by polars-arrow and LanceDB may not be compatible,
/// which necessitates using the C FFI.
use crate::error::Result;
use polars::prelude::{DataFrame, Series};
use std::{mem, sync::Arc};
/// When interpreting Polars dataframes as polars-arrow record batches,
/// one must decide whether to use Arrow string/binary view types
/// instead of the standard Arrow string/binary types.
/// For now, we will not use string view types because conversions
/// for string view types from polars-arrow to arrow-rs are not yet implemented.
/// See: https://lists.apache.org/thread/w88tpz76ox8h3rxkjl4so6rg3f1rv7wt for the
/// differences in the types.
pub const POLARS_ARROW_FLAVOR: bool = false;
const IS_ARRAY_NULLABLE: bool = true;
/// Converts a Polars DataFrame schema to an Arrow RecordBatch schema.
pub fn convert_polars_df_schema_to_arrow_rb_schema(
polars_df_schema: polars::prelude::Schema,
) -> Result<Arc<arrow_schema::Schema>> {
let arrow_fields: Result<Vec<arrow_schema::Field>> = polars_df_schema
.into_iter()
.map(|(name, df_dtype)| {
let polars_arrow_dtype = df_dtype.to_arrow(POLARS_ARROW_FLAVOR);
let polars_field =
polars_arrow::datatypes::Field::new(name, polars_arrow_dtype, IS_ARRAY_NULLABLE);
convert_polars_arrow_field_to_arrow_rs_field(polars_field)
})
.collect();
Ok(Arc::new(arrow_schema::Schema::new(arrow_fields?)))
}
/// Converts an Arrow RecordBatch schema to a Polars DataFrame schema.
pub fn convert_arrow_rb_schema_to_polars_df_schema(
arrow_schema: &arrow_schema::Schema,
) -> Result<polars::prelude::Schema> {
let polars_df_fields: Result<Vec<polars::prelude::Field>> = arrow_schema
.fields()
.iter()
.map(|arrow_rs_field| {
let polars_arrow_field = convert_arrow_rs_field_to_polars_arrow_field(arrow_rs_field)?;
Ok(polars::prelude::Field::new(
arrow_rs_field.name(),
polars::datatypes::DataType::from(polars_arrow_field.data_type()),
))
})
.collect();
Ok(polars::prelude::Schema::from_iter(polars_df_fields?))
}
/// Converts an Arrow RecordBatch to a Polars DataFrame, using a provided Polars DataFrame schema.
pub fn convert_arrow_rb_to_polars_df(
arrow_rb: &arrow::record_batch::RecordBatch,
polars_schema: &polars::prelude::Schema,
) -> Result<DataFrame> {
let mut columns: Vec<Series> = Vec::with_capacity(arrow_rb.num_columns());
for (i, column) in arrow_rb.columns().iter().enumerate() {
let polars_df_dtype = polars_schema.try_get_at_index(i)?.1;
let polars_arrow_dtype = polars_df_dtype.to_arrow(POLARS_ARROW_FLAVOR);
let polars_array =
convert_arrow_rs_array_to_polars_arrow_array(column, polars_arrow_dtype)?;
columns.push(Series::from_arrow(
polars_schema.try_get_at_index(i)?.0,
polars_array,
)?);
}
Ok(DataFrame::from_iter(columns))
}
/// Converts a polars-arrow Arrow array to an arrow-rs Arrow array.
pub fn convert_polars_arrow_array_to_arrow_rs_array(
polars_array: Box<dyn polars_arrow::array::Array>,
arrow_datatype: arrow_schema::DataType,
) -> std::result::Result<arrow_array::ArrayRef, arrow_schema::ArrowError> {
let polars_c_array = polars_arrow::ffi::export_array_to_c(polars_array);
let arrow_c_array = unsafe { mem::transmute(polars_c_array) };
Ok(arrow_array::make_array(unsafe {
arrow::ffi::from_ffi_and_data_type(arrow_c_array, arrow_datatype)
}?))
}
/// Converts an arrow-rs Arrow array to a polars-arrow Arrow array.
fn convert_arrow_rs_array_to_polars_arrow_array(
arrow_rs_array: &Arc<dyn arrow_array::Array>,
polars_arrow_dtype: polars::datatypes::ArrowDataType,
) -> Result<Box<dyn polars_arrow::array::Array>> {
let arrow_c_array = arrow::ffi::FFI_ArrowArray::new(&arrow_rs_array.to_data());
let polars_c_array = unsafe { mem::transmute(arrow_c_array) };
Ok(unsafe { polars_arrow::ffi::import_array_from_c(polars_c_array, polars_arrow_dtype) }?)
}
fn convert_polars_arrow_field_to_arrow_rs_field(
polars_arrow_field: polars_arrow::datatypes::Field,
) -> Result<arrow_schema::Field> {
let polars_c_schema = polars_arrow::ffi::export_field_to_c(&polars_arrow_field);
let arrow_c_schema: arrow::ffi::FFI_ArrowSchema = unsafe { mem::transmute(polars_c_schema) };
let arrow_rs_dtype = arrow_schema::DataType::try_from(&arrow_c_schema)?;
Ok(arrow_schema::Field::new(
polars_arrow_field.name,
arrow_rs_dtype,
IS_ARRAY_NULLABLE,
))
}
fn convert_arrow_rs_field_to_polars_arrow_field(
arrow_rs_field: &arrow_schema::Field,
) -> Result<polars_arrow::datatypes::Field> {
let arrow_rs_dtype = arrow_rs_field.data_type();
let arrow_c_schema = arrow::ffi::FFI_ArrowSchema::try_from(arrow_rs_dtype)?;
let polars_c_schema: polars_arrow::ffi::ArrowSchema = unsafe { mem::transmute(arrow_c_schema) };
Ok(unsafe { polars_arrow::ffi::import_field_from_c(&polars_c_schema) }?)
}

View File

@@ -23,7 +23,6 @@ use tokio::task::spawn_blocking;
use crate::connection::{
ConnectionInternal, CreateTableBuilder, NoData, OpenTableBuilder, TableNamesBuilder,
};
use crate::embeddings::EmbeddingRegistry;
use crate::error::Result;
use crate::Table;
@@ -116,8 +115,4 @@ impl ConnectionInternal for RemoteDatabase {
async fn drop_db(&self) -> Result<()> {
todo!()
}
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
todo!()
}
}

View File

@@ -10,7 +10,7 @@ use crate::{
query::{Query, QueryExecutionOptions, VectorQuery},
table::{
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
TableDefinition, TableInternal, UpdateBuilder,
TableInternal, UpdateBuilder,
},
};
@@ -120,7 +120,4 @@ impl TableInternal for RemoteTable {
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
todo!()
}
async fn table_definition(&self) -> Result<TableDefinition> {
todo!()
}
}

View File

@@ -41,12 +41,10 @@ use lance::io::WrappingObjectStore;
use lance_index::IndexType;
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
use log::info;
use serde::{Deserialize, Serialize};
use snafu::whatever;
use crate::arrow::IntoArrow;
use crate::connection::NoData;
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
use crate::error::{Error, Result};
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
use crate::index::IndexConfig;
@@ -65,79 +63,6 @@ use self::merge::MergeInsertBuilder;
pub(crate) mod dataset;
pub mod merge;
/// Defines the type of column
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ColumnKind {
/// Columns populated by data from the user (this is the most common case)
Physical,
/// Columns populated by applying an embedding function to the input
Embedding(EmbeddingDefinition),
}
/// Defines a column in a table
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnDefinition {
/// The source of the column data
pub kind: ColumnKind,
}
#[derive(Debug, Clone)]
pub struct TableDefinition {
pub column_definitions: Vec<ColumnDefinition>,
pub schema: SchemaRef,
}
impl TableDefinition {
pub fn new(schema: SchemaRef, column_definitions: Vec<ColumnDefinition>) -> Self {
Self {
column_definitions,
schema,
}
}
pub fn new_from_schema(schema: SchemaRef) -> Self {
let column_definitions = schema
.fields()
.iter()
.map(|_| ColumnDefinition {
kind: ColumnKind::Physical,
})
.collect();
Self::new(schema, column_definitions)
}
pub fn try_from_rich_schema(schema: SchemaRef) -> Result<Self> {
let column_definitions = schema.metadata.get("lancedb::column_definitions");
if let Some(column_definitions) = column_definitions {
let column_definitions: Vec<ColumnDefinition> =
serde_json::from_str(column_definitions).map_err(|e| Error::Runtime {
message: format!("Failed to deserialize column definitions: {}", e),
})?;
Ok(Self::new(schema, column_definitions))
} else {
let column_definitions = schema
.fields()
.iter()
.map(|_| ColumnDefinition {
kind: ColumnKind::Physical,
})
.collect();
Ok(Self::new(schema, column_definitions))
}
}
pub fn into_rich_schema(self) -> SchemaRef {
// We have full control over the structure of column definitions. This should
// not fail, except for a bug
let lancedb_metadata = serde_json::to_string(&self.column_definitions).unwrap();
let mut schema_with_metadata = (*self.schema).clone();
schema_with_metadata
.metadata
.insert("lancedb::column_definitions".to_string(), lancedb_metadata);
Arc::new(schema_with_metadata)
}
}
/// Optimize the dataset.
///
/// Similar to `VACUUM` in PostgreSQL, it offers different options to
@@ -207,7 +132,6 @@ pub struct AddDataBuilder<T: IntoArrow> {
pub(crate) data: T,
pub(crate) mode: AddDataMode,
pub(crate) write_options: WriteOptions,
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
}
impl<T: IntoArrow> std::fmt::Debug for AddDataBuilder<T> {
@@ -239,7 +163,6 @@ impl<T: IntoArrow> AddDataBuilder<T> {
mode: self.mode,
parent: self.parent,
write_options: self.write_options,
embedding_registry: self.embedding_registry,
};
parent.add(without_data, data).await
}
@@ -357,7 +280,6 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
async fn checkout(&self, version: u64) -> Result<()>;
async fn checkout_latest(&self) -> Result<()>;
async fn restore(&self) -> Result<()>;
async fn table_definition(&self) -> Result<TableDefinition>;
}
/// A Table is a collection of strong typed Rows.
@@ -366,7 +288,6 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
#[derive(Clone)]
pub struct Table {
inner: Arc<dyn TableInternal>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
impl std::fmt::Display for Table {
@@ -377,20 +298,7 @@ impl std::fmt::Display for Table {
impl Table {
pub(crate) fn new(inner: Arc<dyn TableInternal>) -> Self {
Self {
inner,
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
pub(crate) fn new_with_embedding_registry(
inner: Arc<dyn TableInternal>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
) -> Self {
Self {
inner,
embedding_registry,
}
Self { inner }
}
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
@@ -432,7 +340,6 @@ impl Table {
data: batches,
mode: AddDataMode::Append,
write_options: WriteOptions::default(),
embedding_registry: Some(self.embedding_registry.clone()),
}
}
@@ -836,10 +743,11 @@ impl Table {
impl From<NativeTable> for Table {
fn from(table: NativeTable) -> Self {
Self::new(Arc::new(table))
Self {
inner: Arc::new(table),
}
}
}
/// A table in a LanceDB database.
#[derive(Debug, Clone)]
pub struct NativeTable {
@@ -1010,6 +918,7 @@ impl NativeTable {
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
None => params,
};
let storage_options = params
.store_params
.clone()
@@ -1433,11 +1342,6 @@ impl TableInternal for NativeTable {
Ok(Arc::new(Schema::from(&lance_schema)))
}
async fn table_definition(&self) -> Result<TableDefinition> {
let schema = self.schema().await?;
TableDefinition::try_from_rich_schema(schema)
}
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
Ok(self.dataset.get().await?.count_rows(filter).await?)
}
@@ -1447,9 +1351,6 @@ impl TableInternal for NativeTable {
add: AddDataBuilder<NoData>,
data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
let data =
MaybeEmbedded::try_new(data, self.table_definition().await?, add.embedding_registry)?;
let mut lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
mode: match add.mode {
AddDataMode::Append => WriteMode::Append,
@@ -1477,8 +1378,8 @@ impl TableInternal for NativeTable {
};
self.dataset.ensure_mutable().await?;
let dataset = Dataset::write(data, &self.uri, Some(lance_params)).await?;
let dataset = Dataset::write(data, &self.uri, Some(lance_params)).await?;
self.dataset.set_latest(dataset).await;
Ok(())
}

View File

@@ -1,320 +0,0 @@
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
iter::repeat,
sync::Arc,
};
use arrow::buffer::NullBuffer;
use arrow_array::{
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use futures::StreamExt;
use lancedb::{
arrow::IntoArrow,
connect,
embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry},
query::ExecutableQuery,
Error, Result,
};
#[tokio::test]
async fn test_custom_func() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
let embed_fun = MockEmbed::new("embed_fun".to_string(), 1);
db.embedding_registry()
.register("embed_fun", Arc::new(embed_fun.clone()))?;
let tbl = db
.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
&embed_fun.name,
Some("embeddings"),
))?
.execute()
.await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
}
// now make sure the embeddings are applied when
// we add new records too
tbl.add(create_some_records()?).execute().await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
}
Ok(())
}
#[tokio::test]
async fn test_custom_registry() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir)
.embedding_registry(Arc::new(MyRegistry::default()))
.execute()
.await?;
let tbl = db
.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
"func_1",
Some("embeddings"),
))?
.execute()
.await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(
embeddings.data_type(),
MockEmbed::new("func_1".to_string(), 1)
.dest_type()?
.as_ref()
);
}
Ok(())
}
#[tokio::test]
async fn test_multiple_embeddings() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
let func_1 = MockEmbed::new("func_1".to_string(), 1);
let func_2 = MockEmbed::new("func_2".to_string(), 10);
db.embedding_registry()
.register(&func_1.name, Arc::new(func_1.clone()))?;
db.embedding_registry()
.register(&func_2.name, Arc::new(func_2.clone()))?;
let tbl = db
.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
&func_1.name,
Some("first_embeddings"),
))?
.add_embedding(EmbeddingDefinition::new(
"text",
&func_2.name,
Some("second_embeddings"),
))?
.execute()
.await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("first_embeddings");
assert!(embeddings.is_some());
let second_embeddings = batch.column_by_name("second_embeddings");
assert!(second_embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
let second_embeddings = second_embeddings.unwrap();
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
}
// now make sure the embeddings are applied when
// we add new records too
tbl.add(create_some_records()?).execute().await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("first_embeddings");
assert!(embeddings.is_some());
let second_embeddings = batch.column_by_name("second_embeddings");
assert!(second_embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
let second_embeddings = second_embeddings.unwrap();
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
}
Ok(())
}
#[tokio::test]
async fn test_no_func_in_registry() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
let res = db
.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
"some_func",
Some("first_embeddings"),
));
assert!(res.is_err());
assert!(matches!(
res.err().unwrap(),
Error::EmbeddingFunctionNotFound { .. }
));
Ok(())
}
#[tokio::test]
async fn test_no_func_in_registry_on_add() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
db.embedding_registry().register(
"some_func",
Arc::new(MockEmbed::new("some_func".to_string(), 1)),
)?;
db.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
"some_func",
Some("first_embeddings"),
))?
.execute()
.await?;
let db = connect(tempdir).execute().await?;
let tbl = db.open_table("test").execute().await?;
// This should fail because 'tbl' is expecting "some_func" to be in the registry
let res = tbl.add(create_some_records()?).execute().await;
assert!(res.is_err());
assert!(matches!(
res.unwrap_err(),
crate::Error::EmbeddingFunctionNotFound { .. }
));
Ok(())
}
fn create_some_records() -> Result<impl IntoArrow> {
const TOTAL: usize = 2;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("text", DataType::Utf8, true),
]));
// Create a RecordBatch stream.
let batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(StringArray::from_iter(
repeat(Some("hello world".to_string())).take(TOTAL),
)),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
Ok(Box::new(batches))
}
#[derive(Debug)]
struct MyRegistry {
functions: HashMap<String, Arc<dyn EmbeddingFunction>>,
}
impl Default for MyRegistry {
fn default() -> Self {
let funcs: Vec<Arc<dyn EmbeddingFunction>> = vec![
Arc::new(MockEmbed::new("func_1".to_string(), 1)),
Arc::new(MockEmbed::new("func_2".to_string(), 10)),
];
Self {
functions: funcs
.into_iter()
.map(|f| (f.name().to_string(), f))
.collect(),
}
}
}
/// a mock registry that only has one function called `embed_fun`
impl EmbeddingRegistry for MyRegistry {
fn functions(&self) -> HashSet<String> {
self.functions.keys().cloned().collect()
}
fn register(&self, _name: &str, _function: Arc<dyn EmbeddingFunction>) -> Result<()> {
Err(Error::Other {
message: "MyRegistry is read-only".to_string(),
source: None,
})
}
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
self.functions.get(name).cloned()
}
}
#[derive(Debug, Clone)]
struct MockEmbed {
source_type: DataType,
dest_type: DataType,
name: String,
dim: usize,
}
impl MockEmbed {
pub fn new(name: String, dim: usize) -> Self {
Self {
source_type: DataType::Utf8,
dest_type: DataType::new_fixed_size_list(DataType::Float32, dim as _, true),
name,
dim,
}
}
}
impl EmbeddingFunction for MockEmbed {
fn name(&self) -> &str {
&self.name
}
fn source_type(&self) -> Result<Cow<DataType>> {
Ok(Cow::Borrowed(&self.source_type))
}
fn dest_type(&self) -> Result<Cow<DataType>> {
Ok(Cow::Borrowed(&self.dest_type))
}
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
// and we want to explicitly work with non-nullable arrays.
let len = source.len();
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
let field = Field::new("item", inner.data_type().clone(), false);
let arr = FixedSizeListArray::new(
Arc::new(field),
self.dim as _,
inner,
Some(NullBuffer::new_valid(len)),
);
Ok(Arc::new(arr))
}
}