mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 02:42:57 +00:00
Compare commits
11 Commits
v0.1.10-py
...
v0.1.12
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a03f7ca5a | ||
|
|
88be978e87 | ||
|
|
98b12caa06 | ||
|
|
091dffb171 | ||
|
|
ace6aa883a | ||
|
|
80c25f9896 | ||
|
|
caf22fdb71 | ||
|
|
0e7ae5dfbf | ||
|
|
b261e27222 | ||
|
|
9f603f73a9 | ||
|
|
9ef846929b |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.1.10
|
current_version = 0.1.12
|
||||||
commit = True
|
commit = True
|
||||||
message = Bump version: {current_version} → {new_version}
|
message = Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
1
.github/workflows/rust.yml
vendored
1
.github/workflows/rust.yml
vendored
@@ -6,6 +6,7 @@ on:
|
|||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
|
- Cargo.toml
|
||||||
- rust/**
|
- rust/**
|
||||||
- .github/workflows/rust.yml
|
- .github/workflows/rust.yml
|
||||||
|
|
||||||
|
|||||||
10
Cargo.toml
10
Cargo.toml
@@ -6,9 +6,9 @@ members = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = "0.5.3"
|
lance = "=0.5.5"
|
||||||
arrow-array = "40.0"
|
arrow-array = "42.0"
|
||||||
arrow-data = "40.0"
|
arrow-data = "42.0"
|
||||||
arrow-schema = "40.0"
|
arrow-schema = "42.0"
|
||||||
arrow-ipc = "40.0"
|
arrow-ipc = "42.0"
|
||||||
object_store = "0.6.1"
|
object_store = "0.6.1"
|
||||||
|
|||||||
@@ -43,3 +43,10 @@ pip install lancedb
|
|||||||
::: lancedb.fts.populate_index
|
::: lancedb.fts.populate_index
|
||||||
|
|
||||||
::: lancedb.fts.search_index
|
::: lancedb.fts.search_index
|
||||||
|
|
||||||
|
## Utilities
|
||||||
|
|
||||||
|
::: lancedb.schema.schema_to_dict
|
||||||
|
::: lancedb.schema.dict_to_schema
|
||||||
|
::: lancedb.vector
|
||||||
|
|
||||||
|
|||||||
@@ -79,38 +79,32 @@ await db_setup.createTable('my_vectors', data)
|
|||||||
const tbl = await db.openTable("my_vectors")
|
const tbl = await db.openTable("my_vectors")
|
||||||
|
|
||||||
const results_1 = await tbl.search(Array(1536).fill(1.2))
|
const results_1 = await tbl.search(Array(1536).fill(1.2))
|
||||||
.limit(20)
|
.limit(10)
|
||||||
.execute()
|
.execute()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
<!-- Commenting out for now since metricType fails for JS on Ubuntu 22.04.
|
|
||||||
|
|
||||||
By default, `l2` will be used as `Metric` type. You can customize the metric type
|
By default, `l2` will be used as `Metric` type. You can customize the metric type
|
||||||
as well.
|
as well.
|
||||||
-->
|
|
||||||
|
|
||||||
<!--
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
-->
|
|
||||||
<!-- ```python
|
```python
|
||||||
df = tbl.search(np.random.random((1536))) \
|
df = tbl.search(np.random.random((1536))) \
|
||||||
.metric("cosine") \
|
.metric("cosine") \
|
||||||
.limit(10) \
|
.limit(10) \
|
||||||
.to_df()
|
.to_df()
|
||||||
```
|
```
|
||||||
-->
|
|
||||||
<!--
|
|
||||||
=== "JavaScript"
|
|
||||||
-->
|
|
||||||
|
|
||||||
<!-- ```javascript
|
|
||||||
|
=== "JavaScript"
|
||||||
|
|
||||||
|
```javascript
|
||||||
const results_2 = await tbl.search(Array(1536).fill(1.2))
|
const results_2 = await tbl.search(Array(1536).fill(1.2))
|
||||||
.metricType("cosine")
|
.metricType("cosine")
|
||||||
.limit(20)
|
.limit(10)
|
||||||
.execute()
|
.execute()
|
||||||
```
|
```
|
||||||
-->
|
|
||||||
|
|
||||||
### Search with Vector Index.
|
### Search with Vector Index.
|
||||||
|
|
||||||
|
|||||||
4
node/package-lock.json
generated
4
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.1.9",
|
"version": "0.1.10",
|
||||||
"lockfileVersion": 2,
|
"lockfileVersion": 2,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.1.9",
|
"version": "0.1.10",
|
||||||
"license": "Apache-2.0",
|
"license": "Apache-2.0",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@apache-arrow/ts": "^12.0.0",
|
"@apache-arrow/ts": "^12.0.0",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.1.10",
|
"version": "0.1.12",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"types": "dist/index.d.ts",
|
"types": "dist/index.d.ts",
|
||||||
|
|||||||
@@ -122,6 +122,14 @@ export interface Table<T = number[]> {
|
|||||||
delete: (filter: string) => Promise<void>
|
delete: (filter: string) => Promise<void>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface AwsCredentials {
|
||||||
|
accessKeyId: string
|
||||||
|
|
||||||
|
secretKey: string
|
||||||
|
|
||||||
|
sessionToken?: string
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A connection to a LanceDB database.
|
* A connection to a LanceDB database.
|
||||||
*/
|
*/
|
||||||
@@ -158,13 +166,10 @@ export class LocalConnection implements Connection {
|
|||||||
* @param embeddings An embedding function to use on this Table
|
* @param embeddings An embedding function to use on this Table
|
||||||
*/
|
*/
|
||||||
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
|
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>>
|
||||||
|
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>> {
|
||||||
const tbl = await databaseOpenTable.call(this._db, name)
|
const tbl = await databaseOpenTable.call(this._db, name)
|
||||||
if (embeddings !== undefined) {
|
return new LocalTable(tbl, name, embeddings, awsCredentials)
|
||||||
return new LocalTable(tbl, name, embeddings)
|
|
||||||
} else {
|
|
||||||
return new LocalTable(tbl, name)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -186,16 +191,24 @@ export class LocalConnection implements Connection {
|
|||||||
* @param embeddings An embedding function to use on this Table
|
* @param embeddings An embedding function to use on this Table
|
||||||
*/
|
*/
|
||||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
|
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>>
|
||||||
|
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>> {
|
||||||
if (mode === undefined) {
|
if (mode === undefined) {
|
||||||
mode = WriteMode.Create
|
mode = WriteMode.Create
|
||||||
}
|
}
|
||||||
const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase())
|
|
||||||
if (embeddings !== undefined) {
|
const createArgs = [this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase()]
|
||||||
return new LocalTable(tbl, name, embeddings)
|
if (awsCredentials !== undefined) {
|
||||||
} else {
|
createArgs.push(awsCredentials.accessKeyId)
|
||||||
return new LocalTable(tbl, name)
|
createArgs.push(awsCredentials.secretKey)
|
||||||
|
if (awsCredentials.sessionToken !== undefined) {
|
||||||
|
createArgs.push(awsCredentials.sessionToken)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const tbl = await tableCreate.call(...createArgs)
|
||||||
|
|
||||||
|
return new LocalTable(tbl, name, embeddings, awsCredentials)
|
||||||
}
|
}
|
||||||
|
|
||||||
async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
|
async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
|
||||||
@@ -217,6 +230,7 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
private readonly _tbl: any
|
private readonly _tbl: any
|
||||||
private readonly _name: string
|
private readonly _name: string
|
||||||
private readonly _embeddings?: EmbeddingFunction<T>
|
private readonly _embeddings?: EmbeddingFunction<T>
|
||||||
|
private readonly _awsCredentials?: AwsCredentials
|
||||||
|
|
||||||
constructor (tbl: any, name: string)
|
constructor (tbl: any, name: string)
|
||||||
/**
|
/**
|
||||||
@@ -225,10 +239,12 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
* @param embeddings An embedding function to use when interacting with this table
|
* @param embeddings An embedding function to use when interacting with this table
|
||||||
*/
|
*/
|
||||||
constructor (tbl: any, name: string, embeddings: EmbeddingFunction<T>)
|
constructor (tbl: any, name: string, embeddings: EmbeddingFunction<T>)
|
||||||
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>) {
|
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials)
|
||||||
|
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials) {
|
||||||
this._tbl = tbl
|
this._tbl = tbl
|
||||||
this._name = name
|
this._name = name
|
||||||
this._embeddings = embeddings
|
this._embeddings = embeddings
|
||||||
|
this._awsCredentials = awsCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
get name (): string {
|
get name (): string {
|
||||||
@@ -250,7 +266,15 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
* @return The number of rows added to the table
|
* @return The number of rows added to the table
|
||||||
*/
|
*/
|
||||||
async add (data: Array<Record<string, unknown>>): Promise<number> {
|
async add (data: Array<Record<string, unknown>>): Promise<number> {
|
||||||
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString())
|
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()]
|
||||||
|
if (this._awsCredentials !== undefined) {
|
||||||
|
callArgs.push(this._awsCredentials.accessKeyId)
|
||||||
|
callArgs.push(this._awsCredentials.secretKey)
|
||||||
|
if (this._awsCredentials.sessionToken !== undefined) {
|
||||||
|
callArgs.push(this._awsCredentials.sessionToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tableAdd.call(...callArgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -260,6 +284,14 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
* @return The number of rows added to the table
|
* @return The number of rows added to the table
|
||||||
*/
|
*/
|
||||||
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
|
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
|
||||||
|
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()]
|
||||||
|
if (this._awsCredentials !== undefined) {
|
||||||
|
callArgs.push(this._awsCredentials.accessKeyId)
|
||||||
|
callArgs.push(this._awsCredentials.secretKey)
|
||||||
|
if (this._awsCredentials.sessionToken !== undefined) {
|
||||||
|
callArgs.push(this._awsCredentials.sessionToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString())
|
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from .db import URI, DBConnection, LanceDBConnection
|
from .db import URI, DBConnection, LanceDBConnection
|
||||||
from .remote.db import RemoteDBConnection
|
from .remote.db import RemoteDBConnection
|
||||||
|
from .schema import vector
|
||||||
|
|
||||||
|
|
||||||
def connect(
|
def connect(
|
||||||
|
|||||||
@@ -13,11 +13,12 @@
|
|||||||
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Dict
|
from typing import Any, Callable, Dict, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import attr
|
import attr
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from lancedb.common import Credential
|
from lancedb.common import Credential
|
||||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||||
@@ -34,6 +35,12 @@ def _check_not_closed(f):
|
|||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
|
||||||
|
resp_body = await resp.read()
|
||||||
|
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||||
|
return reader.read_all()
|
||||||
|
|
||||||
|
|
||||||
@attr.define(slots=False)
|
@attr.define(slots=False)
|
||||||
class RestfulLanceDBClient:
|
class RestfulLanceDBClient:
|
||||||
db_name: str
|
db_name: str
|
||||||
@@ -56,28 +63,67 @@ class RestfulLanceDBClient:
|
|||||||
"x-api-key": self.api_key,
|
"x-api-key": self.api_key,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _check_status(resp: aiohttp.ClientResponse):
|
||||||
|
if resp.status == 404:
|
||||||
|
raise LanceDBClientError(f"Not found: {await resp.text()}")
|
||||||
|
elif 400 <= resp.status < 500:
|
||||||
|
raise LanceDBClientError(
|
||||||
|
f"Bad Request: {resp.status}, error: {await resp.text()}"
|
||||||
|
)
|
||||||
|
elif 500 <= resp.status < 600:
|
||||||
|
raise LanceDBClientError(
|
||||||
|
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
|
||||||
|
)
|
||||||
|
elif resp.status != 200:
|
||||||
|
raise LanceDBClientError(
|
||||||
|
f"Unknown Error: {resp.status}, error: {await resp.text()}"
|
||||||
|
)
|
||||||
|
|
||||||
@_check_not_closed
|
@_check_not_closed
|
||||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
||||||
|
"""Send a GET request and returns the deserialized response payload."""
|
||||||
|
if isinstance(params, BaseModel):
|
||||||
|
params: Dict[str, Any] = params.dict(exclude_none=True)
|
||||||
|
async with self.session.get(uri, params=params, headers=self.headers) as resp:
|
||||||
|
await self._check_status(resp)
|
||||||
|
return await resp.json()
|
||||||
|
|
||||||
|
@_check_not_closed
|
||||||
|
async def post(
|
||||||
|
self,
|
||||||
|
uri: str,
|
||||||
|
data: Union[Dict[str, Any], BaseModel],
|
||||||
|
deserialize: Callable = lambda resp: resp.json(),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Send a POST request and returns the deserialized response payload.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
uri : str
|
||||||
|
The uri to send the POST request to.
|
||||||
|
data: Union[Dict[str, Any], BaseModel]
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(data, BaseModel):
|
||||||
|
data: Dict[str, Any] = data.dict(exclude_none=True)
|
||||||
async with self.session.post(
|
async with self.session.post(
|
||||||
f"/1/table/{table_name}/",
|
uri,
|
||||||
json=query.dict(exclude_none=True),
|
json=data,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
) as resp:
|
) as resp:
|
||||||
resp: aiohttp.ClientResponse = resp
|
resp: aiohttp.ClientResponse = resp
|
||||||
if 400 <= resp.status < 500:
|
await self._check_status(resp)
|
||||||
raise LanceDBClientError(
|
return await deserialize(resp)
|
||||||
f"Bad Request: {resp.status}, error: {await resp.text()}"
|
|
||||||
)
|
|
||||||
if 500 <= resp.status < 600:
|
|
||||||
raise LanceDBClientError(
|
|
||||||
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
|
|
||||||
)
|
|
||||||
if resp.status != 200:
|
|
||||||
raise LanceDBClientError(
|
|
||||||
f"Unknown Error: {resp.status}, error: {await resp.text()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
resp_body = await resp.read()
|
@_check_not_closed
|
||||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
async def list_tables(self):
|
||||||
tbl = reader.read_all()
|
"""List all tables in the database."""
|
||||||
|
json = await self.get("/1/table/", {})
|
||||||
|
return json["tables"]
|
||||||
|
|
||||||
|
@_check_not_closed
|
||||||
|
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||||
|
"""Query a table."""
|
||||||
|
tbl = await self.post(f"/1/table/{table_name}/", query, deserialize=_read_ipc)
|
||||||
return VectorQueryResult(tbl)
|
return VectorQueryResult(tbl)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from typing import List
|
from typing import List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -34,12 +35,18 @@ class RemoteDBConnection(DBConnection):
|
|||||||
self.db_name = parsed.netloc
|
self.db_name = parsed.netloc
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self._client = RestfulLanceDBClient(self.db_name, region, api_key)
|
self._client = RestfulLanceDBClient(self.db_name, region, api_key)
|
||||||
|
try:
|
||||||
|
self._loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"RemoveConnect(name={self.db_name})"
|
return f"RemoveConnect(name={self.db_name})"
|
||||||
|
|
||||||
def table_names(self) -> List[str]:
|
def table_names(self) -> List[str]:
|
||||||
raise NotImplementedError
|
"""List the names of all tables in the database."""
|
||||||
|
result = self._loop.run_until_complete(self._client.list_tables())
|
||||||
|
return result
|
||||||
|
|
||||||
def open_table(self, name: str) -> Table:
|
def open_table(self, name: str) -> Table:
|
||||||
"""Open a Lance Table in the database.
|
"""Open a Lance Table in the database.
|
||||||
|
|||||||
@@ -11,7 +11,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
@@ -62,9 +61,5 @@ class RemoteTable(Table):
|
|||||||
return LanceQueryBuilder(self, query, vector_column)
|
return LanceQueryBuilder(self, query, vector_column)
|
||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
result = self._conn._client.query(self._name, query)
|
result = self._conn._client.query(self._name, query)
|
||||||
return loop.run_until_complete(result).to_arrow()
|
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||||
|
|||||||
289
python/lancedb/schema.py
Normal file
289
python/lancedb/schema.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Schema related utilities."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, Type
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
|
||||||
|
def vector(dimension: int, value_type: pa.DataType = pa.float32()) -> pa.DataType:
|
||||||
|
"""A help function to create a vector type.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dimension: The dimension of the vector.
|
||||||
|
value_type: pa.DataType, optional
|
||||||
|
The type of the value in the vector.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A PyArrow DataType for vectors.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
>>> import pyarrow as pa
|
||||||
|
>>> import lancedb
|
||||||
|
>>> schema = pa.schema([
|
||||||
|
... pa.field("id", pa.int64()),
|
||||||
|
... pa.field("vector", lancedb.vector(756)),
|
||||||
|
... ])
|
||||||
|
"""
|
||||||
|
return pa.list_(value_type, dimension)
|
||||||
|
|
||||||
|
|
||||||
|
def _type_to_dict(dt: pa.DataType) -> Dict[str, Any]:
|
||||||
|
if pa.types.is_boolean(dt):
|
||||||
|
return {"type": "boolean"}
|
||||||
|
elif pa.types.is_int8(dt):
|
||||||
|
return {"type": "int8"}
|
||||||
|
elif pa.types.is_int16(dt):
|
||||||
|
return {"type": "int16"}
|
||||||
|
elif pa.types.is_int32(dt):
|
||||||
|
return {"type": "int32"}
|
||||||
|
elif pa.types.is_int64(dt):
|
||||||
|
return {"type": "int64"}
|
||||||
|
elif pa.types.is_uint8(dt):
|
||||||
|
return {"type": "uint8"}
|
||||||
|
elif pa.types.is_uint16(dt):
|
||||||
|
return {"type": "uint16"}
|
||||||
|
elif pa.types.is_uint32(dt):
|
||||||
|
return {"type": "uint32"}
|
||||||
|
elif pa.types.is_uint64(dt):
|
||||||
|
return {"type": "uint64"}
|
||||||
|
elif pa.types.is_float16(dt):
|
||||||
|
return {"type": "float16"}
|
||||||
|
elif pa.types.is_float32(dt):
|
||||||
|
return {"type": "float32"}
|
||||||
|
elif pa.types.is_float64(dt):
|
||||||
|
return {"type": "float64"}
|
||||||
|
elif pa.types.is_date32(dt):
|
||||||
|
return {"type": f"date32"}
|
||||||
|
elif pa.types.is_date64(dt):
|
||||||
|
return {"type": f"date64"}
|
||||||
|
elif pa.types.is_time32(dt):
|
||||||
|
return {"type": f"time32:{dt.unit}"}
|
||||||
|
elif pa.types.is_time64(dt):
|
||||||
|
return {"type": f"time64:{dt.unit}"}
|
||||||
|
elif pa.types.is_timestamp(dt):
|
||||||
|
return {"type": f"timestamp:{dt.unit}:{dt.tz if dt.tz is not None else ''}"}
|
||||||
|
elif pa.types.is_string(dt):
|
||||||
|
return {"type": "string"}
|
||||||
|
elif pa.types.is_binary(dt):
|
||||||
|
return {"type": "binary"}
|
||||||
|
elif pa.types.is_large_string(dt):
|
||||||
|
return {"type": "large_string"}
|
||||||
|
elif pa.types.is_large_binary(dt):
|
||||||
|
return {"type": "large_binary"}
|
||||||
|
elif pa.types.is_fixed_size_binary(dt):
|
||||||
|
return {"type": "fixed_size_binary", "width": dt.byte_width}
|
||||||
|
elif pa.types.is_fixed_size_list(dt):
|
||||||
|
return {
|
||||||
|
"type": "fixed_size_list",
|
||||||
|
"width": dt.list_size,
|
||||||
|
"value_type": _type_to_dict(dt.value_type),
|
||||||
|
}
|
||||||
|
elif pa.types.is_list(dt):
|
||||||
|
return {
|
||||||
|
"type": "list",
|
||||||
|
"value_type": _type_to_dict(dt.value_type),
|
||||||
|
}
|
||||||
|
elif pa.types.is_struct(dt):
|
||||||
|
return {
|
||||||
|
"type": "struct",
|
||||||
|
"fields": [_field_to_dict(dt.field(i)) for i in range(dt.num_fields)],
|
||||||
|
}
|
||||||
|
elif pa.types.is_dictionary(dt):
|
||||||
|
return {
|
||||||
|
"type": "dictionary",
|
||||||
|
"index_type": _type_to_dict(dt.index_type),
|
||||||
|
"value_type": _type_to_dict(dt.value_type),
|
||||||
|
}
|
||||||
|
# TODO: support extension types
|
||||||
|
|
||||||
|
raise TypeError(f"Unsupported type: {dt}")
|
||||||
|
|
||||||
|
|
||||||
|
def _field_to_dict(field: pa.field) -> Dict[str, Any]:
|
||||||
|
ret = {
|
||||||
|
"name": field.name,
|
||||||
|
"type": _type_to_dict(field.type),
|
||||||
|
"nullable": field.nullable,
|
||||||
|
}
|
||||||
|
if field.metadata is not None:
|
||||||
|
ret["metadata"] = field.metadata
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def schema_to_dict(schema: pa.Schema) -> Dict[str, Any]:
|
||||||
|
"""Convert a PyArrow [Schema](pyarrow.Schema) to a dictionary.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
schema : pa.Schema
|
||||||
|
The PyArrow Schema to convert
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A dict of the data type.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
>>> import pyarrow as pa
|
||||||
|
>>> import lancedb
|
||||||
|
>>> schema = pa.schema(
|
||||||
|
... [
|
||||||
|
... pa.field("id", pa.int64()),
|
||||||
|
... pa.field("vector", lancedb.vector(512), nullable=False),
|
||||||
|
... pa.field(
|
||||||
|
... "struct",
|
||||||
|
... pa.struct(
|
||||||
|
... [
|
||||||
|
... pa.field("a", pa.utf8()),
|
||||||
|
... pa.field("b", pa.float32()),
|
||||||
|
... ]
|
||||||
|
... ),
|
||||||
|
... True,
|
||||||
|
... ),
|
||||||
|
... ],
|
||||||
|
... metadata={"key": "value"},
|
||||||
|
... )
|
||||||
|
>>> json_schema = schema_to_dict(schema)
|
||||||
|
>>> assert json_schema == {
|
||||||
|
... "fields": [
|
||||||
|
... {"name": "id", "type": {"type": "int64"}, "nullable": True},
|
||||||
|
... {
|
||||||
|
... "name": "vector",
|
||||||
|
... "type": {
|
||||||
|
... "type": "fixed_size_list",
|
||||||
|
... "value_type": {"type": "float32"},
|
||||||
|
... "width": 512,
|
||||||
|
... },
|
||||||
|
... "nullable": False,
|
||||||
|
... },
|
||||||
|
... {
|
||||||
|
... "name": "struct",
|
||||||
|
... "type": {
|
||||||
|
... "type": "struct",
|
||||||
|
... "fields": [
|
||||||
|
... {"name": "a", "type": {"type": "string"}, "nullable": True},
|
||||||
|
... {"name": "b", "type": {"type": "float32"}, "nullable": True},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... "nullable": True,
|
||||||
|
... },
|
||||||
|
... ],
|
||||||
|
... "metadata": {"key": "value"},
|
||||||
|
... }
|
||||||
|
|
||||||
|
"""
|
||||||
|
fields = []
|
||||||
|
for name in schema.names:
|
||||||
|
field = schema.field(name)
|
||||||
|
fields.append(_field_to_dict(field))
|
||||||
|
json_schema = {
|
||||||
|
"fields": fields,
|
||||||
|
"metadata": {
|
||||||
|
k.decode("utf-8"): v.decode("utf-8") for (k, v) in schema.metadata.items()
|
||||||
|
}
|
||||||
|
if schema.metadata is not None
|
||||||
|
else {},
|
||||||
|
}
|
||||||
|
return json_schema
|
||||||
|
|
||||||
|
|
||||||
|
def _dict_to_type(dt: Dict[str, Any]) -> pa.DataType:
|
||||||
|
type_name = dt["type"]
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
"boolean": pa.bool_(),
|
||||||
|
"int8": pa.int8(),
|
||||||
|
"int16": pa.int16(),
|
||||||
|
"int32": pa.int32(),
|
||||||
|
"int64": pa.int64(),
|
||||||
|
"uint8": pa.uint8(),
|
||||||
|
"uint16": pa.uint16(),
|
||||||
|
"uint32": pa.uint32(),
|
||||||
|
"uint64": pa.uint64(),
|
||||||
|
"float16": pa.float16(),
|
||||||
|
"float32": pa.float32(),
|
||||||
|
"float64": pa.float64(),
|
||||||
|
"string": pa.string(),
|
||||||
|
"binary": pa.binary(),
|
||||||
|
"large_string": pa.large_string(),
|
||||||
|
"large_binary": pa.large_binary(),
|
||||||
|
"date32": pa.date32(),
|
||||||
|
"date64": pa.date64(),
|
||||||
|
}[type_name]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if type_name == "fixed_size_binary":
|
||||||
|
return pa.binary(dt["width"])
|
||||||
|
elif type_name == "fixed_size_list":
|
||||||
|
return pa.list_(_dict_to_type(dt["value_type"]), dt["width"])
|
||||||
|
elif type_name == "list":
|
||||||
|
return pa.list_(_dict_to_type(dt["value_type"]))
|
||||||
|
elif type_name == "struct":
|
||||||
|
fields = []
|
||||||
|
for field in dt["fields"]:
|
||||||
|
fields.append(_dict_to_field(field))
|
||||||
|
return pa.struct(fields)
|
||||||
|
elif type_name == "dictionary":
|
||||||
|
return pa.dictionary(
|
||||||
|
_dict_to_type(dt["index_type"]), _dict_to_type(dt["value_type"])
|
||||||
|
)
|
||||||
|
elif type_name.startswith("time32:"):
|
||||||
|
return pa.time32(type_name.split(":")[1])
|
||||||
|
elif type_name.startswith("time64:"):
|
||||||
|
return pa.time64(type_name.split(":")[1])
|
||||||
|
elif type_name.startswith("timestamp:"):
|
||||||
|
fields = type_name.split(":")
|
||||||
|
unit = fields[1]
|
||||||
|
tz = fields[2] if len(fields) > 2 else None
|
||||||
|
return pa.timestamp(unit, tz)
|
||||||
|
raise TypeError(f"Unsupported type: {dt}")
|
||||||
|
|
||||||
|
|
||||||
|
def _dict_to_field(field: Dict[str, Any]) -> pa.Field:
|
||||||
|
name = field["name"]
|
||||||
|
nullable = field["nullable"] if "nullable" in field else True
|
||||||
|
dt = _dict_to_type(field["type"])
|
||||||
|
metadata = field.get("metadata", None)
|
||||||
|
return pa.field(name, dt, nullable, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_schema(json: Dict[str, Any]) -> pa.Schema:
|
||||||
|
"""Reconstruct a PyArrow Schema from a JSON dict.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
json : Dict[str, Any]
|
||||||
|
The JSON dict to reconstruct Schema from.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A PyArrow Schema.
|
||||||
|
"""
|
||||||
|
fields = []
|
||||||
|
for field in json["fields"]:
|
||||||
|
fields.append(_dict_to_field(field))
|
||||||
|
metadata = {
|
||||||
|
k.encode("utf-8"): v.encode("utf-8")
|
||||||
|
for (k, v) in json.get("metadata", {}).items()
|
||||||
|
}
|
||||||
|
return pa.schema(fields, metadata)
|
||||||
109
python/tests/test_schema.py
Normal file
109
python/tests/test_schema.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
import lancedb
|
||||||
|
from lancedb.schema import dict_to_schema, schema_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
def test_schema_to_dict():
|
||||||
|
schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.int64()),
|
||||||
|
pa.field("vector", lancedb.vector(512), nullable=False),
|
||||||
|
pa.field(
|
||||||
|
"struct",
|
||||||
|
pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("a", pa.utf8()),
|
||||||
|
pa.field("b", pa.float32()),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
pa.field("d", pa.dictionary(pa.int64(), pa.utf8()), False),
|
||||||
|
],
|
||||||
|
metadata={"key": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
json_schema = schema_to_dict(schema)
|
||||||
|
assert json_schema == {
|
||||||
|
"fields": [
|
||||||
|
{"name": "id", "type": {"type": "int64"}, "nullable": True},
|
||||||
|
{
|
||||||
|
"name": "vector",
|
||||||
|
"type": {
|
||||||
|
"type": "fixed_size_list",
|
||||||
|
"value_type": {"type": "float32"},
|
||||||
|
"width": 512,
|
||||||
|
},
|
||||||
|
"nullable": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "struct",
|
||||||
|
"type": {
|
||||||
|
"type": "struct",
|
||||||
|
"fields": [
|
||||||
|
{"name": "a", "type": {"type": "string"}, "nullable": True},
|
||||||
|
{"name": "b", "type": {"type": "float32"}, "nullable": True},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"nullable": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "d",
|
||||||
|
"type": {
|
||||||
|
"type": "dictionary",
|
||||||
|
"index_type": {"type": "int64"},
|
||||||
|
"value_type": {"type": "string"},
|
||||||
|
},
|
||||||
|
"nullable": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"metadata": {"key": "value"},
|
||||||
|
}
|
||||||
|
|
||||||
|
actual_schema = dict_to_schema(json_schema)
|
||||||
|
assert actual_schema == schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_temporal_types():
|
||||||
|
schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("t32", pa.time32("s")),
|
||||||
|
pa.field("t32ms", pa.time32("ms")),
|
||||||
|
pa.field("t64", pa.time64("ns")),
|
||||||
|
pa.field("ts", pa.timestamp("s")),
|
||||||
|
pa.field("ts_us_tz", pa.timestamp("us", tz="America/New_York")),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
json_schema = schema_to_dict(schema)
|
||||||
|
|
||||||
|
assert json_schema == {
|
||||||
|
"fields": [
|
||||||
|
{"name": "t32", "type": {"type": "time32:s"}, "nullable": True},
|
||||||
|
{"name": "t32ms", "type": {"type": "time32:ms"}, "nullable": True},
|
||||||
|
{"name": "t64", "type": {"type": "time64:ns"}, "nullable": True},
|
||||||
|
{"name": "ts", "type": {"type": "timestamp:s:"}, "nullable": True},
|
||||||
|
{
|
||||||
|
"name": "ts_us_tz",
|
||||||
|
"type": {"type": "timestamp:us:America/New_York"},
|
||||||
|
"nullable": True,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
actual_schema = dict_to_schema(json_schema)
|
||||||
|
assert actual_schema == schema
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb-node"
|
name = "vectordb-node"
|
||||||
version = "0.1.10"
|
version = "0.1.12"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
@@ -19,3 +19,6 @@ lance = { workspace = true }
|
|||||||
vectordb = { path = "../../vectordb" }
|
vectordb = { path = "../../vectordb" }
|
||||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||||
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }
|
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }
|
||||||
|
object_store = { workspace = true, features = ["aws"] }
|
||||||
|
async-trait = "0"
|
||||||
|
env_logger = "0"
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use std::ops::Deref;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use arrow_array::cast::as_list_array;
|
use arrow_array::cast::as_list_array;
|
||||||
@@ -25,10 +24,13 @@ use lance::arrow::{FixedSizeListArrayExt, RecordBatchExt};
|
|||||||
pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch {
|
pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch {
|
||||||
let column = record_batch
|
let column = record_batch
|
||||||
.column_by_name("vector")
|
.column_by_name("vector")
|
||||||
|
.cloned()
|
||||||
.expect("vector column is missing");
|
.expect("vector column is missing");
|
||||||
let arr = as_list_array(column.deref());
|
// TODO: we should just consume the underlaying js buffer in the future instead of this arrow around a bunch of times
|
||||||
|
let arr = as_list_array(column.as_ref());
|
||||||
let list_size = arr.values().len() / record_batch.num_rows();
|
let list_size = arr.values().len() / record_batch.num_rows();
|
||||||
let r = FixedSizeListArray::try_new(arr.values(), list_size as i32).unwrap();
|
let r =
|
||||||
|
FixedSizeListArray::try_new_from_values(arr.values().to_owned(), list_size as i32).unwrap();
|
||||||
|
|
||||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||||
"vector",
|
"vector",
|
||||||
|
|||||||
@@ -17,19 +17,23 @@ use std::convert::TryFrom;
|
|||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use arrow_array::{Float32Array, RecordBatchIterator, RecordBatchReader};
|
use arrow_array::{Float32Array, RecordBatchIterator};
|
||||||
use arrow_ipc::writer::FileWriter;
|
use arrow_ipc::writer::FileWriter;
|
||||||
|
use async_trait::async_trait;
|
||||||
use futures::{TryFutureExt, TryStreamExt};
|
use futures::{TryFutureExt, TryStreamExt};
|
||||||
use lance::dataset::{WriteMode, WriteParams};
|
use lance::dataset::{ReadParams, WriteMode, WriteParams};
|
||||||
use lance::index::vector::MetricType;
|
use lance::index::vector::MetricType;
|
||||||
|
use lance::io::object_store::ObjectStoreParams;
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
use neon::types::buffer::TypedArray;
|
use neon::types::buffer::TypedArray;
|
||||||
|
use object_store::aws::{AwsCredential, AwsCredentialProvider};
|
||||||
|
use object_store::CredentialProvider;
|
||||||
use once_cell::sync::OnceCell;
|
use once_cell::sync::OnceCell;
|
||||||
use tokio::runtime::Runtime;
|
use tokio::runtime::Runtime;
|
||||||
|
|
||||||
use vectordb::database::Database;
|
use vectordb::database::Database;
|
||||||
use vectordb::error::Error;
|
use vectordb::error::Error;
|
||||||
use vectordb::table::Table;
|
use vectordb::table::{OpenTableParams, Table};
|
||||||
|
|
||||||
use crate::arrow::arrow_buffer_to_record_batch;
|
use crate::arrow::arrow_buffer_to_record_batch;
|
||||||
|
|
||||||
@@ -49,8 +53,38 @@ struct JsTable {
|
|||||||
|
|
||||||
impl Finalize for JsTable {}
|
impl Finalize for JsTable {}
|
||||||
|
|
||||||
|
// TODO: object_store didn't export this type so I copied it.
|
||||||
|
// Make a requiest to object_store to export this type
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct StaticCredentialProvider<T> {
|
||||||
|
credential: Arc<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> StaticCredentialProvider<T> {
|
||||||
|
pub fn new(credential: T) -> Self {
|
||||||
|
Self {
|
||||||
|
credential: Arc::new(credential),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<T> CredentialProvider for StaticCredentialProvider<T>
|
||||||
|
where
|
||||||
|
T: std::fmt::Debug + Send + Sync,
|
||||||
|
{
|
||||||
|
type Credential = T;
|
||||||
|
|
||||||
|
async fn get_credential(&self) -> object_store::Result<Arc<T>> {
|
||||||
|
Ok(Arc::clone(&self.credential))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
|
fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
|
||||||
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
|
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
|
||||||
|
static LOG: OnceCell<()> = OnceCell::new();
|
||||||
|
|
||||||
|
LOG.get_or_init(|| env_logger::init());
|
||||||
|
|
||||||
RUNTIME.get_or_try_init(|| Runtime::new().or_else(|err| cx.throw_error(err.to_string())))
|
RUNTIME.get_or_try_init(|| Runtime::new().or_else(|err| cx.throw_error(err.to_string())))
|
||||||
}
|
}
|
||||||
@@ -97,19 +131,74 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
|||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_aws_creds<T>(
|
||||||
|
cx: &mut FunctionContext,
|
||||||
|
arg_starting_location: i32,
|
||||||
|
) -> Result<Option<AwsCredentialProvider>, NeonResult<T>> {
|
||||||
|
let secret_key_id = cx
|
||||||
|
.argument_opt(arg_starting_location)
|
||||||
|
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||||
|
.flatten()
|
||||||
|
.map(|v| v.value(cx));
|
||||||
|
|
||||||
|
let secret_key = cx
|
||||||
|
.argument_opt(arg_starting_location + 1)
|
||||||
|
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||||
|
.flatten()
|
||||||
|
.map(|v| v.value(cx));
|
||||||
|
|
||||||
|
let temp_token = cx
|
||||||
|
.argument_opt(arg_starting_location + 2)
|
||||||
|
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||||
|
.flatten()
|
||||||
|
.map(|v| v.value(cx));
|
||||||
|
|
||||||
|
match (secret_key_id, secret_key, temp_token) {
|
||||||
|
(Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new(
|
||||||
|
StaticCredentialProvider::new(AwsCredential {
|
||||||
|
key_id: key_id,
|
||||||
|
secret_key: key,
|
||||||
|
token: optional_token,
|
||||||
|
}),
|
||||||
|
))),
|
||||||
|
(None, None, None) => Ok(None),
|
||||||
|
_ => Err(cx.throw_error("Invalid credentials configuration")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let db = cx
|
let db = cx
|
||||||
.this()
|
.this()
|
||||||
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
||||||
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
|
|
||||||
|
let aws_creds = match get_aws_creds(&mut cx, 1) {
|
||||||
|
Ok(creds) => creds,
|
||||||
|
Err(err) => return err,
|
||||||
|
};
|
||||||
|
|
||||||
|
let param = ReadParams {
|
||||||
|
store_options: Some(ObjectStoreParams {
|
||||||
|
aws_credentials: aws_creds,
|
||||||
|
..ObjectStoreParams::default()
|
||||||
|
}),
|
||||||
|
..ReadParams::default()
|
||||||
|
};
|
||||||
|
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let channel = cx.channel();
|
let channel = cx.channel();
|
||||||
let database = db.database.clone();
|
let database = db.database.clone();
|
||||||
|
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
rt.spawn(async move {
|
rt.spawn(async move {
|
||||||
let table_rst = database.open_table(&table_name).await;
|
let table_rst = database
|
||||||
|
.open_table_with_params(
|
||||||
|
&table_name,
|
||||||
|
OpenTableParams {
|
||||||
|
open_table_params: param,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
let table = Arc::new(Mutex::new(
|
let table = Arc::new(Mutex::new(
|
||||||
@@ -241,8 +330,6 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
|||||||
"create" => WriteMode::Create,
|
"create" => WriteMode::Create,
|
||||||
_ => return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes"),
|
_ => return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes"),
|
||||||
};
|
};
|
||||||
let mut params = WriteParams::default();
|
|
||||||
params.mode = mode;
|
|
||||||
|
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let channel = cx.channel();
|
let channel = cx.channel();
|
||||||
@@ -250,11 +337,22 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
|||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let database = db.database.clone();
|
let database = db.database.clone();
|
||||||
|
|
||||||
|
let aws_creds = match get_aws_creds(&mut cx, 3) {
|
||||||
|
Ok(creds) => creds,
|
||||||
|
Err(err) => return err,
|
||||||
|
};
|
||||||
|
|
||||||
|
let params = WriteParams {
|
||||||
|
store_params: Some(ObjectStoreParams {
|
||||||
|
aws_credentials: aws_creds,
|
||||||
|
..ObjectStoreParams::default()
|
||||||
|
}),
|
||||||
|
mode: mode,
|
||||||
|
..WriteParams::default()
|
||||||
|
};
|
||||||
|
|
||||||
rt.block_on(async move {
|
rt.block_on(async move {
|
||||||
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
|
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||||
batches.into_iter().map(Ok),
|
|
||||||
schema,
|
|
||||||
));
|
|
||||||
let table_rst = database
|
let table_rst = database
|
||||||
.create_table(&table_name, batch_reader, Some(params))
|
.create_table(&table_name, batch_reader, Some(params))
|
||||||
.await;
|
.await;
|
||||||
@@ -289,16 +387,27 @@ fn table_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
|||||||
let table = js_table.table.clone();
|
let table = js_table.table.clone();
|
||||||
let write_mode = write_mode_map.get(write_mode.as_str()).cloned();
|
let write_mode = write_mode_map.get(write_mode.as_str()).cloned();
|
||||||
|
|
||||||
|
let aws_creds = match get_aws_creds(&mut cx, 2) {
|
||||||
|
Ok(creds) => creds,
|
||||||
|
Err(err) => return err,
|
||||||
|
};
|
||||||
|
|
||||||
|
let params = WriteParams {
|
||||||
|
store_params: Some(ObjectStoreParams {
|
||||||
|
aws_credentials: aws_creds,
|
||||||
|
..ObjectStoreParams::default()
|
||||||
|
}),
|
||||||
|
mode: write_mode.unwrap_or(WriteMode::Append),
|
||||||
|
..WriteParams::default()
|
||||||
|
};
|
||||||
|
|
||||||
rt.block_on(async move {
|
rt.block_on(async move {
|
||||||
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
|
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||||
batches.into_iter().map(Ok),
|
let add_result = table.lock().unwrap().add(batch_reader, Some(params)).await;
|
||||||
schema,
|
|
||||||
));
|
|
||||||
let add_result = table.lock().unwrap().add(batch_reader, write_mode).await;
|
|
||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
let added = add_result.or_else(|err| cx.throw_error(err.to_string()))?;
|
let _added = add_result.or_else(|err| cx.throw_error(err.to_string()))?;
|
||||||
Ok(cx.number(added as f64))
|
Ok(cx.boolean(true))
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
Ok(promise)
|
Ok(promise)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb"
|
name = "vectordb"
|
||||||
version = "0.1.10"
|
version = "0.1.12"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ impl Database {
|
|||||||
pub async fn create_table(
|
pub async fn create_table(
|
||||||
&self,
|
&self,
|
||||||
name: &str,
|
name: &str,
|
||||||
batches: Box<dyn RecordBatchReader>,
|
batches: impl RecordBatchReader + Send + 'static,
|
||||||
params: Option<WriteParams>,
|
params: Option<WriteParams>,
|
||||||
) -> Result<Table> {
|
) -> Result<Table> {
|
||||||
Table::create(&self.uri, name, batches, params).await
|
Table::create(&self.uri, name, batches, params).await
|
||||||
|
|||||||
@@ -173,10 +173,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_setters_getters() {
|
async fn test_setters_getters() {
|
||||||
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
|
let batches = make_test_batches();
|
||||||
let ds = Dataset::write(&mut batches, "memory://foo", None)
|
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||||
let query = Query::new(Arc::new(ds), vector.clone());
|
let query = Query::new(Arc::new(ds), vector.clone());
|
||||||
@@ -202,10 +200,8 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_execute() {
|
async fn test_execute() {
|
||||||
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
|
let batches = make_test_batches();
|
||||||
let ds = Dataset::write(&mut batches, "memory://foo", None)
|
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let vector = Float32Array::from_iter_values([0.1; 128]);
|
let vector = Float32Array::from_iter_values([0.1; 128]);
|
||||||
let query = Query::new(Arc::new(ds), vector.clone());
|
let query = Query::new(Arc::new(ds), vector.clone());
|
||||||
@@ -213,7 +209,7 @@ mod tests {
|
|||||||
assert_eq!(result.is_ok(), true);
|
assert_eq!(result.is_ok(), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_test_batches() -> Box<dyn RecordBatchReader> {
|
fn make_test_batches() -> impl RecordBatchReader + Send + 'static {
|
||||||
let dim: usize = 128;
|
let dim: usize = 128;
|
||||||
let schema = Arc::new(ArrowSchema::new(vec![
|
let schema = Arc::new(ArrowSchema::new(vec![
|
||||||
ArrowField::new("key", DataType::Int32, false),
|
ArrowField::new("key", DataType::Int32, false),
|
||||||
@@ -227,11 +223,11 @@ mod tests {
|
|||||||
),
|
),
|
||||||
ArrowField::new("uri", DataType::Utf8, true),
|
ArrowField::new("uri", DataType::Utf8, true),
|
||||||
]));
|
]));
|
||||||
Box::new(RecordBatchIterator::new(
|
RecordBatchIterator::new(
|
||||||
vec![RecordBatch::new_empty(schema.clone())]
|
vec![RecordBatch::new_empty(schema.clone())]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(Ok),
|
.map(Ok),
|
||||||
schema,
|
schema,
|
||||||
))
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ use snafu::prelude::*;
|
|||||||
|
|
||||||
use crate::error::{Error, InvalidTableNameSnafu, Result};
|
use crate::error::{Error, InvalidTableNameSnafu, Result};
|
||||||
use crate::index::vector::VectorIndexBuilder;
|
use crate::index::vector::VectorIndexBuilder;
|
||||||
use crate::WriteMode;
|
|
||||||
use crate::query::Query;
|
use crate::query::Query;
|
||||||
|
use crate::WriteMode;
|
||||||
|
|
||||||
pub const VECTOR_COLUMN_NAME: &str = "vector";
|
pub const VECTOR_COLUMN_NAME: &str = "vector";
|
||||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||||
@@ -117,7 +117,7 @@ impl Table {
|
|||||||
pub async fn create(
|
pub async fn create(
|
||||||
base_uri: &str,
|
base_uri: &str,
|
||||||
name: &str,
|
name: &str,
|
||||||
mut batches: Box<dyn RecordBatchReader>,
|
batches: impl RecordBatchReader + Send + 'static,
|
||||||
params: Option<WriteParams>,
|
params: Option<WriteParams>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let base_path = Path::new(base_uri);
|
let base_path = Path::new(base_uri);
|
||||||
@@ -127,7 +127,7 @@ impl Table {
|
|||||||
.to_str()
|
.to_str()
|
||||||
.context(InvalidTableNameSnafu { name })?
|
.context(InvalidTableNameSnafu { name })?
|
||||||
.to_string();
|
.to_string();
|
||||||
let dataset = Dataset::write(&mut batches, &uri, params)
|
let dataset = Dataset::write(batches, &uri, params)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| match e {
|
.map_err(|e| match e {
|
||||||
lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists {
|
lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists {
|
||||||
@@ -176,14 +176,16 @@ impl Table {
|
|||||||
/// * The number of rows added
|
/// * The number of rows added
|
||||||
pub async fn add(
|
pub async fn add(
|
||||||
&mut self,
|
&mut self,
|
||||||
mut batches: Box<dyn RecordBatchReader>,
|
batches: impl RecordBatchReader + Send + 'static,
|
||||||
write_mode: Option<WriteMode>,
|
params: Option<WriteParams>,
|
||||||
) -> Result<usize> {
|
) -> Result<()> {
|
||||||
let mut params = WriteParams::default();
|
let params = params.unwrap_or(WriteParams {
|
||||||
params.mode = write_mode.unwrap_or(WriteMode::Append);
|
mode: WriteMode::Append,
|
||||||
|
..WriteParams::default()
|
||||||
|
});
|
||||||
|
|
||||||
self.dataset = Arc::new(Dataset::write(&mut batches, &self.uri, Some(params)).await?);
|
self.dataset = Arc::new(Dataset::write(batches, &self.uri, Some(params)).await?);
|
||||||
Ok(batches.count())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new Query object that can be executed.
|
/// Creates a new Query object that can be executed.
|
||||||
@@ -207,12 +209,12 @@ impl Table {
|
|||||||
/// Merge new data into this table.
|
/// Merge new data into this table.
|
||||||
pub async fn merge(
|
pub async fn merge(
|
||||||
&mut self,
|
&mut self,
|
||||||
mut batches: Box<dyn RecordBatchReader>,
|
batches: impl RecordBatchReader + Send + 'static,
|
||||||
left_on: &str,
|
left_on: &str,
|
||||||
right_on: &str,
|
right_on: &str,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut dataset = self.dataset.as_ref().clone();
|
let mut dataset = self.dataset.as_ref().clone();
|
||||||
dataset.merge(&mut batches, left_on, right_on).await?;
|
dataset.merge(batches, left_on, right_on).await?;
|
||||||
self.dataset = Arc::new(dataset);
|
self.dataset = Arc::new(dataset);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -253,8 +255,8 @@ mod tests {
|
|||||||
let dataset_path = tmp_dir.path().join("test.lance");
|
let dataset_path = tmp_dir.path().join("test.lance");
|
||||||
let uri = tmp_dir.path().to_str().unwrap();
|
let uri = tmp_dir.path().to_str().unwrap();
|
||||||
|
|
||||||
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
|
let batches = make_test_batches();
|
||||||
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
|
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -284,11 +286,11 @@ mod tests {
|
|||||||
let tmp_dir = tempdir().unwrap();
|
let tmp_dir = tempdir().unwrap();
|
||||||
let uri = tmp_dir.path().to_str().unwrap();
|
let uri = tmp_dir.path().to_str().unwrap();
|
||||||
|
|
||||||
let batches: Box<dyn RecordBatchReader> = make_test_batches();
|
let batches = make_test_batches();
|
||||||
let _ = batches.schema().clone();
|
let _ = batches.schema().clone();
|
||||||
Table::create(&uri, "test", batches, None).await.unwrap();
|
Table::create(&uri, "test", batches, None).await.unwrap();
|
||||||
|
|
||||||
let batches: Box<dyn RecordBatchReader> = make_test_batches();
|
let batches = make_test_batches();
|
||||||
let result = Table::create(&uri, "test", batches, None).await;
|
let result = Table::create(&uri, "test", batches, None).await;
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result.unwrap_err(),
|
result.unwrap_err(),
|
||||||
@@ -301,12 +303,12 @@ mod tests {
|
|||||||
let tmp_dir = tempdir().unwrap();
|
let tmp_dir = tempdir().unwrap();
|
||||||
let uri = tmp_dir.path().to_str().unwrap();
|
let uri = tmp_dir.path().to_str().unwrap();
|
||||||
|
|
||||||
let batches: Box<dyn RecordBatchReader> = make_test_batches();
|
let batches = make_test_batches();
|
||||||
let schema = batches.schema().clone();
|
let schema = batches.schema().clone();
|
||||||
let mut table = Table::create(&uri, "test", batches, None).await.unwrap();
|
let mut table = Table::create(&uri, "test", batches, None).await.unwrap();
|
||||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||||
|
|
||||||
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
|
let new_batches = RecordBatchIterator::new(
|
||||||
vec![RecordBatch::try_new(
|
vec![RecordBatch::try_new(
|
||||||
schema.clone(),
|
schema.clone(),
|
||||||
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
||||||
@@ -315,7 +317,7 @@ mod tests {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(Ok),
|
.map(Ok),
|
||||||
schema.clone(),
|
schema.clone(),
|
||||||
));
|
);
|
||||||
|
|
||||||
table.add(new_batches, None).await.unwrap();
|
table.add(new_batches, None).await.unwrap();
|
||||||
assert_eq!(table.count_rows().await.unwrap(), 20);
|
assert_eq!(table.count_rows().await.unwrap(), 20);
|
||||||
@@ -327,12 +329,12 @@ mod tests {
|
|||||||
let tmp_dir = tempdir().unwrap();
|
let tmp_dir = tempdir().unwrap();
|
||||||
let uri = tmp_dir.path().to_str().unwrap();
|
let uri = tmp_dir.path().to_str().unwrap();
|
||||||
|
|
||||||
let batches: Box<dyn RecordBatchReader> = make_test_batches();
|
let batches = make_test_batches();
|
||||||
let schema = batches.schema().clone();
|
let schema = batches.schema().clone();
|
||||||
let mut table = Table::create(uri, "test", batches, None).await.unwrap();
|
let mut table = Table::create(uri, "test", batches, None).await.unwrap();
|
||||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||||
|
|
||||||
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
|
let new_batches = RecordBatchIterator::new(
|
||||||
vec![RecordBatch::try_new(
|
vec![RecordBatch::try_new(
|
||||||
schema.clone(),
|
schema.clone(),
|
||||||
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
||||||
@@ -341,10 +343,15 @@ mod tests {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(Ok),
|
.map(Ok),
|
||||||
schema.clone(),
|
schema.clone(),
|
||||||
));
|
);
|
||||||
|
|
||||||
|
let param: WriteParams = WriteParams {
|
||||||
|
mode: WriteMode::Overwrite,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
table
|
table
|
||||||
.add(new_batches, Some(WriteMode::Overwrite))
|
.add(new_batches, Some(param))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||||
@@ -357,8 +364,8 @@ mod tests {
|
|||||||
let dataset_path = tmp_dir.path().join("test.lance");
|
let dataset_path = tmp_dir.path().join("test.lance");
|
||||||
let uri = tmp_dir.path().to_str().unwrap();
|
let uri = tmp_dir.path().to_str().unwrap();
|
||||||
|
|
||||||
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
|
let batches = make_test_batches();
|
||||||
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
|
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -369,7 +376,7 @@ mod tests {
|
|||||||
assert_eq!(vector, query.query_vector);
|
assert_eq!(vector, query.query_vector);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default, Debug)]
|
||||||
struct NoOpCacheWrapper {
|
struct NoOpCacheWrapper {
|
||||||
called: AtomicBool,
|
called: AtomicBool,
|
||||||
}
|
}
|
||||||
@@ -396,8 +403,8 @@ mod tests {
|
|||||||
let dataset_path = tmp_dir.path().join("test.lance");
|
let dataset_path = tmp_dir.path().join("test.lance");
|
||||||
let uri = tmp_dir.path().to_str().unwrap();
|
let uri = tmp_dir.path().to_str().unwrap();
|
||||||
|
|
||||||
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
|
let batches = make_test_batches();
|
||||||
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
|
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -417,15 +424,15 @@ mod tests {
|
|||||||
assert!(wrapper.called());
|
assert!(wrapper.called());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_test_batches() -> Box<dyn RecordBatchReader> {
|
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
|
||||||
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
||||||
Box::new(RecordBatchIterator::new(
|
RecordBatchIterator::new(
|
||||||
vec![RecordBatch::try_new(
|
vec![RecordBatch::try_new(
|
||||||
schema.clone(),
|
schema.clone(),
|
||||||
vec![Arc::new(Int32Array::from_iter_values(0..10))],
|
vec![Arc::new(Int32Array::from_iter_values(0..10))],
|
||||||
)],
|
)],
|
||||||
schema,
|
schema,
|
||||||
))
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -465,9 +472,7 @@ mod tests {
|
|||||||
schema,
|
schema,
|
||||||
);
|
);
|
||||||
|
|
||||||
let reader: Box<dyn RecordBatchReader + Send> = Box::new(batches);
|
let mut table = Table::create(uri, "test", batches, None).await.unwrap();
|
||||||
let mut table = Table::create(uri, "test", reader, None).await.unwrap();
|
|
||||||
|
|
||||||
let mut i = IvfPQIndexBuilder::new();
|
let mut i = IvfPQIndexBuilder::new();
|
||||||
|
|
||||||
let index_builder = i
|
let index_builder = i
|
||||||
|
|||||||
Reference in New Issue
Block a user