mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
9 Commits
python-v0.
...
v0.2.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
31c5df6d99 | ||
|
|
dbf37a0434 | ||
|
|
f20f19b804 | ||
|
|
55207ce844 | ||
|
|
c21f9cdda0 | ||
|
|
bc38abb781 | ||
|
|
731f86e44c | ||
|
|
31dad71c94 | ||
|
|
9585f550b3 |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.2.5
|
current_version = 0.2.6
|
||||||
commit = True
|
commit = True
|
||||||
message = Bump version: {current_version} → {new_version}
|
message = Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
3
.github/workflows/node.yml
vendored
3
.github/workflows/node.yml
vendored
@@ -9,6 +9,7 @@ on:
|
|||||||
- node/**
|
- node/**
|
||||||
- rust/ffi/node/**
|
- rust/ffi/node/**
|
||||||
- .github/workflows/node.yml
|
- .github/workflows/node.yml
|
||||||
|
- docker-compose.yml
|
||||||
|
|
||||||
env:
|
env:
|
||||||
# Disable full debug symbol generation to speed up CI build and keep memory down
|
# Disable full debug symbol generation to speed up CI build and keep memory down
|
||||||
@@ -133,7 +134,7 @@ jobs:
|
|||||||
cache: 'npm'
|
cache: 'npm'
|
||||||
cache-dependency-path: node/package-lock.json
|
cache-dependency-path: node/package-lock.json
|
||||||
- name: start local stack
|
- name: start local stack
|
||||||
run: docker compose -f ../docker-compose.yml up -d
|
run: docker compose -f ../docker-compose.yml up -d --wait
|
||||||
- name: create s3
|
- name: create s3
|
||||||
run: aws s3 mb s3://lancedb-integtest --endpoint $AWS_ENDPOINT
|
run: aws s3 mb s3://lancedb-integtest --endpoint $AWS_ENDPOINT
|
||||||
- name: create ddb
|
- name: create ddb
|
||||||
|
|||||||
34
.github/workflows/python.yml
vendored
34
.github/workflows/python.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
|||||||
- name: isort
|
- name: isort
|
||||||
run: isort --check --diff --quiet .
|
run: isort --check --diff --quiet .
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: pytest -x -v --durations=30 tests
|
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||||
- name: doctest
|
- name: doctest
|
||||||
run: pytest --doctest-modules lancedb
|
run: pytest --doctest-modules lancedb
|
||||||
mac:
|
mac:
|
||||||
@@ -65,4 +65,34 @@ jobs:
|
|||||||
- name: Black
|
- name: Black
|
||||||
run: black --check --diff --no-color --quiet .
|
run: black --check --diff --no-color --quiet .
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: pytest -x -v --durations=30 tests
|
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||||
|
pydantic1x:
|
||||||
|
timeout-minutes: 30
|
||||||
|
runs-on: "ubuntu-22.04"
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: python
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
lfs: true
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: 3.9
|
||||||
|
- name: Install lancedb
|
||||||
|
run: |
|
||||||
|
pip install "pydantic<2"
|
||||||
|
pip install -e .[tests]
|
||||||
|
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||||
|
pip install pytest pytest-mock black isort
|
||||||
|
- name: Black
|
||||||
|
run: black --check --diff --no-color --quiet .
|
||||||
|
- name: isort
|
||||||
|
run: isort --check --diff --quiet .
|
||||||
|
- name: Run tests
|
||||||
|
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||||
|
- name: doctest
|
||||||
|
run: pytest --doctest-modules lancedb
|
||||||
@@ -5,8 +5,8 @@ exclude = ["python"]
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.7.4", "features" = ["dynamodb"] }
|
lance = { "version" = "=0.7.5", "features" = ["dynamodb"] }
|
||||||
lance-linalg = { "version" = "=0.7.4" }
|
lance-linalg = { "version" = "=0.7.5" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "43.0.0", optional = false }
|
arrow = { version = "43.0.0", optional = false }
|
||||||
arrow-array = "43.0"
|
arrow-array = "43.0"
|
||||||
|
|||||||
@@ -13,3 +13,6 @@ services:
|
|||||||
- AWS_SECRET_ACCESS_KEY=SECRETKEY
|
- AWS_SECRET_ACCESS_KEY=SECRETKEY
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: [ "CMD", "curl", "-f", "http://localhost:4566/health" ]
|
test: [ "CMD", "curl", "-f", "http://localhost:4566/health" ]
|
||||||
|
interval: 5s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
|||||||
@@ -84,7 +84,17 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
|
|||||||
```
|
```
|
||||||
|
|
||||||
### From Pydantic Models
|
### From Pydantic Models
|
||||||
LanceDB supports to create Apache Arrow Schema from a Pydantic BaseModel via pydantic_to_schema() method.
|
When you create an empty table without data, you must specify the table schema.
|
||||||
|
LanceDB supports creating tables by specifying a pyarrow schema or a specialized
|
||||||
|
pydantic model called `LanceModel`.
|
||||||
|
|
||||||
|
For example, the following Content model specifies a table with 5 columns:
|
||||||
|
movie_id, vector, genres, title, and imdb_id. When you create a table, you can
|
||||||
|
pass the class as the value of the `schema` parameter to `create_table`.
|
||||||
|
The `vector` column is a `Vector` type, which is a specialized pydantic type that
|
||||||
|
can be configured with the vector dimensions. It is also important to note that
|
||||||
|
LanceDB only understands subclasses of `lancedb.pydantic.LanceModel`
|
||||||
|
(which itself derives from `pydantic.BaseModel`).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from lancedb.pydantic import Vector, LanceModel
|
from lancedb.pydantic import Vector, LanceModel
|
||||||
|
|||||||
@@ -26,15 +26,19 @@ pip install lancedb
|
|||||||
|
|
||||||
## Embeddings
|
## Embeddings
|
||||||
|
|
||||||
::: lancedb.embeddings.with_embeddings
|
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.EmbeddingFunctionRegistry
|
::: lancedb.embeddings.functions.EmbeddingFunctionRegistry
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.EmbeddingFunctionModel
|
::: lancedb.embeddings.functions.EmbeddingFunction
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.TextEmbeddingFunctionModel
|
::: lancedb.embeddings.functions.TextEmbeddingFunction
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.SentenceTransformerEmbeddingFunction
|
::: lancedb.embeddings.functions.SentenceTransformerEmbeddings
|
||||||
|
|
||||||
|
::: lancedb.embeddings.functions.OpenAIEmbeddings
|
||||||
|
|
||||||
|
::: lancedb.embeddings.functions.OpenClipEmbeddings
|
||||||
|
|
||||||
|
::: lancedb.embeddings.with_embeddings
|
||||||
|
|
||||||
## Context
|
## Context
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.2.5",
|
"version": "0.2.6",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"types": "dist/index.d.ts",
|
"types": "dist/index.d.ts",
|
||||||
@@ -81,10 +81,10 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.2.5",
|
"@lancedb/vectordb-darwin-arm64": "0.2.6",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.2.5",
|
"@lancedb/vectordb-darwin-x64": "0.2.6",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.5",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.2.6",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.2.5",
|
"@lancedb/vectordb-linux-x64-gnu": "0.2.6",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.2.5"
|
"@lancedb/vectordb-win32-x64-msvc": "0.2.6"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ chai.use(chaiAsPromised)
|
|||||||
|
|
||||||
describe('LanceDB AWS Integration test', function () {
|
describe('LanceDB AWS Integration test', function () {
|
||||||
it('s3+ddb schema is processed correctly', async function () {
|
it('s3+ddb schema is processed correctly', async function () {
|
||||||
this.timeout(5000)
|
this.timeout(15000)
|
||||||
|
|
||||||
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
||||||
// THE API WILL CHANGE
|
// THE API WILL CHANGE
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import * as chaiAsPromised from 'chai-as-promised'
|
|||||||
|
|
||||||
import * as lancedb from '../index'
|
import * as lancedb from '../index'
|
||||||
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions } from '../index'
|
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions } from '../index'
|
||||||
import { Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray } from 'apache-arrow'
|
import { FixedSizeList, Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray, Float32 } from 'apache-arrow'
|
||||||
|
|
||||||
const expect = chai.expect
|
const expect = chai.expect
|
||||||
const assert = chai.assert
|
const assert = chai.assert
|
||||||
@@ -258,6 +258,36 @@ describe('LanceDB client', function () {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('when searching an empty dataset', function () {
|
||||||
|
it('should not fail', async function () {
|
||||||
|
const dir = await track().mkdir('lancejs')
|
||||||
|
const con = await lancedb.connect(dir)
|
||||||
|
|
||||||
|
const schema = new Schema(
|
||||||
|
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
|
||||||
|
)
|
||||||
|
const table = await con.createTable({ name: 'vectors', schema })
|
||||||
|
const result = await table.search(Array(128).fill(0.1)).execute()
|
||||||
|
assert.isEmpty(result)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('when searching an empty-after-delete dataset', function () {
|
||||||
|
it('should not fail', async function () {
|
||||||
|
const dir = await track().mkdir('lancejs')
|
||||||
|
const con = await lancedb.connect(dir)
|
||||||
|
|
||||||
|
const schema = new Schema(
|
||||||
|
[new Field('vector', new FixedSizeList(128, new Field('float32', new Float32())))]
|
||||||
|
)
|
||||||
|
const table = await con.createTable({ name: 'vectors', schema })
|
||||||
|
await table.add([{ vector: Array(128).fill(0.1) }])
|
||||||
|
await table.delete('vector IS NOT NULL')
|
||||||
|
const result = await table.search(Array(128).fill(0.1)).execute()
|
||||||
|
assert.isEmpty(result)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
describe('when creating a vector index', function () {
|
describe('when creating a vector index', function () {
|
||||||
it('overwrite all records in a table', async function () {
|
it('overwrite all records in a table', async function () {
|
||||||
const uri = await createTestDB(32, 300)
|
const uri = await createTestDB(32, 300)
|
||||||
|
|||||||
@@ -11,12 +11,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import importlib.metadata
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .db import URI, DBConnection, LanceDBConnection
|
from .db import URI, DBConnection, LanceDBConnection
|
||||||
from .remote.db import RemoteDBConnection
|
from .remote.db import RemoteDBConnection
|
||||||
from .schema import vector
|
from .schema import vector
|
||||||
|
|
||||||
|
__version__ = importlib.metadata.version("lancedb")
|
||||||
|
|
||||||
|
|
||||||
def connect(
|
def connect(
|
||||||
uri: URI,
|
uri: URI,
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pyarrow as pa
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lancedb.embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction
|
||||||
|
|
||||||
# import lancedb so we don't have to in every example
|
# import lancedb so we don't have to in every example
|
||||||
|
|
||||||
@@ -22,17 +22,19 @@ def doctest_setup(monkeypatch, tmpdir):
|
|||||||
registry = EmbeddingFunctionRegistry.get_instance()
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
|
||||||
|
|
||||||
@registry.register()
|
@registry.register("test")
|
||||||
class MockEmbeddingFunction(EmbeddingFunctionModel):
|
class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
||||||
def __call__(self, data):
|
"""
|
||||||
if isinstance(data, str):
|
Return the hash of the first 10 characters
|
||||||
data = [data]
|
"""
|
||||||
elif isinstance(data, pa.ChunkedArray):
|
|
||||||
data = data.combine_chunks().to_pylist()
|
|
||||||
elif isinstance(data, pa.Array):
|
|
||||||
data = data.to_pylist()
|
|
||||||
|
|
||||||
return [self.embed(row) for row in data]
|
def generate_embeddings(self, texts):
|
||||||
|
return [self._compute_one_embedding(row) for row in texts]
|
||||||
|
|
||||||
def embed(self, row):
|
def _compute_one_embedding(self, row):
|
||||||
return [float(hash(c)) for c in row[:10]]
|
emb = np.array([float(hash(c)) for c in row[:10]])
|
||||||
|
emb /= np.linalg.norm(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
return 10
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import pyarrow as pa
|
|||||||
from pyarrow import fs
|
from pyarrow import fs
|
||||||
|
|
||||||
from .common import DATA, URI
|
from .common import DATA, URI
|
||||||
from .embeddings import EmbeddingFunctionModel
|
from .embeddings import EmbeddingFunctionConfig
|
||||||
from .pydantic import LanceModel
|
from .pydantic import LanceModel
|
||||||
from .table import LanceTable, Table
|
from .table import LanceTable, Table
|
||||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme
|
from .util import fs_from_uri, get_uri_location, get_uri_scheme
|
||||||
@@ -290,7 +290,7 @@ class LanceDBConnection(DBConnection):
|
|||||||
mode: str = "create",
|
mode: str = "create",
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
embedding_functions: Optional[List[EmbeddingFunctionModel]] = None,
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
) -> LanceTable:
|
) -> LanceTable:
|
||||||
"""Create a table in the database.
|
"""Create a table in the database.
|
||||||
|
|
||||||
|
|||||||
@@ -13,10 +13,12 @@
|
|||||||
|
|
||||||
|
|
||||||
from .functions import (
|
from .functions import (
|
||||||
REGISTRY,
|
EmbeddingFunction,
|
||||||
EmbeddingFunctionModel,
|
EmbeddingFunctionConfig,
|
||||||
EmbeddingFunctionRegistry,
|
EmbeddingFunctionRegistry,
|
||||||
SentenceTransformerEmbeddingFunction,
|
OpenAIEmbeddings,
|
||||||
TextEmbeddingFunctionModel,
|
OpenClipEmbeddings,
|
||||||
|
SentenceTransformerEmbeddings,
|
||||||
|
TextEmbeddingFunction,
|
||||||
)
|
)
|
||||||
from .utils import with_embeddings
|
from .utils import with_embeddings
|
||||||
|
|||||||
@@ -10,43 +10,78 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import concurrent.futures
|
||||||
|
import importlib
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import urllib.error
|
||||||
|
import urllib.parse as urlparse
|
||||||
|
import urllib.request
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from cachetools import cached
|
from cachetools import cached
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingFunctionRegistry:
|
class EmbeddingFunctionRegistry:
|
||||||
"""
|
"""
|
||||||
This is a singleton class used to register embedding functions
|
This is a singleton class used to register embedding functions
|
||||||
and fetch them by name. It also handles serializing and deserializing
|
and fetch them by name. It also handles serializing and deserializing.
|
||||||
|
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||||
|
or TextEmbeddingFunction and registering it with the registry.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
>>> @registry.register("my-embedding-function")
|
||||||
|
... class MyEmbeddingFunction(EmbeddingFunction):
|
||||||
|
... def ndims(self) -> int:
|
||||||
|
... return 128
|
||||||
|
...
|
||||||
|
... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||||
|
... return self.compute_source_embeddings(query, *args, **kwargs)
|
||||||
|
...
|
||||||
|
... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||||
|
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
||||||
|
...
|
||||||
|
>>> registry.get("my-embedding-function")
|
||||||
|
<class 'lancedb.embeddings.functions.MyEmbeddingFunction'>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls):
|
def get_instance(cls):
|
||||||
return REGISTRY
|
return __REGISTRY__
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._functions = {}
|
self._functions = {}
|
||||||
|
|
||||||
def register(self):
|
def register(self, alias: str = None):
|
||||||
"""
|
"""
|
||||||
This creates a decorator that can be used to register
|
This creates a decorator that can be used to register
|
||||||
an EmbeddingFunctionModel.
|
an EmbeddingFunction.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
alias : Optional[str]
|
||||||
|
a human friendly name for the embedding function. If not
|
||||||
|
provided, the class name will be used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This is a decorator for a class that inherits from BaseModel
|
# This is a decorator for a class that inherits from BaseModel
|
||||||
# It adds the class to the registry
|
# It adds the class to the registry
|
||||||
def decorator(cls):
|
def decorator(cls):
|
||||||
if not issubclass(cls, EmbeddingFunctionModel):
|
if not issubclass(cls, EmbeddingFunction):
|
||||||
raise TypeError("Must be a subclass of EmbeddingFunctionModel")
|
raise TypeError("Must be a subclass of EmbeddingFunction")
|
||||||
if cls.__name__ in self._functions:
|
if cls.__name__ in self._functions:
|
||||||
raise KeyError(f"{cls.__name__} was already registered")
|
raise KeyError(f"{cls.__name__} was already registered")
|
||||||
self._functions[cls.__name__] = cls
|
key = alias or cls.__name__
|
||||||
|
self._functions[key] = cls
|
||||||
|
cls.__embedding_function_registry_alias__ = alias
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -57,13 +92,22 @@ class EmbeddingFunctionRegistry:
|
|||||||
"""
|
"""
|
||||||
self._functions = {}
|
self._functions = {}
|
||||||
|
|
||||||
def load(self, name: str):
|
def get(self, name: str):
|
||||||
"""
|
"""
|
||||||
Fetch an embedding function class by name
|
Fetch an embedding function class by name
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The name of the embedding function to fetch
|
||||||
|
Either the alias or the class name if no alias was provided
|
||||||
|
during registration
|
||||||
"""
|
"""
|
||||||
return self._functions[name]
|
return self._functions[name]
|
||||||
|
|
||||||
def parse_functions(self, metadata: Optional[dict]) -> dict:
|
def parse_functions(
|
||||||
|
self, metadata: Optional[Dict[bytes, bytes]]
|
||||||
|
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
||||||
"""
|
"""
|
||||||
Parse the metadata from an arrow table and
|
Parse the metadata from an arrow table and
|
||||||
return a mapping of the vector column to the
|
return a mapping of the vector column to the
|
||||||
@@ -71,9 +115,9 @@ class EmbeddingFunctionRegistry:
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
metadata : Optional[dict]
|
metadata : Optional[Dict[bytes, bytes]]
|
||||||
The metadata from an arrow table. Note that
|
The metadata from an arrow table. Note that
|
||||||
the keys and values are bytes.
|
the keys and values are bytes (pyarrow api)
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -86,68 +130,94 @@ class EmbeddingFunctionRegistry:
|
|||||||
return {}
|
return {}
|
||||||
serialized = metadata[b"embedding_functions"]
|
serialized = metadata[b"embedding_functions"]
|
||||||
raw_list = json.loads(serialized.decode("utf-8"))
|
raw_list = json.loads(serialized.decode("utf-8"))
|
||||||
functions = {}
|
return {
|
||||||
for obj in raw_list:
|
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||||
model = self.load(obj["schema"]["title"])
|
vector_column=obj["vector_column"],
|
||||||
functions[obj["model"]["vector_column"]] = model(**obj["model"])
|
source_column=obj["source_column"],
|
||||||
return functions
|
function=self.get(obj["name"])(**obj["model"]),
|
||||||
|
)
|
||||||
|
for obj in raw_list
|
||||||
|
}
|
||||||
|
|
||||||
def function_to_metadata(self, func):
|
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
||||||
"""
|
"""
|
||||||
Convert the given embedding function and source / vector column configs
|
Convert the given embedding function and source / vector column configs
|
||||||
into a config dictionary that can be serialized into arrow metadata
|
into a config dictionary that can be serialized into arrow metadata
|
||||||
"""
|
"""
|
||||||
schema = func.model_json_schema()
|
func = conf.function
|
||||||
json_data = func.model_dump()
|
name = getattr(
|
||||||
|
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
||||||
|
)
|
||||||
|
json_data = func.safe_model_dump()
|
||||||
return {
|
return {
|
||||||
"schema": schema,
|
"name": name,
|
||||||
"model": json_data,
|
"model": json_data,
|
||||||
|
"source_column": conf.source_column,
|
||||||
|
"vector_column": conf.vector_column,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_table_metadata(self, func_list):
|
def get_table_metadata(self, func_list):
|
||||||
"""
|
"""
|
||||||
Convert a list of embedding functions and source / vector column configs
|
Convert a list of embedding functions and source / vector configs
|
||||||
into a config dictionary that can be serialized into arrow metadata
|
into a config dictionary that can be serialized into arrow metadata
|
||||||
"""
|
"""
|
||||||
|
if func_list is None or len(func_list) == 0:
|
||||||
|
return None
|
||||||
json_data = [self.function_to_metadata(func) for func in func_list]
|
json_data = [self.function_to_metadata(func) for func in func_list]
|
||||||
# Note that metadata dictionary values must be bytes so we need to json dump then utf8 encode
|
# Note that metadata dictionary values must be bytes
|
||||||
|
# so we need to json dump then utf8 encode
|
||||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||||
return {"embedding_functions": metadata}
|
return {"embedding_functions": metadata}
|
||||||
|
|
||||||
|
|
||||||
REGISTRY = EmbeddingFunctionRegistry()
|
# Global instance
|
||||||
|
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||||
|
|
||||||
class EmbeddingFunctionModel(BaseModel, ABC):
|
|
||||||
"""
|
|
||||||
A callable ABC for embedding functions
|
|
||||||
"""
|
|
||||||
|
|
||||||
source_column: Optional[str]
|
|
||||||
vector_column: str
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __call__(self, *args, **kwargs) -> List[np.array]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||||
|
IMAGES = Union[
|
||||||
|
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
|
class EmbeddingFunction(BaseModel, ABC):
|
||||||
"""
|
"""
|
||||||
A callable ABC for embedding functions that take text as input
|
An ABC for embedding functions.
|
||||||
|
|
||||||
|
All concrete embedding functions must implement the following:
|
||||||
|
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
||||||
|
2. get_source_embeddings() which returns a list of embeddings for the source column
|
||||||
|
For text data, the two will be the same. For multi-modal data, the source column
|
||||||
|
might be images and the vector column might be text.
|
||||||
|
3. ndims method which returns the number of dimensions of the vector column
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
_ndims: int = PrivateAttr()
|
||||||
texts = self.sanitize_input(texts)
|
|
||||||
return self.generate_embeddings(texts)
|
@classmethod
|
||||||
|
def create(cls, **kwargs):
|
||||||
|
"""
|
||||||
|
Create an instance of the embedding function
|
||||||
|
"""
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for a given user query
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for the source column in the database
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Sanitize the input to the embedding function. This is called
|
Sanitize the input to the embedding function.
|
||||||
before generate_embeddings() and is useful for stripping
|
|
||||||
whitespace, lowercasing, etc.
|
|
||||||
"""
|
"""
|
||||||
if isinstance(texts, str):
|
if isinstance(texts, str):
|
||||||
texts = [texts]
|
texts = [texts]
|
||||||
@@ -157,6 +227,78 @@ class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
|
|||||||
texts = texts.combine_chunks().to_pylist()
|
texts = texts.combine_chunks().to_pylist()
|
||||||
return texts
|
return texts
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def safe_import(cls, module: str, mitigation=None):
|
||||||
|
"""
|
||||||
|
Import the specified module. If the module is not installed,
|
||||||
|
raise an ImportError with a helpful message.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
module : str
|
||||||
|
The name of the module to import
|
||||||
|
mitigation : Optional[str]
|
||||||
|
The package(s) to install to mitigate the error.
|
||||||
|
If not provided then the module name will be used.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return importlib.import_module(module)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(f"Please install {mitigation or module}")
|
||||||
|
|
||||||
|
def safe_model_dump(self):
|
||||||
|
from ..pydantic import PYDANTIC_VERSION
|
||||||
|
|
||||||
|
if PYDANTIC_VERSION.major < 2:
|
||||||
|
return dict(self)
|
||||||
|
return self.model_dump()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def ndims(self):
|
||||||
|
"""
|
||||||
|
Return the dimensions of the vector column
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def SourceField(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Creates a pydantic Field that can automatically annotate
|
||||||
|
the source column for this embedding function
|
||||||
|
"""
|
||||||
|
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
||||||
|
|
||||||
|
def VectorField(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Creates a pydantic Field that can automatically annotate
|
||||||
|
the target vector column for this embedding function
|
||||||
|
"""
|
||||||
|
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingFunctionConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
This model encapsulates the configuration for a embedding function
|
||||||
|
in a lancedb table. It holds the embedding function, the source column,
|
||||||
|
and the vector column
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_column: str
|
||||||
|
source_column: str
|
||||||
|
function: EmbeddingFunction
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbeddingFunction(EmbeddingFunction):
|
||||||
|
"""
|
||||||
|
A callable ABC for embedding functions that take text as input
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||||
|
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||||
|
|
||||||
|
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||||
|
texts = self.sanitize_input(texts)
|
||||||
|
return self.generate_embeddings(texts)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
self, texts: Union[List[str], np.ndarray]
|
self, texts: Union[List[str], np.ndarray]
|
||||||
@@ -167,15 +309,25 @@ class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@REGISTRY.register()
|
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||||
class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
|
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||||
|
|
||||||
|
|
||||||
|
@register("sentence-transformers")
|
||||||
|
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||||
"""
|
"""
|
||||||
An embedding function that uses the sentence-transformers library
|
An embedding function that uses the sentence-transformers library
|
||||||
|
|
||||||
|
https://huggingface.co/sentence-transformers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "all-MiniLM-L6-v2"
|
name: str = "all-MiniLM-L6-v2"
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
normalize: bool = False
|
normalize: bool = True
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._ndims = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def embedding_model(self):
|
def embedding_model(self):
|
||||||
@@ -186,6 +338,11 @@ class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
|
|||||||
"""
|
"""
|
||||||
return self.__class__.get_embedding_model(self.name, self.device)
|
return self.__class__.get_embedding_model(self.name, self.device)
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
if self._ndims is None:
|
||||||
|
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||||
|
return self._ndims
|
||||||
|
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
self, texts: Union[List[str], np.ndarray]
|
self, texts: Union[List[str], np.ndarray]
|
||||||
) -> List[np.array]:
|
) -> List[np.array]:
|
||||||
@@ -220,9 +377,201 @@ class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
|
|||||||
|
|
||||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||||
"""
|
"""
|
||||||
try:
|
sentence_transformers = cls.safe_import(
|
||||||
from sentence_transformers import SentenceTransformer
|
"sentence_transformers", "sentence-transformers"
|
||||||
|
)
|
||||||
|
return sentence_transformers.SentenceTransformer(name, device=device)
|
||||||
|
|
||||||
return SentenceTransformer(name, device=device)
|
|
||||||
except ImportError:
|
@register("openai")
|
||||||
raise ValueError("Please install sentence_transformers")
|
class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the OpenAI API
|
||||||
|
|
||||||
|
https://platform.openai.com/docs/guides/embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "text-embedding-ada-002"
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
# TODO don't hardcode this
|
||||||
|
return 1536
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given texts
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts: list[str] or np.ndarray (of str)
|
||||||
|
The texts to embed
|
||||||
|
"""
|
||||||
|
# TODO retry, rate limit, token limit
|
||||||
|
openai = self.safe_import("openai")
|
||||||
|
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
|
||||||
|
return [v["embedding"] for v in rs]
|
||||||
|
|
||||||
|
|
||||||
|
@register("open-clip")
|
||||||
|
class OpenClipEmbeddings(EmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the OpenClip API
|
||||||
|
For multi-modal text-to-image search
|
||||||
|
|
||||||
|
https://github.com/mlfoundations/open_clip
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "ViT-B-32"
|
||||||
|
pretrained: str = "laion2b_s34b_b79k"
|
||||||
|
device: str = "cpu"
|
||||||
|
batch_size: int = 64
|
||||||
|
normalize: bool = True
|
||||||
|
_model = PrivateAttr()
|
||||||
|
_preprocess = PrivateAttr()
|
||||||
|
_tokenizer = PrivateAttr()
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
open_clip = self.safe_import("open_clip", "open-clip")
|
||||||
|
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||||
|
self.name, pretrained=self.pretrained
|
||||||
|
)
|
||||||
|
model.to(self.device)
|
||||||
|
self._model, self._preprocess = model, preprocess
|
||||||
|
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||||
|
self._ndims = None
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
if self._ndims is None:
|
||||||
|
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||||
|
return self._ndims
|
||||||
|
|
||||||
|
def compute_query_embeddings(
|
||||||
|
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||||
|
) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for a given user query
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : Union[str, PIL.Image.Image]
|
||||||
|
The query to embed. A query can be either text or an image.
|
||||||
|
"""
|
||||||
|
if isinstance(query, str):
|
||||||
|
return [self.generate_text_embeddings(query)]
|
||||||
|
else:
|
||||||
|
PIL = self.safe_import("PIL", "pillow")
|
||||||
|
if isinstance(query, PIL.Image.Image):
|
||||||
|
return [self.generate_image_embedding(query)]
|
||||||
|
else:
|
||||||
|
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||||
|
|
||||||
|
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||||
|
torch = self.safe_import("torch")
|
||||||
|
text = self.sanitize_input(text)
|
||||||
|
text = self._tokenizer(text)
|
||||||
|
text.to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
text_features = self._model.encode_text(text.to(self.device))
|
||||||
|
if self.normalize:
|
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
return text_features.cpu().numpy().squeeze()
|
||||||
|
|
||||||
|
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||||
|
"""
|
||||||
|
Sanitize the input to the embedding function.
|
||||||
|
"""
|
||||||
|
if isinstance(images, (str, bytes)):
|
||||||
|
images = [images]
|
||||||
|
elif isinstance(images, pa.Array):
|
||||||
|
images = images.to_pylist()
|
||||||
|
elif isinstance(images, pa.ChunkedArray):
|
||||||
|
images = images.combine_chunks().to_pylist()
|
||||||
|
return images
|
||||||
|
|
||||||
|
def compute_source_embeddings(
|
||||||
|
self, images: IMAGES, *args, **kwargs
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given images
|
||||||
|
"""
|
||||||
|
images = self.sanitize_input(images)
|
||||||
|
embeddings = []
|
||||||
|
for i in range(0, len(images), self.batch_size):
|
||||||
|
j = min(i + self.batch_size, len(images))
|
||||||
|
batch = images[i:j]
|
||||||
|
embeddings.extend(self._parallel_get(batch))
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Issue concurrent requests to retrieve the image data
|
||||||
|
"""
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(self.generate_image_embedding, image)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
return [future.result() for future in futures]
|
||||||
|
|
||||||
|
def generate_image_embedding(
|
||||||
|
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Generate the embedding for a single image
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
image : Union[str, bytes, PIL.Image.Image]
|
||||||
|
The image to embed. If the image is a str, it is treated as a uri.
|
||||||
|
If the image is bytes, it is treated as the raw image bytes.
|
||||||
|
"""
|
||||||
|
torch = self.safe_import("torch")
|
||||||
|
# TODO handle retry and errors for https
|
||||||
|
image = self._to_pil(image)
|
||||||
|
image = self._preprocess(image).unsqueeze(0)
|
||||||
|
with torch.no_grad():
|
||||||
|
return self._encode_and_normalize_image(image)
|
||||||
|
|
||||||
|
def _to_pil(self, image: Union[str, bytes]):
|
||||||
|
PIL = self.safe_import("PIL", "pillow")
|
||||||
|
if isinstance(image, bytes):
|
||||||
|
return PIL.Image.open(io.BytesIO(image))
|
||||||
|
if isinstance(image, PIL.Image.Image):
|
||||||
|
return image
|
||||||
|
elif isinstance(image, str):
|
||||||
|
parsed = urlparse.urlparse(image)
|
||||||
|
# TODO handle drive letter on windows.
|
||||||
|
if parsed.scheme == "file":
|
||||||
|
return PIL.Image.open(parsed.path)
|
||||||
|
elif parsed.scheme == "":
|
||||||
|
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||||
|
elif parsed.scheme.startswith("http"):
|
||||||
|
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||||
|
|
||||||
|
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||||
|
"""
|
||||||
|
encode a single image tensor and optionally normalize the output
|
||||||
|
"""
|
||||||
|
image_features = self._model.encode_image(image_tensor)
|
||||||
|
if self.normalize:
|
||||||
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
return image_features.cpu().numpy().squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
def url_retrieve(url: str):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
url: str
|
||||||
|
URL to download from
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(url) as conn:
|
||||||
|
return conn.read()
|
||||||
|
except (socket.gaierror, urllib.error.URLError) as err:
|
||||||
|
raise ConnectionError("could not download {} due to {}".format(url, err))
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ import pyarrow as pa
|
|||||||
import pydantic
|
import pydantic
|
||||||
import semver
|
import semver
|
||||||
|
|
||||||
|
from .embeddings import EmbeddingFunctionRegistry
|
||||||
|
|
||||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||||
try:
|
try:
|
||||||
from pydantic_core import CoreSchema, core_schema
|
from pydantic_core import CoreSchema, core_schema
|
||||||
@@ -126,7 +128,7 @@ def Vector(
|
|||||||
def validate(cls, v):
|
def validate(cls, v):
|
||||||
if not isinstance(v, (list, range, np.ndarray)) or len(v) != dim:
|
if not isinstance(v, (list, range, np.ndarray)) or len(v) != dim:
|
||||||
raise TypeError("A list of numbers or numpy.ndarray is needed")
|
raise TypeError("A list of numbers or numpy.ndarray is needed")
|
||||||
return v
|
return cls(v)
|
||||||
|
|
||||||
if PYDANTIC_VERSION < (2, 0):
|
if PYDANTIC_VERSION < (2, 0):
|
||||||
|
|
||||||
@@ -236,27 +238,18 @@ def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
|
|||||||
>>> from typing import List, Optional
|
>>> from typing import List, Optional
|
||||||
>>> import pydantic
|
>>> import pydantic
|
||||||
>>> from lancedb.pydantic import pydantic_to_schema
|
>>> from lancedb.pydantic import pydantic_to_schema
|
||||||
...
|
|
||||||
>>> class InnerModel(pydantic.BaseModel):
|
|
||||||
... a: str
|
|
||||||
... b: Optional[float]
|
|
||||||
>>>
|
|
||||||
>>> class FooModel(pydantic.BaseModel):
|
>>> class FooModel(pydantic.BaseModel):
|
||||||
... id: int
|
... id: int
|
||||||
... s: Optional[str] = None
|
... s: str
|
||||||
... vec: List[float]
|
... vec: List[float]
|
||||||
... li: List[int]
|
... li: List[int]
|
||||||
... inner: InnerModel
|
...
|
||||||
>>> schema = pydantic_to_schema(FooModel)
|
>>> schema = pydantic_to_schema(FooModel)
|
||||||
>>> assert schema == pa.schema([
|
>>> assert schema == pa.schema([
|
||||||
... pa.field("id", pa.int64(), False),
|
... pa.field("id", pa.int64(), False),
|
||||||
... pa.field("s", pa.utf8(), True),
|
... pa.field("s", pa.utf8(), False),
|
||||||
... pa.field("vec", pa.list_(pa.float64()), False),
|
... pa.field("vec", pa.list_(pa.float64()), False),
|
||||||
... pa.field("li", pa.list_(pa.int64()), False),
|
... pa.field("li", pa.list_(pa.int64()), False),
|
||||||
... pa.field("inner", pa.struct([
|
|
||||||
... pa.field("a", pa.utf8(), False),
|
|
||||||
... pa.field("b", pa.float64(), True),
|
|
||||||
... ]), False),
|
|
||||||
... ])
|
... ])
|
||||||
"""
|
"""
|
||||||
fields = _pydantic_model_to_fields(model)
|
fields = _pydantic_model_to_fields(model)
|
||||||
@@ -290,13 +283,58 @@ class LanceModel(pydantic.BaseModel):
|
|||||||
"""
|
"""
|
||||||
Get the Arrow Schema for this model.
|
Get the Arrow Schema for this model.
|
||||||
"""
|
"""
|
||||||
return pydantic_to_schema(cls)
|
schema = pydantic_to_schema(cls)
|
||||||
|
functions = cls.parse_embedding_functions()
|
||||||
|
if len(functions) > 0:
|
||||||
|
metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata(
|
||||||
|
functions
|
||||||
|
)
|
||||||
|
schema = schema.with_metadata(metadata)
|
||||||
|
return schema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def field_names(cls) -> List[str]:
|
def field_names(cls) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get the field names of this model.
|
Get the field names of this model.
|
||||||
"""
|
"""
|
||||||
|
return list(cls.safe_get_fields().keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def safe_get_fields(cls):
|
||||||
if PYDANTIC_VERSION.major < 2:
|
if PYDANTIC_VERSION.major < 2:
|
||||||
return list(cls.__fields__.keys())
|
return cls.__fields__
|
||||||
return list(cls.model_fields.keys())
|
return cls.model_fields
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
|
||||||
|
"""
|
||||||
|
Parse the embedding functions from this model.
|
||||||
|
"""
|
||||||
|
from .embeddings import EmbeddingFunctionConfig
|
||||||
|
|
||||||
|
vec_and_function = []
|
||||||
|
for name, field_info in cls.safe_get_fields().items():
|
||||||
|
func = get_extras(field_info, "vector_column_for")
|
||||||
|
if func is not None:
|
||||||
|
vec_and_function.append([name, func])
|
||||||
|
|
||||||
|
configs = []
|
||||||
|
for vec, func in vec_and_function:
|
||||||
|
for source, field_info in cls.safe_get_fields().items():
|
||||||
|
src_func = get_extras(field_info, "source_column_for")
|
||||||
|
if src_func == func:
|
||||||
|
configs.append(
|
||||||
|
EmbeddingFunctionConfig(
|
||||||
|
source_column=source, vector_column=vec, function=func
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
|
||||||
|
"""
|
||||||
|
Get the extra metadata from a Pydantic FieldInfo.
|
||||||
|
"""
|
||||||
|
if PYDANTIC_VERSION.major >= 2:
|
||||||
|
return (field_info.json_schema_extra or {}).get(key)
|
||||||
|
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||||
|
|||||||
@@ -60,13 +60,15 @@ class LanceQueryBuilder(ABC):
|
|||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
table: "lancedb.table.Table",
|
table: "lancedb.table.Table",
|
||||||
query: Optional[Union[np.ndarray, str]],
|
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
|
||||||
query_type: str,
|
query_type: str,
|
||||||
vector_column_name: str,
|
vector_column_name: str,
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
if query is None:
|
if query is None:
|
||||||
return LanceEmptyQueryBuilder(table)
|
return LanceEmptyQueryBuilder(table)
|
||||||
|
|
||||||
|
# convert "auto" query_type to "vector" or "fts"
|
||||||
|
# and convert the query to vector if needed
|
||||||
query, query_type = cls._resolve_query(
|
query, query_type = cls._resolve_query(
|
||||||
table, query, query_type, vector_column_name
|
table, query, query_type, vector_column_name
|
||||||
)
|
)
|
||||||
@@ -90,30 +92,27 @@ class LanceQueryBuilder(ABC):
|
|||||||
# otherwise raise TypeError
|
# otherwise raise TypeError
|
||||||
if query_type == "fts":
|
if query_type == "fts":
|
||||||
if not isinstance(query, str):
|
if not isinstance(query, str):
|
||||||
raise TypeError(
|
raise TypeError(f"'fts' queries must be a string: {type(query)}")
|
||||||
f"Query type is 'fts' but query is not a string: {type(query)}"
|
|
||||||
)
|
|
||||||
return query, query_type
|
return query, query_type
|
||||||
elif query_type == "vector":
|
elif query_type == "vector":
|
||||||
# If query_type is vector, then query must be a list or np.ndarray.
|
|
||||||
# otherwise raise TypeError
|
|
||||||
if not isinstance(query, (list, np.ndarray)):
|
if not isinstance(query, (list, np.ndarray)):
|
||||||
raise TypeError(
|
conf = table.embedding_functions.get(vector_column_name)
|
||||||
f"Query type is 'vector' but query is not a list or np.ndarray: {type(query)}"
|
if conf is not None:
|
||||||
)
|
query = conf.function.compute_query_embeddings(query)[0]
|
||||||
|
else:
|
||||||
|
msg = f"No embedding function for {vector_column_name}"
|
||||||
|
raise ValueError(msg)
|
||||||
return query, query_type
|
return query, query_type
|
||||||
elif query_type == "auto":
|
elif query_type == "auto":
|
||||||
if isinstance(query, (list, np.ndarray)):
|
if isinstance(query, (list, np.ndarray)):
|
||||||
return query, "vector"
|
return query, "vector"
|
||||||
elif isinstance(query, str):
|
else:
|
||||||
func = table.embedding_functions.get(vector_column_name, None)
|
conf = table.embedding_functions.get(vector_column_name)
|
||||||
if func is not None:
|
if conf is not None:
|
||||||
query = func(query)[0]
|
query = conf.function.compute_query_embeddings(query)[0]
|
||||||
return query, "vector"
|
return query, "vector"
|
||||||
else:
|
else:
|
||||||
return query, "fts"
|
return query, "fts"
|
||||||
else:
|
|
||||||
raise TypeError("Query must be a list, np.ndarray, or str")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
||||||
@@ -238,7 +237,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
table: "lancedb.table.Table",
|
table: "lancedb.table.Table",
|
||||||
query: Union[np.ndarray, list],
|
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||||
vector_column: str = VECTOR_COLUMN_NAME,
|
vector_column: str = VECTOR_COLUMN_NAME,
|
||||||
):
|
):
|
||||||
super().__init__(table)
|
super().__init__(table)
|
||||||
|
|||||||
@@ -18,10 +18,9 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
from lancedb.common import DATA
|
from ..common import DATA
|
||||||
from lancedb.db import DBConnection
|
from ..db import DBConnection
|
||||||
from lancedb.table import Table, _sanitize_data
|
from ..table import Table, _sanitize_data
|
||||||
|
|
||||||
from .arrow import to_ipc_binary
|
from .arrow import to_ipc_binary
|
||||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,8 @@ from lance.dataset import ReaderLike
|
|||||||
from lance.vector import vec_to_table
|
from lance.vector import vec_to_table
|
||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from .embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionRegistry
|
||||||
|
from .embeddings.functions import EmbeddingFunctionConfig
|
||||||
from .pydantic import LanceModel
|
from .pydantic import LanceModel
|
||||||
from .query import LanceQueryBuilder, Query
|
from .query import LanceQueryBuilder, Query
|
||||||
from .util import fs_from_uri, safe_import_pandas
|
from .util import fs_from_uri, safe_import_pandas
|
||||||
@@ -81,15 +82,16 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
|
|||||||
vector column to the table.
|
vector column to the table.
|
||||||
"""
|
"""
|
||||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||||
for vector_col, func in functions.items():
|
for vector_column, conf in functions.items():
|
||||||
if vector_col not in data.column_names:
|
func = conf.function
|
||||||
col_data = func(data[func.source_column])
|
if vector_column not in data.column_names:
|
||||||
|
col_data = func.compute_source_embeddings(data[conf.source_column])
|
||||||
if schema is not None:
|
if schema is not None:
|
||||||
dtype = schema.field(vector_col).type
|
dtype = schema.field(vector_column).type
|
||||||
else:
|
else:
|
||||||
dtype = pa.list_(pa.float32(), len(col_data[0]))
|
dtype = pa.list_(pa.float32(), len(col_data[0]))
|
||||||
data = data.append_column(
|
data = data.append_column(
|
||||||
pa.field(vector_col, type=dtype), pa.array(col_data, type=dtype)
|
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -230,7 +232,7 @@ class Table(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str]] = None,
|
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
query_type: str = "auto",
|
query_type: str = "auto",
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
@@ -239,7 +241,7 @@ class Table(ABC):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
query: str, list, np.ndarray, default None
|
query: str, list, np.ndarray, PIL.Image.Image, default None
|
||||||
The query to search for. If None then
|
The query to search for. If None then
|
||||||
the select/where/limit clauses are applied to filter
|
the select/where/limit clauses are applied to filter
|
||||||
the table
|
the table
|
||||||
@@ -249,6 +251,8 @@ class Table(ABC):
|
|||||||
"vector", "fts", or "auto"
|
"vector", "fts", or "auto"
|
||||||
If "auto" then the query type is inferred from the query;
|
If "auto" then the query type is inferred from the query;
|
||||||
If `query` is a list/np.ndarray then the query type is "vector";
|
If `query` is a list/np.ndarray then the query type is "vector";
|
||||||
|
If `query` is a PIL.Image.Image then either do vector search
|
||||||
|
or raise an error if no corresponding embedding function is found.
|
||||||
If `query` is a string, then the query type is "vector" if the
|
If `query` is a string, then the query type is "vector" if the
|
||||||
table has embedding functions else the query type is "fts"
|
table has embedding functions else the query type is "fts"
|
||||||
|
|
||||||
@@ -524,6 +528,9 @@ class LanceTable(Table):
|
|||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
):
|
):
|
||||||
"""Add data to the table.
|
"""Add data to the table.
|
||||||
|
If vector columns are missing and the table
|
||||||
|
has embedding functions, then the vector columns
|
||||||
|
are automatically computed and added.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -617,12 +624,6 @@ class LanceTable(Table):
|
|||||||
)
|
)
|
||||||
self._reset_dataset()
|
self._reset_dataset()
|
||||||
|
|
||||||
def _get_embedding_function_for_source_col(self, column_name: str):
|
|
||||||
for k, v in self.embedding_functions.items():
|
|
||||||
if v.source_column == column_name:
|
|
||||||
return v
|
|
||||||
return None
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def embedding_functions(self) -> dict:
|
def embedding_functions(self) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -640,7 +641,7 @@ class LanceTable(Table):
|
|||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str]] = None,
|
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
query_type: str = "auto",
|
query_type: str = "auto",
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
@@ -649,7 +650,7 @@ class LanceTable(Table):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
query: str, list, np.ndarray, or None
|
query: str, list, np.ndarray, a PIL Image or None
|
||||||
The query to search for. If None then
|
The query to search for. If None then
|
||||||
the select/where/limit clauses are applied to filter
|
the select/where/limit clauses are applied to filter
|
||||||
the table
|
the table
|
||||||
@@ -658,9 +659,11 @@ class LanceTable(Table):
|
|||||||
query_type: str, default "auto"
|
query_type: str, default "auto"
|
||||||
"vector", "fts", or "auto"
|
"vector", "fts", or "auto"
|
||||||
If "auto" then the query type is inferred from the query;
|
If "auto" then the query type is inferred from the query;
|
||||||
If the query is a list/np.ndarray then the query type is "vector";
|
If `query` is a list/np.ndarray then the query type is "vector";
|
||||||
|
If `query` is a PIL.Image.Image then either do vector search
|
||||||
|
or raise an error if no corresponding embedding function is found.
|
||||||
If the query is a string, then the query type is "vector" if the
|
If the query is a string, then the query type is "vector" if the
|
||||||
table has embedding functions else the query type is "fts"
|
table has embedding functions, else the query type is "fts"
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -684,7 +687,7 @@ class LanceTable(Table):
|
|||||||
mode="create",
|
mode="create",
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
embedding_functions: List[EmbeddingFunctionModel] = None,
|
embedding_functions: List[EmbeddingFunctionConfig] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new table.
|
Create a new table.
|
||||||
@@ -727,10 +730,16 @@ class LanceTable(Table):
|
|||||||
"""
|
"""
|
||||||
tbl = LanceTable(db, name)
|
tbl = LanceTable(db, name)
|
||||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||||
|
# convert LanceModel to pyarrow schema
|
||||||
|
# note that it's possible this contains
|
||||||
|
# embedding function metadata already
|
||||||
schema = schema.to_arrow_schema()
|
schema = schema.to_arrow_schema()
|
||||||
|
|
||||||
metadata = None
|
metadata = None
|
||||||
if embedding_functions is not None:
|
if embedding_functions is not None:
|
||||||
|
# If we passed in embedding functions explicitly
|
||||||
|
# then we'll override any schema metadata that
|
||||||
|
# may was implicitly specified by the LanceModel schema
|
||||||
registry = EmbeddingFunctionRegistry.get_instance()
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
metadata = registry.get_table_metadata(embedding_functions)
|
metadata = registry.get_table_metadata(embedding_functions)
|
||||||
|
|
||||||
|
|||||||
@@ -70,7 +70,11 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]:
|
|||||||
Get a PyArrow FileSystem from a URI, handling extra environment variables.
|
Get a PyArrow FileSystem from a URI, handling extra environment variables.
|
||||||
"""
|
"""
|
||||||
if get_uri_scheme(uri) == "s3":
|
if get_uri_scheme(uri) == "s3":
|
||||||
fs = pa_fs.S3FileSystem(endpoint_override=os.environ.get("AWS_ENDPOINT"))
|
fs = pa_fs.S3FileSystem(
|
||||||
|
endpoint_override=os.environ.get("AWS_ENDPOINT"),
|
||||||
|
request_timeout=30,
|
||||||
|
connect_timeout=30,
|
||||||
|
)
|
||||||
path = get_uri_location(uri)
|
path = get_uri_location(uri)
|
||||||
return fs, path
|
return fs, path
|
||||||
|
|
||||||
|
|||||||
@@ -44,9 +44,11 @@ classifiers = [
|
|||||||
repository = "https://github.com/lancedb/lancedb"
|
repository = "https://github.com/lancedb/lancedb"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio"]
|
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
|
||||||
dev = ["ruff", "pre-commit", "black"]
|
dev = ["ruff", "pre-commit", "black"]
|
||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||||
|
clip = ["torch", "pillow", "open-clip"]
|
||||||
|
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip"]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools", "wheel"]
|
requires = ["setuptools", "wheel"]
|
||||||
@@ -54,3 +56,10 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
addopts = "--strict-markers"
|
||||||
|
markers = [
|
||||||
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
|
"asyncio"
|
||||||
|
]
|
||||||
@@ -136,11 +136,9 @@ def test_ingest_iterator(tmp_path):
|
|||||||
def run_tests(schema):
|
def run_tests(schema):
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
|
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
|
||||||
|
|
||||||
tbl.to_pandas()
|
tbl.to_pandas()
|
||||||
assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0
|
assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0
|
||||||
assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0
|
assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0
|
||||||
|
|
||||||
tbl_len = len(tbl)
|
tbl_len = len(tbl)
|
||||||
tbl.add(make_batches())
|
tbl.add(make_batches())
|
||||||
assert tbl_len == 50
|
assert tbl_len == 50
|
||||||
|
|||||||
@@ -16,8 +16,12 @@ import lance
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
from lancedb.conftest import MockEmbeddingFunction
|
from lancedb.conftest import MockTextEmbeddingFunction
|
||||||
from lancedb.embeddings import EmbeddingFunctionRegistry, with_embeddings
|
from lancedb.embeddings import (
|
||||||
|
EmbeddingFunctionConfig,
|
||||||
|
EmbeddingFunctionRegistry,
|
||||||
|
with_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mock_embed_func(input_data):
|
def mock_embed_func(input_data):
|
||||||
@@ -54,8 +58,12 @@ def test_embedding_function(tmp_path):
|
|||||||
"vector": [np.random.randn(10), np.random.randn(10)],
|
"vector": [np.random.randn(10), np.random.randn(10)],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
conf = EmbeddingFunctionConfig(
|
||||||
metadata = registry.get_table_metadata([func])
|
source_column="text",
|
||||||
|
vector_column="vector",
|
||||||
|
function=MockTextEmbeddingFunction(),
|
||||||
|
)
|
||||||
|
metadata = registry.get_table_metadata([conf])
|
||||||
table = table.replace_schema_metadata(metadata)
|
table = table.replace_schema_metadata(metadata)
|
||||||
|
|
||||||
# Write it to disk
|
# Write it to disk
|
||||||
@@ -65,14 +73,13 @@ def test_embedding_function(tmp_path):
|
|||||||
ds = lance.dataset(tmp_path / "test.lance")
|
ds = lance.dataset(tmp_path / "test.lance")
|
||||||
|
|
||||||
# can we get the serialized version back out?
|
# can we get the serialized version back out?
|
||||||
functions = registry.parse_functions(ds.schema.metadata)
|
configs = registry.parse_functions(ds.schema.metadata)
|
||||||
|
|
||||||
func = functions["vector"]
|
conf = configs["vector"]
|
||||||
actual = func("hello world")
|
func = conf.function
|
||||||
|
actual = func.compute_query_embeddings("hello world")
|
||||||
|
|
||||||
# We create an instance
|
|
||||||
expected_func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
|
||||||
# And we make sure we can call it
|
# And we make sure we can call it
|
||||||
expected = expected_func("hello world")
|
expected = func.compute_query_embeddings("hello world")
|
||||||
|
|
||||||
assert np.allclose(actual, expected)
|
assert np.allclose(actual, expected)
|
||||||
|
|||||||
125
python/tests/test_embeddings_slow.py
Normal file
125
python/tests/test_embeddings_slow.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import io
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
import lancedb
|
||||||
|
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
|
# These are integration tests for embedding functions.
|
||||||
|
# They are slow because they require downloading models
|
||||||
|
# or connection to external api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||||
|
def test_sentence_transformer(alias, tmp_path):
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
func = registry.get(alias).create()
|
||||||
|
|
||||||
|
class Words(LanceModel):
|
||||||
|
text: str = func.SourceField()
|
||||||
|
vector: Vector(func.ndims()) = func.VectorField()
|
||||||
|
|
||||||
|
table = db.create_table("words", schema=Words)
|
||||||
|
table.add(
|
||||||
|
pd.DataFrame(
|
||||||
|
{
|
||||||
|
"text": [
|
||||||
|
"hello world",
|
||||||
|
"goodbye world",
|
||||||
|
"fizz",
|
||||||
|
"buzz",
|
||||||
|
"foo",
|
||||||
|
"bar",
|
||||||
|
"baz",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
query = "greetings"
|
||||||
|
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||||
|
|
||||||
|
vec = func.compute_query_embeddings(query)[0]
|
||||||
|
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
|
||||||
|
assert actual.text == expected.text
|
||||||
|
assert actual.text == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_openclip(tmp_path):
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
func = registry.get("open-clip").create()
|
||||||
|
|
||||||
|
class Images(LanceModel):
|
||||||
|
label: str
|
||||||
|
image_uri: str = func.SourceField()
|
||||||
|
image_bytes: bytes = func.SourceField()
|
||||||
|
vector: Vector(func.ndims()) = func.VectorField()
|
||||||
|
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
|
||||||
|
|
||||||
|
table = db.create_table("images", schema=Images)
|
||||||
|
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||||
|
uris = [
|
||||||
|
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||||
|
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||||
|
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||||
|
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||||
|
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||||
|
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||||
|
]
|
||||||
|
# get each uri as bytes
|
||||||
|
image_bytes = [requests.get(uri).content for uri in uris]
|
||||||
|
table.add(
|
||||||
|
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
|
||||||
|
)
|
||||||
|
|
||||||
|
# text search
|
||||||
|
actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0]
|
||||||
|
assert actual.label == "dog"
|
||||||
|
frombytes = (
|
||||||
|
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||||
|
.limit(1)
|
||||||
|
.to_pydantic(Images)[0]
|
||||||
|
)
|
||||||
|
assert actual.label == frombytes.label
|
||||||
|
assert np.allclose(actual.vector, frombytes.vector)
|
||||||
|
|
||||||
|
# image search
|
||||||
|
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||||
|
image_bytes = requests.get(query_image_uri).content
|
||||||
|
query_image = Image.open(io.BytesIO(image_bytes))
|
||||||
|
actual = table.search(query_image).limit(1).to_pydantic(Images)[0]
|
||||||
|
assert actual.label == "dog"
|
||||||
|
other = (
|
||||||
|
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||||
|
.limit(1)
|
||||||
|
.to_pydantic(Images)[0]
|
||||||
|
)
|
||||||
|
assert actual.label == other.label
|
||||||
|
|
||||||
|
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
|
||||||
|
assert np.allclose(
|
||||||
|
arrow_table["vector"].combine_chunks().values.to_numpy(),
|
||||||
|
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
|
||||||
|
)
|
||||||
@@ -22,8 +22,9 @@ import pandas as pd
|
|||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lancedb.conftest import MockEmbeddingFunction
|
from lancedb.conftest import MockTextEmbeddingFunction
|
||||||
from lancedb.db import LanceDBConnection
|
from lancedb.db import LanceDBConnection
|
||||||
|
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
from lancedb.table import LanceTable
|
from lancedb.table import LanceTable
|
||||||
|
|
||||||
@@ -356,20 +357,23 @@ def test_create_with_embedding_function(db):
|
|||||||
text: str
|
text: str
|
||||||
vector: Vector(10)
|
vector: Vector(10)
|
||||||
|
|
||||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
func = MockTextEmbeddingFunction()
|
||||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||||
df = pd.DataFrame({"text": texts, "vector": func(texts)})
|
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||||
|
|
||||||
|
conf = EmbeddingFunctionConfig(
|
||||||
|
source_column="text", vector_column="vector", function=func
|
||||||
|
)
|
||||||
table = LanceTable.create(
|
table = LanceTable.create(
|
||||||
db,
|
db,
|
||||||
"my_table",
|
"my_table",
|
||||||
schema=MyTable,
|
schema=MyTable,
|
||||||
embedding_functions=[func],
|
embedding_functions=[conf],
|
||||||
)
|
)
|
||||||
table.add(df)
|
table.add(df)
|
||||||
|
|
||||||
query_str = "hi how are you?"
|
query_str = "hi how are you?"
|
||||||
query_vector = func(query_str)[0]
|
query_vector = func.compute_query_embeddings(query_str)[0]
|
||||||
expected = table.search(query_vector).limit(2).to_arrow()
|
expected = table.search(query_vector).limit(2).to_arrow()
|
||||||
|
|
||||||
actual = table.search(query_str).limit(2).to_arrow()
|
actual = table.search(query_str).limit(2).to_arrow()
|
||||||
@@ -377,17 +381,13 @@ def test_create_with_embedding_function(db):
|
|||||||
|
|
||||||
|
|
||||||
def test_add_with_embedding_function(db):
|
def test_add_with_embedding_function(db):
|
||||||
class MyTable(LanceModel):
|
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||||
text: str
|
|
||||||
vector: Vector(10)
|
|
||||||
|
|
||||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
class MyTable(LanceModel):
|
||||||
table = LanceTable.create(
|
text: str = emb.SourceField()
|
||||||
db,
|
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||||
"my_table",
|
|
||||||
schema=MyTable,
|
table = LanceTable.create(db, "my_table", schema=MyTable)
|
||||||
embedding_functions=[func],
|
|
||||||
)
|
|
||||||
|
|
||||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||||
df = pd.DataFrame({"text": texts})
|
df = pd.DataFrame({"text": texts})
|
||||||
@@ -397,7 +397,7 @@ def test_add_with_embedding_function(db):
|
|||||||
table.add([{"text": t} for t in texts])
|
table.add([{"text": t} for t in texts])
|
||||||
|
|
||||||
query_str = "hi how are you?"
|
query_str = "hi how are you?"
|
||||||
query_vector = func(query_str)[0]
|
query_vector = emb.compute_query_embeddings(query_str)[0]
|
||||||
expected = table.search(query_vector).limit(2).to_arrow()
|
expected = table.search(query_vector).limit(2).to_arrow()
|
||||||
|
|
||||||
actual = table.search(query_str).limit(2).to_arrow()
|
actual = table.search(query_str).limit(2).to_arrow()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb-node"
|
name = "vectordb-node"
|
||||||
version = "0.2.5"
|
version = "0.2.6"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb"
|
name = "vectordb"
|
||||||
version = "0.2.5"
|
version = "0.2.6"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ use lance::dataset::WriteParams;
|
|||||||
use lance::io::object_store::ObjectStore;
|
use lance::io::object_store::ObjectStore;
|
||||||
use snafu::prelude::*;
|
use snafu::prelude::*;
|
||||||
|
|
||||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
use crate::error::{CreateDirSnafu, InvalidTableNameSnafu, Result};
|
||||||
use crate::table::{ReadParams, Table};
|
use crate::table::{ReadParams, Table};
|
||||||
|
|
||||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||||
@@ -36,17 +36,6 @@ pub struct Database {
|
|||||||
const LANCE_EXTENSION: &str = "lance";
|
const LANCE_EXTENSION: &str = "lance";
|
||||||
const ENGINE: &str = "engine";
|
const ENGINE: &str = "engine";
|
||||||
|
|
||||||
/// Parse a url, if it's not a valid url, assume it's a local file
|
|
||||||
/// and try to parse with file:// appended
|
|
||||||
fn parse_url(url: &str) -> Result<url::Url> {
|
|
||||||
match url::Url::parse(url) {
|
|
||||||
Ok(url) => Ok(url),
|
|
||||||
Err(_) => url::Url::parse(format!("file://{}", url).as_str()).map_err(|e| Error::Lance {
|
|
||||||
message: format!("Failed to parse uri: {}", e),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A connection to LanceDB
|
/// A connection to LanceDB
|
||||||
impl Database {
|
impl Database {
|
||||||
/// Connects to LanceDB
|
/// Connects to LanceDB
|
||||||
@@ -59,71 +48,73 @@ impl Database {
|
|||||||
///
|
///
|
||||||
/// * A [Database] object.
|
/// * A [Database] object.
|
||||||
pub async fn connect(uri: &str) -> Result<Database> {
|
pub async fn connect(uri: &str) -> Result<Database> {
|
||||||
// For a native (using lance directly) connection
|
let parse_res = url::Url::parse(uri);
|
||||||
// The DB doesn't use any uri parameters, but lance does
|
|
||||||
// So we need to parse the uri, extract the query string, and progate it to lance
|
|
||||||
let mut url = parse_url(uri)?;
|
|
||||||
|
|
||||||
// special handling for windows
|
match parse_res {
|
||||||
if url.scheme().len() == 1 && cfg!(windows) {
|
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => Self::open_path(uri).await,
|
||||||
let (object_store, base_path) = ObjectStore::from_uri(uri).await?;
|
Ok(mut url) => {
|
||||||
if object_store.is_local() {
|
// iter thru the query params and extract the commit store param
|
||||||
Self::try_create_dir(uri).context(CreateDirSnafu { path: uri })?;
|
let mut engine = None;
|
||||||
|
let mut filtered_querys = vec![];
|
||||||
|
|
||||||
|
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
||||||
|
// THE API WILL CHANGE
|
||||||
|
for (key, value) in url.query_pairs() {
|
||||||
|
if key == ENGINE {
|
||||||
|
engine = Some(value.to_string());
|
||||||
|
} else {
|
||||||
|
// to owned so we can modify the url
|
||||||
|
filtered_querys.push((key.to_string(), value.to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out the commit store query param -- it's a lancedb param
|
||||||
|
url.query_pairs_mut().clear();
|
||||||
|
url.query_pairs_mut().extend_pairs(filtered_querys);
|
||||||
|
// Take a copy of the query string so we can propagate it to lance
|
||||||
|
let query_string = url.query().map(|s| s.to_string());
|
||||||
|
// clear the query string so we can use the url as the base uri
|
||||||
|
// use .set_query(None) instead of .set_query("") because the latter
|
||||||
|
// will add a trailing '?' to the url
|
||||||
|
url.set_query(None);
|
||||||
|
|
||||||
|
let table_base_uri = if let Some(store) = engine {
|
||||||
|
static WARN_ONCE: std::sync::Once = std::sync::Once::new();
|
||||||
|
WARN_ONCE.call_once(|| {
|
||||||
|
log::warn!("Specifing engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE");
|
||||||
|
});
|
||||||
|
let old_scheme = url.scheme().to_string();
|
||||||
|
let new_scheme = format!("{}+{}", old_scheme, store);
|
||||||
|
url.to_string().replacen(&old_scheme, &new_scheme, 1)
|
||||||
|
} else {
|
||||||
|
url.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let plain_uri = url.to_string();
|
||||||
|
let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?;
|
||||||
|
if object_store.is_local() {
|
||||||
|
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Database {
|
||||||
|
uri: table_base_uri,
|
||||||
|
query_string,
|
||||||
|
base_path,
|
||||||
|
object_store,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return Ok(Database {
|
Err(_) => Self::open_path(uri).await,
|
||||||
uri: uri.to_string(),
|
|
||||||
query_string: None,
|
|
||||||
base_path,
|
|
||||||
object_store,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// iter thru the query params and extract the commit store param
|
async fn open_path(path: &str) -> Result<Database> {
|
||||||
let mut engine = None;
|
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||||
let mut filtered_querys = vec![];
|
|
||||||
|
|
||||||
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
|
||||||
// THE API WILL CHANGE
|
|
||||||
for (key, value) in url.query_pairs() {
|
|
||||||
if key == ENGINE {
|
|
||||||
engine = Some(value.to_string());
|
|
||||||
} else {
|
|
||||||
// to owned so we can modify the url
|
|
||||||
filtered_querys.push((key.to_string(), value.to_string()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter out the commit store query param -- it's a lancedb param
|
|
||||||
url.query_pairs_mut().clear();
|
|
||||||
url.query_pairs_mut().extend_pairs(filtered_querys);
|
|
||||||
// Take a copy of the query string so we can propagate it to lance
|
|
||||||
let query_string = url.query().map(|s| s.to_string());
|
|
||||||
// clear the query string so we can use the url as the base uri
|
|
||||||
// use .set_query(None) instead of .set_query("") because the latter
|
|
||||||
// will add a trailing '?' to the url
|
|
||||||
url.set_query(None);
|
|
||||||
|
|
||||||
let table_base_uri = if let Some(store) = engine {
|
|
||||||
static WARN_ONCE: std::sync::Once = std::sync::Once::new();
|
|
||||||
WARN_ONCE.call_once(|| {
|
|
||||||
log::warn!("Specifing engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE");
|
|
||||||
});
|
|
||||||
let old_scheme = url.scheme().to_string();
|
|
||||||
let new_scheme = format!("{}+{}", old_scheme, store);
|
|
||||||
url.to_string().replacen(&old_scheme, &new_scheme, 1)
|
|
||||||
} else {
|
|
||||||
url.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
let plain_uri = url.to_string();
|
|
||||||
let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?;
|
|
||||||
if object_store.is_local() {
|
if object_store.is_local() {
|
||||||
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
|
Self::try_create_dir(path).context(CreateDirSnafu { path: path })?;
|
||||||
}
|
}
|
||||||
|
Ok(Self {
|
||||||
Ok(Database {
|
uri: path.to_string(),
|
||||||
uri: table_base_uri,
|
query_string: None,
|
||||||
query_string,
|
|
||||||
base_path,
|
base_path,
|
||||||
object_store,
|
object_store,
|
||||||
})
|
})
|
||||||
@@ -240,6 +231,7 @@ impl Database {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::fs::create_dir_all;
|
use std::fs::create_dir_all;
|
||||||
|
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
|
|
||||||
use crate::database::Database;
|
use crate::database::Database;
|
||||||
@@ -250,15 +242,29 @@ mod tests {
|
|||||||
let uri = tmp_dir.path().to_str().unwrap();
|
let uri = tmp_dir.path().to_str().unwrap();
|
||||||
let db = Database::connect(uri).await.unwrap();
|
let db = Database::connect(uri).await.unwrap();
|
||||||
|
|
||||||
// file:// scheme should be automatically appended if not specified
|
assert_eq!(db.uri, uri);
|
||||||
// windows path come with drive letter, so file:// won't be appended
|
}
|
||||||
let expected = if cfg!(windows) {
|
|
||||||
uri.to_string()
|
|
||||||
} else {
|
|
||||||
format!("file://{}", uri)
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_eq!(db.uri, expected);
|
#[cfg(not(windows))]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_connect_relative() {
|
||||||
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
let uri = std::fs::canonicalize(tmp_dir.path().to_str().unwrap()).unwrap();
|
||||||
|
|
||||||
|
let mut relative_anacestors = vec![];
|
||||||
|
let current_dir = std::env::current_dir().unwrap();
|
||||||
|
let mut ancestors = current_dir.ancestors();
|
||||||
|
while let Some(_) = ancestors.next() {
|
||||||
|
relative_anacestors.push("..");
|
||||||
|
}
|
||||||
|
let relative_root = std::path::PathBuf::from(relative_anacestors.join("/"));
|
||||||
|
let relative_uri = relative_root.join(&uri);
|
||||||
|
|
||||||
|
let db = Database::connect(relative_uri.to_str().unwrap())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
Reference in New Issue
Block a user