mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
47 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a33a0670f6 | ||
|
|
14c9ff46d1 | ||
|
|
1865f7decf | ||
|
|
a608621476 | ||
|
|
00514999ff | ||
|
|
b3b597fef6 | ||
|
|
bf17144591 | ||
|
|
09e110525f | ||
|
|
40f0dbb64d | ||
|
|
3b19e96ae7 | ||
|
|
78a17ad54c | ||
|
|
a8e6b491e2 | ||
|
|
cea541ca46 | ||
|
|
873ffc1042 | ||
|
|
83273ad997 | ||
|
|
d18d63c69d | ||
|
|
c3e865e8d0 | ||
|
|
a7755cb313 | ||
|
|
3490f3456f | ||
|
|
0a1d0693e1 | ||
|
|
fd330b4b4b | ||
|
|
d4e9fc08e0 | ||
|
|
3626f2f5e1 | ||
|
|
e64712cfa5 | ||
|
|
3e3118f85c | ||
|
|
592598a333 | ||
|
|
5ad21341c9 | ||
|
|
6e08caa091 | ||
|
|
7e259d8b0f | ||
|
|
e84f747464 | ||
|
|
998cd43fe6 | ||
|
|
4bc7eebe61 | ||
|
|
2e3b34e79b | ||
|
|
e7574698eb | ||
|
|
801a9e5f6f | ||
|
|
4e5fbe6c99 | ||
|
|
1a449fa49e | ||
|
|
6bf742c759 | ||
|
|
ef3093bc23 | ||
|
|
16851389ea | ||
|
|
c269524b2f | ||
|
|
f6eef14313 | ||
|
|
32716adaa3 | ||
|
|
5e98b7f4c0 | ||
|
|
3f2589c11f | ||
|
|
e3b99694d6 | ||
|
|
9d42dc349c |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.15.1-beta.2"
|
||||
current_version = "0.16.1-beta.2"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -7,7 +7,7 @@ repos:
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.2.2
|
||||
rev: v0.8.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
- repo: local
|
||||
|
||||
1001
Cargo.lock
generated
1001
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
32
Cargo.toml
32
Cargo.toml
@@ -21,16 +21,16 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.78.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.23.0", "features" = [
|
||||
lance = { "version" = "=0.23.1", "features" = [
|
||||
"dynamodb",
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.4" }
|
||||
lance-io = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.4" }
|
||||
lance-index = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.4" }
|
||||
lance-linalg = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.4" }
|
||||
lance-table = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.4" }
|
||||
lance-testing = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.4" }
|
||||
lance-datafusion = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.4" }
|
||||
lance-encoding = { version = "=0.23.0", git = "https://github.com/lancedb/lance.git", tag = "v0.23.0-beta.4" }
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.23.1-beta.4"}
|
||||
lance-io = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
lance-index = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
lance-linalg = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
lance-table = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
lance-testing = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
lance-datafusion = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
lance-encoding = {version = "=0.23.1", tag="v0.23.1-beta.4", git = "https://github.com/lancedb/lance.git"}
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "53.2", optional = false }
|
||||
arrow-array = "53.2"
|
||||
@@ -42,18 +42,22 @@ arrow-arith = "53.2"
|
||||
arrow-cast = "53.2"
|
||||
async-trait = "0"
|
||||
chrono = "0.4.35"
|
||||
datafusion-common = "44.0"
|
||||
datafusion = { version = "44.0", default-features = false }
|
||||
datafusion-catalog = "44.0"
|
||||
datafusion-common = { version = "44.0", default-features = false }
|
||||
datafusion-execution = "44.0"
|
||||
datafusion-expr = "44.0"
|
||||
datafusion-physical-plan = "44.0"
|
||||
env_logger = "0.10"
|
||||
env_logger = "0.11"
|
||||
half = { "version" = "=2.4.1", default-features = false, features = [
|
||||
"num-traits",
|
||||
] }
|
||||
futures = "0"
|
||||
log = "0.4"
|
||||
moka = { version = "0.11", features = ["future"] }
|
||||
object_store = "0.10.2"
|
||||
moka = { version = "0.12", features = ["future"] }
|
||||
object_store = "0.11.0"
|
||||
pin-project = "1.0.7"
|
||||
snafu = "0.7.4"
|
||||
snafu = "0.8"
|
||||
url = "2"
|
||||
num-traits = "0.2"
|
||||
rand = "0.8"
|
||||
|
||||
@@ -38,6 +38,13 @@ components:
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
index_name:
|
||||
name: index_name
|
||||
in: path
|
||||
description: name of the index
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
invalid_request:
|
||||
description: Invalid request
|
||||
@@ -485,3 +492,22 @@ paths:
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
"404":
|
||||
$ref: "#/components/responses/not_found"
|
||||
/v1/table/{name}/index/{index_name}/drop/:
|
||||
post:
|
||||
description: Drop an index from the table
|
||||
tags:
|
||||
- Tables
|
||||
summary: Drop an index from the table
|
||||
operationId: dropIndex
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/table_name"
|
||||
- $ref: "#/components/parameters/index_name"
|
||||
responses:
|
||||
"200":
|
||||
description: Index successfully dropped
|
||||
"400":
|
||||
$ref: "#/components/responses/invalid_request"
|
||||
"401":
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
"404":
|
||||
$ref: "#/components/responses/not_found"
|
||||
@@ -3,6 +3,7 @@ import * as vectordb from "vectordb";
|
||||
// --8<-- [end:import]
|
||||
|
||||
(async () => {
|
||||
console.log("ann_indexes.ts: start");
|
||||
// --8<-- [start:ingest]
|
||||
const db = await vectordb.connect("data/sample-lancedb");
|
||||
|
||||
@@ -49,5 +50,5 @@ import * as vectordb from "vectordb";
|
||||
.execute();
|
||||
// --8<-- [end:search3]
|
||||
|
||||
console.log("Ann indexes: done");
|
||||
console.log("ann_indexes.ts: done");
|
||||
})();
|
||||
|
||||
@@ -107,7 +107,6 @@ const example = async () => {
|
||||
// --8<-- [start:search]
|
||||
const query = await tbl.search([100, 100]).limit(2).execute();
|
||||
// --8<-- [end:search]
|
||||
console.log(query);
|
||||
|
||||
// --8<-- [start:delete]
|
||||
await tbl.delete('item = "fizz"');
|
||||
@@ -119,8 +118,9 @@ const example = async () => {
|
||||
};
|
||||
|
||||
async function main() {
|
||||
console.log("basic_legacy.ts: start");
|
||||
await example();
|
||||
console.log("Basic example: done");
|
||||
console.log("basic_legacy.ts: done");
|
||||
}
|
||||
|
||||
main();
|
||||
|
||||
@@ -131,6 +131,20 @@ Return a brief description of the connection
|
||||
|
||||
***
|
||||
|
||||
### dropAllTables()
|
||||
|
||||
```ts
|
||||
abstract dropAllTables(): Promise<void>
|
||||
```
|
||||
|
||||
Drop all tables in the database.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`void`>
|
||||
|
||||
***
|
||||
|
||||
### dropTable()
|
||||
|
||||
```ts
|
||||
|
||||
@@ -22,8 +22,6 @@ when creating a table or adding data to it)
|
||||
This function converts an array of Record<String, any> (row-major JS objects)
|
||||
to an Arrow Table (a columnar structure)
|
||||
|
||||
Note that it currently does not support nulls.
|
||||
|
||||
If a schema is provided then it will be used to determine the resulting array
|
||||
types. Fields will also be reordered to fit the order defined by the schema.
|
||||
|
||||
@@ -31,6 +29,9 @@ If a schema is not provided then the types will be inferred and the field order
|
||||
will be controlled by the order of properties in the first record. If a type
|
||||
is inferred it will always be nullable.
|
||||
|
||||
If not all fields are found in the data, then a subset of the schema will be
|
||||
returned.
|
||||
|
||||
If the input is empty then a schema must be provided to create an empty table.
|
||||
|
||||
When a schema is not specified then data types will be inferred. The inference
|
||||
@@ -38,6 +39,7 @@ rules are as follows:
|
||||
|
||||
- boolean => Bool
|
||||
- number => Float64
|
||||
- bigint => Int64
|
||||
- String => Utf8
|
||||
- Buffer => Binary
|
||||
- Record<String, any> => Struct
|
||||
|
||||
@@ -8,6 +8,14 @@
|
||||
|
||||
## Properties
|
||||
|
||||
### extraHeaders?
|
||||
|
||||
```ts
|
||||
optional extraHeaders: Record<string, string>;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### retryConfig?
|
||||
|
||||
```ts
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
## Properties
|
||||
|
||||
### dataStorageVersion?
|
||||
### ~~dataStorageVersion?~~
|
||||
|
||||
```ts
|
||||
optional dataStorageVersion: string;
|
||||
@@ -19,6 +19,10 @@ The version of the data storage format to use.
|
||||
The default is `stable`.
|
||||
Set to "legacy" to use the old format.
|
||||
|
||||
#### Deprecated
|
||||
|
||||
Pass `new_table_data_storage_version` to storageOptions instead.
|
||||
|
||||
***
|
||||
|
||||
### embeddingFunction?
|
||||
@@ -29,7 +33,7 @@ optional embeddingFunction: EmbeddingFunctionConfig;
|
||||
|
||||
***
|
||||
|
||||
### enableV2ManifestPaths?
|
||||
### ~~enableV2ManifestPaths?~~
|
||||
|
||||
```ts
|
||||
optional enableV2ManifestPaths: boolean;
|
||||
@@ -41,6 +45,10 @@ turning this on will make the dataset unreadable for older versions
|
||||
of LanceDB (prior to 0.10.0). To migrate an existing dataset, instead
|
||||
use the LocalTable#migrateManifestPathsV2 method.
|
||||
|
||||
#### Deprecated
|
||||
|
||||
Pass `new_table_enable_v2_manifest_paths` to storageOptions instead.
|
||||
|
||||
***
|
||||
|
||||
### existOk
|
||||
@@ -90,17 +98,3 @@ Options already set on the connection will be inherited by the table,
|
||||
but can be overridden here.
|
||||
|
||||
The available options are described at https://lancedb.github.io/lancedb/guides/storage/
|
||||
|
||||
***
|
||||
|
||||
### useLegacyFormat?
|
||||
|
||||
```ts
|
||||
optional useLegacyFormat: boolean;
|
||||
```
|
||||
|
||||
If true then data files will be written with the legacy format
|
||||
|
||||
The default is false.
|
||||
|
||||
Deprecated. Use data storage version instead.
|
||||
|
||||
@@ -20,6 +20,7 @@ async function setup() {
|
||||
}
|
||||
|
||||
async () => {
|
||||
console.log("search_legacy.ts: start");
|
||||
await setup();
|
||||
|
||||
// --8<-- [start:search1]
|
||||
@@ -37,5 +38,5 @@ async () => {
|
||||
.execute();
|
||||
// --8<-- [end:search2]
|
||||
|
||||
console.log("search: done");
|
||||
console.log("search_legacy.ts: done");
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import * as vectordb from "vectordb";
|
||||
|
||||
(async () => {
|
||||
console.log("sql_legacy.ts: start");
|
||||
const db = await vectordb.connect("data/sample-lancedb");
|
||||
|
||||
let data = [];
|
||||
@@ -34,5 +35,5 @@ import * as vectordb from "vectordb";
|
||||
await tbl.filter("id = 10").limit(10).execute();
|
||||
// --8<-- [end:sql_search]
|
||||
|
||||
console.log("SQL search: done");
|
||||
console.log("sql_legacy.ts: done");
|
||||
})();
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.15.1-beta.2</version>
|
||||
<version>0.16.1-beta.2</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.15.1-beta.2</version>
|
||||
<version>0.16.1-beta.2</version>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<name>LanceDB Parent</name>
|
||||
|
||||
124
node/package-lock.json
generated
124
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -52,14 +52,14 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-darwin-x64": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.15.1-beta.2"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-darwin-x64": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.16.1-beta.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-arrow/ts": "^14.0.2",
|
||||
@@ -329,110 +329,6 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.15.1-beta.2",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.15.1-beta.2.tgz",
|
||||
"integrity": "sha512-hq5VkIW7oP+S030+o14TfSnsanjtKNuYMtv4ANBAsV7Lwr/q7EKwMZLwrbLW4Y/1hrjTjdwZF2ePgd5UwXdq1w==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.15.1-beta.2",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.15.1-beta.2.tgz",
|
||||
"integrity": "sha512-Xvms7y1PG52gDlbaWSotYjYhiT7oF9eF8T29H6wk5FA43Eu6fg1+wrrvMR8+7gDfwjsCN3kNabk9Cu4DDYtldg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.15.1-beta.2",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.15.1-beta.2.tgz",
|
||||
"integrity": "sha512-jTJEan8TLTtly1cucWvQZrh3N7fuJWDBz+ADmHUMQbRH6SDUoGqIekgs0u1WrgYuLFIPAi6V1VI66qbCWncCLQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-musl": {
|
||||
"version": "0.15.1-beta.2",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-musl/-/vectordb-linux-arm64-musl-0.15.1-beta.2.tgz",
|
||||
"integrity": "sha512-KRTcMTsMIdkYz2u1BZ5xPB8q1unyyKEXXQZSj4Sye0vP5lBm1vvH4IAwZ3BkxffULFDLHpPegGSnnq0xNBWp2g==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.15.1-beta.2",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.15.1-beta.2.tgz",
|
||||
"integrity": "sha512-oILD1M8BYAgNmdIuK1jqueqVtdkAvXfpuMNMOiUq2NhUaa/vFxO6Tb2XR8tOWoE14/VJkWTLpk+nLhjgfiNNTA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-musl": {
|
||||
"version": "0.15.1-beta.2",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-musl/-/vectordb-linux-x64-musl-0.15.1-beta.2.tgz",
|
||||
"integrity": "sha512-LO3WE7UUDOJIAN4hifTpaX8tzjfXRtTghqCilUGei8uaZbOD2T1pJKT89Ub7AurNVr3NEWT4NYu/Fa7ilfkb6w==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-arm64-msvc": {
|
||||
"version": "0.15.1-beta.2",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-arm64-msvc/-/vectordb-win32-arm64-msvc-0.15.1-beta.2.tgz",
|
||||
"integrity": "sha512-5H52qrtS1GIif6HC3lvuilKpzKhPJNkkAB5R9bx7ly2EN3d+ZFF4fwdlXhyoUaHI0SUf7z00hzgIVHqWnDadCg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.15.1-beta.2",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.15.1-beta.2.tgz",
|
||||
"integrity": "sha512-9sqH6uw8WWgQSZPzHCV9Ncurlx2YMwr6dnQtQGQ/KTDm7n/V2bz9m0t9+07jhyVzfJIxBkR/UlDMC3U3M/uAuQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
]
|
||||
},
|
||||
"node_modules/@neon-rs/cli": {
|
||||
"version": "0.0.160",
|
||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"private": false,
|
||||
"main": "dist/index.js",
|
||||
@@ -92,13 +92,13 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-x64": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-darwin-arm64": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.15.1-beta.2",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.15.1-beta.2"
|
||||
"@lancedb/vectordb-darwin-x64": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-darwin-arm64": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-x64-musl": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-linux-arm64-musl": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.16.1-beta.2",
|
||||
"@lancedb/vectordb-win32-arm64-msvc": "0.16.1-beta.2"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,7 +47,8 @@ const {
|
||||
tableSchema,
|
||||
tableAddColumns,
|
||||
tableAlterColumns,
|
||||
tableDropColumns
|
||||
tableDropColumns,
|
||||
tableDropIndex
|
||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||
} = require("../native.js");
|
||||
|
||||
@@ -604,6 +605,13 @@ export interface Table<T = number[]> {
|
||||
*/
|
||||
dropColumns(columnNames: string[]): Promise<void>
|
||||
|
||||
/**
|
||||
* Drop an index from the table
|
||||
*
|
||||
* @param indexName The name of the index to drop
|
||||
*/
|
||||
dropIndex(indexName: string): Promise<void>
|
||||
|
||||
/**
|
||||
* Instrument the behavior of this Table with middleware.
|
||||
*
|
||||
@@ -1206,6 +1214,10 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
return tableDropColumns.call(this._tbl, columnNames);
|
||||
}
|
||||
|
||||
async dropIndex(indexName: string): Promise<void> {
|
||||
return tableDropIndex.call(this._tbl, indexName);
|
||||
}
|
||||
|
||||
withMiddleware(middleware: HttpMiddleware): Table<T> {
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -471,6 +471,18 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
)
|
||||
}
|
||||
}
|
||||
async dropIndex (index_name: string): Promise<void> {
|
||||
const res = await this._client.post(
|
||||
`/v1/table/${encodeURIComponent(this._name)}/index/${encodeURIComponent(index_name)}/drop/`
|
||||
)
|
||||
if (res.status !== 200) {
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${await res.body()}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async countRows (filter?: string): Promise<number> {
|
||||
const result = await this._client.post(`/v1/table/${encodeURIComponent(this._name)}/count_rows/`, {
|
||||
|
||||
@@ -894,6 +894,27 @@ describe("LanceDB client", function () {
|
||||
expect(stats.distanceType).to.equal("l2");
|
||||
expect(stats.numIndices).to.equal(1);
|
||||
}).timeout(50_000);
|
||||
|
||||
// not yet implemented
|
||||
// it("can drop index", async function () {
|
||||
// const uri = await createTestDB(32, 300);
|
||||
// const con = await lancedb.connect(uri);
|
||||
// const table = await con.openTable("vectors");
|
||||
// await table.createIndex({
|
||||
// type: "ivf_pq",
|
||||
// column: "vector",
|
||||
// num_partitions: 2,
|
||||
// max_iters: 2,
|
||||
// num_sub_vectors: 2
|
||||
// });
|
||||
//
|
||||
// const indices = await table.listIndices();
|
||||
// expect(indices).to.have.lengthOf(1);
|
||||
// expect(indices[0].name).to.equal("vector_idx");
|
||||
//
|
||||
// await table.dropIndex("vector_idx");
|
||||
// expect(await table.listIndices()).to.have.lengthOf(0);
|
||||
// }).timeout(50_000);
|
||||
});
|
||||
|
||||
describe("when using a custom embedding function", function () {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.15.1-beta.2"
|
||||
version = "0.16.1-beta.2"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
@@ -55,6 +55,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
Float64,
|
||||
Struct,
|
||||
List,
|
||||
Int16,
|
||||
Int32,
|
||||
Int64,
|
||||
Float,
|
||||
@@ -108,13 +109,16 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
false,
|
||||
),
|
||||
]);
|
||||
|
||||
const table = (await tableCreationMethod(
|
||||
records,
|
||||
recordsReversed,
|
||||
schema,
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
)) as any;
|
||||
|
||||
// We expect deterministic ordering of the fields
|
||||
expect(table.schema.names).toEqual(schema.names);
|
||||
|
||||
schema.fields.forEach(
|
||||
(
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
@@ -141,13 +145,13 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
describe("The function makeArrowTable", function () {
|
||||
it("will use data types from a provided schema instead of inference", async function () {
|
||||
const schema = new Schema([
|
||||
new Field("a", new Int32()),
|
||||
new Field("b", new Float32()),
|
||||
new Field("a", new Int32(), false),
|
||||
new Field("b", new Float32(), true),
|
||||
new Field(
|
||||
"c",
|
||||
new FixedSizeList(3, new Field("item", new Float16())),
|
||||
),
|
||||
new Field("d", new Int64()),
|
||||
new Field("d", new Int64(), true),
|
||||
]);
|
||||
const table = makeArrowTable(
|
||||
[
|
||||
@@ -165,12 +169,15 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
expect(actual.numRows).toBe(3);
|
||||
const actualSchema = actual.schema;
|
||||
expect(actualSchema).toEqual(schema);
|
||||
expect(table.getChild("a")?.toJSON()).toEqual([1, 4, 7]);
|
||||
expect(table.getChild("b")?.toJSON()).toEqual([2, 5, 8]);
|
||||
expect(table.getChild("d")?.toJSON()).toEqual([9n, 10n, null]);
|
||||
});
|
||||
|
||||
it("will assume the column `vector` is FixedSizeList<Float32> by default", async function () {
|
||||
const schema = new Schema([
|
||||
new Field("a", new Float(Precision.DOUBLE), true),
|
||||
new Field("b", new Float(Precision.DOUBLE), true),
|
||||
new Field("b", new Int64(), true),
|
||||
new Field(
|
||||
"vector",
|
||||
new FixedSizeList(
|
||||
@@ -181,9 +188,9 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
),
|
||||
]);
|
||||
const table = makeArrowTable([
|
||||
{ a: 1, b: 2, vector: [1, 2, 3] },
|
||||
{ a: 4, b: 5, vector: [4, 5, 6] },
|
||||
{ a: 7, b: 8, vector: [7, 8, 9] },
|
||||
{ a: 1, b: 2n, vector: [1, 2, 3] },
|
||||
{ a: 4, b: 5n, vector: [4, 5, 6] },
|
||||
{ a: 7, b: 8n, vector: [7, 8, 9] },
|
||||
]);
|
||||
|
||||
const buf = await fromTableToBuffer(table);
|
||||
@@ -193,6 +200,19 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
expect(actual.numRows).toBe(3);
|
||||
const actualSchema = actual.schema;
|
||||
expect(actualSchema).toEqual(schema);
|
||||
|
||||
expect(table.getChild("a")?.toJSON()).toEqual([1, 4, 7]);
|
||||
expect(table.getChild("b")?.toJSON()).toEqual([2n, 5n, 8n]);
|
||||
expect(
|
||||
table
|
||||
.getChild("vector")
|
||||
?.toJSON()
|
||||
.map((v) => v.toJSON()),
|
||||
).toEqual([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
]);
|
||||
});
|
||||
|
||||
it("can support multiple vector columns", async function () {
|
||||
@@ -206,7 +226,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
),
|
||||
new Field(
|
||||
"vec2",
|
||||
new FixedSizeList(3, new Field("item", new Float16(), true)),
|
||||
new FixedSizeList(3, new Field("item", new Float64(), true)),
|
||||
true,
|
||||
),
|
||||
]);
|
||||
@@ -219,7 +239,7 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
{
|
||||
vectorColumns: {
|
||||
vec1: { type: new Float16() },
|
||||
vec2: { type: new Float16() },
|
||||
vec2: { type: new Float64() },
|
||||
},
|
||||
},
|
||||
);
|
||||
@@ -307,6 +327,53 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
it("will allow subsets of columns if nullable", async function () {
|
||||
const schema = new Schema([
|
||||
new Field("a", new Int64(), true),
|
||||
new Field(
|
||||
"s",
|
||||
new Struct([
|
||||
new Field("x", new Int32(), true),
|
||||
new Field("y", new Int32(), true),
|
||||
]),
|
||||
true,
|
||||
),
|
||||
new Field("d", new Int16(), true),
|
||||
]);
|
||||
|
||||
const table = makeArrowTable([{ a: 1n }], { schema });
|
||||
expect(table.numCols).toBe(1);
|
||||
expect(table.numRows).toBe(1);
|
||||
|
||||
const table2 = makeArrowTable([{ a: 1n, d: 2 }], { schema });
|
||||
expect(table2.numCols).toBe(2);
|
||||
|
||||
const table3 = makeArrowTable([{ s: { y: 3 } }], { schema });
|
||||
expect(table3.numCols).toBe(1);
|
||||
const expectedSchema = new Schema([
|
||||
new Field("s", new Struct([new Field("y", new Int32(), true)]), true),
|
||||
]);
|
||||
expect(table3.schema).toEqual(expectedSchema);
|
||||
});
|
||||
|
||||
it("will work even if columns are sparsely provided", async function () {
|
||||
const sparseRecords = [{ a: 1n }, { b: 2n }, { c: 3n }, { d: 4n }];
|
||||
const table = makeArrowTable(sparseRecords);
|
||||
expect(table.numCols).toBe(4);
|
||||
expect(table.numRows).toBe(4);
|
||||
|
||||
const schema = new Schema([
|
||||
new Field("a", new Int64(), true),
|
||||
new Field("b", new Int32(), true),
|
||||
new Field("c", new Int64(), true),
|
||||
new Field("d", new Int16(), true),
|
||||
]);
|
||||
const table2 = makeArrowTable(sparseRecords, { schema });
|
||||
expect(table2.numCols).toBe(4);
|
||||
expect(table2.numRows).toBe(4);
|
||||
expect(table2.schema).toEqual(schema);
|
||||
});
|
||||
});
|
||||
|
||||
class DummyEmbedding extends EmbeddingFunction<string> {
|
||||
|
||||
@@ -17,14 +17,14 @@ describe("when connecting", () => {
|
||||
it("should connect", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
expect(db.display()).toBe(
|
||||
`NativeDatabase(uri=${tmpDir.name}, read_consistency_interval=None)`,
|
||||
`ListingDatabase(uri=${tmpDir.name}, read_consistency_interval=None)`,
|
||||
);
|
||||
});
|
||||
|
||||
it("should allow read consistency interval to be specified", async () => {
|
||||
const db = await connect(tmpDir.name, { readConsistencyInterval: 5 });
|
||||
expect(db.display()).toBe(
|
||||
`NativeDatabase(uri=${tmpDir.name}, read_consistency_interval=5s)`,
|
||||
`ListingDatabase(uri=${tmpDir.name}, read_consistency_interval=5s)`,
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -61,6 +61,26 @@ describe("given a connection", () => {
|
||||
await expect(tbl.countRows()).resolves.toBe(1);
|
||||
});
|
||||
|
||||
it("should be able to drop tables`", async () => {
|
||||
await db.createTable("test", [{ id: 1 }, { id: 2 }]);
|
||||
await db.createTable("test2", [{ id: 1 }, { id: 2 }]);
|
||||
await db.createTable("test3", [{ id: 1 }, { id: 2 }]);
|
||||
|
||||
await expect(db.tableNames()).resolves.toEqual(["test", "test2", "test3"]);
|
||||
|
||||
await db.dropTable("test2");
|
||||
|
||||
await expect(db.tableNames()).resolves.toEqual(["test", "test3"]);
|
||||
|
||||
await db.dropAllTables();
|
||||
|
||||
await expect(db.tableNames()).resolves.toEqual([]);
|
||||
|
||||
// Make sure we can still create more tables after dropping all
|
||||
|
||||
await db.createTable("test4", [{ id: 1 }, { id: 2 }]);
|
||||
});
|
||||
|
||||
it("should fail if creating table twice, unless overwrite is true", async () => {
|
||||
let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
|
||||
await expect(tbl.countRows()).resolves.toBe(2);
|
||||
@@ -96,14 +116,15 @@ describe("given a connection", () => {
|
||||
const data = [...Array(10000).keys()].map((i) => ({ id: i }));
|
||||
|
||||
// Create in v1 mode
|
||||
let table = await db.createTable("test", data, { useLegacyFormat: true });
|
||||
let table = await db.createTable("test", data, {
|
||||
storageOptions: { newTableDataStorageVersion: "legacy" },
|
||||
});
|
||||
|
||||
const isV2 = async (table: Table) => {
|
||||
const data = await table
|
||||
.query()
|
||||
.limit(10000)
|
||||
.toArrow({ maxBatchLength: 100000 });
|
||||
console.log(data.batches.length);
|
||||
return data.batches.length < 5;
|
||||
};
|
||||
|
||||
@@ -122,7 +143,7 @@ describe("given a connection", () => {
|
||||
const schema = new Schema([new Field("id", new Float64(), true)]);
|
||||
|
||||
table = await db.createEmptyTable("test_v2_empty", schema, {
|
||||
useLegacyFormat: false,
|
||||
storageOptions: { newTableDataStorageVersion: "stable" },
|
||||
});
|
||||
|
||||
await table.add(data);
|
||||
|
||||
@@ -104,4 +104,26 @@ describe("remote connection", () => {
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("should pass on requested extra headers", async () => {
|
||||
await withMockDatabase(
|
||||
(req, res) => {
|
||||
expect(req.headers["x-my-header"]).toEqual("my-value");
|
||||
|
||||
const body = JSON.stringify({ tables: [] });
|
||||
res.writeHead(200, { "Content-Type": "application/json" }).end(body);
|
||||
},
|
||||
async (db) => {
|
||||
const tableNames = await db.tableNames();
|
||||
expect(tableNames).toEqual([]);
|
||||
},
|
||||
{
|
||||
clientConfig: {
|
||||
extraHeaders: {
|
||||
"x-my-header": "my-value",
|
||||
},
|
||||
},
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -253,6 +253,31 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
const arrowTbl = await table.toArrow();
|
||||
expect(arrowTbl).toBeInstanceOf(ArrowTable);
|
||||
});
|
||||
|
||||
it("should be able to handle missing fields", async () => {
|
||||
const schema = new arrow.Schema([
|
||||
new arrow.Field("id", new arrow.Int32(), true),
|
||||
new arrow.Field("y", new arrow.Int32(), true),
|
||||
new arrow.Field("z", new arrow.Int64(), true),
|
||||
]);
|
||||
const db = await connect(tmpDir.name);
|
||||
const table = await db.createEmptyTable("testNull", schema);
|
||||
await table.add([{ id: 1, y: 2 }]);
|
||||
await table.add([{ id: 2 }]);
|
||||
|
||||
await table
|
||||
.mergeInsert("id")
|
||||
.whenNotMatchedInsertAll()
|
||||
.execute([
|
||||
{ id: 3, z: 3 },
|
||||
{ id: 4, z: 5 },
|
||||
]);
|
||||
|
||||
const res = await table.query().toArrow();
|
||||
expect(res.getChild("id")?.toJSON()).toEqual([1, 2, 3, 4]);
|
||||
expect(res.getChild("y")?.toJSON()).toEqual([2, null, null, null]);
|
||||
expect(res.getChild("z")?.toJSON()).toEqual([null, null, 3n, 5n]);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -42,4 +42,4 @@ test("full text search", async () => {
|
||||
expect(result.length).toBe(10);
|
||||
// --8<-- [end:full_text_search]
|
||||
});
|
||||
});
|
||||
}, 10_000);
|
||||
|
||||
@@ -2,31 +2,37 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import {
|
||||
Data as ArrowData,
|
||||
Table as ArrowTable,
|
||||
Binary,
|
||||
Bool,
|
||||
BufferType,
|
||||
DataType,
|
||||
Dictionary,
|
||||
Field,
|
||||
FixedSizeBinary,
|
||||
FixedSizeList,
|
||||
Float,
|
||||
Float32,
|
||||
Float64,
|
||||
Int,
|
||||
Int32,
|
||||
Int64,
|
||||
LargeBinary,
|
||||
List,
|
||||
Null,
|
||||
RecordBatch,
|
||||
RecordBatchFileReader,
|
||||
RecordBatchFileWriter,
|
||||
RecordBatchReader,
|
||||
RecordBatchStreamWriter,
|
||||
Schema,
|
||||
Struct,
|
||||
Utf8,
|
||||
Vector,
|
||||
makeVector as arrowMakeVector,
|
||||
makeBuilder,
|
||||
makeData,
|
||||
type makeTable,
|
||||
makeTable,
|
||||
vectorFromArray,
|
||||
} from "apache-arrow";
|
||||
import { Buffers } from "apache-arrow/data";
|
||||
@@ -236,8 +242,6 @@ export class MakeArrowTableOptions {
|
||||
* This function converts an array of Record<String, any> (row-major JS objects)
|
||||
* to an Arrow Table (a columnar structure)
|
||||
*
|
||||
* Note that it currently does not support nulls.
|
||||
*
|
||||
* If a schema is provided then it will be used to determine the resulting array
|
||||
* types. Fields will also be reordered to fit the order defined by the schema.
|
||||
*
|
||||
@@ -245,6 +249,9 @@ export class MakeArrowTableOptions {
|
||||
* will be controlled by the order of properties in the first record. If a type
|
||||
* is inferred it will always be nullable.
|
||||
*
|
||||
* If not all fields are found in the data, then a subset of the schema will be
|
||||
* returned.
|
||||
*
|
||||
* If the input is empty then a schema must be provided to create an empty table.
|
||||
*
|
||||
* When a schema is not specified then data types will be inferred. The inference
|
||||
@@ -252,6 +259,7 @@ export class MakeArrowTableOptions {
|
||||
*
|
||||
* - boolean => Bool
|
||||
* - number => Float64
|
||||
* - bigint => Int64
|
||||
* - String => Utf8
|
||||
* - Buffer => Binary
|
||||
* - Record<String, any> => Struct
|
||||
@@ -322,126 +330,316 @@ export function makeArrowTable(
|
||||
options?: Partial<MakeArrowTableOptions>,
|
||||
metadata?: Map<string, string>,
|
||||
): ArrowTable {
|
||||
const opt = new MakeArrowTableOptions(options !== undefined ? options : {});
|
||||
let schema: Schema | undefined = undefined;
|
||||
if (opt.schema !== undefined && opt.schema !== null) {
|
||||
schema = sanitizeSchema(opt.schema);
|
||||
schema = validateSchemaEmbeddings(
|
||||
schema as Schema,
|
||||
data,
|
||||
options?.embeddingFunction,
|
||||
);
|
||||
}
|
||||
|
||||
let schemaMetadata = schema?.metadata || new Map<string, string>();
|
||||
if (metadata !== undefined) {
|
||||
schemaMetadata = new Map([...schemaMetadata, ...metadata]);
|
||||
}
|
||||
|
||||
if (
|
||||
data.length === 0 &&
|
||||
(options?.schema === undefined || options?.schema === null)
|
||||
) {
|
||||
throw new Error("At least one record or a schema needs to be provided");
|
||||
}
|
||||
|
||||
const opt = new MakeArrowTableOptions(options !== undefined ? options : {});
|
||||
if (opt.schema !== undefined && opt.schema !== null) {
|
||||
opt.schema = sanitizeSchema(opt.schema);
|
||||
opt.schema = validateSchemaEmbeddings(
|
||||
opt.schema as Schema,
|
||||
data,
|
||||
options?.embeddingFunction,
|
||||
);
|
||||
}
|
||||
const columns: Record<string, Vector> = {};
|
||||
// TODO: sample dataset to find missing columns
|
||||
// Prefer the field ordering of the schema, if present
|
||||
const columnNames =
|
||||
opt.schema != null ? (opt.schema.names as string[]) : Object.keys(data[0]);
|
||||
for (const colName of columnNames) {
|
||||
if (
|
||||
data.length !== 0 &&
|
||||
!Object.prototype.hasOwnProperty.call(data[0], colName)
|
||||
) {
|
||||
// The field is present in the schema, but not in the data, skip it
|
||||
continue;
|
||||
}
|
||||
// Extract a single column from the records (transpose from row-major to col-major)
|
||||
let values = data.map((datum) => datum[colName]);
|
||||
|
||||
// By default (type === undefined) arrow will infer the type from the JS type
|
||||
let type;
|
||||
if (opt.schema !== undefined) {
|
||||
// If there is a schema provided, then use that for the type instead
|
||||
type = opt.schema?.fields.filter((f) => f.name === colName)[0]?.type;
|
||||
if (DataType.isInt(type) && type.bitWidth === 64) {
|
||||
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051
|
||||
values = values.map((v) => {
|
||||
if (v === null) {
|
||||
return v;
|
||||
}
|
||||
if (typeof v === "bigint") {
|
||||
return v;
|
||||
}
|
||||
if (typeof v === "number") {
|
||||
return BigInt(v);
|
||||
}
|
||||
throw new Error(
|
||||
`Expected BigInt or number for column ${colName}, got ${typeof v}`,
|
||||
);
|
||||
});
|
||||
}
|
||||
} else if (data.length === 0) {
|
||||
if (schema === undefined) {
|
||||
throw new Error("A schema must be provided if data is empty");
|
||||
} else {
|
||||
// Otherwise, check to see if this column is one of the vector columns
|
||||
// defined by opt.vectorColumns and, if so, use the fixed size list type
|
||||
const vectorColumnOptions = opt.vectorColumns[colName];
|
||||
if (vectorColumnOptions !== undefined) {
|
||||
const firstNonNullValue = values.find((v) => v !== null);
|
||||
if (Array.isArray(firstNonNullValue)) {
|
||||
type = newVectorType(
|
||||
firstNonNullValue.length,
|
||||
vectorColumnOptions.type,
|
||||
);
|
||||
schema = new Schema(schema.fields, schemaMetadata);
|
||||
return new ArrowTable(schema);
|
||||
}
|
||||
}
|
||||
|
||||
let inferredSchema = inferSchema(data, schema, opt);
|
||||
inferredSchema = new Schema(inferredSchema.fields, schemaMetadata);
|
||||
|
||||
const finalColumns: Record<string, Vector> = {};
|
||||
for (const field of inferredSchema.fields) {
|
||||
finalColumns[field.name] = transposeData(data, field);
|
||||
}
|
||||
|
||||
return new ArrowTable(inferredSchema, finalColumns);
|
||||
}
|
||||
|
||||
function inferSchema(
|
||||
data: Array<Record<string, unknown>>,
|
||||
schema: Schema | undefined,
|
||||
opts: MakeArrowTableOptions,
|
||||
): Schema {
|
||||
// We will collect all fields we see in the data.
|
||||
const pathTree = new PathTree<DataType>();
|
||||
|
||||
for (const [rowI, row] of data.entries()) {
|
||||
for (const [path, value] of rowPathsAndValues(row)) {
|
||||
if (!pathTree.has(path)) {
|
||||
// First time seeing this field.
|
||||
if (schema !== undefined) {
|
||||
const field = getFieldForPath(schema, path);
|
||||
if (field === undefined) {
|
||||
throw new Error(
|
||||
`Found field not in schema: ${path.join(".")} at row ${rowI}`,
|
||||
);
|
||||
} else {
|
||||
pathTree.set(path, field.type);
|
||||
}
|
||||
} else {
|
||||
throw new Error(
|
||||
`Column ${colName} is expected to be a vector column but first non-null value is not an array. Could not determine size of vector column`,
|
||||
);
|
||||
const inferredType = inferType(value, path, opts);
|
||||
if (inferredType === undefined) {
|
||||
throw new Error(`Failed to infer data type for field ${path.join(".")} at row ${rowI}. \
|
||||
Consider providing an explicit schema.`);
|
||||
}
|
||||
pathTree.set(path, inferredType);
|
||||
}
|
||||
} else if (schema === undefined) {
|
||||
const currentType = pathTree.get(path);
|
||||
const newType = inferType(value, path, opts);
|
||||
if (currentType !== newType) {
|
||||
new Error(`Failed to infer schema for data. Previously inferred type \
|
||||
${currentType} but found ${newType} at row ${rowI}. Consider \
|
||||
providing an explicit schema.`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// Convert an Array of JS values to an arrow vector
|
||||
columns[colName] = makeVector(values, type, opt.dictionaryEncodeStrings);
|
||||
} catch (error: unknown) {
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
throw Error(`Could not convert column "${colName}" to Arrow: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt.schema != null) {
|
||||
// `new ArrowTable(columns)` infers a schema which may sometimes have
|
||||
// incorrect nullability (it assumes nullable=true always)
|
||||
//
|
||||
// `new ArrowTable(schema, columns)` will also fail because it will create a
|
||||
// batch with an inferred schema and then complain that the batch schema
|
||||
// does not match the provided schema.
|
||||
//
|
||||
// To work around this we first create a table with the wrong schema and
|
||||
// then patch the schema of the batches so we can use
|
||||
// `new ArrowTable(schema, batches)` which does not do any schema inference
|
||||
const firstTable = new ArrowTable(columns);
|
||||
const batchesFixed = firstTable.batches.map(
|
||||
(batch) => new RecordBatch(opt.schema as Schema, batch.data),
|
||||
);
|
||||
let schema: Schema;
|
||||
if (metadata !== undefined) {
|
||||
let schemaMetadata = opt.schema.metadata;
|
||||
if (schemaMetadata.size === 0) {
|
||||
schemaMetadata = metadata;
|
||||
} else {
|
||||
for (const [key, entry] of schemaMetadata.entries()) {
|
||||
schemaMetadata.set(key, entry);
|
||||
if (schema === undefined) {
|
||||
function fieldsFromPathTree(pathTree: PathTree<DataType>): Field[] {
|
||||
const fields = [];
|
||||
for (const [name, value] of pathTree.map.entries()) {
|
||||
if (value instanceof PathTree) {
|
||||
const children = fieldsFromPathTree(value);
|
||||
fields.push(new Field(name, new Struct(children), true));
|
||||
} else {
|
||||
fields.push(new Field(name, value, true));
|
||||
}
|
||||
}
|
||||
return fields;
|
||||
}
|
||||
const fields = fieldsFromPathTree(pathTree);
|
||||
return new Schema(fields);
|
||||
} else {
|
||||
function takeMatchingFields(
|
||||
fields: Field[],
|
||||
pathTree: PathTree<DataType>,
|
||||
): Field[] {
|
||||
const outFields = [];
|
||||
for (const field of fields) {
|
||||
if (pathTree.map.has(field.name)) {
|
||||
const value = pathTree.get([field.name]);
|
||||
if (value instanceof PathTree) {
|
||||
const struct = field.type as Struct;
|
||||
const children = takeMatchingFields(struct.children, value);
|
||||
outFields.push(
|
||||
new Field(field.name, new Struct(children), field.nullable),
|
||||
);
|
||||
} else {
|
||||
outFields.push(
|
||||
new Field(field.name, value as DataType, field.nullable),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
return outFields;
|
||||
}
|
||||
const fields = takeMatchingFields(schema.fields, pathTree);
|
||||
return new Schema(fields);
|
||||
}
|
||||
}
|
||||
|
||||
schema = new Schema(opt.schema.fields as Field[], schemaMetadata);
|
||||
function* rowPathsAndValues(
|
||||
row: Record<string, unknown>,
|
||||
basePath: string[] = [],
|
||||
): Generator<[string[], unknown]> {
|
||||
for (const [key, value] of Object.entries(row)) {
|
||||
if (isObject(value)) {
|
||||
yield* rowPathsAndValues(value, [...basePath, key]);
|
||||
} else {
|
||||
schema = opt.schema as Schema;
|
||||
yield [[...basePath, key], value];
|
||||
}
|
||||
return new ArrowTable(schema, batchesFixed);
|
||||
}
|
||||
const tbl = new ArrowTable(columns);
|
||||
if (metadata !== undefined) {
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
(<any>tbl.schema).metadata = metadata;
|
||||
}
|
||||
|
||||
function isObject(value: unknown): value is Record<string, unknown> {
|
||||
return (
|
||||
typeof value === "object" &&
|
||||
value !== null &&
|
||||
!Array.isArray(value) &&
|
||||
!(value instanceof RegExp) &&
|
||||
!(value instanceof Date) &&
|
||||
!(value instanceof Set) &&
|
||||
!(value instanceof Map) &&
|
||||
!(value instanceof Buffer)
|
||||
);
|
||||
}
|
||||
|
||||
function getFieldForPath(schema: Schema, path: string[]): Field | undefined {
|
||||
let current: Field | Schema = schema;
|
||||
for (const key of path) {
|
||||
if (current instanceof Schema) {
|
||||
const field: Field | undefined = current.fields.find(
|
||||
(f) => f.name === key,
|
||||
);
|
||||
if (field === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
current = field;
|
||||
} else if (current instanceof Field && DataType.isStruct(current.type)) {
|
||||
const struct: Struct = current.type;
|
||||
const field = struct.children.find((f) => f.name === key);
|
||||
if (field === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
current = field;
|
||||
} else {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
if (current instanceof Field) {
|
||||
return current;
|
||||
} else {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to infer which Arrow type to use for a given value.
|
||||
*
|
||||
* May return undefined if the type cannot be inferred.
|
||||
*/
|
||||
function inferType(
|
||||
value: unknown,
|
||||
path: string[],
|
||||
opts: MakeArrowTableOptions,
|
||||
): DataType | undefined {
|
||||
if (typeof value === "bigint") {
|
||||
return new Int64();
|
||||
} else if (typeof value === "number") {
|
||||
// Even if it's an integer, it's safer to assume Float64. Users can
|
||||
// always provide an explicit schema or use BigInt if they mean integer.
|
||||
return new Float64();
|
||||
} else if (typeof value === "string") {
|
||||
if (opts.dictionaryEncodeStrings) {
|
||||
return new Dictionary(new Utf8(), new Int32());
|
||||
} else {
|
||||
return new Utf8();
|
||||
}
|
||||
} else if (typeof value === "boolean") {
|
||||
return new Bool();
|
||||
} else if (value instanceof Buffer) {
|
||||
return new Binary();
|
||||
} else if (Array.isArray(value)) {
|
||||
if (value.length === 0) {
|
||||
return undefined; // Without any values we can't infer the type
|
||||
}
|
||||
if (path.length === 1 && Object.hasOwn(opts.vectorColumns, path[0])) {
|
||||
const floatType = sanitizeType(opts.vectorColumns[path[0]].type);
|
||||
return new FixedSizeList(
|
||||
value.length,
|
||||
new Field("item", floatType, true),
|
||||
);
|
||||
}
|
||||
const valueType = inferType(value[0], path, opts);
|
||||
if (valueType === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
// Try to automatically detect embedding columns.
|
||||
if (valueType instanceof Float && path[path.length - 1] === "vector") {
|
||||
// We default to Float32 for vectors.
|
||||
const child = new Field("item", new Float32(), true);
|
||||
return new FixedSizeList(value.length, child);
|
||||
} else {
|
||||
const child = new Field("item", valueType, true);
|
||||
return new List(child);
|
||||
}
|
||||
} else {
|
||||
// TODO: timestamp
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
class PathTree<V> {
|
||||
map: Map<string, V | PathTree<V>>;
|
||||
|
||||
constructor(entries?: [string[], V][]) {
|
||||
this.map = new Map();
|
||||
if (entries !== undefined) {
|
||||
for (const [path, value] of entries) {
|
||||
this.set(path, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
has(path: string[]): boolean {
|
||||
let ref: PathTree<V> = this;
|
||||
for (const part of path) {
|
||||
if (!(ref instanceof PathTree) || !ref.map.has(part)) {
|
||||
return false;
|
||||
}
|
||||
ref = ref.map.get(part) as PathTree<V>;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
get(path: string[]): V | undefined {
|
||||
let ref: PathTree<V> = this;
|
||||
for (const part of path) {
|
||||
if (!(ref instanceof PathTree) || !ref.map.has(part)) {
|
||||
return undefined;
|
||||
}
|
||||
ref = ref.map.get(part) as PathTree<V>;
|
||||
}
|
||||
return ref as V;
|
||||
}
|
||||
set(path: string[], value: V): void {
|
||||
let ref: PathTree<V> = this;
|
||||
for (const part of path.slice(0, path.length - 1)) {
|
||||
if (!ref.map.has(part)) {
|
||||
ref.map.set(part, new PathTree<V>());
|
||||
}
|
||||
ref = ref.map.get(part) as PathTree<V>;
|
||||
}
|
||||
ref.map.set(path[path.length - 1], value);
|
||||
}
|
||||
}
|
||||
|
||||
function transposeData(
|
||||
data: Record<string, unknown>[],
|
||||
field: Field,
|
||||
path: string[] = [],
|
||||
): Vector {
|
||||
if (field.type instanceof Struct) {
|
||||
const childFields = field.type.children;
|
||||
const childVectors = childFields.map((child) => {
|
||||
return transposeData(data, child, [...path, child.name]);
|
||||
});
|
||||
const structData = makeData({
|
||||
type: field.type,
|
||||
children: childVectors as unknown as ArrowData<DataType>[],
|
||||
});
|
||||
return arrowMakeVector(structData);
|
||||
} else {
|
||||
const valuesPath = [...path, field.name];
|
||||
const values = data.map((datum) => {
|
||||
let current: unknown = datum;
|
||||
for (const key of valuesPath) {
|
||||
if (isObject(current) && Object.hasOwn(current, key)) {
|
||||
current = current[key];
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return current;
|
||||
});
|
||||
return makeVector(values, field.type);
|
||||
}
|
||||
return tbl;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -491,6 +689,31 @@ function makeVector(
|
||||
): Vector<any> {
|
||||
if (type !== undefined) {
|
||||
// No need for inference, let Arrow create it
|
||||
if (type instanceof Int) {
|
||||
if (DataType.isInt(type) && type.bitWidth === 64) {
|
||||
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051
|
||||
values = values.map((v) => {
|
||||
if (v === null) {
|
||||
return v;
|
||||
} else if (typeof v === "bigint") {
|
||||
return v;
|
||||
} else if (typeof v === "number") {
|
||||
return BigInt(v);
|
||||
} else {
|
||||
return v;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Similarly, bigint isn't supported for 16 or 32-bit ints.
|
||||
values = values.map((v) => {
|
||||
if (typeof v == "bigint") {
|
||||
return Number(v);
|
||||
} else {
|
||||
return v;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
return vectorFromArray(values, type);
|
||||
}
|
||||
if (values.length === 0) {
|
||||
@@ -902,7 +1125,7 @@ function validateSchemaEmbeddings(
|
||||
schema: Schema,
|
||||
data: Array<Record<string, unknown>>,
|
||||
embeddings: EmbeddingFunctionConfig | undefined,
|
||||
) {
|
||||
): Schema {
|
||||
const fields = [];
|
||||
const missingEmbeddingFields = [];
|
||||
|
||||
|
||||
@@ -52,6 +52,8 @@ export interface CreateTableOptions {
|
||||
*
|
||||
* The default is `stable`.
|
||||
* Set to "legacy" to use the old format.
|
||||
*
|
||||
* @deprecated Pass `new_table_data_storage_version` to storageOptions instead.
|
||||
*/
|
||||
dataStorageVersion?: string;
|
||||
|
||||
@@ -61,17 +63,11 @@ export interface CreateTableOptions {
|
||||
* turning this on will make the dataset unreadable for older versions
|
||||
* of LanceDB (prior to 0.10.0). To migrate an existing dataset, instead
|
||||
* use the {@link LocalTable#migrateManifestPathsV2} method.
|
||||
*
|
||||
* @deprecated Pass `new_table_enable_v2_manifest_paths` to storageOptions instead.
|
||||
*/
|
||||
enableV2ManifestPaths?: boolean;
|
||||
|
||||
/**
|
||||
* If true then data files will be written with the legacy format
|
||||
*
|
||||
* The default is false.
|
||||
*
|
||||
* Deprecated. Use data storage version instead.
|
||||
*/
|
||||
useLegacyFormat?: boolean;
|
||||
schema?: SchemaLike;
|
||||
embeddingFunction?: EmbeddingFunctionConfig;
|
||||
}
|
||||
@@ -215,6 +211,11 @@ export abstract class Connection {
|
||||
* @param {string} name The name of the table to drop.
|
||||
*/
|
||||
abstract dropTable(name: string): Promise<void>;
|
||||
|
||||
/**
|
||||
* Drop all tables in the database.
|
||||
*/
|
||||
abstract dropAllTables(): Promise<void>;
|
||||
}
|
||||
|
||||
/** @hideconstructor */
|
||||
@@ -256,6 +257,28 @@ export class LocalConnection extends Connection {
|
||||
return new LocalTable(innerTable);
|
||||
}
|
||||
|
||||
private getStorageOptions(
|
||||
options?: Partial<CreateTableOptions>,
|
||||
): Record<string, string> | undefined {
|
||||
if (options?.dataStorageVersion !== undefined) {
|
||||
if (options.storageOptions === undefined) {
|
||||
options.storageOptions = {};
|
||||
}
|
||||
options.storageOptions["newTableDataStorageVersion"] =
|
||||
options.dataStorageVersion;
|
||||
}
|
||||
|
||||
if (options?.enableV2ManifestPaths !== undefined) {
|
||||
if (options.storageOptions === undefined) {
|
||||
options.storageOptions = {};
|
||||
}
|
||||
options.storageOptions["newTableEnableV2ManifestPaths"] =
|
||||
options.enableV2ManifestPaths ? "true" : "false";
|
||||
}
|
||||
|
||||
return cleanseStorageOptions(options?.storageOptions);
|
||||
}
|
||||
|
||||
async createTable(
|
||||
nameOrOptions:
|
||||
| string
|
||||
@@ -272,20 +295,14 @@ export class LocalConnection extends Connection {
|
||||
throw new Error("data is required");
|
||||
}
|
||||
const { buf, mode } = await parseTableData(data, options);
|
||||
let dataStorageVersion = "stable";
|
||||
if (options?.dataStorageVersion !== undefined) {
|
||||
dataStorageVersion = options.dataStorageVersion;
|
||||
} else if (options?.useLegacyFormat !== undefined) {
|
||||
dataStorageVersion = options.useLegacyFormat ? "legacy" : "stable";
|
||||
}
|
||||
|
||||
const storageOptions = this.getStorageOptions(options);
|
||||
|
||||
const innerTable = await this.inner.createTable(
|
||||
nameOrOptions,
|
||||
buf,
|
||||
mode,
|
||||
cleanseStorageOptions(options?.storageOptions),
|
||||
dataStorageVersion,
|
||||
options?.enableV2ManifestPaths,
|
||||
storageOptions,
|
||||
);
|
||||
|
||||
return new LocalTable(innerTable);
|
||||
@@ -309,22 +326,14 @@ export class LocalConnection extends Connection {
|
||||
metadata = registry.getTableMetadata([embeddingFunction]);
|
||||
}
|
||||
|
||||
let dataStorageVersion = "stable";
|
||||
if (options?.dataStorageVersion !== undefined) {
|
||||
dataStorageVersion = options.dataStorageVersion;
|
||||
} else if (options?.useLegacyFormat !== undefined) {
|
||||
dataStorageVersion = options.useLegacyFormat ? "legacy" : "stable";
|
||||
}
|
||||
|
||||
const storageOptions = this.getStorageOptions(options);
|
||||
const table = makeEmptyTable(schema, metadata);
|
||||
const buf = await fromTableToBuffer(table);
|
||||
const innerTable = await this.inner.createEmptyTable(
|
||||
name,
|
||||
buf,
|
||||
mode,
|
||||
cleanseStorageOptions(options?.storageOptions),
|
||||
dataStorageVersion,
|
||||
options?.enableV2ManifestPaths,
|
||||
storageOptions,
|
||||
);
|
||||
return new LocalTable(innerTable);
|
||||
}
|
||||
@@ -332,6 +341,10 @@ export class LocalConnection extends Connection {
|
||||
async dropTable(name: string): Promise<void> {
|
||||
return this.inner.dropTable(name);
|
||||
}
|
||||
|
||||
async dropAllTables(): Promise<void> {
|
||||
return this.inner.dropAllTables();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.15.1-beta.2",
|
||||
"version": "0.16.1-beta.2",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -2,17 +2,15 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
|
||||
use lancedb::database::CreateTableMode;
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::*;
|
||||
|
||||
use crate::error::{convert_error, NapiErrorExt};
|
||||
use crate::error::NapiErrorExt;
|
||||
use crate::table::Table;
|
||||
use crate::ConnectionOptions;
|
||||
use lancedb::connection::{
|
||||
ConnectBuilder, Connection as LanceDBConnection, CreateTableMode, LanceFileVersion,
|
||||
};
|
||||
use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection};
|
||||
use lancedb::ipc::{ipc_file_to_batches, ipc_file_to_schema};
|
||||
|
||||
#[napi]
|
||||
@@ -124,8 +122,6 @@ impl Connection {
|
||||
buf: Buffer,
|
||||
mode: String,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
data_storage_options: Option<String>,
|
||||
enable_v2_manifest_paths: Option<bool>,
|
||||
) -> napi::Result<Table> {
|
||||
let batches = ipc_file_to_batches(buf.to_vec())
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||
@@ -137,14 +133,6 @@ impl Connection {
|
||||
builder = builder.storage_option(key, value);
|
||||
}
|
||||
}
|
||||
if let Some(data_storage_option) = data_storage_options.as_ref() {
|
||||
builder = builder.data_storage_version(
|
||||
LanceFileVersion::from_str(data_storage_option).map_err(|e| convert_error(&e))?,
|
||||
);
|
||||
}
|
||||
if let Some(enable_v2_manifest_paths) = enable_v2_manifest_paths {
|
||||
builder = builder.enable_v2_manifest_paths(enable_v2_manifest_paths);
|
||||
}
|
||||
let tbl = builder.execute().await.default_error()?;
|
||||
Ok(Table::new(tbl))
|
||||
}
|
||||
@@ -156,8 +144,6 @@ impl Connection {
|
||||
schema_buf: Buffer,
|
||||
mode: String,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
data_storage_options: Option<String>,
|
||||
enable_v2_manifest_paths: Option<bool>,
|
||||
) -> napi::Result<Table> {
|
||||
let schema = ipc_file_to_schema(schema_buf.to_vec()).map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to marshal schema from JS to Rust: {}", e))
|
||||
@@ -172,14 +158,6 @@ impl Connection {
|
||||
builder = builder.storage_option(key, value);
|
||||
}
|
||||
}
|
||||
if let Some(data_storage_option) = data_storage_options.as_ref() {
|
||||
builder = builder.data_storage_version(
|
||||
LanceFileVersion::from_str(data_storage_option).map_err(|e| convert_error(&e))?,
|
||||
);
|
||||
}
|
||||
if let Some(enable_v2_manifest_paths) = enable_v2_manifest_paths {
|
||||
builder = builder.enable_v2_manifest_paths(enable_v2_manifest_paths);
|
||||
}
|
||||
let tbl = builder.execute().await.default_error()?;
|
||||
Ok(Table::new(tbl))
|
||||
}
|
||||
@@ -209,4 +187,9 @@ impl Connection {
|
||||
pub async fn drop_table(&self, name: String) -> napi::Result<()> {
|
||||
self.get_inner()?.drop_table(&name).await.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn drop_all_tables(&self) -> napi::Result<()> {
|
||||
self.get_inner()?.drop_all_tables().await.default_error()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use napi_derive::*;
|
||||
|
||||
/// Timeout configuration for remote HTTP client.
|
||||
@@ -67,6 +69,7 @@ pub struct ClientConfig {
|
||||
pub user_agent: Option<String>,
|
||||
pub retry_config: Option<RetryConfig>,
|
||||
pub timeout_config: Option<TimeoutConfig>,
|
||||
pub extra_headers: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl From<TimeoutConfig> for lancedb::remote::TimeoutConfig {
|
||||
@@ -104,6 +107,7 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
.unwrap_or(concat!("LanceDB-Node-Client/", env!("CARGO_PKG_VERSION")).to_string()),
|
||||
retry_config: config.retry_config.map(Into::into).unwrap_or_default(),
|
||||
timeout_config: config.timeout_config.map(Into::into).unwrap_or_default(),
|
||||
extra_headers: config.extra_headers.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.18.1-beta.3"
|
||||
current_version = "0.19.1-beta.3"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
2
python/.gitignore
vendored
Normal file
2
python/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# Test data created by some example tests
|
||||
data/
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.18.1-beta.3"
|
||||
version = "0.19.1-beta.3"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -29,4 +29,4 @@ doctest: ## Run documentation tests.
|
||||
|
||||
.PHONY: test
|
||||
test: ## Run tests.
|
||||
pytest python/tests -vv --durations=10 -m "not slow"
|
||||
pytest python/tests -vv --durations=10 -m "not slow and not s3_test"
|
||||
|
||||
@@ -4,7 +4,7 @@ name = "lancedb"
|
||||
dynamic = ["version"]
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.23.0b4",
|
||||
"pylance==0.23.0",
|
||||
"tqdm>=4.27.0",
|
||||
"pydantic>=1.10",
|
||||
"packaging",
|
||||
|
||||
@@ -15,8 +15,6 @@ class Connection(object):
|
||||
mode: str,
|
||||
data: pa.RecordBatchReader,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
) -> Table: ...
|
||||
async def create_empty_table(
|
||||
self,
|
||||
@@ -24,8 +22,6 @@ class Connection(object):
|
||||
mode: str,
|
||||
schema: pa.Schema,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
) -> Table: ...
|
||||
async def rename_table(self, old_name: str, new_name: str) -> None: ...
|
||||
async def drop_table(self, name: str) -> None: ...
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -66,3 +66,17 @@ class AsyncRecordBatchReader:
|
||||
batches = table.to_batches(max_chunksize=max_batch_length)
|
||||
for batch in batches:
|
||||
yield batch
|
||||
|
||||
|
||||
def peek_reader(
|
||||
reader: pa.RecordBatchReader,
|
||||
) -> Tuple[pa.RecordBatch, pa.RecordBatchReader]:
|
||||
if not isinstance(reader, pa.RecordBatchReader):
|
||||
raise TypeError("reader must be a RecordBatchReader")
|
||||
batch = reader.read_next_batch()
|
||||
|
||||
def all_batches():
|
||||
yield batch
|
||||
yield from reader
|
||||
|
||||
return batch, pa.RecordBatchReader.from_batches(batch.schema, all_batches())
|
||||
|
||||
@@ -14,6 +14,7 @@ from overrides import EnforceOverrides, override # type: ignore
|
||||
from lancedb.common import data_to_reader, sanitize_uri, validate_schema
|
||||
from lancedb.background_loop import LOOP
|
||||
|
||||
from . import __version__
|
||||
from ._lancedb import connect as lancedb_connect # type: ignore
|
||||
from .table import (
|
||||
AsyncTable,
|
||||
@@ -26,6 +27,8 @@ from .util import (
|
||||
validate_table_name,
|
||||
)
|
||||
|
||||
import deprecation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pyarrow as pa
|
||||
from .pydantic import LanceModel
|
||||
@@ -119,19 +122,11 @@ class DBConnection(EnforceOverrides):
|
||||
See available options at
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
data_storage_version: optional, str, default "stable"
|
||||
The version of the data storage format to use. Newer versions are more
|
||||
efficient but require newer versions of lance to read. The default is
|
||||
"stable" which will use the legacy v2 version. See the user guide
|
||||
for more details.
|
||||
enable_v2_manifest_paths: bool, optional, default False
|
||||
Use the new V2 manifest paths. These paths provide more efficient
|
||||
opening of datasets with many versions on object stores. WARNING:
|
||||
turning this on will make the dataset unreadable for older versions
|
||||
of LanceDB (prior to 0.13.0). To migrate an existing dataset, instead
|
||||
use the
|
||||
[Table.migrate_manifest_paths_v2][lancedb.table.Table.migrate_v2_manifest_paths]
|
||||
method.
|
||||
|
||||
Deprecated. Set `storage_options` when connecting to the database and set
|
||||
`new_table_data_storage_version` in the options.
|
||||
enable_v2_manifest_paths: optional, bool, default False
|
||||
Deprecated. Set `storage_options` when connecting to the database and set
|
||||
`new_table_enable_v2_manifest_paths` in the options.
|
||||
Returns
|
||||
-------
|
||||
LanceTable
|
||||
@@ -302,6 +297,12 @@ class DBConnection(EnforceOverrides):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def drop_all_tables(self):
|
||||
"""
|
||||
Drop all tables from the database
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._uri
|
||||
@@ -452,8 +453,6 @@ class LanceDBConnection(DBConnection):
|
||||
fill_value=fill_value,
|
||||
embedding_functions=embedding_functions,
|
||||
storage_options=storage_options,
|
||||
data_storage_version=data_storage_version,
|
||||
enable_v2_manifest_paths=enable_v2_manifest_paths,
|
||||
)
|
||||
return tbl
|
||||
|
||||
@@ -496,9 +495,19 @@ class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
LOOP.run(self._conn.drop_table(name, ignore_missing=ignore_missing))
|
||||
|
||||
@override
|
||||
def drop_all_tables(self):
|
||||
LOOP.run(self._conn.drop_all_tables())
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.15.1",
|
||||
removed_in="0.17",
|
||||
current_version=__version__,
|
||||
details="Use drop_all_tables() instead",
|
||||
)
|
||||
@override
|
||||
def drop_database(self):
|
||||
LOOP.run(self._conn.drop_database())
|
||||
LOOP.run(self._conn.drop_all_tables())
|
||||
|
||||
|
||||
class AsyncConnection(object):
|
||||
@@ -595,9 +604,6 @@ class AsyncConnection(object):
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
use_legacy_format: Optional[bool] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
) -> AsyncTable:
|
||||
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
|
||||
|
||||
@@ -640,23 +646,6 @@ class AsyncConnection(object):
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
<https://lancedb.github.io/lancedb/guides/storage/>
|
||||
data_storage_version: optional, str, default "stable"
|
||||
The version of the data storage format to use. Newer versions are more
|
||||
efficient but require newer versions of lance to read. The default is
|
||||
"stable" which will use the legacy v2 version. See the user guide
|
||||
for more details.
|
||||
use_legacy_format: bool, optional, default False. (Deprecated)
|
||||
If True, use the legacy format for the table. If False, use the new format.
|
||||
This method is deprecated, use `data_storage_version` instead.
|
||||
enable_v2_manifest_paths: bool, optional, default False
|
||||
Use the new V2 manifest paths. These paths provide more efficient
|
||||
opening of datasets with many versions on object stores. WARNING:
|
||||
turning this on will make the dataset unreadable for older versions
|
||||
of LanceDB (prior to 0.13.0). To migrate an existing dataset, instead
|
||||
use the
|
||||
[AsyncTable.migrate_manifest_paths_v2][lancedb.table.AsyncTable.migrate_manifest_paths_v2]
|
||||
method.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -795,17 +784,12 @@ class AsyncConnection(object):
|
||||
if mode == "create" and exist_ok:
|
||||
mode = "exist_ok"
|
||||
|
||||
if not data_storage_version:
|
||||
data_storage_version = "legacy" if use_legacy_format else "stable"
|
||||
|
||||
if data is None:
|
||||
new_table = await self._inner.create_empty_table(
|
||||
name,
|
||||
mode,
|
||||
schema,
|
||||
storage_options=storage_options,
|
||||
data_storage_version=data_storage_version,
|
||||
enable_v2_manifest_paths=enable_v2_manifest_paths,
|
||||
)
|
||||
else:
|
||||
data = data_to_reader(data, schema)
|
||||
@@ -814,8 +798,6 @@ class AsyncConnection(object):
|
||||
mode,
|
||||
data,
|
||||
storage_options=storage_options,
|
||||
data_storage_version=data_storage_version,
|
||||
enable_v2_manifest_paths=enable_v2_manifest_paths,
|
||||
)
|
||||
|
||||
return AsyncTable(new_table)
|
||||
@@ -885,9 +867,19 @@ class AsyncConnection(object):
|
||||
if f"Table '{name}' was not found" not in str(e):
|
||||
raise e
|
||||
|
||||
async def drop_all_tables(self):
|
||||
"""Drop all tables from the database."""
|
||||
await self._inner.drop_all_tables()
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.15.1",
|
||||
removed_in="0.17",
|
||||
current_version=__version__,
|
||||
details="Use drop_all_tables() instead",
|
||||
)
|
||||
async def drop_database(self):
|
||||
"""
|
||||
Drop database
|
||||
This is the same thing as dropping all the tables
|
||||
"""
|
||||
await self._inner.drop_db()
|
||||
await self._inner.drop_all_tables()
|
||||
|
||||
@@ -116,7 +116,7 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self):
|
||||
def ndims(self) -> int:
|
||||
"""
|
||||
Return the dimensions of the vector column
|
||||
"""
|
||||
|
||||
@@ -199,18 +199,29 @@ else:
|
||||
]
|
||||
|
||||
|
||||
def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
|
||||
if inspect.isclass(tp):
|
||||
if issubclass(tp, pydantic.BaseModel):
|
||||
# Struct
|
||||
fields = _pydantic_model_to_fields(tp)
|
||||
return pa.struct(fields)
|
||||
if issubclass(tp, FixedSizeListMixin):
|
||||
return pa.list_(tp.value_arrow_type(), tp.dim())
|
||||
return _py_type_to_arrow_type(tp, field)
|
||||
|
||||
|
||||
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
"""Convert a Pydantic FieldInfo to Arrow DataType"""
|
||||
|
||||
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
|
||||
origin = field.annotation.__origin__
|
||||
args = field.annotation.__args__
|
||||
|
||||
if origin is list:
|
||||
child = args[0]
|
||||
return pa.list_(_py_type_to_arrow_type(child, field))
|
||||
elif origin == Union:
|
||||
if len(args) == 2 and args[1] is type(None):
|
||||
return _py_type_to_arrow_type(args[0], field)
|
||||
return _pydantic_type_to_arrow_type(args[0], field)
|
||||
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
|
||||
args = field.annotation.__args__
|
||||
if len(args) == 2:
|
||||
@@ -218,14 +229,7 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
if typ is type(None):
|
||||
continue
|
||||
return _py_type_to_arrow_type(typ, field)
|
||||
elif inspect.isclass(field.annotation):
|
||||
if issubclass(field.annotation, pydantic.BaseModel):
|
||||
# Struct
|
||||
fields = _pydantic_model_to_fields(field.annotation)
|
||||
return pa.struct(fields)
|
||||
elif issubclass(field.annotation, FixedSizeListMixin):
|
||||
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
||||
return _py_type_to_arrow_type(field.annotation, field)
|
||||
return _pydantic_type_to_arrow_type(field.annotation, field)
|
||||
|
||||
|
||||
def is_nullable(field: FieldInfo) -> bool:
|
||||
|
||||
@@ -109,6 +109,7 @@ class ClientConfig:
|
||||
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
||||
retry_config: RetryConfig = field(default_factory=RetryConfig)
|
||||
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
|
||||
extra_headers: Optional[dict] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.retry_config, dict):
|
||||
|
||||
@@ -526,6 +526,9 @@ class RemoteTable(Table):
|
||||
def drop_columns(self, columns: Iterable[str]):
|
||||
return LOOP.run(self._table.drop_columns(columns))
|
||||
|
||||
def drop_index(self, index_name: str):
|
||||
return LOOP.run(self._table.drop_index(index_name))
|
||||
|
||||
def uses_v2_manifest_paths(self) -> bool:
|
||||
raise NotImplementedError(
|
||||
"uses_v2_manifest_paths() is not supported on the LanceDB Cloud"
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
@@ -23,6 +24,7 @@ from typing import (
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import lance
|
||||
from lancedb.arrow import peek_reader
|
||||
from lancedb.background_loop import LOOP
|
||||
from .dependencies import _check_for_pandas
|
||||
import pyarrow as pa
|
||||
@@ -73,17 +75,19 @@ pl = safe_import_polars()
|
||||
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
||||
|
||||
|
||||
def _into_pyarrow_table(data) -> pa.Table:
|
||||
def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
|
||||
if _check_for_hugging_face(data):
|
||||
# Huggingface datasets
|
||||
from lance.dependencies import datasets
|
||||
|
||||
if isinstance(data, datasets.Dataset):
|
||||
schema = data.features.arrow_schema
|
||||
return pa.Table.from_batches(data.data.to_batches(), schema=schema)
|
||||
return pa.RecordBatchReader.from_batches(schema, data.data.to_batches())
|
||||
elif isinstance(data, datasets.dataset_dict.DatasetDict):
|
||||
schema = _schema_from_hf(data, schema)
|
||||
return pa.Table.from_batches(_to_batches_with_split(data), schema=schema)
|
||||
return pa.RecordBatchReader.from_batches(
|
||||
schema, _to_batches_with_split(data)
|
||||
)
|
||||
if isinstance(data, LanceModel):
|
||||
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
||||
|
||||
@@ -95,41 +99,41 @@ def _into_pyarrow_table(data) -> pa.Table:
|
||||
if isinstance(data[0], LanceModel):
|
||||
schema = data[0].__class__.to_arrow_schema()
|
||||
data = [model_to_dict(d) for d in data]
|
||||
return pa.Table.from_pylist(data, schema=schema)
|
||||
return pa.Table.from_pylist(data, schema=schema).to_reader()
|
||||
elif isinstance(data[0], pa.RecordBatch):
|
||||
return pa.Table.from_batches(data)
|
||||
return pa.Table.from_batches(data).to_reader()
|
||||
else:
|
||||
return pa.Table.from_pylist(data)
|
||||
return pa.Table.from_pylist(data).to_reader()
|
||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
||||
table = pa.Table.from_pandas(data, preserve_index=False)
|
||||
# Do not serialize Pandas metadata
|
||||
meta = table.schema.metadata if table.schema.metadata is not None else {}
|
||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||
return table.replace_schema_metadata(meta)
|
||||
return table.replace_schema_metadata(meta).to_reader()
|
||||
elif isinstance(data, pa.Table):
|
||||
return data
|
||||
return data.to_reader()
|
||||
elif isinstance(data, pa.RecordBatch):
|
||||
return pa.Table.from_batches([data])
|
||||
return pa.RecordBatchReader.from_batches(data.schema, [data])
|
||||
elif isinstance(data, LanceDataset):
|
||||
return data.scanner().to_table()
|
||||
return data.scanner().to_reader()
|
||||
elif isinstance(data, pa.dataset.Dataset):
|
||||
return data.to_table()
|
||||
return data.scanner().to_reader()
|
||||
elif isinstance(data, pa.dataset.Scanner):
|
||||
return data.to_table()
|
||||
return data.to_reader()
|
||||
elif isinstance(data, pa.RecordBatchReader):
|
||||
return data.read_all()
|
||||
return data
|
||||
elif (
|
||||
type(data).__module__.startswith("polars")
|
||||
and data.__class__.__name__ == "DataFrame"
|
||||
):
|
||||
return data.to_arrow()
|
||||
return data.to_arrow().to_reader()
|
||||
elif (
|
||||
type(data).__module__.startswith("polars")
|
||||
and data.__class__.__name__ == "LazyFrame"
|
||||
):
|
||||
return data.collect().to_arrow()
|
||||
return data.collect().to_arrow().to_reader()
|
||||
elif isinstance(data, Iterable):
|
||||
return _iterator_to_table(data)
|
||||
return _iterator_to_reader(data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unknown data type {type(data)}. "
|
||||
@@ -139,30 +143,28 @@ def _into_pyarrow_table(data) -> pa.Table:
|
||||
)
|
||||
|
||||
|
||||
def _iterator_to_table(data: Iterable) -> pa.Table:
|
||||
batches = []
|
||||
schema = None # Will get schema from first batch
|
||||
for batch in data:
|
||||
batch_table = _into_pyarrow_table(batch)
|
||||
if schema is not None:
|
||||
if batch_table.schema != schema:
|
||||
def _iterator_to_reader(data: Iterable) -> pa.RecordBatchReader:
|
||||
# Each batch is treated as it's own reader, mainly so we can
|
||||
# re-use the _into_pyarrow_reader logic.
|
||||
first = _into_pyarrow_reader(next(data))
|
||||
schema = first.schema
|
||||
|
||||
def gen():
|
||||
yield from first
|
||||
for batch in data:
|
||||
table: pa.Table = _into_pyarrow_reader(batch).read_all()
|
||||
if table.schema != schema:
|
||||
try:
|
||||
batch_table = batch_table.cast(schema)
|
||||
table = table.cast(schema)
|
||||
except pa.lib.ArrowInvalid:
|
||||
raise ValueError(
|
||||
f"Input iterator yielded a batch with schema that "
|
||||
f"does not match the schema of other batches.\n"
|
||||
f"Expected:\n{schema}\nGot:\n{batch_table.schema}"
|
||||
f"Expected:\n{schema}\nGot:\n{batch.schema}"
|
||||
)
|
||||
else:
|
||||
# Use the first schema for the remainder of the batches
|
||||
schema = batch_table.schema
|
||||
batches.append(batch_table)
|
||||
yield from table.to_batches()
|
||||
|
||||
if batches:
|
||||
return pa.concat_tables(batches)
|
||||
else:
|
||||
raise ValueError("Input iterable is empty")
|
||||
return pa.RecordBatchReader.from_batches(schema, gen())
|
||||
|
||||
|
||||
def _sanitize_data(
|
||||
@@ -173,7 +175,7 @@ def _sanitize_data(
|
||||
fill_value: float = 0.0,
|
||||
*,
|
||||
allow_subschema: bool = False,
|
||||
) -> pa.Table:
|
||||
) -> pa.RecordBatchReader:
|
||||
"""
|
||||
Handle input data, applying all standard transformations.
|
||||
|
||||
@@ -206,20 +208,20 @@ def _sanitize_data(
|
||||
# 1. There might be embedding columns missing that will be added
|
||||
# in the add_embeddings step.
|
||||
# 2. If `allow_subschemas` is True, there might be columns missing.
|
||||
table = _into_pyarrow_table(data)
|
||||
reader = _into_pyarrow_reader(data)
|
||||
|
||||
table = _append_vector_columns(table, target_schema, metadata=metadata)
|
||||
reader = _append_vector_columns(reader, target_schema, metadata=metadata)
|
||||
|
||||
# This happens before the cast so we can fix vector columns with
|
||||
# incorrect lengths before they are cast to FSL.
|
||||
table = _handle_bad_vectors(
|
||||
table,
|
||||
reader = _handle_bad_vectors(
|
||||
reader,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
if target_schema is None:
|
||||
target_schema = _infer_target_schema(table)
|
||||
target_schema, reader = _infer_target_schema(reader)
|
||||
|
||||
if metadata:
|
||||
new_metadata = target_schema.metadata or {}
|
||||
@@ -228,25 +230,25 @@ def _sanitize_data(
|
||||
|
||||
_validate_schema(target_schema)
|
||||
|
||||
table = _cast_to_target_schema(table, target_schema, allow_subschema)
|
||||
reader = _cast_to_target_schema(reader, target_schema, allow_subschema)
|
||||
|
||||
return table
|
||||
return reader
|
||||
|
||||
|
||||
def _cast_to_target_schema(
|
||||
table: pa.Table,
|
||||
reader: pa.RecordBatchReader,
|
||||
target_schema: pa.Schema,
|
||||
allow_subschema: bool = False,
|
||||
) -> pa.Table:
|
||||
) -> pa.RecordBatchReader:
|
||||
# pa.Table.cast expects field order not to be changed.
|
||||
# Lance doesn't care about field order, so we don't need to rearrange fields
|
||||
# to match the target schema. We just need to correctly cast the fields.
|
||||
if table.schema == target_schema:
|
||||
if reader.schema == target_schema:
|
||||
# Fast path when the schemas are already the same
|
||||
return table
|
||||
return reader
|
||||
|
||||
fields = []
|
||||
for field in table.schema:
|
||||
for field in reader.schema:
|
||||
target_field = target_schema.field(field.name)
|
||||
if target_field is None:
|
||||
raise ValueError(f"Field {field.name} not found in target schema")
|
||||
@@ -259,12 +261,16 @@ def _cast_to_target_schema(
|
||||
|
||||
if allow_subschema and len(reordered_schema) != len(target_schema):
|
||||
fields = _infer_subschema(
|
||||
list(iter(table.schema)), list(iter(reordered_schema))
|
||||
list(iter(reader.schema)), list(iter(reordered_schema))
|
||||
)
|
||||
subschema = pa.schema(fields, metadata=target_schema.metadata)
|
||||
return table.cast(subschema)
|
||||
else:
|
||||
return table.cast(reordered_schema)
|
||||
reordered_schema = pa.schema(fields, metadata=target_schema.metadata)
|
||||
|
||||
def gen():
|
||||
for batch in reader:
|
||||
# Table but not RecordBatch has cast.
|
||||
yield pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()[0]
|
||||
|
||||
return pa.RecordBatchReader.from_batches(reordered_schema, gen())
|
||||
|
||||
|
||||
def _infer_subschema(
|
||||
@@ -343,7 +349,10 @@ def sanitize_create_table(
|
||||
if metadata:
|
||||
schema = schema.with_metadata(metadata)
|
||||
# Need to apply metadata to the data as well
|
||||
data = data.replace_schema_metadata(metadata)
|
||||
if isinstance(data, pa.Table):
|
||||
data = data.replace_schema_metadata(metadata)
|
||||
elif isinstance(data, pa.RecordBatchReader):
|
||||
data = pa.RecordBatchReader.from_batches(schema, data)
|
||||
|
||||
return data, schema
|
||||
|
||||
@@ -380,11 +389,11 @@ def _to_batches_with_split(data):
|
||||
|
||||
|
||||
def _append_vector_columns(
|
||||
data: pa.Table,
|
||||
reader: pa.RecordBatchReader,
|
||||
schema: Optional[pa.Schema] = None,
|
||||
*,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> pa.Table:
|
||||
) -> pa.RecordBatchReader:
|
||||
"""
|
||||
Use the embedding function to automatically embed the source columns and add the
|
||||
vector columns to the table.
|
||||
@@ -395,28 +404,43 @@ def _append_vector_columns(
|
||||
metadata = schema.metadata or metadata or {}
|
||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||
|
||||
if not functions:
|
||||
return reader
|
||||
|
||||
fields = list(reader.schema)
|
||||
for vector_column, conf in functions.items():
|
||||
func = conf.function
|
||||
no_vector_column = vector_column not in data.column_names
|
||||
if no_vector_column or pc.all(pc.is_null(data[vector_column])).as_py():
|
||||
col_data = func.compute_source_embeddings_with_retry(
|
||||
data[conf.source_column]
|
||||
)
|
||||
if vector_column not in reader.schema.names:
|
||||
if schema is not None:
|
||||
dtype = schema.field(vector_column).type
|
||||
field = schema.field(vector_column)
|
||||
else:
|
||||
dtype = pa.list_(pa.float32(), len(col_data[0]))
|
||||
if no_vector_column:
|
||||
data = data.append_column(
|
||||
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
|
||||
)
|
||||
else:
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column),
|
||||
pa.field(vector_column, type=dtype),
|
||||
pa.array(col_data, type=dtype),
|
||||
)
|
||||
return data
|
||||
dtype = pa.list_(pa.float32(), conf.function.ndims())
|
||||
field = pa.field(vector_column, type=dtype, nullable=True)
|
||||
fields.append(field)
|
||||
schema = pa.schema(fields, metadata=reader.schema.metadata)
|
||||
|
||||
def gen():
|
||||
for batch in reader:
|
||||
for vector_column, conf in functions.items():
|
||||
func = conf.function
|
||||
no_vector_column = vector_column not in batch.column_names
|
||||
if no_vector_column or pc.all(pc.is_null(batch[vector_column])).as_py():
|
||||
col_data = func.compute_source_embeddings_with_retry(
|
||||
batch[conf.source_column]
|
||||
)
|
||||
if no_vector_column:
|
||||
batch = batch.append_column(
|
||||
schema.field(vector_column),
|
||||
pa.array(col_data, type=schema.field(vector_column).type),
|
||||
)
|
||||
else:
|
||||
batch = batch.set_column(
|
||||
batch.column_names.index(vector_column),
|
||||
schema.field(vector_column),
|
||||
pa.array(col_data, type=schema.field(vector_column).type),
|
||||
)
|
||||
yield batch
|
||||
|
||||
return pa.RecordBatchReader.from_batches(schema, gen())
|
||||
|
||||
|
||||
def _table_path(base: str, table_name: str) -> str:
|
||||
@@ -2085,10 +2109,37 @@ class LanceTable(Table):
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
embedding_functions: list of EmbeddingFunctionModel, default None
|
||||
The embedding functions to use when creating the table.
|
||||
data_storage_version: optional, str, default "stable"
|
||||
Deprecated. Set `storage_options` when connecting to the database and set
|
||||
`new_table_data_storage_version` in the options.
|
||||
enable_v2_manifest_paths: optional, bool, default False
|
||||
Deprecated. Set `storage_options` when connecting to the database and set
|
||||
`new_table_enable_v2_manifest_paths` in the options.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
self._conn = db
|
||||
|
||||
if data_storage_version is not None:
|
||||
warnings.warn(
|
||||
"setting data_storage_version directly on create_table is deprecated. ",
|
||||
"Use database_options instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if storage_options is None:
|
||||
storage_options = {}
|
||||
storage_options["new_table_data_storage_version"] = data_storage_version
|
||||
if enable_v2_manifest_paths is not None:
|
||||
warnings.warn(
|
||||
"setting enable_v2_manifest_paths directly on create_table is ",
|
||||
"deprecated. Use database_options instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if storage_options is None:
|
||||
storage_options = {}
|
||||
storage_options["new_table_enable_v2_manifest_paths"] = (
|
||||
enable_v2_manifest_paths
|
||||
)
|
||||
|
||||
self._table = LOOP.run(
|
||||
self._conn._conn.create_table(
|
||||
name,
|
||||
@@ -2100,8 +2151,6 @@ class LanceTable(Table):
|
||||
fill_value=fill_value,
|
||||
embedding_functions=embedding_functions,
|
||||
storage_options=storage_options,
|
||||
data_storage_version=data_storage_version,
|
||||
enable_v2_manifest_paths=enable_v2_manifest_paths,
|
||||
)
|
||||
)
|
||||
return self
|
||||
@@ -2332,11 +2381,13 @@ class LanceTable(Table):
|
||||
|
||||
|
||||
def _handle_bad_vectors(
|
||||
table: pa.Table,
|
||||
reader: pa.RecordBatchReader,
|
||||
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
for field in table.schema:
|
||||
) -> pa.RecordBatchReader:
|
||||
vector_columns = []
|
||||
|
||||
for field in reader.schema:
|
||||
# They can provide a 'vector' column that isn't yet a FSL
|
||||
named_vector_col = (
|
||||
(
|
||||
@@ -2356,22 +2407,28 @@ def _handle_bad_vectors(
|
||||
)
|
||||
|
||||
if named_vector_col or likely_vector_col:
|
||||
table = _handle_bad_vector_column(
|
||||
table,
|
||||
vector_column_name=field.name,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
vector_columns.append(field.name)
|
||||
|
||||
return table
|
||||
def gen():
|
||||
for batch in reader:
|
||||
for name in vector_columns:
|
||||
batch = _handle_bad_vector_column(
|
||||
batch,
|
||||
vector_column_name=name,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
yield batch
|
||||
|
||||
return pa.RecordBatchReader.from_batches(reader.schema, gen())
|
||||
|
||||
|
||||
def _handle_bad_vector_column(
|
||||
data: pa.Table,
|
||||
data: pa.RecordBatch,
|
||||
vector_column_name: str,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
) -> pa.RecordBatch:
|
||||
"""
|
||||
Ensure that the vector column exists and has type fixed_size_list(float)
|
||||
|
||||
@@ -2459,8 +2516,11 @@ def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray
|
||||
return pc.is_in(indices, has_nan_indices)
|
||||
|
||||
|
||||
def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
||||
schema = table.schema
|
||||
def _infer_target_schema(
|
||||
reader: pa.RecordBatchReader,
|
||||
) -> Tuple[pa.Schema, pa.RecordBatchReader]:
|
||||
schema = reader.schema
|
||||
peeked = None
|
||||
|
||||
for i, field in enumerate(schema):
|
||||
if (
|
||||
@@ -2468,8 +2528,10 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
||||
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
|
||||
and pa.types.is_floating(field.type.value_type)
|
||||
):
|
||||
if peeked is None:
|
||||
peeked, reader = peek_reader(reader)
|
||||
# Use the most common length of the list as the dimensions
|
||||
dim = _modal_list_size(table.column(i))
|
||||
dim = _modal_list_size(peeked.column(i))
|
||||
|
||||
new_field = pa.field(
|
||||
VECTOR_COLUMN_NAME,
|
||||
@@ -2483,8 +2545,10 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
||||
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
|
||||
and pa.types.is_integer(field.type.value_type)
|
||||
):
|
||||
if peeked is None:
|
||||
peeked, reader = peek_reader(reader)
|
||||
# Use the most common length of the list as the dimensions
|
||||
dim = _modal_list_size(table.column(i))
|
||||
dim = _modal_list_size(peeked.column(i))
|
||||
new_field = pa.field(
|
||||
VECTOR_COLUMN_NAME,
|
||||
pa.list_(pa.uint8(), dim),
|
||||
@@ -2493,7 +2557,7 @@ def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
||||
|
||||
schema = schema.set(i, new_field)
|
||||
|
||||
return schema
|
||||
return schema, reader
|
||||
|
||||
|
||||
def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int:
|
||||
|
||||
@@ -299,12 +299,12 @@ def test_create_exist_ok(tmp_db: lancedb.DBConnection):
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
assert str(db) == f"NativeDatabase(uri={tmp_path}, read_consistency_interval=None)"
|
||||
assert str(db) == f"ListingDatabase(uri={tmp_path}, read_consistency_interval=None)"
|
||||
|
||||
db = await lancedb.connect_async(
|
||||
tmp_path, read_consistency_interval=timedelta(seconds=5)
|
||||
)
|
||||
assert str(db) == f"NativeDatabase(uri={tmp_path}, read_consistency_interval=5s)"
|
||||
assert str(db) == f"ListingDatabase(uri={tmp_path}, read_consistency_interval=5s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -396,13 +396,16 @@ async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_table_v2_manifest_paths_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
db_with_v2_paths = await lancedb.connect_async(
|
||||
tmp_path, storage_options={"new_table_enable_v2_manifest_paths": "true"}
|
||||
)
|
||||
db_no_v2_paths = await lancedb.connect_async(
|
||||
tmp_path, storage_options={"new_table_enable_v2_manifest_paths": "false"}
|
||||
)
|
||||
# Create table in v2 mode with v2 manifest paths enabled
|
||||
tbl = await db.create_table(
|
||||
tbl = await db_with_v2_paths.create_table(
|
||||
"test_v2_manifest_paths",
|
||||
data=[{"id": 0}],
|
||||
use_legacy_format=False,
|
||||
enable_v2_manifest_paths=True,
|
||||
)
|
||||
assert await tbl.uses_v2_manifest_paths()
|
||||
manifests_dir = tmp_path / "test_v2_manifest_paths.lance" / "_versions"
|
||||
@@ -410,11 +413,9 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
|
||||
assert re.match(r"\d{20}\.manifest", manifest)
|
||||
|
||||
# Start a table in V1 mode then migrate
|
||||
tbl = await db.create_table(
|
||||
tbl = await db_no_v2_paths.create_table(
|
||||
"test_v2_migration",
|
||||
data=[{"id": 0}],
|
||||
use_legacy_format=False,
|
||||
enable_v2_manifest_paths=False,
|
||||
)
|
||||
assert not await tbl.uses_v2_manifest_paths()
|
||||
manifests_dir = tmp_path / "test_v2_migration.lance" / "_versions"
|
||||
@@ -498,6 +499,10 @@ def test_delete_table(tmp_db: lancedb.DBConnection):
|
||||
# if ignore_missing=True
|
||||
tmp_db.drop_table("does_not_exist", ignore_missing=True)
|
||||
|
||||
tmp_db.drop_all_tables()
|
||||
|
||||
assert tmp_db.table_names() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_table_async(tmp_db: lancedb.DBConnection):
|
||||
@@ -583,7 +588,7 @@ def test_empty_or_nonexistent_table(mem_db: lancedb.DBConnection):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
|
||||
async def test_create_in_v2_mode():
|
||||
def make_data():
|
||||
for i in range(10):
|
||||
yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"])
|
||||
@@ -594,10 +599,13 @@ async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
|
||||
schema = pa.schema([pa.field("x", pa.int64())])
|
||||
|
||||
# Create table in v1 mode
|
||||
tbl = await mem_db_async.create_table(
|
||||
"test", data=make_data(), schema=schema, data_storage_version="legacy"
|
||||
|
||||
v1_db = await lancedb.connect_async(
|
||||
"memory://", storage_options={"new_table_data_storage_version": "legacy"}
|
||||
)
|
||||
|
||||
tbl = await v1_db.create_table("test", data=make_data(), schema=schema)
|
||||
|
||||
async def is_in_v2_mode(tbl):
|
||||
batches = (
|
||||
await tbl.query().limit(10 * 1024).to_batches(max_batch_length=1024 * 10)
|
||||
@@ -610,10 +618,12 @@ async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
|
||||
assert not await is_in_v2_mode(tbl)
|
||||
|
||||
# Create table in v2 mode
|
||||
tbl = await mem_db_async.create_table(
|
||||
"test_v2", data=make_data(), schema=schema, use_legacy_format=False
|
||||
v2_db = await lancedb.connect_async(
|
||||
"memory://", storage_options={"new_table_data_storage_version": "stable"}
|
||||
)
|
||||
|
||||
tbl = await v2_db.create_table("test_v2", data=make_data(), schema=schema)
|
||||
|
||||
assert await is_in_v2_mode(tbl)
|
||||
|
||||
# Add data (should remain in v2 mode)
|
||||
@@ -622,20 +632,18 @@ async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
|
||||
assert await is_in_v2_mode(tbl)
|
||||
|
||||
# Create empty table in v2 mode and add data
|
||||
tbl = await mem_db_async.create_table(
|
||||
"test_empty_v2", data=None, schema=schema, use_legacy_format=False
|
||||
)
|
||||
tbl = await v2_db.create_table("test_empty_v2", data=None, schema=schema)
|
||||
await tbl.add(make_table())
|
||||
|
||||
assert await is_in_v2_mode(tbl)
|
||||
|
||||
# Create empty table uses v1 mode by default
|
||||
tbl = await mem_db_async.create_table(
|
||||
"test_empty_v2_default", data=None, schema=schema, data_storage_version="legacy"
|
||||
)
|
||||
# Db uses v2 mode by default
|
||||
db = await lancedb.connect_async("memory://")
|
||||
|
||||
tbl = await db.create_table("test_empty_v2_default", data=None, schema=schema)
|
||||
await tbl.add(make_table())
|
||||
|
||||
assert not await is_in_v2_mode(tbl)
|
||||
assert await is_in_v2_mode(tbl)
|
||||
|
||||
|
||||
def test_replace_index(mem_db: lancedb.DBConnection):
|
||||
|
||||
@@ -107,7 +107,7 @@ def test_embedding_with_bad_results(tmp_path):
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
table = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(RuntimeError):
|
||||
# Default on_bad_vectors is "error"
|
||||
table.add([{"text": "hello world"}])
|
||||
|
||||
@@ -341,6 +341,7 @@ def test_add_optional_vector(tmp_path):
|
||||
assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(
|
||||
"embedding_type",
|
||||
[
|
||||
|
||||
@@ -10,6 +10,7 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
@@ -252,3 +253,104 @@ def test_lance_model():
|
||||
|
||||
t = TestModel()
|
||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||
|
||||
|
||||
def test_optional_nested_model():
|
||||
class WAMedia(BaseModel):
|
||||
url: str
|
||||
mimetype: str
|
||||
filename: Optional[str]
|
||||
error: Optional[str]
|
||||
data: bytes
|
||||
|
||||
class WALocation(BaseModel):
|
||||
description: Optional[str]
|
||||
latitude: str
|
||||
longitude: str
|
||||
|
||||
class ReplyToMessage(BaseModel):
|
||||
id: str
|
||||
participant: str
|
||||
body: str
|
||||
|
||||
class Message(BaseModel):
|
||||
id: str
|
||||
timestamp: int
|
||||
from_: str
|
||||
fromMe: bool
|
||||
to: str
|
||||
body: str
|
||||
hasMedia: Optional[bool]
|
||||
media: WAMedia
|
||||
mediaUrl: Optional[str]
|
||||
ack: Optional[int]
|
||||
ackName: Optional[str]
|
||||
author: Optional[str]
|
||||
location: Optional[WALocation]
|
||||
vCards: Optional[List[str]]
|
||||
replyTo: Optional[ReplyToMessage]
|
||||
|
||||
class AnyEvent(LanceModel):
|
||||
id: str
|
||||
session: str
|
||||
metadata: Optional[str] = None
|
||||
engine: str
|
||||
event: str
|
||||
|
||||
class MessageEvent(AnyEvent):
|
||||
payload: Message
|
||||
|
||||
schema = pydantic_to_schema(MessageEvent)
|
||||
|
||||
payload = schema.field("payload")
|
||||
assert payload.type == pa.struct(
|
||||
[
|
||||
pa.field("id", pa.utf8(), False),
|
||||
pa.field("timestamp", pa.int64(), False),
|
||||
pa.field("from_", pa.utf8(), False),
|
||||
pa.field("fromMe", pa.bool_(), False),
|
||||
pa.field("to", pa.utf8(), False),
|
||||
pa.field("body", pa.utf8(), False),
|
||||
pa.field("hasMedia", pa.bool_(), True),
|
||||
pa.field(
|
||||
"media",
|
||||
pa.struct(
|
||||
[
|
||||
pa.field("url", pa.utf8(), False),
|
||||
pa.field("mimetype", pa.utf8(), False),
|
||||
pa.field("filename", pa.utf8(), True),
|
||||
pa.field("error", pa.utf8(), True),
|
||||
pa.field("data", pa.binary(), False),
|
||||
]
|
||||
),
|
||||
False,
|
||||
),
|
||||
pa.field("mediaUrl", pa.utf8(), True),
|
||||
pa.field("ack", pa.int64(), True),
|
||||
pa.field("ackName", pa.utf8(), True),
|
||||
pa.field("author", pa.utf8(), True),
|
||||
pa.field(
|
||||
"location",
|
||||
pa.struct(
|
||||
[
|
||||
pa.field("description", pa.utf8(), True),
|
||||
pa.field("latitude", pa.utf8(), False),
|
||||
pa.field("longitude", pa.utf8(), False),
|
||||
]
|
||||
),
|
||||
True, # Optional
|
||||
),
|
||||
pa.field("vCards", pa.list_(pa.utf8()), True),
|
||||
pa.field(
|
||||
"replyTo",
|
||||
pa.struct(
|
||||
[
|
||||
pa.field("id", pa.utf8(), False),
|
||||
pa.field("participant", pa.utf8(), False),
|
||||
pa.field("body", pa.utf8(), False),
|
||||
]
|
||||
),
|
||||
True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -232,6 +232,71 @@ async def test_distance_range_async(table_async: AsyncTable):
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distance_range_with_new_rows_async():
|
||||
conn = await lancedb.connect_async(
|
||||
"memory://", read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": pa.FixedShapeTensorArray.from_numpy_ndarray(
|
||||
np.random.rand(256, 2)
|
||||
),
|
||||
}
|
||||
)
|
||||
table = await conn.create_table("test", data)
|
||||
table.create_index("vector", config=IvfPq(num_partitions=1, num_sub_vectors=2))
|
||||
|
||||
q = [0, 0]
|
||||
rs = await table.query().nearest_to(q).to_arrow()
|
||||
dists = rs["_distance"].to_pylist()
|
||||
min_dist = dists[0]
|
||||
max_dist = dists[-1]
|
||||
|
||||
# append more rows so that execution plan would be mixed with ANN & Flat KNN
|
||||
new_data = pa.table(
|
||||
{
|
||||
"vector": pa.FixedShapeTensorArray.from_numpy_ndarray(np.random.rand(4, 2)),
|
||||
}
|
||||
)
|
||||
await table.add(new_data)
|
||||
|
||||
res = (
|
||||
await table.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 0
|
||||
|
||||
res = (
|
||||
await table.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
for dist in res["_distance"].to_pylist():
|
||||
assert dist >= max_dist
|
||||
|
||||
res = (
|
||||
await table.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
for dist in res["_distance"].to_pylist():
|
||||
assert dist < max_dist
|
||||
|
||||
res = (
|
||||
await table.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
for dist in res["_distance"].to_pylist():
|
||||
assert dist >= min_dist
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"multivec_table", [pa.float16(), pa.float32(), pa.float64()], indirect=True
|
||||
)
|
||||
|
||||
@@ -32,15 +32,16 @@ def make_mock_http_handler(handler):
|
||||
@contextlib.contextmanager
|
||||
def mock_lancedb_connection(handler):
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 8080), make_mock_http_handler(handler)
|
||||
("localhost", 0), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
port = server.server_address[1]
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override="http://localhost:8080",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 2},
|
||||
"timeout_config": {
|
||||
@@ -57,22 +58,24 @@ def mock_lancedb_connection(handler):
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def mock_lancedb_connection_async(handler):
|
||||
async def mock_lancedb_connection_async(handler, **client_config):
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 8080), make_mock_http_handler(handler)
|
||||
("localhost", 0), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
port = server.server_address[1]
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
|
||||
db = await lancedb.connect_async(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override="http://localhost:8080",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 2},
|
||||
"timeout_config": {
|
||||
"connect_timeout": 1,
|
||||
},
|
||||
**client_config,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -254,6 +257,9 @@ def test_table_create_indices():
|
||||
)
|
||||
)
|
||||
request.wfile.write(payload.encode())
|
||||
elif "/drop/" in request.path:
|
||||
request.send_response(200)
|
||||
request.end_headers()
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
@@ -265,6 +271,9 @@ def test_table_create_indices():
|
||||
table.create_scalar_index("id")
|
||||
table.create_fts_index("text")
|
||||
table.create_scalar_index("vector")
|
||||
table.drop_index("vector_idx")
|
||||
table.drop_index("id_idx")
|
||||
table.drop_index("text_idx")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -522,3 +531,19 @@ def test_create_client():
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
lancedb.connect(**mandatory_args, request_thread_pool=10)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass_through_headers():
|
||||
def handler(request):
|
||||
assert request.headers["foo"] == "bar"
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
async with mock_lancedb_connection_async(
|
||||
handler, extra_headers={"foo": "bar"}
|
||||
) as db:
|
||||
table_names = await db.table_names()
|
||||
assert table_names == []
|
||||
|
||||
@@ -14,7 +14,7 @@ from lancedb.table import (
|
||||
_append_vector_columns,
|
||||
_cast_to_target_schema,
|
||||
_handle_bad_vectors,
|
||||
_into_pyarrow_table,
|
||||
_into_pyarrow_reader,
|
||||
_sanitize_data,
|
||||
_infer_target_schema,
|
||||
)
|
||||
@@ -145,19 +145,19 @@ def test_append_vector_columns():
|
||||
schema=schema,
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
data.to_reader(),
|
||||
schema, # metadata passed separate from schema
|
||||
metadata=metadata,
|
||||
)
|
||||
).read_all()
|
||||
assert output.schema == schema
|
||||
assert output["vector"].null_count == 0
|
||||
|
||||
# Adds if missing
|
||||
data = pa.table({"text": ["hello"]})
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
data.to_reader(),
|
||||
schema.with_metadata(metadata),
|
||||
)
|
||||
).read_all()
|
||||
assert output.schema == schema
|
||||
assert output["vector"].null_count == 0
|
||||
|
||||
@@ -170,9 +170,9 @@ def test_append_vector_columns():
|
||||
schema=schema,
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
data.to_reader(),
|
||||
schema.with_metadata(metadata),
|
||||
)
|
||||
).read_all()
|
||||
assert output == data # No change
|
||||
|
||||
# No provided schema
|
||||
@@ -182,9 +182,9 @@ def test_append_vector_columns():
|
||||
}
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
data.to_reader(),
|
||||
metadata=metadata,
|
||||
)
|
||||
).read_all()
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"text": pa.string(),
|
||||
@@ -204,9 +204,9 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
|
||||
if on_bad_vectors == "error":
|
||||
with pytest.raises(ValueError) as e:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
data.to_reader(),
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
)
|
||||
).read_all()
|
||||
output = exception_output(e)
|
||||
assert output == (
|
||||
"ValueError: Vector column 'vector' has variable length vectors. Set "
|
||||
@@ -217,10 +217,10 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
|
||||
return
|
||||
else:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
data.to_reader(),
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=42.0,
|
||||
)
|
||||
).read_all()
|
||||
|
||||
if on_bad_vectors == "drop":
|
||||
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
|
||||
@@ -240,9 +240,9 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
|
||||
if on_bad_vectors == "error":
|
||||
with pytest.raises(ValueError) as e:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
data.to_reader(),
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
)
|
||||
).read_all()
|
||||
output = exception_output(e)
|
||||
assert output == (
|
||||
"ValueError: Vector column 'vector' has NaNs. Set "
|
||||
@@ -253,10 +253,10 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
|
||||
return
|
||||
else:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
data.to_reader(),
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=42.0,
|
||||
)
|
||||
).read_all()
|
||||
|
||||
if on_bad_vectors == "drop":
|
||||
expected = pa.array([[3.0, 4.0]])
|
||||
@@ -274,7 +274,7 @@ def test_handle_bad_vectors_noop():
|
||||
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
|
||||
)
|
||||
data = pa.table({"vector": vector})
|
||||
output = _handle_bad_vectors(data)
|
||||
output = _handle_bad_vectors(data.to_reader()).read_all()
|
||||
assert output["vector"] == vector
|
||||
|
||||
|
||||
@@ -325,7 +325,7 @@ class TestModel(lancedb.pydantic.LanceModel):
|
||||
)
|
||||
def test_into_pyarrow_table(data):
|
||||
expected = pa.table({"a": [1], "b": [2]})
|
||||
output = _into_pyarrow_table(data())
|
||||
output = _into_pyarrow_reader(data()).read_all()
|
||||
assert output == expected
|
||||
|
||||
|
||||
@@ -349,7 +349,7 @@ def test_infer_target_schema():
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _infer_target_schema(data)
|
||||
output, _ = _infer_target_schema(data.to_reader())
|
||||
assert output == expected
|
||||
|
||||
# Handle large list and use modal size
|
||||
@@ -370,7 +370,7 @@ def test_infer_target_schema():
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _infer_target_schema(data)
|
||||
output, _ = _infer_target_schema(data.to_reader())
|
||||
assert output == expected
|
||||
|
||||
# ignore if not list
|
||||
@@ -386,7 +386,7 @@ def test_infer_target_schema():
|
||||
schema=example,
|
||||
)
|
||||
expected = example
|
||||
output = _infer_target_schema(data)
|
||||
output, _ = _infer_target_schema(data.to_reader())
|
||||
assert output == expected
|
||||
|
||||
|
||||
@@ -476,7 +476,7 @@ def test_sanitize_data(
|
||||
target_schema=schema,
|
||||
metadata=metadata,
|
||||
allow_subschema=True,
|
||||
)
|
||||
).read_all()
|
||||
|
||||
assert output_data == expected
|
||||
|
||||
@@ -519,7 +519,7 @@ def test_cast_to_target_schema():
|
||||
"vec2": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _cast_to_target_schema(data, target)
|
||||
output = _cast_to_target_schema(data.to_reader(), target)
|
||||
expected = pa.table(
|
||||
{
|
||||
"id": [1],
|
||||
@@ -550,8 +550,10 @@ def test_cast_to_target_schema():
|
||||
}
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
_cast_to_target_schema(data, target)
|
||||
output = _cast_to_target_schema(data, target, allow_subschema=True)
|
||||
_cast_to_target_schema(data.to_reader(), target)
|
||||
output = _cast_to_target_schema(
|
||||
data.to_reader(), target, allow_subschema=True
|
||||
).read_all()
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
@@ -576,3 +578,22 @@ def test_cast_to_target_schema():
|
||||
schema=expected_schema,
|
||||
)
|
||||
assert output == expected
|
||||
|
||||
|
||||
def test_sanitize_data_stream():
|
||||
# Make sure we don't collect the whole stream when running sanitize_data
|
||||
schema = pa.schema({"a": pa.int32()})
|
||||
|
||||
def stream():
|
||||
yield pa.record_batch([pa.array([1, 2, 3])], schema=schema)
|
||||
raise ValueError("error")
|
||||
|
||||
reader = pa.RecordBatchReader.from_batches(schema, stream())
|
||||
|
||||
output = _sanitize_data(reader)
|
||||
|
||||
first = next(output)
|
||||
assert first == pa.record_batch([pa.array([1, 2, 3])], schema=schema)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
next(output)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||
use lancedb::connection::{Connection as LanceConnection, CreateTableMode, LanceFileVersion};
|
||||
use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pyfunction, pymethods, Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
|
||||
@@ -80,15 +80,13 @@ impl Connection {
|
||||
future_into_py(self_.py(), async move { op.execute().await.infer_error() })
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, mode, data, storage_options=None, data_storage_version=None, enable_v2_manifest_paths=None))]
|
||||
#[pyo3(signature = (name, mode, data, storage_options=None))]
|
||||
pub fn create_table<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
name: String,
|
||||
mode: &str,
|
||||
data: Bound<'_, PyAny>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
data_storage_version: Option<String>,
|
||||
enable_v2_manifest_paths: Option<bool>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
|
||||
@@ -101,32 +99,19 @@ impl Connection {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
|
||||
if let Some(enable_v2_manifest_paths) = enable_v2_manifest_paths {
|
||||
builder = builder.enable_v2_manifest_paths(enable_v2_manifest_paths);
|
||||
}
|
||||
|
||||
if let Some(data_storage_version) = data_storage_version.as_ref() {
|
||||
builder = builder.data_storage_version(
|
||||
LanceFileVersion::from_str(data_storage_version)
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?,
|
||||
);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
let table = builder.execute().await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, mode, schema, storage_options=None, data_storage_version=None, enable_v2_manifest_paths=None))]
|
||||
#[pyo3(signature = (name, mode, schema, storage_options=None))]
|
||||
pub fn create_empty_table<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
name: String,
|
||||
mode: &str,
|
||||
schema: Bound<'_, PyAny>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
data_storage_version: Option<String>,
|
||||
enable_v2_manifest_paths: Option<bool>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
|
||||
@@ -140,17 +125,6 @@ impl Connection {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
|
||||
if let Some(enable_v2_manifest_paths) = enable_v2_manifest_paths {
|
||||
builder = builder.enable_v2_manifest_paths(enable_v2_manifest_paths);
|
||||
}
|
||||
|
||||
if let Some(data_storage_version) = data_storage_version.as_ref() {
|
||||
builder = builder.data_storage_version(
|
||||
LanceFileVersion::from_str(data_storage_version)
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?,
|
||||
);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
let table = builder.execute().await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
@@ -196,12 +170,11 @@ impl Connection {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn drop_db(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
pub fn drop_all_tables(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(
|
||||
self_.py(),
|
||||
async move { inner.drop_db().await.infer_error() },
|
||||
)
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.drop_all_tables().await.infer_error()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,6 +222,7 @@ pub struct PyClientConfig {
|
||||
user_agent: String,
|
||||
retry_config: Option<PyClientRetryConfig>,
|
||||
timeout_config: Option<PyClientTimeoutConfig>,
|
||||
extra_headers: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
@@ -300,6 +274,7 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
|
||||
user_agent: value.user_agent,
|
||||
retry_config: value.retry_config.map(Into::into).unwrap_or_default(),
|
||||
timeout_config: value.timeout_config.map(Into::into).unwrap_or_default(),
|
||||
extra_headers: value.extra_headers.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,7 @@ use arrow::pyarrow::FromPyArrow;
|
||||
use lancedb::index::scalar::FullTextSearchQuery;
|
||||
use lancedb::query::QueryExecutionOptions;
|
||||
use lancedb::query::{
|
||||
ExecutableQuery, HasQuery, Query as LanceDbQuery, QueryBase, Select,
|
||||
VectorQuery as LanceDbVectorQuery,
|
||||
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||
};
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||
@@ -313,7 +312,8 @@ impl VectorQuery {
|
||||
}
|
||||
|
||||
pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<HybridQuery> {
|
||||
let fts_query = Query::new(self.inner.mut_query().clone()).nearest_to_text(query)?;
|
||||
let base_query = self.inner.clone().into_plain();
|
||||
let fts_query = Query::new(base_query).nearest_to_text(query)?;
|
||||
Ok(HybridQuery {
|
||||
inner_vec: self.clone(),
|
||||
inner_fts: fts_query,
|
||||
@@ -411,10 +411,14 @@ impl HybridQuery {
|
||||
}
|
||||
|
||||
pub fn get_limit(&mut self) -> Option<u32> {
|
||||
self.inner_fts.inner.limit.map(|i| i as u32)
|
||||
self.inner_fts
|
||||
.inner
|
||||
.current_request()
|
||||
.limit
|
||||
.map(|i| i as u32)
|
||||
}
|
||||
|
||||
pub fn get_with_row_id(&mut self) -> bool {
|
||||
self.inner_fts.inner.with_row_id
|
||||
self.inner_fts.inner.current_request().with_row_id
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-node"
|
||||
version = "0.15.1-beta.2"
|
||||
version = "0.16.1-beta.2"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
@@ -169,5 +169,6 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
||||
cx.export_function("tableAddColumns", JsTable::js_add_columns)?;
|
||||
cx.export_function("tableAlterColumns", JsTable::js_alter_columns)?;
|
||||
cx.export_function("tableDropColumns", JsTable::js_drop_columns)?;
|
||||
cx.export_function("tableDropIndex", JsTable::js_drop_index)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -638,4 +638,8 @@ impl JsTable {
|
||||
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_drop_index(_cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
todo!("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.15.1-beta.2"
|
||||
version = "0.16.1-beta.2"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -19,7 +19,10 @@ arrow-ord = { workspace = true }
|
||||
arrow-cast = { workspace = true }
|
||||
arrow-ipc.workspace = true
|
||||
chrono = { workspace = true }
|
||||
datafusion-catalog.workspace = true
|
||||
datafusion-common.workspace = true
|
||||
datafusion-execution.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
datafusion-physical-plan.workspace = true
|
||||
object_store = { workspace = true }
|
||||
snafu = { workspace = true }
|
||||
@@ -33,7 +36,7 @@ lance-table = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
lance-testing = { workspace = true }
|
||||
lance-encoding = { workspace = true }
|
||||
moka = { workspace = true}
|
||||
moka = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
log.workspace = true
|
||||
@@ -82,7 +85,8 @@ aws-sdk-s3 = { version = "1.38.0" }
|
||||
aws-sdk-kms = { version = "1.37" }
|
||||
aws-config = { version = "1.0" }
|
||||
aws-smithy-runtime = { version = "1.3" }
|
||||
http-body = "1" # Matching reqwest
|
||||
datafusion.workspace = true
|
||||
http-body = "1" # Matching reqwest
|
||||
|
||||
|
||||
[features]
|
||||
@@ -98,7 +102,7 @@ sentence-transformers = [
|
||||
"dep:candle-core",
|
||||
"dep:candle-transformers",
|
||||
"dep:candle-nn",
|
||||
"dep:tokenizers"
|
||||
"dep:tokenizers",
|
||||
]
|
||||
|
||||
# TLS
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
133
rust/lancedb/src/database.rs
Normal file
133
rust/lancedb/src/database.rs
Normal file
@@ -0,0 +1,133 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! The database module defines the `Database` trait and related types.
|
||||
//!
|
||||
//! A "database" is a generic concept for something that manages tables and their metadata.
|
||||
//!
|
||||
//! We provide a basic implementation of a database that requires no additional infrastructure
|
||||
//! and is based off listing directories in a filesystem.
|
||||
//!
|
||||
//! Users may want to provider their own implementations for a variety of reasons:
|
||||
//! * Tables may be arranged in a different order on the S3 filesystem
|
||||
//! * Tables may be managed by some kind of independent application (e.g. some database)
|
||||
//! * Tables may be managed by a database system (e.g. Postgres)
|
||||
//! * A custom table implementation (e.g. remote table, etc.) may be used
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use lance::dataset::ReadParams;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::table::{BaseTable, TableDefinition, WriteOptions};
|
||||
|
||||
pub mod listing;
|
||||
|
||||
pub trait DatabaseOptions {
|
||||
fn serialize_into_map(&self, map: &mut HashMap<String, String>);
|
||||
}
|
||||
|
||||
/// A request to list names of tables in the database
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct TableNamesRequest {
|
||||
/// If present, only return names that come lexicographically after the supplied
|
||||
/// value.
|
||||
///
|
||||
/// This can be combined with limit to implement pagination by setting this to
|
||||
/// the last table name from the previous page.
|
||||
pub start_after: Option<String>,
|
||||
/// The maximum number of table names to return
|
||||
pub limit: Option<u32>,
|
||||
}
|
||||
|
||||
/// A request to open a table
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OpenTableRequest {
|
||||
pub name: String,
|
||||
pub index_cache_size: Option<u32>,
|
||||
pub lance_read_params: Option<ReadParams>,
|
||||
}
|
||||
|
||||
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableRequest) -> OpenTableRequest + Send>;
|
||||
|
||||
/// Describes what happens when creating a table and a table with
|
||||
/// the same name already exists
|
||||
pub enum CreateTableMode {
|
||||
/// If the table already exists, an error is returned
|
||||
Create,
|
||||
/// If the table already exists, it is opened. Any provided data is
|
||||
/// ignored. The function will be passed an OpenTableBuilder to customize
|
||||
/// how the table is opened
|
||||
ExistOk(TableBuilderCallback),
|
||||
/// If the table already exists, it is overwritten
|
||||
Overwrite,
|
||||
}
|
||||
|
||||
impl CreateTableMode {
|
||||
pub fn exist_ok(
|
||||
callback: impl FnOnce(OpenTableRequest) -> OpenTableRequest + Send + 'static,
|
||||
) -> Self {
|
||||
Self::ExistOk(Box::new(callback))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CreateTableMode {
|
||||
fn default() -> Self {
|
||||
Self::Create
|
||||
}
|
||||
}
|
||||
|
||||
/// The data to start a table or a schema to create an empty table
|
||||
pub enum CreateTableData {
|
||||
/// Creates a table using data, no schema required as it will be obtained from the data
|
||||
Data(Box<dyn RecordBatchReader + Send>),
|
||||
/// Creates an empty table, the definition / schema must be provided separately
|
||||
Empty(TableDefinition),
|
||||
}
|
||||
|
||||
/// A request to create a table
|
||||
pub struct CreateTableRequest {
|
||||
/// The name of the new table
|
||||
pub name: String,
|
||||
/// Initial data to write to the table, can be None to create an empty table
|
||||
pub data: CreateTableData,
|
||||
/// The mode to use when creating the table
|
||||
pub mode: CreateTableMode,
|
||||
/// Options to use when writing data (only used if `data` is not None)
|
||||
pub write_options: WriteOptions,
|
||||
}
|
||||
|
||||
impl CreateTableRequest {
|
||||
pub fn new(name: String, data: CreateTableData) -> Self {
|
||||
Self {
|
||||
name,
|
||||
data,
|
||||
mode: CreateTableMode::default(),
|
||||
write_options: WriteOptions::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The `Database` trait defines the interface for database implementations.
|
||||
///
|
||||
/// A database is responsible for managing tables and their metadata.
|
||||
#[async_trait::async_trait]
|
||||
pub trait Database:
|
||||
Send + Sync + std::any::Any + std::fmt::Debug + std::fmt::Display + 'static
|
||||
{
|
||||
/// List the names of tables in the database
|
||||
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>>;
|
||||
/// Create a table in the database
|
||||
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn BaseTable>>;
|
||||
/// Open a table in the database
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>>;
|
||||
/// Rename a table in the database
|
||||
async fn rename_table(&self, old_name: &str, new_name: &str) -> Result<()>;
|
||||
/// Drop a table in the database
|
||||
async fn drop_table(&self, name: &str) -> Result<()>;
|
||||
/// Drop all tables in the database
|
||||
async fn drop_all_tables(&self) -> Result<()>;
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
}
|
||||
542
rust/lancedb/src/database/listing.rs
Normal file
542
rust/lancedb/src/database/listing.rs
Normal file
@@ -0,0 +1,542 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Provides the `ListingDatabase`, a simple database where tables are folders in a directory
|
||||
|
||||
use std::fs::create_dir_all;
|
||||
use std::path::Path;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use arrow_array::RecordBatchIterator;
|
||||
use lance::dataset::{ReadParams, WriteMode};
|
||||
use lance::io::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry, WrappingObjectStore};
|
||||
use lance_encoding::version::LanceFileVersion;
|
||||
use lance_table::io::commit::commit_handler_from_url;
|
||||
use object_store::local::LocalFileSystem;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
|
||||
use crate::connection::ConnectRequest;
|
||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||
use crate::table::NativeTable;
|
||||
use crate::utils::validate_table_name;
|
||||
|
||||
use super::{
|
||||
BaseTable, CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions,
|
||||
OpenTableRequest, TableNamesRequest,
|
||||
};
|
||||
|
||||
/// File extension to indicate a lance table
|
||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
|
||||
pub const OPT_NEW_TABLE_STORAGE_VERSION: &str = "new_table_data_storage_version";
|
||||
pub const OPT_NEW_TABLE_V2_MANIFEST_PATHS: &str = "new_table_enable_v2_manifest_paths";
|
||||
|
||||
/// Controls how new tables should be created
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct NewTableConfig {
|
||||
/// The storage version to use for new tables
|
||||
///
|
||||
/// If unset, then the latest stable version will be used
|
||||
pub data_storage_version: Option<LanceFileVersion>,
|
||||
/// Whether to enable V2 manifest paths for new tables
|
||||
///
|
||||
/// V2 manifest paths are more efficient than V2 manifest paths but are not
|
||||
/// supported by old clients.
|
||||
pub enable_v2_manifest_paths: Option<bool>,
|
||||
}
|
||||
|
||||
/// Options specific to the listing database
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct ListingDatabaseOptions {
|
||||
/// Controls what kind of Lance tables will be created by this database
|
||||
pub new_table_config: NewTableConfig,
|
||||
}
|
||||
|
||||
impl ListingDatabaseOptions {
|
||||
fn parse_from_map(map: &HashMap<String, String>) -> Result<Self> {
|
||||
let new_table_config = NewTableConfig {
|
||||
data_storage_version: map
|
||||
.get(OPT_NEW_TABLE_STORAGE_VERSION)
|
||||
.map(|s| s.parse())
|
||||
.transpose()?,
|
||||
enable_v2_manifest_paths: map
|
||||
.get(OPT_NEW_TABLE_V2_MANIFEST_PATHS)
|
||||
.map(|s| {
|
||||
s.parse::<bool>().map_err(|_| Error::InvalidInput {
|
||||
message: format!(
|
||||
"enable_v2_manifest_paths must be a boolean, received {}",
|
||||
s
|
||||
),
|
||||
})
|
||||
})
|
||||
.transpose()?,
|
||||
};
|
||||
Ok(Self { new_table_config })
|
||||
}
|
||||
}
|
||||
|
||||
impl DatabaseOptions for ListingDatabaseOptions {
|
||||
fn serialize_into_map(&self, map: &mut HashMap<String, String>) {
|
||||
if let Some(storage_version) = &self.new_table_config.data_storage_version {
|
||||
map.insert(
|
||||
OPT_NEW_TABLE_STORAGE_VERSION.to_string(),
|
||||
storage_version.to_string(),
|
||||
);
|
||||
}
|
||||
if let Some(enable_v2_manifest_paths) = self.new_table_config.enable_v2_manifest_paths {
|
||||
map.insert(
|
||||
OPT_NEW_TABLE_V2_MANIFEST_PATHS.to_string(),
|
||||
enable_v2_manifest_paths.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A database that stores tables in a flat directory structure
|
||||
///
|
||||
/// Tables are stored as directories in the base path of the object store.
|
||||
///
|
||||
/// It is called a "listing database" because we use a "list directory" operation
|
||||
/// to discover what tables are available. Table names are determined from the directory
|
||||
/// names.
|
||||
///
|
||||
/// For example, given the following directory structure:
|
||||
///
|
||||
/// ```text
|
||||
/// /data
|
||||
/// /table1.lance
|
||||
/// /table2.lance
|
||||
/// ```
|
||||
///
|
||||
/// We will have two tables named `table1` and `table2`.
|
||||
#[derive(Debug)]
|
||||
pub struct ListingDatabase {
|
||||
object_store: ObjectStore,
|
||||
query_string: Option<String>,
|
||||
|
||||
pub(crate) uri: String,
|
||||
pub(crate) base_path: object_store::path::Path,
|
||||
|
||||
// the object store wrapper to use on write path
|
||||
pub(crate) store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
|
||||
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
|
||||
// Storage options to be inherited by tables created from this connection
|
||||
storage_options: HashMap<String, String>,
|
||||
|
||||
// Options for tables created by this connection
|
||||
new_table_config: NewTableConfig,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ListingDatabase {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"ListingDatabase(uri={}, read_consistency_interval={})",
|
||||
self.uri,
|
||||
match self.read_consistency_interval {
|
||||
None => {
|
||||
"None".to_string()
|
||||
}
|
||||
Some(duration) => {
|
||||
format!("{}s", duration.as_secs_f64())
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const LANCE_EXTENSION: &str = "lance";
|
||||
const ENGINE: &str = "engine";
|
||||
const MIRRORED_STORE: &str = "mirroredStore";
|
||||
|
||||
/// A connection to LanceDB
|
||||
impl ListingDatabase {
|
||||
/// Connect to a listing database
|
||||
///
|
||||
/// The URI should be a path to a directory where the tables are stored.
|
||||
///
|
||||
/// See [`ListingDatabaseOptions`] for options that can be set on the connection (via
|
||||
/// `storage_options`).
|
||||
pub async fn connect_with_options(request: &ConnectRequest) -> Result<Self> {
|
||||
let uri = &request.uri;
|
||||
let parse_res = url::Url::parse(uri);
|
||||
|
||||
let options = ListingDatabaseOptions::parse_from_map(&request.storage_options)?;
|
||||
|
||||
// TODO: pass params regardless of OS
|
||||
match parse_res {
|
||||
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
|
||||
Self::open_path(
|
||||
uri,
|
||||
request.read_consistency_interval,
|
||||
options.new_table_config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
Ok(mut url) => {
|
||||
// iter thru the query params and extract the commit store param
|
||||
let mut engine = None;
|
||||
let mut mirrored_store = None;
|
||||
let mut filtered_querys = vec![];
|
||||
|
||||
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
||||
// THE API WILL CHANGE
|
||||
for (key, value) in url.query_pairs() {
|
||||
if key == ENGINE {
|
||||
engine = Some(value.to_string());
|
||||
} else if key == MIRRORED_STORE {
|
||||
if cfg!(windows) {
|
||||
return Err(Error::NotSupported {
|
||||
message: "mirrored store is not supported on windows".into(),
|
||||
});
|
||||
}
|
||||
mirrored_store = Some(value.to_string());
|
||||
} else {
|
||||
// to owned so we can modify the url
|
||||
filtered_querys.push((key.to_string(), value.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out the commit store query param -- it's a lancedb param
|
||||
url.query_pairs_mut().clear();
|
||||
url.query_pairs_mut().extend_pairs(filtered_querys);
|
||||
// Take a copy of the query string so we can propagate it to lance
|
||||
let query_string = url.query().map(|s| s.to_string());
|
||||
// clear the query string so we can use the url as the base uri
|
||||
// use .set_query(None) instead of .set_query("") because the latter
|
||||
// will add a trailing '?' to the url
|
||||
url.set_query(None);
|
||||
|
||||
let table_base_uri = if let Some(store) = engine {
|
||||
static WARN_ONCE: std::sync::Once = std::sync::Once::new();
|
||||
WARN_ONCE.call_once(|| {
|
||||
log::warn!("Specifying engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE");
|
||||
});
|
||||
let old_scheme = url.scheme().to_string();
|
||||
let new_scheme = format!("{}+{}", old_scheme, store);
|
||||
url.to_string().replacen(&old_scheme, &new_scheme, 1)
|
||||
} else {
|
||||
url.to_string()
|
||||
};
|
||||
|
||||
let plain_uri = url.to_string();
|
||||
|
||||
let registry = Arc::new(ObjectStoreRegistry::default());
|
||||
let storage_options = request.storage_options.clone();
|
||||
let os_params = ObjectStoreParams {
|
||||
storage_options: Some(storage_options.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
let (object_store, base_path) =
|
||||
ObjectStore::from_uri_and_params(registry, &plain_uri, &os_params).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
|
||||
}
|
||||
|
||||
let write_store_wrapper = match mirrored_store {
|
||||
Some(path) => {
|
||||
let mirrored_store = Arc::new(LocalFileSystem::new_with_prefix(path)?);
|
||||
let wrapper = MirroringObjectStoreWrapper::new(mirrored_store);
|
||||
Some(Arc::new(wrapper) as Arc<dyn WrappingObjectStore>)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
uri: table_base_uri,
|
||||
query_string,
|
||||
base_path,
|
||||
object_store,
|
||||
store_wrapper: write_store_wrapper,
|
||||
read_consistency_interval: request.read_consistency_interval,
|
||||
storage_options,
|
||||
new_table_config: options.new_table_config,
|
||||
})
|
||||
}
|
||||
Err(_) => {
|
||||
Self::open_path(
|
||||
uri,
|
||||
request.read_consistency_interval,
|
||||
options.new_table_config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn open_path(
|
||||
path: &str,
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
new_table_config: NewTableConfig,
|
||||
) -> Result<Self> {
|
||||
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
uri: path.to_string(),
|
||||
query_string: None,
|
||||
base_path,
|
||||
object_store,
|
||||
store_wrapper: None,
|
||||
read_consistency_interval,
|
||||
storage_options: HashMap::new(),
|
||||
new_table_config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Try to create a local directory to store the lancedb dataset
|
||||
fn try_create_dir(path: &str) -> core::result::Result<(), std::io::Error> {
|
||||
let path = Path::new(path);
|
||||
if !path.try_exists()? {
|
||||
create_dir_all(path)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the URI of a table in the database.
|
||||
fn table_uri(&self, name: &str) -> Result<String> {
|
||||
validate_table_name(name)?;
|
||||
|
||||
let path = Path::new(&self.uri);
|
||||
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
|
||||
|
||||
let mut uri = table_uri
|
||||
.as_path()
|
||||
.to_str()
|
||||
.context(InvalidTableNameSnafu {
|
||||
name,
|
||||
reason: "Name is not valid URL",
|
||||
})?
|
||||
.to_string();
|
||||
|
||||
// If there are query string set on the connection, propagate to lance
|
||||
if let Some(query) = self.query_string.as_ref() {
|
||||
uri.push('?');
|
||||
uri.push_str(query.as_str());
|
||||
}
|
||||
|
||||
Ok(uri)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Database for ListingDatabase {
|
||||
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>> {
|
||||
let mut f = self
|
||||
.object_store
|
||||
.read_dir(self.base_path.clone())
|
||||
.await?
|
||||
.iter()
|
||||
.map(Path::new)
|
||||
.filter(|path| {
|
||||
let is_lance = path
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.map(|e| e == LANCE_EXTENSION);
|
||||
is_lance.unwrap_or(false)
|
||||
})
|
||||
.filter_map(|p| p.file_stem().and_then(|s| s.to_str().map(String::from)))
|
||||
.collect::<Vec<String>>();
|
||||
f.sort();
|
||||
if let Some(start_after) = request.start_after {
|
||||
let index = f
|
||||
.iter()
|
||||
.position(|name| name.as_str() > start_after.as_str())
|
||||
.unwrap_or(f.len());
|
||||
f.drain(0..index);
|
||||
}
|
||||
if let Some(limit) = request.limit {
|
||||
f.truncate(limit as usize);
|
||||
}
|
||||
Ok(f)
|
||||
}
|
||||
|
||||
async fn create_table(&self, mut request: CreateTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
let table_uri = self.table_uri(&request.name)?;
|
||||
// Inherit storage options from the connection
|
||||
let storage_options = request
|
||||
.write_options
|
||||
.lance_write_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.store_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.storage_options
|
||||
.get_or_insert_with(Default::default);
|
||||
for (key, value) in self.storage_options.iter() {
|
||||
if !storage_options.contains_key(key) {
|
||||
storage_options.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let storage_options = storage_options.clone();
|
||||
|
||||
let mut write_params = request.write_options.lance_write_params.unwrap_or_default();
|
||||
|
||||
if let Some(storage_version) = &self.new_table_config.data_storage_version {
|
||||
write_params.data_storage_version = Some(*storage_version);
|
||||
} else {
|
||||
// Allow the user to override the storage version via storage options (backwards compatibility)
|
||||
if let Some(data_storage_version) = storage_options.get(OPT_NEW_TABLE_STORAGE_VERSION) {
|
||||
write_params.data_storage_version = Some(data_storage_version.parse()?);
|
||||
}
|
||||
}
|
||||
if let Some(enable_v2_manifest_paths) = self.new_table_config.enable_v2_manifest_paths {
|
||||
write_params.enable_v2_manifest_paths = enable_v2_manifest_paths;
|
||||
} else {
|
||||
// Allow the user to override the storage version via storage options (backwards compatibility)
|
||||
if let Some(enable_v2_manifest_paths) = storage_options
|
||||
.get(OPT_NEW_TABLE_V2_MANIFEST_PATHS)
|
||||
.map(|s| s.parse::<bool>().unwrap())
|
||||
{
|
||||
write_params.enable_v2_manifest_paths = enable_v2_manifest_paths;
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(&request.mode, CreateTableMode::Overwrite) {
|
||||
write_params.mode = WriteMode::Overwrite;
|
||||
}
|
||||
|
||||
let data = match request.data {
|
||||
CreateTableData::Data(data) => data,
|
||||
CreateTableData::Empty(table_definition) => {
|
||||
let schema = table_definition.schema.clone();
|
||||
Box::new(RecordBatchIterator::new(vec![], schema))
|
||||
}
|
||||
};
|
||||
let data_schema = data.schema();
|
||||
|
||||
match NativeTable::create(
|
||||
&table_uri,
|
||||
&request.name,
|
||||
data,
|
||||
self.store_wrapper.clone(),
|
||||
Some(write_params),
|
||||
self.read_consistency_interval,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(table) => Ok(Arc::new(table)),
|
||||
Err(Error::TableAlreadyExists { name }) => match request.mode {
|
||||
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
|
||||
CreateTableMode::ExistOk(callback) => {
|
||||
let req = OpenTableRequest {
|
||||
name: request.name.clone(),
|
||||
index_cache_size: None,
|
||||
lance_read_params: None,
|
||||
};
|
||||
let req = (callback)(req);
|
||||
let table = self.open_table(req).await?;
|
||||
|
||||
let table_schema = table.schema().await?;
|
||||
|
||||
if table_schema != data_schema {
|
||||
return Err(Error::Schema {
|
||||
message: "Provided schema does not match existing table schema"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(table)
|
||||
}
|
||||
CreateTableMode::Overwrite => unreachable!(),
|
||||
},
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
async fn open_table(&self, mut request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
let table_uri = self.table_uri(&request.name)?;
|
||||
|
||||
// Inherit storage options from the connection
|
||||
let storage_options = request
|
||||
.lance_read_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.store_options
|
||||
.get_or_insert_with(Default::default)
|
||||
.storage_options
|
||||
.get_or_insert_with(Default::default);
|
||||
for (key, value) in self.storage_options.iter() {
|
||||
if !storage_options.contains_key(key) {
|
||||
storage_options.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Some ReadParams are exposed in the OpenTableBuilder, but we also
|
||||
// let the user provide their own ReadParams.
|
||||
//
|
||||
// If we have a user provided ReadParams use that
|
||||
// If we don't then start with the default ReadParams and customize it with
|
||||
// the options from the OpenTableBuilder
|
||||
let read_params = request.lance_read_params.unwrap_or_else(|| {
|
||||
let mut default_params = ReadParams::default();
|
||||
if let Some(index_cache_size) = request.index_cache_size {
|
||||
default_params.index_cache_size = index_cache_size as usize;
|
||||
}
|
||||
default_params
|
||||
});
|
||||
|
||||
let native_table = Arc::new(
|
||||
NativeTable::open_with_params(
|
||||
&table_uri,
|
||||
&request.name,
|
||||
self.store_wrapper.clone(),
|
||||
Some(read_params),
|
||||
self.read_consistency_interval,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
Ok(native_table)
|
||||
}
|
||||
|
||||
async fn rename_table(&self, _old_name: &str, _new_name: &str) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "rename_table is not supported in LanceDB OSS".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn drop_table(&self, name: &str) -> Result<()> {
|
||||
let dir_name = format!("{}.{}", name, LANCE_EXTENSION);
|
||||
let full_path = self.base_path.child(dir_name.clone());
|
||||
self.object_store
|
||||
.remove_dir_all(full_path.clone())
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
// this error is not lance::Error::DatasetNotFound,
|
||||
// as the method `remove_dir_all` may be used to remove something not be a dataset
|
||||
lance::Error::NotFound { .. } => Error::TableNotFound {
|
||||
name: name.to_owned(),
|
||||
},
|
||||
_ => Error::from(err),
|
||||
})?;
|
||||
|
||||
let object_store_params = ObjectStoreParams {
|
||||
storage_options: Some(self.storage_options.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
let mut uri = self.uri.clone();
|
||||
if let Some(query_string) = &self.query_string {
|
||||
uri.push_str(&format!("?{}", query_string));
|
||||
}
|
||||
let commit_handler = commit_handler_from_url(&uri, &Some(object_store_params))
|
||||
.await
|
||||
.unwrap();
|
||||
commit_handler.delete(&full_path).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_all_tables(&self) -> Result<()> {
|
||||
self.object_store
|
||||
.remove_dir_all(self.base_path.clone())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,8 @@ pub enum Error {
|
||||
InvalidInput { message: String },
|
||||
#[snafu(display("Table '{name}' was not found"))]
|
||||
TableNotFound { name: String },
|
||||
#[snafu(display("Index '{name}' was not found"))]
|
||||
IndexNotFound { name: String },
|
||||
#[snafu(display("Embedding function '{name}' was not found. : {reason}"))]
|
||||
EmbeddingFunctionNotFound { name: String, reason: String },
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ use serde::Deserialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use vector::IvfFlatIndexBuilder;
|
||||
|
||||
use crate::{table::TableInternal, DistanceType, Error, Result};
|
||||
use crate::{table::BaseTable, DistanceType, Error, Result};
|
||||
|
||||
use self::{
|
||||
scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder},
|
||||
@@ -65,14 +65,14 @@ pub enum Index {
|
||||
///
|
||||
/// The methods on this builder are used to specify options common to all indices.
|
||||
pub struct IndexBuilder {
|
||||
parent: Arc<dyn TableInternal>,
|
||||
parent: Arc<dyn BaseTable>,
|
||||
pub(crate) index: Index,
|
||||
pub(crate) columns: Vec<String>,
|
||||
pub(crate) replace: bool,
|
||||
}
|
||||
|
||||
impl IndexBuilder {
|
||||
pub(crate) fn new(parent: Arc<dyn TableInternal>, columns: Vec<String>, index: Index) -> Self {
|
||||
pub(crate) fn new(parent: Arc<dyn BaseTable>, columns: Vec<String>, index: Index) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
index,
|
||||
|
||||
@@ -23,7 +23,19 @@ impl VectorIndex {
|
||||
let fields = index
|
||||
.fields
|
||||
.iter()
|
||||
.map(|i| manifest.schema.fields[*i as usize].name.clone())
|
||||
.map(|field_id| {
|
||||
manifest
|
||||
.schema
|
||||
.field_by_id(*field_id)
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"field {field_id} of index {} must exist in schema",
|
||||
index.name
|
||||
)
|
||||
})
|
||||
.name
|
||||
.clone()
|
||||
})
|
||||
.collect();
|
||||
Self {
|
||||
columns: fields,
|
||||
|
||||
@@ -193,6 +193,7 @@
|
||||
pub mod arrow;
|
||||
pub mod connection;
|
||||
pub mod data;
|
||||
pub mod database;
|
||||
pub mod embeddings;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
|
||||
@@ -20,12 +20,12 @@ use lance_index::scalar::FullTextSearchQuery;
|
||||
use lance_index::vector::DIST_COL;
|
||||
use lance_io::stream::RecordBatchStreamAdapter;
|
||||
|
||||
use crate::arrow::SendableRecordBatchStream;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::rerankers::rrf::RRFReranker;
|
||||
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
|
||||
use crate::table::TableInternal;
|
||||
use crate::table::BaseTable;
|
||||
use crate::DistanceType;
|
||||
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
|
||||
|
||||
mod hybrid;
|
||||
|
||||
@@ -449,7 +449,7 @@ pub trait QueryBase {
|
||||
}
|
||||
|
||||
pub trait HasQuery {
|
||||
fn mut_query(&mut self) -> &mut Query;
|
||||
fn mut_query(&mut self) -> &mut QueryRequest;
|
||||
}
|
||||
|
||||
impl<T: HasQuery> QueryBase for T {
|
||||
@@ -577,6 +577,65 @@ pub trait ExecutableQuery {
|
||||
fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send;
|
||||
}
|
||||
|
||||
/// A basic query into a table without any kind of search
|
||||
///
|
||||
/// This will result in a (potentially filtered) scan if executed
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QueryRequest {
|
||||
/// limit the number of rows to return.
|
||||
pub limit: Option<usize>,
|
||||
|
||||
/// Offset of the query.
|
||||
pub offset: Option<usize>,
|
||||
|
||||
/// Apply filter to the returned rows.
|
||||
pub filter: Option<String>,
|
||||
|
||||
/// Perform a full text search on the table.
|
||||
pub full_text_search: Option<FullTextSearchQuery>,
|
||||
|
||||
/// Select column projection.
|
||||
pub select: Select,
|
||||
|
||||
/// If set to true, the query is executed only on the indexed data,
|
||||
/// and yields faster results.
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub fast_search: bool,
|
||||
|
||||
/// If set to true, the query will return the `_rowid` meta column.
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub with_row_id: bool,
|
||||
|
||||
/// If set to false, the filter will be applied after the vector search.
|
||||
pub prefilter: bool,
|
||||
|
||||
/// Implementation of reranker that can be used to reorder or combine query
|
||||
/// results, especially if using hybrid search
|
||||
pub reranker: Option<Arc<dyn Reranker>>,
|
||||
|
||||
/// Configure how query results are normalized when doing hybrid search
|
||||
pub norm: Option<NormalizeMethod>,
|
||||
}
|
||||
|
||||
impl Default for QueryRequest {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
limit: Some(DEFAULT_TOP_K),
|
||||
offset: None,
|
||||
filter: None,
|
||||
full_text_search: None,
|
||||
select: Select::All,
|
||||
fast_search: false,
|
||||
with_row_id: false,
|
||||
prefilter: true,
|
||||
reranker: None,
|
||||
norm: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A builder for LanceDB queries.
|
||||
///
|
||||
/// See [`crate::Table::query`] for more details on queries
|
||||
@@ -591,59 +650,15 @@ pub trait ExecutableQuery {
|
||||
/// times.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Query {
|
||||
parent: Arc<dyn TableInternal>,
|
||||
|
||||
/// limit the number of rows to return.
|
||||
pub limit: Option<usize>,
|
||||
|
||||
/// Offset of the query.
|
||||
pub(crate) offset: Option<usize>,
|
||||
|
||||
/// Apply filter to the returned rows.
|
||||
pub(crate) filter: Option<String>,
|
||||
|
||||
/// Perform a full text search on the table.
|
||||
pub(crate) full_text_search: Option<FullTextSearchQuery>,
|
||||
|
||||
/// Select column projection.
|
||||
pub(crate) select: Select,
|
||||
|
||||
/// If set to true, the query is executed only on the indexed data,
|
||||
/// and yields faster results.
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub(crate) fast_search: bool,
|
||||
|
||||
/// If set to true, the query will return the `_rowid` meta column.
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub with_row_id: bool,
|
||||
|
||||
/// If set to false, the filter will be applied after the vector search.
|
||||
pub(crate) prefilter: bool,
|
||||
|
||||
/// Implementation of reranker that can be used to reorder or combine query
|
||||
/// results, especially if using hybrid search
|
||||
pub(crate) reranker: Option<Arc<dyn Reranker>>,
|
||||
|
||||
/// Configure how query results are normalized when doing hybrid search
|
||||
pub(crate) norm: Option<NormalizeMethod>,
|
||||
parent: Arc<dyn BaseTable>,
|
||||
request: QueryRequest,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
pub(crate) fn new(parent: Arc<dyn TableInternal>) -> Self {
|
||||
pub(crate) fn new(parent: Arc<dyn BaseTable>) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
limit: Some(DEFAULT_TOP_K),
|
||||
offset: None,
|
||||
filter: None,
|
||||
full_text_search: None,
|
||||
select: Select::All,
|
||||
fast_search: false,
|
||||
with_row_id: false,
|
||||
prefilter: true,
|
||||
reranker: None,
|
||||
norm: None,
|
||||
request: QueryRequest::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -691,38 +706,98 @@ impl Query {
|
||||
pub fn nearest_to(self, vector: impl IntoQueryVector) -> Result<VectorQuery> {
|
||||
let mut vector_query = self.into_vector();
|
||||
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
|
||||
vector_query.query_vector.push(query_vector);
|
||||
vector_query.request.query_vector.push(query_vector);
|
||||
Ok(vector_query)
|
||||
}
|
||||
|
||||
pub fn into_request(self) -> QueryRequest {
|
||||
self.request
|
||||
}
|
||||
|
||||
pub fn current_request(&self) -> &QueryRequest {
|
||||
&self.request
|
||||
}
|
||||
}
|
||||
|
||||
impl HasQuery for Query {
|
||||
fn mut_query(&mut self) -> &mut Query {
|
||||
self
|
||||
fn mut_query(&mut self) -> &mut QueryRequest {
|
||||
&mut self.request
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutableQuery for Query {
|
||||
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
self.parent
|
||||
.clone()
|
||||
.create_plan(&self.clone().into_vector(), options)
|
||||
.await
|
||||
let req = AnyQuery::Query(self.request.clone());
|
||||
self.parent.clone().create_plan(&req, options).await
|
||||
}
|
||||
|
||||
async fn execute_with_options(
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let query = AnyQuery::Query(self.request.clone());
|
||||
Ok(SendableRecordBatchStream::from(
|
||||
self.parent.clone().plain_query(self, options).await?,
|
||||
self.parent.clone().query(&query, options).await?,
|
||||
))
|
||||
}
|
||||
|
||||
async fn explain_plan(&self, verbose: bool) -> Result<String> {
|
||||
self.parent
|
||||
.explain_plan(&self.clone().into_vector(), verbose)
|
||||
.await
|
||||
let query = AnyQuery::Query(self.request.clone());
|
||||
self.parent.explain_plan(&query, verbose).await
|
||||
}
|
||||
}
|
||||
|
||||
/// A request for a nearest-neighbors search into a table
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorQueryRequest {
|
||||
/// The base query
|
||||
pub base: QueryRequest,
|
||||
/// The column to run the search on
|
||||
///
|
||||
/// If None, then the table will need to auto-detect which column to use
|
||||
pub column: Option<String>,
|
||||
/// The vector(s) to search for
|
||||
pub query_vector: Vec<Arc<dyn Array>>,
|
||||
/// The number of partitions to search
|
||||
pub nprobes: usize,
|
||||
/// The lower bound (inclusive) of the distance to search for.
|
||||
pub lower_bound: Option<f32>,
|
||||
/// The upper bound (exclusive) of the distance to search for.
|
||||
pub upper_bound: Option<f32>,
|
||||
/// The number of candidates to return during the refine step for HNSW,
|
||||
/// defaults to 1.5 * limit.
|
||||
pub ef: Option<usize>,
|
||||
/// A multiplier to control how many additional rows are taken during the refine step
|
||||
pub refine_factor: Option<u32>,
|
||||
/// The distance type to use for the search
|
||||
pub distance_type: Option<DistanceType>,
|
||||
/// Default is true. Set to false to enforce a brute force search.
|
||||
pub use_index: bool,
|
||||
}
|
||||
|
||||
impl Default for VectorQueryRequest {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
base: QueryRequest::default(),
|
||||
column: None,
|
||||
query_vector: Vec::new(),
|
||||
nprobes: 20,
|
||||
lower_bound: None,
|
||||
upper_bound: None,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
use_index: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VectorQueryRequest {
|
||||
pub fn from_plain_query(query: QueryRequest) -> Self {
|
||||
Self {
|
||||
base: query,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -737,39 +812,30 @@ impl ExecutableQuery for Query {
|
||||
/// the query and retrieve results.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorQuery {
|
||||
pub(crate) base: Query,
|
||||
// The column to run the query on. If not specified, we will attempt to guess
|
||||
// the column based on the dataset's schema.
|
||||
pub(crate) column: Option<String>,
|
||||
// IVF PQ - ANN search.
|
||||
pub(crate) query_vector: Vec<Arc<dyn Array>>,
|
||||
pub(crate) nprobes: usize,
|
||||
// The lower bound (inclusive) of the distance to search for.
|
||||
pub(crate) lower_bound: Option<f32>,
|
||||
// The upper bound (exclusive) of the distance to search for.
|
||||
pub(crate) upper_bound: Option<f32>,
|
||||
// The number of candidates to return during the refine step for HNSW,
|
||||
// defaults to 1.5 * limit.
|
||||
pub(crate) ef: Option<usize>,
|
||||
pub(crate) refine_factor: Option<u32>,
|
||||
pub(crate) distance_type: Option<DistanceType>,
|
||||
/// Default is true. Set to false to enforce a brute force search.
|
||||
pub(crate) use_index: bool,
|
||||
parent: Arc<dyn BaseTable>,
|
||||
request: VectorQueryRequest,
|
||||
}
|
||||
|
||||
impl VectorQuery {
|
||||
fn new(base: Query) -> Self {
|
||||
Self {
|
||||
base,
|
||||
column: None,
|
||||
query_vector: Vec::new(),
|
||||
nprobes: 20,
|
||||
lower_bound: None,
|
||||
upper_bound: None,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
use_index: true,
|
||||
parent: base.parent,
|
||||
request: VectorQueryRequest::from_plain_query(base.request),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_request(self) -> VectorQueryRequest {
|
||||
self.request
|
||||
}
|
||||
|
||||
pub fn current_request(&self) -> &VectorQueryRequest {
|
||||
&self.request
|
||||
}
|
||||
|
||||
pub fn into_plain(self) -> Query {
|
||||
Query {
|
||||
parent: self.parent,
|
||||
request: self.request.base,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -781,7 +847,7 @@ impl VectorQuery {
|
||||
/// This parameter must be specified if the table has more than one column
|
||||
/// whose data type is a fixed-size-list of floats.
|
||||
pub fn column(mut self, column: &str) -> Self {
|
||||
self.column = Some(column.to_string());
|
||||
self.request.column = Some(column.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
@@ -797,7 +863,7 @@ impl VectorQuery {
|
||||
/// result.
|
||||
pub fn add_query_vector(mut self, vector: impl IntoQueryVector) -> Result<Self> {
|
||||
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
|
||||
self.query_vector.push(query_vector);
|
||||
self.request.query_vector.push(query_vector);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
@@ -822,15 +888,15 @@ impl VectorQuery {
|
||||
/// your actual data to find the smallest possible value that will still give
|
||||
/// you the desired recall.
|
||||
pub fn nprobes(mut self, nprobes: usize) -> Self {
|
||||
self.nprobes = nprobes;
|
||||
self.request.nprobes = nprobes;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the distance range for vector search,
|
||||
/// only rows with distances in the range [lower_bound, upper_bound) will be returned
|
||||
pub fn distance_range(mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) -> Self {
|
||||
self.lower_bound = lower_bound;
|
||||
self.upper_bound = upper_bound;
|
||||
self.request.lower_bound = lower_bound;
|
||||
self.request.upper_bound = upper_bound;
|
||||
self
|
||||
}
|
||||
|
||||
@@ -842,7 +908,7 @@ impl VectorQuery {
|
||||
/// Increasing this value will increase the recall of your query but will
|
||||
/// also increase the latency of your query. The default value is 1.5*limit.
|
||||
pub fn ef(mut self, ef: usize) -> Self {
|
||||
self.ef = Some(ef);
|
||||
self.request.ef = Some(ef);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -874,7 +940,7 @@ impl VectorQuery {
|
||||
/// and the quantized result vectors. This can be considerably different than the true
|
||||
/// distance between the query vector and the actual uncompressed vector.
|
||||
pub fn refine_factor(mut self, refine_factor: u32) -> Self {
|
||||
self.refine_factor = Some(refine_factor);
|
||||
self.request.refine_factor = Some(refine_factor);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -891,7 +957,7 @@ impl VectorQuery {
|
||||
///
|
||||
/// By default [`DistanceType::L2`] is used.
|
||||
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
|
||||
self.distance_type = Some(distance_type);
|
||||
self.request.distance_type = Some(distance_type);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -903,16 +969,19 @@ impl VectorQuery {
|
||||
/// the vector index can give you ground truth results which you can use to
|
||||
/// calculate your recall to select an appropriate value for nprobes.
|
||||
pub fn bypass_vector_index(mut self) -> Self {
|
||||
self.use_index = false;
|
||||
self.request.use_index = false;
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
|
||||
// clone query and specify we want to include row IDs, which can be needed for reranking
|
||||
let fts_query = self.base.clone().with_row_id();
|
||||
let mut fts_query = Query::new(self.parent.clone());
|
||||
fts_query.request = self.request.base.clone();
|
||||
fts_query = fts_query.with_row_id();
|
||||
|
||||
let mut vector_query = self.clone().with_row_id();
|
||||
|
||||
vector_query.base.full_text_search = None;
|
||||
vector_query.request.base.full_text_search = None;
|
||||
let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
|
||||
|
||||
let (fts_results, vec_results) = try_join!(
|
||||
@@ -928,7 +997,7 @@ impl VectorQuery {
|
||||
let mut fts_results = concat_batches(&fts_schema, fts_results.iter())?;
|
||||
let mut vec_results = concat_batches(&vec_schema, vec_results.iter())?;
|
||||
|
||||
if matches!(self.base.norm, Some(NormalizeMethod::Rank)) {
|
||||
if matches!(self.request.base.norm, Some(NormalizeMethod::Rank)) {
|
||||
vec_results = hybrid::rank(vec_results, DIST_COL, None)?;
|
||||
fts_results = hybrid::rank(fts_results, SCORE_COL, None)?;
|
||||
}
|
||||
@@ -937,14 +1006,20 @@ impl VectorQuery {
|
||||
fts_results = hybrid::normalize_scores(fts_results, SCORE_COL, None)?;
|
||||
|
||||
let reranker = self
|
||||
.request
|
||||
.base
|
||||
.reranker
|
||||
.clone()
|
||||
.unwrap_or(Arc::new(RRFReranker::default()));
|
||||
|
||||
let fts_query = self.base.full_text_search.as_ref().ok_or(Error::Runtime {
|
||||
message: "there should be an FTS search".to_string(),
|
||||
})?;
|
||||
let fts_query = self
|
||||
.request
|
||||
.base
|
||||
.full_text_search
|
||||
.as_ref()
|
||||
.ok_or(Error::Runtime {
|
||||
message: "there should be an FTS search".to_string(),
|
||||
})?;
|
||||
|
||||
let mut results = reranker
|
||||
.rerank_hybrid(&fts_query.query, vec_results, fts_results)
|
||||
@@ -952,12 +1027,12 @@ impl VectorQuery {
|
||||
|
||||
check_reranker_result(&results)?;
|
||||
|
||||
let limit = self.base.limit.unwrap_or(DEFAULT_TOP_K);
|
||||
let limit = self.request.base.limit.unwrap_or(DEFAULT_TOP_K);
|
||||
if results.num_rows() > limit {
|
||||
results = results.slice(0, limit);
|
||||
}
|
||||
|
||||
if !self.base.with_row_id {
|
||||
if !self.request.base.with_row_id {
|
||||
results = results.drop_column(ROW_ID)?;
|
||||
}
|
||||
|
||||
@@ -969,14 +1044,15 @@ impl VectorQuery {
|
||||
|
||||
impl ExecutableQuery for VectorQuery {
|
||||
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
self.base.parent.clone().create_plan(self, options).await
|
||||
let query = AnyQuery::VectorQuery(self.request.clone());
|
||||
self.parent.clone().create_plan(&query, options).await
|
||||
}
|
||||
|
||||
async fn execute_with_options(
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
if self.base.full_text_search.is_some() {
|
||||
if self.request.base.full_text_search.is_some() {
|
||||
let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
|
||||
return Ok(hybrid_result);
|
||||
}
|
||||
@@ -990,13 +1066,14 @@ impl ExecutableQuery for VectorQuery {
|
||||
}
|
||||
|
||||
async fn explain_plan(&self, verbose: bool) -> Result<String> {
|
||||
self.base.parent.explain_plan(self, verbose).await
|
||||
let query = AnyQuery::VectorQuery(self.request.clone());
|
||||
self.parent.explain_plan(&query, verbose).await
|
||||
}
|
||||
}
|
||||
|
||||
impl HasQuery for VectorQuery {
|
||||
fn mut_query(&mut self) -> &mut Query {
|
||||
&mut self.base
|
||||
fn mut_query(&mut self) -> &mut QueryRequest {
|
||||
&mut self.request.base
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1015,7 +1092,7 @@ mod tests {
|
||||
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::{connect, connection::CreateTableMode, Table};
|
||||
use crate::{connect, database::CreateTableMode, Table};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_setters_getters() {
|
||||
@@ -1036,7 +1113,13 @@ mod tests {
|
||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||
let query = table.query().nearest_to(&[0.1, 0.2]).unwrap();
|
||||
assert_eq!(
|
||||
*query.query_vector.first().unwrap().as_ref().as_primitive(),
|
||||
*query
|
||||
.request
|
||||
.query_vector
|
||||
.first()
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.as_primitive(),
|
||||
vector
|
||||
);
|
||||
|
||||
@@ -1054,15 +1137,21 @@ mod tests {
|
||||
.refine_factor(999);
|
||||
|
||||
assert_eq!(
|
||||
*query.query_vector.first().unwrap().as_ref().as_primitive(),
|
||||
*query
|
||||
.request
|
||||
.query_vector
|
||||
.first()
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.as_primitive(),
|
||||
new_vector
|
||||
);
|
||||
assert_eq!(query.base.limit.unwrap(), 100);
|
||||
assert_eq!(query.base.offset.unwrap(), 1);
|
||||
assert_eq!(query.nprobes, 1000);
|
||||
assert!(query.use_index);
|
||||
assert_eq!(query.distance_type, Some(DistanceType::Cosine));
|
||||
assert_eq!(query.refine_factor, Some(999));
|
||||
assert_eq!(query.request.base.limit.unwrap(), 100);
|
||||
assert_eq!(query.request.base.offset.unwrap(), 1);
|
||||
assert_eq!(query.request.nprobes, 1000);
|
||||
assert!(query.request.use_index);
|
||||
assert_eq!(query.request.distance_type, Some(DistanceType::Cosine));
|
||||
assert_eq!(query.request.refine_factor, Some(999));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -14,6 +14,7 @@ pub(crate) mod util;
|
||||
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
||||
#[cfg(test)]
|
||||
const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
|
||||
#[cfg(test)]
|
||||
const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, RetryConfig, TimeoutConfig};
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::{future::Future, time::Duration};
|
||||
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
|
||||
|
||||
use http::HeaderName;
|
||||
use log::debug;
|
||||
use reqwest::{
|
||||
header::{HeaderMap, HeaderValue},
|
||||
@@ -15,7 +16,7 @@ use crate::remote::db::RemoteOptions;
|
||||
const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
|
||||
/// Configuration for the LanceDB Cloud HTTP client.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClientConfig {
|
||||
pub timeout_config: TimeoutConfig,
|
||||
pub retry_config: RetryConfig,
|
||||
@@ -23,6 +24,7 @@ pub struct ClientConfig {
|
||||
/// name and version.
|
||||
pub user_agent: String,
|
||||
// TODO: how to configure request ids?
|
||||
pub extra_headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
@@ -31,12 +33,13 @@ impl Default for ClientConfig {
|
||||
timeout_config: TimeoutConfig::default(),
|
||||
retry_config: RetryConfig::default(),
|
||||
user_agent: concat!("LanceDB-Rust-Client/", env!("CARGO_PKG_VERSION")).into(),
|
||||
extra_headers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// How to handle timeouts for HTTP requests.
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct TimeoutConfig {
|
||||
/// The timeout for creating a connection to the server.
|
||||
///
|
||||
@@ -62,7 +65,7 @@ pub struct TimeoutConfig {
|
||||
}
|
||||
|
||||
/// How to handle retries for HTTP requests.
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct RetryConfig {
|
||||
/// The number of times to retry a request if it fails.
|
||||
///
|
||||
@@ -256,6 +259,7 @@ impl RestfulLanceDbClient<Sender> {
|
||||
host_override.is_some(),
|
||||
options,
|
||||
db_prefix,
|
||||
&client_config,
|
||||
)?)
|
||||
.user_agent(client_config.user_agent)
|
||||
.build()
|
||||
@@ -291,6 +295,7 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
has_host_override: bool,
|
||||
options: &RemoteOptions,
|
||||
db_prefix: Option<&str>,
|
||||
config: &ClientConfig,
|
||||
) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
@@ -345,6 +350,18 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
);
|
||||
}
|
||||
|
||||
for (key, value) in &config.extra_headers {
|
||||
let key_parsed = HeaderName::from_str(key).map_err(|_| Error::InvalidInput {
|
||||
message: format!("non-ascii value for header '{}' provided", key),
|
||||
})?;
|
||||
headers.insert(
|
||||
key_parsed,
|
||||
HeaderValue::from_str(value).map_err(|_| Error::InvalidInput {
|
||||
message: format!("non-ascii value for header '{}' provided", key),
|
||||
})?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_array::RecordBatchIterator;
|
||||
use async_trait::async_trait;
|
||||
use http::StatusCode;
|
||||
use lance_io::object_store::StorageOptions;
|
||||
@@ -13,13 +13,12 @@ use reqwest::header::CONTENT_TYPE;
|
||||
use serde::Deserialize;
|
||||
use tokio::task::spawn_blocking;
|
||||
|
||||
use crate::connection::{
|
||||
ConnectionInternal, CreateTableBuilder, CreateTableMode, NoData, OpenTableBuilder,
|
||||
TableNamesBuilder,
|
||||
use crate::database::{
|
||||
CreateTableData, CreateTableMode, CreateTableRequest, Database, OpenTableRequest,
|
||||
TableNamesRequest,
|
||||
};
|
||||
use crate::embeddings::EmbeddingRegistry;
|
||||
use crate::error::Result;
|
||||
use crate::Table;
|
||||
use crate::table::BaseTable;
|
||||
|
||||
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
|
||||
use super::table::RemoteTable;
|
||||
@@ -105,13 +104,13 @@ impl From<&CreateTableMode> for &'static str {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>> {
|
||||
impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>> {
|
||||
let mut req = self.client.get("/v1/table/");
|
||||
if let Some(limit) = options.limit {
|
||||
if let Some(limit) = request.limit {
|
||||
req = req.query(&[("limit", limit)]);
|
||||
}
|
||||
if let Some(start_after) = options.start_after {
|
||||
if let Some(start_after) = request.start_after {
|
||||
req = req.query(&[("page_token", start_after)]);
|
||||
}
|
||||
let (request_id, rsp) = self.client.send(req, true).await?;
|
||||
@@ -127,11 +126,15 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
Ok(tables)
|
||||
}
|
||||
|
||||
async fn do_create_table(
|
||||
&self,
|
||||
options: CreateTableBuilder<false, NoData>,
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<Table> {
|
||||
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
let data = match request.data {
|
||||
CreateTableData::Data(data) => data,
|
||||
CreateTableData::Empty(table_definition) => {
|
||||
let schema = table_definition.schema.clone();
|
||||
Box::new(RecordBatchIterator::new(vec![], schema))
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: https://github.com/lancedb/lancedb/issues/1026
|
||||
// We should accept data from an async source. In the meantime, spawn this as blocking
|
||||
// to make sure we don't block the tokio runtime if the source is slow.
|
||||
@@ -141,8 +144,8 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/create/", options.name))
|
||||
.query(&[("mode", Into::<&str>::into(&options.mode))])
|
||||
.post(&format!("/v1/table/{}/create/", request.name))
|
||||
.query(&[("mode", Into::<&str>::into(&request.mode))])
|
||||
.body(data_buffer)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
|
||||
|
||||
@@ -151,14 +154,18 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
if rsp.status() == StatusCode::BAD_REQUEST {
|
||||
let body = rsp.text().await.err_to_http(request_id.clone())?;
|
||||
if body.contains("already exists") {
|
||||
return match options.mode {
|
||||
return match request.mode {
|
||||
CreateTableMode::Create => {
|
||||
Err(crate::Error::TableAlreadyExists { name: options.name })
|
||||
Err(crate::Error::TableAlreadyExists { name: request.name })
|
||||
}
|
||||
CreateTableMode::ExistOk(callback) => {
|
||||
let builder = OpenTableBuilder::new(options.parent, options.name);
|
||||
let builder = (callback)(builder);
|
||||
builder.execute().await
|
||||
let req = OpenTableRequest {
|
||||
name: request.name.clone(),
|
||||
index_cache_size: None,
|
||||
lance_read_params: None,
|
||||
};
|
||||
let req = (callback)(req);
|
||||
self.open_table(req).await
|
||||
}
|
||||
|
||||
// This should not happen, as we explicitly set the mode to overwrite and the server
|
||||
@@ -183,31 +190,31 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
|
||||
self.client.check_response(&request_id, rsp).await?;
|
||||
|
||||
self.table_cache.insert(options.name.clone(), ()).await;
|
||||
self.table_cache.insert(request.name.clone(), ()).await;
|
||||
|
||||
Ok(Table::new(Arc::new(RemoteTable::new(
|
||||
Ok(Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
options.name,
|
||||
))))
|
||||
request.name,
|
||||
)))
|
||||
}
|
||||
|
||||
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<Table> {
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
// We describe the table to confirm it exists before moving on.
|
||||
if self.table_cache.get(&options.name).is_none() {
|
||||
if self.table_cache.get(&request.name).await.is_none() {
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/describe/", options.name));
|
||||
.post(&format!("/v1/table/{}/describe/", request.name));
|
||||
let (request_id, resp) = self.client.send(req, true).await?;
|
||||
if resp.status() == StatusCode::NOT_FOUND {
|
||||
return Err(crate::Error::TableNotFound { name: options.name });
|
||||
return Err(crate::Error::TableNotFound { name: request.name });
|
||||
}
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
}
|
||||
|
||||
Ok(Table::new(Arc::new(RemoteTable::new(
|
||||
Ok(Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
options.name,
|
||||
))))
|
||||
request.name,
|
||||
)))
|
||||
}
|
||||
|
||||
async fn rename_table(&self, current_name: &str, new_name: &str) -> Result<()> {
|
||||
@@ -230,14 +237,14 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_db(&self) -> Result<()> {
|
||||
async fn drop_all_tables(&self) -> Result<()> {
|
||||
Err(crate::Error::NotSupported {
|
||||
message: "Dropping databases is not supported in the remote API".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||
todo!()
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,7 +280,7 @@ mod tests {
|
||||
|
||||
use crate::connection::ConnectBuilder;
|
||||
use crate::{
|
||||
connection::CreateTableMode,
|
||||
database::CreateTableMode,
|
||||
remote::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE},
|
||||
Connection, Error,
|
||||
};
|
||||
|
||||
@@ -2,21 +2,22 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::io::Cursor;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::index::Index;
|
||||
use crate::index::IndexStatistics;
|
||||
use crate::query::Select;
|
||||
use crate::table::AddDataMode;
|
||||
use crate::query::{QueryRequest, Select, VectorQueryRequest};
|
||||
use crate::table::{AddDataMode, AnyQuery, Filter};
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::{DistanceType, Error, Table};
|
||||
use crate::{DistanceType, Error};
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_ipc::reader::FileReader;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
|
||||
use datafusion_physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
|
||||
use futures::TryStreamExt;
|
||||
use http::header::CONTENT_TYPE;
|
||||
use http::StatusCode;
|
||||
@@ -31,16 +32,16 @@ use crate::{
|
||||
connection::NoData,
|
||||
error::Result,
|
||||
index::{IndexBuilder, IndexConfig},
|
||||
query::{Query, QueryExecutionOptions, VectorQuery},
|
||||
query::QueryExecutionOptions,
|
||||
table::{
|
||||
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
|
||||
TableDefinition, TableInternal, UpdateBuilder,
|
||||
merge::MergeInsertBuilder, AddDataBuilder, BaseTable, OptimizeAction, OptimizeStats,
|
||||
TableDefinition, UpdateBuilder,
|
||||
},
|
||||
};
|
||||
|
||||
use super::client::RequestResultExt;
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE};
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RemoteTable<S: HttpSend = Sender> {
|
||||
@@ -147,7 +148,7 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
|
||||
}
|
||||
|
||||
fn apply_query_params(body: &mut serde_json::Value, params: &Query) -> Result<()> {
|
||||
fn apply_query_params(body: &mut serde_json::Value, params: &QueryRequest) -> Result<()> {
|
||||
if let Some(offset) = params.offset {
|
||||
body["offset"] = serde_json::Value::Number(serde_json::Number::from(offset));
|
||||
}
|
||||
@@ -204,10 +205,10 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
}
|
||||
|
||||
fn apply_vector_query_params(
|
||||
mut body: serde_json::Value,
|
||||
query: &VectorQuery,
|
||||
) -> Result<Vec<serde_json::Value>> {
|
||||
Self::apply_query_params(&mut body, &query.base)?;
|
||||
body: &mut serde_json::Value,
|
||||
query: &VectorQueryRequest,
|
||||
) -> Result<()> {
|
||||
Self::apply_query_params(body, &query.base)?;
|
||||
|
||||
// Apply general parameters, before we dispatch based on number of query vectors.
|
||||
body["prefilter"] = query.base.prefilter.into();
|
||||
@@ -253,22 +254,21 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
0 => {
|
||||
// Server takes empty vector, not null or undefined.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
Ok(vec![body])
|
||||
}
|
||||
1 => {
|
||||
body["vector"] = vector_to_json(&query.query_vector[0])?;
|
||||
Ok(vec![body])
|
||||
}
|
||||
_ => {
|
||||
let mut bodies = Vec::with_capacity(query.query_vector.len());
|
||||
for vector in &query.query_vector {
|
||||
let mut body = body.clone();
|
||||
body["vector"] = vector_to_json(vector)?;
|
||||
bodies.push(body);
|
||||
}
|
||||
Ok(bodies)
|
||||
let vectors = query
|
||||
.query_vector
|
||||
.iter()
|
||||
.map(vector_to_json)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
body["vector"] = serde_json::Value::Array(vectors);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn check_mutable(&self) -> Result<()> {
|
||||
@@ -288,6 +288,33 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
let read_guard = self.version.read().await;
|
||||
*read_guard
|
||||
}
|
||||
|
||||
async fn execute_query(
|
||||
&self,
|
||||
query: &AnyQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<Pin<Box<dyn RecordBatchStream + Send>>> {
|
||||
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
|
||||
let version = self.current_version().await;
|
||||
let mut body = serde_json::json!({ "version": version });
|
||||
|
||||
match query {
|
||||
AnyQuery::Query(query) => {
|
||||
Self::apply_query_params(&mut body, query)?;
|
||||
// Empty vector can be passed if no vector search is performed.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
}
|
||||
AnyQuery::VectorQuery(query) => {
|
||||
Self::apply_vector_query_params(&mut body, query)?;
|
||||
}
|
||||
}
|
||||
|
||||
let request = request.json(&body);
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
let stream = self.read_arrow_stream(&request_id, response).await?;
|
||||
Ok(stream)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -325,13 +352,10 @@ mod test_utils {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
fn as_native(&self) -> Option<&NativeTable> {
|
||||
None
|
||||
}
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
@@ -398,7 +422,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
let schema = self.describe().await?.schema;
|
||||
Ok(Arc::new(schema.try_into()?))
|
||||
}
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
async fn count_rows(&self, filter: Option<Filter>) -> Result<usize> {
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/count_rows/", self.name));
|
||||
@@ -406,6 +430,11 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
let version = self.current_version().await;
|
||||
|
||||
if let Some(filter) = filter {
|
||||
let Filter::Sql(filter) = filter else {
|
||||
return Err(Error::NotSupported {
|
||||
message: "querying a remote table with a datafusion filter".to_string(),
|
||||
});
|
||||
};
|
||||
request = request.json(&serde_json::json!({ "predicate": filter, "version": version }));
|
||||
} else {
|
||||
let body = serde_json::json!({ "version": version });
|
||||
@@ -453,59 +482,19 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
|
||||
let version = self.current_version().await;
|
||||
let body = serde_json::json!({ "version": version });
|
||||
let bodies = Self::apply_vector_query_params(body, query)?;
|
||||
|
||||
let mut futures = Vec::with_capacity(bodies.len());
|
||||
for body in bodies {
|
||||
let request = request.try_clone().unwrap().json(&body);
|
||||
let future = async move {
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
self.read_arrow_stream(&request_id, response).await
|
||||
};
|
||||
futures.push(future);
|
||||
}
|
||||
let streams = futures::future::try_join_all(futures).await?;
|
||||
if streams.len() == 1 {
|
||||
let stream = streams.into_iter().next().unwrap();
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
} else {
|
||||
let stream_execs = streams
|
||||
.into_iter()
|
||||
.map(|stream| Arc::new(OneShotExec::new(stream)) as Arc<dyn ExecutionPlan>)
|
||||
.collect();
|
||||
Table::multi_vector_plan(stream_execs)
|
||||
}
|
||||
let stream = self.execute_query(query, options).await?;
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
}
|
||||
|
||||
async fn plain_query(
|
||||
async fn query(
|
||||
&self,
|
||||
query: &Query,
|
||||
query: &AnyQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/query/", self.name))
|
||||
.header(CONTENT_TYPE, JSON_CONTENT_TYPE);
|
||||
|
||||
let version = self.current_version().await;
|
||||
let mut body = serde_json::json!({ "version": version });
|
||||
Self::apply_query_params(&mut body, query)?;
|
||||
// Empty vector can be passed if no vector search is performed.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
|
||||
let stream = self.read_arrow_stream(&request_id, response).await?;
|
||||
|
||||
let stream = self.execute_query(query, _options).await?;
|
||||
Ok(DatasetRecordBatchStream::new(stream))
|
||||
}
|
||||
async fn update(&self, update: UpdateBuilder) -> Result<u64> {
|
||||
@@ -820,11 +809,14 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
Ok(Some(stats))
|
||||
}
|
||||
|
||||
/// Not yet supported on LanceDB Cloud.
|
||||
async fn drop_index(&self, _name: &str) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "Drop index is not yet supported on LanceDB Cloud.".into(),
|
||||
})
|
||||
async fn drop_index(&self, index_name: &str) -> Result<()> {
|
||||
let request = self.client.post(&format!(
|
||||
"/v1/table/{}/index/{}/drop/",
|
||||
self.name, index_name
|
||||
));
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn table_definition(&self) -> Result<TableDefinition> {
|
||||
@@ -888,6 +880,7 @@ mod tests {
|
||||
use reqwest::Body;
|
||||
|
||||
use crate::index::vector::IvfFlatIndexBuilder;
|
||||
use crate::remote::JSON_CONTENT_TYPE;
|
||||
use crate::{
|
||||
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
@@ -1468,9 +1461,21 @@ mod tests {
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
let query_vectors = body["vector"].as_array().unwrap();
|
||||
assert_eq!(query_vectors.len(), 2);
|
||||
assert_eq!(query_vectors[0].as_array().unwrap().len(), 3);
|
||||
assert_eq!(query_vectors[1].as_array().unwrap().len(), 3);
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("a", DataType::Int32, false),
|
||||
Field::new("query_index", DataType::Int32, false),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])),
|
||||
Arc::new(Int32Array::from(vec![0, 0, 0, 1, 1, 1])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_file(&data);
|
||||
@@ -1487,8 +1492,6 @@ mod tests {
|
||||
.unwrap()
|
||||
.add_query_vector(vec![0.4, 0.5, 0.6])
|
||||
.unwrap();
|
||||
let plan = query.explain_plan(true).await.unwrap();
|
||||
assert!(plan.contains("UnionExec"), "Plan: {}", plan);
|
||||
|
||||
let results = query
|
||||
.execute()
|
||||
@@ -2022,4 +2025,17 @@ mod tests {
|
||||
|
||||
table.drop_columns(&["a", "b"]).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drop_index() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.url().path(),
|
||||
"/v1/table/my_table/index/my_index/drop/"
|
||||
);
|
||||
http::Response::builder().status(200).body("{}").unwrap()
|
||||
});
|
||||
table.drop_index("my_index").await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ use arrow::datatypes::{Float32Type, UInt8Type};
|
||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_physical_plan::display::DisplayableExecutionPlan;
|
||||
use datafusion_physical_plan::projection::ProjectionExec;
|
||||
use datafusion_physical_plan::repartition::RepartitionExec;
|
||||
@@ -21,12 +22,13 @@ use futures::{StreamExt, TryStreamExt};
|
||||
use lance::dataset::builder::DatasetBuilder;
|
||||
use lance::dataset::cleanup::RemovalStats;
|
||||
use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions};
|
||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::scanner::Scanner;
|
||||
pub use lance::dataset::ColumnAlteration;
|
||||
pub use lance::dataset::NewColumnTransform;
|
||||
pub use lance::dataset::ReadParams;
|
||||
pub use lance::dataset::Version;
|
||||
use lance::dataset::{
|
||||
Dataset, InsertBuilder, UpdateBuilder as LanceUpdateBuilder, Version, WhenMatched, WriteMode,
|
||||
Dataset, InsertBuilder, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode,
|
||||
WriteParams,
|
||||
};
|
||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||
@@ -60,7 +62,8 @@ use crate::index::{
|
||||
};
|
||||
use crate::index::{IndexConfig, IndexStatisticsImpl};
|
||||
use crate::query::{
|
||||
IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K,
|
||||
IntoQueryVector, Query, QueryExecutionOptions, QueryRequest, Select, VectorQuery,
|
||||
VectorQueryRequest, DEFAULT_TOP_K,
|
||||
};
|
||||
use crate::utils::{
|
||||
default_vector_column, supported_bitmap_data_type, supported_btree_data_type,
|
||||
@@ -71,11 +74,13 @@ use crate::utils::{
|
||||
use self::dataset::DatasetConsistencyWrapper;
|
||||
use self::merge::MergeInsertBuilder;
|
||||
|
||||
pub mod datafusion;
|
||||
pub(crate) mod dataset;
|
||||
pub mod merge;
|
||||
|
||||
pub use chrono::Duration;
|
||||
pub use lance::dataset::optimize::CompactionOptions;
|
||||
pub use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
pub use lance_index::optimize::OptimizeOptions;
|
||||
|
||||
/// Defines the type of column
|
||||
@@ -230,6 +235,24 @@ pub struct OptimizeStats {
|
||||
pub prune: Option<RemovalStats>,
|
||||
}
|
||||
|
||||
/// Describes what happens when a vector either contains NaN or
|
||||
/// does not have enough values
|
||||
#[derive(Clone, Debug, Default)]
|
||||
enum BadVectorHandling {
|
||||
/// An error is returned
|
||||
#[default]
|
||||
Error,
|
||||
#[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992
|
||||
/// The offending row is droppped
|
||||
Drop,
|
||||
#[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992
|
||||
/// The invalid/missing items are replaced by fill_value
|
||||
Fill(f32),
|
||||
#[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992
|
||||
/// The invalid items are replaced by NULL
|
||||
None,
|
||||
}
|
||||
|
||||
/// Options to use when writing data
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct WriteOptions {
|
||||
@@ -255,7 +278,7 @@ pub enum AddDataMode {
|
||||
/// A builder for configuring a [`crate::connection::Connection::create_table`] or [`Table::add`]
|
||||
/// operation
|
||||
pub struct AddDataBuilder<T: IntoArrow> {
|
||||
parent: Arc<dyn TableInternal>,
|
||||
parent: Arc<dyn BaseTable>,
|
||||
pub(crate) data: T,
|
||||
pub(crate) mode: AddDataMode,
|
||||
pub(crate) write_options: WriteOptions,
|
||||
@@ -300,13 +323,13 @@ impl<T: IntoArrow> AddDataBuilder<T> {
|
||||
/// A builder for configuring an [`Table::update`] operation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UpdateBuilder {
|
||||
parent: Arc<dyn TableInternal>,
|
||||
parent: Arc<dyn BaseTable>,
|
||||
pub(crate) filter: Option<String>,
|
||||
pub(crate) columns: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl UpdateBuilder {
|
||||
fn new(parent: Arc<dyn TableInternal>) -> Self {
|
||||
fn new(parent: Arc<dyn BaseTable>) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
filter: None,
|
||||
@@ -363,64 +386,102 @@ impl UpdateBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Filters that can be used to limit the rows returned by a query
|
||||
pub enum Filter {
|
||||
/// A SQL filter string
|
||||
Sql(String),
|
||||
/// A Datafusion logical expression
|
||||
Datafusion(Expr),
|
||||
}
|
||||
|
||||
/// A query that can be used to search a LanceDB table
|
||||
pub enum AnyQuery {
|
||||
Query(QueryRequest),
|
||||
VectorQuery(VectorQueryRequest),
|
||||
}
|
||||
|
||||
/// A trait for anything "table-like". This is used for both native tables (which target
|
||||
/// Lance datasets) and remote tables (which target LanceDB cloud)
|
||||
///
|
||||
/// This trait is still EXPERIMENTAL and subject to change in the future
|
||||
#[async_trait]
|
||||
pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
#[allow(dead_code)]
|
||||
pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
/// Get a reference to std::any::Any
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
||||
fn as_native(&self) -> Option<&NativeTable>;
|
||||
/// Get the name of the table.
|
||||
fn name(&self) -> &str;
|
||||
/// Get the arrow [Schema] of the table.
|
||||
async fn schema(&self) -> Result<SchemaRef>;
|
||||
/// Count the number of rows in this table.
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
|
||||
async fn count_rows(&self, filter: Option<Filter>) -> Result<usize>;
|
||||
/// Create a physical plan for the query.
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>>;
|
||||
async fn plain_query(
|
||||
/// Execute a query and return the results as a stream of RecordBatches.
|
||||
async fn query(
|
||||
&self,
|
||||
query: &Query,
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream>;
|
||||
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String> {
|
||||
/// Explain the plan for a query.
|
||||
async fn explain_plan(&self, query: &AnyQuery, verbose: bool) -> Result<String> {
|
||||
let plan = self.create_plan(query, Default::default()).await?;
|
||||
let display = DisplayableExecutionPlan::new(plan.as_ref());
|
||||
|
||||
Ok(format!("{}", display.indent(verbose)))
|
||||
}
|
||||
/// Add new records to the table.
|
||||
async fn add(
|
||||
&self,
|
||||
add: AddDataBuilder<NoData>,
|
||||
data: Box<dyn arrow_array::RecordBatchReader + Send>,
|
||||
) -> Result<()>;
|
||||
/// Delete rows from the table.
|
||||
async fn delete(&self, predicate: &str) -> Result<()>;
|
||||
/// Update rows in the table.
|
||||
async fn update(&self, update: UpdateBuilder) -> Result<u64>;
|
||||
/// Create an index on the provided column(s).
|
||||
async fn create_index(&self, index: IndexBuilder) -> Result<()>;
|
||||
/// List the indices on the table.
|
||||
async fn list_indices(&self) -> Result<Vec<IndexConfig>>;
|
||||
/// Drop an index from the table.
|
||||
async fn drop_index(&self, name: &str) -> Result<()>;
|
||||
/// Get statistics about the index.
|
||||
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>>;
|
||||
/// Merge insert new records into the table.
|
||||
async fn merge_insert(
|
||||
&self,
|
||||
params: MergeInsertBuilder,
|
||||
new_data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<()>;
|
||||
/// Optimize the dataset.
|
||||
async fn optimize(&self, action: OptimizeAction) -> Result<OptimizeStats>;
|
||||
/// Add columns to the table.
|
||||
async fn add_columns(
|
||||
&self,
|
||||
transforms: NewColumnTransform,
|
||||
read_columns: Option<Vec<String>>,
|
||||
) -> Result<()>;
|
||||
/// Alter columns in the table.
|
||||
async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<()>;
|
||||
/// Drop columns from the table.
|
||||
async fn drop_columns(&self, columns: &[&str]) -> Result<()>;
|
||||
/// Get the version of the table.
|
||||
async fn version(&self) -> Result<u64>;
|
||||
/// Checkout a specific version of the table.
|
||||
async fn checkout(&self, version: u64) -> Result<()>;
|
||||
/// Checkout the latest version of the table.
|
||||
async fn checkout_latest(&self) -> Result<()>;
|
||||
/// Restore the table to the currently checked out version.
|
||||
async fn restore(&self) -> Result<()>;
|
||||
/// List the versions of the table.
|
||||
async fn list_versions(&self) -> Result<Vec<Version>>;
|
||||
/// Get the table definition.
|
||||
async fn table_definition(&self) -> Result<TableDefinition>;
|
||||
/// Get the table URI
|
||||
fn dataset_uri(&self) -> &str;
|
||||
}
|
||||
|
||||
@@ -429,7 +490,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
/// The type of the each row is defined in Apache Arrow [Schema].
|
||||
#[derive(Clone)]
|
||||
pub struct Table {
|
||||
inner: Arc<dyn TableInternal>,
|
||||
inner: Arc<dyn BaseTable>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
|
||||
@@ -465,15 +526,19 @@ impl std::fmt::Display for Table {
|
||||
}
|
||||
|
||||
impl Table {
|
||||
pub(crate) fn new(inner: Arc<dyn TableInternal>) -> Self {
|
||||
pub fn new(inner: Arc<dyn BaseTable>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn base_table(&self) -> &Arc<dyn BaseTable> {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
pub(crate) fn new_with_embedding_registry(
|
||||
inner: Arc<dyn TableInternal>,
|
||||
inner: Arc<dyn BaseTable>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
@@ -506,7 +571,7 @@ impl Table {
|
||||
///
|
||||
/// * `filter` if present, only count rows matching the filter
|
||||
pub async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
self.inner.count_rows(filter).await
|
||||
self.inner.count_rows(filter.map(Filter::Sql)).await
|
||||
}
|
||||
|
||||
/// Insert new records into this Table
|
||||
@@ -1045,6 +1110,17 @@ impl From<NativeTable> for Table {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait NativeTableExt {
|
||||
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
||||
fn as_native(&self) -> Option<&NativeTable>;
|
||||
}
|
||||
|
||||
impl NativeTableExt for Arc<dyn BaseTable> {
|
||||
fn as_native(&self) -> Option<&NativeTable> {
|
||||
self.as_any().downcast_ref::<NativeTable>()
|
||||
}
|
||||
}
|
||||
|
||||
/// A table in a LanceDB database.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NativeTable {
|
||||
@@ -1164,7 +1240,7 @@ impl NativeTable {
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [TableImpl] object.
|
||||
pub(crate) async fn create(
|
||||
pub async fn create(
|
||||
uri: &str,
|
||||
name: &str,
|
||||
batches: impl RecordBatchReader + Send + 'static,
|
||||
@@ -1304,10 +1380,11 @@ impl NativeTable {
|
||||
|
||||
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
|
||||
let dataset = self.dataset.get().await?;
|
||||
let (indices, mf) = futures::try_join!(dataset.load_indices(), dataset.latest_manifest())?;
|
||||
let mf = dataset.manifest();
|
||||
let indices = dataset.load_indices().await?;
|
||||
Ok(indices
|
||||
.iter()
|
||||
.map(|i| VectorIndex::new_from_format(&(mf.0), i))
|
||||
.map(|i| VectorIndex::new_from_format(mf, i))
|
||||
.collect())
|
||||
}
|
||||
|
||||
@@ -1658,7 +1735,7 @@ impl NativeTable {
|
||||
|
||||
async fn generic_query(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
let plan = self.create_plan(query, options).await?;
|
||||
@@ -1748,15 +1825,11 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TableInternal for NativeTable {
|
||||
impl BaseTable for NativeTable {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_native(&self) -> Option<&NativeTable> {
|
||||
Some(self)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
@@ -1812,8 +1885,15 @@ impl TableInternal for NativeTable {
|
||||
TableDefinition::try_from_rich_schema(schema)
|
||||
}
|
||||
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
Ok(self.dataset.get().await?.count_rows(filter).await?)
|
||||
async fn count_rows(&self, filter: Option<Filter>) -> Result<usize> {
|
||||
let dataset = self.dataset.get().await?;
|
||||
match filter {
|
||||
None => Ok(dataset.count_rows(None).await?),
|
||||
Some(Filter::Sql(sql)) => Ok(dataset.count_rows(Some(sql)).await?),
|
||||
Some(Filter::Datafusion(_)) => Err(Error::NotSupported {
|
||||
message: "Datafusion filters are not yet supported".to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn add(
|
||||
@@ -1907,12 +1987,17 @@ impl TableInternal for NativeTable {
|
||||
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let query = match query {
|
||||
AnyQuery::VectorQuery(query) => query.clone(),
|
||||
AnyQuery::Query(query) => VectorQueryRequest::from_plain_query(query.clone()),
|
||||
};
|
||||
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
let mut column = query.column.clone();
|
||||
let schema = ds_ref.schema();
|
||||
let mut column = query.column.clone();
|
||||
|
||||
let mut query_vector = query.query_vector.first().cloned();
|
||||
if query.query_vector.len() > 1 {
|
||||
@@ -1957,7 +2042,10 @@ impl TableInternal for NativeTable {
|
||||
let mut sub_query = query.clone();
|
||||
sub_query.query_vector = vec![query_vector];
|
||||
let options_ref = options.clone();
|
||||
async move { self.create_plan(&sub_query, options_ref).await }
|
||||
async move {
|
||||
self.create_plan(&AnyQuery::VectorQuery(sub_query), options_ref)
|
||||
.await
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let plans = futures::future::try_join_all(plan_futures).await?;
|
||||
@@ -2055,13 +2143,12 @@ impl TableInternal for NativeTable {
|
||||
Ok(scanner.create_plan().await?)
|
||||
}
|
||||
|
||||
async fn plain_query(
|
||||
async fn query(
|
||||
&self,
|
||||
query: &Query,
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
self.generic_query(&query.clone().into_vector(), options)
|
||||
.await
|
||||
self.generic_query(query, options).await
|
||||
}
|
||||
|
||||
async fn merge_insert(
|
||||
@@ -2330,7 +2417,10 @@ mod tests {
|
||||
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
assert_eq!(
|
||||
table.count_rows(Some("i >= 5".to_string())).await.unwrap(),
|
||||
table
|
||||
.count_rows(Some(Filter::Sql("i >= 5".to_string())))
|
||||
.await
|
||||
.unwrap(),
|
||||
5
|
||||
);
|
||||
}
|
||||
|
||||
263
rust/lancedb/src/table/datafusion.rs
Normal file
263
rust/lancedb/src/table/datafusion.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers.
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use arrow_schema::Schema as ArrowSchema;
|
||||
use async_trait::async_trait;
|
||||
use datafusion_catalog::{Session, TableProvider};
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, Statistics};
|
||||
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
|
||||
use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
|
||||
use datafusion_physical_plan::{
|
||||
stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
|
||||
};
|
||||
use futures::{TryFutureExt, TryStreamExt};
|
||||
|
||||
use super::{AnyQuery, BaseTable};
|
||||
use crate::{
|
||||
query::{QueryExecutionOptions, QueryRequest, Select},
|
||||
Result,
|
||||
};
|
||||
|
||||
/// Datafusion attempts to maintain batch metadata
|
||||
///
|
||||
/// This is needless and it triggers bugs in DF. This operator erases metadata from the batches.
|
||||
#[derive(Debug)]
|
||||
struct MetadataEraserExec {
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
schema: Arc<ArrowSchema>,
|
||||
properties: PlanProperties,
|
||||
}
|
||||
|
||||
impl MetadataEraserExec {
|
||||
fn compute_properties_from_input(
|
||||
input: &Arc<dyn ExecutionPlan>,
|
||||
schema: &Arc<ArrowSchema>,
|
||||
) -> PlanProperties {
|
||||
let input_properties = input.properties();
|
||||
let eq_properties = input_properties
|
||||
.eq_properties
|
||||
.clone()
|
||||
.with_new_schema(schema.clone())
|
||||
.unwrap();
|
||||
input_properties.clone().with_eq_properties(eq_properties)
|
||||
}
|
||||
|
||||
fn new(input: Arc<dyn ExecutionPlan>) -> Self {
|
||||
let schema = Arc::new(
|
||||
input
|
||||
.schema()
|
||||
.as_ref()
|
||||
.clone()
|
||||
.with_metadata(HashMap::new()),
|
||||
);
|
||||
Self {
|
||||
properties: Self::compute_properties_from_input(&input, &schema),
|
||||
input,
|
||||
schema,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DisplayAs for MetadataEraserExec {
|
||||
fn fmt_as(&self, _: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MetadataEraserExec")
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionPlan for MetadataEraserExec {
|
||||
fn name(&self) -> &str {
|
||||
"MetadataEraserExec"
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn properties(&self) -> &PlanProperties {
|
||||
&self.properties
|
||||
}
|
||||
|
||||
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
|
||||
vec![&self.input]
|
||||
}
|
||||
|
||||
fn with_new_children(
|
||||
self: Arc<Self>,
|
||||
children: Vec<Arc<dyn ExecutionPlan>>,
|
||||
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
|
||||
assert_eq!(children.len(), 1);
|
||||
let new_properties = Self::compute_properties_from_input(&children[0], &self.schema);
|
||||
Ok(Arc::new(Self {
|
||||
input: children[0].clone(),
|
||||
schema: self.schema.clone(),
|
||||
properties: new_properties,
|
||||
}))
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&self,
|
||||
partition: usize,
|
||||
context: Arc<TaskContext>,
|
||||
) -> DataFusionResult<SendableRecordBatchStream> {
|
||||
let stream = self.input.execute(partition, context)?;
|
||||
let schema = self.schema.clone();
|
||||
let stream = stream.map_ok(move |batch| batch.with_schema(schema.clone()).unwrap());
|
||||
Ok(
|
||||
Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
|
||||
as SendableRecordBatchStream,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BaseTableAdapter {
|
||||
table: Arc<dyn BaseTable>,
|
||||
schema: Arc<ArrowSchema>,
|
||||
}
|
||||
|
||||
impl BaseTableAdapter {
|
||||
pub async fn try_new(table: Arc<dyn BaseTable>) -> Result<Self> {
|
||||
let schema = Arc::new(
|
||||
table
|
||||
.schema()
|
||||
.await?
|
||||
.as_ref()
|
||||
.clone()
|
||||
.with_metadata(HashMap::default()),
|
||||
);
|
||||
Ok(Self { table, schema })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TableProvider for BaseTableAdapter {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn schema(&self) -> Arc<ArrowSchema> {
|
||||
self.schema.clone()
|
||||
}
|
||||
|
||||
fn table_type(&self) -> TableType {
|
||||
TableType::Base
|
||||
}
|
||||
|
||||
async fn scan(
|
||||
&self,
|
||||
_state: &dyn Session,
|
||||
projection: Option<&Vec<usize>>,
|
||||
filters: &[Expr],
|
||||
limit: Option<usize>,
|
||||
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
|
||||
let mut query = QueryRequest::default();
|
||||
if let Some(projection) = projection {
|
||||
let field_names = projection
|
||||
.iter()
|
||||
.map(|i| self.schema.field(*i).name().to_string())
|
||||
.collect();
|
||||
query.select = Select::Columns(field_names);
|
||||
}
|
||||
assert!(filters.is_empty());
|
||||
if let Some(limit) = limit {
|
||||
query.limit = Some(limit);
|
||||
} else {
|
||||
// Need to override the default of 10
|
||||
query.limit = None;
|
||||
}
|
||||
let plan = self
|
||||
.table
|
||||
.create_plan(&AnyQuery::Query(query), QueryExecutionOptions::default())
|
||||
.map_err(|err| DataFusionError::External(err.into()))
|
||||
.await?;
|
||||
Ok(Arc::new(MetadataEraserExec::new(plan)))
|
||||
}
|
||||
|
||||
fn supports_filters_pushdown(
|
||||
&self,
|
||||
filters: &[&Expr],
|
||||
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
|
||||
// TODO: Pushdown unsupported until we can support datafusion filters in BaseTable::create_plan
|
||||
Ok(vec![
|
||||
TableProviderFilterPushDown::Unsupported;
|
||||
filters.len()
|
||||
])
|
||||
}
|
||||
|
||||
fn statistics(&self) -> Option<Statistics> {
|
||||
// TODO
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use datafusion::{datasource::provider_as_source, prelude::SessionContext};
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_expr::LogicalPlanBuilder;
|
||||
use futures::TryStreamExt;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::{connect, table::datafusion::BaseTableAdapter};
|
||||
|
||||
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
|
||||
let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]);
|
||||
let schema = Arc::new(
|
||||
Schema::new(vec![Field::new("i", DataType::Int32, false)]).with_metadata(metadata),
|
||||
);
|
||||
RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(0..10))],
|
||||
)],
|
||||
schema,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metadata_erased() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
|
||||
let tbl = db
|
||||
.create_table("foo", make_test_batches())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let provider = Arc::new(
|
||||
BaseTableAdapter::try_new(tbl.base_table().clone())
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
assert!(provider.schema().metadata().is_empty());
|
||||
|
||||
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(provider), None)
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let mut stream = SessionContext::new()
|
||||
.execute_logical_plan(plan)
|
||||
.await
|
||||
.unwrap()
|
||||
.execute_stream()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
while let Some(batch) = stream.try_next().await.unwrap() {
|
||||
assert!(batch.schema().metadata().is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,14 +7,14 @@ use arrow_array::RecordBatchReader;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
use super::TableInternal;
|
||||
use super::BaseTable;
|
||||
|
||||
/// A builder used to create and run a merge insert operation
|
||||
///
|
||||
/// See [`super::Table::merge_insert`] for more context
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MergeInsertBuilder {
|
||||
table: Arc<dyn TableInternal>,
|
||||
table: Arc<dyn BaseTable>,
|
||||
pub(crate) on: Vec<String>,
|
||||
pub(crate) when_matched_update_all: bool,
|
||||
pub(crate) when_matched_update_all_filt: Option<String>,
|
||||
@@ -24,7 +24,7 @@ pub struct MergeInsertBuilder {
|
||||
}
|
||||
|
||||
impl MergeInsertBuilder {
|
||||
pub(super) fn new(table: Arc<dyn TableInternal>, on: Vec<String>) -> Self {
|
||||
pub(super) fn new(table: Arc<dyn BaseTable>, on: Vec<String>) -> Self {
|
||||
Self {
|
||||
table,
|
||||
on,
|
||||
|
||||
Reference in New Issue
Block a user