mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 06:19:57 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a03f7ca5a | ||
|
|
88be978e87 | ||
|
|
98b12caa06 | ||
|
|
091dffb171 | ||
|
|
ace6aa883a | ||
|
|
80c25f9896 | ||
|
|
caf22fdb71 | ||
|
|
0e7ae5dfbf | ||
|
|
b261e27222 | ||
|
|
9f603f73a9 | ||
|
|
9ef846929b | ||
|
|
97364a2514 | ||
|
|
e6c6da6104 | ||
|
|
a5eb665b7d | ||
|
|
e2325c634b | ||
|
|
507eeae9c8 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.1.10
|
||||
current_version = 0.1.12
|
||||
commit = True
|
||||
message = Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
24
.github/workflows/docs.yml
vendored
24
.github/workflows/docs.yml
vendored
@@ -39,6 +39,28 @@ jobs:
|
||||
run: |
|
||||
python -m pip install -e .
|
||||
python -m pip install -r ../docs/requirements.txt
|
||||
- name: Set up node
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: 'npm'
|
||||
cache-dependency-path: node/package-lock.json
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Install node dependencies
|
||||
working-directory: node
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y protobuf-compiler libssl-dev
|
||||
- name: Build node
|
||||
working-directory: node
|
||||
run: |
|
||||
npm ci
|
||||
npm run build
|
||||
npm run tsc
|
||||
- name: Create markdown files
|
||||
working-directory: node
|
||||
run: |
|
||||
npx typedoc --plugin typedoc-plugin-markdown --out ../docs/src/javascript src/index.ts
|
||||
- name: Build docs
|
||||
run: |
|
||||
PYTHONPATH=. mkdocs build -f docs/mkdocs.yml
|
||||
@@ -50,4 +72,4 @@ jobs:
|
||||
path: "docs/site"
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v1
|
||||
uses: actions/deploy-pages@v1
|
||||
6
.github/workflows/python.yml
vendored
6
.github/workflows/python.yml
vendored
@@ -61,6 +61,8 @@ jobs:
|
||||
run: |
|
||||
pip install -e .
|
||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||
pip install pytest pytest-mock
|
||||
pip install pytest pytest-mock black
|
||||
- name: Black
|
||||
run: black --check --diff --no-color --quiet .
|
||||
- name: Run tests
|
||||
run: pytest -x -v --durations=30 tests
|
||||
run: pytest -x -v --durations=30 tests
|
||||
|
||||
1
.github/workflows/rust.yml
vendored
1
.github/workflows/rust.yml
vendored
@@ -6,6 +6,7 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- Cargo.toml
|
||||
- rust/**
|
||||
- .github/workflows/rust.yml
|
||||
|
||||
|
||||
10
Cargo.toml
10
Cargo.toml
@@ -6,9 +6,9 @@ members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = "0.5.3"
|
||||
arrow-array = "40.0"
|
||||
arrow-data = "40.0"
|
||||
arrow-schema = "40.0"
|
||||
arrow-ipc = "40.0"
|
||||
lance = "=0.5.5"
|
||||
arrow-array = "42.0"
|
||||
arrow-data = "42.0"
|
||||
arrow-schema = "42.0"
|
||||
arrow-ipc = "42.0"
|
||||
object_store = "0.6.1"
|
||||
|
||||
@@ -10,14 +10,16 @@ pip install lancedb
|
||||
|
||||
::: lancedb.connect
|
||||
|
||||
::: lancedb.LanceDBConnection
|
||||
::: lancedb.db.DBConnection
|
||||
|
||||
## Table
|
||||
|
||||
::: lancedb.table.LanceTable
|
||||
::: lancedb.table.Table
|
||||
|
||||
## Querying
|
||||
|
||||
::: lancedb.query.Query
|
||||
|
||||
::: lancedb.query.LanceQueryBuilder
|
||||
|
||||
::: lancedb.query.LanceFtsQueryBuilder
|
||||
@@ -41,3 +43,10 @@ pip install lancedb
|
||||
::: lancedb.fts.populate_index
|
||||
|
||||
::: lancedb.fts.search_index
|
||||
|
||||
## Utilities
|
||||
|
||||
::: lancedb.schema.schema_to_dict
|
||||
::: lancedb.schema.dict_to_schema
|
||||
::: lancedb.vector
|
||||
|
||||
|
||||
@@ -79,38 +79,32 @@ await db_setup.createTable('my_vectors', data)
|
||||
const tbl = await db.openTable("my_vectors")
|
||||
|
||||
const results_1 = await tbl.search(Array(1536).fill(1.2))
|
||||
.limit(20)
|
||||
.limit(10)
|
||||
.execute()
|
||||
```
|
||||
|
||||
|
||||
<!-- Commenting out for now since metricType fails for JS on Ubuntu 22.04.
|
||||
|
||||
By default, `l2` will be used as `Metric` type. You can customize the metric type
|
||||
as well.
|
||||
-->
|
||||
|
||||
<!--
|
||||
=== "Python"
|
||||
-->
|
||||
<!-- ```python
|
||||
|
||||
```python
|
||||
df = tbl.search(np.random.random((1536))) \
|
||||
.metric("cosine") \
|
||||
.limit(10) \
|
||||
.to_df()
|
||||
```
|
||||
-->
|
||||
<!--
|
||||
=== "JavaScript"
|
||||
-->
|
||||
|
||||
<!-- ```javascript
|
||||
|
||||
=== "JavaScript"
|
||||
|
||||
```javascript
|
||||
const results_2 = await tbl.search(Array(1536).fill(1.2))
|
||||
.metricType("cosine")
|
||||
.limit(20)
|
||||
.limit(10)
|
||||
.execute()
|
||||
```
|
||||
-->
|
||||
|
||||
|
||||
### Search with Vector Index.
|
||||
|
||||
|
||||
4
node/package-lock.json
generated
4
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.1.9",
|
||||
"version": "0.1.10",
|
||||
"lockfileVersion": 2,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.1.9",
|
||||
"version": "0.1.10",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@apache-arrow/ts": "^12.0.0",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.1.10",
|
||||
"version": "0.1.12",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
|
||||
@@ -122,6 +122,14 @@ export interface Table<T = number[]> {
|
||||
delete: (filter: string) => Promise<void>
|
||||
}
|
||||
|
||||
export interface AwsCredentials {
|
||||
accessKeyId: string
|
||||
|
||||
secretKey: string
|
||||
|
||||
sessionToken?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* A connection to a LanceDB database.
|
||||
*/
|
||||
@@ -158,13 +166,10 @@ export class LocalConnection implements Connection {
|
||||
* @param embeddings An embedding function to use on this Table
|
||||
*/
|
||||
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>>
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>> {
|
||||
const tbl = await databaseOpenTable.call(this._db, name)
|
||||
if (embeddings !== undefined) {
|
||||
return new LocalTable(tbl, name, embeddings)
|
||||
} else {
|
||||
return new LocalTable(tbl, name)
|
||||
}
|
||||
return new LocalTable(tbl, name, embeddings, awsCredentials)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -186,16 +191,24 @@ export class LocalConnection implements Connection {
|
||||
* @param embeddings An embedding function to use on this Table
|
||||
*/
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>>
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>> {
|
||||
if (mode === undefined) {
|
||||
mode = WriteMode.Create
|
||||
}
|
||||
const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase())
|
||||
if (embeddings !== undefined) {
|
||||
return new LocalTable(tbl, name, embeddings)
|
||||
} else {
|
||||
return new LocalTable(tbl, name)
|
||||
|
||||
const createArgs = [this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase()]
|
||||
if (awsCredentials !== undefined) {
|
||||
createArgs.push(awsCredentials.accessKeyId)
|
||||
createArgs.push(awsCredentials.secretKey)
|
||||
if (awsCredentials.sessionToken !== undefined) {
|
||||
createArgs.push(awsCredentials.sessionToken)
|
||||
}
|
||||
}
|
||||
|
||||
const tbl = await tableCreate.call(...createArgs)
|
||||
|
||||
return new LocalTable(tbl, name, embeddings, awsCredentials)
|
||||
}
|
||||
|
||||
async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
|
||||
@@ -217,6 +230,7 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
private readonly _tbl: any
|
||||
private readonly _name: string
|
||||
private readonly _embeddings?: EmbeddingFunction<T>
|
||||
private readonly _awsCredentials?: AwsCredentials
|
||||
|
||||
constructor (tbl: any, name: string)
|
||||
/**
|
||||
@@ -225,10 +239,12 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
* @param embeddings An embedding function to use when interacting with this table
|
||||
*/
|
||||
constructor (tbl: any, name: string, embeddings: EmbeddingFunction<T>)
|
||||
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>) {
|
||||
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials)
|
||||
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials) {
|
||||
this._tbl = tbl
|
||||
this._name = name
|
||||
this._embeddings = embeddings
|
||||
this._awsCredentials = awsCredentials
|
||||
}
|
||||
|
||||
get name (): string {
|
||||
@@ -250,7 +266,15 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
async add (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString())
|
||||
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()]
|
||||
if (this._awsCredentials !== undefined) {
|
||||
callArgs.push(this._awsCredentials.accessKeyId)
|
||||
callArgs.push(this._awsCredentials.secretKey)
|
||||
if (this._awsCredentials.sessionToken !== undefined) {
|
||||
callArgs.push(this._awsCredentials.sessionToken)
|
||||
}
|
||||
}
|
||||
return tableAdd.call(...callArgs)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -260,6 +284,14 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()]
|
||||
if (this._awsCredentials !== undefined) {
|
||||
callArgs.push(this._awsCredentials.accessKeyId)
|
||||
callArgs.push(this._awsCredentials.secretKey)
|
||||
if (this._awsCredentials.sessionToken !== undefined) {
|
||||
callArgs.push(this._awsCredentials.sessionToken)
|
||||
}
|
||||
}
|
||||
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString())
|
||||
}
|
||||
|
||||
|
||||
@@ -11,16 +11,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .db import URI, LanceDBConnection
|
||||
from typing import Optional
|
||||
|
||||
from .db import URI, DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector
|
||||
|
||||
|
||||
def connect(uri: URI) -> LanceDBConnection:
|
||||
"""Connect to a LanceDB instance at the given URI
|
||||
def connect(
|
||||
uri: URI, *, api_key: Optional[str] = None, region: str = "us-west-2"
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri: str or Path
|
||||
The uri of the database.
|
||||
api_token: str, optional
|
||||
If presented, connect to LanceDB cloud.
|
||||
Otherwise, connect to a database on file system or cloud storage.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -34,9 +43,17 @@ def connect(uri: URI) -> LanceDBConnection:
|
||||
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||
|
||||
Connect to LancdDB cloud:
|
||||
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||
|
||||
Returns
|
||||
-------
|
||||
conn : LanceDBConnection
|
||||
conn : DBConnection
|
||||
A connection to a LanceDB database.
|
||||
"""
|
||||
if isinstance(uri, str) and uri.startswith("db://"):
|
||||
if api_key is None:
|
||||
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
||||
return RemoteDBConnection(uri, api_key, region)
|
||||
return LanceDBConnection(uri)
|
||||
|
||||
@@ -23,3 +23,13 @@ URI = Union[str, Path]
|
||||
# TODO support generator
|
||||
DATA = Union[List[dict], dict, pd.DataFrame]
|
||||
VECTOR_COLUMN_NAME = "vector"
|
||||
|
||||
|
||||
class Credential(str):
|
||||
"""Credential field"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "********"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "********"
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import builtins
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
# import lancedb so we don't have to in every example
|
||||
import lancedb
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
@@ -15,17 +15,161 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow as pa
|
||||
from pyarrow import fs
|
||||
|
||||
from .common import DATA, URI
|
||||
from .table import LanceTable
|
||||
from .table import LanceTable, Table
|
||||
from .util import get_uri_location, get_uri_scheme
|
||||
|
||||
|
||||
class LanceDBConnection:
|
||||
class DBConnection(ABC):
|
||||
"""An active LanceDB connection interface."""
|
||||
|
||||
@abstractmethod
|
||||
def table_names(self) -> list[str]:
|
||||
"""List all table names in the database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: DATA = None,
|
||||
schema: pa.Schema = None,
|
||||
mode: str = "create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
data: list, tuple, dict, pd.DataFrame; optional
|
||||
The data to insert into the table.
|
||||
schema: pyarrow.Schema; optional
|
||||
The schema of the table.
|
||||
mode: str; default "create"
|
||||
The mode to use when creating the table. Can be either "create" or "overwrite".
|
||||
By default, if the table already exists, an exception is raised.
|
||||
If you want to overwrite the table, use mode="overwrite".
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Note
|
||||
----
|
||||
The vector index won't be created by default.
|
||||
To create the index, call the `create_index` method on the table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceTable
|
||||
A reference to the newly created table.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Can create with list of tuples or dictionaries:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data)
|
||||
LanceTable(my_table)
|
||||
>>> db["my_table"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: double
|
||||
long: double
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
You can also pass a pandas DataFrame:
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({
|
||||
... "vector": [[1.1, 1.2], [0.2, 1.8]],
|
||||
... "lat": [45.5, 40.1],
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data)
|
||||
LanceTable(table2)
|
||||
>>> db["table2"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: double
|
||||
long: double
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
Data is converted to Arrow before being written to disk. For maximum
|
||||
control over how data is saved, either provide the PyArrow schema to
|
||||
convert to or else provide a PyArrow table directly.
|
||||
|
||||
>>> custom_schema = pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("lat", pa.float32()),
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema)
|
||||
LanceTable(table3)
|
||||
>>> db["table3"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: float
|
||||
long: float
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, name: str) -> LanceTable:
|
||||
return self.open_table(name)
|
||||
|
||||
def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def drop_table(self, name: str):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
A connection to a LanceDB database.
|
||||
|
||||
@@ -59,13 +203,6 @@ class LanceDBConnection:
|
||||
if not isinstance(uri, Path):
|
||||
scheme = get_uri_scheme(uri)
|
||||
is_local = isinstance(uri, Path) or scheme == "file"
|
||||
# managed lancedb remote uses schema like lancedb+[http|grpc|...]://
|
||||
self._is_managed_remote = not is_local and scheme.startswith("lancedb")
|
||||
if self._is_managed_remote:
|
||||
if len(scheme.split("+")) != 2:
|
||||
raise ValueError(
|
||||
f"Invalid LanceDB URI: {uri}, expected uri to have scheme like lancedb+<flavor>://..."
|
||||
)
|
||||
if is_local:
|
||||
if isinstance(uri, str):
|
||||
uri = Path(uri)
|
||||
@@ -79,43 +216,6 @@ class LanceDBConnection:
|
||||
def uri(self) -> str:
|
||||
return self._uri
|
||||
|
||||
@functools.cached_property
|
||||
def is_managed_remote(self) -> bool:
|
||||
return self._is_managed_remote
|
||||
|
||||
@functools.cached_property
|
||||
def remote_flavor(self) -> str:
|
||||
if not self.is_managed_remote:
|
||||
raise ValueError(
|
||||
"Not a managed remote LanceDB, there should be no server flavor"
|
||||
)
|
||||
return get_uri_scheme(self.uri).split("+")[1]
|
||||
|
||||
@functools.cached_property
|
||||
def _client(self) -> "lancedb.remote.LanceDBClient":
|
||||
if not self.is_managed_remote:
|
||||
raise ValueError("Not a managed remote LanceDB, there should be no client")
|
||||
|
||||
# don't import unless we are really using remote
|
||||
from lancedb.remote.client import RestfulLanceDBClient
|
||||
|
||||
if self.remote_flavor == "http":
|
||||
return RestfulLanceDBClient(self._uri)
|
||||
|
||||
raise ValueError("Unsupported remote flavor: " + self.remote_flavor)
|
||||
|
||||
async def close(self):
|
||||
if self._entered:
|
||||
raise ValueError("Cannot re-enter the same LanceDBConnection twice")
|
||||
self._entered = True
|
||||
await self._client.close()
|
||||
|
||||
async def __aenter__(self) -> LanceDBConnection:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await self.close()
|
||||
|
||||
def table_names(self) -> list[str]:
|
||||
"""Get the names of all tables in the database.
|
||||
|
||||
@@ -149,16 +249,13 @@ class LanceDBConnection:
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self.table_names()
|
||||
|
||||
def __getitem__(self, name: str) -> LanceTable:
|
||||
return self.open_table(name)
|
||||
|
||||
def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: DATA = None,
|
||||
schema: pa.Schema = None,
|
||||
mode: str = "create",
|
||||
on_bad_vectors: str = "drop",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> LanceTable:
|
||||
"""Create a table in the database.
|
||||
@@ -175,9 +272,9 @@ class LanceDBConnection:
|
||||
The mode to use when creating the table. Can be either "create" or "overwrite".
|
||||
By default, if the table already exists, an exception is raised.
|
||||
If you want to overwrite the table, use mode="overwrite".
|
||||
on_bad_vectors: str
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
|
||||
@@ -10,18 +10,47 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Awaitable, Literal
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .common import VECTOR_COLUMN_NAME
|
||||
|
||||
|
||||
class Query(BaseModel):
|
||||
"""A Query"""
|
||||
|
||||
vector_column: str = VECTOR_COLUMN_NAME
|
||||
|
||||
# vector to search for
|
||||
vector: List[float]
|
||||
|
||||
# sql filter to refine the query with
|
||||
filter: Optional[str] = None
|
||||
|
||||
# top k results to return
|
||||
k: int
|
||||
|
||||
# # metrics
|
||||
metric: str = "L2"
|
||||
|
||||
# which columns to return in the results
|
||||
columns: Optional[List[str]] = None
|
||||
|
||||
# optional query parameters for tuning the results,
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
nprobes: int = 10
|
||||
|
||||
# Refine factor.
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
|
||||
class LanceQueryBuilder:
|
||||
"""
|
||||
A builder for nearest neighbor queries for LanceDB.
|
||||
@@ -47,9 +76,9 @@ class LanceQueryBuilder:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table: "lancedb.table.LanceTable",
|
||||
query: np.ndarray,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
table: "lancedb.table.Table",
|
||||
query: Union[np.ndarray, str],
|
||||
vector_column: str = VECTOR_COLUMN_NAME,
|
||||
):
|
||||
self._metric = "L2"
|
||||
self._nprobes = 20
|
||||
@@ -59,7 +88,7 @@ class LanceQueryBuilder:
|
||||
self._limit = 10
|
||||
self._columns = None
|
||||
self._where = None
|
||||
self._vector_column_name = vector_column_name
|
||||
self._vector_column = vector_column
|
||||
|
||||
def limit(self, limit: int) -> LanceQueryBuilder:
|
||||
"""Set the maximum number of results to return.
|
||||
@@ -181,52 +210,28 @@ class LanceQueryBuilder:
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as a arrow Table.
|
||||
Execute the query and return the results as an
|
||||
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
|
||||
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "score" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
vector and the returned vectors.
|
||||
"""
|
||||
if self._table._conn.is_managed_remote:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = self._table._conn._client.query(
|
||||
self._table.name, self.to_remote_query()
|
||||
)
|
||||
return loop.run_until_complete(result).to_arrow()
|
||||
|
||||
ds = self._table.to_lance()
|
||||
return ds.to_table(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
nearest={
|
||||
"column": self._vector_column_name,
|
||||
"q": self._query,
|
||||
"k": self._limit,
|
||||
"metric": self._metric,
|
||||
"nprobes": self._nprobes,
|
||||
"refine_factor": self._refine_factor,
|
||||
},
|
||||
)
|
||||
|
||||
def to_remote_query(self) -> "VectorQuery":
|
||||
# don't import unless we are connecting to remote
|
||||
from lancedb.remote.client import VectorQuery
|
||||
|
||||
return VectorQuery(
|
||||
vector=self._query.tolist(),
|
||||
vector = self._query if isinstance(self._query, list) else self._query.tolist()
|
||||
query = Query(
|
||||
vector=vector,
|
||||
filter=self._where,
|
||||
k=self._limit,
|
||||
_metric=self._metric,
|
||||
metric=self._metric,
|
||||
columns=self._columns,
|
||||
nprobes=self._nprobes,
|
||||
refine_factor=self._refine_factor,
|
||||
)
|
||||
return self._table._execute_query(query)
|
||||
|
||||
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
def to_df(self) -> pd.DataFrame:
|
||||
def to_arrow(self) -> pd.Table:
|
||||
try:
|
||||
import tantivy
|
||||
except ImportError:
|
||||
@@ -243,8 +248,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
# get the scores and doc ids
|
||||
row_ids, scores = search_index(index, self._query, self._limit)
|
||||
if len(row_ids) == 0:
|
||||
return pd.DataFrame()
|
||||
empty_schema = pa.schema([pa.field("score", pa.float32())])
|
||||
return pa.Table.from_pylist([], schema=empty_schema)
|
||||
scores = pa.array(scores)
|
||||
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
|
||||
output_tbl = output_tbl.append_column("score", scores)
|
||||
return output_tbl.to_pandas()
|
||||
return output_tbl
|
||||
|
||||
@@ -15,7 +15,6 @@ import abc
|
||||
from typing import List, Optional
|
||||
|
||||
import attr
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -13,12 +13,14 @@
|
||||
|
||||
|
||||
import functools
|
||||
import urllib.parse
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
import aiohttp
|
||||
import attr
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lancedb.common import Credential
|
||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||
from lancedb.remote.errors import LanceDBClientError
|
||||
|
||||
@@ -33,47 +35,95 @@ def _check_not_closed(f):
|
||||
return wrapped
|
||||
|
||||
|
||||
async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
|
||||
resp_body = await resp.read()
|
||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||
return reader.read_all()
|
||||
|
||||
|
||||
@attr.define(slots=False)
|
||||
class RestfulLanceDBClient:
|
||||
url: str
|
||||
db_name: str
|
||||
region: str
|
||||
api_key: Credential
|
||||
closed: bool = attr.field(default=False, init=False)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
parsed = urllib.parse.urlparse(self.url)
|
||||
scheme = parsed.scheme
|
||||
if not scheme.startswith("lancedb"):
|
||||
raise ValueError(
|
||||
f"Invalid scheme: {scheme}, must be like lancedb+<flavor>://"
|
||||
)
|
||||
flavor = scheme.split("+")[1]
|
||||
url = f"{flavor}://{parsed.hostname}:{parsed.port}"
|
||||
url = f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
||||
return aiohttp.ClientSession(url)
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
self.closed = True
|
||||
|
||||
@functools.cached_property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"x-api-key": self.api_key,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _check_status(resp: aiohttp.ClientResponse):
|
||||
if resp.status == 404:
|
||||
raise LanceDBClientError(f"Not found: {await resp.text()}")
|
||||
elif 400 <= resp.status < 500:
|
||||
raise LanceDBClientError(
|
||||
f"Bad Request: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
elif 500 <= resp.status < 600:
|
||||
raise LanceDBClientError(
|
||||
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
elif resp.status != 200:
|
||||
raise LanceDBClientError(
|
||||
f"Unknown Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
|
||||
@_check_not_closed
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
||||
"""Send a GET request and returns the deserialized response payload."""
|
||||
if isinstance(params, BaseModel):
|
||||
params: Dict[str, Any] = params.dict(exclude_none=True)
|
||||
async with self.session.get(uri, params=params, headers=self.headers) as resp:
|
||||
await self._check_status(resp)
|
||||
return await resp.json()
|
||||
|
||||
@_check_not_closed
|
||||
async def post(
|
||||
self,
|
||||
uri: str,
|
||||
data: Union[Dict[str, Any], BaseModel],
|
||||
deserialize: Callable = lambda resp: resp.json(),
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a POST request and returns the deserialized response payload.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
The uri to send the POST request to.
|
||||
data: Union[Dict[str, Any], BaseModel]
|
||||
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data: Dict[str, Any] = data.dict(exclude_none=True)
|
||||
async with self.session.post(
|
||||
f"/table/{table_name}/", json=query.dict(exclude_none=True)
|
||||
uri,
|
||||
json=data,
|
||||
headers=self.headers,
|
||||
) as resp:
|
||||
resp: aiohttp.ClientResponse = resp
|
||||
if 400 <= resp.status < 500:
|
||||
raise LanceDBClientError(
|
||||
f"Bad Request: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
if 500 <= resp.status < 600:
|
||||
raise LanceDBClientError(
|
||||
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
if resp.status != 200:
|
||||
raise LanceDBClientError(
|
||||
f"Unknown Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
await self._check_status(resp)
|
||||
return await deserialize(resp)
|
||||
|
||||
resp_body = await resp.read()
|
||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||
tbl = reader.read_all()
|
||||
@_check_not_closed
|
||||
async def list_tables(self):
|
||||
"""List all tables in the database."""
|
||||
json = await self.get("/1/table/", {})
|
||||
return json["tables"]
|
||||
|
||||
@_check_not_closed
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query a table."""
|
||||
tbl = await self.post(f"/1/table/{table_name}/", query, deserialize=_read_ipc)
|
||||
return VectorQueryResult(tbl)
|
||||
|
||||
78
python/lancedb/remote/db.py
Normal file
78
python/lancedb/remote/db.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.common import DATA
|
||||
from lancedb.db import DBConnection
|
||||
from lancedb.table import Table
|
||||
|
||||
from .client import RestfulLanceDBClient
|
||||
|
||||
|
||||
class RemoteDBConnection(DBConnection):
|
||||
"""A connection to a remote LanceDB database."""
|
||||
|
||||
def __init__(self, db_url: str, api_key: str, region: str):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
parsed = urlparse(db_url)
|
||||
if parsed.scheme != "db":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||
self.db_name = parsed.netloc
|
||||
self.api_key = api_key
|
||||
self._client = RestfulLanceDBClient(self.db_name, region, api_key)
|
||||
try:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoveConnect(name={self.db_name})"
|
||||
|
||||
def table_names(self) -> List[str]:
|
||||
"""List the names of all tables in the database."""
|
||||
result = self._loop.run_until_complete(self._client.list_tables())
|
||||
return result
|
||||
|
||||
def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
from .table import RemoteTable
|
||||
|
||||
# TODO: check if table exists
|
||||
|
||||
return RemoteTable(self, name)
|
||||
|
||||
def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: DATA = None,
|
||||
schema: pa.Schema = None,
|
||||
mode: str = "create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> Table:
|
||||
raise NotImplementedError
|
||||
65
python/lancedb/remote/table.py
Normal file
65
python/lancedb/remote/table.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
|
||||
from ..query import LanceQueryBuilder, Query
|
||||
from ..table import Query, Table
|
||||
from .db import RemoteDBConnection
|
||||
|
||||
|
||||
class RemoteTable(Table):
|
||||
def __init__(self, conn: RemoteDBConnection, name: str):
|
||||
self._conn = conn
|
||||
self._name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteTable({self._conn.db_name}.{self.name})"
|
||||
|
||||
def schema(self) -> pa.Schema:
|
||||
raise NotImplementedError
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
raise NotImplementedError
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
|
||||
) -> LanceQueryBuilder:
|
||||
return LanceQueryBuilder(self, query, vector_column)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
result = self._conn._client.query(self._name, query)
|
||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||
289
python/lancedb/schema.py
Normal file
289
python/lancedb/schema.py
Normal file
@@ -0,0 +1,289 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Schema related utilities."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Type
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def vector(dimension: int, value_type: pa.DataType = pa.float32()) -> pa.DataType:
|
||||
"""A help function to create a vector type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dimension: The dimension of the vector.
|
||||
value_type: pa.DataType, optional
|
||||
The type of the value in the vector.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A PyArrow DataType for vectors.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pyarrow as pa
|
||||
>>> import lancedb
|
||||
>>> schema = pa.schema([
|
||||
... pa.field("id", pa.int64()),
|
||||
... pa.field("vector", lancedb.vector(756)),
|
||||
... ])
|
||||
"""
|
||||
return pa.list_(value_type, dimension)
|
||||
|
||||
|
||||
def _type_to_dict(dt: pa.DataType) -> Dict[str, Any]:
|
||||
if pa.types.is_boolean(dt):
|
||||
return {"type": "boolean"}
|
||||
elif pa.types.is_int8(dt):
|
||||
return {"type": "int8"}
|
||||
elif pa.types.is_int16(dt):
|
||||
return {"type": "int16"}
|
||||
elif pa.types.is_int32(dt):
|
||||
return {"type": "int32"}
|
||||
elif pa.types.is_int64(dt):
|
||||
return {"type": "int64"}
|
||||
elif pa.types.is_uint8(dt):
|
||||
return {"type": "uint8"}
|
||||
elif pa.types.is_uint16(dt):
|
||||
return {"type": "uint16"}
|
||||
elif pa.types.is_uint32(dt):
|
||||
return {"type": "uint32"}
|
||||
elif pa.types.is_uint64(dt):
|
||||
return {"type": "uint64"}
|
||||
elif pa.types.is_float16(dt):
|
||||
return {"type": "float16"}
|
||||
elif pa.types.is_float32(dt):
|
||||
return {"type": "float32"}
|
||||
elif pa.types.is_float64(dt):
|
||||
return {"type": "float64"}
|
||||
elif pa.types.is_date32(dt):
|
||||
return {"type": f"date32"}
|
||||
elif pa.types.is_date64(dt):
|
||||
return {"type": f"date64"}
|
||||
elif pa.types.is_time32(dt):
|
||||
return {"type": f"time32:{dt.unit}"}
|
||||
elif pa.types.is_time64(dt):
|
||||
return {"type": f"time64:{dt.unit}"}
|
||||
elif pa.types.is_timestamp(dt):
|
||||
return {"type": f"timestamp:{dt.unit}:{dt.tz if dt.tz is not None else ''}"}
|
||||
elif pa.types.is_string(dt):
|
||||
return {"type": "string"}
|
||||
elif pa.types.is_binary(dt):
|
||||
return {"type": "binary"}
|
||||
elif pa.types.is_large_string(dt):
|
||||
return {"type": "large_string"}
|
||||
elif pa.types.is_large_binary(dt):
|
||||
return {"type": "large_binary"}
|
||||
elif pa.types.is_fixed_size_binary(dt):
|
||||
return {"type": "fixed_size_binary", "width": dt.byte_width}
|
||||
elif pa.types.is_fixed_size_list(dt):
|
||||
return {
|
||||
"type": "fixed_size_list",
|
||||
"width": dt.list_size,
|
||||
"value_type": _type_to_dict(dt.value_type),
|
||||
}
|
||||
elif pa.types.is_list(dt):
|
||||
return {
|
||||
"type": "list",
|
||||
"value_type": _type_to_dict(dt.value_type),
|
||||
}
|
||||
elif pa.types.is_struct(dt):
|
||||
return {
|
||||
"type": "struct",
|
||||
"fields": [_field_to_dict(dt.field(i)) for i in range(dt.num_fields)],
|
||||
}
|
||||
elif pa.types.is_dictionary(dt):
|
||||
return {
|
||||
"type": "dictionary",
|
||||
"index_type": _type_to_dict(dt.index_type),
|
||||
"value_type": _type_to_dict(dt.value_type),
|
||||
}
|
||||
# TODO: support extension types
|
||||
|
||||
raise TypeError(f"Unsupported type: {dt}")
|
||||
|
||||
|
||||
def _field_to_dict(field: pa.field) -> Dict[str, Any]:
|
||||
ret = {
|
||||
"name": field.name,
|
||||
"type": _type_to_dict(field.type),
|
||||
"nullable": field.nullable,
|
||||
}
|
||||
if field.metadata is not None:
|
||||
ret["metadata"] = field.metadata
|
||||
return ret
|
||||
|
||||
|
||||
def schema_to_dict(schema: pa.Schema) -> Dict[str, Any]:
|
||||
"""Convert a PyArrow [Schema](pyarrow.Schema) to a dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema : pa.Schema
|
||||
The PyArrow Schema to convert
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dict of the data type.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pyarrow as pa
|
||||
>>> import lancedb
|
||||
>>> schema = pa.schema(
|
||||
... [
|
||||
... pa.field("id", pa.int64()),
|
||||
... pa.field("vector", lancedb.vector(512), nullable=False),
|
||||
... pa.field(
|
||||
... "struct",
|
||||
... pa.struct(
|
||||
... [
|
||||
... pa.field("a", pa.utf8()),
|
||||
... pa.field("b", pa.float32()),
|
||||
... ]
|
||||
... ),
|
||||
... True,
|
||||
... ),
|
||||
... ],
|
||||
... metadata={"key": "value"},
|
||||
... )
|
||||
>>> json_schema = schema_to_dict(schema)
|
||||
>>> assert json_schema == {
|
||||
... "fields": [
|
||||
... {"name": "id", "type": {"type": "int64"}, "nullable": True},
|
||||
... {
|
||||
... "name": "vector",
|
||||
... "type": {
|
||||
... "type": "fixed_size_list",
|
||||
... "value_type": {"type": "float32"},
|
||||
... "width": 512,
|
||||
... },
|
||||
... "nullable": False,
|
||||
... },
|
||||
... {
|
||||
... "name": "struct",
|
||||
... "type": {
|
||||
... "type": "struct",
|
||||
... "fields": [
|
||||
... {"name": "a", "type": {"type": "string"}, "nullable": True},
|
||||
... {"name": "b", "type": {"type": "float32"}, "nullable": True},
|
||||
... ],
|
||||
... },
|
||||
... "nullable": True,
|
||||
... },
|
||||
... ],
|
||||
... "metadata": {"key": "value"},
|
||||
... }
|
||||
|
||||
"""
|
||||
fields = []
|
||||
for name in schema.names:
|
||||
field = schema.field(name)
|
||||
fields.append(_field_to_dict(field))
|
||||
json_schema = {
|
||||
"fields": fields,
|
||||
"metadata": {
|
||||
k.decode("utf-8"): v.decode("utf-8") for (k, v) in schema.metadata.items()
|
||||
}
|
||||
if schema.metadata is not None
|
||||
else {},
|
||||
}
|
||||
return json_schema
|
||||
|
||||
|
||||
def _dict_to_type(dt: Dict[str, Any]) -> pa.DataType:
|
||||
type_name = dt["type"]
|
||||
try:
|
||||
return {
|
||||
"boolean": pa.bool_(),
|
||||
"int8": pa.int8(),
|
||||
"int16": pa.int16(),
|
||||
"int32": pa.int32(),
|
||||
"int64": pa.int64(),
|
||||
"uint8": pa.uint8(),
|
||||
"uint16": pa.uint16(),
|
||||
"uint32": pa.uint32(),
|
||||
"uint64": pa.uint64(),
|
||||
"float16": pa.float16(),
|
||||
"float32": pa.float32(),
|
||||
"float64": pa.float64(),
|
||||
"string": pa.string(),
|
||||
"binary": pa.binary(),
|
||||
"large_string": pa.large_string(),
|
||||
"large_binary": pa.large_binary(),
|
||||
"date32": pa.date32(),
|
||||
"date64": pa.date64(),
|
||||
}[type_name]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if type_name == "fixed_size_binary":
|
||||
return pa.binary(dt["width"])
|
||||
elif type_name == "fixed_size_list":
|
||||
return pa.list_(_dict_to_type(dt["value_type"]), dt["width"])
|
||||
elif type_name == "list":
|
||||
return pa.list_(_dict_to_type(dt["value_type"]))
|
||||
elif type_name == "struct":
|
||||
fields = []
|
||||
for field in dt["fields"]:
|
||||
fields.append(_dict_to_field(field))
|
||||
return pa.struct(fields)
|
||||
elif type_name == "dictionary":
|
||||
return pa.dictionary(
|
||||
_dict_to_type(dt["index_type"]), _dict_to_type(dt["value_type"])
|
||||
)
|
||||
elif type_name.startswith("time32:"):
|
||||
return pa.time32(type_name.split(":")[1])
|
||||
elif type_name.startswith("time64:"):
|
||||
return pa.time64(type_name.split(":")[1])
|
||||
elif type_name.startswith("timestamp:"):
|
||||
fields = type_name.split(":")
|
||||
unit = fields[1]
|
||||
tz = fields[2] if len(fields) > 2 else None
|
||||
return pa.timestamp(unit, tz)
|
||||
raise TypeError(f"Unsupported type: {dt}")
|
||||
|
||||
|
||||
def _dict_to_field(field: Dict[str, Any]) -> pa.Field:
|
||||
name = field["name"]
|
||||
nullable = field["nullable"] if "nullable" in field else True
|
||||
dt = _dict_to_type(field["type"])
|
||||
metadata = field.get("metadata", None)
|
||||
return pa.field(name, dt, nullable, metadata)
|
||||
|
||||
|
||||
def dict_to_schema(json: Dict[str, Any]) -> pa.Schema:
|
||||
"""Reconstruct a PyArrow Schema from a JSON dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
json : Dict[str, Any]
|
||||
The JSON dict to reconstruct Schema from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A PyArrow Schema.
|
||||
"""
|
||||
fields = []
|
||||
for field in json["fields"]:
|
||||
fields.append(_dict_to_field(field))
|
||||
metadata = {
|
||||
k.encode("utf-8"): v.encode("utf-8")
|
||||
for (k, v) in json.get("metadata", {}).items()
|
||||
}
|
||||
return pa.schema(fields, metadata)
|
||||
@@ -14,19 +14,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import Any, List, Union
|
||||
from typing import List, Union
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.fs
|
||||
from lance import LanceDataset
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .query import LanceFtsQueryBuilder, LanceQueryBuilder
|
||||
from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query
|
||||
|
||||
|
||||
def _sanitize_data(data, schema, on_bad_vectors, fill_value):
|
||||
@@ -47,14 +49,14 @@ def _sanitize_data(data, schema, on_bad_vectors, fill_value):
|
||||
return data
|
||||
|
||||
|
||||
class LanceTable:
|
||||
class Table(ABC):
|
||||
"""
|
||||
A table in a LanceDB database.
|
||||
A [Table](Table) is a collection of Records in a LanceDB [Database](Database).
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Create using [LanceDBConnection.create_table][lancedb.LanceDBConnection.create_table]
|
||||
Create using [DBConnection.create_table][lancedb.DBConnection.create_table]
|
||||
(more examples in that method's documentation).
|
||||
|
||||
>>> import lancedb
|
||||
@@ -69,12 +71,12 @@ class LanceTable:
|
||||
vector: [[[1.1,1.2]]]
|
||||
b: [[2]]
|
||||
|
||||
Can append new data with [LanceTable.add][lancedb.table.LanceTable.add].
|
||||
Can append new data with [Table.add()][lancedb.table.Table.add].
|
||||
|
||||
>>> table.add([{"vector": [0.5, 1.3], "b": 4}])
|
||||
2
|
||||
|
||||
Can query the table with [LanceTable.search][lancedb.table.LanceTable.search].
|
||||
Can query the table with [Table.search][lancedb.table.Table.search].
|
||||
|
||||
>>> table.search([0.4, 0.4]).select(["b"]).to_df()
|
||||
b vector score
|
||||
@@ -82,8 +84,128 @@ class LanceTable:
|
||||
1 2 [1.1, 1.2] 1.13
|
||||
|
||||
Search queries are much faster when an index is created. See
|
||||
[LanceTable.create_index][lancedb.table.LanceTable.create_index].
|
||||
[Table.create_index][lancedb.table.Table.create_index].
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def schema(self) -> pa.Schema:
|
||||
"""Return the [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) of
|
||||
this [Table](Table)
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_pandas(self) -> pd.DataFrame:
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
"""
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
@abstractmethod
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""Return the table as a pyarrow Table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.Table
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: str, default "L2"
|
||||
The distance metric to use when creating the index.
|
||||
Valid values are "L2", "cosine", or "dot".
|
||||
L2 is euclidean distance.
|
||||
num_partitions: int
|
||||
The number of IVF partitions to use when creating the index.
|
||||
Default is 256.
|
||||
num_sub_vectors: int
|
||||
The number of PQ sub-vectors to use when creating the index.
|
||||
Default is 96.
|
||||
vector_column_name: str, default "vector"
|
||||
The vector column name to create the index.
|
||||
replace: bool, default True
|
||||
If True, replace the existing index if it exists.
|
||||
If False, raise an error if duplicate index exists.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> int:
|
||||
"""Add more data to the [Table](Table).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: list-of-dict, dict, pd.DataFrame
|
||||
The data to insert into the table.
|
||||
mode: str
|
||||
The mode to use when writing the data. Valid values are
|
||||
"append" and "overwrite".
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The number of vectors in the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: list, np.ndarray
|
||||
The query vector.
|
||||
vector_column: str, default "vector"
|
||||
The name of the vector column to search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
A query builder object representing the query.
|
||||
Once executed, the query returns selected columns, the vector,
|
||||
and also the "score" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
pass
|
||||
|
||||
|
||||
class LanceTable(Table):
|
||||
"""
|
||||
A table in a LanceDB database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -95,7 +217,8 @@ class LanceTable:
|
||||
|
||||
def _reset_dataset(self):
|
||||
try:
|
||||
del self.__dict__["_dataset"]
|
||||
if "_dataset" in self.__dict__:
|
||||
del self.__dict__["_dataset"]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -195,26 +318,7 @@ class LanceTable:
|
||||
vector_column_name=VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: str, default "L2"
|
||||
The distance metric to use when creating the index.
|
||||
Valid values are "L2", "cosine", or "dot".
|
||||
L2 is euclidean distance.
|
||||
num_partitions: int
|
||||
The number of IVF partitions to use when creating the index.
|
||||
Default is 256.
|
||||
num_sub_vectors: int
|
||||
The number of PQ sub-vectors to use when creating the index.
|
||||
Default is 96.
|
||||
vector_column_name: str, default "vector"
|
||||
The vector column name to create the index.
|
||||
replace: bool, default True
|
||||
If True, replace the existing index if it exists.
|
||||
If False, raise an error if duplicate index exists.
|
||||
"""
|
||||
"""Create an index on the table."""
|
||||
self._dataset.create_index(
|
||||
column=vector_column_name,
|
||||
index_type="IVF_PQ",
|
||||
@@ -258,7 +362,7 @@ class LanceTable:
|
||||
self,
|
||||
data: DATA,
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "drop",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> int:
|
||||
"""Add data to the table.
|
||||
@@ -270,9 +374,9 @@ class LanceTable:
|
||||
mode: str
|
||||
The mode to use when writing the data. Valid values are
|
||||
"append" and "overwrite".
|
||||
on_bad_vectors: str
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
@@ -281,6 +385,7 @@ class LanceTable:
|
||||
int
|
||||
The number of vectors in the table.
|
||||
"""
|
||||
# TODO: manage table listing and metadata separately
|
||||
data = _sanitize_data(
|
||||
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
@@ -326,10 +431,10 @@ class LanceTable:
|
||||
cls,
|
||||
db,
|
||||
name,
|
||||
data,
|
||||
data=None,
|
||||
schema=None,
|
||||
mode="create",
|
||||
on_bad_vectors: str = "drop",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""
|
||||
@@ -354,37 +459,40 @@ class LanceTable:
|
||||
The LanceDB instance to create the table in.
|
||||
name: str
|
||||
The name of the table to create.
|
||||
data: list-of-dict, dict, pd.DataFrame
|
||||
data: list-of-dict, dict, pd.DataFrame, default None
|
||||
The data to insert into the table.
|
||||
At least one of `data` or `schema` must be provided.
|
||||
schema: dict, optional
|
||||
The schema of the table. If not provided, the schema is inferred from the data.
|
||||
At least one of `data` or `schema` must be provided.
|
||||
mode: str, default "create"
|
||||
The mode to use when writing the data. Valid values are
|
||||
"create", "overwrite", and "append".
|
||||
on_bad_vectors: str
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
tbl = LanceTable(db, name)
|
||||
data = _sanitize_data(
|
||||
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
if data is not None:
|
||||
data = _sanitize_data(
|
||||
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
else:
|
||||
if schema is None:
|
||||
raise ValueError("Either data or schema must be provided")
|
||||
data = pa.Table.from_pylist([], schema=schema)
|
||||
lance.write_dataset(data, tbl._dataset_uri, mode=mode)
|
||||
return tbl
|
||||
return LanceTable(db, name)
|
||||
|
||||
@classmethod
|
||||
def open(cls, db, name):
|
||||
tbl = cls(db, name)
|
||||
if tbl._conn.is_managed_remote:
|
||||
# Not completely sure how to check for remote table existence yet.
|
||||
return tbl
|
||||
if not os.path.exists(tbl._dataset_uri):
|
||||
raise FileNotFoundError(
|
||||
f"Table {name} does not exist. Please first call db.create_table({name}, data)"
|
||||
)
|
||||
|
||||
return tbl
|
||||
|
||||
def delete(self, where: str):
|
||||
@@ -415,11 +523,26 @@ class LanceTable:
|
||||
"""
|
||||
self._dataset.delete(where)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
ds = self.to_lance()
|
||||
return ds.to_table(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
nearest={
|
||||
"column": query.vector_column,
|
||||
"q": query.vector,
|
||||
"k": query.k,
|
||||
"metric": query.metric,
|
||||
"nprobes": query.nprobes,
|
||||
"refine_factor": query.refine_factor,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_schema(
|
||||
data: pa.Table,
|
||||
schema: pa.Schema = None,
|
||||
on_bad_vectors: str = "drop",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
"""Ensure that the table has the expected schema.
|
||||
@@ -431,10 +554,10 @@ def _sanitize_schema(
|
||||
schema: pa.Schema; optional
|
||||
The expected schema. If not provided, this just converts the
|
||||
vector column to fixed_size_list(float32) if necessary.
|
||||
on_bad_vectors: str
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
fill_value: float
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
if schema is not None:
|
||||
@@ -463,7 +586,7 @@ def _sanitize_schema(
|
||||
def _sanitize_vector_column(
|
||||
data: pa.Table,
|
||||
vector_column_name: str,
|
||||
on_bad_vectors: str = "drop",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
"""
|
||||
@@ -475,10 +598,10 @@ def _sanitize_vector_column(
|
||||
The table to sanitize.
|
||||
vector_column_name: str
|
||||
The name of the vector column.
|
||||
on_bad_vectors: str
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
fill_value: float
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.0
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
if vector_column_name not in data.column_names:
|
||||
@@ -501,7 +624,7 @@ def _sanitize_vector_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
|
||||
has_nans = pc.any(vec_arr.values.is_nan()).as_py()
|
||||
has_nans = pc.any(pc.is_nan(vec_arr.values)).as_py()
|
||||
if has_nans:
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
@@ -524,7 +647,7 @@ def ensure_fixed_size_list_of_f32(vec_arr):
|
||||
|
||||
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
"""Sanitize jagged vectors."""
|
||||
if on_bad_vectors == "raise":
|
||||
if on_bad_vectors == "error":
|
||||
raise ValueError(
|
||||
f"Vector column {vector_column_name} has variable length vectors "
|
||||
"Set on_bad_vectors='drop' to remove them, or "
|
||||
@@ -538,7 +661,7 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
|
||||
if on_bad_vectors == "fill":
|
||||
if fill_value is None:
|
||||
raise ValueError(
|
||||
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
)
|
||||
fill_arr = pa.scalar([float(fill_value)] * ndims)
|
||||
vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr)
|
||||
@@ -552,7 +675,7 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
|
||||
|
||||
def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
"""Sanitize NaNs in vectors"""
|
||||
if on_bad_vectors == "raise":
|
||||
if on_bad_vectors == "error":
|
||||
raise ValueError(
|
||||
f"Vector column {vector_column_name} has NaNs. "
|
||||
"Set on_bad_vectors='drop' to remove them, or "
|
||||
@@ -561,10 +684,10 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
elif on_bad_vectors == "fill":
|
||||
if fill_value is None:
|
||||
raise ValueError(
|
||||
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
)
|
||||
fill_value = float(fill_value)
|
||||
values = pc.if_else(vec_arr.values.is_nan(), fill_value, vec_arr.values)
|
||||
values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values)
|
||||
ndims = len(vec_arr[0])
|
||||
vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims)
|
||||
data = data.set_column(
|
||||
|
||||
@@ -11,9 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
from pyarrow import fs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def get_uri_scheme(uri: str) -> str:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.1.9"
|
||||
version = "0.1.10"
|
||||
dependencies = ["pylance~=0.5.0", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic", "attr"]
|
||||
description = "lancedb"
|
||||
authors = [
|
||||
|
||||
@@ -20,18 +20,33 @@ import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.query import LanceQueryBuilder
|
||||
from lancedb.query import LanceQueryBuilder, Query
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
class MockTable:
|
||||
def __init__(self, tmp_path):
|
||||
self.uri = tmp_path
|
||||
self._conn = LanceDBConnection("/tmp/lance/")
|
||||
self._conn = LanceDBConnection(self.uri)
|
||||
|
||||
def to_lance(self):
|
||||
return lance.dataset(self.uri)
|
||||
|
||||
def _execute_query(self, query):
|
||||
ds = self.to_lance()
|
||||
return ds.to_table(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
nearest={
|
||||
"column": query.vector_column,
|
||||
"q": query.vector,
|
||||
"k": query.k,
|
||||
"metric": query.metric,
|
||||
"nprobes": query.nprobes,
|
||||
"refine_factor": query.refine_factor,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table(tmp_path) -> MockTable:
|
||||
@@ -94,20 +109,17 @@ def test_query_builder_with_different_vector_column():
|
||||
)
|
||||
ds = mock.Mock()
|
||||
table.to_lance.return_value = ds
|
||||
table._conn = mock.MagicMock()
|
||||
table._conn.is_managed_remote = False
|
||||
builder.to_arrow()
|
||||
ds.to_table.assert_called_once_with(
|
||||
columns=["b"],
|
||||
filter="b < 10",
|
||||
nearest={
|
||||
"column": vector_column_name,
|
||||
"q": query,
|
||||
"k": 2,
|
||||
"metric": "cosine",
|
||||
"nprobes": 20,
|
||||
"refine_factor": None,
|
||||
},
|
||||
table._execute_query.assert_called_once_with(
|
||||
Query(
|
||||
vector=query,
|
||||
filter="b < 10",
|
||||
k=2,
|
||||
metric="cosine",
|
||||
columns=["b"],
|
||||
nprobes=20,
|
||||
refine_factor=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
import lancedb
|
||||
from lancedb.remote.client import VectorQuery, VectorQueryResult
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class FakeLanceDBClient:
|
||||
|
||||
|
||||
def test_remote_db():
|
||||
conn = LanceDBConnection("lancedb+http://client-will-be-injected")
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
setattr(conn, "_client", FakeLanceDBClient())
|
||||
|
||||
table = conn["test"]
|
||||
|
||||
109
python/tests/test_schema.py
Normal file
109
python/tests/test_schema.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
import lancedb
|
||||
from lancedb.schema import dict_to_schema, schema_to_dict
|
||||
|
||||
|
||||
def test_schema_to_dict():
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", lancedb.vector(512), nullable=False),
|
||||
pa.field(
|
||||
"struct",
|
||||
pa.struct(
|
||||
[
|
||||
pa.field("a", pa.utf8()),
|
||||
pa.field("b", pa.float32()),
|
||||
]
|
||||
),
|
||||
True,
|
||||
),
|
||||
pa.field("d", pa.dictionary(pa.int64(), pa.utf8()), False),
|
||||
],
|
||||
metadata={"key": "value"},
|
||||
)
|
||||
|
||||
json_schema = schema_to_dict(schema)
|
||||
assert json_schema == {
|
||||
"fields": [
|
||||
{"name": "id", "type": {"type": "int64"}, "nullable": True},
|
||||
{
|
||||
"name": "vector",
|
||||
"type": {
|
||||
"type": "fixed_size_list",
|
||||
"value_type": {"type": "float32"},
|
||||
"width": 512,
|
||||
},
|
||||
"nullable": False,
|
||||
},
|
||||
{
|
||||
"name": "struct",
|
||||
"type": {
|
||||
"type": "struct",
|
||||
"fields": [
|
||||
{"name": "a", "type": {"type": "string"}, "nullable": True},
|
||||
{"name": "b", "type": {"type": "float32"}, "nullable": True},
|
||||
],
|
||||
},
|
||||
"nullable": True,
|
||||
},
|
||||
{
|
||||
"name": "d",
|
||||
"type": {
|
||||
"type": "dictionary",
|
||||
"index_type": {"type": "int64"},
|
||||
"value_type": {"type": "string"},
|
||||
},
|
||||
"nullable": False,
|
||||
},
|
||||
],
|
||||
"metadata": {"key": "value"},
|
||||
}
|
||||
|
||||
actual_schema = dict_to_schema(json_schema)
|
||||
assert actual_schema == schema
|
||||
|
||||
|
||||
def test_temporal_types():
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("t32", pa.time32("s")),
|
||||
pa.field("t32ms", pa.time32("ms")),
|
||||
pa.field("t64", pa.time64("ns")),
|
||||
pa.field("ts", pa.timestamp("s")),
|
||||
pa.field("ts_us_tz", pa.timestamp("us", tz="America/New_York")),
|
||||
],
|
||||
)
|
||||
json_schema = schema_to_dict(schema)
|
||||
|
||||
assert json_schema == {
|
||||
"fields": [
|
||||
{"name": "t32", "type": {"type": "time32:s"}, "nullable": True},
|
||||
{"name": "t32ms", "type": {"type": "time32:ms"}, "nullable": True},
|
||||
{"name": "t64", "type": {"type": "time64:ns"}, "nullable": True},
|
||||
{"name": "ts", "type": {"type": "timestamp:s:"}, "nullable": True},
|
||||
{
|
||||
"name": "ts_us_tz",
|
||||
"type": {"type": "timestamp:us:America/New_York"},
|
||||
"nullable": True,
|
||||
},
|
||||
],
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
actual_schema = dict_to_schema(json_schema)
|
||||
assert actual_schema == schema
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.table import LanceTable
|
||||
@@ -89,7 +90,31 @@ def test_create_table(db):
|
||||
assert expected == tbl
|
||||
|
||||
|
||||
def test_empty_table(db):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float32()),
|
||||
]
|
||||
)
|
||||
tbl = LanceTable.create(db, "test", schema=schema)
|
||||
data = [
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
]
|
||||
tbl.add(data=data)
|
||||
|
||||
|
||||
def test_add(db):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"test",
|
||||
@@ -98,7 +123,19 @@ def test_add(db):
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
_add(table, schema)
|
||||
|
||||
table = LanceTable.create(db, "test2", schema=schema)
|
||||
table.add(
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
_add(table, schema)
|
||||
|
||||
|
||||
def _add(table, schema):
|
||||
# table = LanceTable(db, "test")
|
||||
assert len(table) == 2
|
||||
|
||||
@@ -113,13 +150,7 @@ def test_add(db):
|
||||
pa.array(["foo", "bar", "new"]),
|
||||
pa.array([10.0, 20.0, 30.0]),
|
||||
],
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
),
|
||||
schema=schema,
|
||||
)
|
||||
assert expected == table.to_arrow()
|
||||
|
||||
@@ -181,7 +212,21 @@ def test_create_index_method():
|
||||
|
||||
|
||||
def test_add_with_nans(db):
|
||||
# By default we drop bad input vectors
|
||||
# by default we raise an error on bad input vectors
|
||||
bad_data = [
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
|
||||
]
|
||||
for row in bad_data:
|
||||
with pytest.raises(ValueError):
|
||||
LanceTable.create(
|
||||
db,
|
||||
"error_test",
|
||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
|
||||
)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"drop_test",
|
||||
@@ -191,6 +236,7 @@ def test_add_with_nans(db):
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
assert len(table) == 1
|
||||
|
||||
@@ -210,18 +256,3 @@ def test_add_with_nans(db):
|
||||
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
|
||||
v = arrow_tbl["vector"].to_pylist()[0]
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
|
||||
bad_data = [
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
|
||||
]
|
||||
for row in bad_data:
|
||||
with pytest.raises(ValueError):
|
||||
LanceTable.create(
|
||||
db,
|
||||
"raise_test",
|
||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
|
||||
on_bad_vectors="raise",
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb-node"
|
||||
version = "0.1.10"
|
||||
version = "0.1.12"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
edition = "2018"
|
||||
@@ -19,3 +19,6 @@ lance = { workspace = true }
|
||||
vectordb = { path = "../../vectordb" }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }
|
||||
object_store = { workspace = true, features = ["aws"] }
|
||||
async-trait = "0"
|
||||
env_logger = "0"
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::io::Cursor;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::cast::as_list_array;
|
||||
@@ -25,10 +24,13 @@ use lance::arrow::{FixedSizeListArrayExt, RecordBatchExt};
|
||||
pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch {
|
||||
let column = record_batch
|
||||
.column_by_name("vector")
|
||||
.cloned()
|
||||
.expect("vector column is missing");
|
||||
let arr = as_list_array(column.deref());
|
||||
// TODO: we should just consume the underlaying js buffer in the future instead of this arrow around a bunch of times
|
||||
let arr = as_list_array(column.as_ref());
|
||||
let list_size = arr.values().len() / record_batch.num_rows();
|
||||
let r = FixedSizeListArray::try_new(arr.values(), list_size as i32).unwrap();
|
||||
let r =
|
||||
FixedSizeListArray::try_new_from_values(arr.values().to_owned(), list_size as i32).unwrap();
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||
"vector",
|
||||
|
||||
@@ -17,19 +17,23 @@ use std::convert::TryFrom;
|
||||
use std::ops::Deref;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use arrow_array::{Float32Array, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_array::{Float32Array, RecordBatchIterator};
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use async_trait::async_trait;
|
||||
use futures::{TryFutureExt, TryStreamExt};
|
||||
use lance::dataset::{WriteMode, WriteParams};
|
||||
use lance::dataset::{ReadParams, WriteMode, WriteParams};
|
||||
use lance::index::vector::MetricType;
|
||||
use lance::io::object_store::ObjectStoreParams;
|
||||
use neon::prelude::*;
|
||||
use neon::types::buffer::TypedArray;
|
||||
use object_store::aws::{AwsCredential, AwsCredentialProvider};
|
||||
use object_store::CredentialProvider;
|
||||
use once_cell::sync::OnceCell;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
use vectordb::database::Database;
|
||||
use vectordb::error::Error;
|
||||
use vectordb::table::Table;
|
||||
use vectordb::table::{OpenTableParams, Table};
|
||||
|
||||
use crate::arrow::arrow_buffer_to_record_batch;
|
||||
|
||||
@@ -49,8 +53,38 @@ struct JsTable {
|
||||
|
||||
impl Finalize for JsTable {}
|
||||
|
||||
// TODO: object_store didn't export this type so I copied it.
|
||||
// Make a requiest to object_store to export this type
|
||||
#[derive(Debug)]
|
||||
pub struct StaticCredentialProvider<T> {
|
||||
credential: Arc<T>,
|
||||
}
|
||||
|
||||
impl<T> StaticCredentialProvider<T> {
|
||||
pub fn new(credential: T) -> Self {
|
||||
Self {
|
||||
credential: Arc::new(credential),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> CredentialProvider for StaticCredentialProvider<T>
|
||||
where
|
||||
T: std::fmt::Debug + Send + Sync,
|
||||
{
|
||||
type Credential = T;
|
||||
|
||||
async fn get_credential(&self) -> object_store::Result<Arc<T>> {
|
||||
Ok(Arc::clone(&self.credential))
|
||||
}
|
||||
}
|
||||
|
||||
fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
|
||||
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
|
||||
static LOG: OnceCell<()> = OnceCell::new();
|
||||
|
||||
LOG.get_or_init(|| env_logger::init());
|
||||
|
||||
RUNTIME.get_or_try_init(|| Runtime::new().or_else(|err| cx.throw_error(err.to_string())))
|
||||
}
|
||||
@@ -97,19 +131,74 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
fn get_aws_creds<T>(
|
||||
cx: &mut FunctionContext,
|
||||
arg_starting_location: i32,
|
||||
) -> Result<Option<AwsCredentialProvider>, NeonResult<T>> {
|
||||
let secret_key_id = cx
|
||||
.argument_opt(arg_starting_location)
|
||||
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||
.flatten()
|
||||
.map(|v| v.value(cx));
|
||||
|
||||
let secret_key = cx
|
||||
.argument_opt(arg_starting_location + 1)
|
||||
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||
.flatten()
|
||||
.map(|v| v.value(cx));
|
||||
|
||||
let temp_token = cx
|
||||
.argument_opt(arg_starting_location + 2)
|
||||
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||
.flatten()
|
||||
.map(|v| v.value(cx));
|
||||
|
||||
match (secret_key_id, secret_key, temp_token) {
|
||||
(Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new(
|
||||
StaticCredentialProvider::new(AwsCredential {
|
||||
key_id: key_id,
|
||||
secret_key: key,
|
||||
token: optional_token,
|
||||
}),
|
||||
))),
|
||||
(None, None, None) => Ok(None),
|
||||
_ => Err(cx.throw_error("Invalid credentials configuration")),
|
||||
}
|
||||
}
|
||||
|
||||
fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let db = cx
|
||||
.this()
|
||||
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
||||
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
|
||||
let aws_creds = match get_aws_creds(&mut cx, 1) {
|
||||
Ok(creds) => creds,
|
||||
Err(err) => return err,
|
||||
};
|
||||
|
||||
let param = ReadParams {
|
||||
store_options: Some(ObjectStoreParams {
|
||||
aws_credentials: aws_creds,
|
||||
..ObjectStoreParams::default()
|
||||
}),
|
||||
..ReadParams::default()
|
||||
};
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
let database = db.database.clone();
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
rt.spawn(async move {
|
||||
let table_rst = database.open_table(&table_name).await;
|
||||
let table_rst = database
|
||||
.open_table_with_params(
|
||||
&table_name,
|
||||
OpenTableParams {
|
||||
open_table_params: param,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let table = Arc::new(Mutex::new(
|
||||
@@ -241,8 +330,6 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
"create" => WriteMode::Create,
|
||||
_ => return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes"),
|
||||
};
|
||||
let mut params = WriteParams::default();
|
||||
params.mode = mode;
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
@@ -250,11 +337,22 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let (deferred, promise) = cx.promise();
|
||||
let database = db.database.clone();
|
||||
|
||||
let aws_creds = match get_aws_creds(&mut cx, 3) {
|
||||
Ok(creds) => creds,
|
||||
Err(err) => return err,
|
||||
};
|
||||
|
||||
let params = WriteParams {
|
||||
store_params: Some(ObjectStoreParams {
|
||||
aws_credentials: aws_creds,
|
||||
..ObjectStoreParams::default()
|
||||
}),
|
||||
mode: mode,
|
||||
..WriteParams::default()
|
||||
};
|
||||
|
||||
rt.block_on(async move {
|
||||
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
|
||||
batches.into_iter().map(Ok),
|
||||
schema,
|
||||
));
|
||||
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||
let table_rst = database
|
||||
.create_table(&table_name, batch_reader, Some(params))
|
||||
.await;
|
||||
@@ -289,16 +387,27 @@ fn table_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let table = js_table.table.clone();
|
||||
let write_mode = write_mode_map.get(write_mode.as_str()).cloned();
|
||||
|
||||
let aws_creds = match get_aws_creds(&mut cx, 2) {
|
||||
Ok(creds) => creds,
|
||||
Err(err) => return err,
|
||||
};
|
||||
|
||||
let params = WriteParams {
|
||||
store_params: Some(ObjectStoreParams {
|
||||
aws_credentials: aws_creds,
|
||||
..ObjectStoreParams::default()
|
||||
}),
|
||||
mode: write_mode.unwrap_or(WriteMode::Append),
|
||||
..WriteParams::default()
|
||||
};
|
||||
|
||||
rt.block_on(async move {
|
||||
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
|
||||
batches.into_iter().map(Ok),
|
||||
schema,
|
||||
));
|
||||
let add_result = table.lock().unwrap().add(batch_reader, write_mode).await;
|
||||
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||
let add_result = table.lock().unwrap().add(batch_reader, Some(params)).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let added = add_result.or_else(|err| cx.throw_error(err.to_string()))?;
|
||||
Ok(cx.number(added as f64))
|
||||
let _added = add_result.or_else(|err| cx.throw_error(err.to_string()))?;
|
||||
Ok(cx.boolean(true))
|
||||
});
|
||||
});
|
||||
Ok(promise)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb"
|
||||
version = "0.1.10"
|
||||
version = "0.1.12"
|
||||
edition = "2021"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
|
||||
@@ -100,7 +100,7 @@ impl Database {
|
||||
pub async fn create_table(
|
||||
&self,
|
||||
name: &str,
|
||||
batches: Box<dyn RecordBatchReader>,
|
||||
batches: impl RecordBatchReader + Send + 'static,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<Table> {
|
||||
Table::create(&self.uri, name, batches, params).await
|
||||
|
||||
@@ -173,10 +173,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_setters_getters() {
|
||||
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
|
||||
let ds = Dataset::write(&mut batches, "memory://foo", None)
|
||||
.await
|
||||
.unwrap();
|
||||
let batches = make_test_batches();
|
||||
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||
let query = Query::new(Arc::new(ds), vector.clone());
|
||||
@@ -202,10 +200,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute() {
|
||||
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
|
||||
let ds = Dataset::write(&mut batches, "memory://foo", None)
|
||||
.await
|
||||
.unwrap();
|
||||
let batches = make_test_batches();
|
||||
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1; 128]);
|
||||
let query = Query::new(Arc::new(ds), vector.clone());
|
||||
@@ -213,7 +209,7 @@ mod tests {
|
||||
assert_eq!(result.is_ok(), true);
|
||||
}
|
||||
|
||||
fn make_test_batches() -> Box<dyn RecordBatchReader> {
|
||||
fn make_test_batches() -> impl RecordBatchReader + Send + 'static {
|
||||
let dim: usize = 128;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("key", DataType::Int32, false),
|
||||
@@ -227,11 +223,11 @@ mod tests {
|
||||
),
|
||||
ArrowField::new("uri", DataType::Utf8, true),
|
||||
]));
|
||||
Box::new(RecordBatchIterator::new(
|
||||
RecordBatchIterator::new(
|
||||
vec![RecordBatch::new_empty(schema.clone())]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema,
|
||||
))
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,8 +22,8 @@ use snafu::prelude::*;
|
||||
|
||||
use crate::error::{Error, InvalidTableNameSnafu, Result};
|
||||
use crate::index::vector::VectorIndexBuilder;
|
||||
use crate::WriteMode;
|
||||
use crate::query::Query;
|
||||
use crate::WriteMode;
|
||||
|
||||
pub const VECTOR_COLUMN_NAME: &str = "vector";
|
||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
@@ -117,7 +117,7 @@ impl Table {
|
||||
pub async fn create(
|
||||
base_uri: &str,
|
||||
name: &str,
|
||||
mut batches: Box<dyn RecordBatchReader>,
|
||||
batches: impl RecordBatchReader + Send + 'static,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<Self> {
|
||||
let base_path = Path::new(base_uri);
|
||||
@@ -127,7 +127,7 @@ impl Table {
|
||||
.to_str()
|
||||
.context(InvalidTableNameSnafu { name })?
|
||||
.to_string();
|
||||
let dataset = Dataset::write(&mut batches, &uri, params)
|
||||
let dataset = Dataset::write(batches, &uri, params)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists {
|
||||
@@ -176,14 +176,16 @@ impl Table {
|
||||
/// * The number of rows added
|
||||
pub async fn add(
|
||||
&mut self,
|
||||
mut batches: Box<dyn RecordBatchReader>,
|
||||
write_mode: Option<WriteMode>,
|
||||
) -> Result<usize> {
|
||||
let mut params = WriteParams::default();
|
||||
params.mode = write_mode.unwrap_or(WriteMode::Append);
|
||||
batches: impl RecordBatchReader + Send + 'static,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<()> {
|
||||
let params = params.unwrap_or(WriteParams {
|
||||
mode: WriteMode::Append,
|
||||
..WriteParams::default()
|
||||
});
|
||||
|
||||
self.dataset = Arc::new(Dataset::write(&mut batches, &self.uri, Some(params)).await?);
|
||||
Ok(batches.count())
|
||||
self.dataset = Arc::new(Dataset::write(batches, &self.uri, Some(params)).await?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Creates a new Query object that can be executed.
|
||||
@@ -207,12 +209,12 @@ impl Table {
|
||||
/// Merge new data into this table.
|
||||
pub async fn merge(
|
||||
&mut self,
|
||||
mut batches: Box<dyn RecordBatchReader>,
|
||||
batches: impl RecordBatchReader + Send + 'static,
|
||||
left_on: &str,
|
||||
right_on: &str,
|
||||
) -> Result<()> {
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
dataset.merge(&mut batches, left_on, right_on).await?;
|
||||
dataset.merge(batches, left_on, right_on).await?;
|
||||
self.dataset = Arc::new(dataset);
|
||||
Ok(())
|
||||
}
|
||||
@@ -253,8 +255,8 @@ mod tests {
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
|
||||
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
|
||||
let batches = make_test_batches();
|
||||
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -284,11 +286,11 @@ mod tests {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let batches: Box<dyn RecordBatchReader> = make_test_batches();
|
||||
let batches = make_test_batches();
|
||||
let _ = batches.schema().clone();
|
||||
Table::create(&uri, "test", batches, None).await.unwrap();
|
||||
|
||||
let batches: Box<dyn RecordBatchReader> = make_test_batches();
|
||||
let batches = make_test_batches();
|
||||
let result = Table::create(&uri, "test", batches, None).await;
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
@@ -301,12 +303,12 @@ mod tests {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let batches: Box<dyn RecordBatchReader> = make_test_batches();
|
||||
let batches = make_test_batches();
|
||||
let schema = batches.schema().clone();
|
||||
let mut table = Table::create(&uri, "test", batches, None).await.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
|
||||
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
|
||||
let new_batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
||||
@@ -315,7 +317,7 @@ mod tests {
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
));
|
||||
);
|
||||
|
||||
table.add(new_batches, None).await.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 20);
|
||||
@@ -327,12 +329,12 @@ mod tests {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let batches: Box<dyn RecordBatchReader> = make_test_batches();
|
||||
let batches = make_test_batches();
|
||||
let schema = batches.schema().clone();
|
||||
let mut table = Table::create(uri, "test", batches, None).await.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
|
||||
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
|
||||
let new_batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
||||
@@ -341,10 +343,15 @@ mod tests {
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
));
|
||||
);
|
||||
|
||||
let param: WriteParams = WriteParams {
|
||||
mode: WriteMode::Overwrite,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
table
|
||||
.add(new_batches, Some(WriteMode::Overwrite))
|
||||
.add(new_batches, Some(param))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
@@ -357,8 +364,8 @@ mod tests {
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
|
||||
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
|
||||
let batches = make_test_batches();
|
||||
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -369,7 +376,7 @@ mod tests {
|
||||
assert_eq!(vector, query.query_vector);
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
#[derive(Default, Debug)]
|
||||
struct NoOpCacheWrapper {
|
||||
called: AtomicBool,
|
||||
}
|
||||
@@ -396,8 +403,8 @@ mod tests {
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
|
||||
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
|
||||
let batches = make_test_batches();
|
||||
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -417,15 +424,15 @@ mod tests {
|
||||
assert!(wrapper.called());
|
||||
}
|
||||
|
||||
fn make_test_batches() -> Box<dyn RecordBatchReader> {
|
||||
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
||||
Box::new(RecordBatchIterator::new(
|
||||
RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(0..10))],
|
||||
)],
|
||||
schema,
|
||||
))
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -465,9 +472,7 @@ mod tests {
|
||||
schema,
|
||||
);
|
||||
|
||||
let reader: Box<dyn RecordBatchReader + Send> = Box::new(batches);
|
||||
let mut table = Table::create(uri, "test", reader, None).await.unwrap();
|
||||
|
||||
let mut table = Table::create(uri, "test", batches, None).await.unwrap();
|
||||
let mut i = IvfPQIndexBuilder::new();
|
||||
|
||||
let index_builder = i
|
||||
|
||||
Reference in New Issue
Block a user