mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
16 Commits
python-v0.
...
v0.3.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ebcf9bf6ae | ||
|
|
797514bcbf | ||
|
|
1c872ce501 | ||
|
|
479f471c14 | ||
|
|
ae0d2f2599 | ||
|
|
1e8678f11a | ||
|
|
662968559d | ||
|
|
9d895801f2 | ||
|
|
80613a40fd | ||
|
|
d43ef7f11e | ||
|
|
554e068917 | ||
|
|
567734dd6e | ||
|
|
1589499f89 | ||
|
|
682e95fa83 | ||
|
|
1ad5e7f2f0 | ||
|
|
ddb3ef4ce5 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.3.5
|
||||
current_version = 0.3.7
|
||||
commit = True
|
||||
message = Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
17
.github/workflows/python.yml
vendored
17
.github/workflows/python.yml
vendored
@@ -37,18 +37,19 @@ jobs:
|
||||
run: |
|
||||
pip install -e .[tests]
|
||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||
pip install pytest pytest-mock black isort
|
||||
- name: Black
|
||||
run: black --check --diff --no-color --quiet .
|
||||
- name: isort
|
||||
run: isort --check --diff --quiet .
|
||||
pip install pytest pytest-mock ruff
|
||||
- name: Lint
|
||||
run: ruff format --check .
|
||||
- name: Run tests
|
||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||
- name: doctest
|
||||
run: pytest --doctest-modules lancedb
|
||||
mac:
|
||||
timeout-minutes: 30
|
||||
runs-on: "macos-13"
|
||||
strategy:
|
||||
matrix:
|
||||
mac-runner: [ "macos-13", "macos-13-xlarge" ]
|
||||
runs-on: "${{ matrix.mac-runner }}"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -67,8 +68,6 @@ jobs:
|
||||
pip install -e .[tests]
|
||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||
pip install pytest pytest-mock black
|
||||
- name: Black
|
||||
run: black --check --diff --no-color --quiet .
|
||||
- name: Run tests
|
||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||
pydantic1x:
|
||||
@@ -100,4 +99,4 @@ jobs:
|
||||
- name: Run tests
|
||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||
- name: doctest
|
||||
run: pytest --doctest-modules lancedb
|
||||
run: pytest --doctest-modules lancedb
|
||||
|
||||
5
.github/workflows/rust.yml
vendored
5
.github/workflows/rust.yml
vendored
@@ -48,8 +48,11 @@ jobs:
|
||||
- name: Run tests
|
||||
run: cargo test --all-features
|
||||
macos:
|
||||
runs-on: macos-13
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
matrix:
|
||||
mac-runner: [ "macos-13", "macos-13-xlarge" ]
|
||||
runs-on: "${{ matrix.mac-runner }}"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
@@ -5,9 +5,9 @@ exclude = ["python"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.8.10", "features" = ["dynamodb"] }
|
||||
lance-linalg = { "version" = "=0.8.10" }
|
||||
lance-testing = { "version" = "=0.8.10" }
|
||||
lance = { "version" = "=0.8.14", "features" = ["dynamodb"] }
|
||||
lance-linalg = { "version" = "=0.8.14" }
|
||||
lance-testing = { "version" = "=0.8.14" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "47.0.0", optional = false }
|
||||
arrow-array = "47.0"
|
||||
|
||||
74
node/package-lock.json
generated
74
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.3.5",
|
||||
"version": "0.3.6",
|
||||
"lockfileVersion": 2,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.3.5",
|
||||
"version": "0.3.6",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -53,11 +53,11 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.5",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.5",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.5",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.5",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.5"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.6",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.6",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.6",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.6",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.6"
|
||||
}
|
||||
},
|
||||
"node_modules/@apache-arrow/ts": {
|
||||
@@ -317,9 +317,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.5.tgz",
|
||||
"integrity": "sha512-Nnso+WXMSTIUouddDgPDNt40K6d2fF7W5OsfgAMDXAhUrdSMOZbVP0bWklRz9J7JluseBL9/MfLSEYZDTvrACg==",
|
||||
"version": "0.3.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.6.tgz",
|
||||
"integrity": "sha512-GR5v+4kHUCZ71gVxd3mLsUdlreXPUIbvBgvr+BmEXRbLfc7+JsFUjsRgxmoctQ0mXxkW67Sl7v6kQCWcBLCk/Q==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -329,9 +329,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.5.tgz",
|
||||
"integrity": "sha512-gvg/iq13zAamLL7jueiIw7Q67dygm/NmILkFQ3WrAOUjr0IMxLBCv+XMxt62xajTrA+ObyfmU1uiuhrJL81PWw==",
|
||||
"version": "0.3.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.6.tgz",
|
||||
"integrity": "sha512-4qemi4jUXG8jOk7ecECmb0+5Nm0n7YF5/1X9/5uc81I+4What+yhZE9nEsmCGRBqmtuQXkYl35ePvQgj3rCQjQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -341,9 +341,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.5.tgz",
|
||||
"integrity": "sha512-6PvCBIXI9zPqF478TibZxxiAehFZ530g0FOFDT49xtp540HvhE9+XQk/yO0w96mvyoCfzB2lK4haDmdhCoehNw==",
|
||||
"version": "0.3.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.6.tgz",
|
||||
"integrity": "sha512-I/lFqIUcXYxJnUG5+DILzUzcfHRGHXL3kl5bs1MGkR9a7F3oPx1IAwY9wkskVnClM7XF9H7MVcFRVTjHUqoUwA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -353,9 +353,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.5.tgz",
|
||||
"integrity": "sha512-e3nqurUeCow4QONeNf/QP50Z90mgrh9xoUfjRSHcCPQcP6WgmFEafbt0jeSVgZ7tbt7+03/MK0YexhHM/5sBjA==",
|
||||
"version": "0.3.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.6.tgz",
|
||||
"integrity": "sha512-UTA/4bpA3UoByhfDx//S5m4o6uQ1qfpneD0PbuftAjkt9eHg0ABIEpZdiTI3xUBdrjXSKZtpVTxOin9X39IBKQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -365,9 +365,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.5.tgz",
|
||||
"integrity": "sha512-RC1FfgEr6Z9sADuvspT2PG1B2mpKRdckgeiHqTHkIXdq3Qp5V5TeQJAbVvMr2xd1q99W6zreub52QXf+AilLVQ==",
|
||||
"version": "0.3.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.6.tgz",
|
||||
"integrity": "sha512-70IS0TX4BpjSX4GP1Pq835cqQ5LZpfOJuBNtGv93OxMTWTVQUxtp2MLNwOR6OJMGNQz6q84NNKrKOSf15ZGwGg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -4869,33 +4869,33 @@
|
||||
}
|
||||
},
|
||||
"@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.5.tgz",
|
||||
"integrity": "sha512-Nnso+WXMSTIUouddDgPDNt40K6d2fF7W5OsfgAMDXAhUrdSMOZbVP0bWklRz9J7JluseBL9/MfLSEYZDTvrACg==",
|
||||
"version": "0.3.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.6.tgz",
|
||||
"integrity": "sha512-GR5v+4kHUCZ71gVxd3mLsUdlreXPUIbvBgvr+BmEXRbLfc7+JsFUjsRgxmoctQ0mXxkW67Sl7v6kQCWcBLCk/Q==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.5.tgz",
|
||||
"integrity": "sha512-gvg/iq13zAamLL7jueiIw7Q67dygm/NmILkFQ3WrAOUjr0IMxLBCv+XMxt62xajTrA+ObyfmU1uiuhrJL81PWw==",
|
||||
"version": "0.3.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.6.tgz",
|
||||
"integrity": "sha512-4qemi4jUXG8jOk7ecECmb0+5Nm0n7YF5/1X9/5uc81I+4What+yhZE9nEsmCGRBqmtuQXkYl35ePvQgj3rCQjQ==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.5.tgz",
|
||||
"integrity": "sha512-6PvCBIXI9zPqF478TibZxxiAehFZ530g0FOFDT49xtp540HvhE9+XQk/yO0w96mvyoCfzB2lK4haDmdhCoehNw==",
|
||||
"version": "0.3.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.6.tgz",
|
||||
"integrity": "sha512-I/lFqIUcXYxJnUG5+DILzUzcfHRGHXL3kl5bs1MGkR9a7F3oPx1IAwY9wkskVnClM7XF9H7MVcFRVTjHUqoUwA==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.5.tgz",
|
||||
"integrity": "sha512-e3nqurUeCow4QONeNf/QP50Z90mgrh9xoUfjRSHcCPQcP6WgmFEafbt0jeSVgZ7tbt7+03/MK0YexhHM/5sBjA==",
|
||||
"version": "0.3.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.6.tgz",
|
||||
"integrity": "sha512-UTA/4bpA3UoByhfDx//S5m4o6uQ1qfpneD0PbuftAjkt9eHg0ABIEpZdiTI3xUBdrjXSKZtpVTxOin9X39IBKQ==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.5.tgz",
|
||||
"integrity": "sha512-RC1FfgEr6Z9sADuvspT2PG1B2mpKRdckgeiHqTHkIXdq3Qp5V5TeQJAbVvMr2xd1q99W6zreub52QXf+AilLVQ==",
|
||||
"version": "0.3.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.6.tgz",
|
||||
"integrity": "sha512-70IS0TX4BpjSX4GP1Pq835cqQ5LZpfOJuBNtGv93OxMTWTVQUxtp2MLNwOR6OJMGNQz6q84NNKrKOSf15ZGwGg==",
|
||||
"optional": true
|
||||
},
|
||||
"@neon-rs/cli": {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.3.5",
|
||||
"version": "0.3.7",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -81,10 +81,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.5",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.5",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.5",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.5",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.5"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.7",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.7",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.7",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.7",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.7"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,6 +63,9 @@ export class HttpLancedbClient {
|
||||
}
|
||||
).catch((err) => {
|
||||
console.error('error: ', err)
|
||||
if (err.response === undefined) {
|
||||
throw new Error(`Network Error: ${err.message as string}`)
|
||||
}
|
||||
return err.response
|
||||
})
|
||||
if (response.status !== 200) {
|
||||
@@ -86,13 +89,17 @@ export class HttpLancedbClient {
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': this._apiKey()
|
||||
'x-api-key': this._apiKey(),
|
||||
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
|
||||
},
|
||||
params,
|
||||
timeout: 10000
|
||||
}
|
||||
).catch((err) => {
|
||||
console.error('error: ', err)
|
||||
if (err.response === undefined) {
|
||||
throw new Error(`Network Error: ${err.message as string}`)
|
||||
}
|
||||
return err.response
|
||||
})
|
||||
if (response.status !== 200) {
|
||||
@@ -128,6 +135,9 @@ export class HttpLancedbClient {
|
||||
}
|
||||
).catch((err) => {
|
||||
console.error('error: ', err)
|
||||
if (err.response === undefined) {
|
||||
throw new Error(`Network Error: ${err.message as string}`)
|
||||
}
|
||||
return err.response
|
||||
})
|
||||
if (response.status !== 200) {
|
||||
|
||||
@@ -237,7 +237,8 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
}
|
||||
|
||||
async countRows (): Promise<number> {
|
||||
throw new Error('Not implemented')
|
||||
const result = await this._client.post(`/v1/table/${this._name}/describe/`)
|
||||
return result.data?.stats?.num_rows
|
||||
}
|
||||
|
||||
async delete (filter: string): Promise<void> {
|
||||
|
||||
@@ -396,6 +396,40 @@ describe('LanceDB client', function () {
|
||||
})
|
||||
})
|
||||
|
||||
describe('Remote LanceDB client', function () {
|
||||
describe('when the server is not reachable', function () {
|
||||
it('produces a network error', async function () {
|
||||
const con = await lancedb.connect({
|
||||
uri: 'db://test-1234',
|
||||
region: 'asdfasfasfdf',
|
||||
apiKey: 'some-api-key'
|
||||
})
|
||||
|
||||
// GET
|
||||
try {
|
||||
await con.tableNames()
|
||||
} catch (err) {
|
||||
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
|
||||
}
|
||||
|
||||
// POST
|
||||
try {
|
||||
await con.createTable({ name: 'vectors', schema: new Schema([]) })
|
||||
} catch (err) {
|
||||
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
|
||||
}
|
||||
|
||||
// Search
|
||||
const table = await con.openTable('vectors')
|
||||
try {
|
||||
await table.search([0.1, 0.3]).execute()
|
||||
} catch (err) {
|
||||
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Query object', function () {
|
||||
it('sets custom parameters', async function () {
|
||||
const query = new Query([0.1, 0.3])
|
||||
|
||||
@@ -16,10 +16,11 @@ from typing import Optional
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
from .db import URI, DBConnection, LanceDBConnection
|
||||
from .common import URI
|
||||
from .db import DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector
|
||||
from .utils import sentry_log
|
||||
from .schema import vector # noqa: F401
|
||||
from .utils import sentry_log # noqa: F401
|
||||
|
||||
|
||||
def connect(
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -38,3 +40,26 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
||||
|
||||
def ndims(self):
|
||||
return 10
|
||||
|
||||
|
||||
class RateLimitedAPI:
|
||||
rate_limit = 0.1 # 1 request per 0.1 second
|
||||
last_request_time = 0
|
||||
|
||||
@staticmethod
|
||||
def make_request():
|
||||
current_time = time.time()
|
||||
|
||||
if current_time - RateLimitedAPI.last_request_time < RateLimitedAPI.rate_limit:
|
||||
raise Exception("Rate limit exceeded. Please try again later.")
|
||||
|
||||
# Simulate a successful request
|
||||
RateLimitedAPI.last_request_time = current_time
|
||||
return "Request successful"
|
||||
|
||||
|
||||
@registry.register("test-rate-limited")
|
||||
class MockRateLimitedEmbeddingFunction(MockTextEmbeddingFunction):
|
||||
def generate_embeddings(self, texts):
|
||||
RateLimitedAPI.make_request()
|
||||
return [self._compute_one_embedding(row) for row in texts]
|
||||
|
||||
@@ -14,26 +14,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
|
||||
|
||||
import pyarrow as pa
|
||||
from overrides import EnforceOverrides, override
|
||||
from pyarrow import fs
|
||||
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
from .table import LanceTable, Table
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
|
||||
class DBConnection(ABC):
|
||||
|
||||
class DBConnection(EnforceOverrides):
|
||||
"""An active LanceDB connection interface."""
|
||||
|
||||
@abstractmethod
|
||||
def table_names(self) -> list[str]:
|
||||
"""List all table names in the database."""
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""List all table in this database
|
||||
|
||||
Parameters
|
||||
----------
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
limit: int, default 10
|
||||
The size of the page to return.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -45,6 +58,7 @@ class DBConnection(ABC):
|
||||
mode: str = "create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
|
||||
@@ -262,12 +276,15 @@ class LanceDBConnection(DBConnection):
|
||||
def uri(self) -> str:
|
||||
return self._uri
|
||||
|
||||
def table_names(self) -> list[str]:
|
||||
@override
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""Get the names of all tables in the database. The names are sorted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of str
|
||||
Iterator of str.
|
||||
A list of table names.
|
||||
"""
|
||||
try:
|
||||
@@ -296,6 +313,7 @@ class LanceDBConnection(DBConnection):
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self.table_names()
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
self,
|
||||
name: str,
|
||||
@@ -327,6 +345,7 @@ class LanceDBConnection(DBConnection):
|
||||
)
|
||||
return tbl
|
||||
|
||||
@override
|
||||
def open_table(self, name: str) -> LanceTable:
|
||||
"""Open a table in the database.
|
||||
|
||||
@@ -341,6 +360,7 @@ class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
return LanceTable.open(self, name)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str, ignore_missing: bool = False):
|
||||
"""Drop a table from the database.
|
||||
|
||||
@@ -359,6 +379,7 @@ class LanceDBConnection(DBConnection):
|
||||
if not ignore_missing:
|
||||
raise
|
||||
|
||||
@override
|
||||
def drop_database(self):
|
||||
filesystem, path = fs_from_uri(self.uri)
|
||||
filesystem.delete_dir(path)
|
||||
|
||||
@@ -11,8 +11,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F401
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||
from .cohere import CohereEmbeddingFunction
|
||||
from .instructor import InstructorEmbeddingFunction
|
||||
from .open_clip import OpenClipEmbeddings
|
||||
from .openai import OpenAIEmbeddings
|
||||
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||
|
||||
@@ -1,3 +1,15 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
@@ -6,7 +18,7 @@ import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from .utils import TEXT
|
||||
from .utils import TEXT, retry_with_exponential_backoff
|
||||
|
||||
|
||||
class EmbeddingFunction(BaseModel, ABC):
|
||||
@@ -21,6 +33,10 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
3. ndims method which returns the number of dimensions of the vector column
|
||||
"""
|
||||
|
||||
__slots__ = ("__weakref__",) # pydantic 1.x compatibility
|
||||
max_retries: int = (
|
||||
7 # Setitng 0 disables retires. Maybe this should not be enabled by default,
|
||||
)
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@classmethod
|
||||
@@ -44,6 +60,25 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_query_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query with retries
|
||||
"""
|
||||
return retry_with_exponential_backoff(
|
||||
self.compute_query_embeddings, max_retries=self.max_retries
|
||||
)(
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def compute_source_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database with retries
|
||||
"""
|
||||
return retry_with_exponential_backoff(
|
||||
self.compute_source_embeddings, max_retries=self.max_retries
|
||||
)(*args, **kwargs)
|
||||
|
||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
@@ -103,6 +138,14 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if not hasattr(__value, "__dict__"):
|
||||
return False
|
||||
return vars(self) == vars(__value)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(frozenset(vars(self).items()))
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -31,7 +31,8 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "embed-multilingual-v2.0"
|
||||
The name of the model to use. See the Cohere documentation for a list of available models.
|
||||
The name of the model to use. See the Cohere documentation for
|
||||
a list of available models.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -39,7 +40,10 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
cohere = EmbeddingFunctionRegistry.get_instance().get("cohere").create(name="embed-multilingual-v2.0")
|
||||
cohere = EmbeddingFunctionRegistry
|
||||
.get_instance()
|
||||
.get("cohere")
|
||||
.create(name="embed-multilingual-v2.0")
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = cohere.SourceField()
|
||||
|
||||
137
python/lancedb/embeddings/instructor.py
Normal file
137
python/lancedb/embeddings/instructor.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT, weak_lru
|
||||
|
||||
|
||||
@register("instructor")
|
||||
class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the InstructorEmbedding library. Instructor models support multi-task learning, and can be used for a
|
||||
variety of tasks, including text classification, sentence similarity, and document retrieval.
|
||||
If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions:
|
||||
"Represent the `domain` `text_type` for `task_objective`":
|
||||
|
||||
* domain is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc.
|
||||
* text_type is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc.
|
||||
* task_objective is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc.
|
||||
|
||||
For example, if you want to calculate embeddings for a document, you may write the instruction as follows:
|
||||
"Represent the document for retreival"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the model to use. Available models are listed at https://github.com/xlang-ai/instructor-embedding#model-list;
|
||||
The default model is hkunlp/instructor-base
|
||||
batch_size: int, default 32
|
||||
The batch size to use when generating embeddings
|
||||
device: str, default "cpu"
|
||||
The device to use when generating embeddings
|
||||
show_progress_bar: bool, default True
|
||||
Whether to show a progress bar when generating embeddings
|
||||
normalize_embeddings: bool, default True
|
||||
Whether to normalize the embeddings
|
||||
quantize: bool, default False
|
||||
Whether to quantize the model
|
||||
source_instruction: str, default "represent the docuement for retreival"
|
||||
The instruction for the source column
|
||||
query_instruction: str, default "represent the document for retreiving the most similar documents"
|
||||
The instruction for the query
|
||||
|
||||
Examples
|
||||
--------
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction
|
||||
|
||||
instructor = get_registry().get("instructor").create(
|
||||
source_instruction="represent the docuement for retreival",
|
||||
query_instruction="represent the document for retreiving the most similar documents"
|
||||
)
|
||||
|
||||
class Schema(LanceModel):
|
||||
vector: Vector(instructor.ndims()) = instructor.VectorField()
|
||||
text: str = instructor.SourceField()
|
||||
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
|
||||
texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."},
|
||||
{"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."},
|
||||
{"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}]
|
||||
|
||||
tbl.add(texts)
|
||||
|
||||
"""
|
||||
|
||||
name: str = "hkunlp/instructor-base"
|
||||
batch_size: int = 32
|
||||
device: str = "cpu"
|
||||
show_progress_bar: bool = True
|
||||
normalize_embeddings: bool = True
|
||||
quantize: bool = False
|
||||
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||
|
||||
source_instruction: str = "represent the document for retrieval"
|
||||
query_instruction: str = (
|
||||
"represent the document for retrieving the most similar documents"
|
||||
)
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def ndims(self):
|
||||
model = self.get_model()
|
||||
return model.encode("foo").shape[0]
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.generate_embeddings([[self.query_instruction, query]])
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
texts_formatted = []
|
||||
for text in texts:
|
||||
texts_formatted.append([self.source_instruction, text])
|
||||
return self.generate_embeddings(texts_formatted)
|
||||
|
||||
def generate_embeddings(self, texts: List) -> List:
|
||||
model = self.get_model()
|
||||
res = model.encode(
|
||||
texts,
|
||||
batch_size=self.batch_size,
|
||||
show_progress_bar=self.show_progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
).tolist()
|
||||
return res
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def get_model(self):
|
||||
instructor_embedding = self.safe_import(
|
||||
"InstructorEmbedding", "InstructorEmbedding"
|
||||
)
|
||||
torch = self.safe_import("torch", "torch")
|
||||
|
||||
model = instructor_embedding.INSTRUCTOR(self.name)
|
||||
if self.quantize:
|
||||
if (
|
||||
"qnnpack" in torch.backends.quantized.supported_engines
|
||||
): # fix for https://github.com/pytorch/pytorch/issues/29327
|
||||
torch.backends.quantized.engine = "qnnpack"
|
||||
model = torch.quantization.quantize_dynamic(
|
||||
model, {torch.nn.Linear}, dtype=torch.qint8
|
||||
)
|
||||
return model
|
||||
@@ -1,3 +1,15 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import concurrent.futures
|
||||
import io
|
||||
import os
|
||||
|
||||
@@ -1,3 +1,15 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -1,3 +1,15 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -5,6 +17,7 @@ from cachetools import cached
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
|
||||
|
||||
@register("sentence-transformers")
|
||||
@@ -30,7 +43,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.__class__.get_embedding_model(self.name, self.device)
|
||||
return self.get_embedding_model()
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
@@ -54,9 +67,8 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
normalize_embeddings=self.normalize,
|
||||
).tolist()
|
||||
|
||||
@classmethod
|
||||
@cached(cache={})
|
||||
def get_embedding_model(cls, name, device):
|
||||
@weak_lru(maxsize=1)
|
||||
def get_embedding_model(self):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
@@ -71,7 +83,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
sentence_transformers = cls.safe_import(
|
||||
sentence_transformers = self.safe_import(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(name, device=device)
|
||||
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
|
||||
|
||||
@@ -11,10 +11,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import math
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import weakref
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -162,6 +166,99 @@ class FunctionWrapper:
|
||||
yield from _chunker(arr)
|
||||
|
||||
|
||||
def weak_lru(maxsize=128):
|
||||
"""
|
||||
LRU cache that keeps weak references to the objects it caches. Only caches the latest instance of the objects to make sure memory usage
|
||||
is bounded.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
maxsize : int, default 128
|
||||
The maximum number of objects to cache.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable
|
||||
A decorator that can be applied to a method.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> class Foo:
|
||||
... @weak_lru()
|
||||
... def bar(self, x):
|
||||
... return x
|
||||
>>> foo = Foo()
|
||||
>>> foo.bar(1)
|
||||
1
|
||||
>>> foo.bar(2)
|
||||
2
|
||||
>>> foo.bar(1)
|
||||
1
|
||||
"""
|
||||
|
||||
def wrapper(func):
|
||||
@functools.lru_cache(maxsize)
|
||||
def _func(_self, *args, **kwargs):
|
||||
return func(_self(), *args, **kwargs)
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(self, *args, **kwargs):
|
||||
return _func(weakref.ref(self), *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def retry_with_exponential_backoff(
|
||||
func,
|
||||
initial_delay: float = 1,
|
||||
exponential_base: float = 2,
|
||||
jitter: bool = True,
|
||||
max_retries: int = 7,
|
||||
# errors: tuple = (),
|
||||
):
|
||||
"""Retry a function with exponential backoff.
|
||||
|
||||
Args:
|
||||
func (function): The function to be retried.
|
||||
initial_delay (float): Initial delay in seconds (default is 1).
|
||||
exponential_base (float): The base for exponential backoff (default is 2).
|
||||
jitter (bool): Whether to add jitter to the delay (default is True).
|
||||
max_retries (int): Maximum number of retries (default is 10).
|
||||
errors (tuple): Tuple of specific exceptions to retry on (default is (openai.error.RateLimitError,)).
|
||||
|
||||
Returns:
|
||||
function: The decorated function.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
num_retries = 0
|
||||
delay = initial_delay
|
||||
|
||||
# Loop until a successful response or max_retries is hit or an exception is raised
|
||||
while True:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Currently retrying on all exceptions as there is no way to know the format of the error msgs used by different APIs
|
||||
# We'll log the error and say that it is assumed that if this portion errors out, it's due to rate limit but the user
|
||||
# should check the error message to be sure
|
||||
except Exception as e:
|
||||
num_retries += 1
|
||||
|
||||
if num_retries > max_retries:
|
||||
raise Exception(
|
||||
f"Maximum number of retries ({max_retries}) exceeded."
|
||||
)
|
||||
|
||||
delay *= exponential_base * (1 + jitter * random.random())
|
||||
LOGGER.info(f"Retrying in {delay:.2f} seconds due to {e}")
|
||||
time.sleep(delay)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def url_retrieve(url: str):
|
||||
"""
|
||||
Parameters
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
|
||||
|
||||
import deprecation
|
||||
import numpy as np
|
||||
@@ -23,9 +23,11 @@ import pydantic
|
||||
|
||||
from . import __version__
|
||||
from .common import VECTOR_COLUMN_NAME
|
||||
from .pydantic import LanceModel
|
||||
from .util import safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pydantic import LanceModel
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
@@ -140,7 +142,7 @@ class LanceQueryBuilder(ABC):
|
||||
if not isinstance(query, (list, np.ndarray)):
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
query = conf.function.compute_query_embeddings(query)[0]
|
||||
query = conf.function.compute_query_embeddings_with_retry(query)[0]
|
||||
else:
|
||||
msg = f"No embedding function for {vector_column_name}"
|
||||
raise ValueError(msg)
|
||||
@@ -151,7 +153,7 @@ class LanceQueryBuilder(ABC):
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
query = conf.function.compute_query_embeddings(query)[0]
|
||||
query = conf.function.compute_query_embeddings_with_retry(query)[0]
|
||||
return query, "vector"
|
||||
else:
|
||||
return query, "fts"
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import attrs
|
||||
@@ -151,15 +151,14 @@ class RestfulLanceDBClient:
|
||||
return await deserialize(resp)
|
||||
|
||||
@_check_not_closed
|
||||
async def list_tables(self, limit: int, page_token: str):
|
||||
async def list_tables(
|
||||
self, limit: int, page_token: Optional[str] = None
|
||||
) -> Iterable[str]:
|
||||
"""List all tables in the database."""
|
||||
try:
|
||||
json = await self.get(
|
||||
"/v1/table/", {"limit": limit, "page_token": page_token}
|
||||
)
|
||||
return json["tables"]
|
||||
except StopAsyncIteration:
|
||||
return []
|
||||
if page_token is None:
|
||||
page_token = ""
|
||||
json = await self.get("/v1/table/", {"limit": limit, "page_token": page_token})
|
||||
return json["tables"]
|
||||
|
||||
@_check_not_closed
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
|
||||
@@ -12,14 +12,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Iterator, Optional
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pyarrow as pa
|
||||
from overrides import override
|
||||
|
||||
from ..common import DATA
|
||||
from ..db import DBConnection
|
||||
from ..embeddings import EmbeddingFunctionConfig
|
||||
from ..pydantic import LanceModel
|
||||
from ..table import Table, _sanitize_data
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
@@ -52,11 +57,13 @@ class RemoteDBConnection(DBConnection):
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoveConnect(name={self.db_name})"
|
||||
|
||||
def table_names(self, last_token: str, limit=10) -> Iterator[str]:
|
||||
@override
|
||||
def table_names(self, page_token: Optional[str] = None, limit=10) -> Iterable[str]:
|
||||
"""List the names of all tables in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
last_token: str
|
||||
page_token: str
|
||||
The last token to start the new page.
|
||||
|
||||
Returns
|
||||
@@ -65,15 +72,16 @@ class RemoteDBConnection(DBConnection):
|
||||
"""
|
||||
while True:
|
||||
result = self._loop.run_until_complete(
|
||||
self._client.list_tables(limit, last_token)
|
||||
self._client.list_tables(limit, page_token)
|
||||
)
|
||||
if len(result) > 0:
|
||||
last_token = result[len(result) - 1]
|
||||
page_token = result[len(result) - 1]
|
||||
else:
|
||||
break
|
||||
for item in result:
|
||||
yield result
|
||||
yield item
|
||||
|
||||
@override
|
||||
def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
@@ -88,20 +96,43 @@ class RemoteDBConnection(DBConnection):
|
||||
"""
|
||||
from .table import RemoteTable
|
||||
|
||||
# TODO: check if table exists
|
||||
|
||||
# check if table exists
|
||||
try:
|
||||
self._loop.run_until_complete(
|
||||
self._client.post(f"/v1/table/{name}/describe/")
|
||||
)
|
||||
except Exception:
|
||||
logging.error(
|
||||
"Table {name} does not exist."
|
||||
"Please first call db.create_table({name}, data)"
|
||||
)
|
||||
return RemoteTable(self, name)
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: DATA = None,
|
||||
schema: pa.Schema = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
if data is None and schema is None:
|
||||
raise ValueError("Either data or schema must be provided.")
|
||||
if embedding_functions is not None:
|
||||
raise NotImplementedError(
|
||||
"embedding_functions is not supported for remote databases."
|
||||
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
||||
"for this feature."
|
||||
)
|
||||
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
# note that it's possible this contains
|
||||
# embedding function metadata already
|
||||
schema = schema.to_arrow_schema()
|
||||
|
||||
if data is not None:
|
||||
data = _sanitize_data(
|
||||
data,
|
||||
@@ -130,6 +161,7 @@ class RemoteDBConnection(DBConnection):
|
||||
)
|
||||
return RemoteTable(self, name)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str):
|
||||
"""Drop a table from the database.
|
||||
|
||||
|
||||
@@ -44,6 +44,14 @@ class RemoteTable(Table):
|
||||
schema = json_to_schema(resp["schema"])
|
||||
return schema
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
"""Get the current version of the table"""
|
||||
resp = self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
)
|
||||
return resp["version"]
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""Return the table as an Arrow table."""
|
||||
raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
|
||||
@@ -99,8 +107,6 @@ class RemoteTable(Table):
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
if query.prefilter:
|
||||
raise NotImplementedError("Cloud support for prefiltering is coming soon")
|
||||
result = self._conn._client.query(self._name, query)
|
||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||
|
||||
|
||||
@@ -16,16 +16,14 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import timedelta
|
||||
from functools import cached_property
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
from lance import LanceDataset
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
@@ -35,6 +33,12 @@ from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
from .utils.events import register_event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
@@ -86,7 +90,9 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
|
||||
for vector_column, conf in functions.items():
|
||||
func = conf.function
|
||||
if vector_column not in data.column_names:
|
||||
col_data = func.compute_source_embeddings(data[conf.source_column])
|
||||
col_data = func.compute_source_embeddings_with_retry(
|
||||
data[conf.source_column]
|
||||
)
|
||||
if schema is not None:
|
||||
dtype = schema.field(vector_column).type
|
||||
else:
|
||||
|
||||
@@ -14,7 +14,8 @@ dependencies = [
|
||||
"cachetools",
|
||||
"pyyaml>=6.0",
|
||||
"click>=8.1.7",
|
||||
"requests>=2.31.0"
|
||||
"requests>=2.31.0",
|
||||
"overrides>=0.7"
|
||||
]
|
||||
description = "lancedb"
|
||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||
@@ -52,7 +53,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
|
||||
dev = ["ruff", "pre-commit", "black"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere"]
|
||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
|
||||
|
||||
[project.scripts]
|
||||
lancedb = "lancedb.cli.cli:cli"
|
||||
@@ -64,6 +65,9 @@ build-backend = "setuptools.build_meta"
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
[tool.ruff]
|
||||
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--strict-markers"
|
||||
markers = [
|
||||
|
||||
@@ -129,7 +129,7 @@ def test_ingest_iterator(tmp_path):
|
||||
[
|
||||
PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0),
|
||||
PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0),
|
||||
]
|
||||
],
|
||||
# TODO: test pydict separately. it is unique column number and names contraint
|
||||
]
|
||||
|
||||
|
||||
@@ -15,13 +15,16 @@ import sys
|
||||
import lance
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
import lancedb
|
||||
from lancedb.conftest import MockRateLimitedEmbeddingFunction, MockTextEmbeddingFunction
|
||||
from lancedb.embeddings import (
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
with_embeddings,
|
||||
)
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
|
||||
def mock_embed_func(input_data):
|
||||
@@ -83,3 +86,29 @@ def test_embedding_function(tmp_path):
|
||||
expected = func.compute_query_embeddings("hello world")
|
||||
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
|
||||
def test_embedding_function_rate_limit(tmp_path):
|
||||
def _get_schema_from_model(model):
|
||||
class Schema(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
return Schema
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
model = registry.get("test-rate-limited").create(max_retries=0)
|
||||
schema = _get_schema_from_model(model)
|
||||
table = db.create_table("test", schema=schema, mode="overwrite")
|
||||
table.add([{"text": "hello world"}])
|
||||
with pytest.raises(Exception):
|
||||
table.add([{"text": "hello world"}])
|
||||
assert len(table) == 1
|
||||
|
||||
model = registry.get("test-rate-limited").create()
|
||||
schema = _get_schema_from_model(model)
|
||||
table = db.create_table("test", schema=schema, mode="overwrite")
|
||||
table.add([{"text": "hello world"}])
|
||||
table.add([{"text": "hello world"}])
|
||||
assert len(table) == 2
|
||||
|
||||
@@ -32,8 +32,8 @@ from lancedb.pydantic import LanceModel, Vector
|
||||
def test_sentence_transformer(alias, tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get(alias).create()
|
||||
func2 = registry.get(alias).create()
|
||||
func = registry.get(alias).create(max_retries=0)
|
||||
func2 = registry.get(alias).create(max_retries=0)
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
@@ -150,7 +150,11 @@ def test_openclip(tmp_path):
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
) # also skip if cohere not installed
|
||||
def test_cohere_embedding_function():
|
||||
cohere = get_registry().get("cohere").create(name="embed-multilingual-v2.0")
|
||||
cohere = (
|
||||
get_registry()
|
||||
.get("cohere")
|
||||
.create(name="embed-multilingual-v2.0", max_retries=0)
|
||||
)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = cohere.SourceField()
|
||||
@@ -162,3 +166,19 @@ def test_cohere_embedding_function():
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_instructor_embedding(tmp_path):
|
||||
model = get_registry().get("instructor").create()
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb-node"
|
||||
version = "0.3.5"
|
||||
version = "0.3.7"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
edition = "2018"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb"
|
||||
version = "0.3.5"
|
||||
version = "0.3.7"
|
||||
edition = "2021"
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
|
||||
@@ -25,7 +25,8 @@ use bytes::Bytes;
|
||||
use futures::{stream::BoxStream, FutureExt, StreamExt};
|
||||
use lance::io::object_store::WrappingObjectStore;
|
||||
use object_store::{
|
||||
path::Path, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result,
|
||||
path::Path, Error, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore,
|
||||
Result,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
@@ -120,7 +121,10 @@ impl ObjectStore for MirroringObjectStore {
|
||||
|
||||
async fn delete(&self, location: &Path) -> Result<()> {
|
||||
if !location.primary_only() {
|
||||
self.secondary.delete(location).await?;
|
||||
match self.secondary.delete(location).await {
|
||||
Err(Error::NotFound { .. }) | Ok(_) => {}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
self.primary.delete(location).await
|
||||
}
|
||||
|
||||
@@ -376,12 +376,12 @@ impl Table {
|
||||
self.dataset.count_fragments()
|
||||
}
|
||||
|
||||
pub fn count_deleted_rows(&self) -> usize {
|
||||
self.dataset.count_deleted_rows()
|
||||
pub async fn count_deleted_rows(&self) -> Result<usize> {
|
||||
Ok(self.dataset.count_deleted_rows().await?)
|
||||
}
|
||||
|
||||
pub fn num_small_files(&self, max_rows_per_group: usize) -> usize {
|
||||
self.dataset.num_small_files(max_rows_per_group)
|
||||
pub async fn num_small_files(&self, max_rows_per_group: usize) -> usize {
|
||||
self.dataset.num_small_files(max_rows_per_group).await
|
||||
}
|
||||
|
||||
pub async fn count_indexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
|
||||
|
||||
Reference in New Issue
Block a user