Compare commits

..

15 Commits

Author SHA1 Message Date
Lance Release
08d5f93f34 Bump version: 0.15.0 → 0.16.0-beta.0 2024-11-05 23:21:13 +00:00
Will Jones
91cab3b556 feat(python): transition Python remote sdk to use Rust implementation (#1701)
* Replaces Python implementation of Remote SDK with Rust one.
* Drops dependency on `attrs` and `cachetools`. Makes `requests` an
optional dependency used only for embeddings feature.
* Adds dependency on `nest-asyncio`. This was required to get hybrid
search working.
* Deprecate `request_thread_pool` parameter. We now use the tokio
threadpool.
* Stop caching the `schema` on a remote table. Schema is mutable and
there's no mechanism in place to invalidate the cache.
* Removed the client-side resolution of the vector column. We should
already be resolving this server-side.
2024-11-05 13:44:39 -08:00
Will Jones
c61bfc3af8 chore: update package locks (#1798) 2024-11-05 13:28:59 -08:00
Bert
4e8c7b0adf fix: serialize vectordb client errors as json (#1795) 2024-11-05 14:16:25 -05:00
Weston Pace
26f4a80e10 feat: upgrade to lance 0.19.2-beta.3 (#1794) 2024-11-05 06:43:41 -08:00
Will Jones
3604d20ad3 feat(python,node): support with_row_id in Python and remote (#1784)
Needed to support hybrid search in Remote SDK.
2024-11-04 11:25:45 -08:00
Gagan Bhullar
9708d829a9 fix: explain plan options (#1776)
PR fixes #1768
2024-11-04 10:25:34 -08:00
Will Jones
059c9794b5 fix(rust): fix update, open_table, fts search in remote client (#1785)
* `open_table` uses `POST` not `GET`
* `update` uses `predicate` key not `only_if`
* For FTS search, vector cannot be omitted. It must be passed as empty.
* Added logging of JSON request bodies to debug level logging.
2024-11-04 08:27:55 -08:00
Will Jones
15ed7f75a0 feat(python): support post filter on FTS (#1783) 2024-11-01 10:05:05 -07:00
Will Jones
96181ab421 feat: fast_search in Python and Node (#1623)
Sometimes it is acceptable to users to only search indexed data and skip
and new un-indexed data. For example, if un-indexed data will be shortly
indexed and they don't mind the delay. In these cases, we can save a lot
of CPU time in search, and provide better latency. Users can activate
this on queries using `fast_search()`.
2024-11-01 09:29:09 -07:00
Will Jones
f3fc339ef6 fix(rust): fix delete, update, query in remote SDK (#1782)
Fixes several minor issues with Rust remote SDK:

* Delete uses `predicate` not `filter` as parameter
* Update does not return the row value in remote SDK
* Update takes tuples
* Content type returned by query node is wrong, so we shouldn't validate
it. https://github.com/lancedb/sophon/issues/2742
* Data returned by query endpoint is actually an Arrow IPC file, not IPC
stream.
2024-10-31 15:22:09 -07:00
Will Jones
113cd6995b fix: index_stats works for FTS indices (#1780)
When running `index_stats()` for an FTS index, users would get the
deserialization error:

```
InvalidInput { message: "error deserializing index statistics: unknown variant `Inverted`, expected one of `IvfPq`, `IvfHnswPq`, `IvfHnswSq`, `BTree`, `Bitmap`, `LabelList`, `FTS` at line 1 column 24" }
```
2024-10-30 11:33:49 -07:00
Lance Release
02535bdc88 Updating package-lock.json 2024-10-29 22:16:51 +00:00
Lance Release
facc7d61c0 Bump version: 0.12.0-beta.0 → 0.12.0 2024-10-29 22:16:32 +00:00
Lance Release
f947259f16 Bump version: 0.11.1-beta.1 → 0.12.0-beta.0 2024-10-29 22:16:27 +00:00
55 changed files with 907 additions and 1082 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.11.1-beta.1" current_version = "0.12.0"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -21,13 +21,15 @@ categories = ["database-implementations"]
rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again. rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=0.19.1", "features" = ["dynamodb"] } lance = { "version" = "=0.19.2", "features" = [
lance-index = { "version" = "=0.19.1" } "dynamodb",
lance-linalg = { "version" = "=0.19.1" } ], git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-table = { "version" = "=0.19.1" } lance-index = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-testing = { "version" = "=0.19.1" } lance-linalg = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-datafusion = { "version" = "=0.19.1" } lance-table = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-encoding = { "version" = "=0.19.1" } lance-testing = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-datafusion = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-encoding = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "52.2", optional = false } arrow = { version = "52.2", optional = false }
arrow-array = "52.2" arrow-array = "52.2"

View File

@@ -8,7 +8,7 @@
<parent> <parent>
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.11.1-beta.1</version> <version>0.12.0-final.0</version>
<relativePath>../pom.xml</relativePath> <relativePath>../pom.xml</relativePath>
</parent> </parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId> <groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId> <artifactId>lancedb-parent</artifactId>
<version>0.11.1-beta.1</version> <version>0.12.0-final.0</version>
<packaging>pom</packaging> <packaging>pom</packaging>
<name>LanceDB Parent</name> <name>LanceDB Parent</name>

49
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.11.1-beta.1", "version": "0.12.0",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "vectordb", "name": "vectordb",
"version": "0.11.1-beta.1", "version": "0.12.0",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"
@@ -52,11 +52,11 @@
"uuid": "^9.0.0" "uuid": "^9.0.0"
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.11.1-beta.1", "@lancedb/vectordb-darwin-arm64": "0.12.0",
"@lancedb/vectordb-darwin-x64": "0.11.1-beta.1", "@lancedb/vectordb-darwin-x64": "0.12.0",
"@lancedb/vectordb-linux-arm64-gnu": "0.11.1-beta.1", "@lancedb/vectordb-linux-arm64-gnu": "0.12.0",
"@lancedb/vectordb-linux-x64-gnu": "0.11.1-beta.1", "@lancedb/vectordb-linux-x64-gnu": "0.12.0",
"@lancedb/vectordb-win32-x64-msvc": "0.11.1-beta.1" "@lancedb/vectordb-win32-x64-msvc": "0.12.0"
}, },
"peerDependencies": { "peerDependencies": {
"@apache-arrow/ts": "^14.0.2", "@apache-arrow/ts": "^14.0.2",
@@ -327,65 +327,60 @@
} }
}, },
"node_modules/@lancedb/vectordb-darwin-arm64": { "node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.11.1-beta.1", "version": "0.12.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.11.1-beta.1.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.12.0.tgz",
"integrity": "sha512-q9jcCbmcz45UHmjgecL6zK82WaqUJsARfniwXXPcnd8ooISVhPkgN+RVKv6edwI9T0PV+xVRYq+LQLlZu5fyxw==", "integrity": "sha512-9X6UyP/ozHkv39YZ8DWh82m3aeQmUtrVDNuRe3o8has6dJyD/qPYukI8Zked4q8J+86/lgQbr4f+WW2V4Dfc1g==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
"license": "Apache-2.0",
"optional": true, "optional": true,
"os": [ "os": [
"darwin" "darwin"
] ]
}, },
"node_modules/@lancedb/vectordb-darwin-x64": { "node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.11.1-beta.1", "version": "0.12.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.11.1-beta.1.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.12.0.tgz",
"integrity": "sha512-E5tCTS5TaTkssTPa+gdnFxZJ1f60jnSIJXhqufNFZk4s+IMViwR1BPqaqE++WY5c1uBI55ef1862CROKDKX4gg==", "integrity": "sha512-zG+//P3BBpmOiLR+dop68T9AFNxazWlSLF8yVdAtvsqjRzcrrMLR//rIrRcbPHxu8gvvLrMDoDZT+AHd2rElyQ==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
"license": "Apache-2.0",
"optional": true, "optional": true,
"os": [ "os": [
"darwin" "darwin"
] ]
}, },
"node_modules/@lancedb/vectordb-linux-arm64-gnu": { "node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.11.1-beta.1", "version": "0.12.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.11.1-beta.1.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.12.0.tgz",
"integrity": "sha512-Obohy6TH31Uq+fp6ZisHR7iAsvgVPqBExrycVcIJqrLZnIe88N9OWUwBXkmfMAw/2hNJFwD4tU7+4U2FcBWX4w==", "integrity": "sha512-5RiJkcZEdMkK5WUfkV+HVFnJaAergfSiLNgUwJaovEEX8yVChkhrdZFSUj1o/k2k6Ix9mQq+xfIUF+aGN/XnDQ==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
"license": "Apache-2.0",
"optional": true, "optional": true,
"os": [ "os": [
"linux" "linux"
] ]
}, },
"node_modules/@lancedb/vectordb-linux-x64-gnu": { "node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.11.1-beta.1", "version": "0.12.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.11.1-beta.1.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.12.0.tgz",
"integrity": "sha512-3Meu0dgrzNrnBVVQhxkUSAOhQNmgtKHvOvmrRLUicV+X19hd33udihgxVpZZb9mpXenJ8lZsS+Jq6R0hWqntag==", "integrity": "sha512-JFulRNBHLF0TyE0tThaAB9T7CM3zLquPsBF6oA9b1stVdXbEqVqLMltjem0tqfj30zEoEbAKDPpEKII4CPQMTA==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
"license": "Apache-2.0",
"optional": true, "optional": true,
"os": [ "os": [
"linux" "linux"
] ]
}, },
"node_modules/@lancedb/vectordb-win32-x64-msvc": { "node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.11.1-beta.1", "version": "0.12.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.11.1-beta.1.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.12.0.tgz",
"integrity": "sha512-BafZ9OJPQXsS7JW0weAl12wC+827AiRjfUrE5tvrYWZah2OwCF2U2g6uJ3x4pxfwEGsv5xcHFqgxlS7ttFkh+Q==", "integrity": "sha512-T3s/RzB5dvXBqU3qmS6zyHhF0RHS2sSs81zKzYQy2R2nEVPbnwutFSsdA1wEqEXZlr8uTD9nLbkKJKqRNTXVEg==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
"license": "Apache-2.0",
"optional": true, "optional": true,
"os": [ "os": [
"win32" "win32"

View File

@@ -1,6 +1,6 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.11.1-beta.1", "version": "0.12.0",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js", "main": "dist/index.js",
"types": "dist/index.d.ts", "types": "dist/index.d.ts",
@@ -88,10 +88,10 @@
} }
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.11.1-beta.1", "@lancedb/vectordb-darwin-arm64": "0.12.0",
"@lancedb/vectordb-darwin-x64": "0.11.1-beta.1", "@lancedb/vectordb-darwin-x64": "0.12.0",
"@lancedb/vectordb-linux-arm64-gnu": "0.11.1-beta.1", "@lancedb/vectordb-linux-arm64-gnu": "0.12.0",
"@lancedb/vectordb-linux-x64-gnu": "0.11.1-beta.1", "@lancedb/vectordb-linux-x64-gnu": "0.12.0",
"@lancedb/vectordb-win32-x64-msvc": "0.11.1-beta.1" "@lancedb/vectordb-win32-x64-msvc": "0.12.0"
} }
} }

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import axios, { type AxiosResponse, type ResponseType } from 'axios' import axios, { type AxiosError, type AxiosResponse, type ResponseType } from 'axios'
import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow' import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow'
@@ -197,7 +197,7 @@ export class HttpLancedbClient {
response = await callWithMiddlewares(req, this._middlewares) response = await callWithMiddlewares(req, this._middlewares)
return response return response
} catch (err: any) { } catch (err: any) {
console.error('error: ', err) console.error(serializeErrorAsJson(err))
if (err.response === undefined) { if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`) throw new Error(`Network Error: ${err.message as string}`)
} }
@@ -247,7 +247,8 @@ export class HttpLancedbClient {
// return response // return response
} catch (err: any) { } catch (err: any) {
console.error('error: ', err) console.error(serializeErrorAsJson(err))
if (err.response === undefined) { if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`) throw new Error(`Network Error: ${err.message as string}`)
} }
@@ -287,3 +288,15 @@ export class HttpLancedbClient {
return clone return clone
} }
} }
function serializeErrorAsJson(err: AxiosError) {
const error = JSON.parse(JSON.stringify(err, Object.getOwnPropertyNames(err)))
error.response = err.response != null
? JSON.parse(JSON.stringify(
err.response,
// config contains the request data, too noisy
Object.getOwnPropertyNames(err.response).filter(prop => prop !== 'config')
))
: null
return JSON.stringify({ error })
}

View File

@@ -1,7 +1,7 @@
[package] [package]
name = "lancedb-nodejs" name = "lancedb-nodejs"
edition.workspace = true edition.workspace = true
version = "0.11.1-beta.1" version = "0.12.0"
license.workspace = true license.workspace = true
description.workspace = true description.workspace = true
repository.workspace = true repository.workspace = true

View File

@@ -402,6 +402,40 @@ describe("When creating an index", () => {
expect(rst.numRows).toBe(1); expect(rst.numRows).toBe(1);
}); });
it("should be able to query unindexed data", async () => {
await tbl.createIndex("vec");
await tbl.add([
{
id: 300,
vec: Array(32)
.fill(1)
.map(() => Math.random()),
tags: [],
},
]);
const plan1 = await tbl.query().nearestTo(queryVec).explainPlan(true);
expect(plan1).toMatch("LanceScan");
const plan2 = await tbl
.query()
.nearestTo(queryVec)
.fastSearch()
.explainPlan(true);
expect(plan2).not.toMatch("LanceScan");
});
it("should be able to query with row id", async () => {
const results = await tbl
.query()
.nearestTo(queryVec)
.withRowId()
.limit(1)
.toArray();
expect(results.length).toBe(1);
expect(results[0]).toHaveProperty("_rowid");
});
it("should allow parameters to be specified", async () => { it("should allow parameters to be specified", async () => {
await tbl.createIndex("vec", { await tbl.createIndex("vec", {
config: Index.ivfPq({ config: Index.ivfPq({

View File

@@ -239,6 +239,29 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
return this; return this;
} }
/**
* Skip searching un-indexed data. This can make search faster, but will miss
* any data that is not yet indexed.
*
* Use {@link lancedb.Table#optimize} to index all un-indexed data.
*/
fastSearch(): this {
this.doCall((inner: NativeQueryType) => inner.fastSearch());
return this;
}
/**
* Whether to return the row id in the results.
*
* This column can be used to match results between different queries. For
* example, to match results from a full text search and a vector search in
* order to perform hybrid search.
*/
withRowId(): this {
this.doCall((inner: NativeQueryType) => inner.withRowId());
return this;
}
protected nativeExecute( protected nativeExecute(
options?: Partial<QueryExecutionOptions>, options?: Partial<QueryExecutionOptions>,
): Promise<NativeBatchIterator> { ): Promise<NativeBatchIterator> {

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-arm64", "name": "@lancedb/lancedb-darwin-arm64",
"version": "0.11.1-beta.1", "version": "0.12.0",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node", "main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-x64", "name": "@lancedb/lancedb-darwin-x64",
"version": "0.11.1-beta.1", "version": "0.12.0",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.darwin-x64.node", "main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-gnu", "name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.11.1-beta.1", "version": "0.12.0",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node", "main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-gnu", "name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.11.1-beta.1", "version": "0.12.0",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node", "main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-x64-msvc", "name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.11.1-beta.1", "version": "0.12.0",
"os": ["win32"], "os": ["win32"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node", "main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{ {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.11.1-beta.1", "version": "0.12.0",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.11.1-beta.1", "version": "0.12.0",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"

View File

@@ -10,7 +10,7 @@
"vector database", "vector database",
"ann" "ann"
], ],
"version": "0.11.1-beta.1", "version": "0.12.0",
"main": "dist/index.js", "main": "dist/index.js",
"exports": { "exports": {
".": "./dist/index.js", ".": "./dist/index.js",

View File

@@ -80,6 +80,16 @@ impl Query {
Ok(VectorQuery { inner }) Ok(VectorQuery { inner })
} }
#[napi]
pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search();
}
#[napi]
pub fn with_row_id(&mut self) {
self.inner = self.inner.clone().with_row_id();
}
#[napi(catch_unwind)] #[napi(catch_unwind)]
pub async fn execute( pub async fn execute(
&self, &self,
@@ -183,6 +193,16 @@ impl VectorQuery {
self.inner = self.inner.clone().offset(offset as usize); self.inner = self.inner.clone().offset(offset as usize);
} }
#[napi]
pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search();
}
#[napi]
pub fn with_row_id(&mut self) {
self.inner = self.inner.clone().with_row_id();
}
#[napi(catch_unwind)] #[napi(catch_unwind)]
pub async fn execute( pub async fn execute(
&self, &self,

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.15.0" current_version = "0.16.0-beta.0"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-python" name = "lancedb-python"
version = "0.15.0" version = "0.16.0-beta.0"
edition.workspace = true edition.workspace = true
description = "Python bindings for LanceDB" description = "Python bindings for LanceDB"
license.workspace = true license.workspace = true

View File

@@ -3,13 +3,11 @@ name = "lancedb"
# version in Cargo.toml # version in Cargo.toml
dependencies = [ dependencies = [
"deprecation", "deprecation",
"pylance==0.19.1", "nest-asyncio~=1.0",
"requests>=2.31.0", "pylance==0.19.2-beta.3",
"tqdm>=4.27.0", "tqdm>=4.27.0",
"pydantic>=1.10", "pydantic>=1.10",
"attrs>=21.3.0",
"packaging", "packaging",
"cachetools",
"overrides>=0.7", "overrides>=0.7",
] ]
description = "lancedb" description = "lancedb"
@@ -61,6 +59,7 @@ dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"] clip = ["torch", "pillow", "open-clip"]
embeddings = [ embeddings = [
"requests>=2.31.0",
"openai>=1.6.1", "openai>=1.6.1",
"sentence-transformers", "sentence-transformers",
"torch", "torch",

View File

@@ -19,12 +19,10 @@ from typing import Dict, Optional, Union, Any
__version__ = importlib.metadata.version("lancedb") __version__ = importlib.metadata.version("lancedb")
from lancedb.remote import ClientConfig
from ._lancedb import connect as lancedb_connect from ._lancedb import connect as lancedb_connect
from .common import URI, sanitize_uri from .common import URI, sanitize_uri
from .db import AsyncConnection, DBConnection, LanceDBConnection from .db import AsyncConnection, DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection from .remote import ClientConfig
from .schema import vector from .schema import vector
from .table import AsyncTable from .table import AsyncTable
@@ -37,6 +35,7 @@ def connect(
host_override: Optional[str] = None, host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None, read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None, request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
**kwargs: Any, **kwargs: Any,
) -> DBConnection: ) -> DBConnection:
"""Connect to a LanceDB database. """Connect to a LanceDB database.
@@ -64,14 +63,10 @@ def connect(
the last check, then the table will be checked for updates. Note: this the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are consistency only applies to read operations. Write operations are
always consistent. always consistent.
request_thread_pool: int or ThreadPoolExecutor, optional client_config: ClientConfig or dict, optional
The thread pool to use for making batch requests to the LanceDB Cloud API. Configuration options for the LanceDB Cloud HTTP client. If a dict, then
If an integer, then a ThreadPoolExecutor will be created with that the keys are the attributes of the ClientConfig class. If None, then the
number of threads. If None, then a ThreadPoolExecutor will be created default configuration is used.
with the default number of threads. If a ThreadPoolExecutor, then that
executor will be used for making requests. This is for LanceDB Cloud
only and is only used when making batch requests (i.e., passing in
multiple queries to the search method at once).
Examples Examples
-------- --------
@@ -94,6 +89,8 @@ def connect(
conn : DBConnection conn : DBConnection
A connection to a LanceDB database. A connection to a LanceDB database.
""" """
from .remote.db import RemoteDBConnection
if isinstance(uri, str) and uri.startswith("db://"): if isinstance(uri, str) and uri.startswith("db://"):
if api_key is None: if api_key is None:
api_key = os.environ.get("LANCEDB_API_KEY") api_key = os.environ.get("LANCEDB_API_KEY")
@@ -106,7 +103,9 @@ def connect(
api_key, api_key,
region, region,
host_override, host_override,
# TODO: remove this (deprecation warning downstream)
request_thread_pool=request_thread_pool, request_thread_pool=request_thread_pool,
client_config=client_config,
**kwargs, **kwargs,
) )

View File

@@ -36,6 +36,8 @@ class Connection(object):
data_storage_version: Optional[str] = None, data_storage_version: Optional[str] = None,
enable_v2_manifest_paths: Optional[bool] = None, enable_v2_manifest_paths: Optional[bool] = None,
) -> Table: ... ) -> Table: ...
async def rename_table(self, old_name: str, new_name: str) -> None: ...
async def drop_table(self, name: str) -> None: ...
class Table: class Table:
def name(self) -> str: ... def name(self) -> str: ...

View File

@@ -817,6 +817,18 @@ class AsyncConnection(object):
table = await self._inner.open_table(name, storage_options, index_cache_size) table = await self._inner.open_table(name, storage_options, index_cache_size)
return AsyncTable(table) return AsyncTable(table)
async def rename_table(self, old_name: str, new_name: str):
"""Rename a table in the database.
Parameters
----------
old_name: str
The current name of the table.
new_name: str
The new name of the table.
"""
await self._inner.rename_table(old_name, new_name)
async def drop_table(self, name: str): async def drop_table(self, name: str):
"""Drop a table from the database. """Drop a table from the database.

View File

@@ -13,7 +13,6 @@
import os import os
import io import io
import requests
import base64 import base64
from urllib.parse import urlparse from urllib.parse import urlparse
from pathlib import Path from pathlib import Path
@@ -226,6 +225,8 @@ class JinaEmbeddings(EmbeddingFunction):
return [result["embedding"] for result in sorted_embeddings] return [result["embedding"] for result in sorted_embeddings]
def _init_client(self): def _init_client(self):
import requests
if JinaEmbeddings._session is None: if JinaEmbeddings._session is None:
if self.api_key is None and os.environ.get("JINA_API_KEY") is None: if self.api_key is None and os.environ.get("JINA_API_KEY") is None:
api_key_not_found_help("jina") api_key_not_found_help("jina")

View File

@@ -467,6 +467,8 @@ class IvfPq:
The default value is 256. The default value is 256.
""" """
if distance_type is not None:
distance_type = distance_type.lower()
self._inner = LanceDbIndex.ivf_pq( self._inner = LanceDbIndex.ivf_pq(
distance_type=distance_type, distance_type=distance_type,
num_partitions=num_partitions, num_partitions=num_partitions,

View File

@@ -481,6 +481,7 @@ class LanceQueryBuilder(ABC):
>>> plan = table.search(query).explain_plan(True) >>> plan = table.search(query).explain_plan(True)
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE >>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance] ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
GlobalLimitExec: skip=0, fetch=10
FilterExec: _distance@2 IS NOT NULL FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false] SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
KNNVectorDistance: metric=l2 KNNVectorDistance: metric=l2
@@ -500,7 +501,16 @@ class LanceQueryBuilder(ABC):
nearest={ nearest={
"column": self._vector_column, "column": self._vector_column,
"q": self._query, "q": self._query,
"k": self._limit,
"metric": self._metric,
"nprobes": self._nprobes,
"refine_factor": self._refine_factor,
}, },
prefilter=self._prefilter,
filter=self._str_query,
limit=self._limit,
with_row_id=self._with_row_id,
offset=self._offset,
).explain_plan(verbose) ).explain_plan(verbose)
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder: def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
@@ -1315,6 +1325,48 @@ class AsyncQueryBase(object):
self._inner.offset(offset) self._inner.offset(offset)
return self return self
def fast_search(self) -> AsyncQuery:
"""
Skip searching un-indexed data.
This can make queries faster, but will miss any data that has not been
indexed.
!!! tip
You can add new data into an existing index by calling
[AsyncTable.optimize][lancedb.table.AsyncTable.optimize].
"""
self._inner.fast_search()
return self
def with_row_id(self) -> AsyncQuery:
"""
Include the _rowid column in the results.
"""
self._inner.with_row_id()
return self
def postfilter(self) -> AsyncQuery:
"""
If this is called then filtering will happen after the search instead of
before.
By default filtering will be performed before the search. This is how
filtering is typically understood to work. This prefilter step does add some
additional latency. Creating a scalar index on the filter column(s) can
often improve this latency. However, sometimes a filter is too complex or
scalar indices cannot be applied to the column. In these cases postfiltering
can be used instead of prefiltering to improve latency.
Post filtering applies the filter to the results of the search. This
means we only run the filter on a much smaller set of data. However, it can
cause the query to return fewer than `limit` results (or even no results) if
none of the nearest results match the filter.
Post filtering happens during the "refine stage" (described in more detail in
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
factor can often help restore some of the results lost by post filtering.
"""
self._inner.postfilter()
return self
async def to_batches( async def to_batches(
self, *, max_batch_length: Optional[int] = None self, *, max_batch_length: Optional[int] = None
) -> AsyncRecordBatchReader: ) -> AsyncRecordBatchReader:
@@ -1618,30 +1670,6 @@ class AsyncVectorQuery(AsyncQueryBase):
self._inner.distance_type(distance_type) self._inner.distance_type(distance_type)
return self return self
def postfilter(self) -> AsyncVectorQuery:
"""
If this is called then filtering will happen after the vector search instead of
before.
By default filtering will be performed before the vector search. This is how
filtering is typically understood to work. This prefilter step does add some
additional latency. Creating a scalar index on the filter column(s) can
often improve this latency. However, sometimes a filter is too complex or
scalar indices cannot be applied to the column. In these cases postfiltering
can be used instead of prefiltering to improve latency.
Post filtering applies the filter to the results of the vector search. This
means we only run the filter on a much smaller set of data. However, it can
cause the query to return fewer than `limit` results (or even no results) if
none of the nearest results match the filter.
Post filtering happens during the "refine stage" (described in more detail in
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
factor can often help restore some of the results lost by post filtering.
"""
self._inner.postfilter()
return self
def bypass_vector_index(self) -> AsyncVectorQuery: def bypass_vector_index(self) -> AsyncVectorQuery:
""" """
If this is called then any vector index is skipped If this is called then any vector index is skipped

View File

@@ -11,62 +11,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from typing import List, Optional from typing import List, Optional
import attrs
from lancedb import __version__ from lancedb import __version__
import pyarrow as pa
from pydantic import BaseModel
from lancedb.common import VECTOR_COLUMN_NAME __all__ = ["TimeoutConfig", "RetryConfig", "ClientConfig"]
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
class VectorQuery(BaseModel):
# vector to search for
vector: List[float]
# sql filter to refine the query with
filter: Optional[str] = None
# top k results to return
k: int
# # metrics
_metric: str = "L2"
# which columns to return in the results
columns: Optional[List[str]] = None
# optional query parameters for tuning the results,
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
nprobes: int = 10
refine_factor: Optional[int] = None
vector_column: str = VECTOR_COLUMN_NAME
fast_search: bool = False
@attrs.define
class VectorQueryResult:
# for now the response is directly seralized into a pandas dataframe
tbl: pa.Table
def to_arrow(self) -> pa.Table:
return self.tbl
class LanceDBClient(abc.ABC):
@abc.abstractmethod
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query the LanceDB server for the given table and query."""
pass
@dataclass @dataclass
@@ -165,8 +116,8 @@ class RetryConfig:
@dataclass @dataclass
class ClientConfig: class ClientConfig:
user_agent: str = f"LanceDB-Python-Client/{__version__}" user_agent: str = f"LanceDB-Python-Client/{__version__}"
retry_config: Optional[RetryConfig] = None retry_config: RetryConfig = field(default_factory=RetryConfig)
timeout_config: Optional[TimeoutConfig] = None timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
def __post_init__(self): def __post_init__(self):
if isinstance(self.retry_config, dict): if isinstance(self.retry_config, dict):

View File

@@ -1,25 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Union
import pyarrow as pa
def to_ipc_binary(table: Union[pa.Table, Iterable[pa.RecordBatch]]) -> bytes:
"""Serialize a PyArrow Table to IPC binary."""
sink = pa.BufferOutputStream()
if isinstance(table, Iterable):
table = pa.Table.from_batches(table)
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_table(table)
return sink.getvalue().to_pybytes()

View File

@@ -1,269 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Union
from urllib.parse import urljoin
import attrs
import pyarrow as pa
import requests
from pydantic import BaseModel
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult
from lancedb.remote.connection_timeout import LanceDBClientHTTPAdapterFactory
from lancedb.remote.errors import LanceDBClientError
ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream"
def _check_not_closed(f):
@functools.wraps(f)
def wrapped(self, *args, **kwargs):
if self.closed:
raise ValueError("Connection is closed")
return f(self, *args, **kwargs)
return wrapped
def _read_ipc(resp: requests.Response) -> pa.Table:
resp_body = resp.content
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
return reader.read_all()
@attrs.define(slots=False)
class RestfulLanceDBClient:
db_name: str
region: str
api_key: Credential
host_override: Optional[str] = attrs.field(default=None)
closed: bool = attrs.field(default=False, init=False)
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
read_timeout: float = attrs.field(default=300.0, kw_only=True)
@functools.cached_property
def session(self) -> requests.Session:
sess = requests.Session()
retry_adapter_instance = retry_adapter(retry_adapter_options())
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
adapter_class = LanceDBClientHTTPAdapterFactory()
sess.mount("https://", adapter_class())
return sess
@property
def url(self) -> str:
return (
self.host_override
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
return False # Do not suppress exceptions
def close(self):
self.session.close()
self.closed = True
@functools.cached_property
def headers(self) -> Dict[str, str]:
headers = {
"x-api-key": self.api_key,
}
if self.region == "local": # Local test mode
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
if self.host_override:
headers["x-lancedb-database"] = self.db_name
return headers
@staticmethod
def _check_status(resp: requests.Response):
# Leaving request id empty for now, as we'll be replacing this impl
# with the Rust one shortly.
if resp.status_code == 404:
raise LanceDBClientError(
f"Not found: {resp.text}", request_id="", status_code=404
)
elif 400 <= resp.status_code < 500:
raise LanceDBClientError(
f"Bad Request: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
elif 500 <= resp.status_code < 600:
raise LanceDBClientError(
f"Internal Server Error: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
elif resp.status_code != 200:
raise LanceDBClientError(
f"Unknown Error: {resp.status_code}, error: {resp.text}",
request_id="",
status_code=resp.status_code,
)
@_check_not_closed
def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
"""Send a GET request and returns the deserialized response payload."""
if isinstance(params, BaseModel):
params: Dict[str, Any] = params.dict(exclude_none=True)
with self.session.get(
urljoin(self.url, uri),
params=params,
headers=self.headers,
timeout=(self.connection_timeout, self.read_timeout),
) as resp:
self._check_status(resp)
return resp.json()
@_check_not_closed
def post(
self,
uri: str,
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
params: Optional[Dict[str, Any]] = None,
content_type: Optional[str] = None,
deserialize: Callable = lambda resp: resp.json(),
request_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Send a POST request and returns the deserialized response payload.
Parameters
----------
uri : str
The uri to send the POST request to.
data: Union[Dict[str, Any], BaseModel]
request_id: Optional[str]
Optional client side request id to be sent in the request headers.
"""
if isinstance(data, BaseModel):
data: Dict[str, Any] = data.dict(exclude_none=True)
if isinstance(data, bytes):
req_kwargs = {"data": data}
else:
req_kwargs = {"json": data}
headers = self.headers.copy()
if content_type is not None:
headers["content-type"] = content_type
if request_id is not None:
headers["x-request-id"] = request_id
with self.session.post(
urljoin(self.url, uri),
headers=headers,
params=params,
timeout=(self.connection_timeout, self.read_timeout),
**req_kwargs,
) as resp:
self._check_status(resp)
return deserialize(resp)
@_check_not_closed
def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]:
"""List all tables in the database."""
if page_token is None:
page_token = ""
json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
return json["tables"]
@_check_not_closed
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query a table."""
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
return VectorQueryResult(tbl)
def mount_retry_adapter_for_table(self, table_name: str) -> None:
"""
Adds an http adapter to session that will retry retryable requests to the table.
"""
retry_options = retry_adapter_options(methods=["GET", "POST"])
retry_adapter_instance = retry_adapter(retry_options)
session = self.session
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
)
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
retry_adapter_instance,
)
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
retry_adapter_instance,
)
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
return {
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
"backoff_factor": float(
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
),
"backoff_jitter": float(
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
),
"statuses": [
int(i.strip())
for i in os.environ.get(
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
).split(",")
],
"methods": methods,
}
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
total_retries = options["retries"]
connect_retries = options["connect_retries"]
read_retries = options["read_retries"]
backoff_factor = options["backoff_factor"]
backoff_jitter = options["backoff_jitter"]
statuses = options["statuses"]
methods = frozenset(options["methods"])
logging.debug(
f"Setting up retry adapter with {total_retries} retries," # noqa G003
+ f"connect retries {connect_retries}, read retries {read_retries},"
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
+ f"methods {methods}"
)
return HTTPAdapter(
max_retries=Retry(
total=total_retries,
connect=connect_retries,
read=read_retries,
backoff_factor=backoff_factor,
backoff_jitter=backoff_jitter,
status_forcelist=statuses,
allowed_methods=methods,
)
)

View File

@@ -1,115 +0,0 @@
# Copyright 2024 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This module contains an adapter that will close connections if they have not been
# used before a certain timeout. This is necessary because some load balancers will
# close connections after a certain amount of time, but the request module may not yet
# have received the FIN/ACK and will try to reuse the connection.
#
# TODO some of the code here can be simplified if/when this PR is merged:
# https://github.com/urllib3/urllib3/pull/3275
import datetime
import logging
import os
from requests.adapters import HTTPAdapter
from urllib3.connection import HTTPSConnection
from urllib3.connectionpool import HTTPSConnectionPool
from urllib3.poolmanager import PoolManager
def get_client_connection_timeout() -> int:
return int(os.environ.get("LANCE_CLIENT_CONNECTION_TIMEOUT", "300"))
class LanceDBHTTPSConnection(HTTPSConnection):
"""
HTTPSConnection that tracks the last time it was used.
"""
idle_timeout: datetime.timedelta
last_activity: datetime.datetime
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_activity = datetime.datetime.now()
def request(self, *args, **kwargs):
self.last_activity = datetime.datetime.now()
super().request(*args, **kwargs)
def is_expired(self):
return datetime.datetime.now() - self.last_activity > self.idle_timeout
def LanceDBHTTPSConnectionPoolFactory(client_idle_timeout: int):
"""
Creates a connection pool class that can be used to close idle connections.
"""
class LanceDBHTTPSConnectionPool(HTTPSConnectionPool):
# override the connection class
ConnectionCls = LanceDBHTTPSConnection
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _get_conn(self, timeout: float | None = None):
logging.debug("Getting https connection")
conn = super()._get_conn(timeout)
if conn.is_expired():
logging.debug("Closing expired connection")
conn.close()
return conn
def _new_conn(self):
conn = super()._new_conn()
conn.idle_timeout = datetime.timedelta(seconds=client_idle_timeout)
return conn
return LanceDBHTTPSConnectionPool
class LanceDBClientPoolManager(PoolManager):
def __init__(
self, client_idle_timeout: int, num_pools: int, maxsize: int, **kwargs
):
super().__init__(num_pools=num_pools, maxsize=maxsize, **kwargs)
# inject our connection pool impl
connection_pool_class = LanceDBHTTPSConnectionPoolFactory(
client_idle_timeout=client_idle_timeout
)
self.pool_classes_by_scheme["https"] = connection_pool_class
def LanceDBClientHTTPAdapterFactory():
"""
Creates an HTTPAdapter class that can be used to close idle connections
"""
# closure over the timeout
client_idle_timeout = get_client_connection_timeout()
class LanceDBClientRequestHTTPAdapter(HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False):
# inject our pool manager impl
self.poolmanager = LanceDBClientPoolManager(
client_idle_timeout=client_idle_timeout,
num_pools=connections,
maxsize=maxsize,
block=block,
)
return LanceDBClientRequestHTTPAdapter

View File

@@ -11,13 +11,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
from datetime import timedelta
import logging import logging
import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Iterable, List, Optional, Union from typing import Any, Dict, Iterable, List, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import warnings
from cachetools import TTLCache from lancedb import connect_async
from lancedb.remote import ClientConfig
import pyarrow as pa import pyarrow as pa
from overrides import override from overrides import override
@@ -25,10 +28,8 @@ from ..common import DATA
from ..db import DBConnection from ..db import DBConnection
from ..embeddings import EmbeddingFunctionConfig from ..embeddings import EmbeddingFunctionConfig
from ..pydantic import LanceModel from ..pydantic import LanceModel
from ..table import Table, sanitize_create_table from ..table import Table
from ..util import validate_table_name from ..util import validate_table_name
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
class RemoteDBConnection(DBConnection): class RemoteDBConnection(DBConnection):
@@ -41,26 +42,70 @@ class RemoteDBConnection(DBConnection):
region: str, region: str,
host_override: Optional[str] = None, host_override: Optional[str] = None,
request_thread_pool: Optional[ThreadPoolExecutor] = None, request_thread_pool: Optional[ThreadPoolExecutor] = None,
connection_timeout: float = 120.0, client_config: Union[ClientConfig, Dict[str, Any], None] = None,
read_timeout: float = 300.0, connection_timeout: Optional[float] = None,
read_timeout: Optional[float] = None,
): ):
"""Connect to a remote LanceDB database.""" """Connect to a remote LanceDB database."""
if isinstance(client_config, dict):
client_config = ClientConfig(**client_config)
elif client_config is None:
client_config = ClientConfig()
# These are legacy options from the old Python-based client. We keep them
# here for backwards compatibility, but will remove them in a future release.
if request_thread_pool is not None:
warnings.warn(
"request_thread_pool is no longer used and will be removed in "
"a future release.",
DeprecationWarning,
)
if connection_timeout is not None:
warnings.warn(
"connection_timeout is deprecated and will be removed in a future "
"release. Please use client_config.timeout_config.connect_timeout "
"instead.",
DeprecationWarning,
)
client_config.timeout_config.connect_timeout = timedelta(
seconds=connection_timeout
)
if read_timeout is not None:
warnings.warn(
"read_timeout is deprecated and will be removed in a future release. "
"Please use client_config.timeout_config.read_timeout instead.",
DeprecationWarning,
)
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)
parsed = urlparse(db_url) parsed = urlparse(db_url)
if parsed.scheme != "db": if parsed.scheme != "db":
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://") raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self._uri = str(db_url)
self.db_name = parsed.netloc self.db_name = parsed.netloc
self.api_key = api_key
self._client = RestfulLanceDBClient( import nest_asyncio
self.db_name,
region, nest_asyncio.apply()
api_key, try:
host_override, self._loop = asyncio.get_running_loop()
connection_timeout=connection_timeout, except RuntimeError:
read_timeout=read_timeout, self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self.client_config = client_config
self._conn = self._loop.run_until_complete(
connect_async(
db_url,
api_key=api_key,
region=region,
host_override=host_override,
client_config=client_config,
)
) )
self._request_thread_pool = request_thread_pool
self._table_cache = TTLCache(maxsize=10000, ttl=300)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"RemoteConnect(name={self.db_name})" return f"RemoteConnect(name={self.db_name})"
@@ -82,16 +127,9 @@ class RemoteDBConnection(DBConnection):
------- -------
An iterator of table names. An iterator of table names.
""" """
while True: return self._loop.run_until_complete(
result = self._client.list_tables(limit, page_token) self._conn.table_names(start_after=page_token, limit=limit)
)
if len(result) > 0:
page_token = result[len(result) - 1]
else:
break
for item in result:
self._table_cache[item] = True
yield item
@override @override
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table: def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
@@ -108,20 +146,14 @@ class RemoteDBConnection(DBConnection):
""" """
from .table import RemoteTable from .table import RemoteTable
self._client.mount_retry_adapter_for_table(name)
if index_cache_size is not None: if index_cache_size is not None:
logging.info( logging.info(
"index_cache_size is ignored in LanceDb Cloud" "index_cache_size is ignored in LanceDb Cloud"
" (there is no local cache to configure)" " (there is no local cache to configure)"
) )
# check if table exists table = self._loop.run_until_complete(self._conn.open_table(name))
if self._table_cache.get(name) is None: return RemoteTable(table, self.db_name, self._loop)
self._client.post(f"/v1/table/{name}/describe/")
self._table_cache[name] = True
return RemoteTable(self, name)
@override @override
def create_table( def create_table(
@@ -233,27 +265,20 @@ class RemoteDBConnection(DBConnection):
"Please vote https://github.com/lancedb/lancedb/issues/626 " "Please vote https://github.com/lancedb/lancedb/issues/626 "
"for this feature." "for this feature."
) )
if mode is not None:
logging.warning("mode is not yet supported on LanceDB Cloud.")
data, schema = sanitize_create_table(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
from .table import RemoteTable from .table import RemoteTable
data = to_ipc_binary(data) table = self._loop.run_until_complete(
request_id = uuid.uuid4().hex self._conn.create_table(
name,
self._client.post( data,
f"/v1/table/{name}/create/", mode=mode,
data=data, schema=schema,
request_id=request_id, on_bad_vectors=on_bad_vectors,
content_type=ARROW_STREAM_CONTENT_TYPE, fill_value=fill_value,
) )
)
self._table_cache[name] = True return RemoteTable(table, self.db_name, self._loop)
return RemoteTable(self, name)
@override @override
def drop_table(self, name: str): def drop_table(self, name: str):
@@ -264,11 +289,7 @@ class RemoteDBConnection(DBConnection):
name: str name: str
The name of the table. The name of the table.
""" """
self._loop.run_until_complete(self._conn.drop_table(name))
self._client.post(
f"/v1/table/{name}/drop/",
)
self._table_cache.pop(name, default=None)
@override @override
def rename_table(self, cur_name: str, new_name: str): def rename_table(self, cur_name: str, new_name: str):
@@ -281,12 +302,7 @@ class RemoteDBConnection(DBConnection):
new_name: str new_name: str
The new name of the table. The new name of the table.
""" """
self._client.post( self._loop.run_until_complete(self._conn.rename_table(cur_name, new_name))
f"/v1/table/{cur_name}/rename/",
data={"new_table_name": new_name},
)
self._table_cache.pop(cur_name, default=None)
self._table_cache[new_name] = True
async def close(self): async def close(self):
"""Close the connection to the database.""" """Close the connection to the database."""

View File

@@ -11,53 +11,56 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging import logging
import uuid
from concurrent.futures import Future
from functools import cached_property from functools import cached_property
from typing import Dict, Iterable, List, Optional, Union, Literal from typing import Dict, Iterable, List, Optional, Union, Literal
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
import pyarrow as pa import pyarrow as pa
from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder from lancedb.merge import LanceMergeInsertBuilder
from lancedb.embeddings import EmbeddingFunctionRegistry from lancedb.embeddings import EmbeddingFunctionRegistry
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
from ..table import Query, Table, _sanitize_data from ..table import AsyncTable, Query, Table
from ..util import value_to_sql, infer_vector_column_name
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE
from .db import RemoteDBConnection
class RemoteTable(Table): class RemoteTable(Table):
def __init__(self, conn: RemoteDBConnection, name: str): def __init__(
self._conn = conn self,
self.name = name table: AsyncTable,
db_name: str,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
self._loop = loop
self._table = table
self.db_name = db_name
@property
def name(self) -> str:
"""The name of the table"""
return self._table.name
def __repr__(self) -> str: def __repr__(self) -> str:
return f"RemoteTable({self._conn.db_name}.{self.name})" return f"RemoteTable({self.db_name}.{self.name})"
def __len__(self) -> int: def __len__(self) -> int:
self.count_rows(None) self.count_rows(None)
@cached_property @property
def schema(self) -> pa.Schema: def schema(self) -> pa.Schema:
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) """The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
of this Table of this Table
""" """
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/") return self._loop.run_until_complete(self._table.schema())
schema = json_to_schema(resp["schema"])
return schema
@property @property
def version(self) -> int: def version(self) -> int:
"""Get the current version of the table""" """Get the current version of the table"""
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/") return self._loop.run_until_complete(self._table.version())
return resp["version"]
@cached_property @cached_property
def embedding_functions(self) -> dict: def embedding_functions(self) -> dict:
@@ -84,20 +87,18 @@ class RemoteTable(Table):
def list_indices(self): def list_indices(self):
"""List all the indices on the table""" """List all the indices on the table"""
resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/") return self._loop.run_until_complete(self._table.list_indices())
return resp
def index_stats(self, index_uuid: str): def index_stats(self, index_uuid: str):
"""List all the stats of a specified index""" """List all the stats of a specified index"""
resp = self._conn._client.post( return self._loop.run_until_complete(self._table.index_stats(index_uuid))
f"/v1/table/{self.name}/index/{index_uuid}/stats/"
)
return resp
def create_scalar_index( def create_scalar_index(
self, self,
column: str, column: str,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar", index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
*,
replace: bool = False,
): ):
"""Creates a scalar index """Creates a scalar index
Parameters Parameters
@@ -107,20 +108,23 @@ class RemoteTable(Table):
or string column. or string column.
index_type : str index_type : str
The index type of the scalar index. Must be "scalar" (BTREE), The index type of the scalar index. Must be "scalar" (BTREE),
"BTREE", "BITMAP", or "LABEL_LIST" "BTREE", "BITMAP", or "LABEL_LIST",
replace : bool
If True, replace the existing index with the new one.
""" """
if index_type == "scalar" or index_type == "BTREE":
config = BTree()
elif index_type == "BITMAP":
config = Bitmap()
elif index_type == "LABEL_LIST":
config = LabelList()
else:
raise ValueError(f"Unknown index type: {index_type}")
data = { self._loop.run_until_complete(
"column": column, self._table.create_index(column, config=config, replace=replace)
"index_type": index_type,
"replace": True,
}
resp = self._conn._client.post(
f"/v1/table/{self.name}/create_scalar_index/", data=data
) )
return resp
def create_fts_index( def create_fts_index(
self, self,
column: str, column: str,
@@ -128,15 +132,10 @@ class RemoteTable(Table):
replace: bool = False, replace: bool = False,
with_position: bool = True, with_position: bool = True,
): ):
data = { config = FTS(with_position=with_position)
"column": column, self._loop.run_until_complete(
"index_type": "FTS", self._table.create_index(column, config=config, replace=replace)
"replace": replace,
}
resp = self._conn._client.post(
f"/v1/table/{self.name}/create_index/", data=data
) )
return resp
def create_index( def create_index(
self, self,
@@ -204,17 +203,22 @@ class RemoteTable(Table):
"Existing indexes will always be replaced." "Existing indexes will always be replaced."
) )
data = { index_type = index_type.upper()
"column": vector_column_name, if index_type == "VECTOR" or index_type == "IVF_PQ":
"index_type": index_type, config = IvfPq(distance_type=metric)
"metric_type": metric, elif index_type == "IVF_HNSW_PQ":
"index_cache_size": index_cache_size, config = HnswPq(distance_type=metric)
} elif index_type == "IVF_HNSW_SQ":
resp = self._conn._client.post( config = HnswSq(distance_type=metric)
f"/v1/table/{self.name}/create_index/", data=data else:
raise ValueError(
f"Unknown vector index type: {index_type}. Valid options are"
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
) )
return resp self._loop.run_until_complete(
self._table.create_index(vector_column_name, config=config)
)
def add( def add(
self, self,
@@ -246,22 +250,10 @@ class RemoteTable(Table):
The value to use when filling vectors. Only used if on_bad_vectors="fill". The value to use when filling vectors. Only used if on_bad_vectors="fill".
""" """
data, _ = _sanitize_data( self._loop.run_until_complete(
data, self._table.add(
self.schema, data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
metadata=self.schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
) )
payload = to_ipc_binary(data)
request_id = uuid.uuid4().hex
self._conn._client.post(
f"/v1/table/{self.name}/insert/",
data=payload,
params={"request_id": request_id, "mode": mode},
content_type=ARROW_STREAM_CONTENT_TYPE,
) )
def search( def search(
@@ -337,12 +329,6 @@ class RemoteTable(Table):
# empty query builder is not supported in saas, raise error # empty query builder is not supported in saas, raise error
if query is None and query_type != "hybrid": if query is None and query_type != "hybrid":
raise ValueError("Empty query is not supported") raise ValueError("Empty query is not supported")
vector_column_name = infer_vector_column_name(
schema=self.schema,
query_type=query_type,
query=query,
vector_column_name=vector_column_name,
)
return LanceQueryBuilder.create( return LanceQueryBuilder.create(
self, self,
@@ -356,38 +342,10 @@ class RemoteTable(Table):
def _execute_query( def _execute_query(
self, query: Query, batch_size: Optional[int] = None self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader: ) -> pa.RecordBatchReader:
if ( return self._loop.run_until_complete(
query.vector is not None self._table._execute_query(query, batch_size=batch_size)
and len(query.vector) > 0
and not isinstance(query.vector[0], float)
):
if self._conn._request_thread_pool is None:
def submit(name, q):
f = Future()
f.set_result(self._conn._client.query(name, q))
return f
else:
def submit(name, q):
return self._conn._request_thread_pool.submit(
self._conn._client.query, name, q
) )
results = []
for v in query.vector:
v = list(v)
q = query.copy()
q.vector = v
results.append(submit(self.name, q))
return pa.concat_tables(
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
).to_reader()
else:
result = self._conn._client.query(self.name, query)
return result.to_arrow().to_reader()
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder] """Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
that can be used to create a "merge insert" operation. that can be used to create a "merge insert" operation.
@@ -403,42 +361,8 @@ class RemoteTable(Table):
on_bad_vectors: str, on_bad_vectors: str,
fill_value: float, fill_value: float,
): ):
data, _ = _sanitize_data( self._loop.run_until_complete(
new_data, self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)
params = {}
if len(merge._on) != 1:
raise ValueError(
"RemoteTable only supports a single on key in merge_insert"
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
if merge._when_matched_update_all_condition is not None:
params["when_matched_update_all_filt"] = (
merge._when_matched_update_all_condition
)
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()
params["when_not_matched_by_source_delete"] = str(
merge._when_not_matched_by_source_delete
).lower()
if merge._when_not_matched_by_source_condition is not None:
params["when_not_matched_by_source_delete_filt"] = (
merge._when_not_matched_by_source_condition
)
self._conn._client.post(
f"/v1/table/{self.name}/merge_insert/",
data=payload,
params=params,
content_type=ARROW_STREAM_CONTENT_TYPE,
) )
def delete(self, predicate: str): def delete(self, predicate: str):
@@ -488,8 +412,7 @@ class RemoteTable(Table):
x vector _distance # doctest: +SKIP x vector _distance # doctest: +SKIP
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP 0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
""" """
payload = {"predicate": predicate} self._loop.run_until_complete(self._table.delete(predicate))
self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload)
def update( def update(
self, self,
@@ -539,18 +462,9 @@ class RemoteTable(Table):
2 2 [10.0, 10.0] # doctest: +SKIP 2 2 [10.0, 10.0] # doctest: +SKIP
""" """
if values is not None and values_sql is not None: self._loop.run_until_complete(
raise ValueError("Only one of values or values_sql can be provided") self._table.update(where=where, updates=values, updates_sql=values_sql)
if values is None and values_sql is None: )
raise ValueError("Either values or values_sql must be provided")
if values is not None:
updates = [[k, value_to_sql(v)] for k, v in values.items()]
else:
updates = [[k, v] for k, v in values_sql.items()]
payload = {"predicate": where, "updates": updates}
self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload)
def cleanup_old_versions(self, *_): def cleanup_old_versions(self, *_):
"""cleanup_old_versions() is not supported on the LanceDB cloud""" """cleanup_old_versions() is not supported on the LanceDB cloud"""
@@ -565,11 +479,7 @@ class RemoteTable(Table):
) )
def count_rows(self, filter: Optional[str] = None) -> int: def count_rows(self, filter: Optional[str] = None) -> int:
payload = {"predicate": filter} return self._loop.run_until_complete(self._table.count_rows(filter))
resp = self._conn._client.post(
f"/v1/table/{self.name}/count_rows/", data=payload
)
return resp
def add_columns(self, transforms: Dict[str, str]): def add_columns(self, transforms: Dict[str, str]):
raise NotImplementedError( raise NotImplementedError(

View File

@@ -12,7 +12,6 @@
# limitations under the License. # limitations under the License.
import os import os
import requests
from functools import cached_property from functools import cached_property
from typing import Union from typing import Union
@@ -57,6 +56,8 @@ class JinaReranker(Reranker):
@cached_property @cached_property
def _client(self): def _client(self):
import requests
if os.environ.get("JINA_API_KEY") is None and self.api_key is None: if os.environ.get("JINA_API_KEY") is None and self.api_key is None:
raise ValueError( raise ValueError(
"JINA_API_KEY not set. Either set it in your environment or \ "JINA_API_KEY not set. Either set it in your environment or \

View File

@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from numpy import NaN from numpy import nan
import pyarrow as pa import pyarrow as pa
from .base import Reranker from .base import Reranker
@@ -71,7 +71,7 @@ class LinearCombinationReranker(Reranker):
elif self.score == "all": elif self.score == "all":
results = results.append_column( results = results.append_column(
"_distance", "_distance",
pa.array([NaN] * len(fts_results), type=pa.float32()), pa.array([nan] * len(fts_results), type=pa.float32()),
) )
return results return results
@@ -92,7 +92,7 @@ class LinearCombinationReranker(Reranker):
elif self.score == "all": elif self.score == "all":
results = results.append_column( results = results.append_column(
"_score", "_score",
pa.array([NaN] * len(vector_results), type=pa.float32()), pa.array([nan] * len(vector_results), type=pa.float32()),
) )
return results return results

View File

@@ -62,7 +62,7 @@ if TYPE_CHECKING:
from lance.dataset import CleanupStats, ReaderLike from lance.dataset import CleanupStats, ReaderLike
from ._lancedb import Table as LanceDBTable, OptimizeStats from ._lancedb import Table as LanceDBTable, OptimizeStats
from .db import LanceDBConnection from .db import LanceDBConnection
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS, HnswPq, HnswSq
pd = safe_import_pandas() pd = safe_import_pandas()
pl = safe_import_polars() pl = safe_import_polars()
@@ -948,7 +948,9 @@ class Table(ABC):
return _table_uri(self._conn.uri, self.name) return _table_uri(self._conn.uri, self.name)
def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]: def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]:
if get_uri_scheme(self._dataset_uri) != "file": from .remote.table import RemoteTable
if isinstance(self, RemoteTable) or get_uri_scheme(self._dataset_uri) != "file":
return ("", None, False) return ("", None, False)
path = join_uri(self._dataset_uri, "_indices", "fts") path = join_uri(self._dataset_uri, "_indices", "fts")
fs, path = fs_from_uri(path) fs, path = fs_from_uri(path)
@@ -2382,7 +2384,9 @@ class AsyncTable:
column: str, column: str,
*, *,
replace: Optional[bool] = None, replace: Optional[bool] = None,
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None, config: Optional[
Union[IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
] = None,
): ):
"""Create an index to speed up queries """Create an index to speed up queries
@@ -2535,7 +2539,44 @@ class AsyncTable:
async def _execute_query( async def _execute_query(
self, query: Query, batch_size: Optional[int] = None self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader: ) -> pa.RecordBatchReader:
pass # The sync remote table calls into this method, so we need to map the
# query to the async version of the query and run that here. This is only
# used for that code path right now.
async_query = self.query().limit(query.k)
if query.offset > 0:
async_query = async_query.offset(query.offset)
if query.columns:
async_query = async_query.select(query.columns)
if query.filter:
async_query = async_query.where(query.filter)
if query.fast_search:
async_query = async_query.fast_search()
if query.with_row_id:
async_query = async_query.with_row_id()
if query.vector:
async_query = (
async_query.nearest_to(query.vector)
.distance_type(query.metric)
.nprobes(query.nprobes)
)
if query.refine_factor:
async_query = async_query.refine_factor(query.refine_factor)
if query.vector_column:
async_query = async_query.column(query.vector_column)
if not query.prefilter:
async_query = async_query.postfilter()
if isinstance(query.full_text_query, str):
async_query = async_query.nearest_to_text(query.full_text_query)
elif isinstance(query.full_text_query, dict):
fts_query = query.full_text_query["query"]
fts_columns = query.full_text_query.get("columns", []) or []
async_query = async_query.nearest_to_text(fts_query, columns=fts_columns)
table = await async_query.to_arrow()
return table.to_reader()
async def _do_merge( async def _do_merge(
self, self,
@@ -2781,7 +2822,7 @@ class AsyncTable:
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000) cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
return await self._inner.optimize(cleanup_older_than, delete_unverified) return await self._inner.optimize(cleanup_older_than, delete_unverified)
async def list_indices(self) -> IndexConfig: async def list_indices(self) -> Iterable[IndexConfig]:
""" """
List all indices that have been created with Self::create_index List all indices that have been created with Self::create_index
""" """
@@ -2865,3 +2906,8 @@ class IndexStatistics:
] ]
distance_type: Optional[Literal["l2", "cosine", "dot"]] = None distance_type: Optional[Literal["l2", "cosine", "dot"]] = None
num_indices: Optional[int] = None num_indices: Optional[int] = None
# This exists for backwards compatibility with an older API, which returned
# a dictionary instead of a class.
def __getitem__(self, key):
return getattr(self, key)

View File

@@ -18,7 +18,6 @@ import lancedb
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pytest import pytest
import requests
from lancedb.embeddings import get_registry from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
@@ -108,6 +107,7 @@ def test_basic_text_embeddings(alias, tmp_path):
@pytest.mark.slow @pytest.mark.slow
def test_openclip(tmp_path): def test_openclip(tmp_path):
import requests
from PIL import Image from PIL import Image
db = lancedb.connect(tmp_path) db = lancedb.connect(tmp_path)

View File

@@ -235,6 +235,29 @@ async def test_search_fts_async(async_table):
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list() results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
assert len(results) == 5 assert len(results) == 5
expected_count = await async_table.count_rows(
"count > 5000 and contains(text, 'puppy')"
)
expected_count = min(expected_count, 10)
limited_results_pre_filter = await (
async_table.query()
.nearest_to_text("puppy")
.where("count > 5000")
.limit(10)
.to_list()
)
assert len(limited_results_pre_filter) == expected_count
limited_results_post_filter = await (
async_table.query()
.nearest_to_text("puppy")
.where("count > 5000")
.limit(10)
.postfilter()
.to_list()
)
assert len(limited_results_post_filter) <= expected_count
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_fts_specify_column_async(async_table): async def test_search_fts_specify_column_async(async_table):

View File

@@ -49,7 +49,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
# Can recreate if replace=True # Can recreate if replace=True
await some_table.create_index("id", replace=True) await some_table.create_index("id", replace=True)
indices = await some_table.list_indices() indices = await some_table.list_indices()
assert str(indices) == '[Index(BTree, columns=["id"])]' assert str(indices) == '[Index(BTree, columns=["id"], name="id_idx")]'
assert len(indices) == 1 assert len(indices) == 1
assert indices[0].index_type == "BTree" assert indices[0].index_type == "BTree"
assert indices[0].columns == ["id"] assert indices[0].columns == ["id"]
@@ -64,7 +64,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
async def test_create_bitmap_index(some_table: AsyncTable): async def test_create_bitmap_index(some_table: AsyncTable):
await some_table.create_index("id", config=Bitmap()) await some_table.create_index("id", config=Bitmap())
indices = await some_table.list_indices() indices = await some_table.list_indices()
assert str(indices) == '[Index(Bitmap, columns=["id"])]' assert str(indices) == '[Index(Bitmap, columns=["id"], name="id_idx")]'
indices = await some_table.list_indices() indices = await some_table.list_indices()
assert len(indices) == 1 assert len(indices) == 1
index_name = indices[0].name index_name = indices[0].name
@@ -80,7 +80,7 @@ async def test_create_bitmap_index(some_table: AsyncTable):
async def test_create_label_list_index(some_table: AsyncTable): async def test_create_label_list_index(some_table: AsyncTable):
await some_table.create_index("tags", config=LabelList()) await some_table.create_index("tags", config=LabelList())
indices = await some_table.list_indices() indices = await some_table.list_indices()
assert str(indices) == '[Index(LabelList, columns=["tags"])]' assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]'
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -17,6 +17,7 @@ from typing import Optional
import lance import lance
import lancedb import lancedb
from lancedb.index import IvfPq
import numpy as np import numpy as np
import pandas.testing as tm import pandas.testing as tm
import pyarrow as pa import pyarrow as pa
@@ -330,6 +331,12 @@ async def test_query_async(table_async: AsyncTable):
# Also check an empty query # Also check an empty query
await check_query(table_async.query().where("id < 0"), expected_num_rows=0) await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
# with row id
await check_query(
table_async.query().select(["id", "vector"]).with_row_id(),
expected_columns=["id", "vector", "_rowid"],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_to_arrow_async(table_async: AsyncTable): async def test_query_to_arrow_async(table_async: AsyncTable):
@@ -358,6 +365,25 @@ async def test_query_to_pandas_async(table_async: AsyncTable):
assert df.shape == (0, 4) assert df.shape == (0, 4)
@pytest.mark.asyncio
async def test_fast_search_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
vectors = pa.FixedShapeTensorArray.from_numpy_ndarray(
np.random.rand(256, 32)
).storage
table = await db.create_table("test", pa.table({"vector": vectors}))
await table.create_index(
"vector", config=IvfPq(num_partitions=1, num_sub_vectors=1)
)
await table.add(pa.table({"vector": vectors}))
q = [1.0] * 32
plan = await table.query().nearest_to(q).explain_plan(True)
assert "LanceScan" in plan
plan = await table.query().nearest_to(q).fast_search().explain_plan(True)
assert "LanceScan" not in plan
def test_explain_plan(table): def test_explain_plan(table):
q = LanceVectorQueryBuilder(table, [0, 0], "vector") q = LanceVectorQueryBuilder(table, [0, 0], "vector")
plan = q.explain_plan(verbose=True) plan = q.explain_plan(verbose=True)

View File

@@ -1,96 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import attrs
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from aiohttp import web
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
@attrs.define
class MockLanceDBServer:
runner: web.AppRunner = attrs.field(init=False)
site: web.TCPSite = attrs.field(init=False)
async def query_handler(self, request: web.Request) -> web.Response:
table_name = request.match_info["table_name"]
assert table_name == "test_table"
await request.json()
# TODO: do some matching
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
ids = pd.Series(range(10), name="id")
df = pd.DataFrame([vecs, ids]).T
batch = pa.RecordBatch.from_pandas(
df,
schema=pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 128)),
pa.field("id", pa.int64()),
]
),
)
sink = pa.BufferOutputStream()
with pa.ipc.new_file(sink, batch.schema) as writer:
writer.write_batch(batch)
return web.Response(body=sink.getvalue().to_pybytes())
async def setup(self):
app = web.Application()
app.add_routes([web.post("/table/{table_name}", self.query_handler)])
self.runner = web.AppRunner(app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, "localhost", 8111)
async def start(self):
await self.site.start()
async def stop(self):
await self.runner.cleanup()
@pytest.mark.skip(reason="flaky somehow, fix later")
@pytest.mark.asyncio
async def test_e2e_with_mock_server():
mock_server = MockLanceDBServer()
await mock_server.setup()
await mock_server.start()
try:
with RestfulLanceDBClient("lancedb+http://localhost:8111") as client:
df = (
await client.query(
"test_table",
VectorQuery(
vector=np.random.rand(128).tolist(),
k=10,
_metric="L2",
columns=["id", "vector"],
),
)
).to_pandas()
assert "vector" in df.columns
assert "id" in df.columns
assert client.closed
finally:
# make sure we don't leak resources
await mock_server.stop()

View File

@@ -2,91 +2,19 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors # SPDX-FileCopyrightText: Copyright The LanceDB Authors
import contextlib import contextlib
from datetime import timedelta
import http.server import http.server
import json
import threading import threading
from unittest.mock import MagicMock from unittest.mock import MagicMock
import uuid import uuid
import lancedb import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.remote import ClientConfig
from lancedb.remote.errors import HttpError, RetryError from lancedb.remote.errors import HttpError, RetryError
import pyarrow as pa
from lancedb.remote.client import VectorQuery, VectorQueryResult
import pytest import pytest
import pyarrow as pa
class FakeLanceDBClient:
def close(self):
pass
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
assert table_name == "test"
t = pa.schema([]).empty_table()
return VectorQueryResult(t)
def post(self, path: str):
pass
def mount_retry_adapter_for_table(self, table_name: str):
pass
def test_remote_db():
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
table.search([1.0, 2.0]).to_pandas()
def test_create_empty_table():
client = MagicMock()
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
conn._client = client
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
client.post.return_value = {"status": "ok"}
table = conn.create_table("test", schema=schema)
assert table.name == "test"
assert client.post.call_args[0][0] == "/v1/table/test/create/"
json_schema = {
"fields": [
{
"name": "vector",
"nullable": True,
"type": {
"type": "fixed_size_list",
"fields": [
{"name": "item", "nullable": True, "type": {"type": "float"}}
],
"length": 2,
},
},
]
}
client.post.return_value = {"schema": json_schema}
assert table.schema == schema
assert client.post.call_args[0][0] == "/v1/table/test/describe/"
client.post.return_value = 0
assert table.count_rows(None) == 0
def test_create_table_with_recordbatches():
client = MagicMock()
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
conn._client = client
batch = pa.RecordBatch.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0]])], ["vector"])
client.post.return_value = {"status": "ok"}
table = conn.create_table("test", [batch], schema=batch.schema)
assert table.name == "test"
assert client.post.call_args[0][0] == "/v1/table/test/create/"
def make_mock_http_handler(handler): def make_mock_http_handler(handler):
@@ -100,8 +28,35 @@ def make_mock_http_handler(handler):
return MockLanceDBHandler return MockLanceDBHandler
@contextlib.contextmanager
def mock_lancedb_connection(handler):
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
handle = threading.Thread(target=server.serve_forever)
handle.start()
db = lancedb.connect(
"db://dev",
api_key="fake",
host_override="http://localhost:8080",
client_config={
"retry_config": {"retries": 2},
"timeout_config": {
"connect_timeout": 1,
},
},
)
try:
yield db
finally:
server.shutdown()
handle.join()
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def mock_lancedb_connection(handler): async def mock_lancedb_connection_async(handler):
with http.server.HTTPServer( with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler) ("localhost", 8080), make_mock_http_handler(handler)
) as server: ) as server:
@@ -143,7 +98,7 @@ async def test_async_remote_db():
request.end_headers() request.end_headers()
request.wfile.write(b'{"tables": []}') request.wfile.write(b'{"tables": []}')
async with mock_lancedb_connection(handler) as db: async with mock_lancedb_connection_async(handler) as db:
table_names = await db.table_names() table_names = await db.table_names()
assert table_names == [] assert table_names == []
@@ -159,12 +114,12 @@ async def test_http_error():
request.end_headers() request.end_headers()
request.wfile.write(b"Internal Server Error") request.wfile.write(b"Internal Server Error")
async with mock_lancedb_connection(handler) as db: async with mock_lancedb_connection_async(handler) as db:
with pytest.raises(HttpError, match="Internal Server Error") as exc_info: with pytest.raises(HttpError) as exc_info:
await db.table_names() await db.table_names()
assert exc_info.value.request_id == request_id_holder["request_id"] assert exc_info.value.request_id == request_id_holder["request_id"]
assert exc_info.value.status_code == 507 assert "Internal Server Error" in str(exc_info.value)
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -178,15 +133,225 @@ async def test_retry_error():
request.end_headers() request.end_headers()
request.wfile.write(b"Try again later") request.wfile.write(b"Try again later")
async with mock_lancedb_connection(handler) as db: async with mock_lancedb_connection_async(handler) as db:
with pytest.raises(RetryError, match="Hit retry limit") as exc_info: with pytest.raises(RetryError) as exc_info:
await db.table_names() await db.table_names()
assert exc_info.value.request_id == request_id_holder["request_id"] assert exc_info.value.request_id == request_id_holder["request_id"]
assert exc_info.value.status_code == 429
cause = exc_info.value.__cause__ cause = exc_info.value.__cause__
assert isinstance(cause, HttpError) assert isinstance(cause, HttpError)
assert "Try again later" in str(cause) assert "Try again later" in str(cause)
assert cause.request_id == request_id_holder["request_id"] assert cause.request_id == request_id_holder["request_id"]
assert cause.status_code == 429 assert cause.status_code == 429
@contextlib.contextmanager
def query_test_table(query_handler):
def handler(request):
if request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
elif request.path == "/v1/table/test/query/":
content_len = int(request.headers.get("Content-Length"))
body = request.rfile.read(content_len)
body = json.loads(body)
data = query_handler(body)
request.send_response(200)
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
request.end_headers()
with pa.ipc.new_file(request.wfile, schema=data.schema) as f:
f.write_table(data)
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
assert repr(db) == "RemoteConnect(name=dev)"
table = db.open_table("test")
assert repr(table) == "RemoteTable(dev.test)"
yield table
def test_query_sync_minimal():
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": False,
"refine_factor": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 20,
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
data = table.search([1, 2, 3]).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected
def test_query_sync_maximal():
def handler(body):
assert body == {
"distance_type": "cosine",
"k": 42,
"prefilter": True,
"refine_factor": 10,
"vector": [1.0, 2.0, 3.0],
"nprobes": 5,
"filter": "id > 0",
"columns": ["id", "name"],
"vector_column": "vector2",
"fast_search": True,
"with_row_id": True,
}
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
with query_test_table(handler) as table:
(
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
.metric("cosine")
.limit(42)
.refine_factor(10)
.nprobes(5)
.where("id > 0", prefilter=True)
.with_row_id(True)
.select(["id", "name"])
.to_list()
)
def test_query_sync_fts():
def handler(body):
assert body == {
"full_text_query": {
"query": "puppy",
"columns": [],
},
"k": 10,
"vector": [],
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
(table.search("puppy", query_type="fts").to_list())
def handler(body):
assert body == {
"full_text_query": {
"query": "puppy",
"columns": ["name", "description"],
},
"k": 42,
"vector": [],
"with_row_id": True,
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
(
table.search("puppy", query_type="fts", fts_columns=["name", "description"])
.with_row_id(True)
.limit(42)
.to_list()
)
def test_query_sync_hybrid():
def handler(body):
if "full_text_query" in body:
# FTS query
assert body == {
"full_text_query": {
"query": "puppy",
"columns": [],
},
"k": 42,
"vector": [],
"with_row_id": True,
}
return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]})
else:
# Vector query
assert body == {
"distance_type": "l2",
"k": 42,
"prefilter": False,
"refine_factor": None,
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"nprobes": 20,
"with_row_id": True,
}
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})
with query_test_table(handler) as table:
embedding_func = MockTextEmbeddingFunction()
embedding_config = MagicMock()
embedding_config.function = embedding_func
embedding_funcs = MagicMock()
embedding_funcs.get = MagicMock(return_value=embedding_config)
table.embedding_functions = embedding_funcs
(table.search("puppy", query_type="hybrid").limit(42).to_list())
def test_create_client():
mandatory_args = {
"uri": "db://dev",
"api_key": "fake-api-key",
"region": "us-east-1",
}
db = lancedb.connect(**mandatory_args)
assert isinstance(db.client_config, ClientConfig)
db = lancedb.connect(**mandatory_args, client_config={})
assert isinstance(db.client_config, ClientConfig)
db = lancedb.connect(
**mandatory_args,
client_config=ClientConfig(timeout_config={"connect_timeout": 42}),
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
db = lancedb.connect(
**mandatory_args,
client_config={"timeout_config": {"connect_timeout": timedelta(seconds=42)}},
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
db = lancedb.connect(
**mandatory_args, client_config=ClientConfig(retry_config={"retries": 42})
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.retry_config.retries == 42
db = lancedb.connect(
**mandatory_args, client_config={"retry_config": {"retries": 42}}
)
assert isinstance(db.client_config, ClientConfig)
assert db.client_config.retry_config.retries == 42
with pytest.warns(DeprecationWarning):
db = lancedb.connect(**mandatory_args, connection_timeout=42)
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
with pytest.warns(DeprecationWarning):
db = lancedb.connect(**mandatory_args, read_timeout=42)
assert db.client_config.timeout_config.read_timeout == timedelta(seconds=42)
with pytest.warns(DeprecationWarning):
lancedb.connect(**mandatory_args, request_thread_pool=10)

View File

@@ -170,6 +170,17 @@ impl Connection {
}) })
} }
pub fn rename_table(
self_: PyRef<'_, Self>,
old_name: String,
new_name: String,
) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
inner.rename_table(old_name, new_name).await.infer_error()
})
}
pub fn drop_table(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> { pub fn drop_table(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone(); let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {

View File

@@ -24,8 +24,8 @@ use lancedb::{
DistanceType, DistanceType,
}; };
use pyo3::{ use pyo3::{
exceptions::{PyRuntimeError, PyValueError}, exceptions::{PyKeyError, PyRuntimeError, PyValueError},
pyclass, pymethods, PyResult, pyclass, pymethods, IntoPy, PyObject, PyResult, Python,
}; };
use crate::util::parse_distance_type; use crate::util::parse_distance_type;
@@ -236,7 +236,21 @@ pub struct IndexConfig {
#[pymethods] #[pymethods]
impl IndexConfig { impl IndexConfig {
pub fn __repr__(&self) -> String { pub fn __repr__(&self) -> String {
format!("Index({}, columns={:?})", self.index_type, self.columns) format!(
"Index({}, columns={:?}, name=\"{}\")",
self.index_type, self.columns, self.name
)
}
// For backwards-compatibility with the old sync SDK, we also support getting
// attributes via __getitem__.
pub fn __getitem__(&self, key: String, py: Python<'_>) -> PyResult<PyObject> {
match key.as_str() {
"index_type" => Ok(self.index_type.clone().into_py(py)),
"columns" => Ok(self.columns.clone().into_py(py)),
"name" | "index_name" => Ok(self.name.clone().into_py(py)),
_ => Err(PyKeyError::new_err(format!("Invalid key: {}", key))),
}
} }
} }

View File

@@ -68,6 +68,18 @@ impl Query {
self.inner = self.inner.clone().offset(offset as usize); self.inner = self.inner.clone().offset(offset as usize);
} }
pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search();
}
pub fn with_row_id(&mut self) {
self.inner = self.inner.clone().with_row_id();
}
pub fn postfilter(&mut self) {
self.inner = self.inner.clone().postfilter();
}
pub fn nearest_to(&mut self, vector: Bound<'_, PyAny>) -> PyResult<VectorQuery> { pub fn nearest_to(&mut self, vector: Bound<'_, PyAny>) -> PyResult<VectorQuery> {
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?; let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
let array = make_array(data); let array = make_array(data);
@@ -146,6 +158,14 @@ impl VectorQuery {
self.inner = self.inner.clone().offset(offset as usize); self.inner = self.inner.clone().offset(offset as usize);
} }
pub fn fast_search(&mut self) {
self.inner = self.inner.clone().fast_search();
}
pub fn with_row_id(&mut self) {
self.inner = self.inner.clone().with_row_id();
}
pub fn column(&mut self, column: String) { pub fn column(&mut self, column: String) {
self.inner = self.inner.clone().column(&column); self.inner = self.inner.clone().column(&column);
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-node" name = "lancedb-node"
version = "0.11.1-beta.1" version = "0.12.0"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true
edition.workspace = true edition.workspace = true

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb" name = "lancedb"
version = "0.11.1-beta.1" version = "0.12.0"
edition.workspace = true edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true

View File

@@ -39,9 +39,6 @@ use crate::utils::validate_table_name;
use crate::Table; use crate::Table;
pub use lance_encoding::version::LanceFileVersion; pub use lance_encoding::version::LanceFileVersion;
#[cfg(feature = "remote")]
use log::warn;
pub const LANCE_FILE_EXTENSION: &str = "lance"; pub const LANCE_FILE_EXTENSION: &str = "lance";
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>; pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>;
@@ -719,8 +716,7 @@ impl ConnectBuilder {
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput { let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
message: "An api_key is required when connecting to LanceDb Cloud".to_string(), message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
})?; })?;
// TODO: remove this warning when the remote client is ready
warn!("The rust implementation of the remote client is not yet ready for use.");
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new( let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
&self.uri, &self.uri,
&api_key, &api_key,

View File

@@ -119,6 +119,7 @@ pub enum IndexType {
#[serde(alias = "LABEL_LIST")] #[serde(alias = "LABEL_LIST")]
LabelList, LabelList,
// FTS // FTS
#[serde(alias = "INVERTED", alias = "Inverted")]
FTS, FTS,
} }

View File

@@ -403,6 +403,26 @@ pub trait QueryBase {
/// By default, it is false. /// By default, it is false.
fn fast_search(self) -> Self; fn fast_search(self) -> Self;
/// If this is called then filtering will happen after the vector search instead of
/// before.
///
/// By default filtering will be performed before the vector search. This is how
/// filtering is typically understood to work. This prefilter step does add some
/// additional latency. Creating a scalar index on the filter column(s) can
/// often improve this latency. However, sometimes a filter is too complex or scalar
/// indices cannot be applied to the column. In these cases postfiltering can be
/// used instead of prefiltering to improve latency.
///
/// Post filtering applies the filter to the results of the vector search. This means
/// we only run the filter on a much smaller set of data. However, it can cause the
/// query to return fewer than `limit` results (or even no results) if none of the nearest
/// results match the filter.
///
/// Post filtering happens during the "refine stage" (described in more detail in
/// [`Self::refine_factor`]). This means that setting a higher refine factor can often
/// help restore some of the results lost by post filtering.
fn postfilter(self) -> Self;
/// Return the `_rowid` meta column from the Table. /// Return the `_rowid` meta column from the Table.
fn with_row_id(self) -> Self; fn with_row_id(self) -> Self;
} }
@@ -442,6 +462,11 @@ impl<T: HasQuery> QueryBase for T {
self self
} }
fn postfilter(mut self) -> Self {
self.mut_query().prefilter = false;
self
}
fn with_row_id(mut self) -> Self { fn with_row_id(mut self) -> Self {
self.mut_query().with_row_id = true; self.mut_query().with_row_id = true;
self self
@@ -561,6 +586,9 @@ pub struct Query {
/// ///
/// By default, this is false. /// By default, this is false.
pub(crate) with_row_id: bool, pub(crate) with_row_id: bool,
/// If set to false, the filter will be applied after the vector search.
pub(crate) prefilter: bool,
} }
impl Query { impl Query {
@@ -574,6 +602,7 @@ impl Query {
select: Select::All, select: Select::All,
fast_search: false, fast_search: false,
with_row_id: false, with_row_id: false,
prefilter: true,
} }
} }
@@ -678,8 +707,6 @@ pub struct VectorQuery {
pub(crate) distance_type: Option<DistanceType>, pub(crate) distance_type: Option<DistanceType>,
/// Default is true. Set to false to enforce a brute force search. /// Default is true. Set to false to enforce a brute force search.
pub(crate) use_index: bool, pub(crate) use_index: bool,
/// Apply filter before ANN search/
pub(crate) prefilter: bool,
} }
impl VectorQuery { impl VectorQuery {
@@ -692,7 +719,6 @@ impl VectorQuery {
refine_factor: None, refine_factor: None,
distance_type: None, distance_type: None,
use_index: true, use_index: true,
prefilter: true,
} }
} }
@@ -782,29 +808,6 @@ impl VectorQuery {
self self
} }
/// If this is called then filtering will happen after the vector search instead of
/// before.
///
/// By default filtering will be performed before the vector search. This is how
/// filtering is typically understood to work. This prefilter step does add some
/// additional latency. Creating a scalar index on the filter column(s) can
/// often improve this latency. However, sometimes a filter is too complex or scalar
/// indices cannot be applied to the column. In these cases postfiltering can be
/// used instead of prefiltering to improve latency.
///
/// Post filtering applies the filter to the results of the vector search. This means
/// we only run the filter on a much smaller set of data. However, it can cause the
/// query to return fewer than `limit` results (or even no results) if none of the nearest
/// results match the filter.
///
/// Post filtering happens during the "refine stage" (described in more detail in
/// [`Self::refine_factor`]). This means that setting a higher refine factor can often
/// help restore some of the results lost by post filtering.
pub fn postfilter(mut self) -> Self {
self.prefilter = false;
self
}
/// If this is called then any vector index is skipped /// If this is called then any vector index is skipped
/// ///
/// An exhaustive (flat) search will be performed. The query vector will /// An exhaustive (flat) search will be performed. The query vector will

View File

@@ -23,6 +23,8 @@ pub(crate) mod table;
pub(crate) mod util; pub(crate) mod util;
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream"; const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
#[cfg(test)]
const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
const JSON_CONTENT_TYPE: &str = "application/json"; const JSON_CONTENT_TYPE: &str = "application/json";
pub use client::{ClientConfig, RetryConfig, TimeoutConfig}; pub use client::{ClientConfig, RetryConfig, TimeoutConfig};

View File

@@ -341,7 +341,22 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
request_id request_id
}; };
debug!("Sending request_id={}: {:?}", request_id, &request); if log::log_enabled!(log::Level::Debug) {
let content_type = request
.headers()
.get("content-type")
.map(|v| v.to_str().unwrap());
if content_type == Some("application/json") {
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
let body = String::from_utf8_lossy(body);
debug!(
"Sending request_id={}: {:?} with body {}",
request_id, request, body
);
} else {
debug!("Sending request_id={}: {:?}", request_id, request);
}
}
if with_retry { if with_retry {
self.send_with_retry_impl(client, request, request_id).await self.send_with_retry_impl(client, request, request_id).await

View File

@@ -161,7 +161,7 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
if self.table_cache.get(&options.name).is_none() { if self.table_cache.get(&options.name).is_none() {
let req = self let req = self
.client .client
.get(&format!("/v1/table/{}/describe/", options.name)); .post(&format!("/v1/table/{}/describe/", options.name));
let (request_id, resp) = self.client.send(req, true).await?; let (request_id, resp) = self.client.send(req, true).await?;
if resp.status() == StatusCode::NOT_FOUND { if resp.status() == StatusCode::NOT_FOUND {
return Err(crate::Error::TableNotFound { name: options.name }); return Err(crate::Error::TableNotFound { name: options.name });
@@ -301,7 +301,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_open_table() { async fn test_open_table() {
let conn = Connection::new_with_handler(|request| { let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET); assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/table/table1/describe/"); assert_eq!(request.url().path(), "/v1/table/table1/describe/");
assert_eq!(request.url().query(), None); assert_eq!(request.url().query(), None);

View File

@@ -1,3 +1,4 @@
use std::io::Cursor;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use crate::index::Index; use crate::index::Index;
@@ -7,10 +8,9 @@ use crate::table::AddDataMode;
use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::Error; use crate::Error;
use arrow_array::RecordBatchReader; use arrow_array::RecordBatchReader;
use arrow_ipc::reader::StreamReader; use arrow_ipc::reader::FileReader;
use arrow_schema::{DataType, SchemaRef}; use arrow_schema::{DataType, SchemaRef};
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Buf;
use datafusion_common::DataFusionError; use datafusion_common::DataFusionError;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream}; use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
@@ -115,39 +115,14 @@ impl<S: HttpSend> RemoteTable<S> {
async fn read_arrow_stream( async fn read_arrow_stream(
&self, &self,
request_id: &str, request_id: &str,
body: reqwest::Response, response: reqwest::Response,
) -> Result<SendableRecordBatchStream> { ) -> Result<SendableRecordBatchStream> {
// Assert that the content type is correct let response = self.check_table_response(request_id, response).await?;
let content_type = body
.headers()
.get(CONTENT_TYPE)
.ok_or_else(|| Error::Http {
source: "Missing content type".into(),
request_id: request_id.to_string(),
status_code: None,
})?
.to_str()
.map_err(|e| Error::Http {
source: format!("Failed to parse content type: {}", e).into(),
request_id: request_id.to_string(),
status_code: None,
})?;
if content_type != ARROW_STREAM_CONTENT_TYPE {
return Err(Error::Http {
source: format!(
"Expected content type {}, got {}",
ARROW_STREAM_CONTENT_TYPE, content_type
)
.into(),
request_id: request_id.to_string(),
status_code: None,
});
}
// There isn't a way to actually stream this data yet. I have an upstream issue: // There isn't a way to actually stream this data yet. I have an upstream issue:
// https://github.com/apache/arrow-rs/issues/6420 // https://github.com/apache/arrow-rs/issues/6420
let body = body.bytes().await.err_to_http(request_id.into())?; let body = response.bytes().await.err_to_http(request_id.into())?;
let reader = StreamReader::try_new(body.reader(), None)?; let reader = FileReader::try_new(Cursor::new(body), None)?;
let schema = reader.schema(); let schema = reader.schema();
let stream = futures::stream::iter(reader).map_err(DataFusionError::from); let stream = futures::stream::iter(reader).map_err(DataFusionError::from);
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
@@ -192,6 +167,10 @@ impl<S: HttpSend> RemoteTable<S> {
body["fast_search"] = serde_json::Value::Bool(true); body["fast_search"] = serde_json::Value::Bool(true);
} }
if params.with_row_id {
body["with_row_id"] = serde_json::Value::Bool(true);
}
if let Some(full_text_search) = &params.full_text_search { if let Some(full_text_search) = &params.full_text_search {
if full_text_search.wand_factor.is_some() { if full_text_search.wand_factor.is_some() {
return Err(Error::NotSupported { return Err(Error::NotSupported {
@@ -277,7 +256,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
.post(&format!("/v1/table/{}/count_rows/", self.name)); .post(&format!("/v1/table/{}/count_rows/", self.name));
if let Some(filter) = filter { if let Some(filter) = filter {
request = request.json(&serde_json::json!({ "filter": filter })); request = request.json(&serde_json::json!({ "predicate": filter }));
} else { } else {
request = request.json(&serde_json::json!({})); request = request.json(&serde_json::json!({}));
} }
@@ -330,13 +309,13 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let mut body = serde_json::Value::Object(Default::default()); let mut body = serde_json::Value::Object(Default::default());
Self::apply_query_params(&mut body, &query.base)?; Self::apply_query_params(&mut body, &query.base)?;
body["prefilter"] = query.prefilter.into(); body["prefilter"] = query.base.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into(); body["nprobes"] = query.nprobes.into();
body["refine_factor"] = query.refine_factor.into(); body["refine_factor"] = query.refine_factor.into();
if let Some(vector) = query.query_vector.as_ref() { let vector: Vec<f32> = if let Some(vector) = query.query_vector.as_ref() {
let vector: Vec<f32> = match vector.data_type() { match vector.data_type() {
DataType::Float32 => vector DataType::Float32 => vector
.as_any() .as_any()
.downcast_ref::<arrow_array::Float32Array>() .downcast_ref::<arrow_array::Float32Array>()
@@ -350,9 +329,12 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
message: "VectorQuery vector must be of type Float32".into(), message: "VectorQuery vector must be of type Float32".into(),
}) })
} }
}
} else {
// Server takes empty vector, not null or undefined.
Vec::new()
}; };
body["vector"] = serde_json::json!(vector); body["vector"] = serde_json::json!(vector);
}
if let Some(vector_column) = query.column.as_ref() { if let Some(vector_column) = query.column.as_ref() {
body["vector_column"] = serde_json::Value::String(vector_column.clone()); body["vector_column"] = serde_json::Value::String(vector_column.clone());
@@ -383,6 +365,8 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let mut body = serde_json::Value::Object(Default::default()); let mut body = serde_json::Value::Object(Default::default());
Self::apply_query_params(&mut body, query)?; Self::apply_query_params(&mut body, query)?;
// Empty vector can be passed if no vector search is performed.
body["vector"] = serde_json::Value::Array(Vec::new());
let request = request.json(&body); let request = request.json(&body);
@@ -399,30 +383,19 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let mut updates = Vec::new(); let mut updates = Vec::new();
for (column, expression) in update.columns { for (column, expression) in update.columns {
updates.push(column); updates.push(vec![column, expression]);
updates.push(expression);
} }
let request = request.json(&serde_json::json!({ let request = request.json(&serde_json::json!({
"updates": updates, "updates": updates,
"only_if": update.filter, "predicate": update.filter,
})); }));
let (request_id, response) = self.client.send(request, false).await?; let (request_id, response) = self.client.send(request, false).await?;
let response = self.check_table_response(&request_id, response).await?; self.check_table_response(&request_id, response).await?;
let body = response.text().await.err_to_http(request_id.clone())?; Ok(0) // TODO: support returning number of modified rows once supported in SaaS.
serde_json::from_str(&body).map_err(|e| Error::Http {
source: format!(
"Failed to parse updated rows result from response {}: {}",
body, e
)
.into(),
request_id,
status_code: None,
})
} }
async fn delete(&self, predicate: &str) -> Result<()> { async fn delete(&self, predicate: &str) -> Result<()> {
let body = serde_json::json!({ "predicate": predicate }); let body = serde_json::json!({ "predicate": predicate });
@@ -691,6 +664,7 @@ mod tests {
use crate::{ use crate::{
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType}, index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
query::{ExecutableQuery, QueryBase}, query::{ExecutableQuery, QueryBase},
remote::ARROW_FILE_CONTENT_TYPE,
DistanceType, Error, Table, DistanceType, Error, Table,
}; };
@@ -804,7 +778,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
request.body().unwrap().as_bytes().unwrap(), request.body().unwrap().as_bytes().unwrap(),
br#"{"filter":"a > 10"}"# br#"{"predicate":"a > 10"}"#
); );
http::Response::builder().status(200).body("42").unwrap() http::Response::builder().status(200).body("42").unwrap()
@@ -839,6 +813,17 @@ mod tests {
body body
} }
fn write_ipc_file(data: &RecordBatch) -> Vec<u8> {
let mut body = Vec::new();
{
let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut body, &data.schema())
.expect("Failed to create writer");
writer.write(data).expect("Failed to write data");
writer.finish().expect("Failed to finish");
}
body
}
#[tokio::test] #[tokio::test]
async fn test_add_append() { async fn test_add_append() {
let data = RecordBatch::try_new( let data = RecordBatch::try_new(
@@ -947,21 +932,27 @@ mod tests {
let updates = value.get("updates").unwrap().as_array().unwrap(); let updates = value.get("updates").unwrap().as_array().unwrap();
assert!(updates.len() == 2); assert!(updates.len() == 2);
let col_name = updates[0].as_str().unwrap(); let col_name = updates[0][0].as_str().unwrap();
let expression = updates[1].as_str().unwrap(); let expression = updates[0][1].as_str().unwrap();
assert_eq!(col_name, "a"); assert_eq!(col_name, "a");
assert_eq!(expression, "a + 1"); assert_eq!(expression, "a + 1");
let only_if = value.get("only_if").unwrap().as_str().unwrap(); let col_name = updates[1][0].as_str().unwrap();
let expression = updates[1][1].as_str().unwrap();
assert_eq!(col_name, "b");
assert_eq!(expression, "b - 1");
let only_if = value.get("predicate").unwrap().as_str().unwrap();
assert_eq!(only_if, "b > 10"); assert_eq!(only_if, "b > 10");
} }
http::Response::builder().status(200).body("1").unwrap() http::Response::builder().status(200).body("{}").unwrap()
}); });
table table
.update() .update()
.column("a", "a + 1") .column("a", "a + 1")
.column("b", "b - 1")
.only_if("b > 10") .only_if("b > 10")
.execute() .execute()
.await .await
@@ -1092,10 +1083,10 @@ mod tests {
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into(); expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
assert_eq!(body, expected_body); assert_eq!(body, expected_body);
let response_body = write_ipc_stream(&expected_data_ref); let response_body = write_ipc_file(&expected_data_ref);
http::Response::builder() http::Response::builder()
.status(200) .status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body) .body(response_body)
.unwrap() .unwrap()
}); });
@@ -1142,10 +1133,10 @@ mod tests {
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
) )
.unwrap(); .unwrap();
let response_body = write_ipc_stream(&data); let response_body = write_ipc_file(&data);
http::Response::builder() http::Response::builder()
.status(200) .status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body) .body(response_body)
.unwrap() .unwrap()
}); });
@@ -1185,6 +1176,8 @@ mod tests {
"query": "hello world", "query": "hello world",
}, },
"k": 10, "k": 10,
"vector": [],
"with_row_id": true,
}); });
assert_eq!(body, expected_body); assert_eq!(body, expected_body);
@@ -1193,10 +1186,10 @@ mod tests {
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
) )
.unwrap(); .unwrap();
let response_body = write_ipc_stream(&data); let response_body = write_ipc_file(&data);
http::Response::builder() http::Response::builder()
.status(200) .status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body) .body(response_body)
.unwrap() .unwrap()
}); });
@@ -1207,6 +1200,7 @@ mod tests {
FullTextSearchQuery::new("hello world".into()) FullTextSearchQuery::new("hello world".into())
.columns(Some(vec!["a".into(), "b".into()])), .columns(Some(vec!["a".into(), "b".into()])),
) )
.with_row_id()
.limit(10) .limit(10)
.execute() .execute()
.await .await

View File

@@ -1842,7 +1842,7 @@ impl TableInternal for NativeTable {
scanner.nprobs(query.nprobes); scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index); scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter); scanner.prefilter(query.base.prefilter);
match query.base.select { match query.base.select {
Select::Columns(ref columns) => { Select::Columns(ref columns) => {
scanner.project(columns.as_slice())?; scanner.project(columns.as_slice())?;
@@ -3123,6 +3123,12 @@ mod tests {
assert_eq!(index.index_type, crate::index::IndexType::FTS); assert_eq!(index.index_type, crate::index::IndexType::FTS);
assert_eq!(index.columns, vec!["text".to_string()]); assert_eq!(index.columns, vec!["text".to_string()]);
assert_eq!(index.name, "text_idx"); assert_eq!(index.name, "text_idx");
let stats = table.index_stats("text_idx").await.unwrap().unwrap();
assert_eq!(stats.num_indexed_rows, num_rows);
assert_eq!(stats.num_unindexed_rows, 0);
assert_eq!(stats.index_type, crate::index::IndexType::FTS);
assert_eq!(stats.distance_type, None);
} }
#[tokio::test] #[tokio::test]