mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
22 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe89a373a2 | ||
|
|
3d3915edef | ||
|
|
e2e8b6aee4 | ||
|
|
12dbca5248 | ||
|
|
a6babfa651 | ||
|
|
75ede86fab | ||
|
|
becd649130 | ||
|
|
9d2fb7d602 | ||
|
|
fdb5d6fdf1 | ||
|
|
2f13fa225f | ||
|
|
e933de003d | ||
|
|
05fd387425 | ||
|
|
82a1da554c | ||
|
|
a7c0d80b9e | ||
|
|
71323a064a | ||
|
|
df48454b70 | ||
|
|
6603414885 | ||
|
|
c256f6c502 | ||
|
|
cc03f90379 | ||
|
|
975da09b02 | ||
|
|
c32e17b497 | ||
|
|
0528abdf97 |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.4.17
|
current_version = 0.4.19
|
||||||
commit = True
|
commit = True
|
||||||
message = Bump version: {current_version} → {new_version}
|
message = Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,7 +6,7 @@
|
|||||||
venv
|
venv
|
||||||
|
|
||||||
.vscode
|
.vscode
|
||||||
|
.zed
|
||||||
rust/target
|
rust/target
|
||||||
rust/Cargo.lock
|
rust/Cargo.lock
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@
|
|||||||
|
|
||||||
<hr />
|
<hr />
|
||||||
|
|
||||||
LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrevial, filtering and management of embeddings.
|
LanceDB is an open-source database for vector-search built with persistent storage, which greatly simplifies retrieval, filtering and management of embeddings.
|
||||||
|
|
||||||
The key features of LanceDB include:
|
The key features of LanceDB include:
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ The key features of LanceDB include:
|
|||||||
|
|
||||||
* GPU support in building vector index(*).
|
* GPU support in building vector index(*).
|
||||||
|
|
||||||
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lanecdb.html), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
|
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/docs/integrations/vectorstores/lancedb/), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
|
||||||
|
|
||||||
LanceDB's core is written in Rust 🦀 and is built using <a href="https://github.com/lancedb/lance">Lance</a>, an open-source columnar format designed for performant ML workloads.
|
LanceDB's core is written in Rust 🦀 and is built using <a href="https://github.com/lancedb/lance">Lance</a>, an open-source columnar format designed for performant ML workloads.
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ For this purpose, LanceDB introduces an **embedding functions API**, that allow
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
class Pets(LanceModel):
|
class Pets(LanceModel):
|
||||||
vector: Vector(clip.ndims) = clip.VectorField()
|
vector: Vector(clip.ndims()) = clip.VectorField()
|
||||||
image_uri: str = clip.SourceField()
|
image_uri: str = clip.SourceField()
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -149,7 +149,7 @@ You can also use the integration for adding utility operations in the schema. Fo
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
class Pets(LanceModel):
|
class Pets(LanceModel):
|
||||||
vector: Vector(clip.ndims) = clip.VectorField()
|
vector: Vector(clip.ndims()) = clip.VectorField()
|
||||||
image_uri: str = clip.SourceField()
|
image_uri: str = clip.SourceField()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -299,6 +299,14 @@ LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you m
|
|||||||
|
|
||||||
This can also be done with the ``AWS_ENDPOINT`` and ``AWS_DEFAULT_REGION`` environment variables.
|
This can also be done with the ``AWS_ENDPOINT`` and ``AWS_DEFAULT_REGION`` environment variables.
|
||||||
|
|
||||||
|
!!! tip "Local servers"
|
||||||
|
|
||||||
|
For local development, the server often has a `http` endpoint rather than a
|
||||||
|
secure `https` endpoint. In this case, you must also set the `ALLOW_HTTP`
|
||||||
|
environment variable to `true` to allow non-TLS connections, or pass the
|
||||||
|
storage option `allow_http` as `true`. If you do not do this, you will get
|
||||||
|
an error like `URL scheme is not allowed`.
|
||||||
|
|
||||||
#### S3 Express
|
#### S3 Express
|
||||||
|
|
||||||
LanceDB supports [S3 Express One Zone](https://aws.amazon.com/s3/storage-classes/express-one-zone/) endpoints, but requires additional configuration. Also, S3 Express endpoints only support connecting from an EC2 instance within the same region.
|
LanceDB supports [S3 Express One Zone](https://aws.amazon.com/s3/storage-classes/express-one-zone/) endpoints, but requires additional configuration. Also, S3 Express endpoints only support connecting from an EC2 instance within the same region.
|
||||||
|
|||||||
@@ -36,7 +36,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install --quiet openai datasets \n",
|
"!pip install --quiet openai datasets\n",
|
||||||
"!pip install --quiet -U lancedb"
|
"!pip install --quiet -U lancedb"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -213,7 +213,7 @@
|
|||||||
"if \"OPENAI_API_KEY\" not in os.environ:\n",
|
"if \"OPENAI_API_KEY\" not in os.environ:\n",
|
||||||
" # OR set the key here as a variable\n",
|
" # OR set the key here as a variable\n",
|
||||||
" os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n",
|
" os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n",
|
||||||
" \n",
|
"\n",
|
||||||
"client = OpenAI()\n",
|
"client = OpenAI()\n",
|
||||||
"assert len(client.models.list().data) > 0"
|
"assert len(client.models.list().data) > 0"
|
||||||
]
|
]
|
||||||
@@ -234,9 +234,12 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def embed_func(c): \n",
|
"def embed_func(c):\n",
|
||||||
" rs = client.embeddings.create(input=c, model=\"text-embedding-ada-002\")\n",
|
" rs = client.embeddings.create(input=c, model=\"text-embedding-ada-002\")\n",
|
||||||
" return [rs.data[0].embedding]"
|
" return [\n",
|
||||||
|
" data.embedding\n",
|
||||||
|
" for data in rs.data\n",
|
||||||
|
" ]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -514,7 +517,7 @@
|
|||||||
" prompt_start +\n",
|
" prompt_start +\n",
|
||||||
" \"\\n\\n---\\n\\n\".join(context.text) +\n",
|
" \"\\n\\n---\\n\\n\".join(context.text) +\n",
|
||||||
" prompt_end\n",
|
" prompt_end\n",
|
||||||
" ) \n",
|
" )\n",
|
||||||
" return prompt"
|
" return prompt"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
14
node/package-lock.json
generated
14
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.4.17",
|
"version": "0.4.19",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.4.17",
|
"version": "0.4.19",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
@@ -52,11 +52,11 @@
|
|||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.4.17",
|
"@lancedb/vectordb-darwin-arm64": "0.4.19",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.4.17",
|
"@lancedb/vectordb-darwin-x64": "0.4.19",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.17",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.4.19",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.17",
|
"@lancedb/vectordb-linux-x64-gnu": "0.4.19",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.17"
|
"@lancedb/vectordb-win32-x64-msvc": "0.4.19"
|
||||||
},
|
},
|
||||||
"peerDependencies": {
|
"peerDependencies": {
|
||||||
"@apache-arrow/ts": "^14.0.2",
|
"@apache-arrow/ts": "^14.0.2",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.4.17",
|
"version": "0.4.19",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"types": "dist/index.d.ts",
|
"types": "dist/index.d.ts",
|
||||||
@@ -88,10 +88,10 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.4.17",
|
"@lancedb/vectordb-darwin-arm64": "0.4.19",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.4.17",
|
"@lancedb/vectordb-darwin-x64": "0.4.19",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.17",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.4.19",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.17",
|
"@lancedb/vectordb-linux-x64-gnu": "0.4.19",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.17"
|
"@lancedb/vectordb-win32-x64-msvc": "0.4.19"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
|||||||
|
|
||||||
const dir = tmpdir()
|
const dir = tmpdir()
|
||||||
console.log(dir)
|
console.log(dir)
|
||||||
const conn = await lancedb.connect(`s3://lancedb-integtest?mirroredStore=${dir}`)
|
const conn = await lancedb.connect({ uri: `s3://lancedb-integtest?mirroredStore=${dir}`, storageOptions: { allowHttp: 'true' } })
|
||||||
const data = Array(200).fill({ vector: Array(128).fill(1.0), id: 0 })
|
const data = Array(200).fill({ vector: Array(128).fill(1.0), id: 0 })
|
||||||
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 1 }))
|
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 1 }))
|
||||||
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 2 }))
|
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 2 }))
|
||||||
|
|||||||
1
nodejs/.gitignore
vendored
Normal file
1
nodejs/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
yarn.lock
|
||||||
@@ -169,17 +169,20 @@ export class Table {
|
|||||||
* // If the column has a vector (fixed size list) data type then
|
* // If the column has a vector (fixed size list) data type then
|
||||||
* // an IvfPq vector index will be created.
|
* // an IvfPq vector index will be created.
|
||||||
* const table = await conn.openTable("my_table");
|
* const table = await conn.openTable("my_table");
|
||||||
* await table.createIndex(["vector"]);
|
* await table.createIndex("vector");
|
||||||
* @example
|
* @example
|
||||||
* // For advanced control over vector index creation you can specify
|
* // For advanced control over vector index creation you can specify
|
||||||
* // the index type and options.
|
* // the index type and options.
|
||||||
* const table = await conn.openTable("my_table");
|
* const table = await conn.openTable("my_table");
|
||||||
* await table.createIndex(["vector"], I)
|
* await table.createIndex("vector", {
|
||||||
* .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 })
|
* config: lancedb.Index.ivfPq({
|
||||||
* .build();
|
* numPartitions: 128,
|
||||||
|
* numSubVectors: 16,
|
||||||
|
* }),
|
||||||
|
* });
|
||||||
* @example
|
* @example
|
||||||
* // Or create a Scalar index
|
* // Or create a Scalar index
|
||||||
* await table.createIndex("my_float_col").build();
|
* await table.createIndex("my_float_col");
|
||||||
*/
|
*/
|
||||||
async createIndex(column: string, options?: Partial<IndexOptions>) {
|
async createIndex(column: string, options?: Partial<IndexOptions>) {
|
||||||
// Bit of a hack to get around the fact that TS has no package-scope.
|
// Bit of a hack to get around the fact that TS has no package-scope.
|
||||||
@@ -197,8 +200,7 @@ export class Table {
|
|||||||
* vector similarity, sorting, and more.
|
* vector similarity, sorting, and more.
|
||||||
*
|
*
|
||||||
* Note: By default, all columns are returned. For best performance, you should
|
* Note: By default, all columns are returned. For best performance, you should
|
||||||
* only fetch the columns you need. See [`Query::select_with_projection`] for
|
* only fetch the columns you need.
|
||||||
* more details.
|
|
||||||
*
|
*
|
||||||
* When appropriate, various indices and statistics based pruning will be used to
|
* When appropriate, various indices and statistics based pruning will be used to
|
||||||
* accelerate the query.
|
* accelerate the query.
|
||||||
@@ -207,8 +209,11 @@ export class Table {
|
|||||||
* //
|
* //
|
||||||
* // This query will return up to 1000 rows whose value in the `id` column
|
* // This query will return up to 1000 rows whose value in the `id` column
|
||||||
* // is greater than 5. LanceDb supports a broad set of filtering functions.
|
* // is greater than 5. LanceDb supports a broad set of filtering functions.
|
||||||
* for await (const batch of table.query()
|
* for await (const batch of table
|
||||||
* .filter("id > 1").select(["id"]).limit(20)) {
|
* .query()
|
||||||
|
* .where("id > 1")
|
||||||
|
* .select(["id"])
|
||||||
|
* .limit(20)) {
|
||||||
* console.log(batch);
|
* console.log(batch);
|
||||||
* }
|
* }
|
||||||
* @example
|
* @example
|
||||||
@@ -218,12 +223,13 @@ export class Table {
|
|||||||
* // closest to the query vector [1.0, 2.0, 3.0]. If an index has been created
|
* // closest to the query vector [1.0, 2.0, 3.0]. If an index has been created
|
||||||
* // on the "vector" column then this will perform an ANN search.
|
* // on the "vector" column then this will perform an ANN search.
|
||||||
* //
|
* //
|
||||||
* // The `refine_factor` and `nprobes` methods are used to control the recall /
|
* // The `refineFactor` and `nprobes` methods are used to control the recall /
|
||||||
* // latency tradeoff of the search.
|
* // latency tradeoff of the search.
|
||||||
* for await (const batch of table.query()
|
* for await (const batch of table
|
||||||
* .nearestTo([1, 2, 3])
|
* .query()
|
||||||
* .refineFactor(5).nprobe(10)
|
* .where("id > 1")
|
||||||
* .limit(10)) {
|
* .select(["id"])
|
||||||
|
* .limit(20)) {
|
||||||
* console.log(batch);
|
* console.log(batch);
|
||||||
* }
|
* }
|
||||||
* @example
|
* @example
|
||||||
@@ -286,43 +292,45 @@ export class Table {
|
|||||||
await this.inner.dropColumns(columnNames);
|
await this.inner.dropColumns(columnNames);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** Retrieve the version of the table */
|
||||||
* Retrieve the version of the table
|
|
||||||
*
|
|
||||||
* LanceDb supports versioning. Every operation that modifies the table increases
|
|
||||||
* version. As long as a version hasn't been deleted you can `[Self::checkout]` that
|
|
||||||
* version to view the data at that point. In addition, you can `[Self::restore]` the
|
|
||||||
* version to replace the current table with a previous version.
|
|
||||||
*/
|
|
||||||
async version(): Promise<number> {
|
async version(): Promise<number> {
|
||||||
return await this.inner.version();
|
return await this.inner.version();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks out a specific version of the Table
|
* Checks out a specific version of the table _This is an in-place operation._
|
||||||
*
|
*
|
||||||
* Any read operation on the table will now access the data at the checked out version.
|
* This allows viewing previous versions of the table. If you wish to
|
||||||
* As a consequence, calling this method will disable any read consistency interval
|
* keep writing to the dataset starting from an old version, then use
|
||||||
* that was previously set.
|
* the `restore` function.
|
||||||
*
|
*
|
||||||
* This is a read-only operation that turns the table into a sort of "view"
|
* Calling this method will set the table into time-travel mode. If you
|
||||||
* or "detached head". Other table instances will not be affected. To make the change
|
* wish to return to standard mode, call `checkoutLatest`.
|
||||||
* permanent you can use the `[Self::restore]` method.
|
* @param {number} version The version to checkout
|
||||||
|
* @example
|
||||||
|
* ```typescript
|
||||||
|
* import * as lancedb from "@lancedb/lancedb"
|
||||||
|
* const db = await lancedb.connect("./.lancedb");
|
||||||
|
* const table = await db.createTable("my_table", [
|
||||||
|
* { vector: [1.1, 0.9], type: "vector" },
|
||||||
|
* ]);
|
||||||
*
|
*
|
||||||
* Any operation that modifies the table will fail while the table is in a checked
|
* console.log(await table.version()); // 1
|
||||||
* out state.
|
* console.log(table.display());
|
||||||
*
|
* await table.add([{ vector: [0.5, 0.2], type: "vector" }]);
|
||||||
* To return the table to a normal state use `[Self::checkout_latest]`
|
* await table.checkout(1);
|
||||||
|
* console.log(await table.version()); // 2
|
||||||
|
* ```
|
||||||
*/
|
*/
|
||||||
async checkout(version: number): Promise<void> {
|
async checkout(version: number): Promise<void> {
|
||||||
await this.inner.checkout(version);
|
await this.inner.checkout(version);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Ensures the table is pointing at the latest version
|
* Checkout the latest version of the table. _This is an in-place operation._
|
||||||
*
|
*
|
||||||
* This can be used to manually update a table when the read_consistency_interval is None
|
* The table will be set back into standard mode, and will track the latest
|
||||||
* It can also be used to undo a `[Self::checkout]` operation
|
* version of the table.
|
||||||
*/
|
*/
|
||||||
async checkoutLatest(): Promise<void> {
|
async checkoutLatest(): Promise<void> {
|
||||||
await this.inner.checkoutLatest();
|
await this.inner.checkoutLatest();
|
||||||
@@ -344,9 +352,7 @@ export class Table {
|
|||||||
await this.inner.restore();
|
await this.inner.restore();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** List all indices that have been created with {@link Table.createIndex} */
|
||||||
* List all indices that have been created with Self::create_index
|
|
||||||
*/
|
|
||||||
async listIndices(): Promise<IndexConfig[]> {
|
async listIndices(): Promise<IndexConfig[]> {
|
||||||
return await this.inner.listIndices();
|
return await this.inner.listIndices();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-arm64",
|
"name": "@lancedb/lancedb-darwin-arm64",
|
||||||
"version": "0.4.17",
|
"version": "0.4.19",
|
||||||
"os": [
|
"os": [
|
||||||
"darwin"
|
"darwin"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-x64",
|
"name": "@lancedb/lancedb-darwin-x64",
|
||||||
"version": "0.4.17",
|
"version": "0.4.19",
|
||||||
"os": [
|
"os": [
|
||||||
"darwin"
|
"darwin"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||||
"version": "0.4.17",
|
"version": "0.4.19",
|
||||||
"os": [
|
"os": [
|
||||||
"linux"
|
"linux"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||||
"version": "0.4.17",
|
"version": "0.4.19",
|
||||||
"os": [
|
"os": [
|
||||||
"linux"
|
"linux"
|
||||||
],
|
],
|
||||||
|
|||||||
86
nodejs/package-lock.json
generated
86
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.4.16",
|
"version": "0.4.18",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.4.16",
|
"version": "0.4.18",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
@@ -45,13 +45,6 @@
|
|||||||
},
|
},
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">= 18"
|
"node": ">= 18"
|
||||||
},
|
|
||||||
"optionalDependencies": {
|
|
||||||
"@lancedb/lancedb-darwin-arm64": "0.4.16",
|
|
||||||
"@lancedb/lancedb-darwin-x64": "0.4.16",
|
|
||||||
"@lancedb/lancedb-linux-arm64-gnu": "0.4.16",
|
|
||||||
"@lancedb/lancedb-linux-x64-gnu": "0.4.16",
|
|
||||||
"@lancedb/lancedb-win32-x64-msvc": "0.4.16"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@75lb/deep-merge": {
|
"node_modules/@75lb/deep-merge": {
|
||||||
@@ -2221,81 +2214,6 @@
|
|||||||
"@jridgewell/sourcemap-codec": "^1.4.14"
|
"@jridgewell/sourcemap-codec": "^1.4.14"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/lancedb-darwin-arm64": {
|
|
||||||
"version": "0.4.16",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-darwin-arm64/-/lancedb-darwin-arm64-0.4.16.tgz",
|
|
||||||
"integrity": "sha512-CV65ouIDQbBSNtdHbQSr2fqXflOuqud1cfweUS+EiK7eEOEYl7nO2oiFYO49Jy76MEwZxiP99hW825aCqIQJqg==",
|
|
||||||
"cpu": [
|
|
||||||
"arm64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"darwin"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">= 18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@lancedb/lancedb-darwin-x64": {
|
|
||||||
"version": "0.4.16",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-darwin-x64/-/lancedb-darwin-x64-0.4.16.tgz",
|
|
||||||
"integrity": "sha512-1CwIYCNdbFmV7fvqM+qUxbYgwxx0slcCV48PC/I19Ejitgtzw/NJiWDCvONhaLqG85lWNZm1xYceRpVv7b8seQ==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"darwin"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">= 18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@lancedb/lancedb-linux-arm64-gnu": {
|
|
||||||
"version": "0.4.16",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-linux-arm64-gnu/-/lancedb-linux-arm64-gnu-0.4.16.tgz",
|
|
||||||
"integrity": "sha512-CzLEbzoHKS6jV0k52YnvsiVNx0VzLp1Vz/zmbHI6HmB/XbS67qDO93Jk71MDmXq3JDw0FKFCw9ghkg+6YWq7ZA==",
|
|
||||||
"cpu": [
|
|
||||||
"arm64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">= 18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@lancedb/lancedb-linux-x64-gnu": {
|
|
||||||
"version": "0.4.16",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-linux-x64-gnu/-/lancedb-linux-x64-gnu-0.4.16.tgz",
|
|
||||||
"integrity": "sha512-nKChybybi8uA0AFRHBFm7Fz3VXcRm8riv5Gs7xQsrsCtYxxf4DT/0BfUvQ0xKbwNJa+fawHRxi9BOQewdj49fg==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">= 18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@lancedb/lancedb-win32-x64-msvc": {
|
|
||||||
"version": "0.4.16",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/lancedb-win32-x64-msvc/-/lancedb-win32-x64-msvc-0.4.16.tgz",
|
|
||||||
"integrity": "sha512-KMeBPMpv2g+ZMVsHVibed7BydrBlxje1qS0bZTDrLw9BtZOk6XH2lh1mCDnCJI6sbAscUKNA6fDCdquhQPHL7w==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"win32"
|
|
||||||
],
|
|
||||||
"engines": {
|
|
||||||
"node": ">= 18"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/@napi-rs/cli": {
|
"node_modules/@napi-rs/cli": {
|
||||||
"version": "2.18.0",
|
"version": "2.18.0",
|
||||||
"resolved": "https://registry.npmjs.org/@napi-rs/cli/-/cli-2.18.0.tgz",
|
"resolved": "https://registry.npmjs.org/@napi-rs/cli/-/cli-2.18.0.tgz",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb",
|
"name": "@lancedb/lancedb",
|
||||||
"version": "0.4.17",
|
"version": "0.4.19",
|
||||||
"main": "./dist/index.js",
|
"main": "./dist/index.js",
|
||||||
"types": "./dist/index.d.ts",
|
"types": "./dist/index.d.ts",
|
||||||
"napi": {
|
"napi": {
|
||||||
@@ -62,20 +62,14 @@
|
|||||||
"build-release": "npm run build:release && tsc -b && shx cp lancedb/native.d.ts dist/native.d.ts",
|
"build-release": "npm run build:release && tsc -b && shx cp lancedb/native.d.ts dist/native.d.ts",
|
||||||
"chkformat": "prettier . --check",
|
"chkformat": "prettier . --check",
|
||||||
"docs": "typedoc --plugin typedoc-plugin-markdown --out ../docs/src/js lancedb/index.ts",
|
"docs": "typedoc --plugin typedoc-plugin-markdown --out ../docs/src/js lancedb/index.ts",
|
||||||
"lint": "eslint lancedb && eslint __test__",
|
"lint": "eslint lancedb __test__",
|
||||||
|
"lint-fix": "eslint lancedb __test__ --fix",
|
||||||
"prepublishOnly": "napi prepublish -t npm",
|
"prepublishOnly": "napi prepublish -t npm",
|
||||||
"test": "npm run build && jest --verbose",
|
"test": "npm run build && jest --verbose",
|
||||||
"integration": "S3_TEST=1 npm run test",
|
"integration": "S3_TEST=1 npm run test",
|
||||||
"universal": "napi universal",
|
"universal": "napi universal",
|
||||||
"version": "napi version"
|
"version": "napi version"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
|
||||||
"@lancedb/lancedb-darwin-arm64": "0.4.17",
|
|
||||||
"@lancedb/lancedb-darwin-x64": "0.4.17",
|
|
||||||
"@lancedb/lancedb-linux-arm64-gnu": "0.4.17",
|
|
||||||
"@lancedb/lancedb-linux-x64-gnu": "0.4.17",
|
|
||||||
"@lancedb/lancedb-win32-x64-msvc": "0.4.17"
|
|
||||||
},
|
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"openai": "^4.29.2",
|
"openai": "^4.29.2",
|
||||||
"apache-arrow": "^15.0.0"
|
"apache-arrow": "^15.0.0"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.6.11
|
current_version = 0.6.12
|
||||||
commit = True
|
commit = True
|
||||||
message = [python] Bump version: {current_version} → {new_version}
|
message = [python] Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.6.11"
|
version = "0.6.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.10.12",
|
"pylance==0.10.12",
|
||||||
|
|||||||
@@ -107,6 +107,9 @@ def connect(
|
|||||||
request_thread_pool=request_thread_pool,
|
request_thread_pool=request_thread_pool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unknown keyword arguments: {kwargs}")
|
||||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -255,7 +255,13 @@ def retry_with_exponential_backoff(
|
|||||||
)
|
)
|
||||||
|
|
||||||
delay *= exponential_base * (1 + jitter * random.random())
|
delay *= exponential_base * (1 + jitter * random.random())
|
||||||
logging.info("Retrying in %s seconds...", delay)
|
logging.warning(
|
||||||
|
"Error occurred: %s \n Retrying in %s seconds (retry %s of %s) \n",
|
||||||
|
e,
|
||||||
|
delay,
|
||||||
|
num_retries,
|
||||||
|
max_retries,
|
||||||
|
)
|
||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from typing import (
|
|||||||
import deprecation
|
import deprecation
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
import pyarrow.fs as pa_fs
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
@@ -37,7 +38,7 @@ from .arrow import AsyncRecordBatchReader
|
|||||||
from .common import VEC
|
from .common import VEC
|
||||||
from .rerankers.base import Reranker
|
from .rerankers.base import Reranker
|
||||||
from .rerankers.linear_combination import LinearCombinationReranker
|
from .rerankers.linear_combination import LinearCombinationReranker
|
||||||
from .util import safe_import_pandas
|
from .util import fs_from_uri, safe_import_pandas
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import PIL
|
import PIL
|
||||||
@@ -665,6 +666,14 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
|||||||
|
|
||||||
# get the index path
|
# get the index path
|
||||||
index_path = self._table._get_fts_index_path()
|
index_path = self._table._get_fts_index_path()
|
||||||
|
|
||||||
|
# Check that we are on local filesystem
|
||||||
|
fs, _path = fs_from_uri(index_path)
|
||||||
|
if not isinstance(fs, pa_fs.LocalFileSystem):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Full-text search is only supported on the local filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
# check if the index exist
|
# check if the index exist
|
||||||
if not Path(index_path).exists():
|
if not Path(index_path).exists():
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
|
|||||||
@@ -1209,6 +1209,11 @@ class LanceTable(Table):
|
|||||||
raise ValueError("Index already exists. Use replace=True to overwrite.")
|
raise ValueError("Index already exists. Use replace=True to overwrite.")
|
||||||
fs.delete_dir(path)
|
fs.delete_dir(path)
|
||||||
|
|
||||||
|
if not isinstance(fs, pa_fs.LocalFileSystem):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Full-text search is only supported on the local filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
index = create_index(
|
index = create_index(
|
||||||
self._get_fts_index_path(),
|
self._get_fts_index_path(),
|
||||||
field_names,
|
field_names,
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ def test_syntax(table):
|
|||||||
# https://github.com/lancedb/lancedb/issues/769
|
# https://github.com/lancedb/lancedb/issues/769
|
||||||
table.create_fts_index("text")
|
table.create_fts_index("text")
|
||||||
with pytest.raises(ValueError, match="Syntax Error"):
|
with pytest.raises(ValueError, match="Syntax Error"):
|
||||||
table.search("they could have been dogs OR cats").limit(10).to_list()
|
table.search("they could have been dogs OR").limit(10).to_list()
|
||||||
|
|
||||||
# these should work
|
# these should work
|
||||||
|
|
||||||
|
|||||||
@@ -35,21 +35,16 @@ impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
|||||||
match &self {
|
match &self {
|
||||||
Ok(_) => Ok(self.unwrap()),
|
Ok(_) => Ok(self.unwrap()),
|
||||||
Err(err) => match err {
|
Err(err) => match err {
|
||||||
LanceError::InvalidInput { .. } => self.value_error(),
|
LanceError::InvalidInput { .. }
|
||||||
LanceError::InvalidTableName { .. } => self.value_error(),
|
| LanceError::InvalidTableName { .. }
|
||||||
LanceError::TableNotFound { .. } => self.value_error(),
|
| LanceError::TableNotFound { .. }
|
||||||
LanceError::Schema { .. } => self.value_error(),
|
| LanceError::Schema { .. } => self.value_error(),
|
||||||
LanceError::CreateDir { .. } => self.os_error(),
|
LanceError::CreateDir { .. } => self.os_error(),
|
||||||
LanceError::TableAlreadyExists { .. } => self.runtime_error(),
|
|
||||||
LanceError::ObjectStore { .. } => Err(PyIOError::new_err(err.to_string())),
|
LanceError::ObjectStore { .. } => Err(PyIOError::new_err(err.to_string())),
|
||||||
LanceError::Lance { .. } => self.runtime_error(),
|
|
||||||
LanceError::Runtime { .. } => self.runtime_error(),
|
|
||||||
LanceError::Http { .. } => self.runtime_error(),
|
|
||||||
LanceError::Arrow { .. } => self.runtime_error(),
|
|
||||||
LanceError::NotSupported { .. } => {
|
LanceError::NotSupported { .. } => {
|
||||||
Err(PyNotImplementedError::new_err(err.to_string()))
|
Err(PyNotImplementedError::new_err(err.to_string()))
|
||||||
}
|
}
|
||||||
LanceError::Other { .. } => self.runtime_error(),
|
_ => self.runtime_error(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-node"
|
name = "lancedb-node"
|
||||||
version = "0.4.17"
|
version = "0.4.19"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
|||||||
for handle in storage_options_js {
|
for handle in storage_options_js {
|
||||||
let obj = handle.downcast::<JsArray, _>(&mut cx).unwrap();
|
let obj = handle.downcast::<JsArray, _>(&mut cx).unwrap();
|
||||||
let key = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
|
let key = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
|
||||||
let value = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
|
let value = obj.get::<JsString, _, _>(&mut cx, 1)?.value(&mut cx);
|
||||||
|
|
||||||
storage_options.push((key, value));
|
storage_options.push((key, value));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.4.17"
|
version = "0.4.19"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
@@ -40,6 +40,8 @@ serde = { version = "^1" }
|
|||||||
serde_json = { version = "1" }
|
serde_json = { version = "1" }
|
||||||
# For remote feature
|
# For remote feature
|
||||||
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
|
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
|
||||||
|
polars-arrow = { version = ">=0.37", optional = true }
|
||||||
|
polars = { version = ">=0.37", optional = true}
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = "3.5.0"
|
tempfile = "3.5.0"
|
||||||
@@ -56,3 +58,4 @@ default = []
|
|||||||
remote = ["dep:reqwest"]
|
remote = ["dep:reqwest"]
|
||||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||||
s3-test = []
|
s3-test = []
|
||||||
|
polars = ["dep:polars-arrow", "dep:polars"]
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
|
|
||||||
use std::{pin::Pin, sync::Arc};
|
use std::{pin::Pin, sync::Arc};
|
||||||
|
|
||||||
pub use arrow_array;
|
|
||||||
pub use arrow_schema;
|
pub use arrow_schema;
|
||||||
use futures::{Stream, StreamExt};
|
use futures::{Stream, StreamExt};
|
||||||
|
|
||||||
|
#[cfg(feature = "polars")]
|
||||||
|
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
|
||||||
|
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
|
|
||||||
/// An iterator of batches that also has a schema
|
/// An iterator of batches that also has a schema
|
||||||
@@ -114,8 +116,183 @@ pub trait IntoArrow {
|
|||||||
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>>;
|
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub type BoxedRecordBatchReader = Box<dyn arrow_array::RecordBatchReader + Send>;
|
||||||
|
|
||||||
impl<T: arrow_array::RecordBatchReader + Send + 'static> IntoArrow for T {
|
impl<T: arrow_array::RecordBatchReader + Send + 'static> IntoArrow for T {
|
||||||
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
|
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
|
||||||
Ok(Box::new(self))
|
Ok(Box::new(self))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> SimpleRecordBatchStream<S> {
|
||||||
|
pub fn new(stream: S, schema: Arc<arrow_schema::Schema>) -> Self {
|
||||||
|
Self { schema, stream }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(feature = "polars")]
|
||||||
|
/// An iterator of record batches formed from a Polars DataFrame.
|
||||||
|
pub struct PolarsDataFrameRecordBatchReader {
|
||||||
|
chunks: std::vec::IntoIter<ArrowChunk>,
|
||||||
|
arrow_schema: Arc<arrow_schema::Schema>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "polars")]
|
||||||
|
impl PolarsDataFrameRecordBatchReader {
|
||||||
|
/// Creates a new `PolarsDataFrameRecordBatchReader` from a given Polars DataFrame.
|
||||||
|
/// If the input dataframe does not have aligned chunks, this function undergoes
|
||||||
|
/// the costly operation of reallocating each series as a single contigous chunk.
|
||||||
|
pub fn new(mut df: DataFrame) -> Result<Self> {
|
||||||
|
df.align_chunks();
|
||||||
|
let arrow_schema =
|
||||||
|
polars_arrow_convertors::convert_polars_df_schema_to_arrow_rb_schema(df.schema())?;
|
||||||
|
Ok(Self {
|
||||||
|
chunks: df
|
||||||
|
.iter_chunks(polars_arrow_convertors::POLARS_ARROW_FLAVOR)
|
||||||
|
.collect::<Vec<ArrowChunk>>()
|
||||||
|
.into_iter(),
|
||||||
|
arrow_schema,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "polars")]
|
||||||
|
impl Iterator for PolarsDataFrameRecordBatchReader {
|
||||||
|
type Item = std::result::Result<arrow_array::RecordBatch, arrow_schema::ArrowError>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
self.chunks.next().map(|chunk| {
|
||||||
|
let columns: std::result::Result<Vec<arrow_array::ArrayRef>, arrow_schema::ArrowError> =
|
||||||
|
chunk
|
||||||
|
.into_arrays()
|
||||||
|
.into_iter()
|
||||||
|
.zip(self.arrow_schema.fields.iter())
|
||||||
|
.map(|(polars_array, arrow_field)| {
|
||||||
|
polars_arrow_convertors::convert_polars_arrow_array_to_arrow_rs_array(
|
||||||
|
polars_array,
|
||||||
|
arrow_field.data_type().clone(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
arrow_array::RecordBatch::try_new(self.arrow_schema.clone(), columns?)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "polars")]
|
||||||
|
impl arrow_array::RecordBatchReader for PolarsDataFrameRecordBatchReader {
|
||||||
|
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||||
|
self.arrow_schema.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A trait for converting the result of a LanceDB query into a Polars DataFrame with aligned
|
||||||
|
/// chunks. The resulting Polars DataFrame will have aligned chunks, but the series's
|
||||||
|
/// chunks are not guaranteed to be contiguous.
|
||||||
|
#[cfg(feature = "polars")]
|
||||||
|
pub trait IntoPolars {
|
||||||
|
fn into_polars(self) -> impl std::future::Future<Output = Result<DataFrame>> + Send;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "polars")]
|
||||||
|
impl IntoPolars for SendableRecordBatchStream {
|
||||||
|
async fn into_polars(mut self) -> Result<DataFrame> {
|
||||||
|
let polars_schema =
|
||||||
|
polars_arrow_convertors::convert_arrow_rb_schema_to_polars_df_schema(&self.schema())?;
|
||||||
|
let mut acc_df: DataFrame = DataFrame::from(&polars_schema);
|
||||||
|
while let Some(record_batch) = self.next().await {
|
||||||
|
let new_df = polars_arrow_convertors::convert_arrow_rb_to_polars_df(
|
||||||
|
&record_batch?,
|
||||||
|
&polars_schema,
|
||||||
|
)?;
|
||||||
|
acc_df = acc_df.vstack(&new_df)?;
|
||||||
|
}
|
||||||
|
Ok(acc_df)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(all(test, feature = "polars"))]
|
||||||
|
mod tests {
|
||||||
|
use super::SendableRecordBatchStream;
|
||||||
|
use crate::arrow::{
|
||||||
|
IntoArrow, IntoPolars, PolarsDataFrameRecordBatchReader, SimpleRecordBatchStream,
|
||||||
|
};
|
||||||
|
use polars::prelude::{DataFrame, NamedFrom, Series};
|
||||||
|
|
||||||
|
fn get_record_batch_reader_from_polars() -> Box<dyn arrow_array::RecordBatchReader + Send> {
|
||||||
|
let mut string_series = Series::new("string", &["ab"]);
|
||||||
|
let mut int_series = Series::new("int", &[1]);
|
||||||
|
let mut float_series = Series::new("float", &[1.0]);
|
||||||
|
let df1 = DataFrame::new(vec![string_series, int_series, float_series]).unwrap();
|
||||||
|
|
||||||
|
string_series = Series::new("string", &["bc"]);
|
||||||
|
int_series = Series::new("int", &[2]);
|
||||||
|
float_series = Series::new("float", &[2.0]);
|
||||||
|
let df2 = DataFrame::new(vec![string_series, int_series, float_series]).unwrap();
|
||||||
|
|
||||||
|
PolarsDataFrameRecordBatchReader::new(df1.vstack(&df2).unwrap())
|
||||||
|
.unwrap()
|
||||||
|
.into_arrow()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn from_polars_to_arrow() {
|
||||||
|
let record_batch_reader = get_record_batch_reader_from_polars();
|
||||||
|
let schema = record_batch_reader.schema();
|
||||||
|
|
||||||
|
// Test schema conversion
|
||||||
|
assert_eq!(
|
||||||
|
schema
|
||||||
|
.fields
|
||||||
|
.iter()
|
||||||
|
.map(|field| (field.name().as_str(), field.data_type()))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
vec![
|
||||||
|
("string", &arrow_schema::DataType::LargeUtf8),
|
||||||
|
("int", &arrow_schema::DataType::Int32),
|
||||||
|
("float", &arrow_schema::DataType::Float64)
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let record_batches: Vec<arrow_array::RecordBatch> =
|
||||||
|
record_batch_reader.map(|result| result.unwrap()).collect();
|
||||||
|
assert_eq!(record_batches.len(), 2);
|
||||||
|
assert_eq!(schema, record_batches[0].schema());
|
||||||
|
assert_eq!(record_batches[0].schema(), record_batches[1].schema());
|
||||||
|
|
||||||
|
// Test number of rows
|
||||||
|
assert_eq!(record_batches[0].num_rows(), 1);
|
||||||
|
assert_eq!(record_batches[1].num_rows(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn from_arrow_to_polars() {
|
||||||
|
let record_batch_reader = get_record_batch_reader_from_polars();
|
||||||
|
let schema = record_batch_reader.schema();
|
||||||
|
let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream {
|
||||||
|
schema: schema.clone(),
|
||||||
|
stream: futures::stream::iter(
|
||||||
|
record_batch_reader
|
||||||
|
.into_iter()
|
||||||
|
.map(|r| r.map_err(Into::into)),
|
||||||
|
),
|
||||||
|
});
|
||||||
|
let df = stream.into_polars().await.unwrap();
|
||||||
|
|
||||||
|
// Test number of chunks and rows
|
||||||
|
assert_eq!(df.n_chunks(), 2);
|
||||||
|
assert_eq!(df.height(), 2);
|
||||||
|
|
||||||
|
// Test schema conversion
|
||||||
|
assert_eq!(
|
||||||
|
df.schema()
|
||||||
|
.into_iter()
|
||||||
|
.map(|(name, datatype)| (name.to_string(), datatype))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
vec![
|
||||||
|
("string".to_string(), polars::prelude::DataType::String),
|
||||||
|
("int".to_owned(), polars::prelude::DataType::Int32),
|
||||||
|
("float".to_owned(), polars::prelude::DataType::Float64)
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -27,9 +27,12 @@ use object_store::{aws::AwsCredential, local::LocalFileSystem};
|
|||||||
use snafu::prelude::*;
|
use snafu::prelude::*;
|
||||||
|
|
||||||
use crate::arrow::IntoArrow;
|
use crate::arrow::IntoArrow;
|
||||||
|
use crate::embeddings::{
|
||||||
|
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
|
||||||
|
};
|
||||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||||
use crate::table::{NativeTable, WriteOptions};
|
use crate::table::{NativeTable, TableDefinition, WriteOptions};
|
||||||
use crate::utils::validate_table_name;
|
use crate::utils::validate_table_name;
|
||||||
use crate::Table;
|
use crate::Table;
|
||||||
|
|
||||||
@@ -133,9 +136,10 @@ pub struct CreateTableBuilder<const HAS_DATA: bool, T: IntoArrow> {
|
|||||||
parent: Arc<dyn ConnectionInternal>,
|
parent: Arc<dyn ConnectionInternal>,
|
||||||
pub(crate) name: String,
|
pub(crate) name: String,
|
||||||
pub(crate) data: Option<T>,
|
pub(crate) data: Option<T>,
|
||||||
pub(crate) schema: Option<SchemaRef>,
|
|
||||||
pub(crate) mode: CreateTableMode,
|
pub(crate) mode: CreateTableMode,
|
||||||
pub(crate) write_options: WriteOptions,
|
pub(crate) write_options: WriteOptions,
|
||||||
|
pub(crate) table_definition: Option<TableDefinition>,
|
||||||
|
pub(crate) embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Builder methods that only apply when we have initial data
|
// Builder methods that only apply when we have initial data
|
||||||
@@ -145,9 +149,10 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
|
|||||||
parent,
|
parent,
|
||||||
name,
|
name,
|
||||||
data: Some(data),
|
data: Some(data),
|
||||||
schema: None,
|
|
||||||
mode: CreateTableMode::default(),
|
mode: CreateTableMode::default(),
|
||||||
write_options: WriteOptions::default(),
|
write_options: WriteOptions::default(),
|
||||||
|
table_definition: None,
|
||||||
|
embeddings: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,24 +180,43 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
|
|||||||
parent: self.parent,
|
parent: self.parent,
|
||||||
name: self.name,
|
name: self.name,
|
||||||
data: None,
|
data: None,
|
||||||
schema: self.schema,
|
table_definition: self.table_definition,
|
||||||
mode: self.mode,
|
mode: self.mode,
|
||||||
write_options: self.write_options,
|
write_options: self.write_options,
|
||||||
|
embeddings: self.embeddings,
|
||||||
};
|
};
|
||||||
Ok((data, builder))
|
Ok((data, builder))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result<Self> {
|
||||||
|
// Early verification of the embedding name
|
||||||
|
let embedding_func = self
|
||||||
|
.parent
|
||||||
|
.embedding_registry()
|
||||||
|
.get(&definition.embedding_name)
|
||||||
|
.ok_or_else(|| Error::EmbeddingFunctionNotFound {
|
||||||
|
name: definition.embedding_name.to_string(),
|
||||||
|
reason: "No embedding function found in the connection's embedding_registry"
|
||||||
|
.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
self.embeddings.push((definition, embedding_func));
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Builder methods that only apply when we do not have initial data
|
// Builder methods that only apply when we do not have initial data
|
||||||
impl CreateTableBuilder<false, NoData> {
|
impl CreateTableBuilder<false, NoData> {
|
||||||
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
|
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
|
||||||
|
let table_definition = TableDefinition::new_from_schema(schema);
|
||||||
Self {
|
Self {
|
||||||
parent,
|
parent,
|
||||||
name,
|
name,
|
||||||
data: None,
|
data: None,
|
||||||
schema: Some(schema),
|
table_definition: Some(table_definition),
|
||||||
mode: CreateTableMode::default(),
|
mode: CreateTableMode::default(),
|
||||||
write_options: WriteOptions::default(),
|
write_options: WriteOptions::default(),
|
||||||
|
embeddings: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -350,6 +374,7 @@ impl OpenTableBuilder {
|
|||||||
pub(crate) trait ConnectionInternal:
|
pub(crate) trait ConnectionInternal:
|
||||||
Send + Sync + std::fmt::Debug + std::fmt::Display + 'static
|
Send + Sync + std::fmt::Debug + std::fmt::Display + 'static
|
||||||
{
|
{
|
||||||
|
fn embedding_registry(&self) -> &dyn EmbeddingRegistry;
|
||||||
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>>;
|
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>>;
|
||||||
async fn do_create_table(
|
async fn do_create_table(
|
||||||
&self,
|
&self,
|
||||||
@@ -366,7 +391,7 @@ pub(crate) trait ConnectionInternal:
|
|||||||
) -> Result<Table> {
|
) -> Result<Table> {
|
||||||
let batches = Box::new(RecordBatchIterator::new(
|
let batches = Box::new(RecordBatchIterator::new(
|
||||||
vec![],
|
vec![],
|
||||||
options.schema.as_ref().unwrap().clone(),
|
options.table_definition.clone().unwrap().schema.clone(),
|
||||||
));
|
));
|
||||||
self.do_create_table(options, batches).await
|
self.do_create_table(options, batches).await
|
||||||
}
|
}
|
||||||
@@ -453,6 +478,13 @@ impl Connection {
|
|||||||
pub async fn drop_db(&self) -> Result<()> {
|
pub async fn drop_db(&self) -> Result<()> {
|
||||||
self.internal.drop_db().await
|
self.internal.drop_db().await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the in-memory embedding registry.
|
||||||
|
/// It's important to note that the embedding registry is not persisted across connections.
|
||||||
|
/// So if a table contains embeddings, you will need to make sure that you are using a connection that has the same embedding functions registered
|
||||||
|
pub fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||||
|
self.internal.embedding_registry()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -486,6 +518,7 @@ pub struct ConnectBuilder {
|
|||||||
/// consistency only applies to read operations. Write operations are
|
/// consistency only applies to read operations. Write operations are
|
||||||
/// always consistent.
|
/// always consistent.
|
||||||
read_consistency_interval: Option<std::time::Duration>,
|
read_consistency_interval: Option<std::time::Duration>,
|
||||||
|
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConnectBuilder {
|
impl ConnectBuilder {
|
||||||
@@ -498,6 +531,7 @@ impl ConnectBuilder {
|
|||||||
host_override: None,
|
host_override: None,
|
||||||
read_consistency_interval: None,
|
read_consistency_interval: None,
|
||||||
storage_options: HashMap::new(),
|
storage_options: HashMap::new(),
|
||||||
|
embedding_registry: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -516,6 +550,12 @@ impl ConnectBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
|
||||||
|
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
|
||||||
|
self.embedding_registry = Some(registry);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// [`AwsCredential`] to use when connecting to S3.
|
/// [`AwsCredential`] to use when connecting to S3.
|
||||||
#[deprecated(note = "Pass through storage_options instead")]
|
#[deprecated(note = "Pass through storage_options instead")]
|
||||||
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
|
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
|
||||||
@@ -642,6 +682,7 @@ struct Database {
|
|||||||
|
|
||||||
// Storage options to be inherited by tables created from this connection
|
// Storage options to be inherited by tables created from this connection
|
||||||
storage_options: HashMap<String, String>,
|
storage_options: HashMap<String, String>,
|
||||||
|
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Database {
|
impl std::fmt::Display for Database {
|
||||||
@@ -675,7 +716,12 @@ impl Database {
|
|||||||
// TODO: pass params regardless of OS
|
// TODO: pass params regardless of OS
|
||||||
match parse_res {
|
match parse_res {
|
||||||
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
|
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
|
||||||
Self::open_path(uri, options.read_consistency_interval).await
|
Self::open_path(
|
||||||
|
uri,
|
||||||
|
options.read_consistency_interval,
|
||||||
|
options.embedding_registry.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
Ok(mut url) => {
|
Ok(mut url) => {
|
||||||
// iter thru the query params and extract the commit store param
|
// iter thru the query params and extract the commit store param
|
||||||
@@ -745,6 +791,10 @@ impl Database {
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let embedding_registry = options
|
||||||
|
.embedding_registry
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| Arc::new(MemoryRegistry::new()));
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
uri: table_base_uri,
|
uri: table_base_uri,
|
||||||
query_string,
|
query_string,
|
||||||
@@ -753,20 +803,33 @@ impl Database {
|
|||||||
store_wrapper: write_store_wrapper,
|
store_wrapper: write_store_wrapper,
|
||||||
read_consistency_interval: options.read_consistency_interval,
|
read_consistency_interval: options.read_consistency_interval,
|
||||||
storage_options,
|
storage_options,
|
||||||
|
embedding_registry,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Err(_) => Self::open_path(uri, options.read_consistency_interval).await,
|
Err(_) => {
|
||||||
|
Self::open_path(
|
||||||
|
uri,
|
||||||
|
options.read_consistency_interval,
|
||||||
|
options.embedding_registry.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn open_path(
|
async fn open_path(
|
||||||
path: &str,
|
path: &str,
|
||||||
read_consistency_interval: Option<std::time::Duration>,
|
read_consistency_interval: Option<std::time::Duration>,
|
||||||
|
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||||
if object_store.is_local() {
|
if object_store.is_local() {
|
||||||
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let embedding_registry =
|
||||||
|
embedding_registry.unwrap_or_else(|| Arc::new(MemoryRegistry::new()));
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
uri: path.to_string(),
|
uri: path.to_string(),
|
||||||
query_string: None,
|
query_string: None,
|
||||||
@@ -775,6 +838,7 @@ impl Database {
|
|||||||
store_wrapper: None,
|
store_wrapper: None,
|
||||||
read_consistency_interval,
|
read_consistency_interval,
|
||||||
storage_options: HashMap::new(),
|
storage_options: HashMap::new(),
|
||||||
|
embedding_registry,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -815,6 +879,9 @@ impl Database {
|
|||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl ConnectionInternal for Database {
|
impl ConnectionInternal for Database {
|
||||||
|
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||||
|
self.embedding_registry.as_ref()
|
||||||
|
}
|
||||||
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>> {
|
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>> {
|
||||||
let mut f = self
|
let mut f = self
|
||||||
.object_store
|
.object_store
|
||||||
@@ -851,7 +918,7 @@ impl ConnectionInternal for Database {
|
|||||||
data: Box<dyn RecordBatchReader + Send>,
|
data: Box<dyn RecordBatchReader + Send>,
|
||||||
) -> Result<Table> {
|
) -> Result<Table> {
|
||||||
let table_uri = self.table_uri(&options.name)?;
|
let table_uri = self.table_uri(&options.name)?;
|
||||||
|
let embedding_registry = self.embedding_registry.clone();
|
||||||
// Inherit storage options from the connection
|
// Inherit storage options from the connection
|
||||||
let storage_options = options
|
let storage_options = options
|
||||||
.write_options
|
.write_options
|
||||||
@@ -866,6 +933,11 @@ impl ConnectionInternal for Database {
|
|||||||
storage_options.insert(key.clone(), value.clone());
|
storage_options.insert(key.clone(), value.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let data = if options.embeddings.is_empty() {
|
||||||
|
data
|
||||||
|
} else {
|
||||||
|
Box::new(WithEmbeddings::new(data, options.embeddings))
|
||||||
|
};
|
||||||
|
|
||||||
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
|
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
|
||||||
if matches!(&options.mode, CreateTableMode::Overwrite) {
|
if matches!(&options.mode, CreateTableMode::Overwrite) {
|
||||||
@@ -882,7 +954,10 @@ impl ConnectionInternal for Database {
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(table) => Ok(Table::new(Arc::new(table))),
|
Ok(table) => Ok(Table::new_with_embedding_registry(
|
||||||
|
Arc::new(table),
|
||||||
|
embedding_registry,
|
||||||
|
)),
|
||||||
Err(Error::TableAlreadyExists { name }) => match options.mode {
|
Err(Error::TableAlreadyExists { name }) => match options.mode {
|
||||||
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
|
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
|
||||||
CreateTableMode::ExistOk(callback) => {
|
CreateTableMode::ExistOk(callback) => {
|
||||||
|
|||||||
307
rust/lancedb/src/embeddings.rs
Normal file
307
rust/lancedb/src/embeddings.rs
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
// Copyright 2024 LanceDB Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use lance::arrow::RecordBatchExt;
|
||||||
|
use std::{
|
||||||
|
borrow::Cow,
|
||||||
|
collections::{HashMap, HashSet},
|
||||||
|
sync::{Arc, RwLock},
|
||||||
|
};
|
||||||
|
|
||||||
|
use arrow_array::{Array, RecordBatch, RecordBatchReader};
|
||||||
|
use arrow_schema::{DataType, Field, SchemaBuilder};
|
||||||
|
// use async_trait::async_trait;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
error::Result,
|
||||||
|
table::{ColumnDefinition, ColumnKind, TableDefinition},
|
||||||
|
Error,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Trait for embedding functions
|
||||||
|
///
|
||||||
|
/// An embedding function is a function that is applied to a column of input data
|
||||||
|
/// to produce an "embedding" of that input. This embedding is then stored in the
|
||||||
|
/// database alongside (or instead of) the original input.
|
||||||
|
///
|
||||||
|
/// An "embedding" is often a lower-dimensional representation of the input data.
|
||||||
|
/// For example, sentence-transformers can be used to embed sentences into a 768-dimensional
|
||||||
|
/// vector space. This is useful for tasks like similarity search, where we want to find
|
||||||
|
/// similar sentences to a query sentence.
|
||||||
|
///
|
||||||
|
/// To use an embedding function you must first register it with the `EmbeddingsRegistry`.
|
||||||
|
/// Then you can define it on a column in the table schema. That embedding will then be used
|
||||||
|
/// to embed the data in that column.
|
||||||
|
pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync {
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
/// The type of the input data
|
||||||
|
fn source_type(&self) -> Result<Cow<DataType>>;
|
||||||
|
/// The type of the output data
|
||||||
|
/// This should **always** match the output of the `embed` function
|
||||||
|
fn dest_type(&self) -> Result<Cow<DataType>>;
|
||||||
|
/// Embed the input
|
||||||
|
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Defines an embedding from input data into a lower-dimensional space
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||||
|
pub struct EmbeddingDefinition {
|
||||||
|
/// The name of the column in the input data
|
||||||
|
pub source_column: String,
|
||||||
|
/// The name of the embedding column, if not specified
|
||||||
|
/// it will be the source column with `_embedding` appended
|
||||||
|
pub dest_column: Option<String>,
|
||||||
|
/// The name of the embedding function to apply
|
||||||
|
pub embedding_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingDefinition {
|
||||||
|
pub fn new<S: Into<String>>(source_column: S, embedding_name: S, dest: Option<S>) -> Self {
|
||||||
|
Self {
|
||||||
|
source_column: source_column.into(),
|
||||||
|
dest_column: dest.map(|d| d.into()),
|
||||||
|
embedding_name: embedding_name.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A registry of embedding
|
||||||
|
pub trait EmbeddingRegistry: Send + Sync + std::fmt::Debug {
|
||||||
|
/// Return the names of all registered embedding functions
|
||||||
|
fn functions(&self) -> HashSet<String>;
|
||||||
|
/// Register a new [`EmbeddingFunction
|
||||||
|
/// Returns an error if the function can not be registered
|
||||||
|
fn register(&self, name: &str, function: Arc<dyn EmbeddingFunction>) -> Result<()>;
|
||||||
|
/// Get an embedding function by name
|
||||||
|
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A [`EmbeddingRegistry`] that uses in-memory [`HashMap`]s
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
|
pub struct MemoryRegistry {
|
||||||
|
functions: Arc<RwLock<HashMap<String, Arc<dyn EmbeddingFunction>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingRegistry for MemoryRegistry {
|
||||||
|
fn functions(&self) -> HashSet<String> {
|
||||||
|
self.functions.read().unwrap().keys().cloned().collect()
|
||||||
|
}
|
||||||
|
fn register(&self, name: &str, function: Arc<dyn EmbeddingFunction>) -> Result<()> {
|
||||||
|
self.functions
|
||||||
|
.write()
|
||||||
|
.unwrap()
|
||||||
|
.insert(name.to_string(), function);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
|
||||||
|
self.functions.read().unwrap().get(name).cloned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryRegistry {
|
||||||
|
/// Create a new `MemoryRegistry`
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A record batch reader that has embeddings applied to it
|
||||||
|
/// This is a wrapper around another record batch reader that applies an embedding function
|
||||||
|
/// when reading from the record batch
|
||||||
|
pub struct WithEmbeddings<R: RecordBatchReader> {
|
||||||
|
inner: R,
|
||||||
|
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A record batch that might have embeddings applied to it.
|
||||||
|
pub enum MaybeEmbedded<R: RecordBatchReader> {
|
||||||
|
/// The record batch reader has embeddings applied to it
|
||||||
|
Yes(WithEmbeddings<R>),
|
||||||
|
/// The record batch reader does not have embeddings applied to it
|
||||||
|
/// The inner record batch reader is returned as-is
|
||||||
|
No(R),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: RecordBatchReader> MaybeEmbedded<R> {
|
||||||
|
/// Create a new RecordBatchReader with embeddings applied to it if the table definition
|
||||||
|
/// specifies an embedding column and the registry contains an embedding function with that name
|
||||||
|
/// Otherwise, this is a no-op and the inner RecordBatchReader is returned.
|
||||||
|
pub fn try_new(
|
||||||
|
inner: R,
|
||||||
|
table_definition: TableDefinition,
|
||||||
|
registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
if let Some(registry) = registry {
|
||||||
|
let mut embeddings = Vec::with_capacity(table_definition.column_definitions.len());
|
||||||
|
for cd in table_definition.column_definitions.iter() {
|
||||||
|
if let ColumnKind::Embedding(embedding_def) = &cd.kind {
|
||||||
|
match registry.get(&embedding_def.embedding_name) {
|
||||||
|
Some(func) => {
|
||||||
|
embeddings.push((embedding_def.clone(), func));
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
return Err(Error::EmbeddingFunctionNotFound {
|
||||||
|
name: embedding_def.embedding_name.to_string(),
|
||||||
|
reason: format!(
|
||||||
|
"Table was defined with an embedding column `{}` but no embedding function was found with that name within the registry.",
|
||||||
|
embedding_def.embedding_name
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !embeddings.is_empty() {
|
||||||
|
return Ok(Self::Yes(WithEmbeddings { inner, embeddings }));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// No embeddings to apply
|
||||||
|
Ok(Self::No(inner))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: RecordBatchReader> WithEmbeddings<R> {
|
||||||
|
pub fn new(
|
||||||
|
inner: R,
|
||||||
|
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||||
|
) -> Self {
|
||||||
|
Self { inner, embeddings }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: RecordBatchReader> WithEmbeddings<R> {
|
||||||
|
fn dest_fields(&self) -> Result<Vec<Field>> {
|
||||||
|
let schema = self.inner.schema();
|
||||||
|
self.embeddings
|
||||||
|
.iter()
|
||||||
|
.map(|(ed, func)| {
|
||||||
|
let src_field = schema.field_with_name(&ed.source_column).unwrap();
|
||||||
|
|
||||||
|
let field_name = ed
|
||||||
|
.dest_column
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| format!("{}_embedding", &ed.source_column));
|
||||||
|
Ok(Field::new(
|
||||||
|
field_name,
|
||||||
|
func.dest_type()?.into_owned(),
|
||||||
|
src_field.is_nullable(),
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn column_defs(&self) -> Vec<ColumnDefinition> {
|
||||||
|
let base_schema = self.inner.schema();
|
||||||
|
base_schema
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|_| ColumnDefinition {
|
||||||
|
kind: ColumnKind::Physical,
|
||||||
|
})
|
||||||
|
.chain(self.embeddings.iter().map(|(ed, _)| ColumnDefinition {
|
||||||
|
kind: ColumnKind::Embedding(ed.clone()),
|
||||||
|
}))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn table_definition(&self) -> Result<TableDefinition> {
|
||||||
|
let base_schema = self.inner.schema();
|
||||||
|
|
||||||
|
let output_fields = self.dest_fields()?;
|
||||||
|
let column_definitions = self.column_defs();
|
||||||
|
|
||||||
|
let mut sb: SchemaBuilder = base_schema.as_ref().into();
|
||||||
|
sb.extend(output_fields);
|
||||||
|
|
||||||
|
let schema = Arc::new(sb.finish());
|
||||||
|
Ok(TableDefinition {
|
||||||
|
schema,
|
||||||
|
column_definitions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
|
||||||
|
type Item = std::result::Result<RecordBatch, arrow_schema::ArrowError>;
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
match self {
|
||||||
|
Self::Yes(inner) => inner.next(),
|
||||||
|
Self::No(inner) => inner.next(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: RecordBatchReader> RecordBatchReader for MaybeEmbedded<R> {
|
||||||
|
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||||
|
match self {
|
||||||
|
Self::Yes(inner) => inner.schema(),
|
||||||
|
Self::No(inner) => inner.schema(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
|
||||||
|
type Item = std::result::Result<RecordBatch, arrow_schema::ArrowError>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let batch = self.inner.next()?;
|
||||||
|
match batch {
|
||||||
|
Ok(mut batch) => {
|
||||||
|
// todo: parallelize this
|
||||||
|
for (fld, func) in self.embeddings.iter() {
|
||||||
|
let src_column = batch.column_by_name(&fld.source_column).unwrap();
|
||||||
|
let embedding = match func.embed(src_column.clone()) {
|
||||||
|
Ok(embedding) => embedding,
|
||||||
|
Err(e) => {
|
||||||
|
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
|
||||||
|
"Error computing embedding: {}",
|
||||||
|
e
|
||||||
|
))))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let dst_field_name = fld
|
||||||
|
.dest_column
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| format!("{}_embedding", &fld.source_column));
|
||||||
|
|
||||||
|
let dst_field = Field::new(
|
||||||
|
dst_field_name,
|
||||||
|
embedding.data_type().clone(),
|
||||||
|
embedding.nulls().is_some(),
|
||||||
|
);
|
||||||
|
|
||||||
|
match batch.try_with_column(dst_field.clone(), embedding) {
|
||||||
|
Ok(b) => batch = b,
|
||||||
|
Err(e) => return Some(Err(e)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
Some(Ok(batch))
|
||||||
|
}
|
||||||
|
Err(e) => Some(Err(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: RecordBatchReader> RecordBatchReader for WithEmbeddings<R> {
|
||||||
|
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||||
|
self.table_definition()
|
||||||
|
.expect("table definition should be infallible at this point")
|
||||||
|
.into_rich_schema()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,6 +26,9 @@ pub enum Error {
|
|||||||
InvalidInput { message: String },
|
InvalidInput { message: String },
|
||||||
#[snafu(display("Table '{name}' was not found"))]
|
#[snafu(display("Table '{name}' was not found"))]
|
||||||
TableNotFound { name: String },
|
TableNotFound { name: String },
|
||||||
|
#[snafu(display("Embedding function '{name}' was not found. : {reason}"))]
|
||||||
|
EmbeddingFunctionNotFound { name: String, reason: String },
|
||||||
|
|
||||||
#[snafu(display("Table '{name}' already exists"))]
|
#[snafu(display("Table '{name}' already exists"))]
|
||||||
TableAlreadyExists { name: String },
|
TableAlreadyExists { name: String },
|
||||||
#[snafu(display("Unable to created lance dataset at {path}: {source}"))]
|
#[snafu(display("Unable to created lance dataset at {path}: {source}"))]
|
||||||
@@ -112,3 +115,13 @@ impl From<url::ParseError> for Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "polars")]
|
||||||
|
impl From<polars::prelude::PolarsError> for Error {
|
||||||
|
fn from(source: polars::prelude::PolarsError) -> Self {
|
||||||
|
Self::Other {
|
||||||
|
message: "Error in Polars DataFrame integration.".to_string(),
|
||||||
|
source: Some(Box::new(source)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -194,10 +194,13 @@
|
|||||||
pub mod arrow;
|
pub mod arrow;
|
||||||
pub mod connection;
|
pub mod connection;
|
||||||
pub mod data;
|
pub mod data;
|
||||||
|
pub mod embeddings;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod index;
|
pub mod index;
|
||||||
pub mod io;
|
pub mod io;
|
||||||
pub mod ipc;
|
pub mod ipc;
|
||||||
|
#[cfg(feature = "polars")]
|
||||||
|
mod polars_arrow_convertors;
|
||||||
pub mod query;
|
pub mod query;
|
||||||
#[cfg(feature = "remote")]
|
#[cfg(feature = "remote")]
|
||||||
pub(crate) mod remote;
|
pub(crate) mod remote;
|
||||||
|
|||||||
123
rust/lancedb/src/polars_arrow_convertors.rs
Normal file
123
rust/lancedb/src/polars_arrow_convertors.rs
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
/// Polars and LanceDB both use Arrow for their in memory-representation, but use
|
||||||
|
/// different Rust Arrow implementations. LanceDB uses the arrow-rs crate and
|
||||||
|
/// Polars uses the polars-arrow crate.
|
||||||
|
///
|
||||||
|
/// This crate defines zero-copy conversions (of the underlying buffers)
|
||||||
|
/// between polars-arrow and arrow-rs using the C FFI.
|
||||||
|
///
|
||||||
|
/// The polars-arrow does implement conversions to and from arrow-rs, but
|
||||||
|
/// requires a feature flagged dependency on arrow-rs. The version of arrow-rs
|
||||||
|
/// depended on by polars-arrow and LanceDB may not be compatible,
|
||||||
|
/// which necessitates using the C FFI.
|
||||||
|
use crate::error::Result;
|
||||||
|
use polars::prelude::{DataFrame, Series};
|
||||||
|
use std::{mem, sync::Arc};
|
||||||
|
|
||||||
|
/// When interpreting Polars dataframes as polars-arrow record batches,
|
||||||
|
/// one must decide whether to use Arrow string/binary view types
|
||||||
|
/// instead of the standard Arrow string/binary types.
|
||||||
|
/// For now, we will not use string view types because conversions
|
||||||
|
/// for string view types from polars-arrow to arrow-rs are not yet implemented.
|
||||||
|
/// See: https://lists.apache.org/thread/w88tpz76ox8h3rxkjl4so6rg3f1rv7wt for the
|
||||||
|
/// differences in the types.
|
||||||
|
pub const POLARS_ARROW_FLAVOR: bool = false;
|
||||||
|
const IS_ARRAY_NULLABLE: bool = true;
|
||||||
|
|
||||||
|
/// Converts a Polars DataFrame schema to an Arrow RecordBatch schema.
|
||||||
|
pub fn convert_polars_df_schema_to_arrow_rb_schema(
|
||||||
|
polars_df_schema: polars::prelude::Schema,
|
||||||
|
) -> Result<Arc<arrow_schema::Schema>> {
|
||||||
|
let arrow_fields: Result<Vec<arrow_schema::Field>> = polars_df_schema
|
||||||
|
.into_iter()
|
||||||
|
.map(|(name, df_dtype)| {
|
||||||
|
let polars_arrow_dtype = df_dtype.to_arrow(POLARS_ARROW_FLAVOR);
|
||||||
|
let polars_field =
|
||||||
|
polars_arrow::datatypes::Field::new(name, polars_arrow_dtype, IS_ARRAY_NULLABLE);
|
||||||
|
convert_polars_arrow_field_to_arrow_rs_field(polars_field)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(Arc::new(arrow_schema::Schema::new(arrow_fields?)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts an Arrow RecordBatch schema to a Polars DataFrame schema.
|
||||||
|
pub fn convert_arrow_rb_schema_to_polars_df_schema(
|
||||||
|
arrow_schema: &arrow_schema::Schema,
|
||||||
|
) -> Result<polars::prelude::Schema> {
|
||||||
|
let polars_df_fields: Result<Vec<polars::prelude::Field>> = arrow_schema
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|arrow_rs_field| {
|
||||||
|
let polars_arrow_field = convert_arrow_rs_field_to_polars_arrow_field(arrow_rs_field)?;
|
||||||
|
Ok(polars::prelude::Field::new(
|
||||||
|
arrow_rs_field.name(),
|
||||||
|
polars::datatypes::DataType::from(polars_arrow_field.data_type()),
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(polars::prelude::Schema::from_iter(polars_df_fields?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts an Arrow RecordBatch to a Polars DataFrame, using a provided Polars DataFrame schema.
|
||||||
|
pub fn convert_arrow_rb_to_polars_df(
|
||||||
|
arrow_rb: &arrow::record_batch::RecordBatch,
|
||||||
|
polars_schema: &polars::prelude::Schema,
|
||||||
|
) -> Result<DataFrame> {
|
||||||
|
let mut columns: Vec<Series> = Vec::with_capacity(arrow_rb.num_columns());
|
||||||
|
|
||||||
|
for (i, column) in arrow_rb.columns().iter().enumerate() {
|
||||||
|
let polars_df_dtype = polars_schema.try_get_at_index(i)?.1;
|
||||||
|
let polars_arrow_dtype = polars_df_dtype.to_arrow(POLARS_ARROW_FLAVOR);
|
||||||
|
let polars_array =
|
||||||
|
convert_arrow_rs_array_to_polars_arrow_array(column, polars_arrow_dtype)?;
|
||||||
|
columns.push(Series::from_arrow(
|
||||||
|
polars_schema.try_get_at_index(i)?.0,
|
||||||
|
polars_array,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(DataFrame::from_iter(columns))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a polars-arrow Arrow array to an arrow-rs Arrow array.
|
||||||
|
pub fn convert_polars_arrow_array_to_arrow_rs_array(
|
||||||
|
polars_array: Box<dyn polars_arrow::array::Array>,
|
||||||
|
arrow_datatype: arrow_schema::DataType,
|
||||||
|
) -> std::result::Result<arrow_array::ArrayRef, arrow_schema::ArrowError> {
|
||||||
|
let polars_c_array = polars_arrow::ffi::export_array_to_c(polars_array);
|
||||||
|
let arrow_c_array = unsafe { mem::transmute(polars_c_array) };
|
||||||
|
Ok(arrow_array::make_array(unsafe {
|
||||||
|
arrow::ffi::from_ffi_and_data_type(arrow_c_array, arrow_datatype)
|
||||||
|
}?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts an arrow-rs Arrow array to a polars-arrow Arrow array.
|
||||||
|
fn convert_arrow_rs_array_to_polars_arrow_array(
|
||||||
|
arrow_rs_array: &Arc<dyn arrow_array::Array>,
|
||||||
|
polars_arrow_dtype: polars::datatypes::ArrowDataType,
|
||||||
|
) -> Result<Box<dyn polars_arrow::array::Array>> {
|
||||||
|
let arrow_c_array = arrow::ffi::FFI_ArrowArray::new(&arrow_rs_array.to_data());
|
||||||
|
let polars_c_array = unsafe { mem::transmute(arrow_c_array) };
|
||||||
|
Ok(unsafe { polars_arrow::ffi::import_array_from_c(polars_c_array, polars_arrow_dtype) }?)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_polars_arrow_field_to_arrow_rs_field(
|
||||||
|
polars_arrow_field: polars_arrow::datatypes::Field,
|
||||||
|
) -> Result<arrow_schema::Field> {
|
||||||
|
let polars_c_schema = polars_arrow::ffi::export_field_to_c(&polars_arrow_field);
|
||||||
|
let arrow_c_schema: arrow::ffi::FFI_ArrowSchema = unsafe { mem::transmute(polars_c_schema) };
|
||||||
|
let arrow_rs_dtype = arrow_schema::DataType::try_from(&arrow_c_schema)?;
|
||||||
|
Ok(arrow_schema::Field::new(
|
||||||
|
polars_arrow_field.name,
|
||||||
|
arrow_rs_dtype,
|
||||||
|
IS_ARRAY_NULLABLE,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_arrow_rs_field_to_polars_arrow_field(
|
||||||
|
arrow_rs_field: &arrow_schema::Field,
|
||||||
|
) -> Result<polars_arrow::datatypes::Field> {
|
||||||
|
let arrow_rs_dtype = arrow_rs_field.data_type();
|
||||||
|
let arrow_c_schema = arrow::ffi::FFI_ArrowSchema::try_from(arrow_rs_dtype)?;
|
||||||
|
let polars_c_schema: polars_arrow::ffi::ArrowSchema = unsafe { mem::transmute(arrow_c_schema) };
|
||||||
|
Ok(unsafe { polars_arrow::ffi::import_field_from_c(&polars_c_schema) }?)
|
||||||
|
}
|
||||||
@@ -23,6 +23,7 @@ use tokio::task::spawn_blocking;
|
|||||||
use crate::connection::{
|
use crate::connection::{
|
||||||
ConnectionInternal, CreateTableBuilder, NoData, OpenTableBuilder, TableNamesBuilder,
|
ConnectionInternal, CreateTableBuilder, NoData, OpenTableBuilder, TableNamesBuilder,
|
||||||
};
|
};
|
||||||
|
use crate::embeddings::EmbeddingRegistry;
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use crate::Table;
|
use crate::Table;
|
||||||
|
|
||||||
@@ -87,14 +88,16 @@ impl ConnectionInternal for RemoteDatabase {
|
|||||||
.await
|
.await
|
||||||
.unwrap()?;
|
.unwrap()?;
|
||||||
|
|
||||||
self.client
|
let rsp = self
|
||||||
.post(&format!("/v1/table/{}/create", options.name))
|
.client
|
||||||
|
.post(&format!("/v1/table/{}/create/", options.name))
|
||||||
.body(data_buffer)
|
.body(data_buffer)
|
||||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||||
// This is currently expected by LanceDb cloud but will be removed soon.
|
// This is currently expected by LanceDb cloud but will be removed soon.
|
||||||
.header("x-request-id", "na")
|
.header("x-request-id", "na")
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
self.client.check_response(rsp).await?;
|
||||||
|
|
||||||
Ok(Table::new(Arc::new(RemoteTable::new(
|
Ok(Table::new(Arc::new(RemoteTable::new(
|
||||||
self.client.clone(),
|
self.client.clone(),
|
||||||
@@ -113,4 +116,8 @@ impl ConnectionInternal for RemoteDatabase {
|
|||||||
async fn drop_db(&self) -> Result<()> {
|
async fn drop_db(&self) -> Result<()> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ use crate::{
|
|||||||
query::{Query, QueryExecutionOptions, VectorQuery},
|
query::{Query, QueryExecutionOptions, VectorQuery},
|
||||||
table::{
|
table::{
|
||||||
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
|
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
|
||||||
TableInternal, UpdateBuilder,
|
TableDefinition, TableInternal, UpdateBuilder,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -120,4 +120,7 @@ impl TableInternal for RemoteTable {
|
|||||||
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
|
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
async fn table_definition(&self) -> Result<TableDefinition> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,10 +41,12 @@ use lance::io::WrappingObjectStore;
|
|||||||
use lance_index::IndexType;
|
use lance_index::IndexType;
|
||||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
||||||
use log::info;
|
use log::info;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use snafu::whatever;
|
use snafu::whatever;
|
||||||
|
|
||||||
use crate::arrow::IntoArrow;
|
use crate::arrow::IntoArrow;
|
||||||
use crate::connection::NoData;
|
use crate::connection::NoData;
|
||||||
|
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
|
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
|
||||||
use crate::index::IndexConfig;
|
use crate::index::IndexConfig;
|
||||||
@@ -63,6 +65,79 @@ use self::merge::MergeInsertBuilder;
|
|||||||
pub(crate) mod dataset;
|
pub(crate) mod dataset;
|
||||||
pub mod merge;
|
pub mod merge;
|
||||||
|
|
||||||
|
/// Defines the type of column
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub enum ColumnKind {
|
||||||
|
/// Columns populated by data from the user (this is the most common case)
|
||||||
|
Physical,
|
||||||
|
/// Columns populated by applying an embedding function to the input
|
||||||
|
Embedding(EmbeddingDefinition),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Defines a column in a table
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ColumnDefinition {
|
||||||
|
/// The source of the column data
|
||||||
|
pub kind: ColumnKind,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TableDefinition {
|
||||||
|
pub column_definitions: Vec<ColumnDefinition>,
|
||||||
|
pub schema: SchemaRef,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TableDefinition {
|
||||||
|
pub fn new(schema: SchemaRef, column_definitions: Vec<ColumnDefinition>) -> Self {
|
||||||
|
Self {
|
||||||
|
column_definitions,
|
||||||
|
schema,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_from_schema(schema: SchemaRef) -> Self {
|
||||||
|
let column_definitions = schema
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|_| ColumnDefinition {
|
||||||
|
kind: ColumnKind::Physical,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Self::new(schema, column_definitions)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn try_from_rich_schema(schema: SchemaRef) -> Result<Self> {
|
||||||
|
let column_definitions = schema.metadata.get("lancedb::column_definitions");
|
||||||
|
if let Some(column_definitions) = column_definitions {
|
||||||
|
let column_definitions: Vec<ColumnDefinition> =
|
||||||
|
serde_json::from_str(column_definitions).map_err(|e| Error::Runtime {
|
||||||
|
message: format!("Failed to deserialize column definitions: {}", e),
|
||||||
|
})?;
|
||||||
|
Ok(Self::new(schema, column_definitions))
|
||||||
|
} else {
|
||||||
|
let column_definitions = schema
|
||||||
|
.fields()
|
||||||
|
.iter()
|
||||||
|
.map(|_| ColumnDefinition {
|
||||||
|
kind: ColumnKind::Physical,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(Self::new(schema, column_definitions))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_rich_schema(self) -> SchemaRef {
|
||||||
|
// We have full control over the structure of column definitions. This should
|
||||||
|
// not fail, except for a bug
|
||||||
|
let lancedb_metadata = serde_json::to_string(&self.column_definitions).unwrap();
|
||||||
|
let mut schema_with_metadata = (*self.schema).clone();
|
||||||
|
schema_with_metadata
|
||||||
|
.metadata
|
||||||
|
.insert("lancedb::column_definitions".to_string(), lancedb_metadata);
|
||||||
|
Arc::new(schema_with_metadata)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Optimize the dataset.
|
/// Optimize the dataset.
|
||||||
///
|
///
|
||||||
/// Similar to `VACUUM` in PostgreSQL, it offers different options to
|
/// Similar to `VACUUM` in PostgreSQL, it offers different options to
|
||||||
@@ -132,6 +207,7 @@ pub struct AddDataBuilder<T: IntoArrow> {
|
|||||||
pub(crate) data: T,
|
pub(crate) data: T,
|
||||||
pub(crate) mode: AddDataMode,
|
pub(crate) mode: AddDataMode,
|
||||||
pub(crate) write_options: WriteOptions,
|
pub(crate) write_options: WriteOptions,
|
||||||
|
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: IntoArrow> std::fmt::Debug for AddDataBuilder<T> {
|
impl<T: IntoArrow> std::fmt::Debug for AddDataBuilder<T> {
|
||||||
@@ -163,6 +239,7 @@ impl<T: IntoArrow> AddDataBuilder<T> {
|
|||||||
mode: self.mode,
|
mode: self.mode,
|
||||||
parent: self.parent,
|
parent: self.parent,
|
||||||
write_options: self.write_options,
|
write_options: self.write_options,
|
||||||
|
embedding_registry: self.embedding_registry,
|
||||||
};
|
};
|
||||||
parent.add(without_data, data).await
|
parent.add(without_data, data).await
|
||||||
}
|
}
|
||||||
@@ -280,6 +357,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
|||||||
async fn checkout(&self, version: u64) -> Result<()>;
|
async fn checkout(&self, version: u64) -> Result<()>;
|
||||||
async fn checkout_latest(&self) -> Result<()>;
|
async fn checkout_latest(&self) -> Result<()>;
|
||||||
async fn restore(&self) -> Result<()>;
|
async fn restore(&self) -> Result<()>;
|
||||||
|
async fn table_definition(&self) -> Result<TableDefinition>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A Table is a collection of strong typed Rows.
|
/// A Table is a collection of strong typed Rows.
|
||||||
@@ -288,6 +366,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Table {
|
pub struct Table {
|
||||||
inner: Arc<dyn TableInternal>,
|
inner: Arc<dyn TableInternal>,
|
||||||
|
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Table {
|
impl std::fmt::Display for Table {
|
||||||
@@ -298,7 +377,20 @@ impl std::fmt::Display for Table {
|
|||||||
|
|
||||||
impl Table {
|
impl Table {
|
||||||
pub(crate) fn new(inner: Arc<dyn TableInternal>) -> Self {
|
pub(crate) fn new(inner: Arc<dyn TableInternal>) -> Self {
|
||||||
Self { inner }
|
Self {
|
||||||
|
inner,
|
||||||
|
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn new_with_embedding_registry(
|
||||||
|
inner: Arc<dyn TableInternal>,
|
||||||
|
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
embedding_registry,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
||||||
@@ -340,6 +432,7 @@ impl Table {
|
|||||||
data: batches,
|
data: batches,
|
||||||
mode: AddDataMode::Append,
|
mode: AddDataMode::Append,
|
||||||
write_options: WriteOptions::default(),
|
write_options: WriteOptions::default(),
|
||||||
|
embedding_registry: Some(self.embedding_registry.clone()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -743,11 +836,10 @@ impl Table {
|
|||||||
|
|
||||||
impl From<NativeTable> for Table {
|
impl From<NativeTable> for Table {
|
||||||
fn from(table: NativeTable) -> Self {
|
fn from(table: NativeTable) -> Self {
|
||||||
Self {
|
Self::new(Arc::new(table))
|
||||||
inner: Arc::new(table),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A table in a LanceDB database.
|
/// A table in a LanceDB database.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct NativeTable {
|
pub struct NativeTable {
|
||||||
@@ -918,7 +1010,6 @@ impl NativeTable {
|
|||||||
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
||||||
None => params,
|
None => params,
|
||||||
};
|
};
|
||||||
|
|
||||||
let storage_options = params
|
let storage_options = params
|
||||||
.store_params
|
.store_params
|
||||||
.clone()
|
.clone()
|
||||||
@@ -1342,6 +1433,11 @@ impl TableInternal for NativeTable {
|
|||||||
Ok(Arc::new(Schema::from(&lance_schema)))
|
Ok(Arc::new(Schema::from(&lance_schema)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn table_definition(&self) -> Result<TableDefinition> {
|
||||||
|
let schema = self.schema().await?;
|
||||||
|
TableDefinition::try_from_rich_schema(schema)
|
||||||
|
}
|
||||||
|
|
||||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||||
Ok(self.dataset.get().await?.count_rows(filter).await?)
|
Ok(self.dataset.get().await?.count_rows(filter).await?)
|
||||||
}
|
}
|
||||||
@@ -1351,6 +1447,9 @@ impl TableInternal for NativeTable {
|
|||||||
add: AddDataBuilder<NoData>,
|
add: AddDataBuilder<NoData>,
|
||||||
data: Box<dyn RecordBatchReader + Send>,
|
data: Box<dyn RecordBatchReader + Send>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
|
let data =
|
||||||
|
MaybeEmbedded::try_new(data, self.table_definition().await?, add.embedding_registry)?;
|
||||||
|
|
||||||
let mut lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
|
let mut lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
|
||||||
mode: match add.mode {
|
mode: match add.mode {
|
||||||
AddDataMode::Append => WriteMode::Append,
|
AddDataMode::Append => WriteMode::Append,
|
||||||
@@ -1378,8 +1477,8 @@ impl TableInternal for NativeTable {
|
|||||||
};
|
};
|
||||||
|
|
||||||
self.dataset.ensure_mutable().await?;
|
self.dataset.ensure_mutable().await?;
|
||||||
|
|
||||||
let dataset = Dataset::write(data, &self.uri, Some(lance_params)).await?;
|
let dataset = Dataset::write(data, &self.uri, Some(lance_params)).await?;
|
||||||
|
|
||||||
self.dataset.set_latest(dataset).await;
|
self.dataset.set_latest(dataset).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
320
rust/lancedb/tests/embedding_registry_test.rs
Normal file
320
rust/lancedb/tests/embedding_registry_test.rs
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
use std::{
|
||||||
|
borrow::Cow,
|
||||||
|
collections::{HashMap, HashSet},
|
||||||
|
iter::repeat,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
use arrow::buffer::NullBuffer;
|
||||||
|
use arrow_array::{
|
||||||
|
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||||
|
StringArray,
|
||||||
|
};
|
||||||
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
|
use futures::StreamExt;
|
||||||
|
use lancedb::{
|
||||||
|
arrow::IntoArrow,
|
||||||
|
connect,
|
||||||
|
embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry},
|
||||||
|
query::ExecutableQuery,
|
||||||
|
Error, Result,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_custom_func() -> Result<()> {
|
||||||
|
let tempdir = tempfile::tempdir().unwrap();
|
||||||
|
let tempdir = tempdir.path().to_str().unwrap();
|
||||||
|
let db = connect(tempdir).execute().await?;
|
||||||
|
let embed_fun = MockEmbed::new("embed_fun".to_string(), 1);
|
||||||
|
db.embedding_registry()
|
||||||
|
.register("embed_fun", Arc::new(embed_fun.clone()))?;
|
||||||
|
|
||||||
|
let tbl = db
|
||||||
|
.create_table("test", create_some_records()?)
|
||||||
|
.add_embedding(EmbeddingDefinition::new(
|
||||||
|
"text",
|
||||||
|
&embed_fun.name,
|
||||||
|
Some("embeddings"),
|
||||||
|
))?
|
||||||
|
.execute()
|
||||||
|
.await?;
|
||||||
|
let mut res = tbl.query().execute().await?;
|
||||||
|
while let Some(Ok(batch)) = res.next().await {
|
||||||
|
let embeddings = batch.column_by_name("embeddings");
|
||||||
|
assert!(embeddings.is_some());
|
||||||
|
let embeddings = embeddings.unwrap();
|
||||||
|
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
|
||||||
|
}
|
||||||
|
// now make sure the embeddings are applied when
|
||||||
|
// we add new records too
|
||||||
|
tbl.add(create_some_records()?).execute().await?;
|
||||||
|
let mut res = tbl.query().execute().await?;
|
||||||
|
while let Some(Ok(batch)) = res.next().await {
|
||||||
|
let embeddings = batch.column_by_name("embeddings");
|
||||||
|
assert!(embeddings.is_some());
|
||||||
|
let embeddings = embeddings.unwrap();
|
||||||
|
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_custom_registry() -> Result<()> {
|
||||||
|
let tempdir = tempfile::tempdir().unwrap();
|
||||||
|
let tempdir = tempdir.path().to_str().unwrap();
|
||||||
|
|
||||||
|
let db = connect(tempdir)
|
||||||
|
.embedding_registry(Arc::new(MyRegistry::default()))
|
||||||
|
.execute()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let tbl = db
|
||||||
|
.create_table("test", create_some_records()?)
|
||||||
|
.add_embedding(EmbeddingDefinition::new(
|
||||||
|
"text",
|
||||||
|
"func_1",
|
||||||
|
Some("embeddings"),
|
||||||
|
))?
|
||||||
|
.execute()
|
||||||
|
.await?;
|
||||||
|
let mut res = tbl.query().execute().await?;
|
||||||
|
while let Some(Ok(batch)) = res.next().await {
|
||||||
|
let embeddings = batch.column_by_name("embeddings");
|
||||||
|
assert!(embeddings.is_some());
|
||||||
|
let embeddings = embeddings.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
embeddings.data_type(),
|
||||||
|
MockEmbed::new("func_1".to_string(), 1)
|
||||||
|
.dest_type()?
|
||||||
|
.as_ref()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_multiple_embeddings() -> Result<()> {
|
||||||
|
let tempdir = tempfile::tempdir().unwrap();
|
||||||
|
let tempdir = tempdir.path().to_str().unwrap();
|
||||||
|
|
||||||
|
let db = connect(tempdir).execute().await?;
|
||||||
|
let func_1 = MockEmbed::new("func_1".to_string(), 1);
|
||||||
|
let func_2 = MockEmbed::new("func_2".to_string(), 10);
|
||||||
|
db.embedding_registry()
|
||||||
|
.register(&func_1.name, Arc::new(func_1.clone()))?;
|
||||||
|
db.embedding_registry()
|
||||||
|
.register(&func_2.name, Arc::new(func_2.clone()))?;
|
||||||
|
|
||||||
|
let tbl = db
|
||||||
|
.create_table("test", create_some_records()?)
|
||||||
|
.add_embedding(EmbeddingDefinition::new(
|
||||||
|
"text",
|
||||||
|
&func_1.name,
|
||||||
|
Some("first_embeddings"),
|
||||||
|
))?
|
||||||
|
.add_embedding(EmbeddingDefinition::new(
|
||||||
|
"text",
|
||||||
|
&func_2.name,
|
||||||
|
Some("second_embeddings"),
|
||||||
|
))?
|
||||||
|
.execute()
|
||||||
|
.await?;
|
||||||
|
let mut res = tbl.query().execute().await?;
|
||||||
|
while let Some(Ok(batch)) = res.next().await {
|
||||||
|
let embeddings = batch.column_by_name("first_embeddings");
|
||||||
|
assert!(embeddings.is_some());
|
||||||
|
let second_embeddings = batch.column_by_name("second_embeddings");
|
||||||
|
assert!(second_embeddings.is_some());
|
||||||
|
|
||||||
|
let embeddings = embeddings.unwrap();
|
||||||
|
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
|
||||||
|
|
||||||
|
let second_embeddings = second_embeddings.unwrap();
|
||||||
|
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
|
||||||
|
}
|
||||||
|
|
||||||
|
// now make sure the embeddings are applied when
|
||||||
|
// we add new records too
|
||||||
|
tbl.add(create_some_records()?).execute().await?;
|
||||||
|
let mut res = tbl.query().execute().await?;
|
||||||
|
while let Some(Ok(batch)) = res.next().await {
|
||||||
|
let embeddings = batch.column_by_name("first_embeddings");
|
||||||
|
assert!(embeddings.is_some());
|
||||||
|
let second_embeddings = batch.column_by_name("second_embeddings");
|
||||||
|
assert!(second_embeddings.is_some());
|
||||||
|
|
||||||
|
let embeddings = embeddings.unwrap();
|
||||||
|
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
|
||||||
|
|
||||||
|
let second_embeddings = second_embeddings.unwrap();
|
||||||
|
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_no_func_in_registry() -> Result<()> {
|
||||||
|
let tempdir = tempfile::tempdir().unwrap();
|
||||||
|
let tempdir = tempdir.path().to_str().unwrap();
|
||||||
|
|
||||||
|
let db = connect(tempdir).execute().await?;
|
||||||
|
|
||||||
|
let res = db
|
||||||
|
.create_table("test", create_some_records()?)
|
||||||
|
.add_embedding(EmbeddingDefinition::new(
|
||||||
|
"text",
|
||||||
|
"some_func",
|
||||||
|
Some("first_embeddings"),
|
||||||
|
));
|
||||||
|
assert!(res.is_err());
|
||||||
|
assert!(matches!(
|
||||||
|
res.err().unwrap(),
|
||||||
|
Error::EmbeddingFunctionNotFound { .. }
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_no_func_in_registry_on_add() -> Result<()> {
|
||||||
|
let tempdir = tempfile::tempdir().unwrap();
|
||||||
|
let tempdir = tempdir.path().to_str().unwrap();
|
||||||
|
|
||||||
|
let db = connect(tempdir).execute().await?;
|
||||||
|
db.embedding_registry().register(
|
||||||
|
"some_func",
|
||||||
|
Arc::new(MockEmbed::new("some_func".to_string(), 1)),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
db.create_table("test", create_some_records()?)
|
||||||
|
.add_embedding(EmbeddingDefinition::new(
|
||||||
|
"text",
|
||||||
|
"some_func",
|
||||||
|
Some("first_embeddings"),
|
||||||
|
))?
|
||||||
|
.execute()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let db = connect(tempdir).execute().await?;
|
||||||
|
|
||||||
|
let tbl = db.open_table("test").execute().await?;
|
||||||
|
// This should fail because 'tbl' is expecting "some_func" to be in the registry
|
||||||
|
let res = tbl.add(create_some_records()?).execute().await;
|
||||||
|
assert!(res.is_err());
|
||||||
|
assert!(matches!(
|
||||||
|
res.unwrap_err(),
|
||||||
|
crate::Error::EmbeddingFunctionNotFound { .. }
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_some_records() -> Result<impl IntoArrow> {
|
||||||
|
const TOTAL: usize = 2;
|
||||||
|
|
||||||
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Field::new("id", DataType::Int32, false),
|
||||||
|
Field::new("text", DataType::Utf8, true),
|
||||||
|
]));
|
||||||
|
|
||||||
|
// Create a RecordBatch stream.
|
||||||
|
let batches = RecordBatchIterator::new(
|
||||||
|
vec![RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
||||||
|
Arc::new(StringArray::from_iter(
|
||||||
|
repeat(Some("hello world".to_string())).take(TOTAL),
|
||||||
|
)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap()]
|
||||||
|
.into_iter()
|
||||||
|
.map(Ok),
|
||||||
|
schema.clone(),
|
||||||
|
);
|
||||||
|
Ok(Box::new(batches))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct MyRegistry {
|
||||||
|
functions: HashMap<String, Arc<dyn EmbeddingFunction>>,
|
||||||
|
}
|
||||||
|
impl Default for MyRegistry {
|
||||||
|
fn default() -> Self {
|
||||||
|
let funcs: Vec<Arc<dyn EmbeddingFunction>> = vec![
|
||||||
|
Arc::new(MockEmbed::new("func_1".to_string(), 1)),
|
||||||
|
Arc::new(MockEmbed::new("func_2".to_string(), 10)),
|
||||||
|
];
|
||||||
|
Self {
|
||||||
|
functions: funcs
|
||||||
|
.into_iter()
|
||||||
|
.map(|f| (f.name().to_string(), f))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// a mock registry that only has one function called `embed_fun`
|
||||||
|
impl EmbeddingRegistry for MyRegistry {
|
||||||
|
fn functions(&self) -> HashSet<String> {
|
||||||
|
self.functions.keys().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register(&self, _name: &str, _function: Arc<dyn EmbeddingFunction>) -> Result<()> {
|
||||||
|
Err(Error::Other {
|
||||||
|
message: "MyRegistry is read-only".to_string(),
|
||||||
|
source: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
|
||||||
|
self.functions.get(name).cloned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct MockEmbed {
|
||||||
|
source_type: DataType,
|
||||||
|
dest_type: DataType,
|
||||||
|
name: String,
|
||||||
|
dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockEmbed {
|
||||||
|
pub fn new(name: String, dim: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
source_type: DataType::Utf8,
|
||||||
|
dest_type: DataType::new_fixed_size_list(DataType::Float32, dim as _, true),
|
||||||
|
name,
|
||||||
|
dim,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingFunction for MockEmbed {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
&self.name
|
||||||
|
}
|
||||||
|
fn source_type(&self) -> Result<Cow<DataType>> {
|
||||||
|
Ok(Cow::Borrowed(&self.source_type))
|
||||||
|
}
|
||||||
|
fn dest_type(&self) -> Result<Cow<DataType>> {
|
||||||
|
Ok(Cow::Borrowed(&self.dest_type))
|
||||||
|
}
|
||||||
|
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||||
|
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
|
||||||
|
// and we want to explicitly work with non-nullable arrays.
|
||||||
|
let len = source.len();
|
||||||
|
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
|
||||||
|
let field = Field::new("item", inner.data_type().clone(), false);
|
||||||
|
let arr = FixedSizeListArray::new(
|
||||||
|
Arc::new(field),
|
||||||
|
self.dim as _,
|
||||||
|
inner,
|
||||||
|
Some(NullBuffer::new_valid(len)),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Arc::new(arr))
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user