mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
feat: Initial remote table implementation for rust (#1024)
This will eventually replace the remote table implementations in python and node.
This commit is contained in:
2
.github/workflows/python.yml
vendored
2
.github/workflows/python.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
- name: Install ruff
|
||||
run: |
|
||||
pip install ruff
|
||||
pip install ruff==0.2.2
|
||||
- name: Format check
|
||||
run: ruff format --check .
|
||||
- name: Lint
|
||||
|
||||
37
.github/workflows/remote-integration.yml
vendored
Normal file
37
.github/workflows/remote-integration.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
name: LanceDb Cloud Integration Test
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: [Rust]
|
||||
types:
|
||||
- completed
|
||||
|
||||
env:
|
||||
LANCEDB_PROJECT: ${{ secrets.LANCEDB_PROJECT }}
|
||||
LANCEDB_API_KEY: ${{ secrets.LANCEDB_API_KEY }}
|
||||
LANCEDB_REGION: ${{ secrets.LANCEDB_REGION }}
|
||||
|
||||
jobs:
|
||||
test:
|
||||
timeout-minutes: 30
|
||||
runs-on: ubuntu-22.04
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: rust
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: rust
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
- name: Build
|
||||
run: cargo build --all-features
|
||||
- name: Run Integration test
|
||||
run: cargo test --tests -- --ignored
|
||||
1
.github/workflows/rust.yml
vendored
1
.github/workflows/rust.yml
vendored
@@ -119,3 +119,4 @@ jobs:
|
||||
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
||||
cargo build
|
||||
cargo test
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -39,4 +39,6 @@ dist
|
||||
## Rust
|
||||
target
|
||||
|
||||
**/sccache.log
|
||||
|
||||
Cargo.lock
|
||||
|
||||
@@ -5,17 +5,8 @@ repos:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.12.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.0.277
|
||||
rev: v0.2.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
name: isort (python)
|
||||
@@ -103,9 +103,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||
|
||||
source_instruction: str = "represent the document for retrieval"
|
||||
query_instruction: (
|
||||
str
|
||||
) = "represent the document for retrieving the most similar documents"
|
||||
query_instruction: str = (
|
||||
"represent the document for retrieving the most similar documents"
|
||||
)
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def ndims(self):
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Full text search index using tantivy-py"""
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
@@ -277,6 +277,7 @@ class RemoteTable(Table):
|
||||
f = Future()
|
||||
f.set_result(self._conn._client.query(name, q))
|
||||
return f
|
||||
|
||||
else:
|
||||
|
||||
def submit(name, q):
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Schema related utilities."""
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
|
||||
@@ -35,14 +35,16 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
||||
match &self {
|
||||
Ok(_) => Ok(self.unwrap()),
|
||||
Err(err) => match err {
|
||||
LanceError::InvalidInput { .. } => self.value_error(),
|
||||
LanceError::InvalidTableName { .. } => self.value_error(),
|
||||
LanceError::TableNotFound { .. } => self.value_error(),
|
||||
LanceError::TableAlreadyExists { .. } => self.runtime_error(),
|
||||
LanceError::Schema { .. } => self.value_error(),
|
||||
LanceError::CreateDir { .. } => self.os_error(),
|
||||
LanceError::TableAlreadyExists { .. } => self.runtime_error(),
|
||||
LanceError::Store { .. } => self.runtime_error(),
|
||||
LanceError::Lance { .. } => self.runtime_error(),
|
||||
LanceError::Schema { .. } => self.value_error(),
|
||||
LanceError::Runtime { .. } => self.runtime_error(),
|
||||
LanceError::Http { .. } => self.runtime_error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,11 +31,19 @@ async-trait = "0"
|
||||
bytes = "1"
|
||||
futures.workspace = true
|
||||
num-traits.workspace = true
|
||||
url = { workspace = true }
|
||||
url.workspace = true
|
||||
serde = { version = "^1" }
|
||||
serde_json = { version = "1" }
|
||||
|
||||
# For remote feature
|
||||
|
||||
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.5.0"
|
||||
rand = { version = "0.8.3", features = ["small_rng"] }
|
||||
walkdir = "2"
|
||||
|
||||
[features]
|
||||
default = ["remote"]
|
||||
remote = ["dep:reqwest"]
|
||||
|
||||
@@ -194,7 +194,7 @@ impl OpenTableBuilder {
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static {
|
||||
pub(crate) trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static {
|
||||
async fn table_names(&self) -> Result<Vec<String>>;
|
||||
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef>;
|
||||
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<TableRef>;
|
||||
@@ -365,14 +365,46 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
fn execute_remote(self) -> Result<Connection> {
|
||||
let region = self.region.ok_or_else(|| Error::InvalidInput {
|
||||
message: "A region is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
|
||||
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||
&self.uri,
|
||||
&api_key,
|
||||
®ion,
|
||||
self.host_override,
|
||||
)?);
|
||||
Ok(Connection {
|
||||
internal,
|
||||
uri: self.uri,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "remote"))]
|
||||
fn execute_remote(self) -> Result<Connection> {
|
||||
Err(Error::Runtime {
|
||||
message: "cannot connect to LanceDb Cloud unless the 'remote' feature is enabled"
|
||||
.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Establishes a connection to the database
|
||||
pub async fn execute(self) -> Result<Connection> {
|
||||
if self.uri.starts_with("db") {
|
||||
self.execute_remote()
|
||||
} else {
|
||||
let internal = Arc::new(Database::connect_with_options(&self).await?);
|
||||
Ok(Connection {
|
||||
internal,
|
||||
uri: self.uri,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to a LanceDB database.
|
||||
|
||||
@@ -22,6 +22,8 @@ use snafu::Snafu;
|
||||
pub enum Error {
|
||||
#[snafu(display("LanceDBError: Invalid table name: {name}"))]
|
||||
InvalidTableName { name: String },
|
||||
#[snafu(display("LanceDBError: Invalid input, {message}"))]
|
||||
InvalidInput { message: String },
|
||||
#[snafu(display("LanceDBError: Table '{name}' was not found"))]
|
||||
TableNotFound { name: String },
|
||||
#[snafu(display("LanceDBError: Table '{name}' already exists"))]
|
||||
@@ -31,6 +33,8 @@ pub enum Error {
|
||||
path: String,
|
||||
source: std::io::Error,
|
||||
},
|
||||
#[snafu(display("LanceDBError: Http error: {message}"))]
|
||||
Http { message: String },
|
||||
#[snafu(display("LanceDBError: {message}"))]
|
||||
Store { message: String },
|
||||
#[snafu(display("LanceDBError: {message}"))]
|
||||
@@ -82,3 +86,21 @@ impl<T> From<PoisonError<T>> for Error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl From<reqwest::Error> for Error {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
Self::Http {
|
||||
message: e.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl From<url::ParseError> for Error {
|
||||
fn from(e: url::ParseError) -> Self {
|
||||
Self::Http {
|
||||
message: e.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,6 +188,8 @@ pub mod index;
|
||||
pub mod io;
|
||||
pub mod ipc;
|
||||
pub mod query;
|
||||
#[cfg(feature = "remote")]
|
||||
pub(crate) mod remote;
|
||||
pub mod table;
|
||||
pub mod utils;
|
||||
|
||||
|
||||
21
rust/lancedb/src/remote.rs
Normal file
21
rust/lancedb/src/remote.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! This module contains a remote client for a LanceDB server. This is used
|
||||
//! to communicate with LanceDB cloud. It can also serve as an example for
|
||||
//! building client/server applications with LanceDB or as a client for some
|
||||
//! other custom LanceDB service.
|
||||
|
||||
pub mod client;
|
||||
pub mod db;
|
||||
119
rust/lancedb/src/remote/client.rs
Normal file
119
rust/lancedb/src/remote/client.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
// Copyright 2024 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use reqwest::{
|
||||
header::{HeaderMap, HeaderValue},
|
||||
RequestBuilder, Response,
|
||||
};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RestfulLanceDbClient {
|
||||
client: reqwest::Client,
|
||||
host: String,
|
||||
}
|
||||
|
||||
impl RestfulLanceDbClient {
|
||||
fn default_headers(
|
||||
api_key: &str,
|
||||
region: &str,
|
||||
db_name: &str,
|
||||
has_host_override: bool,
|
||||
) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"x-api-key",
|
||||
HeaderValue::from_str(api_key).map_err(|_| Error::Http {
|
||||
message: "non-ascii api key provided".to_string(),
|
||||
})?,
|
||||
);
|
||||
if region == "local" {
|
||||
let host = format!("{}.local.api.lancedb.com", db_name);
|
||||
headers.insert(
|
||||
"Host",
|
||||
HeaderValue::from_str(&host).map_err(|_| Error::Http {
|
||||
message: format!("non-ascii database name '{}' provided", db_name),
|
||||
})?,
|
||||
);
|
||||
}
|
||||
if has_host_override {
|
||||
headers.insert(
|
||||
"x-lancedb-database",
|
||||
HeaderValue::from_str(db_name).map_err(|_| Error::Http {
|
||||
message: format!("non-ascii database name '{}' provided", db_name),
|
||||
})?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
pub fn try_new(
|
||||
db_url: &str,
|
||||
api_key: &str,
|
||||
region: &str,
|
||||
host_override: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let parsed_url = url::Url::parse(db_url)?;
|
||||
debug_assert_eq!(parsed_url.scheme(), "db");
|
||||
if !parsed_url.has_host() {
|
||||
return Err(Error::Http {
|
||||
message: format!("Invalid database URL (missing host) '{}'", db_url),
|
||||
});
|
||||
}
|
||||
let db_name = parsed_url.host_str().unwrap();
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.default_headers(Self::default_headers(
|
||||
api_key,
|
||||
region,
|
||||
db_name,
|
||||
host_override.is_some(),
|
||||
)?)
|
||||
.build()?;
|
||||
let host = match host_override {
|
||||
Some(host_override) => host_override,
|
||||
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
|
||||
};
|
||||
Ok(Self { client, host })
|
||||
}
|
||||
|
||||
pub fn get(&self, uri: &str) -> RequestBuilder {
|
||||
let full_uri = format!("{}{}", self.host, uri);
|
||||
self.client.get(full_uri)
|
||||
}
|
||||
|
||||
async fn rsp_to_str(response: Response) -> String {
|
||||
let status = response.status();
|
||||
response.text().await.unwrap_or_else(|_| status.to_string())
|
||||
}
|
||||
|
||||
pub async fn check_response(&self, response: Response) -> Result<Response> {
|
||||
let status_int: u16 = u16::from(response.status());
|
||||
if (400..500).contains(&status_int) {
|
||||
Err(Error::InvalidInput {
|
||||
message: Self::rsp_to_str(response).await,
|
||||
})
|
||||
} else if status_int != 200 {
|
||||
Err(Error::Runtime {
|
||||
message: Self::rsp_to_str(response).await,
|
||||
})
|
||||
} else {
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
}
|
||||
75
rust/lancedb/src/remote/db.rs
Normal file
75
rust/lancedb/src/remote/db.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
// Copyright 2024 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::connection::{ConnectionInternal, CreateTableBuilder, OpenTableBuilder};
|
||||
use crate::error::Result;
|
||||
use crate::TableRef;
|
||||
|
||||
use super::client::RestfulLanceDbClient;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ListTablesResponse {
|
||||
tables: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RemoteDatabase {
|
||||
client: RestfulLanceDbClient,
|
||||
}
|
||||
|
||||
impl RemoteDatabase {
|
||||
pub fn try_new(
|
||||
uri: &str,
|
||||
api_key: &str,
|
||||
region: &str,
|
||||
host_override: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let client = RestfulLanceDbClient::try_new(uri, api_key, region, host_override)?;
|
||||
Ok(Self { client })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectionInternal for RemoteDatabase {
|
||||
async fn table_names(&self) -> Result<Vec<String>> {
|
||||
let rsp = self
|
||||
.client
|
||||
.get("/v1/table/")
|
||||
.query(&[("limit", 10)])
|
||||
.query(&[("page_token", "")])
|
||||
.send()
|
||||
.await?;
|
||||
let rsp = self.client.check_response(rsp).await?;
|
||||
Ok(rsp.json::<ListTablesResponse>().await?.tables)
|
||||
}
|
||||
|
||||
async fn do_create_table(&self, _options: CreateTableBuilder<true>) -> Result<TableRef> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn do_open_table(&self, _options: OpenTableBuilder) -> Result<TableRef> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn drop_table(&self, _name: &str) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn drop_db(&self) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
40
rust/lancedb/tests/lancedb_cloud.rs
Normal file
40
rust/lancedb/tests/lancedb_cloud.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright 2024 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn cloud_integration_test() {
|
||||
let project = std::env::var("LANCEDB_PROJECT")
|
||||
.expect("the LANCEDB_PROJECT env must be set to run the cloud integration test");
|
||||
let api_key = std::env::var("LANCEDB_API_KEY")
|
||||
.expect("the LANCEDB_API_KEY env must be set to run the cloud integration test");
|
||||
let region = std::env::var("LANCEDB_REGION")
|
||||
.expect("the LANCEDB_REGION env must be set to run the cloud integration test");
|
||||
let host_override = std::env::var("LANCEDB_HOST_OVERRIDE")
|
||||
.map(Some)
|
||||
.unwrap_or(None);
|
||||
if host_override.is_none() {
|
||||
println!("No LANCEDB_HOST_OVERRIDE has been set. Running integration test against LanceDb Cloud production instance");
|
||||
}
|
||||
|
||||
let mut builder = lancedb::connect(&format!("db://{}", project))
|
||||
.api_key(&api_key)
|
||||
.region(®ion);
|
||||
if let Some(host_override) = &host_override {
|
||||
builder = builder.host_override(host_override);
|
||||
}
|
||||
let db = builder.execute().await.unwrap();
|
||||
|
||||
db.table_names().await.unwrap();
|
||||
}
|
||||
Reference in New Issue
Block a user