Compare commits

...

17 Commits

Author SHA1 Message Date
Lance Release
291ed41c3e Bump version: 0.8.1-beta.0 → 0.8.1 2024-05-30 01:00:21 +00:00
Lance Release
fdda7b1a76 Bump version: 0.8.0 → 0.8.1-beta.0 2024-05-30 01:00:21 +00:00
Weston Pace
eb2cbedf19 feat: upgrade lance to 0.11.1 (#1338) 2024-05-29 16:28:09 -07:00
Cory Grinstead
bc139000bd feat(nodejs): add compatibility across arrow versions (#1337)
while adding some more docs & examples for the new js sdk, i ran across
a few compatibility issues when using different arrow versions. This
should fix those issues.
2024-05-29 17:36:34 -05:00
Cory Grinstead
dbea3a7544 feat: js embedding registry (#1308)
---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2024-05-29 13:12:19 -05:00
zhongpu
3bb7c546d7 fix: the bug of async connection context manager (#1333)
- add `return` for `__enter__`

The buggy code didn't return the object, therefore it will always return
None within a context manager:

```python
with await lancedb.connect_async("./.lancedb") as db:
        # db is always None
```

(BTW, why not to design an async context manager?)

- add a unit test for Async connection context manager

- update return type of `AsyncConnection.open_table` to `AsyncTable`

Although type annotation doesn't affect the functionality, it is helpful
for IDEs.
2024-05-29 09:33:32 -07:00
Cory Grinstead
2f4b70ecfe chore: clippy warnings inside java bindings (#1330)
this was causing unrelated PR's to fail.
https://github.com/lancedb/lancedb/actions/runs/9274579178/job/25517248069?pr=1308
2024-05-28 14:05:07 -05:00
Philip Meier
1ad1c0820d chore: replace semver dependency with packaging (#1311)
Fixes #1296 per title. See
https://github.com/lancedb/lancedb/pull/1298#discussion_r1603931457 Cc
@wjones127

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2024-05-28 10:05:16 -07:00
LuQQiu
db712b0f99 feat(java): add table names java api (#1279)
Add lancedb-jni and table names API

---------

Co-authored-by: Lei Xu <eddyxu@gmail.com>
2024-05-24 11:49:11 -07:00
BubbleCal
fd1a5ce788 feat: support IVF_HNSW_PQ (#1314)
this also simplifies the code of creating index with macro

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2024-05-24 18:32:00 +08:00
QianZhu
def087fc85 fix: parse index_stats for scalar index (#1319)
parse the index stats for scalar index - it is different from the index
stats for vector index
2024-05-23 13:10:46 -07:00
Lance Release
43f920182a Bump version: 0.8.0-beta.0 → 0.8.0 2024-05-23 17:32:36 +00:00
Lance Release
718963d1fb Bump version: 0.7.0 → 0.8.0-beta.0 2024-05-23 17:32:36 +00:00
Weston Pace
e4dac751e7 chore: remove working-directory from pypi upload step (#1322)
The wheels are built to `WORKDIR/target/wheels` and the step was
configured to look for them at `WORKDIR/python/target/wheels`.
2024-05-23 10:31:32 -07:00
Lance Release
aae02953eb Updating package-lock.json 2024-05-23 16:30:46 +00:00
Lance Release
1d9f76bdda Bump version: 0.5.0-beta.0 → 0.5.0 2024-05-23 16:30:27 +00:00
Lance Release
affdfc4d48 Bump version: 0.4.20 → 0.5.0-beta.0 2024-05-23 16:30:26 +00:00
78 changed files with 10471 additions and 8245 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.4.20" current_version = "0.5.0"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

85
.github/workflows/java.yml vendored Normal file
View File

@@ -0,0 +1,85 @@
name: Build and Run Java JNI Tests
on:
push:
branches:
- main
pull_request:
paths:
- java/**
- rust/**
- .github/workflows/java.yml
env:
# This env var is used by Swatinem/rust-cache@v2 for the cache
# key, so we set it to make sure it is always consistent.
CARGO_TERM_COLOR: always
# Disable full debug symbol generation to speed up CI build and keep memory down
# "1" means line tables only, which is useful for panic tracebacks.
RUSTFLAGS: "-C debuginfo=1"
RUST_BACKTRACE: "1"
# according to: https://matklad.github.io/2021/09/04/fast-rust-builds.html
# CI builds are faster with incremental disabled.
CARGO_INCREMENTAL: "0"
CARGO_BUILD_JOBS: "1"
jobs:
linux-build:
runs-on: ubuntu-22.04
name: ubuntu-22.04 + Java 11 & 17
defaults:
run:
working-directory: ./java
steps:
- name: Checkout repository
uses: actions/checkout@v4
- uses: Swatinem/rust-cache@v2
with:
workspaces: java/core/lancedb-jni
- name: Run cargo fmt
run: cargo fmt --check
working-directory: ./java/core/lancedb-jni
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Install Java 17
uses: actions/setup-java@v4
with:
distribution: temurin
java-version: 17
cache: "maven"
- run: echo "JAVA_17=$JAVA_HOME" >> $GITHUB_ENV
- name: Install Java 11
uses: actions/setup-java@v4
with:
distribution: temurin
java-version: 11
cache: "maven"
- name: Java Style Check
run: mvn checkstyle:check
# Disable because of issues in lancedb rust core code
# - name: Rust Clippy
# working-directory: java/core/lancedb-jni
# run: cargo clippy --all-targets -- -D warnings
- name: Running tests with Java 11
run: mvn clean test
- name: Running tests with Java 17
run: |
export JAVA_TOOL_OPTIONS="$JAVA_TOOL_OPTIONS \
-XX:+IgnoreUnrecognizedVMOptions \
--add-opens=java.base/java.lang=ALL-UNNAMED \
--add-opens=java.base/java.lang.invoke=ALL-UNNAMED \
--add-opens=java.base/java.lang.reflect=ALL-UNNAMED \
--add-opens=java.base/java.io=ALL-UNNAMED \
--add-opens=java.base/java.net=ALL-UNNAMED \
--add-opens=java.base/java.nio=ALL-UNNAMED \
--add-opens=java.base/java.util=ALL-UNNAMED \
--add-opens=java.base/java.util.concurrent=ALL-UNNAMED \
--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED \
--add-opens=java.base/jdk.internal.ref=ALL-UNNAMED \
--add-opens=java.base/sun.nio.ch=ALL-UNNAMED \
--add-opens=java.base/sun.nio.cs=ALL-UNNAMED \
--add-opens=java.base/sun.security.action=ALL-UNNAMED \
--add-opens=java.base/sun.util.calendar=ALL-UNNAMED \
--add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED \
-Djdk.reflect.useDirectMethodHandle=false \
-Dio.netty.tryReflectionSetAccessible=true"
JAVA_HOME=$JAVA_17 mvn clean test

View File

@@ -27,7 +27,6 @@ runs:
echo "repo=pypi" >> $GITHUB_OUTPUT echo "repo=pypi" >> $GITHUB_OUTPUT
fi fi
- name: Publish to PyPI - name: Publish to PyPI
working-directory: python
shell: bash shell: bash
env: env:
FURY_TOKEN: ${{ inputs.fury_token }} FURY_TOKEN: ${{ inputs.fury_token }}

View File

@@ -14,7 +14,7 @@ repos:
hooks: hooks:
- id: local-biome-check - id: local-biome-check
name: biome check name: biome check
entry: npx biome check entry: npx @biomejs/biome check --config-path nodejs/biome.json nodejs/
language: system language: system
types: [text] types: [text]
files: "nodejs/.*" files: "nodejs/.*"

View File

@@ -1,5 +1,5 @@
[workspace] [workspace]
members = ["rust/ffi/node", "rust/lancedb", "nodejs", "python"] members = ["rust/ffi/node", "rust/lancedb", "nodejs", "python", "java/core/lancedb-jni"]
# Python package needs to be built by maturin. # Python package needs to be built by maturin.
exclude = ["python"] exclude = ["python"]
resolver = "2" resolver = "2"
@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"] categories = ["database-implementations"]
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=0.11.0", "features" = ["dynamodb"] } lance = { "version" = "=0.11.1", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.11.0" } lance-index = { "version" = "=0.11.1" }
lance-linalg = { "version" = "=0.11.0" } lance-linalg = { "version" = "=0.11.1" }
lance-testing = { "version" = "=0.11.0" } lance-testing = { "version" = "=0.11.1" }
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "51.0", optional = false } arrow = { version = "51.0", optional = false }
arrow-array = "51.0" arrow-array = "51.0"

View File

@@ -0,0 +1,27 @@
[package]
name = "lancedb-jni"
description = "JNI bindings for LanceDB"
# TODO modify lancedb/Cargo.toml for version and dependencies
version = "0.4.18"
edition.workspace = true
repository.workspace = true
readme.workspace = true
license.workspace = true
keywords.workspace = true
categories.workspace = true
publish = false
[lib]
crate-type = ["cdylib"]
[dependencies]
lancedb = { path = "../../../rust/lancedb" }
lance = { workspace = true }
arrow = { workspace = true, features = ["ffi"] }
arrow-schema.workspace = true
tokio = "1.23"
jni = "0.21.1"
snafu.workspace = true
lazy_static.workspace = true
serde = { version = "^1" }
serde_json = { version = "1" }

View File

@@ -0,0 +1,130 @@
use crate::ffi::JNIEnvExt;
use crate::traits::IntoJava;
use crate::{Error, RT};
use jni::objects::{JObject, JString, JValue};
use jni::JNIEnv;
pub const NATIVE_CONNECTION: &str = "nativeConnectionHandle";
use crate::Result;
use lancedb::connection::{connect, Connection};
#[derive(Clone)]
pub struct BlockingConnection {
pub(crate) inner: Connection,
}
impl BlockingConnection {
pub fn create(dataset_uri: &str) -> Result<Self> {
let inner = RT.block_on(connect(dataset_uri).execute())?;
Ok(Self { inner })
}
pub fn table_names(
&self,
start_after: Option<String>,
limit: Option<i32>,
) -> Result<Vec<String>> {
let mut op = self.inner.table_names();
if let Some(start_after) = start_after {
op = op.start_after(start_after);
}
if let Some(limit) = limit {
op = op.limit(limit as u32);
}
Ok(RT.block_on(op.execute())?)
}
}
impl IntoJava for BlockingConnection {
fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> JObject<'a> {
attach_native_connection(env, self)
}
}
fn attach_native_connection<'local>(
env: &mut JNIEnv<'local>,
connection: BlockingConnection,
) -> JObject<'local> {
let j_connection = create_java_connection_object(env);
// This block sets a native Rust object (Connection) as a field in the Java object (j_Connection).
// Caution: This creates a potential for memory leaks. The Rust object (Connection) is not
// automatically garbage-collected by Java, and its memory will not be freed unless
// explicitly handled.
//
// To prevent memory leaks, ensure the following:
// 1. The Java object (`j_Connection`) should implement the `java.io.Closeable` interface.
// 2. Users of this Java object should be instructed to always use it within a try-with-resources
// statement (or manually call the `close()` method) to ensure that `self.close()` is invoked.
match unsafe { env.set_rust_field(&j_connection, NATIVE_CONNECTION, connection) } {
Ok(_) => j_connection,
Err(err) => {
env.throw_new(
"java/lang/RuntimeException",
format!("Failed to set native handle for Connection: {}", err),
)
.expect("Error throwing exception");
JObject::null()
}
}
}
fn create_java_connection_object<'a>(env: &mut JNIEnv<'a>) -> JObject<'a> {
env.new_object("com/lancedb/lancedb/Connection", "()V", &[])
.expect("Failed to create Java Lance Connection instance")
}
#[no_mangle]
pub extern "system" fn Java_com_lancedb_lancedb_Connection_releaseNativeConnection(
mut env: JNIEnv,
j_connection: JObject,
) {
let _: BlockingConnection = unsafe {
env.take_rust_field(j_connection, NATIVE_CONNECTION)
.expect("Failed to take native Connection handle")
};
}
#[no_mangle]
pub extern "system" fn Java_com_lancedb_lancedb_Connection_connect<'local>(
mut env: JNIEnv<'local>,
_obj: JObject,
dataset_uri_object: JString,
) -> JObject<'local> {
let dataset_uri: String = ok_or_throw!(env, env.get_string(&dataset_uri_object)).into();
let blocking_connection = ok_or_throw!(env, BlockingConnection::create(&dataset_uri));
blocking_connection.into_java(&mut env)
}
#[no_mangle]
pub extern "system" fn Java_com_lancedb_lancedb_Connection_tableNames<'local>(
mut env: JNIEnv<'local>,
j_connection: JObject,
start_after_obj: JObject, // Optional<String>
limit_obj: JObject, // Optional<Integer>
) -> JObject<'local> {
ok_or_throw!(
env,
inner_table_names(&mut env, j_connection, start_after_obj, limit_obj)
)
}
fn inner_table_names<'local>(
env: &mut JNIEnv<'local>,
j_connection: JObject,
start_after_obj: JObject, // Optional<String>
limit_obj: JObject, // Optional<Integer>
) -> Result<JObject<'local>> {
let start_after = env.get_string_opt(&start_after_obj)?;
let limit = env.get_int_opt(&limit_obj)?;
let conn =
unsafe { env.get_rust_field::<_, _, BlockingConnection>(j_connection, NATIVE_CONNECTION) }?;
let table_names = conn.table_names(start_after, limit)?;
drop(conn);
let j_names = env.new_object("java/util/ArrayList", "()V", &[])?;
for item in table_names {
let jstr_item = env.new_string(item)?;
let item_jobj = JObject::from(jstr_item);
let item_gen = JValue::Object(&item_jobj);
env.call_method(&j_names, "add", "(Ljava/lang/Object;)Z", &[item_gen])?;
}
Ok(j_names)
}

View File

@@ -0,0 +1,225 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::str::Utf8Error;
use arrow_schema::ArrowError;
use jni::errors::Error as JniError;
use serde_json::Error as JsonError;
use snafu::{Location, Snafu};
type BoxedError = Box<dyn std::error::Error + Send + Sync + 'static>;
/// Java Exception types
pub enum JavaException {
IllegalArgumentException,
IOException,
RuntimeException,
}
impl JavaException {
pub fn as_str(&self) -> &str {
match self {
Self::IllegalArgumentException => "java/lang/IllegalArgumentException",
Self::IOException => "java/io/IOException",
Self::RuntimeException => "java/lang/RuntimeException",
}
}
}
/// TODO(lu) change to lancedb-jni
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
#[snafu(display("JNI error: {message}, {location}"))]
Jni { message: String, location: Location },
#[snafu(display("Invalid argument: {message}, {location}"))]
InvalidArgument { message: String, location: Location },
#[snafu(display("IO error: {source}, {location}"))]
IO {
source: BoxedError,
location: Location,
},
#[snafu(display("Arrow error: {message}, {location}"))]
Arrow { message: String, location: Location },
#[snafu(display("Index error: {message}, {location}"))]
Index { message: String, location: Location },
#[snafu(display("JSON error: {message}, {location}"))]
JSON { message: String, location: Location },
#[snafu(display("Dataset at path {path} was not found, {location}"))]
DatasetNotFound { path: String, location: Location },
#[snafu(display("Dataset already exists: {uri}, {location}"))]
DatasetAlreadyExists { uri: String, location: Location },
#[snafu(display("Table '{name}' already exists"))]
TableAlreadyExists { name: String },
#[snafu(display("Table '{name}' was not found"))]
TableNotFound { name: String },
#[snafu(display("Invalid table name '{name}': {reason}"))]
InvalidTableName { name: String, reason: String },
#[snafu(display("Embedding function '{name}' was not found: {reason}, {location}"))]
EmbeddingFunctionNotFound {
name: String,
reason: String,
location: Location,
},
#[snafu(display("Other Lance error: {message}, {location}"))]
OtherLance { message: String, location: Location },
#[snafu(display("Other LanceDB error: {message}, {location}"))]
OtherLanceDB { message: String, location: Location },
}
impl Error {
/// Throw as Java Exception
pub fn throw(&self, env: &mut jni::JNIEnv) {
match self {
Self::InvalidArgument { .. }
| Self::DatasetNotFound { .. }
| Self::DatasetAlreadyExists { .. }
| Self::TableAlreadyExists { .. }
| Self::TableNotFound { .. }
| Self::InvalidTableName { .. }
| Self::EmbeddingFunctionNotFound { .. } => {
self.throw_as(env, JavaException::IllegalArgumentException)
}
Self::IO { .. } | Self::Index { .. } => self.throw_as(env, JavaException::IOException),
Self::Arrow { .. }
| Self::JSON { .. }
| Self::OtherLance { .. }
| Self::OtherLanceDB { .. }
| Self::Jni { .. } => self.throw_as(env, JavaException::RuntimeException),
}
}
/// Throw as an concrete Java Exception
pub fn throw_as(&self, env: &mut jni::JNIEnv, exception: JavaException) {
let message = &format!(
"Error when throwing Java exception: {}:{}",
exception.as_str(),
self
);
env.throw_new(exception.as_str(), self.to_string())
.expect(message);
}
}
pub type Result<T> = std::result::Result<T, Error>;
trait ToSnafuLocation {
fn to_snafu_location(&'static self) -> snafu::Location;
}
impl ToSnafuLocation for std::panic::Location<'static> {
fn to_snafu_location(&'static self) -> snafu::Location {
snafu::Location::new(self.file(), self.line(), self.column())
}
}
impl From<JniError> for Error {
#[track_caller]
fn from(source: JniError) -> Self {
Self::Jni {
message: source.to_string(),
location: std::panic::Location::caller().to_snafu_location(),
}
}
}
impl From<Utf8Error> for Error {
#[track_caller]
fn from(source: Utf8Error) -> Self {
Self::InvalidArgument {
message: source.to_string(),
location: std::panic::Location::caller().to_snafu_location(),
}
}
}
impl From<ArrowError> for Error {
#[track_caller]
fn from(source: ArrowError) -> Self {
Self::Arrow {
message: source.to_string(),
location: std::panic::Location::caller().to_snafu_location(),
}
}
}
impl From<JsonError> for Error {
#[track_caller]
fn from(source: JsonError) -> Self {
Self::JSON {
message: source.to_string(),
location: std::panic::Location::caller().to_snafu_location(),
}
}
}
impl From<lance::Error> for Error {
#[track_caller]
fn from(source: lance::Error) -> Self {
match source {
lance::Error::DatasetNotFound {
path,
source: _,
location,
} => Self::DatasetNotFound { path, location },
lance::Error::DatasetAlreadyExists { uri, location } => {
Self::DatasetAlreadyExists { uri, location }
}
lance::Error::IO { source, location } => Self::IO { source, location },
lance::Error::Arrow { message, location } => Self::Arrow { message, location },
lance::Error::Index { message, location } => Self::Index { message, location },
lance::Error::InvalidInput { source, location } => Self::InvalidArgument {
message: source.to_string(),
location,
},
_ => Self::OtherLance {
message: source.to_string(),
location: std::panic::Location::caller().to_snafu_location(),
},
}
}
}
impl From<lancedb::Error> for Error {
#[track_caller]
fn from(source: lancedb::Error) -> Self {
match source {
lancedb::Error::InvalidTableName { name, reason } => {
Self::InvalidTableName { name, reason }
}
lancedb::Error::InvalidInput { message } => Self::InvalidArgument {
message,
location: std::panic::Location::caller().to_snafu_location(),
},
lancedb::Error::TableNotFound { name } => Self::TableNotFound { name },
lancedb::Error::TableAlreadyExists { name } => Self::TableAlreadyExists { name },
lancedb::Error::EmbeddingFunctionNotFound { name, reason } => {
Self::EmbeddingFunctionNotFound {
name,
reason,
location: std::panic::Location::caller().to_snafu_location(),
}
}
lancedb::Error::Arrow { source } => Self::Arrow {
message: source.to_string(),
location: std::panic::Location::caller().to_snafu_location(),
},
lancedb::Error::Lance { source } => Self::from(source),
_ => Self::OtherLanceDB {
message: source.to_string(),
location: std::panic::Location::caller().to_snafu_location(),
},
}
}
}

View File

@@ -0,0 +1,204 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use core::slice;
use jni::objects::{JByteBuffer, JObjectArray, JString};
use jni::sys::jobjectArray;
use jni::{objects::JObject, JNIEnv};
use crate::error::{Error, Result};
/// TODO(lu) import from lance-jni without duplicate
/// Extend JNIEnv with helper functions.
pub trait JNIEnvExt {
/// Get integers from Java List<Integer> object.
fn get_integers(&mut self, obj: &JObject) -> Result<Vec<i32>>;
/// Get strings from Java List<String> object.
fn get_strings(&mut self, obj: &JObject) -> Result<Vec<String>>;
/// Get strings from Java String[] object.
/// Note that get Option<Vec<String>> from Java Optional<String[]> just doesn't work.
#[allow(unused)]
fn get_strings_array(&mut self, obj: jobjectArray) -> Result<Vec<String>>;
/// Get Option<String> from Java Optional<String>.
fn get_string_opt(&mut self, obj: &JObject) -> Result<Option<String>>;
/// Get Option<Vec<String>> from Java Optional<List<String>>.
#[allow(unused)]
fn get_strings_opt(&mut self, obj: &JObject) -> Result<Option<Vec<String>>>;
/// Get Option<i32> from Java Optional<Integer>.
fn get_int_opt(&mut self, obj: &JObject) -> Result<Option<i32>>;
/// Get Option<Vec<i32>> from Java Optional<List<Integer>>.
fn get_ints_opt(&mut self, obj: &JObject) -> Result<Option<Vec<i32>>>;
/// Get Option<i64> from Java Optional<Long>.
#[allow(unused)]
fn get_long_opt(&mut self, obj: &JObject) -> Result<Option<i64>>;
/// Get Option<u64> from Java Optional<Long>.
#[allow(unused)]
fn get_u64_opt(&mut self, obj: &JObject) -> Result<Option<u64>>;
/// Get Option<&[u8]> from Java Optional<ByteBuffer>.
#[allow(unused)]
fn get_bytes_opt(&mut self, obj: &JObject) -> Result<Option<&[u8]>>;
fn get_optional<T, F>(&mut self, obj: &JObject, f: F) -> Result<Option<T>>
where
F: FnOnce(&mut JNIEnv, &JObject) -> Result<T>;
}
impl JNIEnvExt for JNIEnv<'_> {
fn get_integers(&mut self, obj: &JObject) -> Result<Vec<i32>> {
let list = self.get_list(obj)?;
let mut iter = list.iter(self)?;
let mut results = Vec::with_capacity(list.size(self)? as usize);
while let Some(elem) = iter.next(self)? {
let int_obj = self.call_method(elem, "intValue", "()I", &[])?;
let int_value = int_obj.i()?;
results.push(int_value);
}
Ok(results)
}
fn get_strings(&mut self, obj: &JObject) -> Result<Vec<String>> {
let list = self.get_list(obj)?;
let mut iter = list.iter(self)?;
let mut results = Vec::with_capacity(list.size(self)? as usize);
while let Some(elem) = iter.next(self)? {
let jstr = JString::from(elem);
let val = self.get_string(&jstr)?;
results.push(val.to_str()?.to_string())
}
Ok(results)
}
fn get_strings_array(&mut self, obj: jobjectArray) -> Result<Vec<String>> {
let jobject_array = unsafe { JObjectArray::from_raw(obj) };
let array_len = self.get_array_length(&jobject_array)?;
let mut res: Vec<String> = Vec::new();
for i in 0..array_len {
let item: JString = self.get_object_array_element(&jobject_array, i)?.into();
res.push(self.get_string(&item)?.into());
}
Ok(res)
}
fn get_string_opt(&mut self, obj: &JObject) -> Result<Option<String>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_string_obj = java_obj_gen.l()?;
let jstr = JString::from(java_string_obj);
let val = env.get_string(&jstr)?;
Ok(val.to_str()?.to_string())
})
}
fn get_strings_opt(&mut self, obj: &JObject) -> Result<Option<Vec<String>>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_list_obj = java_obj_gen.l()?;
env.get_strings(&java_list_obj)
})
}
fn get_int_opt(&mut self, obj: &JObject) -> Result<Option<i32>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_int_obj = java_obj_gen.l()?;
let int_obj = env.call_method(java_int_obj, "intValue", "()I", &[])?;
let int_value = int_obj.i()?;
Ok(int_value)
})
}
fn get_ints_opt(&mut self, obj: &JObject) -> Result<Option<Vec<i32>>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_list_obj = java_obj_gen.l()?;
env.get_integers(&java_list_obj)
})
}
fn get_long_opt(&mut self, obj: &JObject) -> Result<Option<i64>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_long_obj = java_obj_gen.l()?;
let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[])?;
let long_value = long_obj.j()?;
Ok(long_value)
})
}
fn get_u64_opt(&mut self, obj: &JObject) -> Result<Option<u64>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_long_obj = java_obj_gen.l()?;
let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[])?;
let long_value = long_obj.j()?;
Ok(long_value as u64)
})
}
fn get_bytes_opt(&mut self, obj: &JObject) -> Result<Option<&[u8]>> {
self.get_optional(obj, |env, inner_obj| {
let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?;
let java_byte_buffer_obj = java_obj_gen.l()?;
let j_byte_buffer = JByteBuffer::from(java_byte_buffer_obj);
let raw_data = env.get_direct_buffer_address(&j_byte_buffer)?;
let capacity = env.get_direct_buffer_capacity(&j_byte_buffer)?;
let data = unsafe { slice::from_raw_parts(raw_data, capacity) };
Ok(data)
})
}
fn get_optional<T, F>(&mut self, obj: &JObject, f: F) -> Result<Option<T>>
where
F: FnOnce(&mut JNIEnv, &JObject) -> Result<T>,
{
if obj.is_null() {
return Ok(None);
}
let is_empty = self.call_method(obj, "isEmpty", "()Z", &[])?;
if is_empty.z()? {
// TODO(lu): put get java object into here cuz can only get java Object
Ok(None)
} else {
f(self, obj).map(Some)
}
}
}
#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseInts(
mut env: JNIEnv,
_obj: JObject,
list_obj: JObject, // List<Integer>
) {
ok_or_throw_without_return!(env, env.get_integers(&list_obj));
}
#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseIntsOpt(
mut env: JNIEnv,
_obj: JObject,
list_obj: JObject, // Optional<List<Integer>>
) {
ok_or_throw_without_return!(env, env.get_ints_opt(&list_obj));
}

View File

@@ -0,0 +1,68 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use lazy_static::lazy_static;
// TODO import from lance-jni without duplicate
#[macro_export]
macro_rules! ok_or_throw {
($env:expr, $result:expr) => {
match $result {
Ok(value) => value,
Err(err) => {
Error::from(err).throw(&mut $env);
return JObject::null();
}
}
};
}
macro_rules! ok_or_throw_without_return {
($env:expr, $result:expr) => {
match $result {
Ok(value) => value,
Err(err) => {
Error::from(err).throw(&mut $env);
return;
}
}
};
}
#[macro_export]
macro_rules! ok_or_throw_with_return {
($env:expr, $result:expr, $ret:expr) => {
match $result {
Ok(value) => value,
Err(err) => {
Error::from(err).throw(&mut $env);
return $ret;
}
}
};
}
mod connection;
pub mod error;
mod ffi;
mod traits;
pub use error::{Error, Result};
lazy_static! {
static ref RT: tokio::runtime::Runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("Failed to create tokio runtime");
}

View File

@@ -0,0 +1,122 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use jni::objects::{JMap, JObject, JString, JValue};
use jni::JNIEnv;
use crate::Result;
pub trait FromJObject<T> {
fn extract(&self) -> Result<T>;
}
/// Convert a Rust type into a Java Object.
pub trait IntoJava {
fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> JObject<'a>;
}
impl FromJObject<i32> for JObject<'_> {
fn extract(&self) -> Result<i32> {
Ok(JValue::from(self).i()?)
}
}
impl FromJObject<i64> for JObject<'_> {
fn extract(&self) -> Result<i64> {
Ok(JValue::from(self).j()?)
}
}
impl FromJObject<f32> for JObject<'_> {
fn extract(&self) -> Result<f32> {
Ok(JValue::from(self).f()?)
}
}
impl FromJObject<f64> for JObject<'_> {
fn extract(&self) -> Result<f64> {
Ok(JValue::from(self).d()?)
}
}
pub trait FromJString {
fn extract(&self, env: &mut JNIEnv) -> Result<String>;
}
impl FromJString for JString<'_> {
fn extract(&self, env: &mut JNIEnv) -> Result<String> {
Ok(env.get_string(self)?.into())
}
}
pub trait JMapExt {
#[allow(dead_code)]
fn get_string(&self, env: &mut JNIEnv, key: &str) -> Result<Option<String>>;
#[allow(dead_code)]
fn get_i32(&self, env: &mut JNIEnv, key: &str) -> Result<Option<i32>>;
#[allow(dead_code)]
fn get_i64(&self, env: &mut JNIEnv, key: &str) -> Result<Option<i64>>;
#[allow(dead_code)]
fn get_f32(&self, env: &mut JNIEnv, key: &str) -> Result<Option<f32>>;
#[allow(dead_code)]
fn get_f64(&self, env: &mut JNIEnv, key: &str) -> Result<Option<f64>>;
}
fn get_map_value<T>(env: &mut JNIEnv, map: &JMap, key: &str) -> Result<Option<T>>
where
for<'a> JObject<'a>: FromJObject<T>,
{
let key_obj: JObject = env.new_string(key)?.into();
if let Some(value) = map.get(env, &key_obj)? {
if value.is_null() {
Ok(None)
} else {
Ok(Some(value.extract()?))
}
} else {
Ok(None)
}
}
impl JMapExt for JMap<'_, '_, '_> {
fn get_string(&self, env: &mut JNIEnv, key: &str) -> Result<Option<String>> {
let key_obj: JObject = env.new_string(key)?.into();
if let Some(value) = self.get(env, &key_obj)? {
let value_str: JString = value.into();
Ok(Some(value_str.extract(env)?))
} else {
Ok(None)
}
}
fn get_i32(&self, env: &mut JNIEnv, key: &str) -> Result<Option<i32>> {
get_map_value(env, self, key)
}
fn get_i64(&self, env: &mut JNIEnv, key: &str) -> Result<Option<i64>> {
get_map_value(env, self, key)
}
fn get_f32(&self, env: &mut JNIEnv, key: &str) -> Result<Option<f32>> {
get_map_value(env, self, key)
}
fn get_f64(&self, env: &mut JNIEnv, key: &str) -> Result<Option<f64>> {
get_map_value(env, self, key)
}
}

94
java/core/pom.xml Normal file
View File

@@ -0,0 +1,94 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>lancedb-core</artifactId>
<name>LanceDB Core</name>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-memory-netty</artifactId>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-c-data</artifactId>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-dataset</artifactId>
</dependency>
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
</dependency>
<dependency>
<groupId>org.questdb</groupId>
<artifactId>jar-jni</artifactId>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
<profile>
<id>build-jni</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.questdb</groupId>
<artifactId>rust-maven-plugin</artifactId>
<version>1.1.1</version>
<executions>
<execution>
<id>lancedb-jni</id>
<goals>
<goal>build</goal>
</goals>
<configuration>
<path>lancedb-jni</path>
<!--<release>true</release>-->
<!-- Copy native libraries to target/classes for runtime access -->
<copyTo>${project.build.directory}/classes/nativelib</copyTo>
<copyWithPlatformDir>true</copyWithPlatformDir>
</configuration>
</execution>
<execution>
<id>lancedb-jni-test</id>
<goals>
<goal>test</goal>
</goals>
<configuration>
<path>lancedb-jni</path>
<release>false</release>
<verbosity>-v</verbosity>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
</profiles>
</project>

View File

@@ -0,0 +1,120 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.lancedb.lancedb;
import io.questdb.jar.jni.JarJniLoader;
import java.io.Closeable;
import java.util.List;
import java.util.Optional;
/**
* Represents LanceDB database.
*/
public class Connection implements Closeable {
static {
JarJniLoader.loadLib(Connection.class, "/nativelib", "lancedb_jni");
}
private long nativeConnectionHandle;
/**
* Connect to a LanceDB instance.
*/
public static native Connection connect(String uri);
/**
* Get the names of all tables in the database. The names are sorted in
* ascending order.
*
* @return the table names
*/
public List<String> tableNames() {
return tableNames(Optional.empty(), Optional.empty());
}
/**
* Get the names of filtered tables in the database. The names are sorted in
* ascending order.
*
* @param limit The number of results to return.
* @return the table names
*/
public List<String> tableNames(int limit) {
return tableNames(Optional.empty(), Optional.of(limit));
}
/**
* Get the names of filtered tables in the database. The names are sorted in
* ascending order.
*
* @param startAfter If present, only return names that come lexicographically after the supplied
* value. This can be combined with limit to implement pagination
* by setting this to the last table name from the previous page.
* @return the table names
*/
public List<String> tableNames(String startAfter) {
return tableNames(Optional.of(startAfter), Optional.empty());
}
/**
* Get the names of filtered tables in the database. The names are sorted in
* ascending order.
*
* @param startAfter If present, only return names that come lexicographically after the supplied
* value. This can be combined with limit to implement pagination
* by setting this to the last table name from the previous page.
* @param limit The number of results to return.
* @return the table names
*/
public List<String> tableNames(String startAfter, int limit) {
return tableNames(Optional.of(startAfter), Optional.of(limit));
}
/**
* Get the names of filtered tables in the database. The names are sorted in
* ascending order.
*
* @param startAfter If present, only return names that come lexicographically after the supplied
* value. This can be combined with limit to implement pagination
* by setting this to the last table name from the previous page.
* @param limit The number of results to return.
* @return the table names
*/
public native List<String> tableNames(
Optional<String> startAfter, Optional<Integer> limit);
/**
* Closes this connection and releases any system resources associated with it. If
* the connection is
* already closed, then invoking this method has no effect.
*/
@Override
public void close() {
if (nativeConnectionHandle != 0) {
releaseNativeConnection(nativeConnectionHandle);
nativeConnectionHandle = 0;
}
}
/**
* Native method to release the Lance connection resources associated with the
* given handle.
*
* @param handle The native handle to the connection resource.
*/
private native void releaseNativeConnection(long handle);
private Connection() {}
}

View File

@@ -0,0 +1,135 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.lancedb.lancedb;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.nio.file.Path;
import java.util.List;
import java.net.URL;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
public class ConnectionTest {
private static final String[] TABLE_NAMES = {
"dataset_version",
"new_empty_dataset",
"test",
"write_stream"
};
@TempDir
static Path tempDir; // Temporary directory for the tests
private static URL lanceDbURL;
@BeforeAll
static void setUp() {
ClassLoader classLoader = ConnectionTest.class.getClassLoader();
lanceDbURL = classLoader.getResource("example_db");
}
@Test
void emptyDB() {
String databaseUri = tempDir.resolve("emptyDB").toString();
try (Connection conn = Connection.connect(databaseUri)) {
List<String> tableNames = conn.tableNames();
assertTrue(tableNames.isEmpty());
}
}
@Test
void tableNames() {
try (Connection conn = Connection.connect(lanceDbURL.toString())) {
List<String> tableNames = conn.tableNames();
assertEquals(4, tableNames.size());
for (int i = 0; i < TABLE_NAMES.length; i++) {
assertEquals(TABLE_NAMES[i], tableNames.get(i));
}
}
}
@Test
void tableNamesStartAfter() {
try (Connection conn = Connection.connect(lanceDbURL.toString())) {
assertTableNamesStartAfter(conn, TABLE_NAMES[0], 3, TABLE_NAMES[1], TABLE_NAMES[2], TABLE_NAMES[3]);
assertTableNamesStartAfter(conn, TABLE_NAMES[1], 2, TABLE_NAMES[2], TABLE_NAMES[3]);
assertTableNamesStartAfter(conn, TABLE_NAMES[2], 1, TABLE_NAMES[3]);
assertTableNamesStartAfter(conn, TABLE_NAMES[3], 0);
assertTableNamesStartAfter(conn, "a_dataset", 4, TABLE_NAMES[0], TABLE_NAMES[1], TABLE_NAMES[2], TABLE_NAMES[3]);
assertTableNamesStartAfter(conn, "o_dataset", 2, TABLE_NAMES[2], TABLE_NAMES[3]);
assertTableNamesStartAfter(conn, "v_dataset", 1, TABLE_NAMES[3]);
assertTableNamesStartAfter(conn, "z_dataset", 0);
}
}
private void assertTableNamesStartAfter(Connection conn, String startAfter, int expectedSize, String... expectedNames) {
List<String> tableNames = conn.tableNames(startAfter);
assertEquals(expectedSize, tableNames.size());
for (int i = 0; i < expectedNames.length; i++) {
assertEquals(expectedNames[i], tableNames.get(i));
}
}
@Test
void tableNamesLimit() {
try (Connection conn = Connection.connect(lanceDbURL.toString())) {
for (int i = 0; i <= TABLE_NAMES.length; i++) {
List<String> tableNames = conn.tableNames(i);
assertEquals(i, tableNames.size());
for (int j = 0; j < i; j++) {
assertEquals(TABLE_NAMES[j], tableNames.get(j));
}
}
}
}
@Test
void tableNamesStartAfterLimit() {
try (Connection conn = Connection.connect(lanceDbURL.toString())) {
List<String> tableNames = conn.tableNames(TABLE_NAMES[0], 2);
assertEquals(2, tableNames.size());
assertEquals(TABLE_NAMES[1], tableNames.get(0));
assertEquals(TABLE_NAMES[2], tableNames.get(1));
tableNames = conn.tableNames(TABLE_NAMES[1], 1);
assertEquals(1, tableNames.size());
assertEquals(TABLE_NAMES[2], tableNames.get(0));
tableNames = conn.tableNames(TABLE_NAMES[2], 2);
assertEquals(1, tableNames.size());
assertEquals(TABLE_NAMES[3], tableNames.get(0));
tableNames = conn.tableNames(TABLE_NAMES[3], 2);
assertEquals(0, tableNames.size());
tableNames = conn.tableNames(TABLE_NAMES[0], 0);
assertEquals(0, tableNames.size());
// Limit larger than the number of remaining tables
tableNames = conn.tableNames(TABLE_NAMES[0], 10);
assertEquals(3, tableNames.size());
assertEquals(TABLE_NAMES[1], tableNames.get(0));
assertEquals(TABLE_NAMES[2], tableNames.get(1));
assertEquals(TABLE_NAMES[3], tableNames.get(2));
// Start after a value not in the list
tableNames = conn.tableNames("non_existent_table", 2);
assertEquals(2, tableNames.size());
assertEquals(TABLE_NAMES[2], tableNames.get(0));
assertEquals(TABLE_NAMES[3], tableNames.get(1));
// Start after the last table with a limit
tableNames = conn.tableNames(TABLE_NAMES[3], 1);
assertEquals(0, tableNames.size());
}
}
}

View File

@@ -0,0 +1 @@
$d51afd07-e3cd-4c76-9b9b-787e13fd55b0<62>=id <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>*int3208name <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>*string08

View File

@@ -0,0 +1 @@
$15648e72-076f-4ef1-8b90-10d305b95b3b<33>=id <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>*int3208name <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>*string08

View File

@@ -0,0 +1 @@
$a3689caf-4f6b-4afc-a3c7-97af75661843<34>oitem <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>*string8price <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>*double80vector <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>*fixed_size_list:float:28

129
java/pom.xml Normal file
View File

@@ -0,0 +1,129 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.1-SNAPSHOT</version>
<packaging>pom</packaging>
<name>Lance Parent</name>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<arrow.version>15.0.0</arrow.version>
</properties>
<modules>
<module>core</module>
</modules>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
<version>${arrow.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-memory-netty</artifactId>
<version>${arrow.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-c-data</artifactId>
<version>${arrow.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-dataset</artifactId>
<version>${arrow.version}</version>
</dependency>
<dependency>
<groupId>org.questdb</groupId>
<artifactId>jar-jni</artifactId>
<version>1.1.1</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>5.10.1</version>
</dependency>
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20210307</version>
</dependency>
</dependencies>
</dependencyManagement>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
<version>3.3.1</version>
<configuration>
<configLocation>google_checks.xml</configLocation>
<consoleOutput>true</consoleOutput>
<failsOnError>true</failsOnError>
<violationSeverity>warning</violationSeverity>
<linkXRef>false</linkXRef>
</configuration>
<executions>
<execution>
<id>validate</id>
<phase>validate</phase>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
<pluginManagement>
<plugins>
<plugin>
<artifactId>maven-clean-plugin</artifactId>
<version>3.1.0</version>
</plugin>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>3.0.2</version>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<compilerArgs>
<arg>-h</arg>
<arg>target/headers</arg>
</compilerArgs>
</configuration>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.2.5</version>
<configuration>
<argLine>--add-opens=java.base/java.nio=ALL-UNNAMED</argLine>
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
<useSystemClassLoader>false</useSystemClassLoader>
</configuration>
</plugin>
<plugin>
<artifactId>maven-jar-plugin</artifactId>
<version>3.0.2</version>
</plugin>
<plugin>
<artifactId>maven-install-plugin</artifactId>
<version>2.5.2</version>
</plugin>
</plugins>
</pluginManagement>
</build>
</project>

View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.4.20", "version": "0.5.0",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "vectordb", "name": "vectordb",
"version": "0.4.20", "version": "0.5.0",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"

View File

@@ -1,6 +1,6 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.4.20", "version": "0.5.0",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js", "main": "dist/index.js",
"types": "dist/index.d.ts", "types": "dist/index.d.ts",

View File

@@ -31,6 +31,7 @@ import {
Schema, Schema,
Struct, Struct,
type Table, type Table,
Type,
Utf8, Utf8,
tableFromIPC, tableFromIPC,
} from "apache-arrow"; } from "apache-arrow";
@@ -51,7 +52,12 @@ import {
makeArrowTable, makeArrowTable,
makeEmptyTable, makeEmptyTable,
} from "../lancedb/arrow"; } from "../lancedb/arrow";
import { type EmbeddingFunction } from "../lancedb/embedding/embedding_function"; import {
EmbeddingFunction,
FieldOptions,
FunctionOptions,
} from "../lancedb/embedding/embedding_function";
import { EmbeddingFunctionConfig } from "../lancedb/embedding/registry";
// biome-ignore lint/suspicious/noExplicitAny: skip // biome-ignore lint/suspicious/noExplicitAny: skip
function sampleRecords(): Array<Record<string, any>> { function sampleRecords(): Array<Record<string, any>> {
@@ -280,23 +286,46 @@ describe("The function makeArrowTable", function () {
}); });
}); });
class DummyEmbedding implements EmbeddingFunction<string> { class DummyEmbedding extends EmbeddingFunction<string> {
public readonly sourceColumn = "string"; toJSON(): Partial<FunctionOptions> {
public readonly embeddingDimension = 2; return {};
public readonly embeddingDataType = new Float16(); }
async embed(data: string[]): Promise<number[][]> { async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
return data.map(() => [0.0, 0.0]); return data.map(() => [0.0, 0.0]);
} }
ndims(): number {
return 2;
}
embeddingDataType() {
return new Float16();
}
} }
class DummyEmbeddingWithNoDimension implements EmbeddingFunction<string> { class DummyEmbeddingWithNoDimension extends EmbeddingFunction<string> {
public readonly sourceColumn = "string"; toJSON(): Partial<FunctionOptions> {
return {};
}
async embed(data: string[]): Promise<number[][]> { embeddingDataType(): Float {
return new Float16();
}
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
return data.map(() => [0.0, 0.0]); return data.map(() => [0.0, 0.0]);
} }
} }
const dummyEmbeddingConfig: EmbeddingFunctionConfig = {
sourceColumn: "string",
function: new DummyEmbedding(),
};
const dummyEmbeddingConfigWithNoDimension: EmbeddingFunctionConfig = {
sourceColumn: "string",
function: new DummyEmbeddingWithNoDimension(),
};
describe("convertToTable", function () { describe("convertToTable", function () {
it("will infer data types correctly", async function () { it("will infer data types correctly", async function () {
@@ -331,7 +360,7 @@ describe("convertToTable", function () {
it("will apply embeddings", async function () { it("will apply embeddings", async function () {
const records = sampleRecords(); const records = sampleRecords();
const table = await convertToTable(records, new DummyEmbedding()); const table = await convertToTable(records, dummyEmbeddingConfig);
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true); expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true);
expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual( expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual(
new Float16().toString(), new Float16().toString(),
@@ -340,7 +369,7 @@ describe("convertToTable", function () {
it("will fail if missing the embedding source column", async function () { it("will fail if missing the embedding source column", async function () {
await expect( await expect(
convertToTable([{ id: 1 }], new DummyEmbedding()), convertToTable([{ id: 1 }], dummyEmbeddingConfig),
).rejects.toThrow("'string' was not present"); ).rejects.toThrow("'string' was not present");
}); });
@@ -351,7 +380,7 @@ describe("convertToTable", function () {
const table = makeEmptyTable(schema); const table = makeEmptyTable(schema);
// If the embedding specifies the dimension we are fine // If the embedding specifies the dimension we are fine
await fromTableToBuffer(table, new DummyEmbedding()); await fromTableToBuffer(table, dummyEmbeddingConfig);
// We can also supply a schema and should be ok // We can also supply a schema and should be ok
const schemaWithEmbedding = new Schema([ const schemaWithEmbedding = new Schema([
@@ -364,13 +393,13 @@ describe("convertToTable", function () {
]); ]);
await fromTableToBuffer( await fromTableToBuffer(
table, table,
new DummyEmbeddingWithNoDimension(), dummyEmbeddingConfigWithNoDimension,
schemaWithEmbedding, schemaWithEmbedding,
); );
// Otherwise we will get an error // Otherwise we will get an error
await expect( await expect(
fromTableToBuffer(table, new DummyEmbeddingWithNoDimension()), fromTableToBuffer(table, dummyEmbeddingConfigWithNoDimension),
).rejects.toThrow("does not specify `embeddingDimension`"); ).rejects.toThrow("does not specify `embeddingDimension`");
}); });
@@ -383,7 +412,7 @@ describe("convertToTable", function () {
false, false,
), ),
]); ]);
const table = await convertToTable([], new DummyEmbedding(), { schema }); const table = await convertToTable([], dummyEmbeddingConfig, { schema });
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true); expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true);
expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual( expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual(
new Float16().toString(), new Float16().toString(),
@@ -393,16 +422,17 @@ describe("convertToTable", function () {
it("will complain if embeddings present but schema missing embedding column", async function () { it("will complain if embeddings present but schema missing embedding column", async function () {
const schema = new Schema([new Field("string", new Utf8(), false)]); const schema = new Schema([new Field("string", new Utf8(), false)]);
await expect( await expect(
convertToTable([], new DummyEmbedding(), { schema }), convertToTable([], dummyEmbeddingConfig, { schema }),
).rejects.toThrow("column vector was missing"); ).rejects.toThrow("column vector was missing");
}); });
it("will provide a nice error if run twice", async function () { it("will provide a nice error if run twice", async function () {
const records = sampleRecords(); const records = sampleRecords();
const table = await convertToTable(records, new DummyEmbedding()); const table = await convertToTable(records, dummyEmbeddingConfig);
// fromTableToBuffer will try and apply the embeddings again // fromTableToBuffer will try and apply the embeddings again
await expect( await expect(
fromTableToBuffer(table, new DummyEmbedding()), fromTableToBuffer(table, dummyEmbeddingConfig),
).rejects.toThrow("already existed"); ).rejects.toThrow("already existed");
}); });
}); });

View File

@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
import * as tmp from "tmp"; import * as tmp from "tmp";
import { Connection, connect } from "../lancedb"; import { Connection, connect } from "../lancedb";
describe("when connecting", () => { describe("when connecting", () => {

View File

@@ -0,0 +1,169 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import * as arrow from "apache-arrow";
import * as arrowOld from "apache-arrow-old";
import * as tmp from "tmp";
import { connect } from "../lancedb";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
describe.each([arrow, arrowOld])("LanceSchema", (arrow) => {
test("should preserve input order", async () => {
const schema = LanceSchema({
id: new arrow.Int32(),
text: new arrow.Utf8(),
vector: new arrow.Float32(),
});
expect(schema.fields.map((x) => x.name)).toEqual(["id", "text", "vector"]);
});
});
describe("Registry", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => {
tmpDir.removeCallback();
getRegistry().reset();
});
it("should register a new item to the registry", async () => {
@register("mock-embedding")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {
someText: "hello",
};
}
constructor() {
super();
}
ndims() {
return 3;
}
embeddingDataType(): arrow.Float {
return new arrow.Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
}
}
const func = getRegistry()
.get<MockEmbeddingFunction>("mock-embedding")!
.create();
const schema = LanceSchema({
id: new arrow.Int32(),
text: func.sourceField(new arrow.Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{ schema },
);
const expected = [
[1, 2, 3],
[1, 2, 3],
];
const actual = await table.query().toArrow();
const vectors = actual
.getChild("vector")
?.toArray()
.map((x: unknown) => {
if (x instanceof arrow.Vector) {
return [...x];
} else {
return x;
}
});
expect(vectors).toEqual(expected);
});
test("should error if registering with the same name", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {
someText: "hello",
};
}
constructor() {
super();
}
ndims() {
return 3;
}
embeddingDataType(): arrow.Float {
return new arrow.Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
}
}
register("mock-embedding")(MockEmbeddingFunction);
expect(() => register("mock-embedding")(MockEmbeddingFunction)).toThrow(
'Embedding function with alias "mock-embedding" already exists',
);
});
test("schema should contain correct metadata", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {
someText: "hello",
};
}
constructor() {
super();
}
ndims() {
return 3;
}
embeddingDataType(): arrow.Float {
return new arrow.Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
}
}
const func = new MockEmbeddingFunction();
const schema = LanceSchema({
id: new arrow.Int32(),
text: func.sourceField(new arrow.Utf8()),
vector: func.vectorField(),
});
const expectedMetadata = new Map<string, string>([
[
"embedding_functions",
JSON.stringify([
{
sourceColumn: "text",
vectorColumn: "vector",
name: "MockEmbeddingFunction",
model: { someText: "hello" },
},
]),
],
]);
expect(schema.metadata).toEqual(expectedMetadata);
});
});

View File

@@ -16,23 +16,34 @@ import * as fs from "fs";
import * as path from "path"; import * as path from "path";
import * as tmp from "tmp"; import * as tmp from "tmp";
import * as arrow from "apache-arrow";
import * as arrowOld from "apache-arrow-old";
import { Table, connect } from "../lancedb";
import { import {
Field, Field,
FixedSizeList, FixedSizeList,
Float,
Float32, Float32,
Float64, Float64,
Int32, Int32,
Int64, Int64,
Schema, Schema,
} from "apache-arrow"; Utf8,
import { Table, connect } from "../lancedb"; makeArrowTable,
import { makeArrowTable } from "../lancedb/arrow"; } from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
import { Index } from "../lancedb/indices"; import { Index } from "../lancedb/indices";
describe("Given a table", () => { // biome-ignore lint/suspicious/noExplicitAny: <explanation>
describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
let tmpDir: tmp.DirResult; let tmpDir: tmp.DirResult;
let table: Table; let table: Table;
const schema = new Schema([new Field("id", new Float64(), true)]);
const schema = new arrow.Schema([
new arrow.Field("id", new arrow.Float64(), true),
]);
beforeEach(async () => { beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true }); tmpDir = tmp.dirSync({ unsafeCleanup: true });
const conn = await connect(tmpDir.name); const conn = await connect(tmpDir.name);
@@ -420,6 +431,161 @@ describe("when dealing with versioning", () => {
}); });
}); });
describe("embedding functions", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
it("should be able to create a table with an embedding function", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
embeddingFunction: {
function: func,
sourceColumn: "text",
},
},
);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should be able to create an empty table with an embedding function", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const schema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", new Float32(), true)),
true,
),
]);
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createEmptyTable("test", schema, {
embeddingFunction: {
function: func,
sourceColumn: "text",
},
});
const outSchema = await table.schema();
expect(outSchema.metadata.get("embedding_functions")).toBeDefined();
await table.add([{ text: "hello world" }]);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should error when appending to a table with an unregistered embedding function", async () => {
@register("mock")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("mock")!.create();
const schema = LanceSchema({
id: new arrow.Float64(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
schema,
},
);
getRegistry().reset();
const db2 = await connect(tmpDir.name);
const tbl = await db2.openTable("test");
expect(tbl.add([{ id: 3, text: "hello" }])).rejects.toThrow(
`Function "mock" not found in registry`,
);
});
});
describe("when optimizing a dataset", () => { describe("when optimizing a dataset", () => {
let tmpDir: tmp.DirResult; let tmpDir: tmp.DirResult;
let table: Table; let table: Table;

View File

@@ -48,7 +48,7 @@
"noUnsafeFinally": "error", "noUnsafeFinally": "error",
"noUnsafeOptionalChaining": "error", "noUnsafeOptionalChaining": "error",
"noUnusedLabels": "error", "noUnusedLabels": "error",
"noUnusedVariables": "error", "noUnusedVariables": "warn",
"useIsNan": "error", "useIsNan": "error",
"useValidForDirection": "error", "useValidForDirection": "error",
"useYield": "error" "useYield": "error"
@@ -101,7 +101,13 @@
}, },
"overrides": [ "overrides": [
{ {
"include": ["**/*.ts", "**/*.tsx", "**/*.mts", "**/*.cts"], "include": [
"**/*.ts",
"**/*.tsx",
"**/*.mts",
"**/*.cts",
"__test__/*.test.ts"
],
"linter": { "linter": {
"rules": { "rules": {
"correctness": { "correctness": {

View File

@@ -17,10 +17,14 @@ import {
Binary, Binary,
DataType, DataType,
Field, Field,
FixedSizeBinary,
FixedSizeList, FixedSizeList,
type Float, Float,
Float32, Float32,
Int,
LargeBinary,
List, List,
Null,
RecordBatch, RecordBatch,
RecordBatchFileWriter, RecordBatchFileWriter,
RecordBatchStreamWriter, RecordBatchStreamWriter,
@@ -34,7 +38,99 @@ import {
vectorFromArray, vectorFromArray,
} from "apache-arrow"; } from "apache-arrow";
import { type EmbeddingFunction } from "./embedding/embedding_function"; import { type EmbeddingFunction } from "./embedding/embedding_function";
import { sanitizeSchema } from "./sanitize"; import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
export * from "apache-arrow";
export function isArrowTable(value: object): value is ArrowTable {
if (value instanceof ArrowTable) return true;
return "schema" in value && "batches" in value;
}
export function isDataType(value: unknown): value is DataType {
return (
value instanceof DataType ||
DataType.isNull(value) ||
DataType.isInt(value) ||
DataType.isFloat(value) ||
DataType.isBinary(value) ||
DataType.isLargeBinary(value) ||
DataType.isUtf8(value) ||
DataType.isLargeUtf8(value) ||
DataType.isBool(value) ||
DataType.isDecimal(value) ||
DataType.isDate(value) ||
DataType.isTime(value) ||
DataType.isTimestamp(value) ||
DataType.isInterval(value) ||
DataType.isDuration(value) ||
DataType.isList(value) ||
DataType.isStruct(value) ||
DataType.isUnion(value) ||
DataType.isFixedSizeBinary(value) ||
DataType.isFixedSizeList(value) ||
DataType.isMap(value) ||
DataType.isDictionary(value)
);
}
export function isNull(value: unknown): value is Null {
return value instanceof Null || DataType.isNull(value);
}
export function isInt(value: unknown): value is Int {
return value instanceof Int || DataType.isInt(value);
}
export function isFloat(value: unknown): value is Float {
return value instanceof Float || DataType.isFloat(value);
}
export function isBinary(value: unknown): value is Binary {
return value instanceof Binary || DataType.isBinary(value);
}
export function isLargeBinary(value: unknown): value is LargeBinary {
return value instanceof LargeBinary || DataType.isLargeBinary(value);
}
export function isUtf8(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isUtf8(value);
}
export function isLargeUtf8(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isLargeUtf8(value);
}
export function isBool(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isBool(value);
}
export function isDecimal(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isDecimal(value);
}
export function isDate(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isDate(value);
}
export function isTime(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isTime(value);
}
export function isTimestamp(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isTimestamp(value);
}
export function isInterval(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isInterval(value);
}
export function isDuration(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isDuration(value);
}
export function isList(value: unknown): value is List {
return value instanceof List || DataType.isList(value);
}
export function isStruct(value: unknown): value is Struct {
return value instanceof Struct || DataType.isStruct(value);
}
export function isUnion(value: unknown): value is Struct {
return value instanceof Struct || DataType.isUnion(value);
}
export function isFixedSizeBinary(value: unknown): value is FixedSizeBinary {
return value instanceof FixedSizeBinary || DataType.isFixedSizeBinary(value);
}
export function isFixedSizeList(value: unknown): value is FixedSizeList {
return value instanceof FixedSizeList || DataType.isFixedSizeList(value);
}
/** Data type accepted by NodeJS SDK */ /** Data type accepted by NodeJS SDK */
export type Data = Record<string, unknown>[] | ArrowTable; export type Data = Record<string, unknown>[] | ArrowTable;
@@ -198,6 +294,7 @@ export class MakeArrowTableOptions {
export function makeArrowTable( export function makeArrowTable(
data: Array<Record<string, unknown>>, data: Array<Record<string, unknown>>,
options?: Partial<MakeArrowTableOptions>, options?: Partial<MakeArrowTableOptions>,
metadata?: Map<string, string>,
): ArrowTable { ): ArrowTable {
if ( if (
data.length === 0 && data.length === 0 &&
@@ -290,20 +387,41 @@ export function makeArrowTable(
// `new ArrowTable(schema, batches)` which does not do any schema inference // `new ArrowTable(schema, batches)` which does not do any schema inference
const firstTable = new ArrowTable(columns); const firstTable = new ArrowTable(columns);
const batchesFixed = firstTable.batches.map( const batchesFixed = firstTable.batches.map(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
(batch) => new RecordBatch(opt.schema!, batch.data), (batch) => new RecordBatch(opt.schema!, batch.data),
); );
return new ArrowTable(opt.schema, batchesFixed); let schema: Schema;
} else { if (metadata !== undefined) {
return new ArrowTable(columns); let schemaMetadata = opt.schema.metadata;
if (schemaMetadata.size === 0) {
schemaMetadata = metadata;
} else {
for (const [key, entry] of schemaMetadata.entries()) {
schemaMetadata.set(key, entry);
}
}
schema = new Schema(opt.schema.fields, schemaMetadata);
} else {
schema = opt.schema;
}
return new ArrowTable(schema, batchesFixed);
} }
const tbl = new ArrowTable(columns);
if (metadata !== undefined) {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
(<any>tbl.schema).metadata = metadata;
}
return tbl;
} }
/** /**
* Create an empty Arrow table with the provided schema * Create an empty Arrow table with the provided schema
*/ */
export function makeEmptyTable(schema: Schema): ArrowTable { export function makeEmptyTable(
return makeArrowTable([], { schema }); schema: Schema,
metadata?: Map<string, string>,
): ArrowTable {
return makeArrowTable([], { schema }, metadata);
} }
/** /**
@@ -375,13 +493,75 @@ function makeVector(
} }
} }
/** Helper function to apply embeddings from metadata to an input table */
async function applyEmbeddingsFromMetadata(
table: ArrowTable,
schema: Schema,
): Promise<ArrowTable> {
const registry = getRegistry();
const functions = registry.parseFunctions(schema.metadata);
const columns = Object.fromEntries(
table.schema.fields.map((field) => [
field.name,
table.getChild(field.name)!,
]),
);
for (const functionEntry of functions.values()) {
const sourceColumn = columns[functionEntry.sourceColumn];
const destColumn = functionEntry.vectorColumn ?? "vector";
if (sourceColumn === undefined) {
throw new Error(
`Cannot apply embedding function because the source column '${functionEntry.sourceColumn}' was not present in the data`,
);
}
if (columns[destColumn] !== undefined) {
throw new Error(
`Attempt to apply embeddings to table failed because column ${destColumn} already existed`,
);
}
if (table.batches.length > 1) {
throw new Error(
"Internal error: `makeArrowTable` unexpectedly created a table with more than one batch",
);
}
const values = sourceColumn.toArray();
const vectors =
await functionEntry.function.computeSourceEmbeddings(values);
if (vectors.length !== values.length) {
throw new Error(
"Embedding function did not return an embedding for each input element",
);
}
let destType: DataType;
const dtype = schema.fields.find((f) => f.name === destColumn)!.type;
if (isFixedSizeList(dtype)) {
destType = sanitizeType(dtype);
} else {
throw new Error(
"Expected FixedSizeList as datatype for vector field, instead got: " +
dtype,
);
}
const vector = makeVector(vectors, destType);
columns[destColumn] = vector;
}
const newTable = new ArrowTable(columns);
return alignTable(newTable, schema);
}
/** Helper function to apply embeddings to an input table */ /** Helper function to apply embeddings to an input table */
async function applyEmbeddings<T>( async function applyEmbeddings<T>(
table: ArrowTable, table: ArrowTable,
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunctionConfig,
schema?: Schema, schema?: Schema,
): Promise<ArrowTable> { ): Promise<ArrowTable> {
if (embeddings == null) { if (schema?.metadata.has("embedding_functions")) {
return applyEmbeddingsFromMetadata(table, schema!);
} else if (embeddings == null || embeddings === undefined) {
return table; return table;
} }
@@ -399,8 +579,9 @@ async function applyEmbeddings<T>(
const newColumns = Object.fromEntries(colEntries); const newColumns = Object.fromEntries(colEntries);
const sourceColumn = newColumns[embeddings.sourceColumn]; const sourceColumn = newColumns[embeddings.sourceColumn];
const destColumn = embeddings.destColumn ?? "vector"; const destColumn = embeddings.vectorColumn ?? "vector";
const innerDestType = embeddings.embeddingDataType ?? new Float32(); const innerDestType =
embeddings.function.embeddingDataType() ?? new Float32();
if (sourceColumn === undefined) { if (sourceColumn === undefined) {
throw new Error( throw new Error(
`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`, `Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`,
@@ -414,11 +595,9 @@ async function applyEmbeddings<T>(
// if we call convertToTable with 0 records and a schema that includes the embedding // if we call convertToTable with 0 records and a schema that includes the embedding
return table; return table;
} }
if (embeddings.embeddingDimension !== undefined) { const dimensions = embeddings.function.ndims();
const destType = newVectorType( if (dimensions !== undefined) {
embeddings.embeddingDimension, const destType = newVectorType(dimensions, innerDestType);
innerDestType,
);
newColumns[destColumn] = makeVector([], destType); newColumns[destColumn] = makeVector([], destType);
} else if (schema != null) { } else if (schema != null) {
const destField = schema.fields.find((f) => f.name === destColumn); const destField = schema.fields.find((f) => f.name === destColumn);
@@ -446,7 +625,9 @@ async function applyEmbeddings<T>(
); );
} }
const values = sourceColumn.toArray(); const values = sourceColumn.toArray();
const vectors = await embeddings.embed(values as T[]); const vectors = await embeddings.function.computeSourceEmbeddings(
values as T[],
);
if (vectors.length !== values.length) { if (vectors.length !== values.length) {
throw new Error( throw new Error(
"Embedding function did not return an embedding for each input element", "Embedding function did not return an embedding for each input element",
@@ -486,9 +667,9 @@ async function applyEmbeddings<T>(
* embedding columns. If no schema is provded then embedding columns will * embedding columns. If no schema is provded then embedding columns will
* be placed at the end of the table, after all of the input columns. * be placed at the end of the table, after all of the input columns.
*/ */
export async function convertToTable<T>( export async function convertToTable(
data: Array<Record<string, unknown>>, data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunctionConfig,
makeTableOptions?: Partial<MakeArrowTableOptions>, makeTableOptions?: Partial<MakeArrowTableOptions>,
): Promise<ArrowTable> { ): Promise<ArrowTable> {
const table = makeArrowTable(data, makeTableOptions); const table = makeArrowTable(data, makeTableOptions);
@@ -496,13 +677,13 @@ export async function convertToTable<T>(
} }
/** Creates the Arrow Type for a Vector column with dimension `dim` */ /** Creates the Arrow Type for a Vector column with dimension `dim` */
function newVectorType<T extends Float>( export function newVectorType<T extends Float>(
dim: number, dim: number,
innerType: T, innerType: T,
): FixedSizeList<T> { ): FixedSizeList<T> {
// in Lance we always default to have the elements nullable, so we need to set it to true // in Lance we always default to have the elements nullable, so we need to set it to true
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements // otherwise we often get schema mismatches because the stored data always has schema with nullable elements
const children = new Field<T>("item", innerType, true); const children = new Field("item", <T>sanitizeType(innerType), true);
return new FixedSizeList(dim, children); return new FixedSizeList(dim, children);
} }
@@ -513,9 +694,9 @@ function newVectorType<T extends Float>(
* *
* `schema` is required if data is empty * `schema` is required if data is empty
*/ */
export async function fromRecordsToBuffer<T>( export async function fromRecordsToBuffer(
data: Array<Record<string, unknown>>, data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunctionConfig,
schema?: Schema, schema?: Schema,
): Promise<Buffer> { ): Promise<Buffer> {
if (schema !== undefined && schema !== null) { if (schema !== undefined && schema !== null) {
@@ -533,9 +714,9 @@ export async function fromRecordsToBuffer<T>(
* *
* `schema` is required if data is empty * `schema` is required if data is empty
*/ */
export async function fromRecordsToStreamBuffer<T>( export async function fromRecordsToStreamBuffer(
data: Array<Record<string, unknown>>, data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunctionConfig,
schema?: Schema, schema?: Schema,
): Promise<Buffer> { ): Promise<Buffer> {
if (schema !== undefined && schema !== null) { if (schema !== undefined && schema !== null) {
@@ -554,9 +735,9 @@ export async function fromRecordsToStreamBuffer<T>(
* *
* `schema` is required if the table is empty * `schema` is required if the table is empty
*/ */
export async function fromTableToBuffer<T>( export async function fromTableToBuffer(
table: ArrowTable, table: ArrowTable,
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunctionConfig,
schema?: Schema, schema?: Schema,
): Promise<Buffer> { ): Promise<Buffer> {
if (schema !== undefined && schema !== null) { if (schema !== undefined && schema !== null) {
@@ -575,19 +756,19 @@ export async function fromTableToBuffer<T>(
* *
* `schema` is required if the table is empty * `schema` is required if the table is empty
*/ */
export async function fromDataToBuffer<T>( export async function fromDataToBuffer(
data: Data, data: Data,
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunctionConfig,
schema?: Schema, schema?: Schema,
): Promise<Buffer> { ): Promise<Buffer> {
if (schema !== undefined && schema !== null) { if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema); schema = sanitizeSchema(schema);
} }
if (data instanceof ArrowTable) { if (isArrowTable(data)) {
return fromTableToBuffer(data, embeddings, schema); return fromTableToBuffer(data, embeddings, schema);
} else { } else {
const table = await convertToTable(data); const table = await convertToTable(data, embeddings, { schema });
return fromTableToBuffer(table, embeddings, schema); return fromTableToBuffer(table);
} }
} }
@@ -599,9 +780,9 @@ export async function fromDataToBuffer<T>(
* *
* `schema` is required if the table is empty * `schema` is required if the table is empty
*/ */
export async function fromTableToStreamBuffer<T>( export async function fromTableToStreamBuffer(
table: ArrowTable, table: ArrowTable,
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunctionConfig,
schema?: Schema, schema?: Schema,
): Promise<Buffer> { ): Promise<Buffer> {
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema); const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
@@ -664,10 +845,25 @@ function validateSchemaEmbeddings(
// if it does not, we add it to the list of missing embedding fields // if it does not, we add it to the list of missing embedding fields
// Finally, we check if those missing embedding fields are `this._embeddings` // Finally, we check if those missing embedding fields are `this._embeddings`
// if they are not, we throw an error // if they are not, we throw an error
for (const field of schema.fields) { for (let field of schema.fields) {
if (field.type instanceof FixedSizeList) { if (isFixedSizeList(field.type)) {
field = sanitizeField(field);
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) { if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
missingEmbeddingFields.push(field); if (schema.metadata.has("embedding_functions")) {
const embeddings = JSON.parse(
schema.metadata.get("embedding_functions")!,
);
if (
// biome-ignore lint/suspicious/noExplicitAny: we don't know the type of `f`
embeddings.find((f: any) => f["vectorColumn"] === field.name) ===
undefined
) {
missingEmbeddingFields.push(field);
}
} else {
missingEmbeddingFields.push(field);
}
} else { } else {
fields.push(field); fields.push(field);
} }

View File

@@ -12,8 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import { Table as ArrowTable, Schema } from "apache-arrow"; import { Table as ArrowTable, Schema } from "./arrow";
import { fromTableToBuffer, makeArrowTable, makeEmptyTable } from "./arrow"; import {
fromTableToBuffer,
isArrowTable,
makeArrowTable,
makeEmptyTable,
} from "./arrow";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { ConnectionOptions, Connection as LanceDbConnection } from "./native"; import { ConnectionOptions, Connection as LanceDbConnection } from "./native";
import { Table } from "./table"; import { Table } from "./table";
@@ -65,6 +71,8 @@ export interface CreateTableOptions {
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/ * The available options are described at https://lancedb.github.io/lancedb/guides/storage/
*/ */
storageOptions?: Record<string, string>; storageOptions?: Record<string, string>;
schema?: Schema;
embeddingFunction?: EmbeddingFunctionConfig;
} }
export interface OpenTableOptions { export interface OpenTableOptions {
@@ -174,6 +182,7 @@ export class Connection {
cleanseStorageOptions(options?.storageOptions), cleanseStorageOptions(options?.storageOptions),
options?.indexCacheSize, options?.indexCacheSize,
); );
return new Table(innerTable); return new Table(innerTable);
} }
@@ -196,18 +205,24 @@ export class Connection {
} }
let table: ArrowTable; let table: ArrowTable;
if (data instanceof ArrowTable) { if (isArrowTable(data)) {
table = data; table = data;
} else { } else {
table = makeArrowTable(data); table = makeArrowTable(data, options);
} }
const buf = await fromTableToBuffer(table);
const buf = await fromTableToBuffer(
table,
options?.embeddingFunction,
options?.schema,
);
const innerTable = await this.inner.createTable( const innerTable = await this.inner.createTable(
name, name,
buf, buf,
mode, mode,
cleanseStorageOptions(options?.storageOptions), cleanseStorageOptions(options?.storageOptions),
); );
return new Table(innerTable); return new Table(innerTable);
} }
@@ -227,8 +242,14 @@ export class Connection {
if (mode === "create" && existOk) { if (mode === "create" && existOk) {
mode = "exist_ok"; mode = "exist_ok";
} }
let metadata: Map<string, string> | undefined = undefined;
if (options?.embeddingFunction !== undefined) {
const embeddingFunction = options.embeddingFunction;
const registry = getRegistry();
metadata = registry.getTableMetadata([embeddingFunction]);
}
const table = makeEmptyTable(schema); const table = makeEmptyTable(schema, metadata);
const buf = await fromTableToBuffer(table); const buf = await fromTableToBuffer(table);
const innerTable = await this.inner.createEmptyTable( const innerTable = await this.inner.createEmptyTable(
name, name,

View File

@@ -1,4 +1,4 @@
// Copyright 2023 Lance Developers. // Copyright 2024 Lance Developers.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@@ -12,67 +12,151 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import { type Float } from "apache-arrow"; import "reflect-metadata";
import {
DataType,
Field,
FixedSizeList,
Float,
Float32,
isDataType,
isFixedSizeList,
isFloat,
newVectorType,
} from "../arrow";
import { sanitizeType } from "../sanitize";
/**
* Options for a given embedding function
*/
export interface FunctionOptions {
// biome-ignore lint/suspicious/noExplicitAny: options can be anything
[key: string]: any;
}
/** /**
* An embedding function that automatically creates vector representation for a given column. * An embedding function that automatically creates vector representation for a given column.
*/ */
export interface EmbeddingFunction<T> { export abstract class EmbeddingFunction<
// biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do
T = any,
M extends FunctionOptions = FunctionOptions,
> {
/** /**
* The name of the column that will be used as input for the Embedding Function. * Convert the embedding function to a JSON object
* It is used to serialize the embedding function to the schema
* It's important that any object returned by this method contains all the necessary
* information to recreate the embedding function
*
* It should return the same object that was passed to the constructor
* If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
*
* @example
* ```ts
* class MyEmbeddingFunction extends EmbeddingFunction {
* constructor(options: {model: string, timeout: number}) {
* super();
* this.model = options.model;
* this.timeout = options.timeout;
* }
* toJSON() {
* return {
* model: this.model,
* timeout: this.timeout,
* };
* }
* ```
*/ */
sourceColumn: string; abstract toJSON(): Partial<M>;
/** /**
* The data type of the embedding * sourceField is used in combination with `LanceSchema` to provide a declarative data model
* *
* The embedding function should return `number`. This will be converted into * @param optionsOrDatatype - The options for the field or the datatype
* an Arrow float array. By default this will be Float32 but this property can *
* be used to control the conversion. * @see {@link lancedb.LanceSchema}
*/ */
embeddingDataType?: Float; sourceField(
optionsOrDatatype: Partial<FieldOptions> | DataType,
): [DataType, Map<string, EmbeddingFunction>] {
let datatype = isDataType(optionsOrDatatype)
? optionsOrDatatype
: optionsOrDatatype?.datatype;
if (!datatype) {
throw new Error("Datatype is required");
}
datatype = sanitizeType(datatype);
const metadata = new Map<string, EmbeddingFunction>();
metadata.set("source_column_for", this);
return [datatype, metadata];
}
/** /**
* The dimension of the embedding * vectorField is used in combination with `LanceSchema` to provide a declarative data model
* *
* This is optional, normally this can be determined by looking at the results of * @param options - The options for the field
* `embed`. If this is not specified, and there is an attempt to apply the embedding *
* to an empty table, then that process will fail. * @see {@link lancedb.LanceSchema}
*/ */
embeddingDimension?: number; vectorField(
options?: Partial<FieldOptions>,
): [DataType, Map<string, EmbeddingFunction>] {
let dtype: DataType;
const dims = this.ndims() ?? options?.dims;
if (!options?.datatype) {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
dtype = new FixedSizeList(dims, new Field("item", new Float32(), true));
} else {
if (isFixedSizeList(options.datatype)) {
dtype = options.datatype;
} else if (isFloat(options.datatype)) {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
dtype = newVectorType(dims, options.datatype);
} else {
throw new Error(
"Expected FixedSizeList or Float as datatype for vector field",
);
}
}
const metadata = new Map<string, EmbeddingFunction>();
metadata.set("vector_column_for", this);
/** return [dtype, metadata];
* The name of the column that will contain the embedding }
*
* By default this is "vector"
*/
destColumn?: string;
/** /** The number of dimensions of the embeddings */
* Should the source column be excluded from the resulting table ndims(): number | undefined {
* return undefined;
* By default the source column is included. Set this to true and }
* only the embedding will be stored.
*/ /** The datatype of the embeddings */
excludeSource?: boolean; abstract embeddingDataType(): Float;
/** /**
* Creates a vector representation for the given values. * Creates a vector representation for the given values.
*/ */
embed: (data: T[]) => Promise<number[][]>; abstract computeSourceEmbeddings(
data: T[],
): Promise<number[][] | Float32Array[] | Float64Array[]>;
/**
Compute the embeddings for a single query
*/
async computeQueryEmbeddings(
data: T,
): Promise<number[] | Float32Array | Float64Array> {
return this.computeSourceEmbeddings([data]).then(
(embeddings) => embeddings[0],
);
}
} }
/** Test if the input seems to be an embedding function */ export interface FieldOptions<T extends DataType = DataType> {
export function isEmbeddingFunction<T>( datatype: T;
value: unknown, dims?: number;
): value is EmbeddingFunction<T> {
if (typeof value !== "object" || value === null) {
return false;
}
if (!("sourceColumn" in value) || !("embed" in value)) {
return false;
}
return (
typeof value.sourceColumn === "string" && typeof value.embed === "function"
);
} }

View File

@@ -1,2 +1,113 @@
export { EmbeddingFunction, isEmbeddingFunction } from "./embedding_function"; // Copyright 2023 Lance Developers.
export { OpenAIEmbeddingFunction } from "./openai"; //
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { DataType, Field, Schema } from "../arrow";
import { isDataType } from "../arrow";
import { sanitizeType } from "../sanitize";
import { EmbeddingFunction } from "./embedding_function";
import { EmbeddingFunctionConfig, getRegistry } from "./registry";
export { EmbeddingFunction } from "./embedding_function";
// We need to explicitly export '*' so that the `register` decorator actually registers the class.
export * from "./openai";
export * from "./registry";
/**
* Create a schema with embedding functions.
*
* @param fields
* @returns Schema
* @example
* ```ts
* class MyEmbeddingFunction extends EmbeddingFunction {
* // ...
* }
* const func = new MyEmbeddingFunction();
* const schema = LanceSchema({
* id: new Int32(),
* text: func.sourceField(new Utf8()),
* vector: func.vectorField(),
* // optional: specify the datatype and/or dimensions
* vector2: func.vectorField({ datatype: new Float32(), dims: 3}),
* });
*
* const table = await db.createTable("my_table", data, { schema });
* ```
*/
export function LanceSchema(
fields: Record<string, [object, Map<string, EmbeddingFunction>] | object>,
): Schema {
const arrowFields: Field[] = [];
const embeddingFunctions = new Map<
EmbeddingFunction,
Partial<EmbeddingFunctionConfig>
>();
Object.entries(fields).forEach(([key, value]) => {
if (isDataType(value)) {
arrowFields.push(new Field(key, sanitizeType(value), true));
} else {
const [dtype, metadata] = value as [
object,
Map<string, EmbeddingFunction>,
];
arrowFields.push(new Field(key, sanitizeType(dtype), true));
parseEmbeddingFunctions(embeddingFunctions, key, metadata);
}
});
const registry = getRegistry();
const metadata = registry.getTableMetadata(
Array.from(embeddingFunctions.values()) as EmbeddingFunctionConfig[],
);
const schema = new Schema(arrowFields, metadata);
return schema;
}
function parseEmbeddingFunctions(
embeddingFunctions: Map<EmbeddingFunction, Partial<EmbeddingFunctionConfig>>,
key: string,
metadata: Map<string, EmbeddingFunction>,
): void {
if (metadata.has("source_column_for")) {
const embedFunction = metadata.get("source_column_for")!;
const current = embeddingFunctions.get(embedFunction);
if (current !== undefined) {
embeddingFunctions.set(embedFunction, {
...current,
sourceColumn: key,
});
} else {
embeddingFunctions.set(embedFunction, {
sourceColumn: key,
function: embedFunction,
});
}
} else if (metadata.has("vector_column_for")) {
const embedFunction = metadata.get("vector_column_for")!;
const current = embeddingFunctions.get(embedFunction);
if (current !== undefined) {
embeddingFunctions.set(embedFunction, {
...current,
vectorColumn: key,
});
} else {
embeddingFunctions.set(embedFunction, {
vectorColumn: key,
function: embedFunction,
});
}
}
}

View File

@@ -13,17 +13,31 @@
// limitations under the License. // limitations under the License.
import type OpenAI from "openai"; import type OpenAI from "openai";
import { type EmbeddingFunction } from "./embedding_function"; import { Float, Float32 } from "../arrow";
import { EmbeddingFunction } from "./embedding_function";
import { register } from "./registry";
export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> { export type OpenAIOptions = {
private readonly _openai: OpenAI; apiKey?: string;
private readonly _modelName: string; model?: string;
};
@register("openai")
export class OpenAIEmbeddingFunction extends EmbeddingFunction<
string,
OpenAIOptions
> {
#openai: OpenAI;
#modelName: string;
constructor(options: OpenAIOptions = { model: "text-embedding-ada-002" }) {
super();
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
if (!openAIKey) {
throw new Error("OpenAI API key is required");
}
const modelName = options?.model ?? "text-embedding-ada-002";
constructor(
sourceColumn: string,
openAIKey: string,
modelName: string = "text-embedding-ada-002",
) {
/** /**
* @type {import("openai").default} * @type {import("openai").default}
*/ */
@@ -36,18 +50,40 @@ export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> {
throw new Error("please install openai@^4.24.1 using npm install openai"); throw new Error("please install openai@^4.24.1 using npm install openai");
} }
this.sourceColumn = sourceColumn;
const configuration = { const configuration = {
apiKey: openAIKey, apiKey: openAIKey,
}; };
this._openai = new Openai(configuration); this.#openai = new Openai(configuration);
this._modelName = modelName; this.#modelName = modelName;
} }
async embed(data: string[]): Promise<number[][]> { toJSON() {
const response = await this._openai.embeddings.create({ return {
model: this._modelName, model: this.#modelName,
};
}
ndims(): number {
switch (this.#modelName) {
case "text-embedding-ada-002":
return 1536;
case "text-embedding-3-large":
return 3072;
case "text-embedding-3-small":
return 1536;
default:
return null as never;
}
}
embeddingDataType(): Float {
return new Float32();
}
async computeSourceEmbeddings(data: string[]): Promise<number[][]> {
const response = await this.#openai.embeddings.create({
model: this.#modelName,
input: data, input: data,
}); });
@@ -58,5 +94,15 @@ export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> {
return embeddings; return embeddings;
} }
sourceColumn: string; async computeQueryEmbeddings(data: string): Promise<number[]> {
if (typeof data !== "string") {
throw new Error("Data must be a string");
}
const response = await this.#openai.embeddings.create({
model: this.#modelName,
input: data,
});
return response.data[0].embedding;
}
} }

View File

@@ -0,0 +1,172 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import type { EmbeddingFunction } from "./embedding_function";
import "reflect-metadata";
export interface EmbeddingFunctionOptions {
[key: string]: unknown;
}
export interface EmbeddingFunctionFactory<
T extends EmbeddingFunction = EmbeddingFunction,
> {
new (modelOptions?: EmbeddingFunctionOptions): T;
}
interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
create(options?: EmbeddingFunctionOptions): T;
}
/**
* This is a singleton class used to register embedding functions
* and fetch them by name. It also handles serializing and deserializing.
* You can implement your own embedding function by subclassing EmbeddingFunction
* or TextEmbeddingFunction and registering it with the registry
*/
export class EmbeddingFunctionRegistry {
#functions: Map<string, EmbeddingFunctionFactory> = new Map();
/**
* Register an embedding function
* @param name The name of the function
* @param func The function to register
*/
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
this: EmbeddingFunctionRegistry,
alias?: string,
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
): (ctor: T) => any {
const self = this;
return function (ctor: T) {
if (!alias) {
alias = ctor.name;
}
if (self.#functions.has(alias)) {
throw new Error(
`Embedding function with alias "${alias}" already exists`,
);
}
self.#functions.set(alias, ctor);
Reflect.defineMetadata("lancedb::embedding::name", alias, ctor);
return ctor;
};
}
/**
* Fetch an embedding function by name
* @param name The name of the function
*/
get<T extends EmbeddingFunction<unknown> = EmbeddingFunction>(
name: string,
): EmbeddingFunctionCreate<T> | undefined {
const factory = this.#functions.get(name);
if (!factory) {
return undefined;
}
return {
create: function (options: EmbeddingFunctionOptions) {
return new factory(options) as unknown as T;
},
};
}
/**
* reset the registry to the initial state
*/
reset(this: EmbeddingFunctionRegistry) {
this.#functions.clear();
}
parseFunctions(
this: EmbeddingFunctionRegistry,
metadata: Map<string, string>,
): Map<string, EmbeddingFunctionConfig> {
if (!metadata.has("embedding_functions")) {
return new Map();
} else {
type FunctionConfig = {
name: string;
sourceColumn: string;
vectorColumn: string;
model: EmbeddingFunctionOptions;
};
const functions = <FunctionConfig[]>(
JSON.parse(metadata.get("embedding_functions")!)
);
return new Map(
functions.map((f) => {
const fn = this.get(f.name);
if (!fn) {
throw new Error(`Function "${f.name}" not found in registry`);
}
return [
f.name,
{
sourceColumn: f.sourceColumn,
vectorColumn: f.vectorColumn,
function: this.get(f.name)!.create(f.model),
},
];
}),
);
}
}
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
functionToMetadata(conf: EmbeddingFunctionConfig): Record<string, any> {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
const metadata: Record<string, any> = {};
const name = Reflect.getMetadata(
"lancedb::embedding::name",
conf.function.constructor,
);
metadata["sourceColumn"] = conf.sourceColumn;
metadata["vectorColumn"] = conf.vectorColumn ?? "vector";
metadata["name"] = name ?? conf.function.constructor.name;
metadata["model"] = conf.function.toJSON();
return metadata;
}
getTableMetadata(functions: EmbeddingFunctionConfig[]): Map<string, string> {
const metadata = new Map<string, string>();
const jsonData = functions.map((conf) => this.functionToMetadata(conf));
metadata.set("embedding_functions", JSON.stringify(jsonData));
return metadata;
}
}
const _REGISTRY = new EmbeddingFunctionRegistry();
export function register(name?: string) {
return _REGISTRY.register(name);
}
/**
* Utility function to get the global instance of the registry
* @returns `EmbeddingFunctionRegistry` The global instance of the registry
* @example
* ```ts
* const registry = getRegistry();
* const openai = registry.get("openai").create();
*/
export function getRegistry(): EmbeddingFunctionRegistry {
return _REGISTRY;
}
export interface EmbeddingFunctionConfig {
sourceColumn: string;
vectorColumn?: string;
function: EmbeddingFunction;
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "apache-arrow"; import { Table as ArrowTable, RecordBatch, tableFromIPC } from "./arrow";
import { type IvfPqOptions } from "./indices"; import { type IvfPqOptions } from "./indices";
import { import {
RecordBatchIterator as NativeBatchIterator, RecordBatchIterator as NativeBatchIterator,
@@ -170,6 +170,7 @@ export class QueryBase<
/** Collect the results as an array of objects. */ /** Collect the results as an array of objects. */
async toArray(): Promise<unknown[]> { async toArray(): Promise<unknown[]> {
const tbl = await this.toArrow(); const tbl = await this.toArrow();
// eslint-disable-next-line @typescript-eslint/no-unsafe-return // eslint-disable-next-line @typescript-eslint/no-unsafe-return
return tbl.toArray(); return tbl.toArray();
} }

View File

@@ -20,6 +20,7 @@
// comes from the exact same library instance. This is not always the case // comes from the exact same library instance. This is not always the case
// and so we must sanitize the input to ensure that it is compatible. // and so we must sanitize the input to ensure that it is compatible.
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type";
import { import {
Binary, Binary,
Bool, Bool,
@@ -75,10 +76,9 @@ import {
Uint64, Uint64,
Union, Union,
Utf8, Utf8,
} from "apache-arrow"; } from "./arrow";
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type";
function sanitizeMetadata( export function sanitizeMetadata(
metadataLike?: unknown, metadataLike?: unknown,
): Map<string, string> | undefined { ): Map<string, string> | undefined {
if (metadataLike === undefined || metadataLike === null) { if (metadataLike === undefined || metadataLike === null) {
@@ -97,7 +97,7 @@ function sanitizeMetadata(
return metadataLike as Map<string, string>; return metadataLike as Map<string, string>;
} }
function sanitizeInt(typeLike: object) { export function sanitizeInt(typeLike: object) {
if ( if (
!("bitWidth" in typeLike) || !("bitWidth" in typeLike) ||
typeof typeLike.bitWidth !== "number" || typeof typeLike.bitWidth !== "number" ||
@@ -111,14 +111,14 @@ function sanitizeInt(typeLike: object) {
return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth); return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth);
} }
function sanitizeFloat(typeLike: object) { export function sanitizeFloat(typeLike: object) {
if (!("precision" in typeLike) || typeof typeLike.precision !== "number") { if (!("precision" in typeLike) || typeof typeLike.precision !== "number") {
throw Error("Expected a Float Type to have a `precision` property"); throw Error("Expected a Float Type to have a `precision` property");
} }
return new Float(typeLike.precision as Precision); return new Float(typeLike.precision as Precision);
} }
function sanitizeDecimal(typeLike: object) { export function sanitizeDecimal(typeLike: object) {
if ( if (
!("scale" in typeLike) || !("scale" in typeLike) ||
typeof typeLike.scale !== "number" || typeof typeLike.scale !== "number" ||
@@ -134,14 +134,14 @@ function sanitizeDecimal(typeLike: object) {
return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth); return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth);
} }
function sanitizeDate(typeLike: object) { export function sanitizeDate(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Date type to have a `unit` property"); throw Error("Expected a Date type to have a `unit` property");
} }
return new Date_(typeLike.unit as DateUnit); return new Date_(typeLike.unit as DateUnit);
} }
function sanitizeTime(typeLike: object) { export function sanitizeTime(typeLike: object) {
if ( if (
!("unit" in typeLike) || !("unit" in typeLike) ||
typeof typeLike.unit !== "number" || typeof typeLike.unit !== "number" ||
@@ -155,7 +155,7 @@ function sanitizeTime(typeLike: object) {
return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth); return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth);
} }
function sanitizeTimestamp(typeLike: object) { export function sanitizeTimestamp(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Timestamp type to have a `unit` property"); throw Error("Expected a Timestamp type to have a `unit` property");
} }
@@ -166,7 +166,7 @@ function sanitizeTimestamp(typeLike: object) {
return new Timestamp(typeLike.unit, timezone); return new Timestamp(typeLike.unit, timezone);
} }
function sanitizeTypedTimestamp( export function sanitizeTypedTimestamp(
typeLike: object, typeLike: object,
// eslint-disable-next-line @typescript-eslint/naming-convention // eslint-disable-next-line @typescript-eslint/naming-convention
Datatype: Datatype:
@@ -182,14 +182,14 @@ function sanitizeTypedTimestamp(
return new Datatype(timezone); return new Datatype(timezone);
} }
function sanitizeInterval(typeLike: object) { export function sanitizeInterval(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected an Interval type to have a `unit` property"); throw Error("Expected an Interval type to have a `unit` property");
} }
return new Interval(typeLike.unit); return new Interval(typeLike.unit);
} }
function sanitizeList(typeLike: object) { export function sanitizeList(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error( throw Error(
"Expected a List type to have an array-like `children` property", "Expected a List type to have an array-like `children` property",
@@ -201,7 +201,7 @@ function sanitizeList(typeLike: object) {
return new List(sanitizeField(typeLike.children[0])); return new List(sanitizeField(typeLike.children[0]));
} }
function sanitizeStruct(typeLike: object) { export function sanitizeStruct(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error( throw Error(
"Expected a Struct type to have an array-like `children` property", "Expected a Struct type to have an array-like `children` property",
@@ -210,7 +210,7 @@ function sanitizeStruct(typeLike: object) {
return new Struct(typeLike.children.map((child) => sanitizeField(child))); return new Struct(typeLike.children.map((child) => sanitizeField(child)));
} }
function sanitizeUnion(typeLike: object) { export function sanitizeUnion(typeLike: object) {
if ( if (
!("typeIds" in typeLike) || !("typeIds" in typeLike) ||
!("mode" in typeLike) || !("mode" in typeLike) ||
@@ -234,7 +234,7 @@ function sanitizeUnion(typeLike: object) {
); );
} }
function sanitizeTypedUnion( export function sanitizeTypedUnion(
typeLike: object, typeLike: object,
// eslint-disable-next-line @typescript-eslint/naming-convention // eslint-disable-next-line @typescript-eslint/naming-convention
UnionType: typeof DenseUnion | typeof SparseUnion, UnionType: typeof DenseUnion | typeof SparseUnion,
@@ -256,7 +256,7 @@ function sanitizeTypedUnion(
); );
} }
function sanitizeFixedSizeBinary(typeLike: object) { export function sanitizeFixedSizeBinary(typeLike: object) {
if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") { if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") {
throw Error( throw Error(
"Expected a FixedSizeBinary type to have a `byteWidth` property", "Expected a FixedSizeBinary type to have a `byteWidth` property",
@@ -265,7 +265,7 @@ function sanitizeFixedSizeBinary(typeLike: object) {
return new FixedSizeBinary(typeLike.byteWidth); return new FixedSizeBinary(typeLike.byteWidth);
} }
function sanitizeFixedSizeList(typeLike: object) { export function sanitizeFixedSizeList(typeLike: object) {
if (!("listSize" in typeLike) || typeof typeLike.listSize !== "number") { if (!("listSize" in typeLike) || typeof typeLike.listSize !== "number") {
throw Error("Expected a FixedSizeList type to have a `listSize` property"); throw Error("Expected a FixedSizeList type to have a `listSize` property");
} }
@@ -283,7 +283,7 @@ function sanitizeFixedSizeList(typeLike: object) {
); );
} }
function sanitizeMap(typeLike: object) { export function sanitizeMap(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error( throw Error(
"Expected a Map type to have an array-like `children` property", "Expected a Map type to have an array-like `children` property",
@@ -300,14 +300,14 @@ function sanitizeMap(typeLike: object) {
); );
} }
function sanitizeDuration(typeLike: object) { export function sanitizeDuration(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Duration type to have a `unit` property"); throw Error("Expected a Duration type to have a `unit` property");
} }
return new Duration(typeLike.unit); return new Duration(typeLike.unit);
} }
function sanitizeDictionary(typeLike: object) { export function sanitizeDictionary(typeLike: object) {
if (!("id" in typeLike) || typeof typeLike.id !== "number") { if (!("id" in typeLike) || typeof typeLike.id !== "number") {
throw Error("Expected a Dictionary type to have an `id` property"); throw Error("Expected a Dictionary type to have an `id` property");
} }
@@ -329,7 +329,7 @@ function sanitizeDictionary(typeLike: object) {
} }
// biome-ignore lint/suspicious/noExplicitAny: skip // biome-ignore lint/suspicious/noExplicitAny: skip
function sanitizeType(typeLike: unknown): DataType<any> { export function sanitizeType(typeLike: unknown): DataType<any> {
if (typeof typeLike !== "object" || typeLike === null) { if (typeof typeLike !== "object" || typeLike === null) {
throw Error("Expected a Type but object was null/undefined"); throw Error("Expected a Type but object was null/undefined");
} }
@@ -449,7 +449,7 @@ function sanitizeType(typeLike: unknown): DataType<any> {
} }
} }
function sanitizeField(fieldLike: unknown): Field { export function sanitizeField(fieldLike: unknown): Field {
if (fieldLike instanceof Field) { if (fieldLike instanceof Field) {
return fieldLike; return fieldLike;
} }

View File

@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import { Schema, tableFromIPC } from "apache-arrow"; import { Data, Schema, fromDataToBuffer, tableFromIPC } from "./arrow";
import { Data, fromDataToBuffer } from "./arrow";
import { getRegistry } from "./embedding/registry";
import { IndexOptions } from "./indices"; import { IndexOptions } from "./indices";
import { import {
AddColumnsSql, AddColumnsSql,
@@ -122,8 +123,14 @@ export class Table {
*/ */
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> { async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
const mode = options?.mode ?? "append"; const mode = options?.mode ?? "append";
const schema = await this.schema();
const registry = getRegistry();
const functions = registry.parseFunctions(schema.metadata);
const buffer = await fromDataToBuffer(data); const buffer = await fromDataToBuffer(
data,
functions.values().next().value,
);
await this.inner.add(buffer, mode); await this.inner.add(buffer, mode);
} }

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-arm64", "name": "@lancedb/lancedb-darwin-arm64",
"version": "0.4.20", "version": "0.5.0",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node", "main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-x64", "name": "@lancedb/lancedb-darwin-x64",
"version": "0.4.20", "version": "0.5.0",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.darwin-x64.node", "main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-gnu", "name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.4.20", "version": "0.5.0",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node", "main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-gnu", "name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.4.20", "version": "0.5.0",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node", "main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-x64-msvc", "name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.4.20", "version": "0.5.0",
"os": ["win32"], "os": ["win32"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node", "main": "lancedb.win32-x64-msvc.node",

15383
nodejs/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,12 @@
{ {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.4.20", "version": "0.5.0",
"main": "./dist/index.js", "main": "dist/index.js",
"types": "./dist/index.d.ts", "exports": {
".": "./dist/index.js",
"./embedding": "./dist/embedding/index.js"
},
"types": "dist/index.d.ts",
"napi": { "napi": {
"name": "lancedb", "name": "lancedb",
"triples": { "triples": {
@@ -62,6 +66,7 @@
}, },
"dependencies": { "dependencies": {
"apache-arrow": "^15.0.0", "apache-arrow": "^15.0.0",
"openai": "^4.29.2" "openai": "^4.29.2",
"reflect-metadata": "^0.2.2"
} }
} }

View File

@@ -7,7 +7,9 @@
"outDir": "./dist", "outDir": "./dist",
"strict": true, "strict": true,
"allowJs": true, "allowJs": true,
"resolveJsonModule": true "resolveJsonModule": true,
"emitDecoratorMetadata": true,
"experimentalDecorators": true
}, },
"exclude": ["./dist/*"], "exclude": ["./dist/*"],
"typedocOptions": { "typedocOptions": {

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.7.0" current_version = "0.8.1"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-python" name = "lancedb-python"
version = "0.7.0" version = "0.8.1"
edition.workspace = true edition.workspace = true
description = "Python bindings for LanceDB" description = "Python bindings for LanceDB"
license.workspace = true license.workspace = true

View File

@@ -3,14 +3,14 @@ name = "lancedb"
# version in Cargo.toml # version in Cargo.toml
dependencies = [ dependencies = [
"deprecation", "deprecation",
"pylance==0.11.0", "pylance==0.11.1",
"ratelimiter~=1.0", "ratelimiter~=1.0",
"requests>=2.31.0", "requests>=2.31.0",
"retry>=0.9.2", "retry>=0.9.2",
"tqdm>=4.27.0", "tqdm>=4.27.0",
"pydantic>=1.10", "pydantic>=1.10",
"attrs>=21.3.0", "attrs>=21.3.0",
"semver", "packaging",
"cachetools", "cachetools",
"overrides>=0.7", "overrides>=0.7",
] ]

View File

@@ -509,7 +509,7 @@ class AsyncConnection(object):
return self._inner.__repr__() return self._inner.__repr__()
def __enter__(self): def __enter__(self):
self return self
def __exit__(self, *_): def __exit__(self, *_):
self.close() self.close()
@@ -779,7 +779,7 @@ class AsyncConnection(object):
name: str, name: str,
storage_options: Optional[Dict[str, str]] = None, storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None, index_cache_size: Optional[int] = None,
) -> Table: ) -> AsyncTable:
"""Open a Lance Table in the database. """Open a Lance Table in the database.
Parameters Parameters

View File

@@ -74,7 +74,7 @@ class BedRockText(TextEmbeddingFunction):
profile_name: Union[str, None] = None profile_name: Union[str, None] = None
role_session_name: str = "lancedb-embeddings" role_session_name: str = "lancedb-embeddings"
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
class Config: class Config:
keep_untouched = (cached_property,) keep_untouched = (cached_property,)

View File

@@ -90,7 +90,7 @@ class GeminiText(TextEmbeddingFunction):
query_task_type: str = "retrieval_query" query_task_type: str = "retrieval_query"
source_task_type: str = "retrieval_document" source_task_type: str = "retrieval_document"
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
class Config: class Config:
keep_untouched = (cached_property,) keep_untouched = (cached_property,)

View File

@@ -40,7 +40,7 @@ class ImageBindEmbeddings(EmbeddingFunction):
device: str = "cpu" device: str = "cpu"
normalize: bool = False normalize: bool = False
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
class Config: class Config:
keep_untouched = (cached_property,) keep_untouched = (cached_property,)

View File

@@ -54,7 +54,7 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name) self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name)
self._model = transformers.AutoModel.from_pretrained(self.name) self._model = transformers.AutoModel.from_pretrained(self.name)
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
class Config: class Config:
keep_untouched = (cached_property,) keep_untouched = (cached_property,)

View File

@@ -35,13 +35,13 @@ from typing import (
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
import pydantic import pydantic
import semver from packaging.version import Version
PYDANTIC_VERSION = semver.parse_version_info(pydantic.__version__) PYDANTIC_VERSION = Version(pydantic.__version__)
try: try:
from pydantic_core import CoreSchema, core_schema from pydantic_core import CoreSchema, core_schema
except ImportError: except ImportError:
if PYDANTIC_VERSION >= (2,): if PYDANTIC_VERSION.major >= 2:
raise raise
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -144,7 +144,7 @@ def Vector(
raise TypeError("A list of numbers or numpy.ndarray is needed") raise TypeError("A list of numbers or numpy.ndarray is needed")
return cls(v) return cls(v)
if PYDANTIC_VERSION < (2, 0): if PYDANTIC_VERSION.major < 2:
@classmethod @classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]): def __modify_schema__(cls, field_schema: Dict[str, Any]):

View File

@@ -1,5 +1,5 @@
import os import os
import semver from packaging.version import Version
from functools import cached_property from functools import cached_property
from typing import Union from typing import Union
@@ -44,9 +44,8 @@ class CohereReranker(Reranker):
def _client(self): def _client(self):
cohere = attempt_import_or_raise("cohere") cohere = attempt_import_or_raise("cohere")
# ensure version is at least 0.5.0 # ensure version is at least 0.5.0
if ( if hasattr(cohere, "__version__") and Version(cohere.__version__) < Version(
hasattr(cohere, "__version__") "0.5.0"
and semver.compare(cohere.__version__, "5.0.0") < 0
): ):
raise ValueError( raise ValueError(
f"cohere version must be at least 0.5.0, found {cohere.__version__}" f"cohere version must be at least 0.5.0, found {cohere.__version__}"

View File

@@ -296,6 +296,13 @@ async def test_close(tmp_path):
await db.table_names() await db.table_names()
@pytest.mark.asyncio
async def test_context_manager(tmp_path):
with await lancedb.connect_async(tmp_path) as db:
assert db.is_open()
assert not db.is_open()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_mode_async(tmp_path): async def test_create_mode_async(tmp_path):
db = await lancedb.connect_async(tmp_path) db = await lancedb.connect_async(tmp_path)

View File

@@ -178,7 +178,7 @@ def test_fixed_size_list_field():
li: List[int] li: List[int]
data = TestModel(vec=list(range(16)), li=[1, 2, 3]) data = TestModel(vec=list(range(16)), li=[1, 2, 3])
if PYDANTIC_VERSION >= (2,): if PYDANTIC_VERSION.major >= 2:
assert json.loads(data.model_dump_json()) == { assert json.loads(data.model_dump_json()) == {
"vec": list(range(16)), "vec": list(range(16)),
"li": [1, 2, 3], "li": [1, 2, 3],
@@ -197,7 +197,7 @@ def test_fixed_size_list_field():
] ]
) )
if PYDANTIC_VERSION >= (2,): if PYDANTIC_VERSION.major >= 2:
json_schema = TestModel.model_json_schema() json_schema = TestModel.model_json_schema()
else: else:
json_schema = TestModel.schema() json_schema = TestModel.schema()

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-node" name = "lancedb-node"
version = "0.4.20" version = "0.5.0"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true
edition.workspace = true edition.workspace = true

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb" name = "lancedb"
version = "0.4.20" version = "0.5.0"
edition.workspace = true edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true
@@ -38,6 +38,7 @@ url.workspace = true
regex.workspace = true regex.workspace = true
serde = { version = "^1" } serde = { version = "^1" }
serde_json = { version = "1" } serde_json = { version = "1" }
serde_with = { version = "3.8.1" }
# For remote feature # For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true } polars-arrow = { version = ">=0.37,<0.40.0", optional = true }

View File

@@ -14,11 +14,14 @@
use std::sync::Arc; use std::sync::Arc;
use serde::Deserialize;
use serde_with::skip_serializing_none;
use crate::{table::TableInternal, Result}; use crate::{table::TableInternal, Result};
use self::{ use self::{
scalar::BTreeIndexBuilder, scalar::BTreeIndexBuilder,
vector::{IvfHnswSqIndexBuilder, IvfPqIndexBuilder}, vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
}; };
pub mod scalar; pub mod scalar;
@@ -28,6 +31,7 @@ pub enum Index {
Auto, Auto,
BTree(BTreeIndexBuilder), BTree(BTreeIndexBuilder),
IvfPq(IvfPqIndexBuilder), IvfPq(IvfPqIndexBuilder),
IvfHnswPq(IvfHnswPqIndexBuilder),
IvfHnswSq(IvfHnswSqIndexBuilder), IvfHnswSq(IvfHnswSqIndexBuilder),
} }
@@ -69,6 +73,7 @@ impl IndexBuilder {
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum IndexType { pub enum IndexType {
IvfPq, IvfPq,
IvfHnswPq,
IvfHnswSq, IvfHnswSq,
BTree, BTree,
} }
@@ -83,3 +88,19 @@ pub struct IndexConfig {
/// be more columns to represent composite indices. /// be more columns to represent composite indices.
pub columns: Vec<String>, pub columns: Vec<String>,
} }
#[skip_serializing_none]
#[derive(Debug, Deserialize)]
pub struct IndexMetadata {
pub metric_type: Option<String>,
pub index_type: Option<String>,
}
#[skip_serializing_none]
#[derive(Debug, Deserialize)]
pub struct IndexStatistics {
pub num_indexed_rows: usize,
pub num_unindexed_rows: usize,
pub index_type: Option<String>,
pub indices: Vec<IndexMetadata>,
}

View File

@@ -19,8 +19,6 @@
//! values //! values
use std::cmp::max; use std::cmp::max;
use serde::Deserialize;
use lance::table::format::{Index, Manifest}; use lance::table::format::{Index, Manifest};
use crate::DistanceType; use crate::DistanceType;
@@ -46,18 +44,118 @@ impl VectorIndex {
} }
} }
#[derive(Debug, Deserialize)] macro_rules! impl_distance_type_setter {
pub struct VectorIndexMetadata { () => {
pub metric_type: String, /// [DistanceType] to use to build the index.
pub index_type: String, ///
/// Default value is [DistanceType::L2].
///
/// This is used when training the index to calculate the IVF partitions (vectors are
/// grouped in partitions with similar vectors according to this distance type) and to
/// calculate a subvector's code during quantization.
///
/// The metric type used to train an index MUST match the metric type used to search the
/// index. Failure to do so will yield inaccurate results.
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
self.distance_type = distance_type;
self
}
};
} }
#[derive(Debug, Deserialize)] macro_rules! impl_ivf_params_setter {
pub struct VectorIndexStatistics { () => {
pub num_indexed_rows: usize, /// The number of IVF partitions to create.
pub num_unindexed_rows: usize, ///
pub index_type: String, /// This value should generally scale with the number of rows in the dataset. By default
pub indices: Vec<VectorIndexMetadata>, /// the number of partitions is the square root of the number of rows.
///
/// If this value is too large then the first part of the search (picking the right partition)
/// will be slow. If this value is too small then the second part of the search (searching
/// within a partition) will be slow.
pub fn num_partitions(mut self, num_partitions: u32) -> Self {
self.num_partitions = Some(num_partitions);
self
}
/// The rate used to calculate the number of training vectors for kmeans.
///
/// When an IVF index is trained, we need to calculate partitions. These are groups
/// of vectors that are similar to each other. To do this we use an algorithm called kmeans.
///
/// Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
/// random sample of the data. This parameter controls the size of the sample. The total
/// number of vectors used to train the index is `sample_rate * num_partitions`.
///
/// Increasing this value might improve the quality of the index but in most cases the
/// default should be sufficient.
///
/// The default value is 256.
pub fn sample_rate(mut self, sample_rate: u32) -> Self {
self.sample_rate = sample_rate;
self
}
/// Max iterations to train kmeans.
///
/// When training an IVF index we use kmeans to calculate the partitions. This parameter
/// controls how many iterations of kmeans to run.
///
/// Increasing this might improve the quality of the index but in most cases the parameter
/// is unused because kmeans will converge with fewer iterations. The parameter is only
/// used in cases where kmeans does not appear to converge. In those cases it is unlikely
/// that setting this larger will lead to the index converging anyways.
///
/// The default value is 50.
pub fn max_iterations(mut self, max_iterations: u32) -> Self {
self.max_iterations = max_iterations;
self
}
};
}
macro_rules! impl_pq_params_setter {
() => {
/// Number of sub-vectors of PQ.
///
/// This value controls how much the vector is compressed during the quantization step.
/// The more sub vectors there are the less the vector is compressed. The default is
/// the dimension of the vector divided by 16. If the dimension is not evenly divisible
/// by 16 we use the dimension divded by 8.
///
/// The above two cases are highly preferred. Having 8 or 16 values per subvector allows
/// us to use efficient SIMD instructions.
///
/// If the dimension is not visible by 8 then we use 1 subvector. This is not ideal and
/// will likely result in poor performance.
pub fn num_sub_vectors(mut self, num_sub_vectors: u32) -> Self {
self.num_sub_vectors = Some(num_sub_vectors);
self
}
};
}
macro_rules! impl_hnsw_params_setter {
() => {
/// The number of neighbors to select for each vector in the HNSW graph.
/// This value controls the tradeoff between search speed and accuracy.
/// The higher the value the more accurate the search but the slower it will be.
/// The default value is 20.
pub fn num_edges(mut self, m: u32) -> Self {
self.m = m;
self
}
/// The number of candidates to evaluate during the construction of the HNSW graph.
/// This value controls the tradeoff between build speed and accuracy.
/// The higher the value the more accurate the build but the slower it will be.
/// This value should be set to a value that is not less than `ef` in the search phase.
/// The default value is 300.
pub fn ef_construction(mut self, ef_construction: u32) -> Self {
self.ef_construction = ef_construction;
self
}
};
} }
/// Builder for an IVF PQ index. /// Builder for an IVF PQ index.
@@ -106,84 +204,9 @@ impl Default for IvfPqIndexBuilder {
} }
impl IvfPqIndexBuilder { impl IvfPqIndexBuilder {
/// [DistanceType] to use to build the index. impl_distance_type_setter!();
/// impl_ivf_params_setter!();
/// Default value is [DistanceType::L2]. impl_pq_params_setter!();
///
/// This is used when training the index to calculate the IVF partitions (vectors are
/// grouped in partitions with similar vectors according to this distance type) and to
/// calculate a subvector's code during quantization.
///
/// The metric type used to train an index MUST match the metric type used to search the
/// index. Failure to do so will yield inaccurate results.
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
self.distance_type = distance_type;
self
}
/// The number of IVF partitions to create.
///
/// This value should generally scale with the number of rows in the dataset. By default
/// the number of partitions is the square root of the number of rows.
///
/// If this value is too large then the first part of the search (picking the right partition)
/// will be slow. If this value is too small then the second part of the search (searching
/// within a partition) will be slow.
pub fn num_partitions(mut self, num_partitions: u32) -> Self {
self.num_partitions = Some(num_partitions);
self
}
/// Number of sub-vectors of PQ.
///
/// This value controls how much the vector is compressed during the quantization step.
/// The more sub vectors there are the less the vector is compressed. The default is
/// the dimension of the vector divided by 16. If the dimension is not evenly divisible
/// by 16 we use the dimension divded by 8.
///
/// The above two cases are highly preferred. Having 8 or 16 values per subvector allows
/// us to use efficient SIMD instructions.
///
/// If the dimension is not visible by 8 then we use 1 subvector. This is not ideal and
/// will likely result in poor performance.
pub fn num_sub_vectors(mut self, num_sub_vectors: u32) -> Self {
self.num_sub_vectors = Some(num_sub_vectors);
self
}
/// The rate used to calculate the number of training vectors for kmeans.
///
/// When an IVF PQ index is trained, we need to calculate partitions. These are groups
/// of vectors that are similar to each other. To do this we use an algorithm called kmeans.
///
/// Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
/// random sample of the data. This parameter controls the size of the sample. The total
/// number of vectors used to train the index is `sample_rate * num_partitions`.
///
/// Increasing this value might improve the quality of the index but in most cases the
/// default should be sufficient.
///
/// The default value is 256.
pub fn sample_rate(mut self, sample_rate: u32) -> Self {
self.sample_rate = sample_rate;
self
}
/// Max iterations to train kmeans.
///
/// When training an IVF PQ index we use kmeans to calculate the partitions. This parameter
/// controls how many iterations of kmeans to run.
///
/// Increasing this might improve the quality of the index but in most cases the parameter
/// is unused because kmeans will converge with fewer iterations. The parameter is only
/// used in cases where kmeans does not appear to converge. In those cases it is unlikely
/// that setting this larger will lead to the index converging anyways.
///
/// The default value is 50.
pub fn max_iterations(mut self, max_iterations: u32) -> Self {
self.max_iterations = max_iterations;
self
}
} }
pub(crate) fn suggested_num_partitions(rows: usize) -> u32 { pub(crate) fn suggested_num_partitions(rows: usize) -> u32 {
@@ -206,6 +229,51 @@ pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
} }
} }
/// Builder for an IVF HNSW PQ index.
///
/// This index is a combination of IVF and HNSW.
/// The IVF part is the same as the IVF PQ index.
/// For each IVF partition, this builds a HNSW graph, the graph is used to
/// quickly find the closest vectors to a query vector.
///
/// The PQ (product quantizer) is used to compress the vectors as the same as IVF PQ.
#[derive(Debug, Clone)]
pub struct IvfHnswPqIndexBuilder {
// IVF
pub(crate) distance_type: DistanceType,
pub(crate) num_partitions: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
// HNSW
pub(crate) m: u32,
pub(crate) ef_construction: u32,
// PQ
pub(crate) num_sub_vectors: Option<u32>,
}
impl Default for IvfHnswPqIndexBuilder {
fn default() -> Self {
Self {
distance_type: DistanceType::L2,
num_partitions: None,
num_sub_vectors: None,
sample_rate: 256,
max_iterations: 50,
m: 20,
ef_construction: 300,
}
}
}
impl IvfHnswPqIndexBuilder {
impl_distance_type_setter!();
impl_ivf_params_setter!();
impl_hnsw_params_setter!();
impl_pq_params_setter!();
}
/// Builder for an IVF_HNSW_SQ index. /// Builder for an IVF_HNSW_SQ index.
/// ///
/// This index is a combination of IVF and HNSW. /// This index is a combination of IVF and HNSW.
@@ -244,85 +312,7 @@ impl Default for IvfHnswSqIndexBuilder {
} }
impl IvfHnswSqIndexBuilder { impl IvfHnswSqIndexBuilder {
/// [DistanceType] to use to build the index. impl_distance_type_setter!();
/// impl_ivf_params_setter!();
/// Default value is [DistanceType::L2]. impl_hnsw_params_setter!();
///
/// This is used when training the index to calculate the IVF partitions (vectors are
/// grouped in partitions with similar vectors according to this distance type)
///
/// The metric type used to train an index MUST match the metric type used to search the
/// index. Failure to do so will yield inaccurate results.
///
/// Now IVF_HNSW_SQ only supports L2 and Cosine distance types.
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
self.distance_type = distance_type;
self
}
/// The number of IVF partitions to create.
///
/// This value should generally scale with the number of rows in the dataset. By default
/// the number of partitions is the square root of the number of rows.
///
/// If this value is too large then the first part of the search (picking the right partition)
/// will be slow. If this value is too small then the second part of the search (searching
/// within a partition) will be slow.
pub fn num_partitions(mut self, num_partitions: u32) -> Self {
self.num_partitions = Some(num_partitions);
self
}
/// The rate used to calculate the number of training vectors for kmeans and SQ.
///
/// When an IVF_HNSW_SQ index is trained, we need to calculate partitions and min/max value of vectors. These are groups
/// of vectors that are similar to each other. To do this we use an algorithm called kmeans.
///
/// Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
/// random sample of the data. This parameter controls the size of the sample. The total
/// number of vectors used to train the IVF is `sample_rate * num_partitions`.
///
/// The total number of vectors used to train the SQ is `sample_rate * 2^{num_bits}`.
///
/// Increasing this value might improve the quality of the index but in most cases the
/// default should be sufficient.
///
/// The default value is 256.
pub fn sample_rate(mut self, sample_rate: u32) -> Self {
self.sample_rate = sample_rate;
self
}
/// Max iterations to train kmeans.
///
/// When training an IVF index we use kmeans to calculate the partitions. This parameter
/// controls how many iterations of kmeans to run.
///
/// Increasing this might improve the quality of the index but in most cases the parameter
/// is unused because kmeans will converge with fewer iterations. The parameter is only
/// used in cases where kmeans does not appear to converge. In those cases it is unlikely
/// that setting this larger will lead to the index converging anyways.
///
/// The default value is 50.
pub fn max_iterations(mut self, max_iterations: u32) -> Self {
self.max_iterations = max_iterations;
self
}
/// The number of neighbors to select for each vector in the HNSW graph.
/// Bumping this number will increase the recall of the search but also increase the build/search time.
/// The default value is 20.
pub fn m(mut self, m: u32) -> Self {
self.m = m;
self
}
/// The number of candidates to evaluate during the construction of the HNSW graph.
/// Bumping this number will increase the recall of the search but also increase the build/search time.
/// This value should be not less than `ef` in the search phase.
/// The default value is 300.
pub fn ef_construction(mut self, ef_construction: u32) -> Self {
self.ef_construction = ef_construction;
self
}
} }

View File

@@ -37,6 +37,7 @@ use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatch
use lance::io::WrappingObjectStore; use lance::io::WrappingObjectStore;
use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_index::vector::ivf::IvfBuildParams; use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::pq::PQBuildParams;
use lance_index::vector::sq::builder::SQBuildParams; use lance_index::vector::sq::builder::SQBuildParams;
use lance_index::DatasetIndexExt; use lance_index::DatasetIndexExt;
use lance_index::IndexType; use lance_index::IndexType;
@@ -49,9 +50,10 @@ use crate::connection::NoData;
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry}; use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::index::vector::{ use crate::index::vector::{
IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex,
}; };
use crate::index::IndexConfig; use crate::index::IndexConfig;
use crate::index::IndexStatistics;
use crate::index::{ use crate::index::{
vector::{suggested_num_partitions, suggested_num_sub_vectors}, vector::{suggested_num_partitions, suggested_num_sub_vectors},
Index, IndexBuilder, Index, IndexBuilder,
@@ -1217,7 +1219,7 @@ impl NativeTable {
pub async fn get_index_type(&self, index_uuid: &str) -> Result<Option<String>> { pub async fn get_index_type(&self, index_uuid: &str) -> Result<Option<String>> {
match self.load_index_stats(index_uuid).await? { match self.load_index_stats(index_uuid).await? {
Some(stats) => Ok(Some(stats.index_type)), Some(stats) => Ok(Some(stats.index_type.unwrap_or_default())),
None => Ok(None), None => Ok(None),
} }
} }
@@ -1228,7 +1230,7 @@ impl NativeTable {
stats stats
.indices .indices
.iter() .iter()
.map(|i| i.metric_type.clone()) .filter_map(|i| i.metric_type.clone())
.collect(), .collect(),
)), )),
None => Ok(None), None => Ok(None),
@@ -1244,7 +1246,7 @@ impl NativeTable {
.collect()) .collect())
} }
async fn load_index_stats(&self, index_uuid: &str) -> Result<Option<VectorIndexStatistics>> { async fn load_index_stats(&self, index_uuid: &str) -> Result<Option<IndexStatistics>> {
let index = self let index = self
.load_indices() .load_indices()
.await? .await?
@@ -1255,7 +1257,7 @@ impl NativeTable {
} }
let dataset = self.dataset.get().await?; let dataset = self.dataset.get().await?;
let index_stats = dataset.index_statistics(&index.unwrap().index_name).await?; let index_stats = dataset.index_statistics(&index.unwrap().index_name).await?;
let index_stats: VectorIndexStatistics = whatever!( let index_stats: IndexStatistics = whatever!(
serde_json::from_str(&index_stats), serde_json::from_str(&index_stats),
"error deserializing index statistics {index_stats}", "error deserializing index statistics {index_stats}",
); );
@@ -1316,6 +1318,69 @@ impl NativeTable {
Ok(()) Ok(())
} }
async fn create_ivf_hnsw_pq_index(
&self,
index: IvfHnswPqIndexBuilder,
field: &Field,
replace: bool,
) -> Result<()> {
if !Self::supported_vector_data_type(field.data_type()) {
return Err(Error::InvalidInput {
message: format!(
"An IVF HNSW PQ index cannot be created on the column `{}` which has data type {}",
field.name(),
field.data_type()
),
});
}
let num_partitions = if let Some(n) = index.num_partitions {
n
} else {
suggested_num_partitions(self.count_rows(None).await?)
};
let num_sub_vectors: u32 = if let Some(n) = index.num_sub_vectors {
n
} else {
match field.data_type() {
arrow_schema::DataType::FixedSizeList(_, n) => {
Ok::<u32, Error>(suggested_num_sub_vectors(*n as u32))
}
_ => Err(Error::Schema {
message: format!("Column '{}' is not a FixedSizeList", field.name()),
}),
}?
};
let mut dataset = self.dataset.get_mut().await?;
let mut ivf_params = IvfBuildParams::new(num_partitions as usize);
ivf_params.sample_rate = index.sample_rate as usize;
ivf_params.max_iters = index.max_iterations as usize;
let hnsw_params = HnswBuildParams::default()
.num_edges(index.m as usize)
.ef_construction(index.ef_construction as usize);
let pq_params = PQBuildParams {
num_sub_vectors: num_sub_vectors as usize,
..Default::default()
};
let lance_idx_params = lance::index::vector::VectorIndexParams::with_ivf_hnsw_pq_params(
index.distance_type.into(),
ivf_params,
hnsw_params,
pq_params,
);
dataset
.create_index(
&[field.name()],
IndexType::Vector,
None,
&lance_idx_params,
replace,
)
.await?;
Ok(())
}
async fn create_ivf_hnsw_sq_index( async fn create_ivf_hnsw_sq_index(
&self, &self,
index: IvfHnswSqIndexBuilder, index: IvfHnswSqIndexBuilder,
@@ -1610,6 +1675,10 @@ impl TableInternal for NativeTable {
Index::Auto => self.create_auto_index(field, opts).await, Index::Auto => self.create_auto_index(field, opts).await,
Index::BTree(_) => self.create_btree_index(field, opts).await, Index::BTree(_) => self.create_btree_index(field, opts).await,
Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await, Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await,
Index::IvfHnswPq(ivf_hnsw_pq) => {
self.create_ivf_hnsw_pq_index(ivf_hnsw_pq, field, opts.replace)
.await
}
Index::IvfHnswSq(ivf_hnsw_sq) => { Index::IvfHnswSq(ivf_hnsw_sq) => {
self.create_ivf_hnsw_sq_index(ivf_hnsw_sq, field, opts.replace) self.create_ivf_hnsw_sq_index(ivf_hnsw_sq, field, opts.replace)
.await .await
@@ -1682,7 +1751,7 @@ impl TableInternal for NativeTable {
builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep); builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep);
} }
let job = builder.try_build()?; let job = builder.try_build()?;
let new_dataset = job.execute_reader(new_data).await?; let (new_dataset, _stats) = job.execute_reader(new_data).await?;
self.dataset.set_latest(new_dataset.as_ref().clone()).await; self.dataset.set_latest(new_dataset.as_ref().clone()).await;
Ok(()) Ok(())
} }
@@ -2475,6 +2544,25 @@ mod tests {
.unwrap(), .unwrap(),
Some(0) Some(0)
); );
assert_eq!(
table
.as_native()
.unwrap()
.get_index_type(index_uuid)
.await
.unwrap()
.map(|index_type| index_type.to_string()),
Some("IVF".to_string())
);
assert_eq!(
table
.as_native()
.unwrap()
.get_distance_type(index_uuid)
.await
.unwrap(),
Some(crate::DistanceType::L2.to_string())
);
} }
#[tokio::test] #[tokio::test]
@@ -2573,6 +2661,102 @@ mod tests {
); );
} }
#[tokio::test]
async fn test_create_index_ivf_hnsw_pq() {
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use rand;
use std::iter::repeat_with;
use arrow_array::Float32Array;
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let conn = connect(uri).execute().await.unwrap();
let dimension = 16;
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"embeddings",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dimension,
),
false,
)]));
let mut rng = rand::thread_rng();
let float_arr = Float32Array::from(
repeat_with(|| rng.gen::<f32>())
.take(512 * dimension as usize)
.collect::<Vec<f32>>(),
);
let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap());
let batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()]
.into_iter()
.map(Ok),
schema,
);
let table = conn.create_table("test", batches).execute().await.unwrap();
assert_eq!(
table
.as_native()
.unwrap()
.count_indexed_rows("my_index")
.await
.unwrap(),
None
);
assert_eq!(
table
.as_native()
.unwrap()
.count_unindexed_rows("my_index")
.await
.unwrap(),
None
);
let index = IvfHnswPqIndexBuilder::default();
table
.create_index(&["embeddings"], Index::IvfHnswPq(index))
.execute()
.await
.unwrap();
let index_configs = table.list_indices().await.unwrap();
assert_eq!(index_configs.len(), 1);
let index = index_configs.into_iter().next().unwrap();
assert_eq!(index.index_type, crate::index::IndexType::IvfPq);
assert_eq!(index.columns, vec!["embeddings".to_string()]);
assert_eq!(table.count_rows(None).await.unwrap(), 512);
assert_eq!(table.name(), "test");
let indices = table.as_native().unwrap().load_indices().await.unwrap();
let index_uuid = &indices[0].index_uuid;
assert_eq!(
table
.as_native()
.unwrap()
.count_indexed_rows(index_uuid)
.await
.unwrap(),
Some(512)
);
assert_eq!(
table
.as_native()
.unwrap()
.count_unindexed_rows(index_uuid)
.await
.unwrap(),
Some(0)
);
}
fn create_fixed_size_list<T: Array>(values: T, list_size: i32) -> Result<FixedSizeListArray> { fn create_fixed_size_list<T: Array>(values: T, list_size: i32) -> Result<FixedSizeListArray> {
let list_type = DataType::FixedSizeList( let list_type = DataType::FixedSizeList(
Arc::new(Field::new("item", values.data_type().clone(), true)), Arc::new(Field::new("item", values.data_type().clone(), true)),
@@ -2644,6 +2828,27 @@ mod tests {
let index = index_configs.into_iter().next().unwrap(); let index = index_configs.into_iter().next().unwrap();
assert_eq!(index.index_type, crate::index::IndexType::BTree); assert_eq!(index.index_type, crate::index::IndexType::BTree);
assert_eq!(index.columns, vec!["i".to_string()]); assert_eq!(index.columns, vec!["i".to_string()]);
let indices = table.as_native().unwrap().load_indices().await.unwrap();
let index_uuid = &indices[0].index_uuid;
assert_eq!(
table
.as_native()
.unwrap()
.count_indexed_rows(index_uuid)
.await
.unwrap(),
Some(1)
);
assert_eq!(
table
.as_native()
.unwrap()
.count_unindexed_rows(index_uuid)
.await
.unwrap(),
Some(0)
);
} }
#[tokio::test] #[tokio::test]