diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index d80cb948..e4a95569 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -33,7 +33,7 @@ jobs: python-version: "3.11" - name: Install ruff run: | - pip install ruff + pip install ruff==0.2.2 - name: Format check run: ruff format --check . - name: Lint diff --git a/.github/workflows/remote-integration.yml b/.github/workflows/remote-integration.yml new file mode 100644 index 00000000..68862ebf --- /dev/null +++ b/.github/workflows/remote-integration.yml @@ -0,0 +1,37 @@ +name: LanceDb Cloud Integration Test + +on: + workflow_run: + workflows: [Rust] + types: + - completed + +env: + LANCEDB_PROJECT: ${{ secrets.LANCEDB_PROJECT }} + LANCEDB_API_KEY: ${{ secrets.LANCEDB_API_KEY }} + LANCEDB_REGION: ${{ secrets.LANCEDB_REGION }} + +jobs: + test: + timeout-minutes: 30 + runs-on: ubuntu-22.04 + defaults: + run: + shell: bash + working-directory: rust + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + lfs: true + - uses: Swatinem/rust-cache@v2 + with: + workspaces: rust + - name: Install dependencies + run: | + sudo apt update + sudo apt install -y protobuf-compiler libssl-dev + - name: Build + run: cargo build --all-features + - name: Run Integration test + run: cargo test --tests -- --ignored diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c43a5d4f..d9c5358f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -119,3 +119,4 @@ jobs: $env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT cargo build cargo test + \ No newline at end of file diff --git a/.gitignore b/.gitignore index 607466e6..46e13e7b 100644 --- a/.gitignore +++ b/.gitignore @@ -39,4 +39,6 @@ dist ## Rust target +**/sccache.log + Cargo.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b8f409a0..93801fcd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,17 +5,8 @@ repos: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace -- repo: https://github.com/psf/black - rev: 22.12.0 - hooks: - - id: black - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.0.277 + rev: v0.2.2 hooks: - id: ruff -- repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - name: isort (python) \ No newline at end of file diff --git a/python/python/lancedb/embeddings/instructor.py b/python/python/lancedb/embeddings/instructor.py index e6481e19..98206bc5 100644 --- a/python/python/lancedb/embeddings/instructor.py +++ b/python/python/lancedb/embeddings/instructor.py @@ -103,9 +103,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction): # convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly source_instruction: str = "represent the document for retrieval" - query_instruction: ( - str - ) = "represent the document for retrieving the most similar documents" + query_instruction: str = ( + "represent the document for retrieving the most similar documents" + ) @weak_lru(maxsize=1) def ndims(self): diff --git a/python/python/lancedb/fts.py b/python/python/lancedb/fts.py index 750e3076..cb36aa79 100644 --- a/python/python/lancedb/fts.py +++ b/python/python/lancedb/fts.py @@ -12,6 +12,7 @@ # limitations under the License. """Full text search index using tantivy-py""" + import os from typing import List, Tuple diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 147a0b08..dafa2e2c 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -277,6 +277,7 @@ class RemoteTable(Table): f = Future() f.set_result(self._conn._client.query(name, q)) return f + else: def submit(name, q): diff --git a/python/python/lancedb/schema.py b/python/python/lancedb/schema.py index 9b5dd5e7..1a604dd5 100644 --- a/python/python/lancedb/schema.py +++ b/python/python/lancedb/schema.py @@ -12,6 +12,7 @@ # limitations under the License. """Schema related utilities.""" + import pyarrow as pa diff --git a/python/src/error.rs b/python/src/error.rs index c65192ca..20ae7c2a 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -35,14 +35,16 @@ impl PythonErrorExt for std::result::Result { match &self { Ok(_) => Ok(self.unwrap()), Err(err) => match err { + LanceError::InvalidInput { .. } => self.value_error(), LanceError::InvalidTableName { .. } => self.value_error(), LanceError::TableNotFound { .. } => self.value_error(), - LanceError::TableAlreadyExists { .. } => self.runtime_error(), + LanceError::Schema { .. } => self.value_error(), LanceError::CreateDir { .. } => self.os_error(), + LanceError::TableAlreadyExists { .. } => self.runtime_error(), LanceError::Store { .. } => self.runtime_error(), LanceError::Lance { .. } => self.runtime_error(), - LanceError::Schema { .. } => self.value_error(), LanceError::Runtime { .. } => self.runtime_error(), + LanceError::Http { .. } => self.runtime_error(), }, } } diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 0429dca2..e0febf1f 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -31,11 +31,19 @@ async-trait = "0" bytes = "1" futures.workspace = true num-traits.workspace = true -url = { workspace = true } +url.workspace = true serde = { version = "^1" } serde_json = { version = "1" } +# For remote feature + +reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } + [dev-dependencies] tempfile = "3.5.0" rand = { version = "0.8.3", features = ["small_rng"] } walkdir = "2" + +[features] +default = ["remote"] +remote = ["dep:reqwest"] diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 5b871b8b..11d72cde 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -194,7 +194,7 @@ impl OpenTableBuilder { } #[async_trait::async_trait] -trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static { +pub(crate) trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static { async fn table_names(&self) -> Result>; async fn do_create_table(&self, options: CreateTableBuilder) -> Result; async fn do_open_table(&self, options: OpenTableBuilder) -> Result; @@ -365,14 +365,46 @@ impl ConnectBuilder { self } - /// Establishes a connection to the database - pub async fn execute(self) -> Result { - let internal = Arc::new(Database::connect_with_options(&self).await?); + #[cfg(feature = "remote")] + fn execute_remote(self) -> Result { + let region = self.region.ok_or_else(|| Error::InvalidInput { + message: "A region is required when connecting to LanceDb Cloud".to_string(), + })?; + let api_key = self.api_key.ok_or_else(|| Error::InvalidInput { + message: "An api_key is required when connecting to LanceDb Cloud".to_string(), + })?; + let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new( + &self.uri, + &api_key, + ®ion, + self.host_override, + )?); Ok(Connection { internal, uri: self.uri, }) } + + #[cfg(not(feature = "remote"))] + fn execute_remote(self) -> Result { + Err(Error::Runtime { + message: "cannot connect to LanceDb Cloud unless the 'remote' feature is enabled" + .to_string(), + }) + } + + /// Establishes a connection to the database + pub async fn execute(self) -> Result { + if self.uri.starts_with("db") { + self.execute_remote() + } else { + let internal = Arc::new(Database::connect_with_options(&self).await?); + Ok(Connection { + internal, + uri: self.uri, + }) + } + } } /// Connect to a LanceDB database. diff --git a/rust/lancedb/src/error.rs b/rust/lancedb/src/error.rs index 2bdf97d6..f8961b85 100644 --- a/rust/lancedb/src/error.rs +++ b/rust/lancedb/src/error.rs @@ -22,6 +22,8 @@ use snafu::Snafu; pub enum Error { #[snafu(display("LanceDBError: Invalid table name: {name}"))] InvalidTableName { name: String }, + #[snafu(display("LanceDBError: Invalid input, {message}"))] + InvalidInput { message: String }, #[snafu(display("LanceDBError: Table '{name}' was not found"))] TableNotFound { name: String }, #[snafu(display("LanceDBError: Table '{name}' already exists"))] @@ -31,6 +33,8 @@ pub enum Error { path: String, source: std::io::Error, }, + #[snafu(display("LanceDBError: Http error: {message}"))] + Http { message: String }, #[snafu(display("LanceDBError: {message}"))] Store { message: String }, #[snafu(display("LanceDBError: {message}"))] @@ -82,3 +86,21 @@ impl From> for Error { } } } + +#[cfg(feature = "remote")] +impl From for Error { + fn from(e: reqwest::Error) -> Self { + Self::Http { + message: e.to_string(), + } + } +} + +#[cfg(feature = "remote")] +impl From for Error { + fn from(e: url::ParseError) -> Self { + Self::Http { + message: e.to_string(), + } + } +} diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index ac086ef6..a04826aa 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -188,6 +188,8 @@ pub mod index; pub mod io; pub mod ipc; pub mod query; +#[cfg(feature = "remote")] +pub(crate) mod remote; pub mod table; pub mod utils; diff --git a/rust/lancedb/src/remote.rs b/rust/lancedb/src/remote.rs new file mode 100644 index 00000000..57a86a92 --- /dev/null +++ b/rust/lancedb/src/remote.rs @@ -0,0 +1,21 @@ +// Copyright 2024 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! This module contains a remote client for a LanceDB server. This is used +//! to communicate with LanceDB cloud. It can also serve as an example for +//! building client/server applications with LanceDB or as a client for some +//! other custom LanceDB service. + +pub mod client; +pub mod db; diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs new file mode 100644 index 00000000..8b516d20 --- /dev/null +++ b/rust/lancedb/src/remote/client.rs @@ -0,0 +1,119 @@ +// Copyright 2024 LanceDB Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::time::Duration; + +use reqwest::{ + header::{HeaderMap, HeaderValue}, + RequestBuilder, Response, +}; + +use crate::error::{Error, Result}; + +#[derive(Debug)] +pub struct RestfulLanceDbClient { + client: reqwest::Client, + host: String, +} + +impl RestfulLanceDbClient { + fn default_headers( + api_key: &str, + region: &str, + db_name: &str, + has_host_override: bool, + ) -> Result { + let mut headers = HeaderMap::new(); + headers.insert( + "x-api-key", + HeaderValue::from_str(api_key).map_err(|_| Error::Http { + message: "non-ascii api key provided".to_string(), + })?, + ); + if region == "local" { + let host = format!("{}.local.api.lancedb.com", db_name); + headers.insert( + "Host", + HeaderValue::from_str(&host).map_err(|_| Error::Http { + message: format!("non-ascii database name '{}' provided", db_name), + })?, + ); + } + if has_host_override { + headers.insert( + "x-lancedb-database", + HeaderValue::from_str(db_name).map_err(|_| Error::Http { + message: format!("non-ascii database name '{}' provided", db_name), + })?, + ); + } + + Ok(headers) + } + + pub fn try_new( + db_url: &str, + api_key: &str, + region: &str, + host_override: Option, + ) -> Result { + let parsed_url = url::Url::parse(db_url)?; + debug_assert_eq!(parsed_url.scheme(), "db"); + if !parsed_url.has_host() { + return Err(Error::Http { + message: format!("Invalid database URL (missing host) '{}'", db_url), + }); + } + let db_name = parsed_url.host_str().unwrap(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .default_headers(Self::default_headers( + api_key, + region, + db_name, + host_override.is_some(), + )?) + .build()?; + let host = match host_override { + Some(host_override) => host_override, + None => format!("https://{}.{}.api.lancedb.com", db_name, region), + }; + Ok(Self { client, host }) + } + + pub fn get(&self, uri: &str) -> RequestBuilder { + let full_uri = format!("{}{}", self.host, uri); + self.client.get(full_uri) + } + + async fn rsp_to_str(response: Response) -> String { + let status = response.status(); + response.text().await.unwrap_or_else(|_| status.to_string()) + } + + pub async fn check_response(&self, response: Response) -> Result { + let status_int: u16 = u16::from(response.status()); + if (400..500).contains(&status_int) { + Err(Error::InvalidInput { + message: Self::rsp_to_str(response).await, + }) + } else if status_int != 200 { + Err(Error::Runtime { + message: Self::rsp_to_str(response).await, + }) + } else { + Ok(response) + } + } +} diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs new file mode 100644 index 00000000..db7ee00d --- /dev/null +++ b/rust/lancedb/src/remote/db.rs @@ -0,0 +1,75 @@ +// Copyright 2024 LanceDB Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_trait::async_trait; +use serde::Deserialize; + +use crate::connection::{ConnectionInternal, CreateTableBuilder, OpenTableBuilder}; +use crate::error::Result; +use crate::TableRef; + +use super::client::RestfulLanceDbClient; + +#[derive(Deserialize)] +struct ListTablesResponse { + tables: Vec, +} + +#[derive(Debug)] +pub struct RemoteDatabase { + client: RestfulLanceDbClient, +} + +impl RemoteDatabase { + pub fn try_new( + uri: &str, + api_key: &str, + region: &str, + host_override: Option, + ) -> Result { + let client = RestfulLanceDbClient::try_new(uri, api_key, region, host_override)?; + Ok(Self { client }) + } +} + +#[async_trait] +impl ConnectionInternal for RemoteDatabase { + async fn table_names(&self) -> Result> { + let rsp = self + .client + .get("/v1/table/") + .query(&[("limit", 10)]) + .query(&[("page_token", "")]) + .send() + .await?; + let rsp = self.client.check_response(rsp).await?; + Ok(rsp.json::().await?.tables) + } + + async fn do_create_table(&self, _options: CreateTableBuilder) -> Result { + todo!() + } + + async fn do_open_table(&self, _options: OpenTableBuilder) -> Result { + todo!() + } + + async fn drop_table(&self, _name: &str) -> Result<()> { + todo!() + } + + async fn drop_db(&self) -> Result<()> { + todo!() + } +} diff --git a/rust/lancedb/tests/lancedb_cloud.rs b/rust/lancedb/tests/lancedb_cloud.rs new file mode 100644 index 00000000..88fae82d --- /dev/null +++ b/rust/lancedb/tests/lancedb_cloud.rs @@ -0,0 +1,40 @@ +// Copyright 2024 LanceDB Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[tokio::test] +#[ignore] +async fn cloud_integration_test() { + let project = std::env::var("LANCEDB_PROJECT") + .expect("the LANCEDB_PROJECT env must be set to run the cloud integration test"); + let api_key = std::env::var("LANCEDB_API_KEY") + .expect("the LANCEDB_API_KEY env must be set to run the cloud integration test"); + let region = std::env::var("LANCEDB_REGION") + .expect("the LANCEDB_REGION env must be set to run the cloud integration test"); + let host_override = std::env::var("LANCEDB_HOST_OVERRIDE") + .map(Some) + .unwrap_or(None); + if host_override.is_none() { + println!("No LANCEDB_HOST_OVERRIDE has been set. Running integration test against LanceDb Cloud production instance"); + } + + let mut builder = lancedb::connect(&format!("db://{}", project)) + .api_key(&api_key) + .region(®ion); + if let Some(host_override) = &host_override { + builder = builder.host_override(host_override); + } + let db = builder.execute().await.unwrap(); + + db.table_names().await.unwrap(); +}