feat: Initial remote table implementation for rust (#1024)

This will eventually replace the remote table implementations in python
and node.
This commit is contained in:
Weston Pace
2024-02-29 10:55:49 -08:00
parent 45b5b66c82
commit 629c622d15
18 changed files with 376 additions and 21 deletions

View File

@@ -33,7 +33,7 @@ jobs:
python-version: "3.11"
- name: Install ruff
run: |
pip install ruff
pip install ruff==0.2.2
- name: Format check
run: ruff format --check .
- name: Lint

View File

@@ -0,0 +1,37 @@
name: LanceDb Cloud Integration Test
on:
workflow_run:
workflows: [Rust]
types:
- completed
env:
LANCEDB_PROJECT: ${{ secrets.LANCEDB_PROJECT }}
LANCEDB_API_KEY: ${{ secrets.LANCEDB_API_KEY }}
LANCEDB_REGION: ${{ secrets.LANCEDB_REGION }}
jobs:
test:
timeout-minutes: 30
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
working-directory: rust
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
lfs: true
- uses: Swatinem/rust-cache@v2
with:
workspaces: rust
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Build
run: cargo build --all-features
- name: Run Integration test
run: cargo test --tests -- --ignored

View File

@@ -119,3 +119,4 @@ jobs:
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
cargo build
cargo test

2
.gitignore vendored
View File

@@ -39,4 +39,6 @@ dist
## Rust
target
**/sccache.log
Cargo.lock

View File

@@ -5,17 +5,8 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 22.12.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.277
rev: v0.2.2
hooks:
- id: ruff
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: isort (python)

View File

@@ -103,9 +103,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
source_instruction: str = "represent the document for retrieval"
query_instruction: (
str
) = "represent the document for retrieving the most similar documents"
query_instruction: str = (
"represent the document for retrieving the most similar documents"
)
@weak_lru(maxsize=1)
def ndims(self):

View File

@@ -12,6 +12,7 @@
# limitations under the License.
"""Full text search index using tantivy-py"""
import os
from typing import List, Tuple

View File

@@ -277,6 +277,7 @@ class RemoteTable(Table):
f = Future()
f.set_result(self._conn._client.query(name, q))
return f
else:
def submit(name, q):

View File

@@ -12,6 +12,7 @@
# limitations under the License.
"""Schema related utilities."""
import pyarrow as pa

View File

@@ -35,14 +35,16 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
match &self {
Ok(_) => Ok(self.unwrap()),
Err(err) => match err {
LanceError::InvalidInput { .. } => self.value_error(),
LanceError::InvalidTableName { .. } => self.value_error(),
LanceError::TableNotFound { .. } => self.value_error(),
LanceError::TableAlreadyExists { .. } => self.runtime_error(),
LanceError::Schema { .. } => self.value_error(),
LanceError::CreateDir { .. } => self.os_error(),
LanceError::TableAlreadyExists { .. } => self.runtime_error(),
LanceError::Store { .. } => self.runtime_error(),
LanceError::Lance { .. } => self.runtime_error(),
LanceError::Schema { .. } => self.value_error(),
LanceError::Runtime { .. } => self.runtime_error(),
LanceError::Http { .. } => self.runtime_error(),
},
}
}

View File

@@ -31,11 +31,19 @@ async-trait = "0"
bytes = "1"
futures.workspace = true
num-traits.workspace = true
url = { workspace = true }
url.workspace = true
serde = { version = "^1" }
serde_json = { version = "1" }
# For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
[dev-dependencies]
tempfile = "3.5.0"
rand = { version = "0.8.3", features = ["small_rng"] }
walkdir = "2"
[features]
default = ["remote"]
remote = ["dep:reqwest"]

View File

@@ -194,7 +194,7 @@ impl OpenTableBuilder {
}
#[async_trait::async_trait]
trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static {
pub(crate) trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static {
async fn table_names(&self) -> Result<Vec<String>>;
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef>;
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<TableRef>;
@@ -365,14 +365,46 @@ impl ConnectBuilder {
self
}
/// Establishes a connection to the database
pub async fn execute(self) -> Result<Connection> {
let internal = Arc::new(Database::connect_with_options(&self).await?);
#[cfg(feature = "remote")]
fn execute_remote(self) -> Result<Connection> {
let region = self.region.ok_or_else(|| Error::InvalidInput {
message: "A region is required when connecting to LanceDb Cloud".to_string(),
})?;
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
})?;
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
&self.uri,
&api_key,
&region,
self.host_override,
)?);
Ok(Connection {
internal,
uri: self.uri,
})
}
#[cfg(not(feature = "remote"))]
fn execute_remote(self) -> Result<Connection> {
Err(Error::Runtime {
message: "cannot connect to LanceDb Cloud unless the 'remote' feature is enabled"
.to_string(),
})
}
/// Establishes a connection to the database
pub async fn execute(self) -> Result<Connection> {
if self.uri.starts_with("db") {
self.execute_remote()
} else {
let internal = Arc::new(Database::connect_with_options(&self).await?);
Ok(Connection {
internal,
uri: self.uri,
})
}
}
}
/// Connect to a LanceDB database.

View File

@@ -22,6 +22,8 @@ use snafu::Snafu;
pub enum Error {
#[snafu(display("LanceDBError: Invalid table name: {name}"))]
InvalidTableName { name: String },
#[snafu(display("LanceDBError: Invalid input, {message}"))]
InvalidInput { message: String },
#[snafu(display("LanceDBError: Table '{name}' was not found"))]
TableNotFound { name: String },
#[snafu(display("LanceDBError: Table '{name}' already exists"))]
@@ -31,6 +33,8 @@ pub enum Error {
path: String,
source: std::io::Error,
},
#[snafu(display("LanceDBError: Http error: {message}"))]
Http { message: String },
#[snafu(display("LanceDBError: {message}"))]
Store { message: String },
#[snafu(display("LanceDBError: {message}"))]
@@ -82,3 +86,21 @@ impl<T> From<PoisonError<T>> for Error {
}
}
}
#[cfg(feature = "remote")]
impl From<reqwest::Error> for Error {
fn from(e: reqwest::Error) -> Self {
Self::Http {
message: e.to_string(),
}
}
}
#[cfg(feature = "remote")]
impl From<url::ParseError> for Error {
fn from(e: url::ParseError) -> Self {
Self::Http {
message: e.to_string(),
}
}
}

View File

@@ -188,6 +188,8 @@ pub mod index;
pub mod io;
pub mod ipc;
pub mod query;
#[cfg(feature = "remote")]
pub(crate) mod remote;
pub mod table;
pub mod utils;

View File

@@ -0,0 +1,21 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! This module contains a remote client for a LanceDB server. This is used
//! to communicate with LanceDB cloud. It can also serve as an example for
//! building client/server applications with LanceDB or as a client for some
//! other custom LanceDB service.
pub mod client;
pub mod db;

View File

@@ -0,0 +1,119 @@
// Copyright 2024 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
use reqwest::{
header::{HeaderMap, HeaderValue},
RequestBuilder, Response,
};
use crate::error::{Error, Result};
#[derive(Debug)]
pub struct RestfulLanceDbClient {
client: reqwest::Client,
host: String,
}
impl RestfulLanceDbClient {
fn default_headers(
api_key: &str,
region: &str,
db_name: &str,
has_host_override: bool,
) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
headers.insert(
"x-api-key",
HeaderValue::from_str(api_key).map_err(|_| Error::Http {
message: "non-ascii api key provided".to_string(),
})?,
);
if region == "local" {
let host = format!("{}.local.api.lancedb.com", db_name);
headers.insert(
"Host",
HeaderValue::from_str(&host).map_err(|_| Error::Http {
message: format!("non-ascii database name '{}' provided", db_name),
})?,
);
}
if has_host_override {
headers.insert(
"x-lancedb-database",
HeaderValue::from_str(db_name).map_err(|_| Error::Http {
message: format!("non-ascii database name '{}' provided", db_name),
})?,
);
}
Ok(headers)
}
pub fn try_new(
db_url: &str,
api_key: &str,
region: &str,
host_override: Option<String>,
) -> Result<Self> {
let parsed_url = url::Url::parse(db_url)?;
debug_assert_eq!(parsed_url.scheme(), "db");
if !parsed_url.has_host() {
return Err(Error::Http {
message: format!("Invalid database URL (missing host) '{}'", db_url),
});
}
let db_name = parsed_url.host_str().unwrap();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.default_headers(Self::default_headers(
api_key,
region,
db_name,
host_override.is_some(),
)?)
.build()?;
let host = match host_override {
Some(host_override) => host_override,
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
};
Ok(Self { client, host })
}
pub fn get(&self, uri: &str) -> RequestBuilder {
let full_uri = format!("{}{}", self.host, uri);
self.client.get(full_uri)
}
async fn rsp_to_str(response: Response) -> String {
let status = response.status();
response.text().await.unwrap_or_else(|_| status.to_string())
}
pub async fn check_response(&self, response: Response) -> Result<Response> {
let status_int: u16 = u16::from(response.status());
if (400..500).contains(&status_int) {
Err(Error::InvalidInput {
message: Self::rsp_to_str(response).await,
})
} else if status_int != 200 {
Err(Error::Runtime {
message: Self::rsp_to_str(response).await,
})
} else {
Ok(response)
}
}
}

View File

@@ -0,0 +1,75 @@
// Copyright 2024 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use serde::Deserialize;
use crate::connection::{ConnectionInternal, CreateTableBuilder, OpenTableBuilder};
use crate::error::Result;
use crate::TableRef;
use super::client::RestfulLanceDbClient;
#[derive(Deserialize)]
struct ListTablesResponse {
tables: Vec<String>,
}
#[derive(Debug)]
pub struct RemoteDatabase {
client: RestfulLanceDbClient,
}
impl RemoteDatabase {
pub fn try_new(
uri: &str,
api_key: &str,
region: &str,
host_override: Option<String>,
) -> Result<Self> {
let client = RestfulLanceDbClient::try_new(uri, api_key, region, host_override)?;
Ok(Self { client })
}
}
#[async_trait]
impl ConnectionInternal for RemoteDatabase {
async fn table_names(&self) -> Result<Vec<String>> {
let rsp = self
.client
.get("/v1/table/")
.query(&[("limit", 10)])
.query(&[("page_token", "")])
.send()
.await?;
let rsp = self.client.check_response(rsp).await?;
Ok(rsp.json::<ListTablesResponse>().await?.tables)
}
async fn do_create_table(&self, _options: CreateTableBuilder<true>) -> Result<TableRef> {
todo!()
}
async fn do_open_table(&self, _options: OpenTableBuilder) -> Result<TableRef> {
todo!()
}
async fn drop_table(&self, _name: &str) -> Result<()> {
todo!()
}
async fn drop_db(&self) -> Result<()> {
todo!()
}
}

View File

@@ -0,0 +1,40 @@
// Copyright 2024 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[tokio::test]
#[ignore]
async fn cloud_integration_test() {
let project = std::env::var("LANCEDB_PROJECT")
.expect("the LANCEDB_PROJECT env must be set to run the cloud integration test");
let api_key = std::env::var("LANCEDB_API_KEY")
.expect("the LANCEDB_API_KEY env must be set to run the cloud integration test");
let region = std::env::var("LANCEDB_REGION")
.expect("the LANCEDB_REGION env must be set to run the cloud integration test");
let host_override = std::env::var("LANCEDB_HOST_OVERRIDE")
.map(Some)
.unwrap_or(None);
if host_override.is_none() {
println!("No LANCEDB_HOST_OVERRIDE has been set. Running integration test against LanceDb Cloud production instance");
}
let mut builder = lancedb::connect(&format!("db://{}", project))
.api_key(&api_key)
.region(&region);
if let Some(host_override) = &host_override {
builder = builder.host_override(host_override);
}
let db = builder.execute().await.unwrap();
db.table_names().await.unwrap();
}