mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
Compare commits
13 Commits
python-v0.
...
hybrid_que
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
887ac0d79d | ||
|
|
c1af53b787 | ||
|
|
2a02d1394b | ||
|
|
085066d2a8 | ||
|
|
adf1a38f4d | ||
|
|
294c33a42e | ||
|
|
3ad4992282 | ||
|
|
51cc422799 | ||
|
|
a696dbc8f4 | ||
|
|
9ca0260d54 | ||
|
|
6486ec870b | ||
|
|
64db2393f7 | ||
|
|
bd4e8341fe |
2
.github/workflows/python.yml
vendored
2
.github/workflows/python.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
- name: Install ruff
|
- name: Install ruff
|
||||||
run: |
|
run: |
|
||||||
pip install ruff
|
pip install ruff==0.2.2
|
||||||
- name: Format check
|
- name: Format check
|
||||||
run: ruff format --check .
|
run: ruff format --check .
|
||||||
- name: Lint
|
- name: Lint
|
||||||
|
|||||||
37
.github/workflows/remote-integration.yml
vendored
Normal file
37
.github/workflows/remote-integration.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
name: LanceDb Cloud Integration Test
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_run:
|
||||||
|
workflows: [Rust]
|
||||||
|
types:
|
||||||
|
- completed
|
||||||
|
|
||||||
|
env:
|
||||||
|
LANCEDB_PROJECT: ${{ secrets.LANCEDB_PROJECT }}
|
||||||
|
LANCEDB_API_KEY: ${{ secrets.LANCEDB_API_KEY }}
|
||||||
|
LANCEDB_REGION: ${{ secrets.LANCEDB_REGION }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
timeout-minutes: 30
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: rust
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
lfs: true
|
||||||
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
with:
|
||||||
|
workspaces: rust
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install -y protobuf-compiler libssl-dev
|
||||||
|
- name: Build
|
||||||
|
run: cargo build --all-features
|
||||||
|
- name: Run Integration test
|
||||||
|
run: cargo test --tests -- --ignored
|
||||||
2
.github/workflows/run_tests/action.yml
vendored
2
.github/workflows/run_tests/action.yml
vendored
@@ -11,7 +11,7 @@ runs:
|
|||||||
- name: Install lancedb
|
- name: Install lancedb
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
pip3 install $(ls target/wheels/lancedb-*.whl)[tests,dev,embeddings]
|
pip3 install $(ls target/wheels/lancedb-*.whl)[tests,dev]
|
||||||
- name: pytest
|
- name: pytest
|
||||||
shell: bash
|
shell: bash
|
||||||
run: pytest -m "not slow" -x -v --durations=30 python/python/tests
|
run: pytest -m "not slow" -x -v --durations=30 python/python/tests
|
||||||
|
|||||||
1
.github/workflows/rust.yml
vendored
1
.github/workflows/rust.yml
vendored
@@ -119,3 +119,4 @@ jobs:
|
|||||||
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
|
||||||
cargo build
|
cargo build
|
||||||
cargo test
|
cargo test
|
||||||
|
|
||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -39,4 +39,6 @@ dist
|
|||||||
## Rust
|
## Rust
|
||||||
target
|
target
|
||||||
|
|
||||||
|
**/sccache.log
|
||||||
|
|
||||||
Cargo.lock
|
Cargo.lock
|
||||||
|
|||||||
@@ -5,17 +5,8 @@ repos:
|
|||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- repo: https://github.com/psf/black
|
|
||||||
rev: 22.12.0
|
|
||||||
hooks:
|
|
||||||
- id: black
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
# Ruff version.
|
# Ruff version.
|
||||||
rev: v0.0.277
|
rev: v0.2.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
- repo: https://github.com/pycqa/isort
|
|
||||||
rev: 5.12.0
|
|
||||||
hooks:
|
|
||||||
- id: isort
|
|
||||||
name: isort (python)
|
|
||||||
@@ -341,6 +341,7 @@ export interface Table<T = number[]> {
|
|||||||
*
|
*
|
||||||
* @param column The column to index
|
* @param column The column to index
|
||||||
* @param replace If false, fail if an index already exists on the column
|
* @param replace If false, fail if an index already exists on the column
|
||||||
|
* it is always set to true for remote connections
|
||||||
*
|
*
|
||||||
* Scalar indices, like vector indices, can be used to speed up scans. A scalar
|
* Scalar indices, like vector indices, can be used to speed up scans. A scalar
|
||||||
* index can speed up scans that contain filter expressions on the indexed column.
|
* index can speed up scans that contain filter expressions on the indexed column.
|
||||||
@@ -384,7 +385,7 @@ export interface Table<T = number[]> {
|
|||||||
* await table.createScalarIndex('my_col')
|
* await table.createScalarIndex('my_col')
|
||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
createScalarIndex: (column: string, replace: boolean) => Promise<void>
|
createScalarIndex: (column: string, replace?: boolean) => Promise<void>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the number of rows in this table.
|
* Returns the number of rows in this table.
|
||||||
@@ -914,7 +915,10 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async createScalarIndex (column: string, replace: boolean): Promise<void> {
|
async createScalarIndex (column: string, replace?: boolean): Promise<void> {
|
||||||
|
if (replace === undefined) {
|
||||||
|
replace = true
|
||||||
|
}
|
||||||
return tableCreateScalarIndex.call(this._tbl, column, replace)
|
return tableCreateScalarIndex.call(this._tbl, column, replace)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -397,7 +397,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const column = indexParams.column ?? 'vector'
|
const column = indexParams.column ?? 'vector'
|
||||||
const indexType = 'vector' // only vector index is supported for remote connections
|
const indexType = 'vector'
|
||||||
const metricType = indexParams.metric_type ?? 'L2'
|
const metricType = indexParams.metric_type ?? 'L2'
|
||||||
const indexCacheSize = indexParams.index_cache_size ?? null
|
const indexCacheSize = indexParams.index_cache_size ?? null
|
||||||
|
|
||||||
@@ -420,8 +420,25 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async createScalarIndex (column: string, replace: boolean): Promise<void> {
|
async createScalarIndex (column: string): Promise<void> {
|
||||||
throw new Error('Not implemented')
|
const indexType = 'scalar'
|
||||||
|
|
||||||
|
const data = {
|
||||||
|
column,
|
||||||
|
index_type: indexType,
|
||||||
|
replace: true
|
||||||
|
}
|
||||||
|
const res = await this._client.post(
|
||||||
|
`/v1/table/${this._name}/create_scalar_index/`,
|
||||||
|
data
|
||||||
|
)
|
||||||
|
if (res.status !== 200) {
|
||||||
|
throw new Error(
|
||||||
|
`Server Error, status: ${res.status}, ` +
|
||||||
|
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||||
|
`message: ${res.statusText}: ${res.data}`
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async countRows (): Promise<number> {
|
async countRows (): Promise<number> {
|
||||||
|
|||||||
34
nodejs/__test__/connection.test.ts
Normal file
34
nodejs/__test__/connection.test.ts
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
// Copyright 2024 Lance Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import * as os from "os";
|
||||||
|
import * as path from "path";
|
||||||
|
import * as fs from "fs";
|
||||||
|
|
||||||
|
import { connect } from "../dist/index.js";
|
||||||
|
|
||||||
|
describe("when working with a connection", () => {
|
||||||
|
|
||||||
|
const tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "test-connection"));
|
||||||
|
|
||||||
|
it("should fail if creating table twice, unless overwrite is true", async() => {
|
||||||
|
const db = await connect(tmpDir);
|
||||||
|
let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
|
||||||
|
await expect(tbl.countRows()).resolves.toBe(2);
|
||||||
|
await expect(db.createTable("test", [{ id: 1 }, { id: 2 }])).rejects.toThrow();
|
||||||
|
tbl = await db.createTable("test", [{ id: 3 }], { mode: "overwrite" });
|
||||||
|
await expect(tbl.countRows()).resolves.toBe(1);
|
||||||
|
})
|
||||||
|
|
||||||
|
});
|
||||||
@@ -201,17 +201,17 @@ describe("Read consistency interval", () => {
|
|||||||
await table.add([{ id: 2 }]);
|
await table.add([{ id: 2 }]);
|
||||||
|
|
||||||
if (interval === undefined) {
|
if (interval === undefined) {
|
||||||
expect(await table2.countRows()).toEqual(1n);
|
expect(await table2.countRows()).toEqual(1);
|
||||||
// TODO: once we implement time travel we can uncomment this part of the test.
|
// TODO: once we implement time travel we can uncomment this part of the test.
|
||||||
// await table2.checkout_latest();
|
// await table2.checkout_latest();
|
||||||
// expect(await table2.countRows()).toEqual(2);
|
// expect(await table2.countRows()).toEqual(2);
|
||||||
} else if (interval === 0) {
|
} else if (interval === 0) {
|
||||||
expect(await table2.countRows()).toEqual(2n);
|
expect(await table2.countRows()).toEqual(2);
|
||||||
} else {
|
} else {
|
||||||
// interval == 0.1
|
// interval == 0.1
|
||||||
expect(await table2.countRows()).toEqual(1n);
|
expect(await table2.countRows()).toEqual(1);
|
||||||
await new Promise(r => setTimeout(r, 100));
|
await new Promise(r => setTimeout(r, 100));
|
||||||
expect(await table2.countRows()).toEqual(2n);
|
expect(await table2.countRows()).toEqual(2);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
15
nodejs/__test__/tsconfig.json
Normal file
15
nodejs/__test__/tsconfig.json
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"extends": "../tsconfig.json",
|
||||||
|
"compilerOptions": {
|
||||||
|
"outDir": "./dist/spec",
|
||||||
|
"module": "commonjs",
|
||||||
|
"target": "es2022",
|
||||||
|
"types": [
|
||||||
|
"jest",
|
||||||
|
"node"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"include": [
|
||||||
|
"**/*",
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -17,6 +17,24 @@ import { Connection as _NativeConnection } from "./native";
|
|||||||
import { Table } from "./table";
|
import { Table } from "./table";
|
||||||
import { Table as ArrowTable } from "apache-arrow";
|
import { Table as ArrowTable } from "apache-arrow";
|
||||||
|
|
||||||
|
export interface CreateTableOptions {
|
||||||
|
/**
|
||||||
|
* The mode to use when creating the table.
|
||||||
|
*
|
||||||
|
* If this is set to "create" and the table already exists then either
|
||||||
|
* an error will be thrown or, if existOk is true, then nothing will
|
||||||
|
* happen. Any provided data will be ignored.
|
||||||
|
*
|
||||||
|
* If this is set to "overwrite" then any existing table will be replaced.
|
||||||
|
*/
|
||||||
|
mode: "create" | "overwrite";
|
||||||
|
/**
|
||||||
|
* If this is true and the table already exists and the mode is "create"
|
||||||
|
* then no error will be raised.
|
||||||
|
*/
|
||||||
|
existOk: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A LanceDB Connection that allows you to open tables and create new ones.
|
* A LanceDB Connection that allows you to open tables and create new ones.
|
||||||
*
|
*
|
||||||
@@ -53,10 +71,18 @@ export class Connection {
|
|||||||
*/
|
*/
|
||||||
async createTable(
|
async createTable(
|
||||||
name: string,
|
name: string,
|
||||||
data: Record<string, unknown>[] | ArrowTable
|
data: Record<string, unknown>[] | ArrowTable,
|
||||||
|
options?: Partial<CreateTableOptions>
|
||||||
): Promise<Table> {
|
): Promise<Table> {
|
||||||
|
let mode: string = options?.mode ?? "create";
|
||||||
|
const existOk = options?.existOk ?? false;
|
||||||
|
|
||||||
|
if (mode === "create" && existOk) {
|
||||||
|
mode = "exist_ok";
|
||||||
|
}
|
||||||
|
|
||||||
const buf = toBuffer(data);
|
const buf = toBuffer(data);
|
||||||
const innerTable = await this.inner.createTable(name, buf);
|
const innerTable = await this.inner.createTable(name, buf, mode);
|
||||||
return new Table(innerTable);
|
return new Table(innerTable);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
4
nodejs/lancedb/native.d.ts
vendored
4
nodejs/lancedb/native.d.ts
vendored
@@ -85,7 +85,7 @@ export class Connection {
|
|||||||
* - buf: The buffer containing the IPC file.
|
* - buf: The buffer containing the IPC file.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
createTable(name: string, buf: Buffer): Promise<Table>
|
createTable(name: string, buf: Buffer, mode: string): Promise<Table>
|
||||||
openTable(name: string): Promise<Table>
|
openTable(name: string): Promise<Table>
|
||||||
/** Drop table with the name. Or raise an error if the table does not exist. */
|
/** Drop table with the name. Or raise an error if the table does not exist. */
|
||||||
dropTable(name: string): Promise<void>
|
dropTable(name: string): Promise<void>
|
||||||
@@ -117,7 +117,7 @@ export class Table {
|
|||||||
/** Return Schema as empty Arrow IPC file. */
|
/** Return Schema as empty Arrow IPC file. */
|
||||||
schema(): Promise<Buffer>
|
schema(): Promise<Buffer>
|
||||||
add(buf: Buffer): Promise<void>
|
add(buf: Buffer): Promise<void>
|
||||||
countRows(filter?: string | undefined | null): Promise<bigint>
|
countRows(filter?: string | undefined | null): Promise<number>
|
||||||
delete(predicate: string): Promise<void>
|
delete(predicate: string): Promise<void>
|
||||||
createIndex(): IndexBuilder
|
createIndex(): IndexBuilder
|
||||||
query(): Query
|
query(): Query
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ export class Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Count the total number of rows in the dataset. */
|
/** Count the total number of rows in the dataset. */
|
||||||
async countRows(filter?: string): Promise<bigint> {
|
async countRows(filter?: string): Promise<number> {
|
||||||
return await this.inner.countRows(filter);
|
return await this.inner.countRows(filter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -51,8 +51,7 @@
|
|||||||
"docs": "typedoc --plugin typedoc-plugin-markdown lancedb/index.ts",
|
"docs": "typedoc --plugin typedoc-plugin-markdown lancedb/index.ts",
|
||||||
"lint": "eslint lancedb --ext .js,.ts",
|
"lint": "eslint lancedb --ext .js,.ts",
|
||||||
"prepublishOnly": "napi prepublish -t npm",
|
"prepublishOnly": "napi prepublish -t npm",
|
||||||
"//": "maxWorkers=1 is workaround for bigint issue in jest: https://github.com/jestjs/jest/issues/11617#issuecomment-1068732414",
|
"test": "npm run build && jest --verbose",
|
||||||
"test": "npm run build && jest --maxWorkers=1",
|
|
||||||
"universal": "napi universal",
|
"universal": "napi universal",
|
||||||
"version": "napi version"
|
"version": "napi version"
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ use napi_derive::*;
|
|||||||
|
|
||||||
use crate::table::Table;
|
use crate::table::Table;
|
||||||
use crate::ConnectionOptions;
|
use crate::ConnectionOptions;
|
||||||
use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection};
|
use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection, CreateTableMode};
|
||||||
use lancedb::ipc::ipc_file_to_batches;
|
use lancedb::ipc::ipc_file_to_batches;
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
@@ -25,6 +25,17 @@ pub struct Connection {
|
|||||||
conn: LanceDBConnection,
|
conn: LanceDBConnection,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Connection {
|
||||||
|
fn parse_create_mode_str(mode: &str) -> napi::Result<CreateTableMode> {
|
||||||
|
match mode {
|
||||||
|
"create" => Ok(CreateTableMode::Create),
|
||||||
|
"overwrite" => Ok(CreateTableMode::Overwrite),
|
||||||
|
"exist_ok" => Ok(CreateTableMode::exist_ok(|builder| builder)),
|
||||||
|
_ => Err(napi::Error::from_reason(format!("Invalid mode {}", mode))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
impl Connection {
|
impl Connection {
|
||||||
/// Create a new Connection instance from the given URI.
|
/// Create a new Connection instance from the given URI.
|
||||||
@@ -65,12 +76,19 @@ impl Connection {
|
|||||||
/// - buf: The buffer containing the IPC file.
|
/// - buf: The buffer containing the IPC file.
|
||||||
///
|
///
|
||||||
#[napi]
|
#[napi]
|
||||||
pub async fn create_table(&self, name: String, buf: Buffer) -> napi::Result<Table> {
|
pub async fn create_table(
|
||||||
|
&self,
|
||||||
|
name: String,
|
||||||
|
buf: Buffer,
|
||||||
|
mode: String,
|
||||||
|
) -> napi::Result<Table> {
|
||||||
let batches = ipc_file_to_batches(buf.to_vec())
|
let batches = ipc_file_to_batches(buf.to_vec())
|
||||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||||
|
let mode = Self::parse_create_mode_str(&mode)?;
|
||||||
let tbl = self
|
let tbl = self
|
||||||
.conn
|
.conn
|
||||||
.create_table(&name, Box::new(batches))
|
.create_table(&name, Box::new(batches))
|
||||||
|
.mode(mode)
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
|
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
|
||||||
|
|||||||
@@ -68,13 +68,17 @@ impl Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<usize> {
|
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<i64> {
|
||||||
self.table.count_rows(filter).await.map_err(|e| {
|
self.table
|
||||||
napi::Error::from_reason(format!(
|
.count_rows(filter)
|
||||||
"Failed to count rows in table {}: {}",
|
.await
|
||||||
self.table, e
|
.map(|val| val as i64)
|
||||||
))
|
.map_err(|e| {
|
||||||
})
|
napi::Error::from_reason(format!(
|
||||||
|
"Failed to count rows in table {}: {}",
|
||||||
|
self.table, e
|
||||||
|
))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.6.0
|
current_version = 0.6.1
|
||||||
commit = True
|
commit = True
|
||||||
message = [python] Bump version: {current_version} → {new_version}
|
message = [python] Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
24
python/ASYNC_MIGRATION.md
Normal file
24
python/ASYNC_MIGRATION.md
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Migration from Sync to Async API
|
||||||
|
|
||||||
|
A new asynchronous API has been added to LanceDb. This API is built
|
||||||
|
on top of the rust lancedb crate (instead of being built on top of
|
||||||
|
pylance). This will help keep the various language bindings in sync.
|
||||||
|
There are some slight changes between the synchronous and the asynchronous
|
||||||
|
APIs. This document will help you migrate. These changes relate mostly
|
||||||
|
to the Connection and Table classes.
|
||||||
|
|
||||||
|
## Almost all functions are async
|
||||||
|
|
||||||
|
The most important change is that almost all functions are now async.
|
||||||
|
This means the functions now return `asyncio` coroutines. You will
|
||||||
|
need to use `await` to call these functions.
|
||||||
|
|
||||||
|
## Connection
|
||||||
|
|
||||||
|
No changes yet.
|
||||||
|
|
||||||
|
## Table
|
||||||
|
|
||||||
|
* Previously `Table.schema` was a property. Now it is an async method.
|
||||||
|
* The method `Table.__len__` was removed and `len(table)` will no longer
|
||||||
|
work. Use `Table.count_rows` instead.
|
||||||
@@ -14,6 +14,7 @@ name = "_lancedb"
|
|||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
arrow = { version = "50.0.0", features = ["pyarrow"] }
|
||||||
lancedb = { path = "../rust/lancedb" }
|
lancedb = { path = "../rust/lancedb" }
|
||||||
env_logger = "0.10"
|
env_logger = "0.10"
|
||||||
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] }
|
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] }
|
||||||
@@ -23,4 +24,7 @@ pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
|
|||||||
lzma-sys = { version = "*", features = ["static"] }
|
lzma-sys = { version = "*", features = ["static"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
pyo3-build-config = { version = "0.20.3", features = ["extension-module", "abi3-py38"] }
|
pyo3-build-config = { version = "0.20.3", features = [
|
||||||
|
"extension-module",
|
||||||
|
"abi3-py38",
|
||||||
|
] }
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.6.0"
|
version = "0.6.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.10.1",
|
"pylance==0.10.1",
|
||||||
|
|||||||
@@ -1,7 +1,19 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
class Connection(object):
|
class Connection(object):
|
||||||
async def table_names(self) -> list[str]: ...
|
async def table_names(self) -> list[str]: ...
|
||||||
|
async def create_table(
|
||||||
|
self, name: str, mode: str, data: pa.RecordBatchReader
|
||||||
|
) -> Table: ...
|
||||||
|
async def create_empty_table(
|
||||||
|
self, name: str, mode: str, schema: pa.Schema
|
||||||
|
) -> Table: ...
|
||||||
|
|
||||||
|
class Table(object):
|
||||||
|
def name(self) -> str: ...
|
||||||
|
async def schema(self) -> pa.Schema: ...
|
||||||
|
|
||||||
async def connect(
|
async def connect(
|
||||||
uri: str,
|
uri: str,
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, List, Union
|
from typing import Iterable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
@@ -38,3 +38,99 @@ class Credential(str):
|
|||||||
|
|
||||||
def sanitize_uri(uri: URI) -> str:
|
def sanitize_uri(uri: URI) -> str:
|
||||||
return str(uri)
|
return str(uri)
|
||||||
|
|
||||||
|
|
||||||
|
def _casting_recordbatch_iter(
|
||||||
|
input_iter: Iterable[pa.RecordBatch], schema: pa.Schema
|
||||||
|
) -> Iterable[pa.RecordBatch]:
|
||||||
|
"""
|
||||||
|
Wrapper around an iterator of record batches. If the batches don't match the
|
||||||
|
schema, try to cast them to the schema. If that fails, raise an error.
|
||||||
|
|
||||||
|
This is helpful for users who might have written the iterator with default
|
||||||
|
data types in PyArrow, but specified more specific types in the schema. For
|
||||||
|
example, PyArrow defaults to float64 for floating point types, but Lance
|
||||||
|
uses float32 for vectors.
|
||||||
|
"""
|
||||||
|
for batch in input_iter:
|
||||||
|
if not isinstance(batch, pa.RecordBatch):
|
||||||
|
raise TypeError(f"Expected RecordBatch, got {type(batch)}")
|
||||||
|
if batch.schema != schema:
|
||||||
|
try:
|
||||||
|
# RecordBatch doesn't have a cast method, but table does.
|
||||||
|
batch = pa.Table.from_batches([batch]).cast(schema).to_batches()[0]
|
||||||
|
except pa.lib.ArrowInvalid:
|
||||||
|
raise ValueError(
|
||||||
|
f"Input RecordBatch iterator yielded a batch with schema that "
|
||||||
|
f"does not match the expected schema.\nExpected:\n{schema}\n"
|
||||||
|
f"Got:\n{batch.schema}"
|
||||||
|
)
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def data_to_reader(
|
||||||
|
data: DATA, schema: Optional[pa.Schema] = None
|
||||||
|
) -> pa.RecordBatchReader:
|
||||||
|
"""Convert various types of input into a RecordBatchReader"""
|
||||||
|
if pd is not None and isinstance(data, pd.DataFrame):
|
||||||
|
return pa.Table.from_pandas(data, schema=schema).to_reader()
|
||||||
|
elif isinstance(data, pa.Table):
|
||||||
|
return data.to_reader()
|
||||||
|
elif isinstance(data, pa.RecordBatch):
|
||||||
|
return pa.Table.from_batches([data]).to_reader()
|
||||||
|
# elif isinstance(data, LanceDataset):
|
||||||
|
# return data_obj.scanner().to_reader()
|
||||||
|
elif isinstance(data, pa.dataset.Dataset):
|
||||||
|
return pa.dataset.Scanner.from_dataset(data).to_reader()
|
||||||
|
elif isinstance(data, pa.dataset.Scanner):
|
||||||
|
return data.to_reader()
|
||||||
|
elif isinstance(data, pa.RecordBatchReader):
|
||||||
|
return data
|
||||||
|
elif (
|
||||||
|
type(data).__module__.startswith("polars")
|
||||||
|
and data.__class__.__name__ == "DataFrame"
|
||||||
|
):
|
||||||
|
return data.to_arrow().to_reader()
|
||||||
|
# for other iterables, assume they are of type Iterable[RecordBatch]
|
||||||
|
elif isinstance(data, Iterable):
|
||||||
|
if schema is not None:
|
||||||
|
data = _casting_recordbatch_iter(data, schema)
|
||||||
|
return pa.RecordBatchReader.from_batches(schema, data)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Must provide schema to write dataset from RecordBatch iterable"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Unknown data type {type(data)}. "
|
||||||
|
"Please check "
|
||||||
|
"https://lancedb.github.io/lance/read_and_write.html "
|
||||||
|
"to see supported types."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_schema(schema: pa.Schema):
|
||||||
|
"""
|
||||||
|
Make sure the metadata is valid utf8
|
||||||
|
"""
|
||||||
|
if schema.metadata is not None:
|
||||||
|
_validate_metadata(schema.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_metadata(metadata: dict):
|
||||||
|
"""
|
||||||
|
Make sure the metadata values are valid utf8 (can be nested)
|
||||||
|
|
||||||
|
Raises ValueError if not valid utf8
|
||||||
|
"""
|
||||||
|
for k, v in metadata.items():
|
||||||
|
if isinstance(v, bytes):
|
||||||
|
try:
|
||||||
|
v.decode("utf8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Metadata key {k} is not valid utf8. "
|
||||||
|
"Consider base64 encode for generic binary metadata."
|
||||||
|
)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
_validate_metadata(v)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -22,7 +23,12 @@ import pyarrow as pa
|
|||||||
from overrides import EnforceOverrides, override
|
from overrides import EnforceOverrides, override
|
||||||
from pyarrow import fs
|
from pyarrow import fs
|
||||||
|
|
||||||
from .table import LanceTable, Table
|
from lancedb.common import data_to_reader, validate_schema
|
||||||
|
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
|
||||||
|
from lancedb.utils.events import register_event
|
||||||
|
|
||||||
|
from .pydantic import LanceModel
|
||||||
|
from .table import AsyncLanceTable, LanceTable, Table, _sanitize_data
|
||||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -31,7 +37,6 @@ if TYPE_CHECKING:
|
|||||||
from ._lancedb import Connection as LanceDbConnection
|
from ._lancedb import Connection as LanceDbConnection
|
||||||
from .common import DATA, URI
|
from .common import DATA, URI
|
||||||
from .embeddings import EmbeddingFunctionConfig
|
from .embeddings import EmbeddingFunctionConfig
|
||||||
from .pydantic import LanceModel
|
|
||||||
|
|
||||||
|
|
||||||
class DBConnection(EnforceOverrides):
|
class DBConnection(EnforceOverrides):
|
||||||
@@ -644,6 +649,7 @@ class AsyncLanceDBConnection(AsyncConnection):
|
|||||||
page_token=None,
|
page_token=None,
|
||||||
limit=None,
|
limit=None,
|
||||||
) -> Iterable[str]:
|
) -> Iterable[str]:
|
||||||
|
# TODO: hook in page_token and limit
|
||||||
return await self._inner.table_names()
|
return await self._inner.table_names()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@@ -657,8 +663,66 @@ class AsyncLanceDBConnection(AsyncConnection):
|
|||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
) -> LanceTable:
|
) -> Table:
|
||||||
raise NotImplementedError
|
if mode.lower() not in ["create", "overwrite"]:
|
||||||
|
raise ValueError("mode must be either 'create' or 'overwrite'")
|
||||||
|
|
||||||
|
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||||
|
# convert LanceModel to pyarrow schema
|
||||||
|
# note that it's possible this contains
|
||||||
|
# embedding function metadata already
|
||||||
|
schema = schema.to_arrow_schema()
|
||||||
|
|
||||||
|
metadata = None
|
||||||
|
if embedding_functions is not None:
|
||||||
|
# If we passed in embedding functions explicitly
|
||||||
|
# then we'll override any schema metadata that
|
||||||
|
# may was implicitly specified by the LanceModel schema
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
metadata = registry.get_table_metadata(embedding_functions)
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
data = _sanitize_data(
|
||||||
|
data,
|
||||||
|
schema,
|
||||||
|
metadata=metadata,
|
||||||
|
on_bad_vectors=on_bad_vectors,
|
||||||
|
fill_value=fill_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
if schema is None:
|
||||||
|
if data is None:
|
||||||
|
raise ValueError("Either data or schema must be provided")
|
||||||
|
elif hasattr(data, "schema"):
|
||||||
|
schema = data.schema
|
||||||
|
elif isinstance(data, Iterable):
|
||||||
|
if metadata:
|
||||||
|
raise TypeError(
|
||||||
|
(
|
||||||
|
"Persistent embedding functions not yet "
|
||||||
|
"supported for generator data input"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
schema = schema.with_metadata(metadata)
|
||||||
|
validate_schema(schema)
|
||||||
|
|
||||||
|
if mode == "create" and exist_ok:
|
||||||
|
mode = "exist_ok"
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
new_table = await self._inner.create_empty_table(name, mode, schema)
|
||||||
|
else:
|
||||||
|
data = data_to_reader(data, schema)
|
||||||
|
new_table = await self._inner.create_table(
|
||||||
|
name,
|
||||||
|
mode,
|
||||||
|
data,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_event("create_table")
|
||||||
|
return AsyncLanceTable(new_table)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def open_table(self, name: str) -> LanceTable:
|
async def open_table(self, name: str) -> LanceTable:
|
||||||
|
|||||||
@@ -103,9 +103,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
|||||||
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||||
|
|
||||||
source_instruction: str = "represent the document for retrieval"
|
source_instruction: str = "represent the document for retrieval"
|
||||||
query_instruction: (
|
query_instruction: str = (
|
||||||
str
|
"represent the document for retrieving the most similar documents"
|
||||||
) = "represent the document for retrieving the most similar documents"
|
)
|
||||||
|
|
||||||
@weak_lru(maxsize=1)
|
@weak_lru(maxsize=1)
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Full text search index using tantivy-py"""
|
"""Full text search index using tantivy-py"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|||||||
@@ -117,23 +117,36 @@ class LanceQueryBuilder(ABC):
|
|||||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||||
query_type: str,
|
query_type: str,
|
||||||
vector_column_name: str,
|
vector_column_name: str,
|
||||||
|
vector: Optional[VEC] = None,
|
||||||
|
text: Optional[str] = None,
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
if query is None:
|
if query is None and vector is None and text is None:
|
||||||
return LanceEmptyQueryBuilder(table)
|
return LanceEmptyQueryBuilder(table)
|
||||||
|
|
||||||
if query_type == "hybrid":
|
if query_type == "hybrid":
|
||||||
# hybrid fts and vector query
|
# hybrid fts and vector query
|
||||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
return LanceHybridQueryBuilder(
|
||||||
|
table, query, vector_column_name, vector, text
|
||||||
|
)
|
||||||
|
|
||||||
# convert "auto" query_type to "vector", "fts"
|
# Resolve hybrid query with explicit vector and text params here to avoid
|
||||||
# or "hybrid" and convert the query to vector if needed
|
# adding them as params in the BaseQueryBuilder class
|
||||||
|
if vector is not None or text is not None:
|
||||||
|
if query_type not in ["hybrid", "auto"]:
|
||||||
|
raise ValueError(
|
||||||
|
"If `vector` and `text` are provided, then `query_type`\
|
||||||
|
must be 'hybrid' or 'auto'"
|
||||||
|
)
|
||||||
|
return LanceHybridQueryBuilder(
|
||||||
|
table, query, vector_column_name, vector, text
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert "auto" query_type to "vector" or "fts"
|
||||||
|
# and convert the query to vector if needed
|
||||||
query, query_type = cls._resolve_query(
|
query, query_type = cls._resolve_query(
|
||||||
table, query, query_type, vector_column_name
|
table, query, query_type, vector_column_name
|
||||||
)
|
)
|
||||||
|
|
||||||
if query_type == "hybrid":
|
|
||||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
|
||||||
|
|
||||||
if isinstance(query, str):
|
if isinstance(query, str):
|
||||||
# fts
|
# fts
|
||||||
return LanceFtsQueryBuilder(table, query)
|
return LanceFtsQueryBuilder(table, query)
|
||||||
@@ -161,8 +174,6 @@ class LanceQueryBuilder(ABC):
|
|||||||
elif query_type == "auto":
|
elif query_type == "auto":
|
||||||
if isinstance(query, (list, np.ndarray)):
|
if isinstance(query, (list, np.ndarray)):
|
||||||
return query, "vector"
|
return query, "vector"
|
||||||
if isinstance(query, tuple):
|
|
||||||
return query, "hybrid"
|
|
||||||
else:
|
else:
|
||||||
conf = table.embedding_functions.get(vector_column_name)
|
conf = table.embedding_functions.get(vector_column_name)
|
||||||
if conf is not None:
|
if conf is not None:
|
||||||
@@ -336,10 +347,8 @@ class LanceQueryBuilder(ABC):
|
|||||||
LanceQueryBuilder
|
LanceQueryBuilder
|
||||||
The LanceQueryBuilder object.
|
The LanceQueryBuilder object.
|
||||||
"""
|
"""
|
||||||
if isinstance(columns, list):
|
if isinstance(columns, list) or isinstance(columns, dict):
|
||||||
self._columns = columns
|
self._columns = columns
|
||||||
elif isinstance(columns, dict):
|
|
||||||
self._columns = list(columns.items())
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("columns must be a list or a dictionary")
|
raise ValueError("columns must be a list or a dictionary")
|
||||||
return self
|
return self
|
||||||
@@ -630,12 +639,20 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
|||||||
|
|
||||||
|
|
||||||
class LanceHybridQueryBuilder(LanceQueryBuilder):
|
class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||||
def __init__(self, table: "Table", query: str, vector_column: str):
|
def __init__(
|
||||||
|
self,
|
||||||
|
table: "Table",
|
||||||
|
query: str,
|
||||||
|
vector_column: str,
|
||||||
|
vector: Optional[VEC] = None,
|
||||||
|
text: Optional[str] = None,
|
||||||
|
):
|
||||||
super().__init__(table)
|
super().__init__(table)
|
||||||
self._validate_fts_index()
|
self._validate_fts_index()
|
||||||
vector_query, fts_query = self._validate_query(query)
|
vector_query, fts_query = self._validate_query(
|
||||||
|
query, vector_column, vector, text
|
||||||
|
)
|
||||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
|
||||||
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
|
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
|
||||||
self._norm = "score"
|
self._norm = "score"
|
||||||
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
|
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
|
||||||
@@ -646,23 +663,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
|||||||
"Please create a full-text search index " "to perform hybrid search."
|
"Please create a full-text search index " "to perform hybrid search."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_query(self, query):
|
def _validate_query(self, query, vector_column, vector, text):
|
||||||
# Temp hack to support vectorized queries for hybrid search
|
if query is not None:
|
||||||
if isinstance(query, str):
|
if vector is not None or text is not None:
|
||||||
return query, query
|
|
||||||
elif isinstance(query, tuple):
|
|
||||||
if len(query) != 2:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The query must be a tuple of (vector_query, fts_query)."
|
"Either pass `query` or `vector` and `text` separately, not both."
|
||||||
)
|
)
|
||||||
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
|
else:
|
||||||
|
if vector is None or text is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Either pass `query` or `vector` and `text` separately, not both."
|
||||||
|
)
|
||||||
|
|
||||||
|
if vector is not None and text is not None:
|
||||||
|
if not isinstance(vector, (list, np.ndarray, pa.Array, pa.ChunkedArray)):
|
||||||
raise ValueError(f"The vector query must be one of {VEC}.")
|
raise ValueError(f"The vector query must be one of {VEC}.")
|
||||||
if not isinstance(query[1], str):
|
if not isinstance(text, str):
|
||||||
raise ValueError("The fts query must be a string.")
|
raise ValueError("The fts query must be a string.")
|
||||||
return query[0], query[1]
|
return vector, text
|
||||||
|
if isinstance(query, str):
|
||||||
|
vector = self._query_to_vector(self._table, query, vector_column)
|
||||||
|
return vector, query
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The query must be either a string or a tuple of (vector, string)."
|
f"For hybrid search `query` must be a string or `vector` and `text` \
|
||||||
|
must be provided explicitly of types {VEC} and str respectively."
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_arrow(self) -> pa.Table:
|
def to_arrow(self) -> pa.Table:
|
||||||
|
|||||||
@@ -66,12 +66,36 @@ class RemoteTable(Table):
|
|||||||
"""to_pandas() is not yet supported on LanceDB cloud."""
|
"""to_pandas() is not yet supported on LanceDB cloud."""
|
||||||
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||||
|
|
||||||
def create_scalar_index(self, *args, **kwargs):
|
def list_indices(self):
|
||||||
"""Creates a scalar index"""
|
"""List all the indices on the table"""
|
||||||
return NotImplementedError(
|
print(self._name)
|
||||||
"create_scalar_index() is not yet supported on LanceDB cloud."
|
resp = self._conn._client.post(f"/v1/table/{self._name}/index/list/")
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def create_scalar_index(
|
||||||
|
self,
|
||||||
|
column: str,
|
||||||
|
):
|
||||||
|
"""Creates a scalar index
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
column : str
|
||||||
|
The column to be indexed. Must be a boolean, integer, float,
|
||||||
|
or string column.
|
||||||
|
"""
|
||||||
|
index_type = "scalar"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"column": column,
|
||||||
|
"index_type": index_type,
|
||||||
|
"replace": True,
|
||||||
|
}
|
||||||
|
resp = self._conn._client.post(
|
||||||
|
f"/v1/table/{self._name}/create_scalar_index/", data=data
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
def create_index(
|
def create_index(
|
||||||
self,
|
self,
|
||||||
metric="L2",
|
metric="L2",
|
||||||
@@ -277,6 +301,7 @@ class RemoteTable(Table):
|
|||||||
f = Future()
|
f = Future()
|
||||||
f.set_result(self._conn._client.query(name, q))
|
f.set_result(self._conn._client.query(name, q))
|
||||||
return f
|
return f
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def submit(name, q):
|
def submit(name, q):
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Schema related utilities."""
|
"""Schema related utilities."""
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import pyarrow.compute as pc
|
|||||||
import pyarrow.fs as pa_fs
|
import pyarrow.fs as pa_fs
|
||||||
from lance import LanceDataset
|
from lance import LanceDataset
|
||||||
from lance.vector import vec_to_table
|
from lance.vector import vec_to_table
|
||||||
|
from overrides import override
|
||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
@@ -48,6 +49,7 @@ if TYPE_CHECKING:
|
|||||||
import PIL
|
import PIL
|
||||||
from lance.dataset import CleanupStats, ReaderLike
|
from lance.dataset import CleanupStats, ReaderLike
|
||||||
|
|
||||||
|
from ._lancedb import Table as LanceDBTable
|
||||||
from .db import LanceDBConnection
|
from .db import LanceDBConnection
|
||||||
|
|
||||||
|
|
||||||
@@ -416,6 +418,8 @@ class Table(ABC):
|
|||||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||||
vector_column_name: Optional[str] = None,
|
vector_column_name: Optional[str] = None,
|
||||||
query_type: str = "auto",
|
query_type: str = "auto",
|
||||||
|
vector: Optional[VEC] = None,
|
||||||
|
text: Optional[str] = None,
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
"""Create a search query to find the nearest neighbors
|
"""Create a search query to find the nearest neighbors
|
||||||
of the given query vector. We currently support [vector search][search]
|
of the given query vector. We currently support [vector search][search]
|
||||||
@@ -1251,6 +1255,8 @@ class LanceTable(Table):
|
|||||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||||
vector_column_name: Optional[str] = None,
|
vector_column_name: Optional[str] = None,
|
||||||
query_type: str = "auto",
|
query_type: str = "auto",
|
||||||
|
vector: Optional[VEC] = None,
|
||||||
|
text: Optional[str] = None,
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
"""Create a search query to find the nearest neighbors
|
"""Create a search query to find the nearest neighbors
|
||||||
of the given query vector. We currently support [vector search][search]
|
of the given query vector. We currently support [vector search][search]
|
||||||
@@ -1305,6 +1311,10 @@ class LanceTable(Table):
|
|||||||
or raise an error if no corresponding embedding function is found.
|
or raise an error if no corresponding embedding function is found.
|
||||||
If the `query` is a string, then the query type is "vector" if the
|
If the `query` is a string, then the query type is "vector" if the
|
||||||
table has embedding functions, else the query type is "fts"
|
table has embedding functions, else the query type is "fts"
|
||||||
|
vector: list/np.ndarray, default None
|
||||||
|
vector query for hybrid search
|
||||||
|
text: str, default None
|
||||||
|
text query for hybrid search
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -1314,11 +1324,17 @@ class LanceTable(Table):
|
|||||||
and also the "_distance" column which is the distance between the query
|
and also the "_distance" column which is the distance between the query
|
||||||
vector and the returned vector.
|
vector and the returned vector.
|
||||||
"""
|
"""
|
||||||
if vector_column_name is None and query is not None:
|
is_query_defined = query is not None or vector is not None or text is not None
|
||||||
|
if vector_column_name is None and is_query_defined:
|
||||||
vector_column_name = inf_vector_column_query(self.schema)
|
vector_column_name = inf_vector_column_query(self.schema)
|
||||||
register_event("search_table")
|
register_event("search_table")
|
||||||
return LanceQueryBuilder.create(
|
return LanceQueryBuilder.create(
|
||||||
self, query, query_type, vector_column_name=vector_column_name
|
self,
|
||||||
|
query,
|
||||||
|
query_type,
|
||||||
|
vector_column_name=vector_column_name,
|
||||||
|
vector=vector,
|
||||||
|
text=text,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1780,3 +1796,715 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
|||||||
is_full = np.any(~is_value_nan.reshape(-1, vec_arr.type.list_size), axis=1)
|
is_full = np.any(~is_value_nan.reshape(-1, vec_arr.type.list_size), axis=1)
|
||||||
data = data.filter(is_full)
|
data = data.filter(is_full)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncTable(ABC):
|
||||||
|
"""
|
||||||
|
A Table is a collection of Records in a LanceDB Database.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
Create using [DBConnection.create_table][lancedb.DBConnection.create_table]
|
||||||
|
(more examples in that method's documentation).
|
||||||
|
|
||||||
|
>>> import lancedb
|
||||||
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
|
>>> table = db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2}])
|
||||||
|
>>> table.head()
|
||||||
|
pyarrow.Table
|
||||||
|
vector: fixed_size_list<item: float>[2]
|
||||||
|
child 0, item: float
|
||||||
|
b: int64
|
||||||
|
----
|
||||||
|
vector: [[[1.1,1.2]]]
|
||||||
|
b: [[2]]
|
||||||
|
|
||||||
|
Can append new data with [Table.add()][lancedb.table.Table.add].
|
||||||
|
|
||||||
|
>>> table.add([{"vector": [0.5, 1.3], "b": 4}])
|
||||||
|
|
||||||
|
Can query the table with [Table.search][lancedb.table.Table.search].
|
||||||
|
|
||||||
|
>>> table.search([0.4, 0.4]).select(["b", "vector"]).to_pandas()
|
||||||
|
b vector _distance
|
||||||
|
0 4 [0.5, 1.3] 0.82
|
||||||
|
1 2 [1.1, 1.2] 1.13
|
||||||
|
|
||||||
|
Search queries are much faster when an index is created. See
|
||||||
|
[Table.create_index][lancedb.table.Table.create_index].
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""The name of the table."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def schema(self) -> pa.Schema:
|
||||||
|
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||||
|
of this Table
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def count_rows(self, filter: Optional[str] = None) -> int:
|
||||||
|
"""
|
||||||
|
Count the number of rows in the table.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filter: str, optional
|
||||||
|
A SQL where clause to filter the rows to count.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def to_pandas(self) -> "pd.DataFrame":
|
||||||
|
"""Return the table as a pandas DataFrame.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
pd.DataFrame
|
||||||
|
"""
|
||||||
|
return self.to_arrow().to_pandas()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def to_arrow(self) -> pa.Table:
|
||||||
|
"""Return the table as a pyarrow Table.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
pa.Table
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def create_index(
|
||||||
|
self,
|
||||||
|
metric="L2",
|
||||||
|
num_partitions=256,
|
||||||
|
num_sub_vectors=96,
|
||||||
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
|
replace: bool = True,
|
||||||
|
accelerator: Optional[str] = None,
|
||||||
|
index_cache_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Create an index on the table.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
metric: str, default "L2"
|
||||||
|
The distance metric to use when creating the index.
|
||||||
|
Valid values are "L2", "cosine", or "dot".
|
||||||
|
L2 is euclidean distance.
|
||||||
|
num_partitions: int, default 256
|
||||||
|
The number of IVF partitions to use when creating the index.
|
||||||
|
Default is 256.
|
||||||
|
num_sub_vectors: int, default 96
|
||||||
|
The number of PQ sub-vectors to use when creating the index.
|
||||||
|
Default is 96.
|
||||||
|
vector_column_name: str, default "vector"
|
||||||
|
The vector column name to create the index.
|
||||||
|
replace: bool, default True
|
||||||
|
- If True, replace the existing index if it exists.
|
||||||
|
|
||||||
|
- If False, raise an error if duplicate index exists.
|
||||||
|
accelerator: str, default None
|
||||||
|
If set, use the given accelerator to create the index.
|
||||||
|
Only support "cuda" for now.
|
||||||
|
index_cache_size : int, optional
|
||||||
|
The size of the index cache in number of entries. Default value is 256.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_scalar_index(
|
||||||
|
self,
|
||||||
|
column: str,
|
||||||
|
*,
|
||||||
|
replace: bool = True,
|
||||||
|
):
|
||||||
|
"""Create a scalar index on a column.
|
||||||
|
|
||||||
|
Scalar indices, like vector indices, can be used to speed up scans. A scalar
|
||||||
|
index can speed up scans that contain filter expressions on the indexed column.
|
||||||
|
For example, the following scan will be faster if the column ``my_col`` has
|
||||||
|
a scalar index:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import lancedb
|
||||||
|
|
||||||
|
db = lancedb.connect("/data/lance")
|
||||||
|
img_table = db.open_table("images")
|
||||||
|
my_df = img_table.search().where("my_col = 7", prefilter=True).to_pandas()
|
||||||
|
|
||||||
|
Scalar indices can also speed up scans containing a vector search and a
|
||||||
|
prefilter:
|
||||||
|
|
||||||
|
.. code-block::python
|
||||||
|
|
||||||
|
import lancedb
|
||||||
|
|
||||||
|
db = lancedb.connect("/data/lance")
|
||||||
|
img_table = db.open_table("images")
|
||||||
|
img_table.search([1, 2, 3, 4], vector_column_name="vector")
|
||||||
|
.where("my_col != 7", prefilter=True)
|
||||||
|
.to_pandas()
|
||||||
|
|
||||||
|
Scalar indices can only speed up scans for basic filters using
|
||||||
|
equality, comparison, range (e.g. ``my_col BETWEEN 0 AND 100``), and set
|
||||||
|
membership (e.g. `my_col IN (0, 1, 2)`)
|
||||||
|
|
||||||
|
Scalar indices can be used if the filter contains multiple indexed columns and
|
||||||
|
the filter criteria are AND'd or OR'd together
|
||||||
|
(e.g. ``my_col < 0 AND other_col> 100``)
|
||||||
|
|
||||||
|
Scalar indices may be used if the filter contains non-indexed columns but,
|
||||||
|
depending on the structure of the filter, they may not be usable. For example,
|
||||||
|
if the column ``not_indexed`` does not have a scalar index then the filter
|
||||||
|
``my_col = 0 OR not_indexed = 1`` will not be able to use any scalar index on
|
||||||
|
``my_col``.
|
||||||
|
|
||||||
|
**Experimental API**
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
column : str
|
||||||
|
The column to be indexed. Must be a boolean, integer, float,
|
||||||
|
or string column.
|
||||||
|
replace : bool, default True
|
||||||
|
Replace the existing index if it exists.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import lance
|
||||||
|
|
||||||
|
dataset = lance.dataset("./images.lance")
|
||||||
|
dataset.create_scalar_index("category")
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def add(
|
||||||
|
self,
|
||||||
|
data: DATA,
|
||||||
|
mode: str = "append",
|
||||||
|
on_bad_vectors: str = "error",
|
||||||
|
fill_value: float = 0.0,
|
||||||
|
):
|
||||||
|
"""Add more data to the [Table](Table).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data: DATA
|
||||||
|
The data to insert into the table. Acceptable types are:
|
||||||
|
|
||||||
|
- dict or list-of-dict
|
||||||
|
|
||||||
|
- pandas.DataFrame
|
||||||
|
|
||||||
|
- pyarrow.Table or pyarrow.RecordBatch
|
||||||
|
mode: str
|
||||||
|
The mode to use when writing the data. Valid values are
|
||||||
|
"append" and "overwrite".
|
||||||
|
on_bad_vectors: str, default "error"
|
||||||
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
|
One of "error", "drop", "fill".
|
||||||
|
fill_value: float, default 0.
|
||||||
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||||
|
"""
|
||||||
|
Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
||||||
|
that can be used to create a "merge insert" operation
|
||||||
|
|
||||||
|
This operation can add rows, update rows, and remove rows all in a single
|
||||||
|
transaction. It is a very generic tool that can be used to create
|
||||||
|
behaviors like "insert if not exists", "update or insert (i.e. upsert)",
|
||||||
|
or even replace a portion of existing data with new data (e.g. replace
|
||||||
|
all data where month="january")
|
||||||
|
|
||||||
|
The merge insert operation works by combining new data from a
|
||||||
|
**source table** with existing data in a **target table** by using a
|
||||||
|
join. There are three categories of records.
|
||||||
|
|
||||||
|
"Matched" records are records that exist in both the source table and
|
||||||
|
the target table. "Not matched" records exist only in the source table
|
||||||
|
(e.g. these are new data) "Not matched by source" records exist only
|
||||||
|
in the target table (this is old data)
|
||||||
|
|
||||||
|
The builder returned by this method can be used to customize what
|
||||||
|
should happen for each category of data.
|
||||||
|
|
||||||
|
Please note that the data may appear to be reordered as part of this
|
||||||
|
operation. This is because updated rows will be deleted from the
|
||||||
|
dataset and then reinserted at the end with the new values.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
|
||||||
|
on: Union[str, Iterable[str]]
|
||||||
|
A column (or columns) to join on. This is how records from the
|
||||||
|
source table and target table are matched. Typically this is some
|
||||||
|
kind of key or id column.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lancedb
|
||||||
|
>>> data = pa.table({"a": [2, 1, 3], "b": ["a", "b", "c"]})
|
||||||
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
|
>>> table = db.create_table("my_table", data)
|
||||||
|
>>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||||
|
>>> # Perform a "upsert" operation
|
||||||
|
>>> table.merge_insert("a") \\
|
||||||
|
... .when_matched_update_all() \\
|
||||||
|
... .when_not_matched_insert_all() \\
|
||||||
|
... .execute(new_data)
|
||||||
|
>>> # The order of new rows is non-deterministic since we use
|
||||||
|
>>> # a hash-join as part of this operation and so we sort here
|
||||||
|
>>> table.to_arrow().sort_by("a").to_pandas()
|
||||||
|
a b
|
||||||
|
0 1 b
|
||||||
|
1 2 x
|
||||||
|
2 3 y
|
||||||
|
3 4 z
|
||||||
|
"""
|
||||||
|
on = [on] if isinstance(on, str) else list(on.iter())
|
||||||
|
|
||||||
|
return LanceMergeInsertBuilder(self, on)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||||
|
vector_column_name: Optional[str] = None,
|
||||||
|
query_type: str = "auto",
|
||||||
|
) -> LanceQueryBuilder:
|
||||||
|
"""Create a search query to find the nearest neighbors
|
||||||
|
of the given query vector. We currently support [vector search][search]
|
||||||
|
and [full-text search][experimental-full-text-search].
|
||||||
|
|
||||||
|
All query options are defined in [Query][lancedb.query.Query].
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lancedb
|
||||||
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
|
>>> data = [
|
||||||
|
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
||||||
|
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
||||||
|
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
|
||||||
|
... ]
|
||||||
|
>>> table = db.create_table("my_table", data)
|
||||||
|
>>> query = [0.4, 1.4, 2.4]
|
||||||
|
>>> (table.search(query)
|
||||||
|
... .where("original_width > 1000", prefilter=True)
|
||||||
|
... .select(["caption", "original_width", "vector"])
|
||||||
|
... .limit(2)
|
||||||
|
... .to_pandas())
|
||||||
|
caption original_width vector _distance
|
||||||
|
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
|
||||||
|
1 test 3000 [0.3, 6.2, 2.6] 23.089996
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||||
|
The targetted vector to search for.
|
||||||
|
|
||||||
|
- *default None*.
|
||||||
|
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||||
|
|
||||||
|
- If None then the select/where/limit clauses are applied to filter
|
||||||
|
the table
|
||||||
|
vector_column_name: str, optional
|
||||||
|
The name of the vector column to search.
|
||||||
|
|
||||||
|
The vector column needs to be a pyarrow fixed size list type
|
||||||
|
|
||||||
|
- If not specified then the vector column is inferred from
|
||||||
|
the table schema
|
||||||
|
|
||||||
|
- If the table has multiple vector columns then the *vector_column_name*
|
||||||
|
needs to be specified. Otherwise, an error is raised.
|
||||||
|
query_type: str
|
||||||
|
*default "auto"*.
|
||||||
|
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||||
|
|
||||||
|
- If "auto" then the query type is inferred from the query;
|
||||||
|
|
||||||
|
- If `query` is a list/np.ndarray then the query type is
|
||||||
|
"vector";
|
||||||
|
|
||||||
|
- If `query` is a PIL.Image.Image then either do vector search,
|
||||||
|
or raise an error if no corresponding embedding function is found.
|
||||||
|
|
||||||
|
- If `query` is a string, then the query type is "vector" if the
|
||||||
|
table has embedding functions else the query type is "fts"
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceQueryBuilder
|
||||||
|
A query builder object representing the query.
|
||||||
|
Once executed, the query returns
|
||||||
|
|
||||||
|
- selected columns
|
||||||
|
|
||||||
|
- the vector
|
||||||
|
|
||||||
|
- and also the "_distance" column which is the distance between the query
|
||||||
|
vector and the returned vector.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _execute_query(self, query: Query) -> pa.Table:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _do_merge(
|
||||||
|
self,
|
||||||
|
merge: LanceMergeInsertBuilder,
|
||||||
|
new_data: DATA,
|
||||||
|
on_bad_vectors: str,
|
||||||
|
fill_value: float,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete(self, where: str):
|
||||||
|
"""Delete rows from the table.
|
||||||
|
|
||||||
|
This can be used to delete a single row, many rows, all rows, or
|
||||||
|
sometimes no rows (if your predicate matches nothing).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
where: str
|
||||||
|
The SQL where clause to use when deleting rows.
|
||||||
|
|
||||||
|
- For example, 'x = 2' or 'x IN (1, 2, 3)'.
|
||||||
|
|
||||||
|
The filter must not be empty, or it will error.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lancedb
|
||||||
|
>>> data = [
|
||||||
|
... {"x": 1, "vector": [1, 2]},
|
||||||
|
... {"x": 2, "vector": [3, 4]},
|
||||||
|
... {"x": 3, "vector": [5, 6]}
|
||||||
|
... ]
|
||||||
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
|
>>> table = db.create_table("my_table", data)
|
||||||
|
>>> table.to_pandas()
|
||||||
|
x vector
|
||||||
|
0 1 [1.0, 2.0]
|
||||||
|
1 2 [3.0, 4.0]
|
||||||
|
2 3 [5.0, 6.0]
|
||||||
|
>>> table.delete("x = 2")
|
||||||
|
>>> table.to_pandas()
|
||||||
|
x vector
|
||||||
|
0 1 [1.0, 2.0]
|
||||||
|
1 3 [5.0, 6.0]
|
||||||
|
|
||||||
|
If you have a list of values to delete, you can combine them into a
|
||||||
|
stringified list and use the `IN` operator:
|
||||||
|
|
||||||
|
>>> to_remove = [1, 5]
|
||||||
|
>>> to_remove = ", ".join([str(v) for v in to_remove])
|
||||||
|
>>> to_remove
|
||||||
|
'1, 5'
|
||||||
|
>>> table.delete(f"x IN ({to_remove})")
|
||||||
|
>>> table.to_pandas()
|
||||||
|
x vector
|
||||||
|
0 3 [5.0, 6.0]
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update(
|
||||||
|
self,
|
||||||
|
where: Optional[str] = None,
|
||||||
|
values: Optional[dict] = None,
|
||||||
|
*,
|
||||||
|
values_sql: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This can be used to update zero to all rows depending on how many
|
||||||
|
rows match the where clause. If no where clause is provided, then
|
||||||
|
all rows will be updated.
|
||||||
|
|
||||||
|
Either `values` or `values_sql` must be provided. You cannot provide
|
||||||
|
both.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
where: str, optional
|
||||||
|
The SQL where clause to use when updating rows. For example, 'x = 2'
|
||||||
|
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
||||||
|
values: dict, optional
|
||||||
|
The values to update. The keys are the column names and the values
|
||||||
|
are the values to set.
|
||||||
|
values_sql: dict, optional
|
||||||
|
The values to update, expressed as SQL expression strings. These can
|
||||||
|
reference existing columns. For example, {"x": "x + 1"} will increment
|
||||||
|
the x column by 1.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lancedb
|
||||||
|
>>> import pandas as pd
|
||||||
|
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||||
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
|
>>> table = db.create_table("my_table", data)
|
||||||
|
>>> table.to_pandas()
|
||||||
|
x vector
|
||||||
|
0 1 [1.0, 2.0]
|
||||||
|
1 2 [3.0, 4.0]
|
||||||
|
2 3 [5.0, 6.0]
|
||||||
|
>>> table.update(where="x = 2", values={"vector": [10, 10]})
|
||||||
|
>>> table.to_pandas()
|
||||||
|
x vector
|
||||||
|
0 1 [1.0, 2.0]
|
||||||
|
1 3 [5.0, 6.0]
|
||||||
|
2 2 [10.0, 10.0]
|
||||||
|
>>> table.update(values_sql={"x": "x + 1"})
|
||||||
|
>>> table.to_pandas()
|
||||||
|
x vector
|
||||||
|
0 2 [1.0, 2.0]
|
||||||
|
1 4 [5.0, 6.0]
|
||||||
|
2 3 [10.0, 10.0]
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def cleanup_old_versions(
|
||||||
|
self,
|
||||||
|
older_than: Optional[timedelta] = None,
|
||||||
|
*,
|
||||||
|
delete_unverified: bool = False,
|
||||||
|
) -> CleanupStats:
|
||||||
|
"""
|
||||||
|
Clean up old versions of the table, freeing disk space.
|
||||||
|
|
||||||
|
Note: This function is not available in LanceDb Cloud (since LanceDb
|
||||||
|
Cloud manages cleanup for you automatically)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
older_than: timedelta, default None
|
||||||
|
The minimum age of the version to delete. If None, then this defaults
|
||||||
|
to two weeks.
|
||||||
|
delete_unverified: bool, default False
|
||||||
|
Because they may be part of an in-progress transaction, files newer
|
||||||
|
than 7 days old are not deleted by default. If you are sure that
|
||||||
|
there are no in-progress transactions, then you can set this to True
|
||||||
|
to delete all files older than `older_than`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
CleanupStats
|
||||||
|
The stats of the cleanup operation, including how many bytes were
|
||||||
|
freed.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def compact_files(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the compaction process on the table.
|
||||||
|
|
||||||
|
Note: This function is not available in LanceDb Cloud (since LanceDb
|
||||||
|
Cloud manages compaction for you automatically)
|
||||||
|
|
||||||
|
This can be run after making several small appends to optimize the table
|
||||||
|
for faster reads.
|
||||||
|
|
||||||
|
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
|
||||||
|
For most cases, the default should be fine.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def add_columns(self, transforms: Dict[str, str]):
|
||||||
|
"""
|
||||||
|
Add new columns with defined values.
|
||||||
|
|
||||||
|
This is not yet available in LanceDB Cloud.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
transforms: Dict[str, str]
|
||||||
|
A map of column name to a SQL expression to use to calculate the
|
||||||
|
value of the new column. These expressions will be evaluated for
|
||||||
|
each row in the table, and can reference existing columns.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def alter_columns(self, alterations: Iterable[Dict[str, str]]):
|
||||||
|
"""
|
||||||
|
Alter column names and nullability.
|
||||||
|
|
||||||
|
This is not yet available in LanceDB Cloud.
|
||||||
|
|
||||||
|
alterations : Iterable[Dict[str, Any]]
|
||||||
|
A sequence of dictionaries, each with the following keys:
|
||||||
|
- "path": str
|
||||||
|
The column path to alter. For a top-level column, this is the name.
|
||||||
|
For a nested column, this is the dot-separated path, e.g. "a.b.c".
|
||||||
|
- "name": str, optional
|
||||||
|
The new name of the column. If not specified, the column name is
|
||||||
|
not changed.
|
||||||
|
- "nullable": bool, optional
|
||||||
|
Whether the column should be nullable. If not specified, the column
|
||||||
|
nullability is not changed. Only non-nullable columns can be changed
|
||||||
|
to nullable. Currently, you cannot change a nullable column to
|
||||||
|
non-nullable.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def drop_columns(self, columns: Iterable[str]):
|
||||||
|
"""
|
||||||
|
Drop columns from the table.
|
||||||
|
|
||||||
|
This is not yet available in LanceDB Cloud.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
columns : Iterable[str]
|
||||||
|
The names of the columns to drop.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncLanceTable(AsyncTable):
|
||||||
|
def __init__(self, table: LanceDBTable):
|
||||||
|
self._inner = table
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._inner.name()
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def schema(self) -> pa.Schema:
|
||||||
|
return await self._inner.schema()
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def count_rows(self, filter: Optional[str] = None) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def to_pandas(self) -> "pd.DataFrame":
|
||||||
|
return self.to_arrow().to_pandas()
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def to_arrow(self) -> pa.Table:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def create_index(
|
||||||
|
self,
|
||||||
|
metric="L2",
|
||||||
|
num_partitions=256,
|
||||||
|
num_sub_vectors=96,
|
||||||
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
|
replace: bool = True,
|
||||||
|
accelerator: Optional[str] = None,
|
||||||
|
index_cache_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def create_scalar_index(
|
||||||
|
self,
|
||||||
|
column: str,
|
||||||
|
*,
|
||||||
|
replace: bool = True,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def add(
|
||||||
|
self,
|
||||||
|
data: DATA,
|
||||||
|
mode: str = "append",
|
||||||
|
on_bad_vectors: str = "error",
|
||||||
|
fill_value: float = 0.0,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||||
|
on = [on] if isinstance(on, str) else list(on.iter())
|
||||||
|
|
||||||
|
return LanceMergeInsertBuilder(self, on)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||||
|
vector_column_name: Optional[str] = None,
|
||||||
|
query_type: str = "auto",
|
||||||
|
) -> LanceQueryBuilder:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def _execute_query(self, query: Query) -> pa.Table:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def _do_merge(
|
||||||
|
self,
|
||||||
|
merge: LanceMergeInsertBuilder,
|
||||||
|
new_data: DATA,
|
||||||
|
on_bad_vectors: str,
|
||||||
|
fill_value: float,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def delete(self, where: str):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def update(
|
||||||
|
self,
|
||||||
|
where: Optional[str] = None,
|
||||||
|
values: Optional[dict] = None,
|
||||||
|
*,
|
||||||
|
values_sql: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def cleanup_old_versions(
|
||||||
|
self,
|
||||||
|
older_than: Optional[timedelta] = None,
|
||||||
|
*,
|
||||||
|
delete_unverified: bool = False,
|
||||||
|
) -> CleanupStats:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def compact_files(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def add_columns(self, transforms: Dict[str, str]):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def alter_columns(self, alterations: Iterable[Dict[str, str]]):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def drop_columns(self, columns: Iterable[str]):
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -250,6 +250,78 @@ def test_create_exist_ok(tmp_path):
|
|||||||
db.create_table("test", schema=bad_schema, exist_ok=True)
|
db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_mode_async(tmp_path):
|
||||||
|
db = await lancedb.connect_async(tmp_path)
|
||||||
|
data = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||||
|
"item": ["foo", "bar"],
|
||||||
|
"price": [10.0, 20.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await db.create_table("test", data=data)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await db.create_table("test", data=data)
|
||||||
|
|
||||||
|
new_data = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||||
|
"item": ["fizz", "buzz"],
|
||||||
|
"price": [10.0, 20.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_tbl = await db.create_table("test", data=new_data, mode="overwrite")
|
||||||
|
|
||||||
|
# MIGRATION: to_pandas() is not available in async
|
||||||
|
# assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_exist_ok_async(tmp_path):
|
||||||
|
db = await lancedb.connect_async(tmp_path)
|
||||||
|
data = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||||
|
"item": ["foo", "bar"],
|
||||||
|
"price": [10.0, 20.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tbl = await db.create_table("test", data=data)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await db.create_table("test", data=data)
|
||||||
|
|
||||||
|
# open the table but don't add more rows
|
||||||
|
tbl2 = await db.create_table("test", data=data, exist_ok=True)
|
||||||
|
assert tbl.name == tbl2.name
|
||||||
|
assert await tbl.schema() == await tbl2.schema()
|
||||||
|
|
||||||
|
schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||||
|
pa.field("item", pa.utf8()),
|
||||||
|
pa.field("price", pa.float64()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tbl3 = await db.create_table("test", schema=schema, exist_ok=True)
|
||||||
|
assert await tbl3.schema() == schema
|
||||||
|
|
||||||
|
# Migration: When creating a table, but the table already exists, but
|
||||||
|
# the schema is different, it should raise an error.
|
||||||
|
# bad_schema = pa.schema(
|
||||||
|
# [
|
||||||
|
# pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||||
|
# pa.field("item", pa.utf8()),
|
||||||
|
# pa.field("price", pa.float64()),
|
||||||
|
# pa.field("extra", pa.float32()),
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
# with pytest.raises(ValueError):
|
||||||
|
# await db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def test_delete_table(tmp_path):
|
def test_delete_table(tmp_path):
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
data = pd.DataFrame(
|
data = pd.DataFrame(
|
||||||
|
|||||||
@@ -94,6 +94,17 @@ def test_query_builder(table):
|
|||||||
assert all(np.array(rs[0]["vector"]) == [1, 2])
|
assert all(np.array(rs[0]["vector"]) == [1, 2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_projection(table):
|
||||||
|
rs = (
|
||||||
|
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||||
|
.limit(1)
|
||||||
|
.select({"id": "id", "id2": "id * 2"})
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert rs[0]["id"] == 1
|
||||||
|
assert rs[0]["id2"] == 2
|
||||||
|
|
||||||
|
|
||||||
def test_query_builder_with_filter(table):
|
def test_query_builder_with_filter(table):
|
||||||
rs = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_list()
|
rs = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_list()
|
||||||
assert rs[0]["id"] == 2
|
assert rs[0]["id"] == 2
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ def test_linear_combination(tmp_path):
|
|||||||
query = "Our father who art in heaven"
|
query = "Our father who art in heaven"
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
result = (
|
result = (
|
||||||
table.search((query_vector, query))
|
table.search(vector=query_vector, text=query, query_type="hybrid")
|
||||||
.limit(30)
|
.limit(30)
|
||||||
.rerank(normalize="score")
|
.rerank(normalize="score")
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
@@ -118,6 +118,32 @@ def test_linear_combination(tmp_path):
|
|||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# automatically deduce the query type
|
||||||
|
result = (
|
||||||
|
table.search(vector=query_vector, text=query)
|
||||||
|
.limit(30)
|
||||||
|
.rerank(normalize="score")
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
|
||||||
|
# wrong query type raises an error
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
table.search(vector=query_vector, text=query, query_type="vector").rerank(
|
||||||
|
normalize="score"
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
table.search(vector=query_vector, text=query, query_type="fts").rerank(
|
||||||
|
normalize="score"
|
||||||
|
)
|
||||||
|
|
||||||
|
# raise an error if only vector or text is provided
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
table.search(vector=query_vector).to_arrow()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
table.search(text=query).to_arrow()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||||
@@ -141,7 +167,7 @@ def test_cohere_reranker(tmp_path):
|
|||||||
query = "Our father who art in heaven"
|
query = "Our father who art in heaven"
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
result = (
|
result = (
|
||||||
table.search((query_vector, query))
|
table.search(vector=query_vector, text=query)
|
||||||
.limit(30)
|
.limit(30)
|
||||||
.rerank(reranker=CohereReranker())
|
.rerank(reranker=CohereReranker())
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
@@ -175,7 +201,7 @@ def test_cross_encoder_reranker(tmp_path):
|
|||||||
query = "Our father who art in heaven"
|
query = "Our father who art in heaven"
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
result = (
|
result = (
|
||||||
table.search((query_vector, query), query_type="hybrid")
|
table.search(vector=query_vector, text=query, query_type="hybrid")
|
||||||
.limit(30)
|
.limit(30)
|
||||||
.rerank(reranker=CrossEncoderReranker())
|
.rerank(reranker=CrossEncoderReranker())
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
@@ -209,7 +235,7 @@ def test_colbert_reranker(tmp_path):
|
|||||||
query = "Our father who art in heaven"
|
query = "Our father who art in heaven"
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
result = (
|
result = (
|
||||||
table.search((query_vector, query))
|
table.search(vector=query_vector, text=query)
|
||||||
.limit(30)
|
.limit(30)
|
||||||
.rerank(reranker=ColbertReranker())
|
.rerank(reranker=ColbertReranker())
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
@@ -246,7 +272,7 @@ def test_openai_reranker(tmp_path):
|
|||||||
query = "Our father who art in heaven"
|
query = "Our father who art in heaven"
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
result = (
|
result = (
|
||||||
table.search((query_vector, query))
|
table.search(vector=query_vector, text=query)
|
||||||
.limit(30)
|
.limit(30)
|
||||||
.rerank(reranker=OpenaiReranker())
|
.rerank(reranker=OpenaiReranker())
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
|
|||||||
@@ -12,19 +12,33 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::time::Duration;
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
use lancedb::connection::Connection as LanceConnection;
|
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||||
use pyo3::{pyclass, pyfunction, pymethods, PyAny, PyRef, PyResult, Python};
|
use lancedb::connection::{Connection as LanceConnection, CreateTableMode};
|
||||||
|
use pyo3::{
|
||||||
|
exceptions::PyValueError, pyclass, pyfunction, pymethods, PyAny, PyRef, PyResult, Python,
|
||||||
|
};
|
||||||
use pyo3_asyncio::tokio::future_into_py;
|
use pyo3_asyncio::tokio::future_into_py;
|
||||||
|
|
||||||
use crate::error::PythonErrorExt;
|
use crate::{error::PythonErrorExt, table::Table};
|
||||||
|
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
pub struct Connection {
|
pub struct Connection {
|
||||||
inner: LanceConnection,
|
inner: LanceConnection,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Connection {
|
||||||
|
fn parse_create_mode_str(mode: &str) -> PyResult<CreateTableMode> {
|
||||||
|
match mode {
|
||||||
|
"create" => Ok(CreateTableMode::Create),
|
||||||
|
"overwrite" => Ok(CreateTableMode::Overwrite),
|
||||||
|
"exist_ok" => Ok(CreateTableMode::exist_ok(|builder| builder)),
|
||||||
|
_ => Err(PyValueError::new_err(format!("Invalid mode {}", mode))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl Connection {
|
impl Connection {
|
||||||
pub fn table_names(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
pub fn table_names(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||||
@@ -33,6 +47,51 @@ impl Connection {
|
|||||||
inner.table_names().await.infer_error()
|
inner.table_names().await.infer_error()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn create_table<'a>(
|
||||||
|
self_: PyRef<'a, Self>,
|
||||||
|
name: String,
|
||||||
|
mode: &str,
|
||||||
|
data: &PyAny,
|
||||||
|
) -> PyResult<&'a PyAny> {
|
||||||
|
let inner = self_.inner.clone();
|
||||||
|
|
||||||
|
let mode = Self::parse_create_mode_str(mode)?;
|
||||||
|
|
||||||
|
let batches = Box::new(ArrowArrayStreamReader::from_pyarrow(data)?);
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
let table = inner
|
||||||
|
.create_table(name, batches)
|
||||||
|
.mode(mode)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.infer_error()?;
|
||||||
|
Ok(Table::new(table))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_empty_table<'a>(
|
||||||
|
self_: PyRef<'a, Self>,
|
||||||
|
name: String,
|
||||||
|
mode: &str,
|
||||||
|
schema: &PyAny,
|
||||||
|
) -> PyResult<&'a PyAny> {
|
||||||
|
let inner = self_.inner.clone();
|
||||||
|
|
||||||
|
let mode = Self::parse_create_mode_str(mode)?;
|
||||||
|
|
||||||
|
let schema = Schema::from_pyarrow(schema)?;
|
||||||
|
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
let table = inner
|
||||||
|
.create_empty_table(name, Arc::new(schema))
|
||||||
|
.mode(mode)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.infer_error()?;
|
||||||
|
Ok(Table::new(table))
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
|||||||
@@ -35,14 +35,17 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
|||||||
match &self {
|
match &self {
|
||||||
Ok(_) => Ok(self.unwrap()),
|
Ok(_) => Ok(self.unwrap()),
|
||||||
Err(err) => match err {
|
Err(err) => match err {
|
||||||
|
LanceError::InvalidInput { .. } => self.value_error(),
|
||||||
LanceError::InvalidTableName { .. } => self.value_error(),
|
LanceError::InvalidTableName { .. } => self.value_error(),
|
||||||
LanceError::TableNotFound { .. } => self.value_error(),
|
LanceError::TableNotFound { .. } => self.value_error(),
|
||||||
LanceError::TableAlreadyExists { .. } => self.runtime_error(),
|
LanceError::Schema { .. } => self.value_error(),
|
||||||
LanceError::CreateDir { .. } => self.os_error(),
|
LanceError::CreateDir { .. } => self.os_error(),
|
||||||
|
LanceError::TableAlreadyExists { .. } => self.runtime_error(),
|
||||||
LanceError::Store { .. } => self.runtime_error(),
|
LanceError::Store { .. } => self.runtime_error(),
|
||||||
LanceError::Lance { .. } => self.runtime_error(),
|
LanceError::Lance { .. } => self.runtime_error(),
|
||||||
LanceError::Schema { .. } => self.value_error(),
|
|
||||||
LanceError::Runtime { .. } => self.runtime_error(),
|
LanceError::Runtime { .. } => self.runtime_error(),
|
||||||
|
LanceError::Http { .. } => self.runtime_error(),
|
||||||
|
LanceError::Arrow { .. } => self.runtime_error(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ use env_logger::Env;
|
|||||||
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
||||||
|
|
||||||
pub mod connection;
|
pub mod connection;
|
||||||
pub(crate) mod error;
|
pub mod error;
|
||||||
|
pub mod table;
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
|
|||||||
34
python/src/table.rs
Normal file
34
python/src/table.rs
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow::pyarrow::ToPyArrow;
|
||||||
|
use lancedb::table::Table as LanceTable;
|
||||||
|
use pyo3::{pyclass, pymethods, PyAny, PyRef, PyResult, Python};
|
||||||
|
use pyo3_asyncio::tokio::future_into_py;
|
||||||
|
|
||||||
|
use crate::error::PythonErrorExt;
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
pub struct Table {
|
||||||
|
inner: Arc<dyn LanceTable>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Table {
|
||||||
|
pub(crate) fn new(inner: Arc<dyn LanceTable>) -> Self {
|
||||||
|
Self { inner }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl Table {
|
||||||
|
pub fn name(&self) -> String {
|
||||||
|
self.inner.name().to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn schema(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||||
|
let inner = self_.inner.clone();
|
||||||
|
future_into_py(self_.py(), async move {
|
||||||
|
let schema = inner.schema().await.infer_error()?;
|
||||||
|
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -31,11 +31,20 @@ async-trait = "0"
|
|||||||
bytes = "1"
|
bytes = "1"
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
num-traits.workspace = true
|
num-traits.workspace = true
|
||||||
url = { workspace = true }
|
url.workspace = true
|
||||||
serde = { version = "^1" }
|
serde = { version = "^1" }
|
||||||
serde_json = { version = "1" }
|
serde_json = { version = "1" }
|
||||||
|
|
||||||
|
# For remote feature
|
||||||
|
|
||||||
|
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = "3.5.0"
|
tempfile = "3.5.0"
|
||||||
rand = { version = "0.8.3", features = ["small_rng"] }
|
rand = { version = "0.8.3", features = ["small_rng"] }
|
||||||
|
uuid = { version = "1.7.0", features = ["v4"] }
|
||||||
walkdir = "2"
|
walkdir = "2"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["remote"]
|
||||||
|
remote = ["dep:reqwest"]
|
||||||
|
|||||||
@@ -80,11 +80,11 @@ enum BadVectorHandling {
|
|||||||
/// A builder for configuring a [`Connection::create_table`] operation
|
/// A builder for configuring a [`Connection::create_table`] operation
|
||||||
pub struct CreateTableBuilder<const HAS_DATA: bool> {
|
pub struct CreateTableBuilder<const HAS_DATA: bool> {
|
||||||
parent: Arc<dyn ConnectionInternal>,
|
parent: Arc<dyn ConnectionInternal>,
|
||||||
name: String,
|
pub(crate) name: String,
|
||||||
data: Option<Box<dyn RecordBatchReader + Send>>,
|
pub(crate) data: Option<Box<dyn RecordBatchReader + Send>>,
|
||||||
schema: Option<SchemaRef>,
|
pub(crate) schema: Option<SchemaRef>,
|
||||||
mode: CreateTableMode,
|
pub(crate) mode: CreateTableMode,
|
||||||
write_options: WriteOptions,
|
pub(crate) write_options: WriteOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Builder methods that only apply when we have initial data
|
// Builder methods that only apply when we have initial data
|
||||||
@@ -194,7 +194,7 @@ impl OpenTableBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static {
|
pub(crate) trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static {
|
||||||
async fn table_names(&self) -> Result<Vec<String>>;
|
async fn table_names(&self) -> Result<Vec<String>>;
|
||||||
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef>;
|
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef>;
|
||||||
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<TableRef>;
|
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<TableRef>;
|
||||||
@@ -365,14 +365,46 @@ impl ConnectBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Establishes a connection to the database
|
#[cfg(feature = "remote")]
|
||||||
pub async fn execute(self) -> Result<Connection> {
|
fn execute_remote(self) -> Result<Connection> {
|
||||||
let internal = Arc::new(Database::connect_with_options(&self).await?);
|
let region = self.region.ok_or_else(|| Error::InvalidInput {
|
||||||
|
message: "A region is required when connecting to LanceDb Cloud".to_string(),
|
||||||
|
})?;
|
||||||
|
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
|
||||||
|
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
||||||
|
})?;
|
||||||
|
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||||
|
&self.uri,
|
||||||
|
&api_key,
|
||||||
|
®ion,
|
||||||
|
self.host_override,
|
||||||
|
)?);
|
||||||
Ok(Connection {
|
Ok(Connection {
|
||||||
internal,
|
internal,
|
||||||
uri: self.uri,
|
uri: self.uri,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "remote"))]
|
||||||
|
fn execute_remote(self) -> Result<Connection> {
|
||||||
|
Err(Error::Runtime {
|
||||||
|
message: "cannot connect to LanceDb Cloud unless the 'remote' feature is enabled"
|
||||||
|
.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Establishes a connection to the database
|
||||||
|
pub async fn execute(self) -> Result<Connection> {
|
||||||
|
if self.uri.starts_with("db") {
|
||||||
|
self.execute_remote()
|
||||||
|
} else {
|
||||||
|
let internal = Arc::new(Database::connect_with_options(&self).await?);
|
||||||
|
Ok(Connection {
|
||||||
|
internal,
|
||||||
|
uri: self.uri,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Connect to a LanceDB database.
|
/// Connect to a LanceDB database.
|
||||||
|
|||||||
@@ -20,32 +20,40 @@ use snafu::Snafu;
|
|||||||
#[derive(Debug, Snafu)]
|
#[derive(Debug, Snafu)]
|
||||||
#[snafu(visibility(pub(crate)))]
|
#[snafu(visibility(pub(crate)))]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
#[snafu(display("LanceDBError: Invalid table name: {name}"))]
|
#[snafu(display("Invalid table name: {name}"))]
|
||||||
InvalidTableName { name: String },
|
InvalidTableName { name: String },
|
||||||
#[snafu(display("LanceDBError: Table '{name}' was not found"))]
|
#[snafu(display("Invalid input, {message}"))]
|
||||||
|
InvalidInput { message: String },
|
||||||
|
#[snafu(display("Table '{name}' was not found"))]
|
||||||
TableNotFound { name: String },
|
TableNotFound { name: String },
|
||||||
#[snafu(display("LanceDBError: Table '{name}' already exists"))]
|
#[snafu(display("Table '{name}' already exists"))]
|
||||||
TableAlreadyExists { name: String },
|
TableAlreadyExists { name: String },
|
||||||
#[snafu(display("LanceDBError: Unable to created lance dataset at {path}: {source}"))]
|
#[snafu(display("Unable to created lance dataset at {path}: {source}"))]
|
||||||
CreateDir {
|
CreateDir {
|
||||||
path: String,
|
path: String,
|
||||||
source: std::io::Error,
|
source: std::io::Error,
|
||||||
},
|
},
|
||||||
#[snafu(display("LanceDBError: {message}"))]
|
#[snafu(display("Schema Error: {message}"))]
|
||||||
Store { message: String },
|
|
||||||
#[snafu(display("LanceDBError: {message}"))]
|
|
||||||
Lance { message: String },
|
|
||||||
#[snafu(display("LanceDB Schema Error: {message}"))]
|
|
||||||
Schema { message: String },
|
Schema { message: String },
|
||||||
#[snafu(display("Runtime error: {message}"))]
|
#[snafu(display("Runtime error: {message}"))]
|
||||||
Runtime { message: String },
|
Runtime { message: String },
|
||||||
|
|
||||||
|
// 3rd party / external errors
|
||||||
|
#[snafu(display("object_store error: {message}"))]
|
||||||
|
Store { message: String },
|
||||||
|
#[snafu(display("lance error: {message}"))]
|
||||||
|
Lance { message: String },
|
||||||
|
#[snafu(display("Http error: {message}"))]
|
||||||
|
Http { message: String },
|
||||||
|
#[snafu(display("Arrow error: {message}"))]
|
||||||
|
Arrow { message: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
impl From<ArrowError> for Error {
|
impl From<ArrowError> for Error {
|
||||||
fn from(e: ArrowError) -> Self {
|
fn from(e: ArrowError) -> Self {
|
||||||
Self::Lance {
|
Self::Arrow {
|
||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -82,3 +90,21 @@ impl<T> From<PoisonError<T>> for Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
impl From<reqwest::Error> for Error {
|
||||||
|
fn from(e: reqwest::Error) -> Self {
|
||||||
|
Self::Http {
|
||||||
|
message: e.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
impl From<url::ParseError> for Error {
|
||||||
|
fn from(e: url::ParseError) -> Self {
|
||||||
|
Self::Http {
|
||||||
|
message: e.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -188,6 +188,8 @@ pub mod index;
|
|||||||
pub mod io;
|
pub mod io;
|
||||||
pub mod ipc;
|
pub mod ipc;
|
||||||
pub mod query;
|
pub mod query;
|
||||||
|
#[cfg(feature = "remote")]
|
||||||
|
pub(crate) mod remote;
|
||||||
pub mod table;
|
pub mod table;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
|
|||||||
23
rust/lancedb/src/remote.rs
Normal file
23
rust/lancedb/src/remote.rs
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
// Copyright 2024 Lance Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
//! This module contains a remote client for a LanceDB server. This is used
|
||||||
|
//! to communicate with LanceDB cloud. It can also serve as an example for
|
||||||
|
//! building client/server applications with LanceDB or as a client for some
|
||||||
|
//! other custom LanceDB service.
|
||||||
|
|
||||||
|
pub mod client;
|
||||||
|
pub mod db;
|
||||||
|
pub mod table;
|
||||||
|
pub mod util;
|
||||||
124
rust/lancedb/src/remote/client.rs
Normal file
124
rust/lancedb/src/remote/client.rs
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
// Copyright 2024 LanceDB Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use reqwest::{
|
||||||
|
header::{HeaderMap, HeaderValue},
|
||||||
|
RequestBuilder, Response,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct RestfulLanceDbClient {
|
||||||
|
client: reqwest::Client,
|
||||||
|
host: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RestfulLanceDbClient {
|
||||||
|
fn default_headers(
|
||||||
|
api_key: &str,
|
||||||
|
region: &str,
|
||||||
|
db_name: &str,
|
||||||
|
has_host_override: bool,
|
||||||
|
) -> Result<HeaderMap> {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"x-api-key",
|
||||||
|
HeaderValue::from_str(api_key).map_err(|_| Error::Http {
|
||||||
|
message: "non-ascii api key provided".to_string(),
|
||||||
|
})?,
|
||||||
|
);
|
||||||
|
if region == "local" {
|
||||||
|
let host = format!("{}.local.api.lancedb.com", db_name);
|
||||||
|
headers.insert(
|
||||||
|
"Host",
|
||||||
|
HeaderValue::from_str(&host).map_err(|_| Error::Http {
|
||||||
|
message: format!("non-ascii database name '{}' provided", db_name),
|
||||||
|
})?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if has_host_override {
|
||||||
|
headers.insert(
|
||||||
|
"x-lancedb-database",
|
||||||
|
HeaderValue::from_str(db_name).map_err(|_| Error::Http {
|
||||||
|
message: format!("non-ascii database name '{}' provided", db_name),
|
||||||
|
})?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn try_new(
|
||||||
|
db_url: &str,
|
||||||
|
api_key: &str,
|
||||||
|
region: &str,
|
||||||
|
host_override: Option<String>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let parsed_url = url::Url::parse(db_url)?;
|
||||||
|
debug_assert_eq!(parsed_url.scheme(), "db");
|
||||||
|
if !parsed_url.has_host() {
|
||||||
|
return Err(Error::Http {
|
||||||
|
message: format!("Invalid database URL (missing host) '{}'", db_url),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let db_name = parsed_url.host_str().unwrap();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(30))
|
||||||
|
.default_headers(Self::default_headers(
|
||||||
|
api_key,
|
||||||
|
region,
|
||||||
|
db_name,
|
||||||
|
host_override.is_some(),
|
||||||
|
)?)
|
||||||
|
.build()?;
|
||||||
|
let host = match host_override {
|
||||||
|
Some(host_override) => host_override,
|
||||||
|
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
|
||||||
|
};
|
||||||
|
Ok(Self { client, host })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, uri: &str) -> RequestBuilder {
|
||||||
|
let full_uri = format!("{}{}", self.host, uri);
|
||||||
|
self.client.get(full_uri)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post(&self, uri: &str) -> RequestBuilder {
|
||||||
|
let full_uri = format!("{}{}", self.host, uri);
|
||||||
|
self.client.post(full_uri)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn rsp_to_str(response: Response) -> String {
|
||||||
|
let status = response.status();
|
||||||
|
response.text().await.unwrap_or_else(|_| status.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn check_response(&self, response: Response) -> Result<Response> {
|
||||||
|
let status_int: u16 = u16::from(response.status());
|
||||||
|
if (400..500).contains(&status_int) {
|
||||||
|
Err(Error::InvalidInput {
|
||||||
|
message: Self::rsp_to_str(response).await,
|
||||||
|
})
|
||||||
|
} else if status_int != 200 {
|
||||||
|
Err(Error::Runtime {
|
||||||
|
message: Self::rsp_to_str(response).await,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
102
rust/lancedb/src/remote/db.rs
Normal file
102
rust/lancedb/src/remote/db.rs
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
// Copyright 2024 LanceDB Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::header::CONTENT_TYPE;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tokio::task::spawn_blocking;
|
||||||
|
|
||||||
|
use crate::connection::{ConnectionInternal, CreateTableBuilder, OpenTableBuilder};
|
||||||
|
use crate::error::Result;
|
||||||
|
use crate::TableRef;
|
||||||
|
|
||||||
|
use super::client::RestfulLanceDbClient;
|
||||||
|
use super::table::RemoteTable;
|
||||||
|
use super::util::batches_to_ipc_bytes;
|
||||||
|
|
||||||
|
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ListTablesResponse {
|
||||||
|
tables: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RemoteDatabase {
|
||||||
|
client: RestfulLanceDbClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RemoteDatabase {
|
||||||
|
pub fn try_new(
|
||||||
|
uri: &str,
|
||||||
|
api_key: &str,
|
||||||
|
region: &str,
|
||||||
|
host_override: Option<String>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let client = RestfulLanceDbClient::try_new(uri, api_key, region, host_override)?;
|
||||||
|
Ok(Self { client })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ConnectionInternal for RemoteDatabase {
|
||||||
|
async fn table_names(&self) -> Result<Vec<String>> {
|
||||||
|
let rsp = self
|
||||||
|
.client
|
||||||
|
.get("/v1/table/")
|
||||||
|
.query(&[("limit", 10)])
|
||||||
|
.query(&[("page_token", "")])
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
let rsp = self.client.check_response(rsp).await?;
|
||||||
|
Ok(rsp.json::<ListTablesResponse>().await?.tables)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef> {
|
||||||
|
let data = options.data.unwrap();
|
||||||
|
// TODO: https://github.com/lancedb/lancedb/issues/1026
|
||||||
|
// We should accept data from an async source. In the meantime, spawn this as blocking
|
||||||
|
// to make sure we don't block the tokio runtime if the source is slow.
|
||||||
|
let data_buffer = spawn_blocking(move || batches_to_ipc_bytes(data))
|
||||||
|
.await
|
||||||
|
.unwrap()?;
|
||||||
|
|
||||||
|
self.client
|
||||||
|
.post(&format!("/v1/table/{}/create", options.name))
|
||||||
|
.body(data_buffer)
|
||||||
|
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||||
|
.header("x-request-id", "na")
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Arc::new(RemoteTable::new(
|
||||||
|
self.client.clone(),
|
||||||
|
options.name,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_open_table(&self, _options: OpenTableBuilder) -> Result<TableRef> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn drop_table(&self, _name: &str) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn drop_db(&self) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
89
rust/lancedb/src/remote/table.rs
Normal file
89
rust/lancedb/src/remote/table.rs
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
use arrow_array::RecordBatchReader;
|
||||||
|
use arrow_schema::SchemaRef;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use lance::dataset::{ColumnAlteration, NewColumnTransform};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
error::Result,
|
||||||
|
index::IndexBuilder,
|
||||||
|
query::Query,
|
||||||
|
table::{
|
||||||
|
merge::MergeInsertBuilder, AddDataOptions, NativeTable, OptimizeAction, OptimizeStats,
|
||||||
|
},
|
||||||
|
Table,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::client::RestfulLanceDbClient;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RemoteTable {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
client: RestfulLanceDbClient,
|
||||||
|
name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RemoteTable {
|
||||||
|
pub fn new(client: RestfulLanceDbClient, name: String) -> Self {
|
||||||
|
Self { client, name }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for RemoteTable {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "RemoteTable({})", self.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Table for RemoteTable {
|
||||||
|
fn as_any(&self) -> &dyn std::any::Any {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
fn as_native(&self) -> Option<&NativeTable> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
&self.name
|
||||||
|
}
|
||||||
|
async fn schema(&self) -> Result<SchemaRef> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
async fn count_rows(&self, _filter: Option<String>) -> Result<usize> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
async fn add(
|
||||||
|
&self,
|
||||||
|
_batches: Box<dyn RecordBatchReader + Send>,
|
||||||
|
_options: AddDataOptions,
|
||||||
|
) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
async fn delete(&self, _predicate: &str) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
fn create_index(&self, _column: &[&str]) -> IndexBuilder {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
fn merge_insert(&self, _on: &[&str]) -> MergeInsertBuilder {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
fn query(&self) -> Query {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
async fn optimize(&self, _action: OptimizeAction) -> Result<OptimizeStats> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
async fn add_columns(
|
||||||
|
&self,
|
||||||
|
_transforms: NewColumnTransform,
|
||||||
|
_read_columns: Option<Vec<String>>,
|
||||||
|
) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
async fn alter_columns(&self, _alterations: &[ColumnAlteration]) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
async fn drop_columns(&self, _columns: &[&str]) -> Result<()> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
21
rust/lancedb/src/remote/util.rs
Normal file
21
rust/lancedb/src/remote/util.rs
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
use std::io::Cursor;
|
||||||
|
|
||||||
|
use arrow_array::RecordBatchReader;
|
||||||
|
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
|
pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result<Vec<u8>> {
|
||||||
|
const WRITE_BUF_SIZE: usize = 4096;
|
||||||
|
let buf = Vec::with_capacity(WRITE_BUF_SIZE);
|
||||||
|
let mut buf = Cursor::new(buf);
|
||||||
|
{
|
||||||
|
let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut buf, &batches.schema())?;
|
||||||
|
|
||||||
|
for batch in batches {
|
||||||
|
let batch = batch?;
|
||||||
|
writer.write(&batch)?;
|
||||||
|
}
|
||||||
|
writer.finish()?;
|
||||||
|
}
|
||||||
|
Ok(buf.into_inner())
|
||||||
|
}
|
||||||
67
rust/lancedb/tests/lancedb_cloud.rs
Normal file
67
rust/lancedb/tests/lancedb_cloud.rs
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// Copyright 2024 LanceDB Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow_array::RecordBatchIterator;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore]
|
||||||
|
async fn cloud_integration_test() {
|
||||||
|
let project = std::env::var("LANCEDB_PROJECT")
|
||||||
|
.expect("the LANCEDB_PROJECT env must be set to run the cloud integration test");
|
||||||
|
let api_key = std::env::var("LANCEDB_API_KEY")
|
||||||
|
.expect("the LANCEDB_API_KEY env must be set to run the cloud integration test");
|
||||||
|
let region = std::env::var("LANCEDB_REGION")
|
||||||
|
.expect("the LANCEDB_REGION env must be set to run the cloud integration test");
|
||||||
|
let host_override = std::env::var("LANCEDB_HOST_OVERRIDE")
|
||||||
|
.map(Some)
|
||||||
|
.unwrap_or(None);
|
||||||
|
if host_override.is_none() {
|
||||||
|
println!("No LANCEDB_HOST_OVERRIDE has been set. Running integration test against LanceDb Cloud production instance");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut builder = lancedb::connect(&format!("db://{}", project))
|
||||||
|
.api_key(&api_key)
|
||||||
|
.region(®ion);
|
||||||
|
if let Some(host_override) = &host_override {
|
||||||
|
builder = builder.host_override(host_override);
|
||||||
|
}
|
||||||
|
let db = builder.execute().await.unwrap();
|
||||||
|
|
||||||
|
let schema = Arc::new(arrow_schema::Schema::new(vec![
|
||||||
|
arrow_schema::Field::new("id", arrow_schema::DataType::Int64, false),
|
||||||
|
arrow_schema::Field::new("name", arrow_schema::DataType::Utf8, false),
|
||||||
|
]));
|
||||||
|
let initial_data = arrow::record_batch::RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(arrow_array::Int64Array::from(vec![1, 2, 3])),
|
||||||
|
Arc::new(arrow_array::StringArray::from(vec!["a", "b", "c"])),
|
||||||
|
],
|
||||||
|
);
|
||||||
|
let rbr = RecordBatchIterator::new(vec![initial_data], schema);
|
||||||
|
|
||||||
|
let name = uuid::Uuid::new_v4().to_string();
|
||||||
|
let tbl = db
|
||||||
|
.create_table(name.clone(), Box::new(rbr))
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(tbl.name(), name);
|
||||||
|
|
||||||
|
let table_names = db.table_names().await.unwrap();
|
||||||
|
assert!(table_names.contains(&name));
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user