Compare commits

...

16 Commits

Author SHA1 Message Date
Lance Release
975398c3a8 Bump version: 0.12.0 → 0.13.0-beta.0 2024-11-05 23:21:32 +00:00
Lance Release
08d5f93f34 Bump version: 0.15.0 → 0.16.0-beta.0 2024-11-05 23:21:13 +00:00
Will Jones
91cab3b556 feat(python): transition Python remote sdk to use Rust implementation (#1701)
* Replaces Python implementation of Remote SDK with Rust one.
* Drops dependency on `attrs` and `cachetools`. Makes `requests` an
optional dependency used only for embeddings feature.
* Adds dependency on `nest-asyncio`. This was required to get hybrid
search working.
* Deprecate `request_thread_pool` parameter. We now use the tokio
threadpool.
* Stop caching the `schema` on a remote table. Schema is mutable and
there's no mechanism in place to invalidate the cache.
* Removed the client-side resolution of the vector column. We should
already be resolving this server-side.
2024-11-05 13:44:39 -08:00
Will Jones
c61bfc3af8 chore: update package locks (#1798) 2024-11-05 13:28:59 -08:00
Bert
4e8c7b0adf fix: serialize vectordb client errors as json (#1795) 2024-11-05 14:16:25 -05:00
Weston Pace
26f4a80e10 feat: upgrade to lance 0.19.2-beta.3 (#1794) 2024-11-05 06:43:41 -08:00
Will Jones
3604d20ad3 feat(python,node): support with_row_id in Python and remote (#1784)
Needed to support hybrid search in Remote SDK.
2024-11-04 11:25:45 -08:00
Gagan Bhullar
9708d829a9 fix: explain plan options (#1776)
PR fixes #1768
2024-11-04 10:25:34 -08:00
Will Jones
059c9794b5 fix(rust): fix update, open_table, fts search in remote client (#1785)
* `open_table` uses `POST` not `GET`
* `update` uses `predicate` key not `only_if`
* For FTS search, vector cannot be omitted. It must be passed as empty.
* Added logging of JSON request bodies to debug level logging.
2024-11-04 08:27:55 -08:00
Will Jones
15ed7f75a0 feat(python): support post filter on FTS (#1783) 2024-11-01 10:05:05 -07:00
Will Jones
96181ab421 feat: fast_search in Python and Node (#1623)
Sometimes it is acceptable to users to only search indexed data and skip
and new un-indexed data. For example, if un-indexed data will be shortly
indexed and they don't mind the delay. In these cases, we can save a lot
of CPU time in search, and provide better latency. Users can activate
this on queries using `fast_search()`.
2024-11-01 09:29:09 -07:00
Will Jones
f3fc339ef6 fix(rust): fix delete, update, query in remote SDK (#1782)
Fixes several minor issues with Rust remote SDK:

* Delete uses `predicate` not `filter` as parameter
* Update does not return the row value in remote SDK
* Update takes tuples
* Content type returned by query node is wrong, so we shouldn't validate
it. https://github.com/lancedb/sophon/issues/2742
* Data returned by query endpoint is actually an Arrow IPC file, not IPC
stream.
2024-10-31 15:22:09 -07:00
Will Jones
113cd6995b fix: index_stats works for FTS indices (#1780)
When running `index_stats()` for an FTS index, users would get the
deserialization error:

```
InvalidInput { message: "error deserializing index statistics: unknown variant `Inverted`, expected one of `IvfPq`, `IvfHnswPq`, `IvfHnswSq`, `BTree`, `Bitmap`, `LabelList`, `FTS` at line 1 column 24" }
```
2024-10-30 11:33:49 -07:00
Lance Release
02535bdc88 Updating package-lock.json 2024-10-29 22:16:51 +00:00
Lance Release
facc7d61c0 Bump version: 0.12.0-beta.0 → 0.12.0 2024-10-29 22:16:32 +00:00
Lance Release
f947259f16 Bump version: 0.11.1-beta.1 → 0.12.0-beta.0 2024-10-29 22:16:27 +00:00
55 changed files with 907 additions and 1082 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.11.1-beta.1"
current_version = "0.13.0-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -21,13 +21,15 @@ categories = ["database-implementations"]
rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
[workspace.dependencies]
lance = { "version" = "=0.19.1", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.19.1" }
lance-linalg = { "version" = "=0.19.1" }
lance-table = { "version" = "=0.19.1" }
lance-testing = { "version" = "=0.19.1" }
lance-datafusion = { "version" = "=0.19.1" }
lance-encoding = { "version" = "=0.19.1" }
lance = { "version" = "=0.19.2", "features" = [
"dynamodb",
], git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-index = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-linalg = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-table = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-testing = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-datafusion = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-encoding = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
# Note that this one does not include pyarrow
arrow = { version = "52.2", optional = false }
arrow-array = "52.2"

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.11.1-beta.1</version>
<version>0.13.0-beta.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.11.1-beta.1</version>
<version>0.13.0-beta.0</version>
<packaging>pom</packaging>
<name>LanceDB Parent</name>

49
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.11.1-beta.1",
"version": "0.12.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.11.1-beta.1",
"version": "0.12.0",
"cpu": [
"x64",
"arm64"
@@ -52,11 +52,11 @@
"uuid": "^9.0.0"
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.11.1-beta.1",
"@lancedb/vectordb-darwin-x64": "0.11.1-beta.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.11.1-beta.1",
"@lancedb/vectordb-linux-x64-gnu": "0.11.1-beta.1",
"@lancedb/vectordb-win32-x64-msvc": "0.11.1-beta.1"
"@lancedb/vectordb-darwin-arm64": "0.12.0",
"@lancedb/vectordb-darwin-x64": "0.12.0",
"@lancedb/vectordb-linux-arm64-gnu": "0.12.0",
"@lancedb/vectordb-linux-x64-gnu": "0.12.0",
"@lancedb/vectordb-win32-x64-msvc": "0.12.0"
},
"peerDependencies": {
"@apache-arrow/ts": "^14.0.2",
@@ -327,65 +327,60 @@
}
},
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.11.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.11.1-beta.1.tgz",
"integrity": "sha512-q9jcCbmcz45UHmjgecL6zK82WaqUJsARfniwXXPcnd8ooISVhPkgN+RVKv6edwI9T0PV+xVRYq+LQLlZu5fyxw==",
"version": "0.12.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.12.0.tgz",
"integrity": "sha512-9X6UyP/ozHkv39YZ8DWh82m3aeQmUtrVDNuRe3o8has6dJyD/qPYukI8Zked4q8J+86/lgQbr4f+WW2V4Dfc1g==",
"cpu": [
"arm64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.11.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.11.1-beta.1.tgz",
"integrity": "sha512-E5tCTS5TaTkssTPa+gdnFxZJ1f60jnSIJXhqufNFZk4s+IMViwR1BPqaqE++WY5c1uBI55ef1862CROKDKX4gg==",
"version": "0.12.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.12.0.tgz",
"integrity": "sha512-zG+//P3BBpmOiLR+dop68T9AFNxazWlSLF8yVdAtvsqjRzcrrMLR//rIrRcbPHxu8gvvLrMDoDZT+AHd2rElyQ==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.11.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.11.1-beta.1.tgz",
"integrity": "sha512-Obohy6TH31Uq+fp6ZisHR7iAsvgVPqBExrycVcIJqrLZnIe88N9OWUwBXkmfMAw/2hNJFwD4tU7+4U2FcBWX4w==",
"version": "0.12.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.12.0.tgz",
"integrity": "sha512-5RiJkcZEdMkK5WUfkV+HVFnJaAergfSiLNgUwJaovEEX8yVChkhrdZFSUj1o/k2k6Ix9mQq+xfIUF+aGN/XnDQ==",
"cpu": [
"arm64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.11.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.11.1-beta.1.tgz",
"integrity": "sha512-3Meu0dgrzNrnBVVQhxkUSAOhQNmgtKHvOvmrRLUicV+X19hd33udihgxVpZZb9mpXenJ8lZsS+Jq6R0hWqntag==",
"version": "0.12.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.12.0.tgz",
"integrity": "sha512-JFulRNBHLF0TyE0tThaAB9T7CM3zLquPsBF6oA9b1stVdXbEqVqLMltjem0tqfj30zEoEbAKDPpEKII4CPQMTA==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.11.1-beta.1",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.11.1-beta.1.tgz",
"integrity": "sha512-BafZ9OJPQXsS7JW0weAl12wC+827AiRjfUrE5tvrYWZah2OwCF2U2g6uJ3x4pxfwEGsv5xcHFqgxlS7ttFkh+Q==",
"version": "0.12.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.12.0.tgz",
"integrity": "sha512-T3s/RzB5dvXBqU3qmS6zyHhF0RHS2sSs81zKzYQy2R2nEVPbnwutFSsdA1wEqEXZlr8uTD9nLbkKJKqRNTXVEg==",
"cpu": [
"x64"
],
"license": "Apache-2.0",
"optional": true,
"os": [
"win32"

View File

@@ -1,6 +1,6 @@
{
"name": "vectordb",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.0",
"description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js",
"types": "dist/index.d.ts",
@@ -88,10 +88,10 @@
}
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.11.1-beta.1",
"@lancedb/vectordb-darwin-x64": "0.11.1-beta.1",
"@lancedb/vectordb-linux-arm64-gnu": "0.11.1-beta.1",
"@lancedb/vectordb-linux-x64-gnu": "0.11.1-beta.1",
"@lancedb/vectordb-win32-x64-msvc": "0.11.1-beta.1"
"@lancedb/vectordb-darwin-arm64": "0.13.0-beta.0",
"@lancedb/vectordb-darwin-x64": "0.13.0-beta.0",
"@lancedb/vectordb-linux-arm64-gnu": "0.13.0-beta.0",
"@lancedb/vectordb-linux-x64-gnu": "0.13.0-beta.0",
"@lancedb/vectordb-win32-x64-msvc": "0.13.0-beta.0"
}
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import axios, { type AxiosResponse, type ResponseType } from 'axios'
import axios, { type AxiosError, type AxiosResponse, type ResponseType } from 'axios'
import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow'
@@ -197,7 +197,7 @@ export class HttpLancedbClient {
response = await callWithMiddlewares(req, this._middlewares)
return response
} catch (err: any) {
console.error('error: ', err)
console.error(serializeErrorAsJson(err))
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
@@ -247,7 +247,8 @@ export class HttpLancedbClient {
// return response
} catch (err: any) {
console.error('error: ', err)
console.error(serializeErrorAsJson(err))
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
@@ -287,3 +288,15 @@ export class HttpLancedbClient {
return clone
}
}
function serializeErrorAsJson(err: AxiosError) {
const error = JSON.parse(JSON.stringify(err, Object.getOwnPropertyNames(err)))
error.response = err.response != null
? JSON.parse(JSON.stringify(
err.response,
// config contains the request data, too noisy
Object.getOwnPropertyNames(err.response).filter(prop => prop !== 'config')
))
: null
return JSON.stringify({ error })
}

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.11.1-beta.1"
version = "0.13.0-beta.0"
license.workspace = true
description.workspace = true
repository.workspace = true

View File

@@ -402,6 +402,40 @@ describe("When creating an index", () => {
expect(rst.numRows).toBe(1);
});
it("should be able to query unindexed data", async () => {
await tbl.createIndex("vec");
await tbl.add([
{
id: 300,
vec: Array(32)
.fill(1)
.map(() => Math.random()),
tags: [],
},
]);
const plan1 = await tbl.query().nearestTo(queryVec).explainPlan(true);
expect(plan1).toMatch("LanceScan");
const plan2 = await tbl
.query()
.nearestTo(queryVec)
.fastSearch()
.explainPlan(true);
expect(plan2).not.toMatch("LanceScan");
});
it("should be able to query with row id", async () => {
const results = await tbl
.query()
.nearestTo(queryVec)
.withRowId()
.limit(1)
.toArray();
expect(results.length).toBe(1);
expect(results[0]).toHaveProperty("_rowid");
});
it("should allow parameters to be specified", async () => {
await tbl.createIndex("vec", {
config: Index.ivfPq({

View File

@@ -239,6 +239,29 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
return this;
}
/**
* Skip searching un-indexed data. This can make search faster, but will miss
* any data that is not yet indexed.
*
* Use {@link lancedb.Table#optimize} to index all un-indexed data.
*/
fastSearch(): this {
this.doCall((inner: NativeQueryType) => inner.fastSearch());
return this;
}
/**
* Whether to return the row id in the results.
*
* This column can be used to match results between different queries. For
* example, to match results from a full text search and a vector search in
* order to perform hybrid search.
*/
withRowId(): this {
this.doCall((inner: NativeQueryType) => inner.withRowId());
return this;
}
protected nativeExecute(
options?: Partial<QueryExecutionOptions>,
): Promise<NativeBatchIterator> {

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.0",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.0",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.0",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.0",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.0",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.11.1-beta.1",
"version": "0.12.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.11.1-beta.1",
"version": "0.12.0",
"cpu": [
"x64",
"arm64"

View File

@@ -10,7 +10,7 @@
"vector database",
"ann"
],
"version": "0.11.1-beta.1",
"version": "0.13.0-beta.0",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -80,6 +80,16 @@ impl Query {
Ok(VectorQuery { inner })
}
#[napi]
pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search();
}
#[napi]
pub fn with_row_id(&mut self) {
self.inner = self.inner.clone().with_row_id();
}
#[napi(catch_unwind)]
pub async fn execute(
&self,
@@ -183,6 +193,16 @@ impl VectorQuery {
self.inner = self.inner.clone().offset(offset as usize);
}
#[napi]
pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search();
}
#[napi]
pub fn with_row_id(&mut self) {
self.inner = self.inner.clone().with_row_id();
}
#[napi(catch_unwind)]
pub async fn execute(
&self,

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.15.0"
current_version = "0.16.0-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.15.0"
version = "0.16.0-beta.0"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -3,13 +3,11 @@ name = "lancedb"
# version in Cargo.toml
dependencies = [
"deprecation",
"pylance==0.19.1",
"requests>=2.31.0",
"nest-asyncio~=1.0",
"pylance==0.19.2-beta.3",
"tqdm>=4.27.0",
"pydantic>=1.10",
"attrs>=21.3.0",
"packaging",
"cachetools",
"overrides>=0.7",
]
description = "lancedb"
@@ -61,6 +59,7 @@ dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = [
"requests>=2.31.0",
"openai>=1.6.1",
"sentence-transformers",
"torch",

View File

@@ -19,12 +19,10 @@ from typing import Dict, Optional, Union, Any
__version__ = importlib.metadata.version("lancedb")
from lancedb.remote import ClientConfig
from ._lancedb import connect as lancedb_connect
from .common import URI, sanitize_uri
from .db import AsyncConnection, DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection
from .remote import ClientConfig
from .schema import vector
from .table import AsyncTable
@@ -37,6 +35,7 @@ def connect(
host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
**kwargs: Any,
) -> DBConnection:
"""Connect to a LanceDB database.
@@ -64,14 +63,10 @@ def connect(
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
request_thread_pool: int or ThreadPoolExecutor, optional
The thread pool to use for making batch requests to the LanceDB Cloud API.
If an integer, then a ThreadPoolExecutor will be created with that
number of threads. If None, then a ThreadPoolExecutor will be created
with the default number of threads. If a ThreadPoolExecutor, then that
executor will be used for making requests. This is for LanceDB Cloud
only and is only used when making batch requests (i.e., passing in
multiple queries to the search method at once).
client_config: ClientConfig or dict, optional
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
the keys are the attributes of the ClientConfig class. If None, then the
default configuration is used.
Examples
--------
@@ -94,6 +89,8 @@ def connect(
conn : DBConnection
A connection to a LanceDB database.
"""
from .remote.db import RemoteDBConnection
if isinstance(uri, str) and uri.startswith("db://"):
if api_key is None:
api_key = os.environ.get("LANCEDB_API_KEY")
@@ -106,7 +103,9 @@ def connect(
api_key,
region,
host_override,
# TODO: remove this (deprecation warning downstream)
request_thread_pool=request_thread_pool,
client_config=client_config,
**kwargs,
)

View File

@@ -36,6 +36,8 @@ class Connection(object):
data_storage_version: Optional[str] = None,
enable_v2_manifest_paths: Optional[bool] = None,
) -> Table: ...
async def rename_table(self, old_name: str, new_name: str) -> None: ...
async def drop_table(self, name: str) -> None: ...
class Table:
def name(self) -> str: ...

View File

@@ -817,6 +817,18 @@ class AsyncConnection(object):
table = await self._inner.open_table(name, storage_options, index_cache_size)
return AsyncTable(table)
async def rename_table(self, old_name: str, new_name: str):
"""Rename a table in the database.
Parameters
----------
old_name: str
The current name of the table.
new_name: str
The new name of the table.
"""
await self._inner.rename_table(old_name, new_name)
async def drop_table(self, name: str):
"""Drop a table from the database.

View File

@@ -13,7 +13,6 @@
import os
import io
import requests
import base64
from urllib.parse import urlparse
from pathlib import Path
@@ -226,6 +225,8 @@ class JinaEmbeddings(EmbeddingFunction):
return [result["embedding"] for result in sorted_embeddings]
def _init_client(self):
import requests
if JinaEmbeddings._session is None:
if self.api_key is None and os.environ.get("JINA_API_KEY") is None:
api_key_not_found_help("jina")

View File

@@ -467,6 +467,8 @@ class IvfPq:
The default value is 256.
"""
if distance_type is not None:
distance_type = distance_type.lower()
self._inner = LanceDbIndex.ivf_pq(
distance_type=distance_type,
num_partitions=num_partitions,

View File

@@ -481,6 +481,7 @@ class LanceQueryBuilder(ABC):
>>> plan = table.search(query).explain_plan(True)
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
GlobalLimitExec: skip=0, fetch=10
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
KNNVectorDistance: metric=l2
@@ -500,7 +501,16 @@ class LanceQueryBuilder(ABC):
nearest={
"column": self._vector_column,
"q": self._query,
"k": self._limit,
"metric": self._metric,
"nprobes": self._nprobes,
"refine_factor": self._refine_factor,
},
prefilter=self._prefilter,
filter=self._str_query,
limit=self._limit,
with_row_id=self._with_row_id,
offset=self._offset,
).explain_plan(verbose)
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
@@ -1315,6 +1325,48 @@ class AsyncQueryBase(object):
self._inner.offset(offset)
return self
def fast_search(self) -> AsyncQuery:
"""
Skip searching un-indexed data.
This can make queries faster, but will miss any data that has not been
indexed.
!!! tip
You can add new data into an existing index by calling
[AsyncTable.optimize][lancedb.table.AsyncTable.optimize].
"""
self._inner.fast_search()
return self
def with_row_id(self) -> AsyncQuery:
"""
Include the _rowid column in the results.
"""
self._inner.with_row_id()
return self
def postfilter(self) -> AsyncQuery:
"""
If this is called then filtering will happen after the search instead of
before.
By default filtering will be performed before the search. This is how
filtering is typically understood to work. This prefilter step does add some
additional latency. Creating a scalar index on the filter column(s) can
often improve this latency. However, sometimes a filter is too complex or
scalar indices cannot be applied to the column. In these cases postfiltering
can be used instead of prefiltering to improve latency.
Post filtering applies the filter to the results of the search. This
means we only run the filter on a much smaller set of data. However, it can
cause the query to return fewer than `limit` results (or even no results) if
none of the nearest results match the filter.
Post filtering happens during the "refine stage" (described in more detail in
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
factor can often help restore some of the results lost by post filtering.
"""
self._inner.postfilter()
return self
async def to_batches(
self, *, max_batch_length: Optional[int] = None
) -> AsyncRecordBatchReader:
@@ -1618,30 +1670,6 @@ class AsyncVectorQuery(AsyncQueryBase):
self._inner.distance_type(distance_type)
return self
def postfilter(self) -> AsyncVectorQuery:
"""
If this is called then filtering will happen after the vector search instead of
before.
By default filtering will be performed before the vector search. This is how
filtering is typically understood to work. This prefilter step does add some
additional latency. Creating a scalar index on the filter column(s) can
often improve this latency. However, sometimes a filter is too complex or
scalar indices cannot be applied to the column. In these cases postfiltering
can be used instead of prefiltering to improve latency.
Post filtering applies the filter to the results of the vector search. This
means we only run the filter on a much smaller set of data. However, it can
cause the query to return fewer than `limit` results (or even no results) if
none of the nearest results match the filter.
Post filtering happens during the "refine stage" (described in more detail in
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
factor can often help restore some of the results lost by post filtering.
"""
self._inner.postfilter()
return self
def bypass_vector_index(self) -> AsyncVectorQuery:
"""
If this is called then any vector index is skipped

View File

@@ -11,62 +11,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import timedelta
from typing import List, Optional
import attrs
from lancedb import __version__
import pyarrow as pa
from pydantic import BaseModel
from lancedb.common import VECTOR_COLUMN_NAME
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
class VectorQuery(BaseModel):
# vector to search for
vector: List[float]
# sql filter to refine the query with
filter: Optional[str] = None
# top k results to return
k: int
# # metrics
_metric: str = "L2"
# which columns to return in the results
columns: Optional[List[str]] = None
# optional query parameters for tuning the results,
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
nprobes: int = 10
refine_factor: Optional[int] = None
vector_column: str = VECTOR_COLUMN_NAME
fast_search: bool = False
@attrs.define
class VectorQueryResult:
# for now the response is directly seralized into a pandas dataframe
tbl: pa.Table
def to_arrow(self) -> pa.Table:
return self.tbl
class LanceDBClient(abc.ABC):
@abc.abstractmethod
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query the LanceDB server for the given table and query."""
pass
__all__ = ["TimeoutConfig", "RetryConfig", "ClientConfig"]
@dataclass
@@ -165,8 +116,8 @@ class RetryConfig:
@dataclass
class ClientConfig:
user_agent: str = f"LanceDB-Python-Client/{__version__}"
retry_config: Optional[RetryConfig] = None
timeout_config: Optional[TimeoutConfig] = None
retry_config: RetryConfig = field(default_factory=RetryConfig)
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
def __post_init__(self):
if isinstance(self.retry_config, dict):

View File

@@ -1,25 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Union
import pyarrow as pa
def to_ipc_binary(table: Union[pa.Table, Iterable[pa.RecordBatch]]) -> bytes:
"""Serialize a PyArrow Table to IPC binary."""
sink = pa.BufferOutputStream()
if isinstance(table, Iterable):
table = pa.Table.from_batches(table)
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_table(table)
return sink.getvalue().to_pybytes()

View File

@@ -1,269 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Union
from urllib.parse import urljoin
import attrs
import pyarrow as pa
import requests
from pydantic import BaseModel
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult
from lancedb.remote.connection_timeout import LanceDBClientHTTPAdapterFactory
from lancedb.remote.errors import LanceDBClientError
ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream"
def _check_not_closed(f):
@functools.wraps(f)
def wrapped(self, *args, **kwargs):
if self.closed:
raise ValueError("Connection is closed")
return f(self, *args, **kwargs)
return wrapped
def _read_ipc(resp: requests.Response) -> pa.Table:
resp_body = resp.content
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
return reader.read_all()
@attrs.define(slots=False)
class RestfulLanceDBClient:
db_name: str
region: str
api_key: Credential
host_override: Optional[str] = attrs.field(default=None)
closed: bool = attrs.field(default=False, init=False)
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
read_timeout: float = attrs.field(default=300.0, kw_only=True)
@functools.cached_property
def session(self) -> requests.Session:
sess = requests.Session()
retry_adapter_instance = retry_adapter(retry_adapter_options())
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
adapter_class = LanceDBClientHTTPAdapterFactory()
sess.mount("https://", adapter_class())
return sess
@property
def url(self) -> str:
return (
self.host_override
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
return False # Do not suppress exceptions
def close(self):
self.session.close()
self.closed = True
@functools.cached_property
def headers(self) -> Dict[str, str]:
headers = {
"x-api-key": self.api_key,
}
if self.region == "local": # Local test mode
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
if self.host_override:
headers["x-lancedb-database"] = self.db_name
return headers
@staticmethod
def _check_status(resp: requests.Response):
# Leaving request id empty for now, as we'll be replacing this impl
# with the Rust one shortly.
if resp.status_code == 404:
raise LanceDBClientError(
f"Not found: {resp.text}", request_id="", status_code=404
)
elif 400 <= resp.status_code < 500:
raise LanceDBClientError(
f"Bad Request: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
elif 500 <= resp.status_code < 600:
raise LanceDBClientError(
f"Internal Server Error: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
elif resp.status_code != 200:
raise LanceDBClientError(
f"Unknown Error: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
@_check_not_closed
def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
"""Send a GET request and returns the deserialized response payload."""
if isinstance(params, BaseModel):
params: Dict[str, Any] = params.dict(exclude_none=True)
with self.session.get(
urljoin(self.url, uri),
params=params,
headers=self.headers,
timeout=(self.connection_timeout, self.read_timeout),
) as resp:
self._check_status(resp)
return resp.json()
@_check_not_closed
def post(
self,
uri: str,
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
params: Optional[Dict[str, Any]] = None,
content_type: Optional[str] = None,
deserialize: Callable = lambda resp: resp.json(),
request_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Send a POST request and returns the deserialized response payload.
Parameters
----------
uri : str
The uri to send the POST request to.
data: Union[Dict[str, Any], BaseModel]
request_id: Optional[str]
Optional client side request id to be sent in the request headers.
"""
if isinstance(data, BaseModel):
data: Dict[str, Any] = data.dict(exclude_none=True)
if isinstance(data, bytes):
req_kwargs = {"data": data}
else:
req_kwargs = {"json": data}
headers = self.headers.copy()
if content_type is not None:
headers["content-type"] = content_type
if request_id is not None:
headers["x-request-id"] = request_id
with self.session.post(
urljoin(self.url, uri),
headers=headers,
params=params,
timeout=(self.connection_timeout, self.read_timeout),
**req_kwargs,
) as resp:
self._check_status(resp)
return deserialize(resp)
@_check_not_closed
def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]:
"""List all tables in the database."""
if page_token is None:
page_token = ""
json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
return json["tables"]
@_check_not_closed
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query a table."""
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
return VectorQueryResult(tbl)
def mount_retry_adapter_for_table(self, table_name: str) -> None:
"""
Adds an http adapter to session that will retry retryable requests to the table.
"""
retry_options = retry_adapter_options(methods=["GET", "POST"])
retry_adapter_instance = retry_adapter(retry_options)
session = self.session
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
)
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
retry_adapter_instance,
)
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
retry_adapter_instance,
)
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
return {
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
"backoff_factor": float(
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
),
"backoff_jitter": float(
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
),
"statuses": [
int(i.strip())
for i in os.environ.get(
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
).split(",")
],
"methods": methods,
}
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
total_retries = options["retries"]
connect_retries = options["connect_retries"]
read_retries = options["read_retries"]
backoff_factor = options["backoff_factor"]
backoff_jitter = options["backoff_jitter"]
statuses = options["statuses"]
methods = frozenset(options["methods"])
logging.debug(
f"Setting up retry adapter with {total_retries} retries," # noqa G003
+ f"connect retries {connect_retries}, read retries {read_retries},"
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
+ f"methods {methods}"
)
return HTTPAdapter(
max_retries=Retry(
total=total_retries,
connect=connect_retries,
read=read_retries,
backoff_factor=backoff_factor,
backoff_jitter=backoff_jitter,
status_forcelist=statuses,
allowed_methods=methods,
)
)

View File

@@ -1,115 +0,0 @@
# Copyright 2024 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This module contains an adapter that will close connections if they have not been
# used before a certain timeout. This is necessary because some load balancers will
# close connections after a certain amount of time, but the request module may not yet
# have received the FIN/ACK and will try to reuse the connection.
#
# TODO some of the code here can be simplified if/when this PR is merged:
# https://github.com/urllib3/urllib3/pull/3275
import datetime
import logging
import os
from requests.adapters import HTTPAdapter
from urllib3.connection import HTTPSConnection
from urllib3.connectionpool import HTTPSConnectionPool
from urllib3.poolmanager import PoolManager
def get_client_connection_timeout() -> int:
return int(os.environ.get("LANCE_CLIENT_CONNECTION_TIMEOUT", "300"))
class LanceDBHTTPSConnection(HTTPSConnection):
"""
HTTPSConnection that tracks the last time it was used.
"""
idle_timeout: datetime.timedelta
last_activity: datetime.datetime
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_activity = datetime.datetime.now()
def request(self, *args, **kwargs):
self.last_activity = datetime.datetime.now()
super().request(*args, **kwargs)
def is_expired(self):
return datetime.datetime.now() - self.last_activity > self.idle_timeout
def LanceDBHTTPSConnectionPoolFactory(client_idle_timeout: int):
"""
Creates a connection pool class that can be used to close idle connections.
"""
class LanceDBHTTPSConnectionPool(HTTPSConnectionPool):
# override the connection class
ConnectionCls = LanceDBHTTPSConnection
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _get_conn(self, timeout: float | None = None):
logging.debug("Getting https connection")
conn = super()._get_conn(timeout)
if conn.is_expired():
logging.debug("Closing expired connection")
conn.close()
return conn
def _new_conn(self):
conn = super()._new_conn()
conn.idle_timeout = datetime.timedelta(seconds=client_idle_timeout)
return conn
return LanceDBHTTPSConnectionPool
class LanceDBClientPoolManager(PoolManager):
def __init__(
self, client_idle_timeout: int, num_pools: int, maxsize: int, **kwargs
):
super().__init__(num_pools=num_pools, maxsize=maxsize, **kwargs)
# inject our connection pool impl
connection_pool_class = LanceDBHTTPSConnectionPoolFactory(
client_idle_timeout=client_idle_timeout
)
self.pool_classes_by_scheme["https"] = connection_pool_class
def LanceDBClientHTTPAdapterFactory():
"""
Creates an HTTPAdapter class that can be used to close idle connections
"""
# closure over the timeout
client_idle_timeout = get_client_connection_timeout()
class LanceDBClientRequestHTTPAdapter(HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False):
# inject our pool manager impl
self.poolmanager = LanceDBClientPoolManager(
client_idle_timeout=client_idle_timeout,
num_pools=connections,
maxsize=maxsize,
block=block,
)
return LanceDBClientRequestHTTPAdapter

View File

@@ -11,13 +11,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from datetime import timedelta
import logging
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union
from urllib.parse import urlparse
import warnings
from cachetools import TTLCache
from lancedb import connect_async
from lancedb.remote import ClientConfig
import pyarrow as pa
from overrides import override
@@ -25,10 +28,8 @@ from ..common import DATA
from ..db import DBConnection
from ..embeddings import EmbeddingFunctionConfig
from ..pydantic import LanceModel
from ..table import Table, sanitize_create_table
from ..table import Table
from ..util import validate_table_name
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
class RemoteDBConnection(DBConnection):
@@ -41,26 +42,70 @@ class RemoteDBConnection(DBConnection):
region: str,
host_override: Optional[str] = None,
request_thread_pool: Optional[ThreadPoolExecutor] = None,
connection_timeout: float = 120.0,
read_timeout: float = 300.0,
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
connection_timeout: Optional[float] = None,
read_timeout: Optional[float] = None,
):
"""Connect to a remote LanceDB database."""
if isinstance(client_config, dict):
client_config = ClientConfig(**client_config)
elif client_config is None:
client_config = ClientConfig()
# These are legacy options from the old Python-based client. We keep them
# here for backwards compatibility, but will remove them in a future release.
if request_thread_pool is not None:
warnings.warn(
"request_thread_pool is no longer used and will be removed in "
"a future release.",
DeprecationWarning,
)
if connection_timeout is not None:
warnings.warn(
"connection_timeout is deprecated and will be removed in a future "
"release. Please use client_config.timeout_config.connect_timeout "
"instead.",
DeprecationWarning,
)
client_config.timeout_config.connect_timeout = timedelta(
seconds=connection_timeout
)
if read_timeout is not None:
warnings.warn(
"read_timeout is deprecated and will be removed in a future release. "
"Please use client_config.timeout_config.read_timeout instead.",
DeprecationWarning,
)
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)
parsed = urlparse(db_url)
if parsed.scheme != "db":
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self._uri = str(db_url)
self.db_name = parsed.netloc
self.api_key = api_key
self._client = RestfulLanceDBClient(
self.db_name,
region,
api_key,
host_override,
connection_timeout=connection_timeout,
read_timeout=read_timeout,
import nest_asyncio
nest_asyncio.apply()
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self.client_config = client_config
self._conn = self._loop.run_until_complete(
connect_async(
db_url,
api_key=api_key,
region=region,
host_override=host_override,
client_config=client_config,
)
)
self._request_thread_pool = request_thread_pool
self._table_cache = TTLCache(maxsize=10000, ttl=300)
def __repr__(self) -> str:
return f"RemoteConnect(name={self.db_name})"
@@ -82,16 +127,9 @@ class RemoteDBConnection(DBConnection):
-------
An iterator of table names.
"""
while True:
result = self._client.list_tables(limit, page_token)
if len(result) > 0:
page_token = result[len(result) - 1]
else:
break
for item in result:
self._table_cache[item] = True
yield item
return self._loop.run_until_complete(
self._conn.table_names(start_after=page_token, limit=limit)
)
@override
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
@@ -108,20 +146,14 @@ class RemoteDBConnection(DBConnection):
"""
from .table import RemoteTable
self._client.mount_retry_adapter_for_table(name)
if index_cache_size is not None:
logging.info(
"index_cache_size is ignored in LanceDb Cloud"
" (there is no local cache to configure)"
)
# check if table exists
if self._table_cache.get(name) is None:
self._client.post(f"/v1/table/{name}/describe/")
self._table_cache[name] = True
return RemoteTable(self, name)
table = self._loop.run_until_complete(self._conn.open_table(name))
return RemoteTable(table, self.db_name, self._loop)
@override
def create_table(
@@ -233,27 +265,20 @@ class RemoteDBConnection(DBConnection):
"Please vote https://github.com/lancedb/lancedb/issues/626 "
"for this feature."
)
if mode is not None:
logging.warning("mode is not yet supported on LanceDB Cloud.")
data, schema = sanitize_create_table(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
from .table import RemoteTable
data = to_ipc_binary(data)
request_id = uuid.uuid4().hex
self._client.post(
f"/v1/table/{name}/create/",
data=data,
request_id=request_id,
content_type=ARROW_STREAM_CONTENT_TYPE,
table = self._loop.run_until_complete(
self._conn.create_table(
name,
data,
mode=mode,
schema=schema,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
)
self._table_cache[name] = True
return RemoteTable(self, name)
return RemoteTable(table, self.db_name, self._loop)
@override
def drop_table(self, name: str):
@@ -264,11 +289,7 @@ class RemoteDBConnection(DBConnection):
name: str
The name of the table.
"""
self._client.post(
f"/v1/table/{name}/drop/",
)
self._table_cache.pop(name, default=None)
self._loop.run_until_complete(self._conn.drop_table(name))
@override
def rename_table(self, cur_name: str, new_name: str):
@@ -281,12 +302,7 @@ class RemoteDBConnection(DBConnection):
new_name: str
The new name of the table.
"""
self._client.post(
f"/v1/table/{cur_name}/rename/",
data={"new_table_name": new_name},
)
self._table_cache.pop(cur_name, default=None)
self._table_cache[new_name] = True
self._loop.run_until_complete(self._conn.rename_table(cur_name, new_name))
async def close(self):
"""Close the connection to the database."""

View File

@@ -11,53 +11,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import uuid
from concurrent.futures import Future
from functools import cached_property
from typing import Dict, Iterable, List, Optional, Union, Literal
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
import pyarrow as pa
from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder
from lancedb.embeddings import EmbeddingFunctionRegistry
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
from ..table import Query, Table, _sanitize_data
from ..util import value_to_sql, infer_vector_column_name
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE
from .db import RemoteDBConnection
from ..table import AsyncTable, Query, Table
class RemoteTable(Table):
def __init__(self, conn: RemoteDBConnection, name: str):
self._conn = conn
self.name = name
def __init__(
self,
table: AsyncTable,
db_name: str,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
self._loop = loop
self._table = table
self.db_name = db_name
@property
def name(self) -> str:
"""The name of the table"""
return self._table.name
def __repr__(self) -> str:
return f"RemoteTable({self._conn.db_name}.{self.name})"
return f"RemoteTable({self.db_name}.{self.name})"
def __len__(self) -> int:
self.count_rows(None)
@cached_property
@property
def schema(self) -> pa.Schema:
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
of this Table
"""
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
schema = json_to_schema(resp["schema"])
return schema
return self._loop.run_until_complete(self._table.schema())
@property
def version(self) -> int:
"""Get the current version of the table"""
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
return resp["version"]
return self._loop.run_until_complete(self._table.version())
@cached_property
def embedding_functions(self) -> dict:
@@ -84,20 +87,18 @@ class RemoteTable(Table):
def list_indices(self):
"""List all the indices on the table"""
resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/")
return resp
return self._loop.run_until_complete(self._table.list_indices())
def index_stats(self, index_uuid: str):
"""List all the stats of a specified index"""
resp = self._conn._client.post(
f"/v1/table/{self.name}/index/{index_uuid}/stats/"
)
return resp
return self._loop.run_until_complete(self._table.index_stats(index_uuid))
def create_scalar_index(
self,
column: str,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
*,
replace: bool = False,
):
"""Creates a scalar index
Parameters
@@ -107,20 +108,23 @@ class RemoteTable(Table):
or string column.
index_type : str
The index type of the scalar index. Must be "scalar" (BTREE),
"BTREE", "BITMAP", or "LABEL_LIST"
"BTREE", "BITMAP", or "LABEL_LIST",
replace : bool
If True, replace the existing index with the new one.
"""
if index_type == "scalar" or index_type == "BTREE":
config = BTree()
elif index_type == "BITMAP":
config = Bitmap()
elif index_type == "LABEL_LIST":
config = LabelList()
else:
raise ValueError(f"Unknown index type: {index_type}")
data = {
"column": column,
"index_type": index_type,
"replace": True,
}
resp = self._conn._client.post(
f"/v1/table/{self.name}/create_scalar_index/", data=data
self._loop.run_until_complete(
self._table.create_index(column, config=config, replace=replace)
)
return resp
def create_fts_index(
self,
column: str,
@@ -128,15 +132,10 @@ class RemoteTable(Table):
replace: bool = False,
with_position: bool = True,
):
data = {
"column": column,
"index_type": "FTS",
"replace": replace,
}
resp = self._conn._client.post(
f"/v1/table/{self.name}/create_index/", data=data
config = FTS(with_position=with_position)
self._loop.run_until_complete(
self._table.create_index(column, config=config, replace=replace)
)
return resp
def create_index(
self,
@@ -204,17 +203,22 @@ class RemoteTable(Table):
"Existing indexes will always be replaced."
)
data = {
"column": vector_column_name,
"index_type": index_type,
"metric_type": metric,
"index_cache_size": index_cache_size,
}
resp = self._conn._client.post(
f"/v1/table/{self.name}/create_index/", data=data
)
index_type = index_type.upper()
if index_type == "VECTOR" or index_type == "IVF_PQ":
config = IvfPq(distance_type=metric)
elif index_type == "IVF_HNSW_PQ":
config = HnswPq(distance_type=metric)
elif index_type == "IVF_HNSW_SQ":
config = HnswSq(distance_type=metric)
else:
raise ValueError(
f"Unknown vector index type: {index_type}. Valid options are"
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
)
return resp
self._loop.run_until_complete(
self._table.create_index(vector_column_name, config=config)
)
def add(
self,
@@ -246,22 +250,10 @@ class RemoteTable(Table):
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
data, _ = _sanitize_data(
data,
self.schema,
metadata=self.schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)
request_id = uuid.uuid4().hex
self._conn._client.post(
f"/v1/table/{self.name}/insert/",
data=payload,
params={"request_id": request_id, "mode": mode},
content_type=ARROW_STREAM_CONTENT_TYPE,
self._loop.run_until_complete(
self._table.add(
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
)
def search(
@@ -337,12 +329,6 @@ class RemoteTable(Table):
# empty query builder is not supported in saas, raise error
if query is None and query_type != "hybrid":
raise ValueError("Empty query is not supported")
vector_column_name = infer_vector_column_name(
schema=self.schema,
query_type=query_type,
query=query,
vector_column_name=vector_column_name,
)
return LanceQueryBuilder.create(
self,
@@ -356,37 +342,9 @@ class RemoteTable(Table):
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
if (
query.vector is not None
and len(query.vector) > 0
and not isinstance(query.vector[0], float)
):
if self._conn._request_thread_pool is None:
def submit(name, q):
f = Future()
f.set_result(self._conn._client.query(name, q))
return f
else:
def submit(name, q):
return self._conn._request_thread_pool.submit(
self._conn._client.query, name, q
)
results = []
for v in query.vector:
v = list(v)
q = query.copy()
q.vector = v
results.append(submit(self.name, q))
return pa.concat_tables(
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
).to_reader()
else:
result = self._conn._client.query(self.name, query)
return result.to_arrow().to_reader()
return self._loop.run_until_complete(
self._table._execute_query(query, batch_size=batch_size)
)
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
@@ -403,42 +361,8 @@ class RemoteTable(Table):
on_bad_vectors: str,
fill_value: float,
):
data, _ = _sanitize_data(
new_data,
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)
params = {}
if len(merge._on) != 1:
raise ValueError(
"RemoteTable only supports a single on key in merge_insert"
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
if merge._when_matched_update_all_condition is not None:
params["when_matched_update_all_filt"] = (
merge._when_matched_update_all_condition
)
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()
params["when_not_matched_by_source_delete"] = str(
merge._when_not_matched_by_source_delete
).lower()
if merge._when_not_matched_by_source_condition is not None:
params["when_not_matched_by_source_delete_filt"] = (
merge._when_not_matched_by_source_condition
)
self._conn._client.post(
f"/v1/table/{self.name}/merge_insert/",
data=payload,
params=params,
content_type=ARROW_STREAM_CONTENT_TYPE,
self._loop.run_until_complete(
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
)
def delete(self, predicate: str):
@@ -488,8 +412,7 @@ class RemoteTable(Table):
x vector _distance # doctest: +SKIP
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
"""
payload = {"predicate": predicate}
self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload)
self._loop.run_until_complete(self._table.delete(predicate))
def update(
self,
@@ -539,18 +462,9 @@ class RemoteTable(Table):
2 2 [10.0, 10.0] # doctest: +SKIP
"""
if values is not None and values_sql is not None:
raise ValueError("Only one of values or values_sql can be provided")
if values is None and values_sql is None:
raise ValueError("Either values or values_sql must be provided")
if values is not None:
updates = [[k, value_to_sql(v)] for k, v in values.items()]
else:
updates = [[k, v] for k, v in values_sql.items()]
payload = {"predicate": where, "updates": updates}
self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload)
self._loop.run_until_complete(
self._table.update(where=where, updates=values, updates_sql=values_sql)
)
def cleanup_old_versions(self, *_):
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
@@ -565,11 +479,7 @@ class RemoteTable(Table):
)
def count_rows(self, filter: Optional[str] = None) -> int:
payload = {"predicate": filter}
resp = self._conn._client.post(
f"/v1/table/{self.name}/count_rows/", data=payload
)
return resp
return self._loop.run_until_complete(self._table.count_rows(filter))
def add_columns(self, transforms: Dict[str, str]):
raise NotImplementedError(

View File

@@ -12,7 +12,6 @@
# limitations under the License.
import os
import requests
from functools import cached_property
from typing import Union
@@ -57,6 +56,8 @@ class JinaReranker(Reranker):
@cached_property
def _client(self):
import requests
if os.environ.get("JINA_API_KEY") is None and self.api_key is None:
raise ValueError(
"JINA_API_KEY not set. Either set it in your environment or \

View File

@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from numpy import NaN
from numpy import nan
import pyarrow as pa
from .base import Reranker
@@ -71,7 +71,7 @@ class LinearCombinationReranker(Reranker):
elif self.score == "all":
results = results.append_column(
"_distance",
pa.array([NaN] * len(fts_results), type=pa.float32()),
pa.array([nan] * len(fts_results), type=pa.float32()),
)
return results
@@ -92,7 +92,7 @@ class LinearCombinationReranker(Reranker):
elif self.score == "all":
results = results.append_column(
"_score",
pa.array([NaN] * len(vector_results), type=pa.float32()),
pa.array([nan] * len(vector_results), type=pa.float32()),
)
return results

View File

@@ -62,7 +62,7 @@ if TYPE_CHECKING:
from lance.dataset import CleanupStats, ReaderLike
from ._lancedb import Table as LanceDBTable, OptimizeStats
from .db import LanceDBConnection
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS, HnswPq, HnswSq
pd = safe_import_pandas()
pl = safe_import_polars()
@@ -948,7 +948,9 @@ class Table(ABC):
return _table_uri(self._conn.uri, self.name)
def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]:
if get_uri_scheme(self._dataset_uri) != "file":
from .remote.table import RemoteTable
if isinstance(self, RemoteTable) or get_uri_scheme(self._dataset_uri) != "file":
return ("", None, False)
path = join_uri(self._dataset_uri, "_indices", "fts")
fs, path = fs_from_uri(path)
@@ -2382,7 +2384,9 @@ class AsyncTable:
column: str,
*,
replace: Optional[bool] = None,
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None,
config: Optional[
Union[IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
] = None,
):
"""Create an index to speed up queries
@@ -2535,7 +2539,44 @@ class AsyncTable:
async def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
pass
# The sync remote table calls into this method, so we need to map the
# query to the async version of the query and run that here. This is only
# used for that code path right now.
async_query = self.query().limit(query.k)
if query.offset > 0:
async_query = async_query.offset(query.offset)
if query.columns:
async_query = async_query.select(query.columns)
if query.filter:
async_query = async_query.where(query.filter)
if query.fast_search:
async_query = async_query.fast_search()
if query.with_row_id:
async_query = async_query.with_row_id()
if query.vector:
async_query = (
async_query.nearest_to(query.vector)
.distance_type(query.metric)
.nprobes(query.nprobes)
)
if query.refine_factor:
async_query = async_query.refine_factor(query.refine_factor)
if query.vector_column:
async_query = async_query.column(query.vector_column)
if not query.prefilter:
async_query = async_query.postfilter()
if isinstance(query.full_text_query, str):
async_query = async_query.nearest_to_text(query.full_text_query)
elif isinstance(query.full_text_query, dict):
fts_query = query.full_text_query["query"]
fts_columns = query.full_text_query.get("columns", []) or []
async_query = async_query.nearest_to_text(fts_query, columns=fts_columns)
table = await async_query.to_arrow()
return table.to_reader()
async def _do_merge(
self,
@@ -2781,7 +2822,7 @@ class AsyncTable:
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
return await self._inner.optimize(cleanup_older_than, delete_unverified)
async def list_indices(self) -> IndexConfig:
async def list_indices(self) -> Iterable[IndexConfig]:
"""
List all indices that have been created with Self::create_index
"""
@@ -2865,3 +2906,8 @@ class IndexStatistics:
]
distance_type: Optional[Literal["l2", "cosine", "dot"]] = None
num_indices: Optional[int] = None
# This exists for backwards compatibility with an older API, which returned
# a dictionary instead of a class.
def __getitem__(self, key):
return getattr(self, key)

View File

@@ -18,7 +18,6 @@ import lancedb
import numpy as np
import pandas as pd
import pytest
import requests
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
@@ -108,6 +107,7 @@ def test_basic_text_embeddings(alias, tmp_path):
@pytest.mark.slow
def test_openclip(tmp_path):
import requests
from PIL import Image
db = lancedb.connect(tmp_path)

View File

@@ -235,6 +235,29 @@ async def test_search_fts_async(async_table):
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
assert len(results) == 5
expected_count = await async_table.count_rows(
"count > 5000 and contains(text, 'puppy')"
)
expected_count = min(expected_count, 10)
limited_results_pre_filter = await (
async_table.query()
.nearest_to_text("puppy")
.where("count > 5000")
.limit(10)
.to_list()
)
assert len(limited_results_pre_filter) == expected_count
limited_results_post_filter = await (
async_table.query()
.nearest_to_text("puppy")
.where("count > 5000")
.limit(10)
.postfilter()
.to_list()
)
assert len(limited_results_post_filter) <= expected_count
@pytest.mark.asyncio
async def test_search_fts_specify_column_async(async_table):

View File

@@ -49,7 +49,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
# Can recreate if replace=True
await some_table.create_index("id", replace=True)
indices = await some_table.list_indices()
assert str(indices) == '[Index(BTree, columns=["id"])]'
assert str(indices) == '[Index(BTree, columns=["id"], name="id_idx")]'
assert len(indices) == 1
assert indices[0].index_type == "BTree"
assert indices[0].columns == ["id"]
@@ -64,7 +64,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
async def test_create_bitmap_index(some_table: AsyncTable):
await some_table.create_index("id", config=Bitmap())
indices = await some_table.list_indices()
assert str(indices) == '[Index(Bitmap, columns=["id"])]'
assert str(indices) == '[Index(Bitmap, columns=["id"], name="id_idx")]'
indices = await some_table.list_indices()
assert len(indices) == 1
index_name = indices[0].name
@@ -80,7 +80,7 @@ async def test_create_bitmap_index(some_table: AsyncTable):
async def test_create_label_list_index(some_table: AsyncTable):
await some_table.create_index("tags", config=LabelList())
indices = await some_table.list_indices()
assert str(indices) == '[Index(LabelList, columns=["tags"])]'
assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]'
@pytest.mark.asyncio

View File

@@ -17,6 +17,7 @@ from typing import Optional
import lance
import lancedb
from lancedb.index import IvfPq
import numpy as np
import pandas.testing as tm
import pyarrow as pa
@@ -330,6 +331,12 @@ async def test_query_async(table_async: AsyncTable):
# Also check an empty query
await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
# with row id
await check_query(
table_async.query().select(["id", "vector"]).with_row_id(),
expected_columns=["id", "vector", "_rowid"],
)
@pytest.mark.asyncio
async def test_query_to_arrow_async(table_async: AsyncTable):
@@ -358,6 +365,25 @@ async def test_query_to_pandas_async(table_async: AsyncTable):
assert df.shape == (0, 4)
@pytest.mark.asyncio
async def test_fast_search_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
vectors = pa.FixedShapeTensorArray.from_numpy_ndarray(
np.random.rand(256, 32)
).storage
table = await db.create_table("test", pa.table({"vector": vectors}))
await table.create_index(
"vector", config=IvfPq(num_partitions=1, num_sub_vectors=1)
)
await table.add(pa.table({"vector": vectors}))
q = [1.0] * 32
plan = await table.query().nearest_to(q).explain_plan(True)
assert "LanceScan" in plan
plan = await table.query().nearest_to(q).fast_search().explain_plan(True)
assert "LanceScan" not in plan
def test_explain_plan(table):
q = LanceVectorQueryBuilder(table, [0, 0], "vector")
plan = q.explain_plan(verbose=True)

View File

@@ -1,96 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import attrs
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from aiohttp import web
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
@attrs.define
class MockLanceDBServer:
runner: web.AppRunner = attrs.field(init=False)
site: web.TCPSite = attrs.field(init=False)
async def query_handler(self, request: web.Request) -> web.Response:
table_name = request.match_info["table_name"]
assert table_name == "test_table"
await request.json()
# TODO: do some matching
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
ids = pd.Series(range(10), name="id")
df = pd.DataFrame([vecs, ids]).T
batch = pa.RecordBatch.from_pandas(
df,
schema=pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 128)),
pa.field("id", pa.int64()),
]
),
)
sink = pa.BufferOutputStream()
with pa.ipc.new_file(sink, batch.schema) as writer:
writer.write_batch(batch)
return web.Response(body=sink.getvalue().to_pybytes())
async def setup(self):
app = web.Application()
app.add_routes([web.post("/table/{table_name}", self.query_handler)])
self.runner = web.AppRunner(app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, "localhost", 8111)
async def start(self):
await self.site.start()
async def stop(self):
await self.runner.cleanup()
@pytest.mark.skip(reason="flaky somehow, fix later")
@pytest.mark.asyncio
async def test_e2e_with_mock_server():
mock_server = MockLanceDBServer()
await mock_server.setup()
await mock_server.start()
try:
with RestfulLanceDBClient("lancedb+http://localhost:8111") as client:
df = (
await client.query(
"test_table",
VectorQuery(
vector=np.random.rand(128).tolist(),
k=10,
_metric="L2",
columns=["id", "vector"],
),
)
).to_pandas()
assert "vector" in df.columns
assert "id" in df.columns
assert client.closed
finally:
# make sure we don't leak resources
await mock_server.stop()

View File

@@ -2,91 +2,19 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import contextlib
from datetime import timedelta
import http.server
import json
import threading
from unittest.mock import MagicMock
import uuid
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.remote import ClientConfig
from lancedb.remote.errors import HttpError, RetryError
import pyarrow as pa
from lancedb.remote.client import VectorQuery, VectorQueryResult
import pytest
class FakeLanceDBClient:
def close(self):
pass
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
assert table_name == "test"
t = pa.schema([]).empty_table()
return VectorQueryResult(t)
def post(self, path: str):
pass
def mount_retry_adapter_for_table(self, table_name: str):
pass
def test_remote_db():
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
table.search([1.0, 2.0]).to_pandas()
def test_create_empty_table():
client = MagicMock()
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
conn._client = client
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
client.post.return_value = {"status": "ok"}
table = conn.create_table("test", schema=schema)
assert table.name == "test"
assert client.post.call_args[0][0] == "/v1/table/test/create/"
json_schema = {
"fields": [
{
"name": "vector",
"nullable": True,
"type": {
"type": "fixed_size_list",
"fields": [
{"name": "item", "nullable": True, "type": {"type": "float"}}
],
"length": 2,
},
},
]
}
client.post.return_value = {"schema": json_schema}
assert table.schema == schema
assert client.post.call_args[0][0] == "/v1/table/test/describe/"
client.post.return_value = 0
assert table.count_rows(None) == 0
def test_create_table_with_recordbatches():
client = MagicMock()
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
conn._client = client
batch = pa.RecordBatch.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0]])], ["vector"])
client.post.return_value = {"status": "ok"}
table = conn.create_table("test", [batch], schema=batch.schema)
assert table.name == "test"
assert client.post.call_args[0][0] == "/v1/table/test/create/"
import pyarrow as pa
def make_mock_http_handler(handler):
@@ -100,8 +28,35 @@ def make_mock_http_handler(handler):
return MockLanceDBHandler
@contextlib.contextmanager
def mock_lancedb_connection(handler):
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
handle = threading.Thread(target=server.serve_forever)
handle.start()
db = lancedb.connect(
"db://dev",
api_key="fake",
host_override="http://localhost:8080",
client_config={
"retry_config": {"retries": 2},
"timeout_config": {
"connect_timeout": 1,
},
},
)
try:
yield db
finally:
server.shutdown()
handle.join()
@contextlib.asynccontextmanager
async def mock_lancedb_connection(handler):
async def mock_lancedb_connection_async(handler):
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
@@ -143,7 +98,7 @@ async def test_async_remote_db():
request.end_headers()
request.wfile.write(b'{"tables": []}')
async with mock_lancedb_connection(handler) as db:
async with mock_lancedb_connection_async(handler) as db:
table_names = await db.table_names()
assert table_names == []
@@ -159,12 +114,12 @@ async def test_http_error():
request.end_headers()
request.wfile.write(b"Internal Server Error")
async with mock_lancedb_connection(handler) as db:
with pytest.raises(HttpError, match="Internal Server Error") as exc_info:
async with mock_lancedb_connection_async(handler) as db:
with pytest.raises(HttpError) as exc_info:
await db.table_names()
assert exc_info.value.request_id == request_id_holder["request_id"]
assert exc_info.value.status_code == 507
assert "Internal Server Error" in str(exc_info.value)
@pytest.mark.asyncio
@@ -178,15 +133,225 @@ async def test_retry_error():
request.end_headers()
request.wfile.write(b"Try again later")
async with mock_lancedb_connection(handler) as db:
with pytest.raises(RetryError, match="Hit retry limit") as exc_info:
async with mock_lancedb_connection_async(handler) as db:
with pytest.raises(RetryError) as exc_info:
await db.table_names()
assert exc_info.value.request_id == request_id_holder["request_id"]
assert exc_info.value.status_code == 429
cause = exc_info.value.__cause__
assert isinstance(cause, HttpError)
assert "Try again later" in str(cause)
assert cause.request_id == request_id_holder["request_id"]
assert cause.status_code == 429
@contextlib.contextmanager
def query_test_table(query_handler):
def handler(request):
if request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
elif request.path == "/v1/table/test/query/":
content_len = int(request.headers.get("Content-Length"))
body = request.rfile.read(content_len)
body = json.loads(body)
data = query_handler(body)
request.send_response(200)
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
request.end_headers()
with pa.ipc.new_file(request.wfile, schema=data.schema) as f:
f.write_table(data)
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
assert repr(db) == "RemoteConnect(name=dev)"
table = db.open_table("test")
assert repr(table) == "RemoteTable(dev.test)"
yield table
def test_query_sync_minimal():
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": False,
"refine_factor": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 20,
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
data = table.search([1, 2, 3]).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected
def test_query_sync_maximal():
def handler(body):
assert body == {
"distance_type": "cosine",
"k": 42,
"prefilter": True,
"refine_factor": 10,
"vector": [1.0, 2.0, 3.0],
"nprobes": 5,
"filter": "id > 0",
"columns": ["id", "name"],
"vector_column": "vector2",
"fast_search": True,
"with_row_id": True,
}
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
with query_test_table(handler) as table:
(
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
.metric("cosine")
.limit(42)
.refine_factor(10)
.nprobes(5)
.where("id > 0", prefilter=True)
.with_row_id(True)
.select(["id", "name"])
.to_list()
)
def test_query_sync_fts():
def handler(body):
assert body == {
"full_text_query": {
"query": "puppy",
"columns": [],
},
"k": 10,
"vector": [],
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
(table.search("puppy", query_type="fts").to_list())
def handler(body):
assert body == {
"full_text_query": {
"query": "puppy",
"columns": ["name", "description"],
},
"k": 42,
"vector": [],
"with_row_id": True,
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
(
table.search("puppy", query_type="fts", fts_columns=["name", "description"])
.with_row_id(True)
.limit(42)
.to_list()
)
def test_query_sync_hybrid():
def handler(body):
if "full_text_query" in body:
# FTS query
assert body == {
"full_text_query": {
"query": "puppy",
"columns": [],
},
"k": 42,
"vector": [],
"with_row_id": True,
}
return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]})
else:
# Vector query
assert body == {
"distance_type": "l2",
"k": 42,
"prefilter": False,
"refine_factor": None,
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"nprobes": 20,
"with_row_id": True,
}
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})
with query_test_table(handler) as table:
embedding_func = MockTextEmbeddingFunction()
embedding_config = MagicMock()
embedding_config.function = embedding_func
embedding_funcs = MagicMock()
embedding_funcs.get = MagicMock(return_value=embedding_config)
table.embedding_functions = embedding_funcs
(table.search("puppy", query_type="hybrid").limit(42).to_list())
def test_create_client():
mandatory_args = {
"uri": "db://dev",
"api_key": "fake-api-key",
"region": "us-east-1",
}
db = lancedb.connect(**mandatory_args)
assert isinstance(db.client_config, ClientConfig)
db = lancedb.connect(**mandatory_args, client_config={})
assert isinstance(db.client_config, ClientConfig)
db = lancedb.connect(
**mandatory_args,
client_config=ClientConfig(timeout_config={"connect_timeout": 42}),
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
db = lancedb.connect(
**mandatory_args,
client_config={"timeout_config": {"connect_timeout": timedelta(seconds=42)}},
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
db = lancedb.connect(
**mandatory_args, client_config=ClientConfig(retry_config={"retries": 42})
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.retry_config.retries == 42
db = lancedb.connect(
**mandatory_args, client_config={"retry_config": {"retries": 42}}
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.retry_config.retries == 42
with pytest.warns(DeprecationWarning):
db = lancedb.connect(**mandatory_args, connection_timeout=42)
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
with pytest.warns(DeprecationWarning):
db = lancedb.connect(**mandatory_args, read_timeout=42)
assert db.client_config.timeout_config.read_timeout == timedelta(seconds=42)
with pytest.warns(DeprecationWarning):
lancedb.connect(**mandatory_args, request_thread_pool=10)

View File

@@ -170,6 +170,17 @@ impl Connection {
})
}
pub fn rename_table(
self_: PyRef<'_, Self>,
old_name: String,
new_name: String,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
inner.rename_table(old_name, new_name).await.infer_error()
})
}
pub fn drop_table(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {

View File

@@ -24,8 +24,8 @@ use lancedb::{
DistanceType,
};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pymethods, PyResult,
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
pyclass, pymethods, IntoPy, PyObject, PyResult, Python,
};
use crate::util::parse_distance_type;
@@ -236,7 +236,21 @@ pub struct IndexConfig {
#[pymethods]
impl IndexConfig {
pub fn __repr__(&self) -> String {
format!("Index({}, columns={:?})", self.index_type, self.columns)
format!(
"Index({}, columns={:?}, name=\"{}\")",
self.index_type, self.columns, self.name
)
}
// For backwards-compatibility with the old sync SDK, we also support getting
// attributes via __getitem__.
pub fn __getitem__(&self, key: String, py: Python<'_>) -> PyResult<PyObject> {
match key.as_str() {
"index_type" => Ok(self.index_type.clone().into_py(py)),
"columns" => Ok(self.columns.clone().into_py(py)),
"name" | "index_name" => Ok(self.name.clone().into_py(py)),
_ => Err(PyKeyError::new_err(format!("Invalid key: {}", key))),
}
}
}

View File

@@ -68,6 +68,18 @@ impl Query {
self.inner = self.inner.clone().offset(offset as usize);
}
pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search();
}
pub fn with_row_id(&mut self) {
self.inner = self.inner.clone().with_row_id();
}
pub fn postfilter(&mut self) {
self.inner = self.inner.clone().postfilter();
}
pub fn nearest_to(&mut self, vector: Bound<'_, PyAny>) -> PyResult<VectorQuery> {
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
let array = make_array(data);
@@ -146,6 +158,14 @@ impl VectorQuery {
self.inner = self.inner.clone().offset(offset as usize);
}
pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search();
}
pub fn with_row_id(&mut self) {
self.inner = self.inner.clone().with_row_id();
}
pub fn column(&mut self, column: String) {
self.inner = self.inner.clone().column(&column);
}

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-node"
version = "0.11.1-beta.1"
version = "0.13.0-beta.0"
description = "Serverless, low-latency vector database for AI applications"
license.workspace = true
edition.workspace = true

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.11.1-beta.1"
version = "0.13.0-beta.0"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true

View File

@@ -39,9 +39,6 @@ use crate::utils::validate_table_name;
use crate::Table;
pub use lance_encoding::version::LanceFileVersion;
#[cfg(feature = "remote")]
use log::warn;
pub const LANCE_FILE_EXTENSION: &str = "lance";
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>;
@@ -719,8 +716,7 @@ impl ConnectBuilder {
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
})?;
// TODO: remove this warning when the remote client is ready
warn!("The rust implementation of the remote client is not yet ready for use.");
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
&self.uri,
&api_key,

View File

@@ -119,6 +119,7 @@ pub enum IndexType {
#[serde(alias = "LABEL_LIST")]
LabelList,
// FTS
#[serde(alias = "INVERTED", alias = "Inverted")]
FTS,
}

View File

@@ -403,6 +403,26 @@ pub trait QueryBase {
/// By default, it is false.
fn fast_search(self) -> Self;
/// If this is called then filtering will happen after the vector search instead of
/// before.
///
/// By default filtering will be performed before the vector search. This is how
/// filtering is typically understood to work. This prefilter step does add some
/// additional latency. Creating a scalar index on the filter column(s) can
/// often improve this latency. However, sometimes a filter is too complex or scalar
/// indices cannot be applied to the column. In these cases postfiltering can be
/// used instead of prefiltering to improve latency.
///
/// Post filtering applies the filter to the results of the vector search. This means
/// we only run the filter on a much smaller set of data. However, it can cause the
/// query to return fewer than `limit` results (or even no results) if none of the nearest
/// results match the filter.
///
/// Post filtering happens during the "refine stage" (described in more detail in
/// [`Self::refine_factor`]). This means that setting a higher refine factor can often
/// help restore some of the results lost by post filtering.
fn postfilter(self) -> Self;
/// Return the `_rowid` meta column from the Table.
fn with_row_id(self) -> Self;
}
@@ -442,6 +462,11 @@ impl<T: HasQuery> QueryBase for T {
self
}
fn postfilter(mut self) -> Self {
self.mut_query().prefilter = false;
self
}
fn with_row_id(mut self) -> Self {
self.mut_query().with_row_id = true;
self
@@ -561,6 +586,9 @@ pub struct Query {
///
/// By default, this is false.
pub(crate) with_row_id: bool,
/// If set to false, the filter will be applied after the vector search.
pub(crate) prefilter: bool,
}
impl Query {
@@ -574,6 +602,7 @@ impl Query {
select: Select::All,
fast_search: false,
with_row_id: false,
prefilter: true,
}
}
@@ -678,8 +707,6 @@ pub struct VectorQuery {
pub(crate) distance_type: Option<DistanceType>,
/// Default is true. Set to false to enforce a brute force search.
pub(crate) use_index: bool,
/// Apply filter before ANN search/
pub(crate) prefilter: bool,
}
impl VectorQuery {
@@ -692,7 +719,6 @@ impl VectorQuery {
refine_factor: None,
distance_type: None,
use_index: true,
prefilter: true,
}
}
@@ -782,29 +808,6 @@ impl VectorQuery {
self
}
/// If this is called then filtering will happen after the vector search instead of
/// before.
///
/// By default filtering will be performed before the vector search. This is how
/// filtering is typically understood to work. This prefilter step does add some
/// additional latency. Creating a scalar index on the filter column(s) can
/// often improve this latency. However, sometimes a filter is too complex or scalar
/// indices cannot be applied to the column. In these cases postfiltering can be
/// used instead of prefiltering to improve latency.
///
/// Post filtering applies the filter to the results of the vector search. This means
/// we only run the filter on a much smaller set of data. However, it can cause the
/// query to return fewer than `limit` results (or even no results) if none of the nearest
/// results match the filter.
///
/// Post filtering happens during the "refine stage" (described in more detail in
/// [`Self::refine_factor`]). This means that setting a higher refine factor can often
/// help restore some of the results lost by post filtering.
pub fn postfilter(mut self) -> Self {
self.prefilter = false;
self
}
/// If this is called then any vector index is skipped
///
/// An exhaustive (flat) search will be performed. The query vector will

View File

@@ -23,6 +23,8 @@ pub(crate) mod table;
pub(crate) mod util;
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
#[cfg(test)]
const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
const JSON_CONTENT_TYPE: &str = "application/json";
pub use client::{ClientConfig, RetryConfig, TimeoutConfig};

View File

@@ -341,7 +341,22 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
request_id
};
debug!("Sending request_id={}: {:?}", request_id, &request);
if log::log_enabled!(log::Level::Debug) {
let content_type = request
.headers()
.get("content-type")
.map(|v| v.to_str().unwrap());
if content_type == Some("application/json") {
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
let body = String::from_utf8_lossy(body);
debug!(
"Sending request_id={}: {:?} with body {}",
request_id, request, body
);
} else {
debug!("Sending request_id={}: {:?}", request_id, request);
}
}
if with_retry {
self.send_with_retry_impl(client, request, request_id).await

View File

@@ -161,7 +161,7 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
if self.table_cache.get(&options.name).is_none() {
let req = self
.client
.get(&format!("/v1/table/{}/describe/", options.name));
.post(&format!("/v1/table/{}/describe/", options.name));
let (request_id, resp) = self.client.send(req, true).await?;
if resp.status() == StatusCode::NOT_FOUND {
return Err(crate::Error::TableNotFound { name: options.name });
@@ -301,7 +301,7 @@ mod tests {
#[tokio::test]
async fn test_open_table() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/table/table1/describe/");
assert_eq!(request.url().query(), None);

View File

@@ -1,3 +1,4 @@
use std::io::Cursor;
use std::sync::{Arc, Mutex};
use crate::index::Index;
@@ -7,10 +8,9 @@ use crate::table::AddDataMode;
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::Error;
use arrow_array::RecordBatchReader;
use arrow_ipc::reader::StreamReader;
use arrow_ipc::reader::FileReader;
use arrow_schema::{DataType, SchemaRef};
use async_trait::async_trait;
use bytes::Buf;
use datafusion_common::DataFusionError;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
@@ -115,39 +115,14 @@ impl<S: HttpSend> RemoteTable<S> {
async fn read_arrow_stream(
&self,
request_id: &str,
body: reqwest::Response,
response: reqwest::Response,
) -> Result<SendableRecordBatchStream> {
// Assert that the content type is correct
let content_type = body
.headers()
.get(CONTENT_TYPE)
.ok_or_else(|| Error::Http {
source: "Missing content type".into(),
request_id: request_id.to_string(),
status_code: None,
})?
.to_str()
.map_err(|e| Error::Http {
source: format!("Failed to parse content type: {}", e).into(),
request_id: request_id.to_string(),
status_code: None,
})?;
if content_type != ARROW_STREAM_CONTENT_TYPE {
return Err(Error::Http {
source: format!(
"Expected content type {}, got {}",
ARROW_STREAM_CONTENT_TYPE, content_type
)
.into(),
request_id: request_id.to_string(),
status_code: None,
});
}
let response = self.check_table_response(request_id, response).await?;
// There isn't a way to actually stream this data yet. I have an upstream issue:
// https://github.com/apache/arrow-rs/issues/6420
let body = body.bytes().await.err_to_http(request_id.into())?;
let reader = StreamReader::try_new(body.reader(), None)?;
let body = response.bytes().await.err_to_http(request_id.into())?;
let reader = FileReader::try_new(Cursor::new(body), None)?;
let schema = reader.schema();
let stream = futures::stream::iter(reader).map_err(DataFusionError::from);
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
@@ -192,6 +167,10 @@ impl<S: HttpSend> RemoteTable<S> {
body["fast_search"] = serde_json::Value::Bool(true);
}
if params.with_row_id {
body["with_row_id"] = serde_json::Value::Bool(true);
}
if let Some(full_text_search) = &params.full_text_search {
if full_text_search.wand_factor.is_some() {
return Err(Error::NotSupported {
@@ -277,7 +256,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
.post(&format!("/v1/table/{}/count_rows/", self.name));
if let Some(filter) = filter {
request = request.json(&serde_json::json!({ "filter": filter }));
request = request.json(&serde_json::json!({ "predicate": filter }));
} else {
request = request.json(&serde_json::json!({}));
}
@@ -330,13 +309,13 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let mut body = serde_json::Value::Object(Default::default());
Self::apply_query_params(&mut body, &query.base)?;
body["prefilter"] = query.prefilter.into();
body["prefilter"] = query.base.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into();
body["refine_factor"] = query.refine_factor.into();
if let Some(vector) = query.query_vector.as_ref() {
let vector: Vec<f32> = match vector.data_type() {
let vector: Vec<f32> = if let Some(vector) = query.query_vector.as_ref() {
match vector.data_type() {
DataType::Float32 => vector
.as_any()
.downcast_ref::<arrow_array::Float32Array>()
@@ -350,9 +329,12 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
message: "VectorQuery vector must be of type Float32".into(),
})
}
};
body["vector"] = serde_json::json!(vector);
}
}
} else {
// Server takes empty vector, not null or undefined.
Vec::new()
};
body["vector"] = serde_json::json!(vector);
if let Some(vector_column) = query.column.as_ref() {
body["vector_column"] = serde_json::Value::String(vector_column.clone());
@@ -383,6 +365,8 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let mut body = serde_json::Value::Object(Default::default());
Self::apply_query_params(&mut body, query)?;
// Empty vector can be passed if no vector search is performed.
body["vector"] = serde_json::Value::Array(Vec::new());
let request = request.json(&body);
@@ -399,30 +383,19 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let mut updates = Vec::new();
for (column, expression) in update.columns {
updates.push(column);
updates.push(expression);
updates.push(vec![column, expression]);
}
let request = request.json(&serde_json::json!({
"updates": updates,
"only_if": update.filter,
"predicate": update.filter,
}));
let (request_id, response) = self.client.send(request, false).await?;
let response = self.check_table_response(&request_id, response).await?;
self.check_table_response(&request_id, response).await?;
let body = response.text().await.err_to_http(request_id.clone())?;
serde_json::from_str(&body).map_err(|e| Error::Http {
source: format!(
"Failed to parse updated rows result from response {}: {}",
body, e
)
.into(),
request_id,
status_code: None,
})
Ok(0) // TODO: support returning number of modified rows once supported in SaaS.
}
async fn delete(&self, predicate: &str) -> Result<()> {
let body = serde_json::json!({ "predicate": predicate });
@@ -691,6 +664,7 @@ mod tests {
use crate::{
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
query::{ExecutableQuery, QueryBase},
remote::ARROW_FILE_CONTENT_TYPE,
DistanceType, Error, Table,
};
@@ -804,7 +778,7 @@ mod tests {
);
assert_eq!(
request.body().unwrap().as_bytes().unwrap(),
br#"{"filter":"a > 10"}"#
br#"{"predicate":"a > 10"}"#
);
http::Response::builder().status(200).body("42").unwrap()
@@ -839,6 +813,17 @@ mod tests {
body
}
fn write_ipc_file(data: &RecordBatch) -> Vec<u8> {
let mut body = Vec::new();
{
let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut body, &data.schema())
.expect("Failed to create writer");
writer.write(data).expect("Failed to write data");
writer.finish().expect("Failed to finish");
}
body
}
#[tokio::test]
async fn test_add_append() {
let data = RecordBatch::try_new(
@@ -947,21 +932,27 @@ mod tests {
let updates = value.get("updates").unwrap().as_array().unwrap();
assert!(updates.len() == 2);
let col_name = updates[0].as_str().unwrap();
let expression = updates[1].as_str().unwrap();
let col_name = updates[0][0].as_str().unwrap();
let expression = updates[0][1].as_str().unwrap();
assert_eq!(col_name, "a");
assert_eq!(expression, "a + 1");
let only_if = value.get("only_if").unwrap().as_str().unwrap();
let col_name = updates[1][0].as_str().unwrap();
let expression = updates[1][1].as_str().unwrap();
assert_eq!(col_name, "b");
assert_eq!(expression, "b - 1");
let only_if = value.get("predicate").unwrap().as_str().unwrap();
assert_eq!(only_if, "b > 10");
}
http::Response::builder().status(200).body("1").unwrap()
http::Response::builder().status(200).body("{}").unwrap()
});
table
.update()
.column("a", "a + 1")
.column("b", "b - 1")
.only_if("b > 10")
.execute()
.await
@@ -1092,10 +1083,10 @@ mod tests {
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
assert_eq!(body, expected_body);
let response_body = write_ipc_stream(&expected_data_ref);
let response_body = write_ipc_file(&expected_data_ref);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
@@ -1142,10 +1133,10 @@ mod tests {
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let response_body = write_ipc_stream(&data);
let response_body = write_ipc_file(&data);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
@@ -1185,6 +1176,8 @@ mod tests {
"query": "hello world",
},
"k": 10,
"vector": [],
"with_row_id": true,
});
assert_eq!(body, expected_body);
@@ -1193,10 +1186,10 @@ mod tests {
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let response_body = write_ipc_stream(&data);
let response_body = write_ipc_file(&data);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
@@ -1207,6 +1200,7 @@ mod tests {
FullTextSearchQuery::new("hello world".into())
.columns(Some(vec!["a".into(), "b".into()])),
)
.with_row_id()
.limit(10)
.execute()
.await

View File

@@ -1842,7 +1842,7 @@ impl TableInternal for NativeTable {
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
scanner.prefilter(query.base.prefilter);
match query.base.select {
Select::Columns(ref columns) => {
scanner.project(columns.as_slice())?;
@@ -3123,6 +3123,12 @@ mod tests {
assert_eq!(index.index_type, crate::index::IndexType::FTS);
assert_eq!(index.columns, vec!["text".to_string()]);
assert_eq!(index.name, "text_idx");
let stats = table.index_stats("text_idx").await.unwrap().unwrap();
assert_eq!(stats.num_indexed_rows, num_rows);
assert_eq!(stats.num_unindexed_rows, 0);
assert_eq!(stats.index_type, crate::index::IndexType::FTS);
assert_eq!(stats.distance_type, None);
}
#[tokio::test]