mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
45 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b37c58342e | ||
|
|
a06e64f22d | ||
|
|
e983198f0e | ||
|
|
76e7b4abf8 | ||
|
|
5f6eb4651e | ||
|
|
805c78bb20 | ||
|
|
4746281b21 | ||
|
|
7b3b6bdccd | ||
|
|
37e1124c0f | ||
|
|
93f037ee41 | ||
|
|
e4fc06825a | ||
|
|
fe89a373a2 | ||
|
|
3d3915edef | ||
|
|
e2e8b6aee4 | ||
|
|
12dbca5248 | ||
|
|
a6babfa651 | ||
|
|
75ede86fab | ||
|
|
becd649130 | ||
|
|
9d2fb7d602 | ||
|
|
fdb5d6fdf1 | ||
|
|
2f13fa225f | ||
|
|
e933de003d | ||
|
|
05fd387425 | ||
|
|
82a1da554c | ||
|
|
a7c0d80b9e | ||
|
|
71323a064a | ||
|
|
df48454b70 | ||
|
|
6603414885 | ||
|
|
c256f6c502 | ||
|
|
cc03f90379 | ||
|
|
975da09b02 | ||
|
|
c32e17b497 | ||
|
|
0528abdf97 | ||
|
|
1090c311e8 | ||
|
|
e767cbb374 | ||
|
|
3d7c48feca | ||
|
|
08d62550bb | ||
|
|
b272408b05 | ||
|
|
46ffa87cd4 | ||
|
|
cd9fc37b95 | ||
|
|
431f94e564 | ||
|
|
c1a7d65473 | ||
|
|
1e5ccb1614 | ||
|
|
2e7ab373dc | ||
|
|
c7fbc4aaee |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.17
|
||||
current_version = 0.4.20
|
||||
commit = True
|
||||
message = Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,7 +6,7 @@
|
||||
venv
|
||||
|
||||
.vscode
|
||||
|
||||
.zed
|
||||
rust/target
|
||||
rust/Cargo.lock
|
||||
|
||||
|
||||
26
Cargo.toml
26
Cargo.toml
@@ -14,22 +14,22 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.10.12", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.10.12" }
|
||||
lance-linalg = { "version" = "=0.10.12" }
|
||||
lance-testing = { "version" = "=0.10.12" }
|
||||
lance = { "version" = "=0.10.18", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.10.18" }
|
||||
lance-linalg = { "version" = "=0.10.18" }
|
||||
lance-testing = { "version" = "=0.10.18" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "50.0", optional = false }
|
||||
arrow-array = "50.0"
|
||||
arrow-data = "50.0"
|
||||
arrow-ipc = "50.0"
|
||||
arrow-ord = "50.0"
|
||||
arrow-schema = "50.0"
|
||||
arrow-arith = "50.0"
|
||||
arrow-cast = "50.0"
|
||||
arrow = { version = "51.0", optional = false }
|
||||
arrow-array = "51.0"
|
||||
arrow-data = "51.0"
|
||||
arrow-ipc = "51.0"
|
||||
arrow-ord = "51.0"
|
||||
arrow-schema = "51.0"
|
||||
arrow-arith = "51.0"
|
||||
arrow-cast = "51.0"
|
||||
async-trait = "0"
|
||||
chrono = "0.4.35"
|
||||
half = { "version" = "=2.3.1", default-features = false, features = [
|
||||
half = { "version" = "=2.4.1", default-features = false, features = [
|
||||
"num-traits",
|
||||
] }
|
||||
futures = "0"
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
|
||||
<hr />
|
||||
|
||||
LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrevial, filtering and management of embeddings.
|
||||
LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrieval, filtering and management of embeddings.
|
||||
|
||||
The key features of LanceDB include:
|
||||
|
||||
@@ -36,7 +36,7 @@ The key features of LanceDB include:
|
||||
|
||||
* GPU support in building vector index(*).
|
||||
|
||||
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lanecdb.html), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
|
||||
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/docs/integrations/vectorstores/lancedb/), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
|
||||
|
||||
LanceDB's core is written in Rust 🦀 and is built using <a href="https://github.com/lancedb/lance">Lance</a>, an open-source columnar format designed for performant ML workloads.
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ Allows you to set parameters when registering a `sentence-transformers` object.
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
db = lancedb.connect("/tmp/db")
|
||||
model = get_registry.get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5", device="cpu")
|
||||
model = get_registry().get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5", device="cpu")
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
|
||||
@@ -46,7 +46,7 @@ For this purpose, LanceDB introduces an **embedding functions API**, that allow
|
||||
|
||||
```python
|
||||
class Pets(LanceModel):
|
||||
vector: Vector(clip.ndims) = clip.VectorField()
|
||||
vector: Vector(clip.ndims()) = clip.VectorField()
|
||||
image_uri: str = clip.SourceField()
|
||||
```
|
||||
|
||||
@@ -149,7 +149,7 @@ You can also use the integration for adding utility operations in the schema. Fo
|
||||
|
||||
```python
|
||||
class Pets(LanceModel):
|
||||
vector: Vector(clip.ndims) = clip.VectorField()
|
||||
vector: Vector(clip.ndims()) = clip.VectorField()
|
||||
image_uri: str = clip.SourceField()
|
||||
|
||||
@property
|
||||
|
||||
@@ -299,6 +299,14 @@ LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you m
|
||||
|
||||
This can also be done with the ``AWS_ENDPOINT`` and ``AWS_DEFAULT_REGION`` environment variables.
|
||||
|
||||
!!! tip "Local servers"
|
||||
|
||||
For local development, the server often has a `http` endpoint rather than a
|
||||
secure `https` endpoint. In this case, you must also set the `ALLOW_HTTP`
|
||||
environment variable to `true` to allow non-TLS connections, or pass the
|
||||
storage option `allow_http` as `true`. If you do not do this, you will get
|
||||
an error like `URL scheme is not allowed`.
|
||||
|
||||
#### S3 Express
|
||||
|
||||
LanceDB supports [S3 Express One Zone](https://aws.amazon.com/s3/storage-classes/express-one-zone/) endpoints, but requires additional configuration. Also, S3 Express endpoints only support connecting from an EC2 instance within the same region.
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip install --quiet openai datasets \n",
|
||||
"!pip install --quiet openai datasets\n",
|
||||
"!pip install --quiet -U lancedb"
|
||||
]
|
||||
},
|
||||
@@ -213,7 +213,7 @@
|
||||
"if \"OPENAI_API_KEY\" not in os.environ:\n",
|
||||
" # OR set the key here as a variable\n",
|
||||
" os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"client = OpenAI()\n",
|
||||
"assert len(client.models.list().data) > 0"
|
||||
]
|
||||
@@ -234,9 +234,12 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def embed_func(c): \n",
|
||||
"def embed_func(c):\n",
|
||||
" rs = client.embeddings.create(input=c, model=\"text-embedding-ada-002\")\n",
|
||||
" return [rs.data[0].embedding]"
|
||||
" return [\n",
|
||||
" data.embedding\n",
|
||||
" for data in rs.data\n",
|
||||
" ]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -514,7 +517,7 @@
|
||||
" prompt_start +\n",
|
||||
" \"\\n\\n---\\n\\n\".join(context.text) +\n",
|
||||
" prompt_end\n",
|
||||
" ) \n",
|
||||
" )\n",
|
||||
" return prompt"
|
||||
]
|
||||
},
|
||||
|
||||
74
node/package-lock.json
generated
74
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.20",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.20",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -52,11 +52,11 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.17",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.17",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.17",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.17",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.17"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.20",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.20",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.20",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.20",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.20"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-arrow/ts": "^14.0.2",
|
||||
@@ -333,6 +333,66 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.4.20",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.20.tgz",
|
||||
"integrity": "sha512-ffP2K4sA5mQTgePyARw1y8dPN996FmpvyAYoWO+TSItaXlhcXvc+KVa5udNMCZMDYeEnEv2Xpj6k4PwW3oBz+A==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.4.20",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.20.tgz",
|
||||
"integrity": "sha512-GSYsXE20RIehDu30FjREhJdEzhnwOTV7ZsrSXagStzLY1gr7pyd7sfqxmmUtdD09di7LnQoiM71AOpPTa01YwQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.4.20",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.20.tgz",
|
||||
"integrity": "sha512-FpNOjOsz3nJVm6EBGyNgbOW2aFhsWZ/igeY45Z8hbZaaK2YBwrg/DASoNlUzgv6IR8cUaGJ2irNVJfsKR2cG6g==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.4.20",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.20.tgz",
|
||||
"integrity": "sha512-pOqWjrRZQSrLTlQPkjidRii7NZDw8Xu9pN6ouVu2JAK8n81FXaPtFCyAI+Y3v9GpnYDN0rvD4eQ36aHAVPsa2g==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.4.20",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.20.tgz",
|
||||
"integrity": "sha512-5J5SsYSJ7jRCmU/sgwVHdrGz43B/7R2T9OEoFTKyVAtqTZdu75rkytXyn9SyEayXVhlUOaw76N0ASm0hAoDS/A==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
]
|
||||
},
|
||||
"node_modules/@neon-rs/cli": {
|
||||
"version": "0.0.160",
|
||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.20",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -88,10 +88,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.17",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.17",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.17",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.17",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.17"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.20",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.20",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.20",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.20",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.20"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,7 +163,7 @@ export interface CreateTableOptions<T> {
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI.
|
||||
*
|
||||
* Accpeted formats:
|
||||
* Accepted formats:
|
||||
*
|
||||
* - `/path/to/database` - local database
|
||||
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
|
||||
|
||||
@@ -51,7 +51,7 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
||||
|
||||
const dir = tmpdir()
|
||||
console.log(dir)
|
||||
const conn = await lancedb.connect(`s3://lancedb-integtest?mirroredStore=${dir}`)
|
||||
const conn = await lancedb.connect({ uri: `s3://lancedb-integtest?mirroredStore=${dir}`, storageOptions: { allowHttp: 'true' } })
|
||||
const data = Array(200).fill({ vector: Array(128).fill(1.0), id: 0 })
|
||||
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 1 }))
|
||||
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 2 }))
|
||||
|
||||
@@ -140,6 +140,9 @@ export class RemoteConnection implements Connection {
|
||||
schema = nameOrOpts.schema
|
||||
embeddings = nameOrOpts.embeddingFunction
|
||||
tableName = nameOrOpts.name
|
||||
if (data === undefined) {
|
||||
data = nameOrOpts.data
|
||||
}
|
||||
}
|
||||
|
||||
let buffer: Buffer
|
||||
|
||||
1
nodejs/.gitignore
vendored
Normal file
1
nodejs/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
yarn.lock
|
||||
@@ -20,7 +20,7 @@ import { Table as ArrowTable, Schema } from "apache-arrow";
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI.
|
||||
*
|
||||
* Accpeted formats:
|
||||
* Accepted formats:
|
||||
*
|
||||
* - `/path/to/database` - local database
|
||||
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
|
||||
@@ -77,6 +77,18 @@ export interface OpenTableOptions {
|
||||
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
|
||||
*/
|
||||
storageOptions?: Record<string, string>;
|
||||
/**
|
||||
* Set the size of the index cache, specified as a number of entries
|
||||
*
|
||||
* The exact meaning of an "entry" will depend on the type of index:
|
||||
* - IVF: there is one entry for each IVF partition
|
||||
* - BTREE: there is one entry for the entire index
|
||||
*
|
||||
* This cache applies to the entire opened table, across all indices.
|
||||
* Setting this value higher will increase performance on larger datasets
|
||||
* at the expense of more RAM
|
||||
*/
|
||||
indexCacheSize?: number;
|
||||
}
|
||||
|
||||
export interface TableNamesOptions {
|
||||
@@ -160,6 +172,7 @@ export class Connection {
|
||||
const innerTable = await this.inner.openTable(
|
||||
name,
|
||||
cleanseStorageOptions(options?.storageOptions),
|
||||
options?.indexCacheSize,
|
||||
);
|
||||
return new Table(innerTable);
|
||||
}
|
||||
|
||||
@@ -169,17 +169,20 @@ export class Table {
|
||||
* // If the column has a vector (fixed size list) data type then
|
||||
* // an IvfPq vector index will be created.
|
||||
* const table = await conn.openTable("my_table");
|
||||
* await table.createIndex(["vector"]);
|
||||
* await table.createIndex("vector");
|
||||
* @example
|
||||
* // For advanced control over vector index creation you can specify
|
||||
* // the index type and options.
|
||||
* const table = await conn.openTable("my_table");
|
||||
* await table.createIndex(["vector"], I)
|
||||
* .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 })
|
||||
* .build();
|
||||
* await table.createIndex("vector", {
|
||||
* config: lancedb.Index.ivfPq({
|
||||
* numPartitions: 128,
|
||||
* numSubVectors: 16,
|
||||
* }),
|
||||
* });
|
||||
* @example
|
||||
* // Or create a Scalar index
|
||||
* await table.createIndex("my_float_col").build();
|
||||
* await table.createIndex("my_float_col");
|
||||
*/
|
||||
async createIndex(column: string, options?: Partial<IndexOptions>) {
|
||||
// Bit of a hack to get around the fact that TS has no package-scope.
|
||||
@@ -197,8 +200,7 @@ export class Table {
|
||||
* vector similarity, sorting, and more.
|
||||
*
|
||||
* Note: By default, all columns are returned. For best performance, you should
|
||||
* only fetch the columns you need. See [`Query::select_with_projection`] for
|
||||
* more details.
|
||||
* only fetch the columns you need.
|
||||
*
|
||||
* When appropriate, various indices and statistics based pruning will be used to
|
||||
* accelerate the query.
|
||||
@@ -207,8 +209,11 @@ export class Table {
|
||||
* //
|
||||
* // This query will return up to 1000 rows whose value in the `id` column
|
||||
* // is greater than 5. LanceDb supports a broad set of filtering functions.
|
||||
* for await (const batch of table.query()
|
||||
* .filter("id > 1").select(["id"]).limit(20)) {
|
||||
* for await (const batch of table
|
||||
* .query()
|
||||
* .where("id > 1")
|
||||
* .select(["id"])
|
||||
* .limit(20)) {
|
||||
* console.log(batch);
|
||||
* }
|
||||
* @example
|
||||
@@ -218,12 +223,13 @@ export class Table {
|
||||
* // closest to the query vector [1.0, 2.0, 3.0]. If an index has been created
|
||||
* // on the "vector" column then this will perform an ANN search.
|
||||
* //
|
||||
* // The `refine_factor` and `nprobes` methods are used to control the recall /
|
||||
* // The `refineFactor` and `nprobes` methods are used to control the recall /
|
||||
* // latency tradeoff of the search.
|
||||
* for await (const batch of table.query()
|
||||
* .nearestTo([1, 2, 3])
|
||||
* .refineFactor(5).nprobe(10)
|
||||
* .limit(10)) {
|
||||
* for await (const batch of table
|
||||
* .query()
|
||||
* .where("id > 1")
|
||||
* .select(["id"])
|
||||
* .limit(20)) {
|
||||
* console.log(batch);
|
||||
* }
|
||||
* @example
|
||||
@@ -286,43 +292,45 @@ export class Table {
|
||||
await this.inner.dropColumns(columnNames);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve the version of the table
|
||||
*
|
||||
* LanceDb supports versioning. Every operation that modifies the table increases
|
||||
* version. As long as a version hasn't been deleted you can `[Self::checkout]` that
|
||||
* version to view the data at that point. In addition, you can `[Self::restore]` the
|
||||
* version to replace the current table with a previous version.
|
||||
*/
|
||||
/** Retrieve the version of the table */
|
||||
async version(): Promise<number> {
|
||||
return await this.inner.version();
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks out a specific version of the Table
|
||||
* Checks out a specific version of the table _This is an in-place operation._
|
||||
*
|
||||
* Any read operation on the table will now access the data at the checked out version.
|
||||
* As a consequence, calling this method will disable any read consistency interval
|
||||
* that was previously set.
|
||||
* This allows viewing previous versions of the table. If you wish to
|
||||
* keep writing to the dataset starting from an old version, then use
|
||||
* the `restore` function.
|
||||
*
|
||||
* This is a read-only operation that turns the table into a sort of "view"
|
||||
* or "detached head". Other table instances will not be affected. To make the change
|
||||
* permanent you can use the `[Self::restore]` method.
|
||||
* Calling this method will set the table into time-travel mode. If you
|
||||
* wish to return to standard mode, call `checkoutLatest`.
|
||||
* @param {number} version The version to checkout
|
||||
* @example
|
||||
* ```typescript
|
||||
* import * as lancedb from "@lancedb/lancedb"
|
||||
* const db = await lancedb.connect("./.lancedb");
|
||||
* const table = await db.createTable("my_table", [
|
||||
* { vector: [1.1, 0.9], type: "vector" },
|
||||
* ]);
|
||||
*
|
||||
* Any operation that modifies the table will fail while the table is in a checked
|
||||
* out state.
|
||||
*
|
||||
* To return the table to a normal state use `[Self::checkout_latest]`
|
||||
* console.log(await table.version()); // 1
|
||||
* console.log(table.display());
|
||||
* await table.add([{ vector: [0.5, 0.2], type: "vector" }]);
|
||||
* await table.checkout(1);
|
||||
* console.log(await table.version()); // 2
|
||||
* ```
|
||||
*/
|
||||
async checkout(version: number): Promise<void> {
|
||||
await this.inner.checkout(version);
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures the table is pointing at the latest version
|
||||
* Checkout the latest version of the table. _This is an in-place operation._
|
||||
*
|
||||
* This can be used to manually update a table when the read_consistency_interval is None
|
||||
* It can also be used to undo a `[Self::checkout]` operation
|
||||
* The table will be set back into standard mode, and will track the latest
|
||||
* version of the table.
|
||||
*/
|
||||
async checkoutLatest(): Promise<void> {
|
||||
await this.inner.checkoutLatest();
|
||||
@@ -344,9 +352,7 @@ export class Table {
|
||||
await this.inner.restore();
|
||||
}
|
||||
|
||||
/**
|
||||
* List all indices that have been created with Self::create_index
|
||||
*/
|
||||
/** List all indices that have been created with {@link Table.createIndex} */
|
||||
async listIndices(): Promise<IndexConfig[]> {
|
||||
return await this.inner.listIndices();
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.20",
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.20",
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.20",
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.20",
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
|
||||
86
nodejs/package-lock.json
generated
86
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.4.16",
|
||||
"version": "0.4.20",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.4.16",
|
||||
"version": "0.4.20",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -45,13 +45,6 @@
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/lancedb-darwin-arm64": "0.4.16",
|
||||
"@lancedb/lancedb-darwin-x64": "0.4.16",
|
||||
"@lancedb/lancedb-linux-arm64-gnu": "0.4.16",
|
||||
"@lancedb/lancedb-linux-x64-gnu": "0.4.16",
|
||||
"@lancedb/lancedb-win32-x64-msvc": "0.4.16"
|
||||
}
|
||||
},
|
||||
"node_modules/@75lb/deep-merge": {
|
||||
@@ -2221,81 +2214,6 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.14"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/lancedb-darwin-arm64": {
|
||||
"version": "0.4.16",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-darwin-arm64/-/lancedb-darwin-arm64-0.4.16.tgz",
|
||||
"integrity": "sha512-CV65ouIDQbBSNtdHbQSr2fqXflOuqud1cfweUS+EiK7eEOEYl7nO2oiFYO49Jy76MEwZxiP99hW825aCqIQJqg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/lancedb-darwin-x64": {
|
||||
"version": "0.4.16",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-darwin-x64/-/lancedb-darwin-x64-0.4.16.tgz",
|
||||
"integrity": "sha512-1CwIYCNdbFmV7fvqM+qUxbYgwxx0slcCV48PC/I19Ejitgtzw/NJiWDCvONhaLqG85lWNZm1xYceRpVv7b8seQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/lancedb-linux-arm64-gnu": {
|
||||
"version": "0.4.16",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-linux-arm64-gnu/-/lancedb-linux-arm64-gnu-0.4.16.tgz",
|
||||
"integrity": "sha512-CzLEbzoHKS6jV0k52YnvsiVNx0VzLp1Vz/zmbHI6HmB/XbS67qDO93Jk71MDmXq3JDw0FKFCw9ghkg+6YWq7ZA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/lancedb-linux-x64-gnu": {
|
||||
"version": "0.4.16",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-linux-x64-gnu/-/lancedb-linux-x64-gnu-0.4.16.tgz",
|
||||
"integrity": "sha512-nKChybybi8uA0AFRHBFm7Fz3VXcRm8riv5Gs7xQsrsCtYxxf4DT/0BfUvQ0xKbwNJa+fawHRxi9BOQewdj49fg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/lancedb-win32-x64-msvc": {
|
||||
"version": "0.4.16",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-win32-x64-msvc/-/lancedb-win32-x64-msvc-0.4.16.tgz",
|
||||
"integrity": "sha512-KMeBPMpv2g+ZMVsHVibed7BydrBlxje1qS0bZTDrLw9BtZOk6XH2lh1mCDnCJI6sbAscUKNA6fDCdquhQPHL7w==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/cli": {
|
||||
"version": "2.18.0",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/cli/-/cli-2.18.0.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.4.17",
|
||||
"version": "0.4.20",
|
||||
"main": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"napi": {
|
||||
@@ -62,20 +62,14 @@
|
||||
"build-release": "npm run build:release && tsc -b && shx cp lancedb/native.d.ts dist/native.d.ts",
|
||||
"chkformat": "prettier . --check",
|
||||
"docs": "typedoc --plugin typedoc-plugin-markdown --out ../docs/src/js lancedb/index.ts",
|
||||
"lint": "eslint lancedb && eslint __test__",
|
||||
"lint": "eslint lancedb __test__",
|
||||
"lint-fix": "eslint lancedb __test__ --fix",
|
||||
"prepublishOnly": "napi prepublish -t npm",
|
||||
"test": "npm run build && jest --verbose",
|
||||
"integration": "S3_TEST=1 npm run test",
|
||||
"universal": "napi universal",
|
||||
"version": "napi version"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/lancedb-darwin-arm64": "0.4.17",
|
||||
"@lancedb/lancedb-darwin-x64": "0.4.17",
|
||||
"@lancedb/lancedb-linux-arm64-gnu": "0.4.17",
|
||||
"@lancedb/lancedb-linux-x64-gnu": "0.4.17",
|
||||
"@lancedb/lancedb-win32-x64-msvc": "0.4.17"
|
||||
},
|
||||
"dependencies": {
|
||||
"openai": "^4.29.2",
|
||||
"apache-arrow": "^15.0.0"
|
||||
|
||||
@@ -176,6 +176,7 @@ impl Connection {
|
||||
&self,
|
||||
name: String,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
index_cache_size: Option<u32>,
|
||||
) -> napi::Result<Table> {
|
||||
let mut builder = self.get_inner()?.open_table(&name);
|
||||
if let Some(storage_options) = storage_options {
|
||||
@@ -183,6 +184,9 @@ impl Connection {
|
||||
builder = builder.storage_option(key, value);
|
||||
}
|
||||
}
|
||||
if let Some(index_cache_size) = index_cache_size {
|
||||
builder = builder.index_cache_size(index_cache_size);
|
||||
}
|
||||
let tbl = builder
|
||||
.execute()
|
||||
.await
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.6.9
|
||||
current_version = 0.6.13
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
@@ -14,7 +14,7 @@ name = "_lancedb"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
arrow = { version = "50.0.0", features = ["pyarrow"] }
|
||||
arrow = { version = "51.0.0", features = ["pyarrow"] }
|
||||
lancedb = { path = "../rust/lancedb" }
|
||||
env_logger = "0.10"
|
||||
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] }
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.6.9"
|
||||
version = "0.6.13"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.10.12",
|
||||
@@ -10,7 +10,7 @@ dependencies = [
|
||||
"tqdm>=4.27.0",
|
||||
"pydantic>=1.10",
|
||||
"attrs>=21.3.0",
|
||||
"semver>=3.0",
|
||||
"semver",
|
||||
"cachetools",
|
||||
"overrides>=0.7",
|
||||
]
|
||||
|
||||
@@ -107,6 +107,9 @@ def connect(
|
||||
request_thread_pool=request_thread_pool,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if kwargs:
|
||||
raise ValueError(f"Unknown keyword arguments: {kwargs}")
|
||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||
|
||||
|
||||
|
||||
@@ -224,13 +224,23 @@ class DBConnection(EnforceOverrides):
|
||||
def __getitem__(self, name: str) -> LanceTable:
|
||||
return self.open_table(name)
|
||||
|
||||
def open_table(self, name: str) -> Table:
|
||||
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
index_cache_size: int, default 256
|
||||
Set the size of the index cache, specified as a number of entries
|
||||
|
||||
The exact meaning of an "entry" will depend on the type of index:
|
||||
* IVF - there is one entry for each IVF partition
|
||||
* BTREE - there is one entry for the entire index
|
||||
|
||||
This cache applies to the entire opened table, across all indices.
|
||||
Setting this value higher will increase performance on larger datasets
|
||||
at the expense of more RAM
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -248,6 +258,18 @@ class DBConnection(EnforceOverrides):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def rename_table(self, cur_name: str, new_name: str):
|
||||
"""Rename a table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cur_name: str
|
||||
The current name of the table.
|
||||
new_name: str
|
||||
The new name of the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def drop_database(self):
|
||||
"""
|
||||
Drop database
|
||||
@@ -407,7 +429,9 @@ class LanceDBConnection(DBConnection):
|
||||
return tbl
|
||||
|
||||
@override
|
||||
def open_table(self, name: str) -> LanceTable:
|
||||
def open_table(
|
||||
self, name: str, *, index_cache_size: Optional[int] = None
|
||||
) -> LanceTable:
|
||||
"""Open a table in the database.
|
||||
|
||||
Parameters
|
||||
@@ -419,7 +443,7 @@ class LanceDBConnection(DBConnection):
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
return LanceTable.open(self, name)
|
||||
return LanceTable.open(self, name, index_cache_size=index_cache_size)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str, ignore_missing: bool = False):
|
||||
@@ -751,7 +775,10 @@ class AsyncConnection(object):
|
||||
return AsyncTable(new_table)
|
||||
|
||||
async def open_table(
|
||||
self, name: str, storage_options: Optional[Dict[str, str]] = None
|
||||
self,
|
||||
name: str,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
@@ -764,12 +791,22 @@ class AsyncConnection(object):
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
index_cache_size: int, default 256
|
||||
Set the size of the index cache, specified as a number of entries
|
||||
|
||||
The exact meaning of an "entry" will depend on the type of index:
|
||||
* IVF - there is one entry for each IVF partition
|
||||
* BTREE - there is one entry for the entire index
|
||||
|
||||
This cache applies to the entire opened table, across all indices.
|
||||
Setting this value higher will increase performance on larger datasets
|
||||
at the expense of more RAM
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
table = await self._inner.open_table(name, storage_options)
|
||||
table = await self._inner.open_table(name, storage_options, index_cache_size)
|
||||
return AsyncTable(table)
|
||||
|
||||
async def drop_table(self, name: str):
|
||||
|
||||
@@ -255,7 +255,13 @@ def retry_with_exponential_backoff(
|
||||
)
|
||||
|
||||
delay *= exponential_base * (1 + jitter * random.random())
|
||||
logging.info("Retrying in %s seconds...", delay)
|
||||
logging.warning(
|
||||
"Error occurred: %s \n Retrying in %s seconds (retry %s of %s) \n",
|
||||
e,
|
||||
delay,
|
||||
num_retries,
|
||||
max_retries,
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -37,7 +37,7 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
import semver
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
PYDANTIC_VERSION = semver.parse_version_info(pydantic.__version__)
|
||||
try:
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
except ImportError:
|
||||
|
||||
@@ -30,6 +30,7 @@ from typing import (
|
||||
import deprecation
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.fs as pa_fs
|
||||
import pydantic
|
||||
|
||||
from . import __version__
|
||||
@@ -37,7 +38,7 @@ from .arrow import AsyncRecordBatchReader
|
||||
from .common import VEC
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.linear_combination import LinearCombinationReranker
|
||||
from .util import safe_import_pandas
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
@@ -665,6 +666,14 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
# get the index path
|
||||
index_path = self._table._get_fts_index_path()
|
||||
|
||||
# Check that we are on local filesystem
|
||||
fs, _path = fs_from_uri(index_path)
|
||||
if not isinstance(fs, pa_fs.LocalFileSystem):
|
||||
raise NotImplementedError(
|
||||
"Full-text search is only supported on the local filesystem"
|
||||
)
|
||||
|
||||
# check if the index exist
|
||||
if not Path(index_path).exists():
|
||||
raise FileNotFoundError(
|
||||
|
||||
@@ -94,7 +94,7 @@ class RemoteDBConnection(DBConnection):
|
||||
yield item
|
||||
|
||||
@override
|
||||
def open_table(self, name: str) -> Table:
|
||||
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
@@ -110,6 +110,12 @@ class RemoteDBConnection(DBConnection):
|
||||
|
||||
self._client.mount_retry_adapter_for_table(name)
|
||||
|
||||
if index_cache_size is not None:
|
||||
logging.info(
|
||||
"index_cache_size is ignored in LanceDb Cloud"
|
||||
" (there is no local cache to configure)"
|
||||
)
|
||||
|
||||
# check if table exists
|
||||
if self._table_cache.get(name) is None:
|
||||
self._client.post(f"/v1/table/{name}/describe/")
|
||||
@@ -279,7 +285,25 @@ class RemoteDBConnection(DBConnection):
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/drop/",
|
||||
)
|
||||
self._table_cache.pop(name)
|
||||
self._table_cache.pop(name, default=None)
|
||||
|
||||
@override
|
||||
def rename_table(self, cur_name: str, new_name: str):
|
||||
"""Rename a table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cur_name: str
|
||||
The current name of the table.
|
||||
new_name: str
|
||||
The new name of the table.
|
||||
"""
|
||||
self._client.post(
|
||||
f"/v1/table/{cur_name}/rename/",
|
||||
data={"new_table_name": new_name},
|
||||
)
|
||||
self._table_cache.pop(cur_name, default=None)
|
||||
self._table_cache[new_name] = True
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the database."""
|
||||
|
||||
@@ -72,7 +72,7 @@ class RemoteTable(Table):
|
||||
return resp
|
||||
|
||||
def index_stats(self, index_uuid: str):
|
||||
"""List all the indices on the table"""
|
||||
"""List all the stats of a specified index"""
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/index/{index_uuid}/stats/"
|
||||
)
|
||||
|
||||
@@ -806,6 +806,7 @@ class _LanceLatestDatasetRef(_LanceDatasetRef):
|
||||
"""Reference to the latest version of a LanceDataset."""
|
||||
|
||||
uri: str
|
||||
index_cache_size: Optional[int] = None
|
||||
read_consistency_interval: Optional[timedelta] = None
|
||||
last_consistency_check: Optional[float] = None
|
||||
_dataset: Optional[LanceDataset] = None
|
||||
@@ -813,7 +814,9 @@ class _LanceLatestDatasetRef(_LanceDatasetRef):
|
||||
@property
|
||||
def dataset(self) -> LanceDataset:
|
||||
if not self._dataset:
|
||||
self._dataset = lance.dataset(self.uri)
|
||||
self._dataset = lance.dataset(
|
||||
self.uri, index_cache_size=self.index_cache_size
|
||||
)
|
||||
self.last_consistency_check = time.monotonic()
|
||||
elif self.read_consistency_interval is not None:
|
||||
now = time.monotonic()
|
||||
@@ -842,12 +845,15 @@ class _LanceLatestDatasetRef(_LanceDatasetRef):
|
||||
class _LanceTimeTravelRef(_LanceDatasetRef):
|
||||
uri: str
|
||||
version: int
|
||||
index_cache_size: Optional[int] = None
|
||||
_dataset: Optional[LanceDataset] = None
|
||||
|
||||
@property
|
||||
def dataset(self) -> LanceDataset:
|
||||
if not self._dataset:
|
||||
self._dataset = lance.dataset(self.uri, version=self.version)
|
||||
self._dataset = lance.dataset(
|
||||
self.uri, version=self.version, index_cache_size=self.index_cache_size
|
||||
)
|
||||
return self._dataset
|
||||
|
||||
@dataset.setter
|
||||
@@ -884,6 +890,8 @@ class LanceTable(Table):
|
||||
connection: "LanceDBConnection",
|
||||
name: str,
|
||||
version: Optional[int] = None,
|
||||
*,
|
||||
index_cache_size: Optional[int] = None,
|
||||
):
|
||||
self._conn = connection
|
||||
self.name = name
|
||||
@@ -892,11 +900,13 @@ class LanceTable(Table):
|
||||
self._ref = _LanceTimeTravelRef(
|
||||
uri=self._dataset_uri,
|
||||
version=version,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
else:
|
||||
self._ref = _LanceLatestDatasetRef(
|
||||
uri=self._dataset_uri,
|
||||
read_consistency_interval=connection.read_consistency_interval,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1199,6 +1209,11 @@ class LanceTable(Table):
|
||||
raise ValueError("Index already exists. Use replace=True to overwrite.")
|
||||
fs.delete_dir(path)
|
||||
|
||||
if not isinstance(fs, pa_fs.LocalFileSystem):
|
||||
raise NotImplementedError(
|
||||
"Full-text search is only supported on the local filesystem"
|
||||
)
|
||||
|
||||
index = create_index(
|
||||
self._get_fts_index_path(),
|
||||
field_names,
|
||||
|
||||
@@ -368,6 +368,15 @@ async def test_create_exist_ok_async(tmp_path):
|
||||
# await db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||
|
||||
|
||||
def test_open_table_sync(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
db.create_table("test", data=[{"id": 0}])
|
||||
assert db.open_table("test").count_rows() == 1
|
||||
assert db.open_table("test", index_cache_size=0).count_rows() == 1
|
||||
with pytest.raises(FileNotFoundError, match="does not exist"):
|
||||
db.open_table("does_not_exist")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_table(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
@@ -397,6 +406,10 @@ async def test_open_table(tmp_path):
|
||||
}
|
||||
)
|
||||
|
||||
# No way to verify this yet, but at least make sure we
|
||||
# can pass the parameter
|
||||
await db.open_table("test", index_cache_size=0)
|
||||
|
||||
with pytest.raises(ValueError, match="was not found"):
|
||||
await db.open_table("does_not_exist")
|
||||
|
||||
|
||||
@@ -213,7 +213,7 @@ def test_syntax(table):
|
||||
# https://github.com/lancedb/lancedb/issues/769
|
||||
table.create_fts_index("text")
|
||||
with pytest.raises(ValueError, match="Syntax Error"):
|
||||
table.search("they could have been dogs OR cats").limit(10).to_list()
|
||||
table.search("they could have been dogs OR").limit(10).to_list()
|
||||
|
||||
# these should work
|
||||
|
||||
|
||||
@@ -134,17 +134,21 @@ impl Connection {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, storage_options = None))]
|
||||
#[pyo3(signature = (name, storage_options = None, index_cache_size = None))]
|
||||
pub fn open_table(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
index_cache_size: Option<u32>,
|
||||
) -> PyResult<&PyAny> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
let mut builder = inner.open_table(name);
|
||||
if let Some(storage_options) = storage_options {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
if let Some(index_cache_size) = index_cache_size {
|
||||
builder = builder.index_cache_size(index_cache_size);
|
||||
}
|
||||
future_into_py(self_.py(), async move {
|
||||
let table = builder.execute().await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
|
||||
@@ -35,21 +35,16 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
||||
match &self {
|
||||
Ok(_) => Ok(self.unwrap()),
|
||||
Err(err) => match err {
|
||||
LanceError::InvalidInput { .. } => self.value_error(),
|
||||
LanceError::InvalidTableName { .. } => self.value_error(),
|
||||
LanceError::TableNotFound { .. } => self.value_error(),
|
||||
LanceError::Schema { .. } => self.value_error(),
|
||||
LanceError::InvalidInput { .. }
|
||||
| LanceError::InvalidTableName { .. }
|
||||
| LanceError::TableNotFound { .. }
|
||||
| LanceError::Schema { .. } => self.value_error(),
|
||||
LanceError::CreateDir { .. } => self.os_error(),
|
||||
LanceError::TableAlreadyExists { .. } => self.runtime_error(),
|
||||
LanceError::ObjectStore { .. } => Err(PyIOError::new_err(err.to_string())),
|
||||
LanceError::Lance { .. } => self.runtime_error(),
|
||||
LanceError::Runtime { .. } => self.runtime_error(),
|
||||
LanceError::Http { .. } => self.runtime_error(),
|
||||
LanceError::Arrow { .. } => self.runtime_error(),
|
||||
LanceError::NotSupported { .. } => {
|
||||
Err(PyNotImplementedError::new_err(err.to_string()))
|
||||
}
|
||||
LanceError::Other { .. } => self.runtime_error(),
|
||||
_ => self.runtime_error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-node"
|
||||
version = "0.4.17"
|
||||
version = "0.4.20"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
@@ -59,7 +59,7 @@ fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
for handle in storage_options_js {
|
||||
let obj = handle.downcast::<JsArray, _>(&mut cx).unwrap();
|
||||
let key = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
|
||||
let value = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
|
||||
let value = obj.get::<JsString, _, _>(&mut cx, 1)?.value(&mut cx);
|
||||
|
||||
storage_options.push((key, value));
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.4.17"
|
||||
version = "0.4.20"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -40,6 +40,8 @@ serde = { version = "^1" }
|
||||
serde_json = { version = "1" }
|
||||
# For remote feature
|
||||
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
|
||||
polars-arrow = { version = ">=0.37", optional = true }
|
||||
polars = { version = ">=0.37", optional = true}
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.5.0"
|
||||
@@ -52,7 +54,8 @@ aws-sdk-kms = { version = "1.0" }
|
||||
aws-config = { version = "1.0" }
|
||||
|
||||
[features]
|
||||
default = ["remote"]
|
||||
default = []
|
||||
remote = ["dep:reqwest"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
polars = ["dep:polars-arrow", "dep:polars"]
|
||||
|
||||
@@ -14,10 +14,12 @@
|
||||
|
||||
use std::{pin::Pin, sync::Arc};
|
||||
|
||||
pub use arrow_array;
|
||||
pub use arrow_schema;
|
||||
use futures::{Stream, StreamExt};
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
/// An iterator of batches that also has a schema
|
||||
@@ -114,8 +116,183 @@ pub trait IntoArrow {
|
||||
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>>;
|
||||
}
|
||||
|
||||
pub type BoxedRecordBatchReader = Box<dyn arrow_array::RecordBatchReader + Send>;
|
||||
|
||||
impl<T: arrow_array::RecordBatchReader + Send + 'static> IntoArrow for T {
|
||||
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
|
||||
Ok(Box::new(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> SimpleRecordBatchStream<S> {
|
||||
pub fn new(stream: S, schema: Arc<arrow_schema::Schema>) -> Self {
|
||||
Self { schema, stream }
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "polars")]
|
||||
/// An iterator of record batches formed from a Polars DataFrame.
|
||||
pub struct PolarsDataFrameRecordBatchReader {
|
||||
chunks: std::vec::IntoIter<ArrowChunk>,
|
||||
arrow_schema: Arc<arrow_schema::Schema>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
impl PolarsDataFrameRecordBatchReader {
|
||||
/// Creates a new `PolarsDataFrameRecordBatchReader` from a given Polars DataFrame.
|
||||
/// If the input dataframe does not have aligned chunks, this function undergoes
|
||||
/// the costly operation of reallocating each series as a single contigous chunk.
|
||||
pub fn new(mut df: DataFrame) -> Result<Self> {
|
||||
df.align_chunks();
|
||||
let arrow_schema =
|
||||
polars_arrow_convertors::convert_polars_df_schema_to_arrow_rb_schema(df.schema())?;
|
||||
Ok(Self {
|
||||
chunks: df
|
||||
.iter_chunks(polars_arrow_convertors::POLARS_ARROW_FLAVOR)
|
||||
.collect::<Vec<ArrowChunk>>()
|
||||
.into_iter(),
|
||||
arrow_schema,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
impl Iterator for PolarsDataFrameRecordBatchReader {
|
||||
type Item = std::result::Result<arrow_array::RecordBatch, arrow_schema::ArrowError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.chunks.next().map(|chunk| {
|
||||
let columns: std::result::Result<Vec<arrow_array::ArrayRef>, arrow_schema::ArrowError> =
|
||||
chunk
|
||||
.into_arrays()
|
||||
.into_iter()
|
||||
.zip(self.arrow_schema.fields.iter())
|
||||
.map(|(polars_array, arrow_field)| {
|
||||
polars_arrow_convertors::convert_polars_arrow_array_to_arrow_rs_array(
|
||||
polars_array,
|
||||
arrow_field.data_type().clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
arrow_array::RecordBatch::try_new(self.arrow_schema.clone(), columns?)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
impl arrow_array::RecordBatchReader for PolarsDataFrameRecordBatchReader {
|
||||
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||
self.arrow_schema.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait for converting the result of a LanceDB query into a Polars DataFrame with aligned
|
||||
/// chunks. The resulting Polars DataFrame will have aligned chunks, but the series's
|
||||
/// chunks are not guaranteed to be contiguous.
|
||||
#[cfg(feature = "polars")]
|
||||
pub trait IntoPolars {
|
||||
fn into_polars(self) -> impl std::future::Future<Output = Result<DataFrame>> + Send;
|
||||
}
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
impl IntoPolars for SendableRecordBatchStream {
|
||||
async fn into_polars(mut self) -> Result<DataFrame> {
|
||||
let polars_schema =
|
||||
polars_arrow_convertors::convert_arrow_rb_schema_to_polars_df_schema(&self.schema())?;
|
||||
let mut acc_df: DataFrame = DataFrame::from(&polars_schema);
|
||||
while let Some(record_batch) = self.next().await {
|
||||
let new_df = polars_arrow_convertors::convert_arrow_rb_to_polars_df(
|
||||
&record_batch?,
|
||||
&polars_schema,
|
||||
)?;
|
||||
acc_df = acc_df.vstack(&new_df)?;
|
||||
}
|
||||
Ok(acc_df)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "polars"))]
|
||||
mod tests {
|
||||
use super::SendableRecordBatchStream;
|
||||
use crate::arrow::{
|
||||
IntoArrow, IntoPolars, PolarsDataFrameRecordBatchReader, SimpleRecordBatchStream,
|
||||
};
|
||||
use polars::prelude::{DataFrame, NamedFrom, Series};
|
||||
|
||||
fn get_record_batch_reader_from_polars() -> Box<dyn arrow_array::RecordBatchReader + Send> {
|
||||
let mut string_series = Series::new("string", &["ab"]);
|
||||
let mut int_series = Series::new("int", &[1]);
|
||||
let mut float_series = Series::new("float", &[1.0]);
|
||||
let df1 = DataFrame::new(vec![string_series, int_series, float_series]).unwrap();
|
||||
|
||||
string_series = Series::new("string", &["bc"]);
|
||||
int_series = Series::new("int", &[2]);
|
||||
float_series = Series::new("float", &[2.0]);
|
||||
let df2 = DataFrame::new(vec![string_series, int_series, float_series]).unwrap();
|
||||
|
||||
PolarsDataFrameRecordBatchReader::new(df1.vstack(&df2).unwrap())
|
||||
.unwrap()
|
||||
.into_arrow()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_polars_to_arrow() {
|
||||
let record_batch_reader = get_record_batch_reader_from_polars();
|
||||
let schema = record_batch_reader.schema();
|
||||
|
||||
// Test schema conversion
|
||||
assert_eq!(
|
||||
schema
|
||||
.fields
|
||||
.iter()
|
||||
.map(|field| (field.name().as_str(), field.data_type()))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("string", &arrow_schema::DataType::LargeUtf8),
|
||||
("int", &arrow_schema::DataType::Int32),
|
||||
("float", &arrow_schema::DataType::Float64)
|
||||
]
|
||||
);
|
||||
let record_batches: Vec<arrow_array::RecordBatch> =
|
||||
record_batch_reader.map(|result| result.unwrap()).collect();
|
||||
assert_eq!(record_batches.len(), 2);
|
||||
assert_eq!(schema, record_batches[0].schema());
|
||||
assert_eq!(record_batches[0].schema(), record_batches[1].schema());
|
||||
|
||||
// Test number of rows
|
||||
assert_eq!(record_batches[0].num_rows(), 1);
|
||||
assert_eq!(record_batches[1].num_rows(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn from_arrow_to_polars() {
|
||||
let record_batch_reader = get_record_batch_reader_from_polars();
|
||||
let schema = record_batch_reader.schema();
|
||||
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
|
||||
schema: schema.clone(),
|
||||
stream: futures::stream::iter(
|
||||
record_batch_reader
|
||||
.into_iter()
|
||||
.map(|r| r.map_err(Into::into)),
|
||||
),
|
||||
});
|
||||
let df = stream.into_polars().await.unwrap();
|
||||
|
||||
// Test number of chunks and rows
|
||||
assert_eq!(df.n_chunks(), 2);
|
||||
assert_eq!(df.height(), 2);
|
||||
|
||||
// Test schema conversion
|
||||
assert_eq!(
|
||||
df.schema()
|
||||
.into_iter()
|
||||
.map(|(name, datatype)| (name.to_string(), datatype))
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
("string".to_string(), polars::prelude::DataType::String),
|
||||
("int".to_owned(), polars::prelude::DataType::Int32),
|
||||
("float".to_owned(), polars::prelude::DataType::Float64)
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,12 +27,18 @@ use object_store::{aws::AwsCredential, local::LocalFileSystem};
|
||||
use snafu::prelude::*;
|
||||
|
||||
use crate::arrow::IntoArrow;
|
||||
use crate::embeddings::{
|
||||
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
|
||||
};
|
||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||
use crate::table::{NativeTable, WriteOptions};
|
||||
use crate::table::{NativeTable, TableDefinition, WriteOptions};
|
||||
use crate::utils::validate_table_name;
|
||||
use crate::Table;
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
use log::warn;
|
||||
|
||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
|
||||
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>;
|
||||
@@ -130,9 +136,10 @@ pub struct CreateTableBuilder<const HAS_DATA: bool, T: IntoArrow> {
|
||||
parent: Arc<dyn ConnectionInternal>,
|
||||
pub(crate) name: String,
|
||||
pub(crate) data: Option<T>,
|
||||
pub(crate) schema: Option<SchemaRef>,
|
||||
pub(crate) mode: CreateTableMode,
|
||||
pub(crate) write_options: WriteOptions,
|
||||
pub(crate) table_definition: Option<TableDefinition>,
|
||||
pub(crate) embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||
}
|
||||
|
||||
// Builder methods that only apply when we have initial data
|
||||
@@ -142,9 +149,10 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
|
||||
parent,
|
||||
name,
|
||||
data: Some(data),
|
||||
schema: None,
|
||||
mode: CreateTableMode::default(),
|
||||
write_options: WriteOptions::default(),
|
||||
table_definition: None,
|
||||
embeddings: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,24 +180,43 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
|
||||
parent: self.parent,
|
||||
name: self.name,
|
||||
data: None,
|
||||
schema: self.schema,
|
||||
table_definition: self.table_definition,
|
||||
mode: self.mode,
|
||||
write_options: self.write_options,
|
||||
embeddings: self.embeddings,
|
||||
};
|
||||
Ok((data, builder))
|
||||
}
|
||||
|
||||
pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result<Self> {
|
||||
// Early verification of the embedding name
|
||||
let embedding_func = self
|
||||
.parent
|
||||
.embedding_registry()
|
||||
.get(&definition.embedding_name)
|
||||
.ok_or_else(|| Error::EmbeddingFunctionNotFound {
|
||||
name: definition.embedding_name.to_string(),
|
||||
reason: "No embedding function found in the connection's embedding_registry"
|
||||
.to_string(),
|
||||
})?;
|
||||
|
||||
self.embeddings.push((definition, embedding_func));
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
// Builder methods that only apply when we do not have initial data
|
||||
impl CreateTableBuilder<false, NoData> {
|
||||
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
|
||||
let table_definition = TableDefinition::new_from_schema(schema);
|
||||
Self {
|
||||
parent,
|
||||
name,
|
||||
data: None,
|
||||
schema: Some(schema),
|
||||
table_definition: Some(table_definition),
|
||||
mode: CreateTableMode::default(),
|
||||
write_options: WriteOptions::default(),
|
||||
embeddings: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,6 +374,7 @@ impl OpenTableBuilder {
|
||||
pub(crate) trait ConnectionInternal:
|
||||
Send + Sync + std::fmt::Debug + std::fmt::Display + 'static
|
||||
{
|
||||
fn embedding_registry(&self) -> &dyn EmbeddingRegistry;
|
||||
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>>;
|
||||
async fn do_create_table(
|
||||
&self,
|
||||
@@ -363,7 +391,7 @@ pub(crate) trait ConnectionInternal:
|
||||
) -> Result<Table> {
|
||||
let batches = Box::new(RecordBatchIterator::new(
|
||||
vec![],
|
||||
options.schema.as_ref().unwrap().clone(),
|
||||
options.table_definition.clone().unwrap().schema.clone(),
|
||||
));
|
||||
self.do_create_table(options, batches).await
|
||||
}
|
||||
@@ -450,6 +478,13 @@ impl Connection {
|
||||
pub async fn drop_db(&self) -> Result<()> {
|
||||
self.internal.drop_db().await
|
||||
}
|
||||
|
||||
/// Get the in-memory embedding registry.
|
||||
/// It's important to note that the embedding registry is not persisted across connections.
|
||||
/// So if a table contains embeddings, you will need to make sure that you are using a connection that has the same embedding functions registered
|
||||
pub fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||
self.internal.embedding_registry()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -483,6 +518,7 @@ pub struct ConnectBuilder {
|
||||
/// consistency only applies to read operations. Write operations are
|
||||
/// always consistent.
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
}
|
||||
|
||||
impl ConnectBuilder {
|
||||
@@ -495,6 +531,7 @@ impl ConnectBuilder {
|
||||
host_override: None,
|
||||
read_consistency_interval: None,
|
||||
storage_options: HashMap::new(),
|
||||
embedding_registry: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -513,6 +550,12 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
|
||||
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
|
||||
self.embedding_registry = Some(registry);
|
||||
self
|
||||
}
|
||||
|
||||
/// [`AwsCredential`] to use when connecting to S3.
|
||||
#[deprecated(note = "Pass through storage_options instead")]
|
||||
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
|
||||
@@ -579,6 +622,7 @@ impl ConnectBuilder {
|
||||
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
|
||||
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
warn!("The rust implementation of the remote client is not yet ready for use.");
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||
&self.uri,
|
||||
&api_key,
|
||||
@@ -638,6 +682,7 @@ struct Database {
|
||||
|
||||
// Storage options to be inherited by tables created from this connection
|
||||
storage_options: HashMap<String, String>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Database {
|
||||
@@ -671,7 +716,12 @@ impl Database {
|
||||
// TODO: pass params regardless of OS
|
||||
match parse_res {
|
||||
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
|
||||
Self::open_path(uri, options.read_consistency_interval).await
|
||||
Self::open_path(
|
||||
uri,
|
||||
options.read_consistency_interval,
|
||||
options.embedding_registry.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
Ok(mut url) => {
|
||||
// iter thru the query params and extract the commit store param
|
||||
@@ -741,6 +791,10 @@ impl Database {
|
||||
None => None,
|
||||
};
|
||||
|
||||
let embedding_registry = options
|
||||
.embedding_registry
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(MemoryRegistry::new()));
|
||||
Ok(Self {
|
||||
uri: table_base_uri,
|
||||
query_string,
|
||||
@@ -749,20 +803,33 @@ impl Database {
|
||||
store_wrapper: write_store_wrapper,
|
||||
read_consistency_interval: options.read_consistency_interval,
|
||||
storage_options,
|
||||
embedding_registry,
|
||||
})
|
||||
}
|
||||
Err(_) => Self::open_path(uri, options.read_consistency_interval).await,
|
||||
Err(_) => {
|
||||
Self::open_path(
|
||||
uri,
|
||||
options.read_consistency_interval,
|
||||
options.embedding_registry.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn open_path(
|
||||
path: &str,
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
) -> Result<Self> {
|
||||
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
||||
}
|
||||
|
||||
let embedding_registry =
|
||||
embedding_registry.unwrap_or_else(|| Arc::new(MemoryRegistry::new()));
|
||||
|
||||
Ok(Self {
|
||||
uri: path.to_string(),
|
||||
query_string: None,
|
||||
@@ -771,6 +838,7 @@ impl Database {
|
||||
store_wrapper: None,
|
||||
read_consistency_interval,
|
||||
storage_options: HashMap::new(),
|
||||
embedding_registry,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -811,6 +879,9 @@ impl Database {
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ConnectionInternal for Database {
|
||||
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||
self.embedding_registry.as_ref()
|
||||
}
|
||||
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>> {
|
||||
let mut f = self
|
||||
.object_store
|
||||
@@ -847,7 +918,7 @@ impl ConnectionInternal for Database {
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<Table> {
|
||||
let table_uri = self.table_uri(&options.name)?;
|
||||
|
||||
let embedding_registry = self.embedding_registry.clone();
|
||||
// Inherit storage options from the connection
|
||||
let storage_options = options
|
||||
.write_options
|
||||
@@ -862,6 +933,11 @@ impl ConnectionInternal for Database {
|
||||
storage_options.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
let data = if options.embeddings.is_empty() {
|
||||
data
|
||||
} else {
|
||||
Box::new(WithEmbeddings::new(data, options.embeddings))
|
||||
};
|
||||
|
||||
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
|
||||
if matches!(&options.mode, CreateTableMode::Overwrite) {
|
||||
@@ -878,7 +954,10 @@ impl ConnectionInternal for Database {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(table) => Ok(Table::new(Arc::new(table))),
|
||||
Ok(table) => Ok(Table::new_with_embedding_registry(
|
||||
Arc::new(table),
|
||||
embedding_registry,
|
||||
)),
|
||||
Err(Error::TableAlreadyExists { name }) => match options.mode {
|
||||
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
|
||||
CreateTableMode::ExistOk(callback) => {
|
||||
@@ -909,12 +988,23 @@ impl ConnectionInternal for Database {
|
||||
}
|
||||
}
|
||||
|
||||
// Some ReadParams are exposed in the OpenTableBuilder, but we also
|
||||
// let the user provide their own ReadParams.
|
||||
//
|
||||
// If we have a user provided ReadParams use that
|
||||
// If we don't then start with the default ReadParams and customize it with
|
||||
// the options from the OpenTableBuilder
|
||||
let read_params = options.lance_read_params.unwrap_or_else(|| ReadParams {
|
||||
index_cache_size: options.index_cache_size as usize,
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
let native_table = Arc::new(
|
||||
NativeTable::open_with_params(
|
||||
&table_uri,
|
||||
&options.name,
|
||||
self.store_wrapper.clone(),
|
||||
options.lance_read_params,
|
||||
Some(read_params),
|
||||
self.read_consistency_interval,
|
||||
)
|
||||
.await?,
|
||||
@@ -1032,7 +1122,6 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "this can't pass due to https://github.com/lancedb/lancedb/issues/1019, enable it after the bug fixed"]
|
||||
async fn test_open_table() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
307
rust/lancedb/src/embeddings.rs
Normal file
307
rust/lancedb/src/embeddings.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
// Copyright 2024 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use lance::arrow::RecordBatchExt;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{HashMap, HashSet},
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use arrow_array::{Array, RecordBatch, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field, SchemaBuilder};
|
||||
// use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
error::Result,
|
||||
table::{ColumnDefinition, ColumnKind, TableDefinition},
|
||||
Error,
|
||||
};
|
||||
|
||||
/// Trait for embedding functions
|
||||
///
|
||||
/// An embedding function is a function that is applied to a column of input data
|
||||
/// to produce an "embedding" of that input. This embedding is then stored in the
|
||||
/// database alongside (or instead of) the original input.
|
||||
///
|
||||
/// An "embedding" is often a lower-dimensional representation of the input data.
|
||||
/// For example, sentence-transformers can be used to embed sentences into a 768-dimensional
|
||||
/// vector space. This is useful for tasks like similarity search, where we want to find
|
||||
/// similar sentences to a query sentence.
|
||||
///
|
||||
/// To use an embedding function you must first register it with the `EmbeddingsRegistry`.
|
||||
/// Then you can define it on a column in the table schema. That embedding will then be used
|
||||
/// to embed the data in that column.
|
||||
pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
/// The type of the input data
|
||||
fn source_type(&self) -> Result<Cow<DataType>>;
|
||||
/// The type of the output data
|
||||
/// This should **always** match the output of the `embed` function
|
||||
fn dest_type(&self) -> Result<Cow<DataType>>;
|
||||
/// Embed the input
|
||||
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
|
||||
}
|
||||
|
||||
/// Defines an embedding from input data into a lower-dimensional space
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub struct EmbeddingDefinition {
|
||||
/// The name of the column in the input data
|
||||
pub source_column: String,
|
||||
/// The name of the embedding column, if not specified
|
||||
/// it will be the source column with `_embedding` appended
|
||||
pub dest_column: Option<String>,
|
||||
/// The name of the embedding function to apply
|
||||
pub embedding_name: String,
|
||||
}
|
||||
|
||||
impl EmbeddingDefinition {
|
||||
pub fn new<S: Into<String>>(source_column: S, embedding_name: S, dest: Option<S>) -> Self {
|
||||
Self {
|
||||
source_column: source_column.into(),
|
||||
dest_column: dest.map(|d| d.into()),
|
||||
embedding_name: embedding_name.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A registry of embedding
|
||||
pub trait EmbeddingRegistry: Send + Sync + std::fmt::Debug {
|
||||
/// Return the names of all registered embedding functions
|
||||
fn functions(&self) -> HashSet<String>;
|
||||
/// Register a new [`EmbeddingFunction
|
||||
/// Returns an error if the function can not be registered
|
||||
fn register(&self, name: &str, function: Arc<dyn EmbeddingFunction>) -> Result<()>;
|
||||
/// Get an embedding function by name
|
||||
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>>;
|
||||
}
|
||||
|
||||
/// A [`EmbeddingRegistry`] that uses in-memory [`HashMap`]s
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct MemoryRegistry {
|
||||
functions: Arc<RwLock<HashMap<String, Arc<dyn EmbeddingFunction>>>>,
|
||||
}
|
||||
|
||||
impl EmbeddingRegistry for MemoryRegistry {
|
||||
fn functions(&self) -> HashSet<String> {
|
||||
self.functions.read().unwrap().keys().cloned().collect()
|
||||
}
|
||||
fn register(&self, name: &str, function: Arc<dyn EmbeddingFunction>) -> Result<()> {
|
||||
self.functions
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(name.to_string(), function);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
|
||||
self.functions.read().unwrap().get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryRegistry {
|
||||
/// Create a new `MemoryRegistry`
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// A record batch reader that has embeddings applied to it
|
||||
/// This is a wrapper around another record batch reader that applies an embedding function
|
||||
/// when reading from the record batch
|
||||
pub struct WithEmbeddings<R: RecordBatchReader> {
|
||||
inner: R,
|
||||
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||
}
|
||||
|
||||
/// A record batch that might have embeddings applied to it.
|
||||
pub enum MaybeEmbedded<R: RecordBatchReader> {
|
||||
/// The record batch reader has embeddings applied to it
|
||||
Yes(WithEmbeddings<R>),
|
||||
/// The record batch reader does not have embeddings applied to it
|
||||
/// The inner record batch reader is returned as-is
|
||||
No(R),
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> MaybeEmbedded<R> {
|
||||
/// Create a new RecordBatchReader with embeddings applied to it if the table definition
|
||||
/// specifies an embedding column and the registry contains an embedding function with that name
|
||||
/// Otherwise, this is a no-op and the inner RecordBatchReader is returned.
|
||||
pub fn try_new(
|
||||
inner: R,
|
||||
table_definition: TableDefinition,
|
||||
registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
) -> Result<Self> {
|
||||
if let Some(registry) = registry {
|
||||
let mut embeddings = Vec::with_capacity(table_definition.column_definitions.len());
|
||||
for cd in table_definition.column_definitions.iter() {
|
||||
if let ColumnKind::Embedding(embedding_def) = &cd.kind {
|
||||
match registry.get(&embedding_def.embedding_name) {
|
||||
Some(func) => {
|
||||
embeddings.push((embedding_def.clone(), func));
|
||||
}
|
||||
None => {
|
||||
return Err(Error::EmbeddingFunctionNotFound {
|
||||
name: embedding_def.embedding_name.to_string(),
|
||||
reason: format!(
|
||||
"Table was defined with an embedding column `{}` but no embedding function was found with that name within the registry.",
|
||||
embedding_def.embedding_name
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !embeddings.is_empty() {
|
||||
return Ok(Self::Yes(WithEmbeddings { inner, embeddings }));
|
||||
}
|
||||
};
|
||||
|
||||
// No embeddings to apply
|
||||
Ok(Self::No(inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> WithEmbeddings<R> {
|
||||
pub fn new(
|
||||
inner: R,
|
||||
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||
) -> Self {
|
||||
Self { inner, embeddings }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> WithEmbeddings<R> {
|
||||
fn dest_fields(&self) -> Result<Vec<Field>> {
|
||||
let schema = self.inner.schema();
|
||||
self.embeddings
|
||||
.iter()
|
||||
.map(|(ed, func)| {
|
||||
let src_field = schema.field_with_name(&ed.source_column).unwrap();
|
||||
|
||||
let field_name = ed
|
||||
.dest_column
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("{}_embedding", &ed.source_column));
|
||||
Ok(Field::new(
|
||||
field_name,
|
||||
func.dest_type()?.into_owned(),
|
||||
src_field.is_nullable(),
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn column_defs(&self) -> Vec<ColumnDefinition> {
|
||||
let base_schema = self.inner.schema();
|
||||
base_schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|_| ColumnDefinition {
|
||||
kind: ColumnKind::Physical,
|
||||
})
|
||||
.chain(self.embeddings.iter().map(|(ed, _)| ColumnDefinition {
|
||||
kind: ColumnKind::Embedding(ed.clone()),
|
||||
}))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn table_definition(&self) -> Result<TableDefinition> {
|
||||
let base_schema = self.inner.schema();
|
||||
|
||||
let output_fields = self.dest_fields()?;
|
||||
let column_definitions = self.column_defs();
|
||||
|
||||
let mut sb: SchemaBuilder = base_schema.as_ref().into();
|
||||
sb.extend(output_fields);
|
||||
|
||||
let schema = Arc::new(sb.finish());
|
||||
Ok(TableDefinition {
|
||||
schema,
|
||||
column_definitions,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
|
||||
type Item = std::result::Result<RecordBatch, arrow_schema::ArrowError>;
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self {
|
||||
Self::Yes(inner) => inner.next(),
|
||||
Self::No(inner) => inner.next(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> RecordBatchReader for MaybeEmbedded<R> {
|
||||
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||
match self {
|
||||
Self::Yes(inner) => inner.schema(),
|
||||
Self::No(inner) => inner.schema(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
|
||||
type Item = std::result::Result<RecordBatch, arrow_schema::ArrowError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let batch = self.inner.next()?;
|
||||
match batch {
|
||||
Ok(mut batch) => {
|
||||
// todo: parallelize this
|
||||
for (fld, func) in self.embeddings.iter() {
|
||||
let src_column = batch.column_by_name(&fld.source_column).unwrap();
|
||||
let embedding = match func.embed(src_column.clone()) {
|
||||
Ok(embedding) => embedding,
|
||||
Err(e) => {
|
||||
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
|
||||
"Error computing embedding: {}",
|
||||
e
|
||||
))))
|
||||
}
|
||||
};
|
||||
let dst_field_name = fld
|
||||
.dest_column
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("{}_embedding", &fld.source_column));
|
||||
|
||||
let dst_field = Field::new(
|
||||
dst_field_name,
|
||||
embedding.data_type().clone(),
|
||||
embedding.nulls().is_some(),
|
||||
);
|
||||
|
||||
match batch.try_with_column(dst_field.clone(), embedding) {
|
||||
Ok(b) => batch = b,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
}
|
||||
Some(Ok(batch))
|
||||
}
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> RecordBatchReader for WithEmbeddings<R> {
|
||||
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||
self.table_definition()
|
||||
.expect("table definition should be infallible at this point")
|
||||
.into_rich_schema()
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,9 @@ pub enum Error {
|
||||
InvalidInput { message: String },
|
||||
#[snafu(display("Table '{name}' was not found"))]
|
||||
TableNotFound { name: String },
|
||||
#[snafu(display("Embedding function '{name}' was not found. : {reason}"))]
|
||||
EmbeddingFunctionNotFound { name: String, reason: String },
|
||||
|
||||
#[snafu(display("Table '{name}' already exists"))]
|
||||
TableAlreadyExists { name: String },
|
||||
#[snafu(display("Unable to created lance dataset at {path}: {source}"))]
|
||||
@@ -112,3 +115,13 @@ impl From<url::ParseError> for Error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "polars")]
|
||||
impl From<polars::prelude::PolarsError> for Error {
|
||||
fn from(source: polars::prelude::PolarsError) -> Self {
|
||||
Self::Other {
|
||||
message: "Error in Polars DataFrame integration.".to_string(),
|
||||
source: Some(Box::new(source)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,10 +46,18 @@ impl VectorIndex {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct VectorIndexMetadata {
|
||||
pub metric_type: String,
|
||||
pub index_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct VectorIndexStatistics {
|
||||
pub num_indexed_rows: usize,
|
||||
pub num_unindexed_rows: usize,
|
||||
pub index_type: String,
|
||||
pub indices: Vec<VectorIndexMetadata>,
|
||||
}
|
||||
|
||||
/// Builder for an IVF PQ index.
|
||||
|
||||
@@ -350,8 +350,16 @@ mod test {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_e2e() {
|
||||
let dir1 = tempfile::tempdir().unwrap().into_path();
|
||||
let dir2 = tempfile::tempdir().unwrap().into_path();
|
||||
let dir1 = tempfile::tempdir()
|
||||
.unwrap()
|
||||
.into_path()
|
||||
.canonicalize()
|
||||
.unwrap();
|
||||
let dir2 = tempfile::tempdir()
|
||||
.unwrap()
|
||||
.into_path()
|
||||
.canonicalize()
|
||||
.unwrap();
|
||||
|
||||
let secondary_store = LocalFileSystem::new_with_prefix(dir2.to_str().unwrap()).unwrap();
|
||||
let object_store_wrapper = Arc::new(MirroringObjectStoreWrapper {
|
||||
|
||||
@@ -34,6 +34,16 @@
|
||||
//! cargo install lancedb
|
||||
//! ```
|
||||
//!
|
||||
//! ## Crate Features
|
||||
//!
|
||||
//! ### Experimental Features
|
||||
//!
|
||||
//! These features are not enabled by default. They are experimental or in-development features that
|
||||
//! are not yet ready to be released.
|
||||
//!
|
||||
//! - `remote` - Enable remote client to connect to LanceDB cloud. This is not yet fully implemented
|
||||
//! and should not be enabled.
|
||||
//!
|
||||
//! ### Quick Start
|
||||
//!
|
||||
//! #### Connect to a database.
|
||||
@@ -184,10 +194,13 @@
|
||||
pub mod arrow;
|
||||
pub mod connection;
|
||||
pub mod data;
|
||||
pub mod embeddings;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod io;
|
||||
pub mod ipc;
|
||||
#[cfg(feature = "polars")]
|
||||
mod polars_arrow_convertors;
|
||||
pub mod query;
|
||||
#[cfg(feature = "remote")]
|
||||
pub(crate) mod remote;
|
||||
@@ -225,6 +238,9 @@ pub enum DistanceType {
|
||||
/// distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||
/// L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||
Dot,
|
||||
/// Hamming distance. Hamming distance is a distance metric that measures
|
||||
/// the number of positions at which the corresponding elements are different.
|
||||
Hamming,
|
||||
}
|
||||
|
||||
impl From<DistanceType> for LanceDistanceType {
|
||||
@@ -233,6 +249,7 @@ impl From<DistanceType> for LanceDistanceType {
|
||||
DistanceType::L2 => Self::L2,
|
||||
DistanceType::Cosine => Self::Cosine,
|
||||
DistanceType::Dot => Self::Dot,
|
||||
DistanceType::Hamming => Self::Hamming,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -243,6 +260,7 @@ impl From<LanceDistanceType> for DistanceType {
|
||||
LanceDistanceType::L2 => Self::L2,
|
||||
LanceDistanceType::Cosine => Self::Cosine,
|
||||
LanceDistanceType::Dot => Self::Dot,
|
||||
LanceDistanceType::Hamming => Self::Hamming,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
123
rust/lancedb/src/polars_arrow_convertors.rs
Normal file
123
rust/lancedb/src/polars_arrow_convertors.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
/// Polars and LanceDB both use Arrow for their in memory-representation, but use
|
||||
/// different Rust Arrow implementations. LanceDB uses the arrow-rs crate and
|
||||
/// Polars uses the polars-arrow crate.
|
||||
///
|
||||
/// This crate defines zero-copy conversions (of the underlying buffers)
|
||||
/// between polars-arrow and arrow-rs using the C FFI.
|
||||
///
|
||||
/// The polars-arrow does implement conversions to and from arrow-rs, but
|
||||
/// requires a feature flagged dependency on arrow-rs. The version of arrow-rs
|
||||
/// depended on by polars-arrow and LanceDB may not be compatible,
|
||||
/// which necessitates using the C FFI.
|
||||
use crate::error::Result;
|
||||
use polars::prelude::{DataFrame, Series};
|
||||
use std::{mem, sync::Arc};
|
||||
|
||||
/// When interpreting Polars dataframes as polars-arrow record batches,
|
||||
/// one must decide whether to use Arrow string/binary view types
|
||||
/// instead of the standard Arrow string/binary types.
|
||||
/// For now, we will not use string view types because conversions
|
||||
/// for string view types from polars-arrow to arrow-rs are not yet implemented.
|
||||
/// See: https://lists.apache.org/thread/w88tpz76ox8h3rxkjl4so6rg3f1rv7wt for the
|
||||
/// differences in the types.
|
||||
pub const POLARS_ARROW_FLAVOR: bool = false;
|
||||
const IS_ARRAY_NULLABLE: bool = true;
|
||||
|
||||
/// Converts a Polars DataFrame schema to an Arrow RecordBatch schema.
|
||||
pub fn convert_polars_df_schema_to_arrow_rb_schema(
|
||||
polars_df_schema: polars::prelude::Schema,
|
||||
) -> Result<Arc<arrow_schema::Schema>> {
|
||||
let arrow_fields: Result<Vec<arrow_schema::Field>> = polars_df_schema
|
||||
.into_iter()
|
||||
.map(|(name, df_dtype)| {
|
||||
let polars_arrow_dtype = df_dtype.to_arrow(POLARS_ARROW_FLAVOR);
|
||||
let polars_field =
|
||||
polars_arrow::datatypes::Field::new(name, polars_arrow_dtype, IS_ARRAY_NULLABLE);
|
||||
convert_polars_arrow_field_to_arrow_rs_field(polars_field)
|
||||
})
|
||||
.collect();
|
||||
Ok(Arc::new(arrow_schema::Schema::new(arrow_fields?)))
|
||||
}
|
||||
|
||||
/// Converts an Arrow RecordBatch schema to a Polars DataFrame schema.
|
||||
pub fn convert_arrow_rb_schema_to_polars_df_schema(
|
||||
arrow_schema: &arrow_schema::Schema,
|
||||
) -> Result<polars::prelude::Schema> {
|
||||
let polars_df_fields: Result<Vec<polars::prelude::Field>> = arrow_schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|arrow_rs_field| {
|
||||
let polars_arrow_field = convert_arrow_rs_field_to_polars_arrow_field(arrow_rs_field)?;
|
||||
Ok(polars::prelude::Field::new(
|
||||
arrow_rs_field.name(),
|
||||
polars::datatypes::DataType::from(polars_arrow_field.data_type()),
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
Ok(polars::prelude::Schema::from_iter(polars_df_fields?))
|
||||
}
|
||||
|
||||
/// Converts an Arrow RecordBatch to a Polars DataFrame, using a provided Polars DataFrame schema.
|
||||
pub fn convert_arrow_rb_to_polars_df(
|
||||
arrow_rb: &arrow::record_batch::RecordBatch,
|
||||
polars_schema: &polars::prelude::Schema,
|
||||
) -> Result<DataFrame> {
|
||||
let mut columns: Vec<Series> = Vec::with_capacity(arrow_rb.num_columns());
|
||||
|
||||
for (i, column) in arrow_rb.columns().iter().enumerate() {
|
||||
let polars_df_dtype = polars_schema.try_get_at_index(i)?.1;
|
||||
let polars_arrow_dtype = polars_df_dtype.to_arrow(POLARS_ARROW_FLAVOR);
|
||||
let polars_array =
|
||||
convert_arrow_rs_array_to_polars_arrow_array(column, polars_arrow_dtype)?;
|
||||
columns.push(Series::from_arrow(
|
||||
polars_schema.try_get_at_index(i)?.0,
|
||||
polars_array,
|
||||
)?);
|
||||
}
|
||||
|
||||
Ok(DataFrame::from_iter(columns))
|
||||
}
|
||||
|
||||
/// Converts a polars-arrow Arrow array to an arrow-rs Arrow array.
|
||||
pub fn convert_polars_arrow_array_to_arrow_rs_array(
|
||||
polars_array: Box<dyn polars_arrow::array::Array>,
|
||||
arrow_datatype: arrow_schema::DataType,
|
||||
) -> std::result::Result<arrow_array::ArrayRef, arrow_schema::ArrowError> {
|
||||
let polars_c_array = polars_arrow::ffi::export_array_to_c(polars_array);
|
||||
let arrow_c_array = unsafe { mem::transmute(polars_c_array) };
|
||||
Ok(arrow_array::make_array(unsafe {
|
||||
arrow::ffi::from_ffi_and_data_type(arrow_c_array, arrow_datatype)
|
||||
}?))
|
||||
}
|
||||
|
||||
/// Converts an arrow-rs Arrow array to a polars-arrow Arrow array.
|
||||
fn convert_arrow_rs_array_to_polars_arrow_array(
|
||||
arrow_rs_array: &Arc<dyn arrow_array::Array>,
|
||||
polars_arrow_dtype: polars::datatypes::ArrowDataType,
|
||||
) -> Result<Box<dyn polars_arrow::array::Array>> {
|
||||
let arrow_c_array = arrow::ffi::FFI_ArrowArray::new(&arrow_rs_array.to_data());
|
||||
let polars_c_array = unsafe { mem::transmute(arrow_c_array) };
|
||||
Ok(unsafe { polars_arrow::ffi::import_array_from_c(polars_c_array, polars_arrow_dtype) }?)
|
||||
}
|
||||
|
||||
fn convert_polars_arrow_field_to_arrow_rs_field(
|
||||
polars_arrow_field: polars_arrow::datatypes::Field,
|
||||
) -> Result<arrow_schema::Field> {
|
||||
let polars_c_schema = polars_arrow::ffi::export_field_to_c(&polars_arrow_field);
|
||||
let arrow_c_schema: arrow::ffi::FFI_ArrowSchema = unsafe { mem::transmute(polars_c_schema) };
|
||||
let arrow_rs_dtype = arrow_schema::DataType::try_from(&arrow_c_schema)?;
|
||||
Ok(arrow_schema::Field::new(
|
||||
polars_arrow_field.name,
|
||||
arrow_rs_dtype,
|
||||
IS_ARRAY_NULLABLE,
|
||||
))
|
||||
}
|
||||
|
||||
fn convert_arrow_rs_field_to_polars_arrow_field(
|
||||
arrow_rs_field: &arrow_schema::Field,
|
||||
) -> Result<polars_arrow::datatypes::Field> {
|
||||
let arrow_rs_dtype = arrow_rs_field.data_type();
|
||||
let arrow_c_schema = arrow::ffi::FFI_ArrowSchema::try_from(arrow_rs_dtype)?;
|
||||
let polars_c_schema: polars_arrow::ffi::ArrowSchema = unsafe { mem::transmute(arrow_c_schema) };
|
||||
Ok(unsafe { polars_arrow::ffi::import_field_from_c(&polars_c_schema) }?)
|
||||
}
|
||||
@@ -23,6 +23,7 @@ use tokio::task::spawn_blocking;
|
||||
use crate::connection::{
|
||||
ConnectionInternal, CreateTableBuilder, NoData, OpenTableBuilder, TableNamesBuilder,
|
||||
};
|
||||
use crate::embeddings::EmbeddingRegistry;
|
||||
use crate::error::Result;
|
||||
use crate::Table;
|
||||
|
||||
@@ -87,14 +88,16 @@ impl ConnectionInternal for RemoteDatabase {
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
||||
self.client
|
||||
.post(&format!("/v1/table/{}/create", options.name))
|
||||
let rsp = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/create/", options.name))
|
||||
.body(data_buffer)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
// This is currently expected by LanceDb cloud but will be removed soon.
|
||||
.header("x-request-id", "na")
|
||||
.send()
|
||||
.await?;
|
||||
self.client.check_response(rsp).await?;
|
||||
|
||||
Ok(Table::new(Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
@@ -113,4 +116,8 @@ impl ConnectionInternal for RemoteDatabase {
|
||||
async fn drop_db(&self) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::{
|
||||
query::{Query, QueryExecutionOptions, VectorQuery},
|
||||
table::{
|
||||
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
|
||||
TableInternal, UpdateBuilder,
|
||||
TableDefinition, TableInternal, UpdateBuilder,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -120,4 +120,7 @@ impl TableInternal for RemoteTable {
|
||||
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
|
||||
todo!()
|
||||
}
|
||||
async fn table_definition(&self) -> Result<TableDefinition> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,10 +41,12 @@ use lance::io::WrappingObjectStore;
|
||||
use lance_index::IndexType;
|
||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
||||
use log::info;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::whatever;
|
||||
|
||||
use crate::arrow::IntoArrow;
|
||||
use crate::connection::NoData;
|
||||
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
|
||||
use crate::index::IndexConfig;
|
||||
@@ -63,6 +65,79 @@ use self::merge::MergeInsertBuilder;
|
||||
pub(crate) mod dataset;
|
||||
pub mod merge;
|
||||
|
||||
/// Defines the type of column
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ColumnKind {
|
||||
/// Columns populated by data from the user (this is the most common case)
|
||||
Physical,
|
||||
/// Columns populated by applying an embedding function to the input
|
||||
Embedding(EmbeddingDefinition),
|
||||
}
|
||||
|
||||
/// Defines a column in a table
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ColumnDefinition {
|
||||
/// The source of the column data
|
||||
pub kind: ColumnKind,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TableDefinition {
|
||||
pub column_definitions: Vec<ColumnDefinition>,
|
||||
pub schema: SchemaRef,
|
||||
}
|
||||
|
||||
impl TableDefinition {
|
||||
pub fn new(schema: SchemaRef, column_definitions: Vec<ColumnDefinition>) -> Self {
|
||||
Self {
|
||||
column_definitions,
|
||||
schema,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_from_schema(schema: SchemaRef) -> Self {
|
||||
let column_definitions = schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|_| ColumnDefinition {
|
||||
kind: ColumnKind::Physical,
|
||||
})
|
||||
.collect();
|
||||
Self::new(schema, column_definitions)
|
||||
}
|
||||
|
||||
pub fn try_from_rich_schema(schema: SchemaRef) -> Result<Self> {
|
||||
let column_definitions = schema.metadata.get("lancedb::column_definitions");
|
||||
if let Some(column_definitions) = column_definitions {
|
||||
let column_definitions: Vec<ColumnDefinition> =
|
||||
serde_json::from_str(column_definitions).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to deserialize column definitions: {}", e),
|
||||
})?;
|
||||
Ok(Self::new(schema, column_definitions))
|
||||
} else {
|
||||
let column_definitions = schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|_| ColumnDefinition {
|
||||
kind: ColumnKind::Physical,
|
||||
})
|
||||
.collect();
|
||||
Ok(Self::new(schema, column_definitions))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_rich_schema(self) -> SchemaRef {
|
||||
// We have full control over the structure of column definitions. This should
|
||||
// not fail, except for a bug
|
||||
let lancedb_metadata = serde_json::to_string(&self.column_definitions).unwrap();
|
||||
let mut schema_with_metadata = (*self.schema).clone();
|
||||
schema_with_metadata
|
||||
.metadata
|
||||
.insert("lancedb::column_definitions".to_string(), lancedb_metadata);
|
||||
Arc::new(schema_with_metadata)
|
||||
}
|
||||
}
|
||||
|
||||
/// Optimize the dataset.
|
||||
///
|
||||
/// Similar to `VACUUM` in PostgreSQL, it offers different options to
|
||||
@@ -132,6 +207,7 @@ pub struct AddDataBuilder<T: IntoArrow> {
|
||||
pub(crate) data: T,
|
||||
pub(crate) mode: AddDataMode,
|
||||
pub(crate) write_options: WriteOptions,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
}
|
||||
|
||||
impl<T: IntoArrow> std::fmt::Debug for AddDataBuilder<T> {
|
||||
@@ -163,6 +239,7 @@ impl<T: IntoArrow> AddDataBuilder<T> {
|
||||
mode: self.mode,
|
||||
parent: self.parent,
|
||||
write_options: self.write_options,
|
||||
embedding_registry: self.embedding_registry,
|
||||
};
|
||||
parent.add(without_data, data).await
|
||||
}
|
||||
@@ -280,6 +357,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
async fn checkout(&self, version: u64) -> Result<()>;
|
||||
async fn checkout_latest(&self) -> Result<()>;
|
||||
async fn restore(&self) -> Result<()>;
|
||||
async fn table_definition(&self) -> Result<TableDefinition>;
|
||||
}
|
||||
|
||||
/// A Table is a collection of strong typed Rows.
|
||||
@@ -288,6 +366,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
#[derive(Clone)]
|
||||
pub struct Table {
|
||||
inner: Arc<dyn TableInternal>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Table {
|
||||
@@ -298,7 +377,20 @@ impl std::fmt::Display for Table {
|
||||
|
||||
impl Table {
|
||||
pub(crate) fn new(inner: Arc<dyn TableInternal>) -> Self {
|
||||
Self { inner }
|
||||
Self {
|
||||
inner,
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_with_embedding_registry(
|
||||
inner: Arc<dyn TableInternal>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
embedding_registry,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
||||
@@ -340,6 +432,7 @@ impl Table {
|
||||
data: batches,
|
||||
mode: AddDataMode::Append,
|
||||
write_options: WriteOptions::default(),
|
||||
embedding_registry: Some(self.embedding_registry.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -743,11 +836,10 @@ impl Table {
|
||||
|
||||
impl From<NativeTable> for Table {
|
||||
fn from(table: NativeTable) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(table),
|
||||
}
|
||||
Self::new(Arc::new(table))
|
||||
}
|
||||
}
|
||||
|
||||
/// A table in a LanceDB database.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NativeTable {
|
||||
@@ -918,7 +1010,6 @@ impl NativeTable {
|
||||
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
||||
None => params,
|
||||
};
|
||||
|
||||
let storage_options = params
|
||||
.store_params
|
||||
.clone()
|
||||
@@ -1061,6 +1152,26 @@ impl NativeTable {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_index_type(&self, index_uuid: &str) -> Result<Option<String>> {
|
||||
match self.load_index_stats(index_uuid).await? {
|
||||
Some(stats) => Ok(Some(stats.index_type)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_distance_type(&self, index_uuid: &str) -> Result<Option<String>> {
|
||||
match self.load_index_stats(index_uuid).await? {
|
||||
Some(stats) => Ok(Some(
|
||||
stats
|
||||
.indices
|
||||
.iter()
|
||||
.map(|i| i.metric_type.clone())
|
||||
.collect(),
|
||||
)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
|
||||
let dataset = self.dataset.get().await?;
|
||||
let (indices, mf) = futures::try_join!(dataset.load_indices(), dataset.latest_manifest())?;
|
||||
@@ -1322,6 +1433,11 @@ impl TableInternal for NativeTable {
|
||||
Ok(Arc::new(Schema::from(&lance_schema)))
|
||||
}
|
||||
|
||||
async fn table_definition(&self) -> Result<TableDefinition> {
|
||||
let schema = self.schema().await?;
|
||||
TableDefinition::try_from_rich_schema(schema)
|
||||
}
|
||||
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
Ok(self.dataset.get().await?.count_rows(filter).await?)
|
||||
}
|
||||
@@ -1331,6 +1447,9 @@ impl TableInternal for NativeTable {
|
||||
add: AddDataBuilder<NoData>,
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<()> {
|
||||
let data =
|
||||
MaybeEmbedded::try_new(data, self.table_definition().await?, add.embedding_registry)?;
|
||||
|
||||
let mut lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
|
||||
mode: match add.mode {
|
||||
AddDataMode::Append => WriteMode::Append,
|
||||
@@ -1358,8 +1477,8 @@ impl TableInternal for NativeTable {
|
||||
};
|
||||
|
||||
self.dataset.ensure_mutable().await?;
|
||||
|
||||
let dataset = Dataset::write(data, &self.uri, Some(lance_params)).await?;
|
||||
|
||||
self.dataset.set_latest(dataset).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
320
rust/lancedb/tests/embedding_registry_test.rs
Normal file
320
rust/lancedb/tests/embedding_registry_test.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{HashMap, HashSet},
|
||||
iter::repeat,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use arrow::buffer::NullBuffer;
|
||||
use arrow_array::{
|
||||
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||
StringArray,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::StreamExt;
|
||||
use lancedb::{
|
||||
arrow::IntoArrow,
|
||||
connect,
|
||||
embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry},
|
||||
query::ExecutableQuery,
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_custom_func() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
let db = connect(tempdir).execute().await?;
|
||||
let embed_fun = MockEmbed::new("embed_fun".to_string(), 1);
|
||||
db.embedding_registry()
|
||||
.register("embed_fun", Arc::new(embed_fun.clone()))?;
|
||||
|
||||
let tbl = db
|
||||
.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
&embed_fun.name,
|
||||
Some("embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
|
||||
}
|
||||
// now make sure the embeddings are applied when
|
||||
// we add new records too
|
||||
tbl.add(create_some_records()?).execute().await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_custom_registry() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir)
|
||||
.embedding_registry(Arc::new(MyRegistry::default()))
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let tbl = db
|
||||
.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
"func_1",
|
||||
Some("embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(
|
||||
embeddings.data_type(),
|
||||
MockEmbed::new("func_1".to_string(), 1)
|
||||
.dest_type()?
|
||||
.as_ref()
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_embeddings() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
let func_1 = MockEmbed::new("func_1".to_string(), 1);
|
||||
let func_2 = MockEmbed::new("func_2".to_string(), 10);
|
||||
db.embedding_registry()
|
||||
.register(&func_1.name, Arc::new(func_1.clone()))?;
|
||||
db.embedding_registry()
|
||||
.register(&func_2.name, Arc::new(func_2.clone()))?;
|
||||
|
||||
let tbl = db
|
||||
.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
&func_1.name,
|
||||
Some("first_embeddings"),
|
||||
))?
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
&func_2.name,
|
||||
Some("second_embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("first_embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let second_embeddings = batch.column_by_name("second_embeddings");
|
||||
assert!(second_embeddings.is_some());
|
||||
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
|
||||
|
||||
let second_embeddings = second_embeddings.unwrap();
|
||||
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
|
||||
}
|
||||
|
||||
// now make sure the embeddings are applied when
|
||||
// we add new records too
|
||||
tbl.add(create_some_records()?).execute().await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("first_embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let second_embeddings = batch.column_by_name("second_embeddings");
|
||||
assert!(second_embeddings.is_some());
|
||||
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
|
||||
|
||||
let second_embeddings = second_embeddings.unwrap();
|
||||
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_func_in_registry() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
|
||||
let res = db
|
||||
.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
"some_func",
|
||||
Some("first_embeddings"),
|
||||
));
|
||||
assert!(res.is_err());
|
||||
assert!(matches!(
|
||||
res.err().unwrap(),
|
||||
Error::EmbeddingFunctionNotFound { .. }
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_func_in_registry_on_add() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
db.embedding_registry().register(
|
||||
"some_func",
|
||||
Arc::new(MockEmbed::new("some_func".to_string(), 1)),
|
||||
)?;
|
||||
|
||||
db.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
"some_func",
|
||||
Some("first_embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
|
||||
let tbl = db.open_table("test").execute().await?;
|
||||
// This should fail because 'tbl' is expecting "some_func" to be in the registry
|
||||
let res = tbl.add(create_some_records()?).execute().await;
|
||||
assert!(res.is_err());
|
||||
assert!(matches!(
|
||||
res.unwrap_err(),
|
||||
crate::Error::EmbeddingFunctionNotFound { .. }
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_some_records() -> Result<impl IntoArrow> {
|
||||
const TOTAL: usize = 2;
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("text", DataType::Utf8, true),
|
||||
]));
|
||||
|
||||
// Create a RecordBatch stream.
|
||||
let batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
||||
Arc::new(StringArray::from_iter(
|
||||
repeat(Some("hello world".to_string())).take(TOTAL),
|
||||
)),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
);
|
||||
Ok(Box::new(batches))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MyRegistry {
|
||||
functions: HashMap<String, Arc<dyn EmbeddingFunction>>,
|
||||
}
|
||||
impl Default for MyRegistry {
|
||||
fn default() -> Self {
|
||||
let funcs: Vec<Arc<dyn EmbeddingFunction>> = vec![
|
||||
Arc::new(MockEmbed::new("func_1".to_string(), 1)),
|
||||
Arc::new(MockEmbed::new("func_2".to_string(), 10)),
|
||||
];
|
||||
Self {
|
||||
functions: funcs
|
||||
.into_iter()
|
||||
.map(|f| (f.name().to_string(), f))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// a mock registry that only has one function called `embed_fun`
|
||||
impl EmbeddingRegistry for MyRegistry {
|
||||
fn functions(&self) -> HashSet<String> {
|
||||
self.functions.keys().cloned().collect()
|
||||
}
|
||||
|
||||
fn register(&self, _name: &str, _function: Arc<dyn EmbeddingFunction>) -> Result<()> {
|
||||
Err(Error::Other {
|
||||
message: "MyRegistry is read-only".to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
|
||||
self.functions.get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct MockEmbed {
|
||||
source_type: DataType,
|
||||
dest_type: DataType,
|
||||
name: String,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl MockEmbed {
|
||||
pub fn new(name: String, dim: usize) -> Self {
|
||||
Self {
|
||||
source_type: DataType::Utf8,
|
||||
dest_type: DataType::new_fixed_size_list(DataType::Float32, dim as _, true),
|
||||
name,
|
||||
dim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingFunction for MockEmbed {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
fn source_type(&self) -> Result<Cow<DataType>> {
|
||||
Ok(Cow::Borrowed(&self.source_type))
|
||||
}
|
||||
fn dest_type(&self) -> Result<Cow<DataType>> {
|
||||
Ok(Cow::Borrowed(&self.dest_type))
|
||||
}
|
||||
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
|
||||
// and we want to explicitly work with non-nullable arrays.
|
||||
let len = source.len();
|
||||
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
|
||||
let field = Field::new("item", inner.data_type().clone(), false);
|
||||
let arr = FixedSizeListArray::new(
|
||||
Arc::new(field),
|
||||
self.dim as _,
|
||||
inner,
|
||||
Some(NullBuffer::new_valid(len)),
|
||||
);
|
||||
|
||||
Ok(Arc::new(arr))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user