mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
16 Commits
python-v0.
...
v0.13.0-be
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
975398c3a8 | ||
|
|
08d5f93f34 | ||
|
|
91cab3b556 | ||
|
|
c61bfc3af8 | ||
|
|
4e8c7b0adf | ||
|
|
26f4a80e10 | ||
|
|
3604d20ad3 | ||
|
|
9708d829a9 | ||
|
|
059c9794b5 | ||
|
|
15ed7f75a0 | ||
|
|
96181ab421 | ||
|
|
f3fc339ef6 | ||
|
|
113cd6995b | ||
|
|
02535bdc88 | ||
|
|
facc7d61c0 | ||
|
|
f947259f16 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.11.1-beta.1"
|
||||
current_version = "0.13.0-beta.0"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
16
Cargo.toml
16
Cargo.toml
@@ -21,13 +21,15 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.19.1", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.19.1" }
|
||||
lance-linalg = { "version" = "=0.19.1" }
|
||||
lance-table = { "version" = "=0.19.1" }
|
||||
lance-testing = { "version" = "=0.19.1" }
|
||||
lance-datafusion = { "version" = "=0.19.1" }
|
||||
lance-encoding = { "version" = "=0.19.1" }
|
||||
lance = { "version" = "=0.19.2", "features" = [
|
||||
"dynamodb",
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-index = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-linalg = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-table = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-testing = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-datafusion = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-encoding = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "52.2", optional = false }
|
||||
arrow-array = "52.2"
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.11.1-beta.1</version>
|
||||
<version>0.13.0-beta.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.11.1-beta.1</version>
|
||||
<version>0.13.0-beta.0</version>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<name>LanceDB Parent</name>
|
||||
|
||||
49
node/package-lock.json
generated
49
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.12.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.12.0",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -52,11 +52,11 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-darwin-x64": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.11.1-beta.1"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.12.0",
|
||||
"@lancedb/vectordb-darwin-x64": "0.12.0",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.12.0",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.12.0",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.12.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-arrow/ts": "^14.0.2",
|
||||
@@ -327,65 +327,60 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.11.1-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.11.1-beta.1.tgz",
|
||||
"integrity": "sha512-q9jcCbmcz45UHmjgecL6zK82WaqUJsARfniwXXPcnd8ooISVhPkgN+RVKv6edwI9T0PV+xVRYq+LQLlZu5fyxw==",
|
||||
"version": "0.12.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.12.0.tgz",
|
||||
"integrity": "sha512-9X6UyP/ozHkv39YZ8DWh82m3aeQmUtrVDNuRe3o8has6dJyD/qPYukI8Zked4q8J+86/lgQbr4f+WW2V4Dfc1g==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.11.1-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.11.1-beta.1.tgz",
|
||||
"integrity": "sha512-E5tCTS5TaTkssTPa+gdnFxZJ1f60jnSIJXhqufNFZk4s+IMViwR1BPqaqE++WY5c1uBI55ef1862CROKDKX4gg==",
|
||||
"version": "0.12.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.12.0.tgz",
|
||||
"integrity": "sha512-zG+//P3BBpmOiLR+dop68T9AFNxazWlSLF8yVdAtvsqjRzcrrMLR//rIrRcbPHxu8gvvLrMDoDZT+AHd2rElyQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.11.1-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.11.1-beta.1.tgz",
|
||||
"integrity": "sha512-Obohy6TH31Uq+fp6ZisHR7iAsvgVPqBExrycVcIJqrLZnIe88N9OWUwBXkmfMAw/2hNJFwD4tU7+4U2FcBWX4w==",
|
||||
"version": "0.12.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.12.0.tgz",
|
||||
"integrity": "sha512-5RiJkcZEdMkK5WUfkV+HVFnJaAergfSiLNgUwJaovEEX8yVChkhrdZFSUj1o/k2k6Ix9mQq+xfIUF+aGN/XnDQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.11.1-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.11.1-beta.1.tgz",
|
||||
"integrity": "sha512-3Meu0dgrzNrnBVVQhxkUSAOhQNmgtKHvOvmrRLUicV+X19hd33udihgxVpZZb9mpXenJ8lZsS+Jq6R0hWqntag==",
|
||||
"version": "0.12.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.12.0.tgz",
|
||||
"integrity": "sha512-JFulRNBHLF0TyE0tThaAB9T7CM3zLquPsBF6oA9b1stVdXbEqVqLMltjem0tqfj30zEoEbAKDPpEKII4CPQMTA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.11.1-beta.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.11.1-beta.1.tgz",
|
||||
"integrity": "sha512-BafZ9OJPQXsS7JW0weAl12wC+827AiRjfUrE5tvrYWZah2OwCF2U2g6uJ3x4pxfwEGsv5xcHFqgxlS7ttFkh+Q==",
|
||||
"version": "0.12.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.12.0.tgz",
|
||||
"integrity": "sha512-T3s/RzB5dvXBqU3qmS6zyHhF0RHS2sSs81zKzYQy2R2nEVPbnwutFSsdA1wEqEXZlr8uTD9nLbkKJKqRNTXVEg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "Apache-2.0",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.0",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -88,10 +88,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-darwin-x64": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.11.1-beta.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.11.1-beta.1"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.13.0-beta.0",
|
||||
"@lancedb/vectordb-darwin-x64": "0.13.0-beta.0",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.13.0-beta.0",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.13.0-beta.0",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.13.0-beta.0"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import axios, { type AxiosResponse, type ResponseType } from 'axios'
|
||||
import axios, { type AxiosError, type AxiosResponse, type ResponseType } from 'axios'
|
||||
|
||||
import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow'
|
||||
|
||||
@@ -197,7 +197,7 @@ export class HttpLancedbClient {
|
||||
response = await callWithMiddlewares(req, this._middlewares)
|
||||
return response
|
||||
} catch (err: any) {
|
||||
console.error('error: ', err)
|
||||
console.error(serializeErrorAsJson(err))
|
||||
if (err.response === undefined) {
|
||||
throw new Error(`Network Error: ${err.message as string}`)
|
||||
}
|
||||
@@ -247,7 +247,8 @@ export class HttpLancedbClient {
|
||||
|
||||
// return response
|
||||
} catch (err: any) {
|
||||
console.error('error: ', err)
|
||||
console.error(serializeErrorAsJson(err))
|
||||
|
||||
if (err.response === undefined) {
|
||||
throw new Error(`Network Error: ${err.message as string}`)
|
||||
}
|
||||
@@ -287,3 +288,15 @@ export class HttpLancedbClient {
|
||||
return clone
|
||||
}
|
||||
}
|
||||
|
||||
function serializeErrorAsJson(err: AxiosError) {
|
||||
const error = JSON.parse(JSON.stringify(err, Object.getOwnPropertyNames(err)))
|
||||
error.response = err.response != null
|
||||
? JSON.parse(JSON.stringify(
|
||||
err.response,
|
||||
// config contains the request data, too noisy
|
||||
Object.getOwnPropertyNames(err.response).filter(prop => prop !== 'config')
|
||||
))
|
||||
: null
|
||||
return JSON.stringify({ error })
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.11.1-beta.1"
|
||||
version = "0.13.0-beta.0"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
@@ -402,6 +402,40 @@ describe("When creating an index", () => {
|
||||
expect(rst.numRows).toBe(1);
|
||||
});
|
||||
|
||||
it("should be able to query unindexed data", async () => {
|
||||
await tbl.createIndex("vec");
|
||||
await tbl.add([
|
||||
{
|
||||
id: 300,
|
||||
vec: Array(32)
|
||||
.fill(1)
|
||||
.map(() => Math.random()),
|
||||
tags: [],
|
||||
},
|
||||
]);
|
||||
|
||||
const plan1 = await tbl.query().nearestTo(queryVec).explainPlan(true);
|
||||
expect(plan1).toMatch("LanceScan");
|
||||
|
||||
const plan2 = await tbl
|
||||
.query()
|
||||
.nearestTo(queryVec)
|
||||
.fastSearch()
|
||||
.explainPlan(true);
|
||||
expect(plan2).not.toMatch("LanceScan");
|
||||
});
|
||||
|
||||
it("should be able to query with row id", async () => {
|
||||
const results = await tbl
|
||||
.query()
|
||||
.nearestTo(queryVec)
|
||||
.withRowId()
|
||||
.limit(1)
|
||||
.toArray();
|
||||
expect(results.length).toBe(1);
|
||||
expect(results[0]).toHaveProperty("_rowid");
|
||||
});
|
||||
|
||||
it("should allow parameters to be specified", async () => {
|
||||
await tbl.createIndex("vec", {
|
||||
config: Index.ivfPq({
|
||||
|
||||
@@ -239,6 +239,29 @@ export class QueryBase<NativeQueryType extends NativeQuery | NativeVectorQuery>
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Skip searching un-indexed data. This can make search faster, but will miss
|
||||
* any data that is not yet indexed.
|
||||
*
|
||||
* Use {@link lancedb.Table#optimize} to index all un-indexed data.
|
||||
*/
|
||||
fastSearch(): this {
|
||||
this.doCall((inner: NativeQueryType) => inner.fastSearch());
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether to return the row id in the results.
|
||||
*
|
||||
* This column can be used to match results between different queries. For
|
||||
* example, to match results from a full text search and a vector search in
|
||||
* order to perform hybrid search.
|
||||
*/
|
||||
withRowId(): this {
|
||||
this.doCall((inner: NativeQueryType) => inner.withRowId());
|
||||
return this;
|
||||
}
|
||||
|
||||
protected nativeExecute(
|
||||
options?: Partial<QueryExecutionOptions>,
|
||||
): Promise<NativeBatchIterator> {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.0",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.0",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.0",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.12.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.12.0",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"vector database",
|
||||
"ann"
|
||||
],
|
||||
"version": "0.11.1-beta.1",
|
||||
"version": "0.13.0-beta.0",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -80,6 +80,16 @@ impl Query {
|
||||
Ok(VectorQuery { inner })
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn fast_search(&mut self) {
|
||||
self.inner = self.inner.clone().fast_search();
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn with_row_id(&mut self) {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
@@ -183,6 +193,16 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().offset(offset as usize);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn fast_search(&mut self) {
|
||||
self.inner = self.inner.clone().fast_search();
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn with_row_id(&mut self) {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.15.0"
|
||||
current_version = "0.16.0-beta.0"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.15.0"
|
||||
version = "0.16.0-beta.0"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -3,13 +3,11 @@ name = "lancedb"
|
||||
# version in Cargo.toml
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.19.1",
|
||||
"requests>=2.31.0",
|
||||
"nest-asyncio~=1.0",
|
||||
"pylance==0.19.2-beta.3",
|
||||
"tqdm>=4.27.0",
|
||||
"pydantic>=1.10",
|
||||
"attrs>=21.3.0",
|
||||
"packaging",
|
||||
"cachetools",
|
||||
"overrides>=0.7",
|
||||
]
|
||||
description = "lancedb"
|
||||
@@ -61,6 +59,7 @@ dev = ["ruff", "pre-commit"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = [
|
||||
"requests>=2.31.0",
|
||||
"openai>=1.6.1",
|
||||
"sentence-transformers",
|
||||
"torch",
|
||||
|
||||
@@ -19,12 +19,10 @@ from typing import Dict, Optional, Union, Any
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
from lancedb.remote import ClientConfig
|
||||
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .common import URI, sanitize_uri
|
||||
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .remote import ClientConfig
|
||||
from .schema import vector
|
||||
from .table import AsyncTable
|
||||
|
||||
@@ -37,6 +35,7 @@ def connect(
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
@@ -64,14 +63,10 @@ def connect(
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
request_thread_pool: int or ThreadPoolExecutor, optional
|
||||
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
||||
If an integer, then a ThreadPoolExecutor will be created with that
|
||||
number of threads. If None, then a ThreadPoolExecutor will be created
|
||||
with the default number of threads. If a ThreadPoolExecutor, then that
|
||||
executor will be used for making requests. This is for LanceDB Cloud
|
||||
only and is only used when making batch requests (i.e., passing in
|
||||
multiple queries to the search method at once).
|
||||
client_config: ClientConfig or dict, optional
|
||||
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
|
||||
the keys are the attributes of the ClientConfig class. If None, then the
|
||||
default configuration is used.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -94,6 +89,8 @@ def connect(
|
||||
conn : DBConnection
|
||||
A connection to a LanceDB database.
|
||||
"""
|
||||
from .remote.db import RemoteDBConnection
|
||||
|
||||
if isinstance(uri, str) and uri.startswith("db://"):
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("LANCEDB_API_KEY")
|
||||
@@ -106,7 +103,9 @@ def connect(
|
||||
api_key,
|
||||
region,
|
||||
host_override,
|
||||
# TODO: remove this (deprecation warning downstream)
|
||||
request_thread_pool=request_thread_pool,
|
||||
client_config=client_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -36,6 +36,8 @@ class Connection(object):
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
) -> Table: ...
|
||||
async def rename_table(self, old_name: str, new_name: str) -> None: ...
|
||||
async def drop_table(self, name: str) -> None: ...
|
||||
|
||||
class Table:
|
||||
def name(self) -> str: ...
|
||||
|
||||
@@ -817,6 +817,18 @@ class AsyncConnection(object):
|
||||
table = await self._inner.open_table(name, storage_options, index_cache_size)
|
||||
return AsyncTable(table)
|
||||
|
||||
async def rename_table(self, old_name: str, new_name: str):
|
||||
"""Rename a table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
old_name: str
|
||||
The current name of the table.
|
||||
new_name: str
|
||||
The new name of the table.
|
||||
"""
|
||||
await self._inner.rename_table(old_name, new_name)
|
||||
|
||||
async def drop_table(self, name: str):
|
||||
"""Drop a table from the database.
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
|
||||
import os
|
||||
import io
|
||||
import requests
|
||||
import base64
|
||||
from urllib.parse import urlparse
|
||||
from pathlib import Path
|
||||
@@ -226,6 +225,8 @@ class JinaEmbeddings(EmbeddingFunction):
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
|
||||
def _init_client(self):
|
||||
import requests
|
||||
|
||||
if JinaEmbeddings._session is None:
|
||||
if self.api_key is None and os.environ.get("JINA_API_KEY") is None:
|
||||
api_key_not_found_help("jina")
|
||||
|
||||
@@ -467,6 +467,8 @@ class IvfPq:
|
||||
|
||||
The default value is 256.
|
||||
"""
|
||||
if distance_type is not None:
|
||||
distance_type = distance_type.lower()
|
||||
self._inner = LanceDbIndex.ivf_pq(
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
|
||||
@@ -481,6 +481,7 @@ class LanceQueryBuilder(ABC):
|
||||
>>> plan = table.search(query).explain_plan(True)
|
||||
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
||||
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
|
||||
GlobalLimitExec: skip=0, fetch=10
|
||||
FilterExec: _distance@2 IS NOT NULL
|
||||
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
|
||||
KNNVectorDistance: metric=l2
|
||||
@@ -500,7 +501,16 @@ class LanceQueryBuilder(ABC):
|
||||
nearest={
|
||||
"column": self._vector_column,
|
||||
"q": self._query,
|
||||
"k": self._limit,
|
||||
"metric": self._metric,
|
||||
"nprobes": self._nprobes,
|
||||
"refine_factor": self._refine_factor,
|
||||
},
|
||||
prefilter=self._prefilter,
|
||||
filter=self._str_query,
|
||||
limit=self._limit,
|
||||
with_row_id=self._with_row_id,
|
||||
offset=self._offset,
|
||||
).explain_plan(verbose)
|
||||
|
||||
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
|
||||
@@ -1315,6 +1325,48 @@ class AsyncQueryBase(object):
|
||||
self._inner.offset(offset)
|
||||
return self
|
||||
|
||||
def fast_search(self) -> AsyncQuery:
|
||||
"""
|
||||
Skip searching un-indexed data.
|
||||
|
||||
This can make queries faster, but will miss any data that has not been
|
||||
indexed.
|
||||
|
||||
!!! tip
|
||||
You can add new data into an existing index by calling
|
||||
[AsyncTable.optimize][lancedb.table.AsyncTable.optimize].
|
||||
"""
|
||||
self._inner.fast_search()
|
||||
return self
|
||||
|
||||
def with_row_id(self) -> AsyncQuery:
|
||||
"""
|
||||
Include the _rowid column in the results.
|
||||
"""
|
||||
self._inner.with_row_id()
|
||||
return self
|
||||
|
||||
def postfilter(self) -> AsyncQuery:
|
||||
"""
|
||||
If this is called then filtering will happen after the search instead of
|
||||
before.
|
||||
By default filtering will be performed before the search. This is how
|
||||
filtering is typically understood to work. This prefilter step does add some
|
||||
additional latency. Creating a scalar index on the filter column(s) can
|
||||
often improve this latency. However, sometimes a filter is too complex or
|
||||
scalar indices cannot be applied to the column. In these cases postfiltering
|
||||
can be used instead of prefiltering to improve latency.
|
||||
Post filtering applies the filter to the results of the search. This
|
||||
means we only run the filter on a much smaller set of data. However, it can
|
||||
cause the query to return fewer than `limit` results (or even no results) if
|
||||
none of the nearest results match the filter.
|
||||
Post filtering happens during the "refine stage" (described in more detail in
|
||||
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
|
||||
factor can often help restore some of the results lost by post filtering.
|
||||
"""
|
||||
self._inner.postfilter()
|
||||
return self
|
||||
|
||||
async def to_batches(
|
||||
self, *, max_batch_length: Optional[int] = None
|
||||
) -> AsyncRecordBatchReader:
|
||||
@@ -1618,30 +1670,6 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
self._inner.distance_type(distance_type)
|
||||
return self
|
||||
|
||||
def postfilter(self) -> AsyncVectorQuery:
|
||||
"""
|
||||
If this is called then filtering will happen after the vector search instead of
|
||||
before.
|
||||
|
||||
By default filtering will be performed before the vector search. This is how
|
||||
filtering is typically understood to work. This prefilter step does add some
|
||||
additional latency. Creating a scalar index on the filter column(s) can
|
||||
often improve this latency. However, sometimes a filter is too complex or
|
||||
scalar indices cannot be applied to the column. In these cases postfiltering
|
||||
can be used instead of prefiltering to improve latency.
|
||||
|
||||
Post filtering applies the filter to the results of the vector search. This
|
||||
means we only run the filter on a much smaller set of data. However, it can
|
||||
cause the query to return fewer than `limit` results (or even no results) if
|
||||
none of the nearest results match the filter.
|
||||
|
||||
Post filtering happens during the "refine stage" (described in more detail in
|
||||
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
|
||||
factor can often help restore some of the results lost by post filtering.
|
||||
"""
|
||||
self._inner.postfilter()
|
||||
return self
|
||||
|
||||
def bypass_vector_index(self) -> AsyncVectorQuery:
|
||||
"""
|
||||
If this is called then any vector index is skipped
|
||||
|
||||
@@ -11,62 +11,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
import attrs
|
||||
from lancedb import __version__
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lancedb.common import VECTOR_COLUMN_NAME
|
||||
|
||||
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
|
||||
|
||||
|
||||
class VectorQuery(BaseModel):
|
||||
# vector to search for
|
||||
vector: List[float]
|
||||
|
||||
# sql filter to refine the query with
|
||||
filter: Optional[str] = None
|
||||
|
||||
# top k results to return
|
||||
k: int
|
||||
|
||||
# # metrics
|
||||
_metric: str = "L2"
|
||||
|
||||
# which columns to return in the results
|
||||
columns: Optional[List[str]] = None
|
||||
|
||||
# optional query parameters for tuning the results,
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
nprobes: int = 10
|
||||
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
vector_column: str = VECTOR_COLUMN_NAME
|
||||
|
||||
fast_search: bool = False
|
||||
|
||||
|
||||
@attrs.define
|
||||
class VectorQueryResult:
|
||||
# for now the response is directly seralized into a pandas dataframe
|
||||
tbl: pa.Table
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
return self.tbl
|
||||
|
||||
|
||||
class LanceDBClient(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query the LanceDB server for the given table and query."""
|
||||
pass
|
||||
__all__ = ["TimeoutConfig", "RetryConfig", "ClientConfig"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -165,8 +116,8 @@ class RetryConfig:
|
||||
@dataclass
|
||||
class ClientConfig:
|
||||
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
||||
retry_config: Optional[RetryConfig] = None
|
||||
timeout_config: Optional[TimeoutConfig] = None
|
||||
retry_config: RetryConfig = field(default_factory=RetryConfig)
|
||||
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.retry_config, dict):
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Iterable, Union
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def to_ipc_binary(table: Union[pa.Table, Iterable[pa.RecordBatch]]) -> bytes:
|
||||
"""Serialize a PyArrow Table to IPC binary."""
|
||||
sink = pa.BufferOutputStream()
|
||||
if isinstance(table, Iterable):
|
||||
table = pa.Table.from_batches(table)
|
||||
with pa.ipc.new_stream(sink, table.schema) as writer:
|
||||
writer.write_table(table)
|
||||
return sink.getvalue().to_pybytes()
|
||||
@@ -1,269 +0,0 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import attrs
|
||||
import pyarrow as pa
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3 import Retry
|
||||
|
||||
from lancedb.common import Credential
|
||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||
from lancedb.remote.connection_timeout import LanceDBClientHTTPAdapterFactory
|
||||
from lancedb.remote.errors import LanceDBClientError
|
||||
|
||||
ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream"
|
||||
|
||||
|
||||
def _check_not_closed(f):
|
||||
@functools.wraps(f)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
if self.closed:
|
||||
raise ValueError("Connection is closed")
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _read_ipc(resp: requests.Response) -> pa.Table:
|
||||
resp_body = resp.content
|
||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||
return reader.read_all()
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class RestfulLanceDBClient:
|
||||
db_name: str
|
||||
region: str
|
||||
api_key: Credential
|
||||
host_override: Optional[str] = attrs.field(default=None)
|
||||
|
||||
closed: bool = attrs.field(default=False, init=False)
|
||||
|
||||
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
|
||||
read_timeout: float = attrs.field(default=300.0, kw_only=True)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> requests.Session:
|
||||
sess = requests.Session()
|
||||
|
||||
retry_adapter_instance = retry_adapter(retry_adapter_options())
|
||||
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
|
||||
|
||||
adapter_class = LanceDBClientHTTPAdapterFactory()
|
||||
sess.mount("https://", adapter_class())
|
||||
return sess
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return (
|
||||
self.host_override
|
||||
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
return False # Do not suppress exceptions
|
||||
|
||||
def close(self):
|
||||
self.session.close()
|
||||
self.closed = True
|
||||
|
||||
@functools.cached_property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
}
|
||||
if self.region == "local": # Local test mode
|
||||
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
|
||||
if self.host_override:
|
||||
headers["x-lancedb-database"] = self.db_name
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _check_status(resp: requests.Response):
|
||||
# Leaving request id empty for now, as we'll be replacing this impl
|
||||
# with the Rust one shortly.
|
||||
if resp.status_code == 404:
|
||||
raise LanceDBClientError(
|
||||
f"Not found: {resp.text}", request_id="", status_code=404
|
||||
)
|
||||
elif 400 <= resp.status_code < 500:
|
||||
raise LanceDBClientError(
|
||||
f"Bad Request: {resp.status_code}, error: {resp.text}",
|
||||
request_id="",
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
elif 500 <= resp.status_code < 600:
|
||||
raise LanceDBClientError(
|
||||
f"Internal Server Error: {resp.status_code}, error: {resp.text}",
|
||||
request_id="",
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
elif resp.status_code != 200:
|
||||
raise LanceDBClientError(
|
||||
f"Unknown Error: {resp.status_code}, error: {resp.text}",
|
||||
request_id="",
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
|
||||
@_check_not_closed
|
||||
def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
||||
"""Send a GET request and returns the deserialized response payload."""
|
||||
if isinstance(params, BaseModel):
|
||||
params: Dict[str, Any] = params.dict(exclude_none=True)
|
||||
with self.session.get(
|
||||
urljoin(self.url, uri),
|
||||
params=params,
|
||||
headers=self.headers,
|
||||
timeout=(self.connection_timeout, self.read_timeout),
|
||||
) as resp:
|
||||
self._check_status(resp)
|
||||
return resp.json()
|
||||
|
||||
@_check_not_closed
|
||||
def post(
|
||||
self,
|
||||
uri: str,
|
||||
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
content_type: Optional[str] = None,
|
||||
deserialize: Callable = lambda resp: resp.json(),
|
||||
request_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a POST request and returns the deserialized response payload.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
The uri to send the POST request to.
|
||||
data: Union[Dict[str, Any], BaseModel]
|
||||
request_id: Optional[str]
|
||||
Optional client side request id to be sent in the request headers.
|
||||
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data: Dict[str, Any] = data.dict(exclude_none=True)
|
||||
if isinstance(data, bytes):
|
||||
req_kwargs = {"data": data}
|
||||
else:
|
||||
req_kwargs = {"json": data}
|
||||
|
||||
headers = self.headers.copy()
|
||||
if content_type is not None:
|
||||
headers["content-type"] = content_type
|
||||
if request_id is not None:
|
||||
headers["x-request-id"] = request_id
|
||||
with self.session.post(
|
||||
urljoin(self.url, uri),
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=(self.connection_timeout, self.read_timeout),
|
||||
**req_kwargs,
|
||||
) as resp:
|
||||
self._check_status(resp)
|
||||
return deserialize(resp)
|
||||
|
||||
@_check_not_closed
|
||||
def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]:
|
||||
"""List all tables in the database."""
|
||||
if page_token is None:
|
||||
page_token = ""
|
||||
json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
|
||||
return json["tables"]
|
||||
|
||||
@_check_not_closed
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query a table."""
|
||||
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
|
||||
return VectorQueryResult(tbl)
|
||||
|
||||
def mount_retry_adapter_for_table(self, table_name: str) -> None:
|
||||
"""
|
||||
Adds an http adapter to session that will retry retryable requests to the table.
|
||||
"""
|
||||
retry_options = retry_adapter_options(methods=["GET", "POST"])
|
||||
retry_adapter_instance = retry_adapter(retry_options)
|
||||
session = self.session
|
||||
|
||||
session.mount(
|
||||
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
|
||||
)
|
||||
session.mount(
|
||||
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
|
||||
retry_adapter_instance,
|
||||
)
|
||||
session.mount(
|
||||
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
|
||||
retry_adapter_instance,
|
||||
)
|
||||
|
||||
|
||||
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
|
||||
return {
|
||||
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
|
||||
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
|
||||
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
|
||||
"backoff_factor": float(
|
||||
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
|
||||
),
|
||||
"backoff_jitter": float(
|
||||
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
|
||||
),
|
||||
"statuses": [
|
||||
int(i.strip())
|
||||
for i in os.environ.get(
|
||||
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
|
||||
).split(",")
|
||||
],
|
||||
"methods": methods,
|
||||
}
|
||||
|
||||
|
||||
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
|
||||
total_retries = options["retries"]
|
||||
connect_retries = options["connect_retries"]
|
||||
read_retries = options["read_retries"]
|
||||
backoff_factor = options["backoff_factor"]
|
||||
backoff_jitter = options["backoff_jitter"]
|
||||
statuses = options["statuses"]
|
||||
methods = frozenset(options["methods"])
|
||||
logging.debug(
|
||||
f"Setting up retry adapter with {total_retries} retries," # noqa G003
|
||||
+ f"connect retries {connect_retries}, read retries {read_retries},"
|
||||
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
|
||||
+ f"methods {methods}"
|
||||
)
|
||||
|
||||
return HTTPAdapter(
|
||||
max_retries=Retry(
|
||||
total=total_retries,
|
||||
connect=connect_retries,
|
||||
read=read_retries,
|
||||
backoff_factor=backoff_factor,
|
||||
backoff_jitter=backoff_jitter,
|
||||
status_forcelist=statuses,
|
||||
allowed_methods=methods,
|
||||
)
|
||||
)
|
||||
@@ -1,115 +0,0 @@
|
||||
# Copyright 2024 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This module contains an adapter that will close connections if they have not been
|
||||
# used before a certain timeout. This is necessary because some load balancers will
|
||||
# close connections after a certain amount of time, but the request module may not yet
|
||||
# have received the FIN/ACK and will try to reuse the connection.
|
||||
#
|
||||
# TODO some of the code here can be simplified if/when this PR is merged:
|
||||
# https://github.com/urllib3/urllib3/pull/3275
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.connection import HTTPSConnection
|
||||
from urllib3.connectionpool import HTTPSConnectionPool
|
||||
from urllib3.poolmanager import PoolManager
|
||||
|
||||
|
||||
def get_client_connection_timeout() -> int:
|
||||
return int(os.environ.get("LANCE_CLIENT_CONNECTION_TIMEOUT", "300"))
|
||||
|
||||
|
||||
class LanceDBHTTPSConnection(HTTPSConnection):
|
||||
"""
|
||||
HTTPSConnection that tracks the last time it was used.
|
||||
"""
|
||||
|
||||
idle_timeout: datetime.timedelta
|
||||
last_activity: datetime.datetime
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.last_activity = datetime.datetime.now()
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
self.last_activity = datetime.datetime.now()
|
||||
super().request(*args, **kwargs)
|
||||
|
||||
def is_expired(self):
|
||||
return datetime.datetime.now() - self.last_activity > self.idle_timeout
|
||||
|
||||
|
||||
def LanceDBHTTPSConnectionPoolFactory(client_idle_timeout: int):
|
||||
"""
|
||||
Creates a connection pool class that can be used to close idle connections.
|
||||
"""
|
||||
|
||||
class LanceDBHTTPSConnectionPool(HTTPSConnectionPool):
|
||||
# override the connection class
|
||||
ConnectionCls = LanceDBHTTPSConnection
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _get_conn(self, timeout: float | None = None):
|
||||
logging.debug("Getting https connection")
|
||||
conn = super()._get_conn(timeout)
|
||||
if conn.is_expired():
|
||||
logging.debug("Closing expired connection")
|
||||
conn.close()
|
||||
|
||||
return conn
|
||||
|
||||
def _new_conn(self):
|
||||
conn = super()._new_conn()
|
||||
conn.idle_timeout = datetime.timedelta(seconds=client_idle_timeout)
|
||||
return conn
|
||||
|
||||
return LanceDBHTTPSConnectionPool
|
||||
|
||||
|
||||
class LanceDBClientPoolManager(PoolManager):
|
||||
def __init__(
|
||||
self, client_idle_timeout: int, num_pools: int, maxsize: int, **kwargs
|
||||
):
|
||||
super().__init__(num_pools=num_pools, maxsize=maxsize, **kwargs)
|
||||
# inject our connection pool impl
|
||||
connection_pool_class = LanceDBHTTPSConnectionPoolFactory(
|
||||
client_idle_timeout=client_idle_timeout
|
||||
)
|
||||
self.pool_classes_by_scheme["https"] = connection_pool_class
|
||||
|
||||
|
||||
def LanceDBClientHTTPAdapterFactory():
|
||||
"""
|
||||
Creates an HTTPAdapter class that can be used to close idle connections
|
||||
"""
|
||||
|
||||
# closure over the timeout
|
||||
client_idle_timeout = get_client_connection_timeout()
|
||||
|
||||
class LanceDBClientRequestHTTPAdapter(HTTPAdapter):
|
||||
def init_poolmanager(self, connections, maxsize, block=False):
|
||||
# inject our pool manager impl
|
||||
self.poolmanager = LanceDBClientPoolManager(
|
||||
client_idle_timeout=client_idle_timeout,
|
||||
num_pools=connections,
|
||||
maxsize=maxsize,
|
||||
block=block,
|
||||
)
|
||||
|
||||
return LanceDBClientRequestHTTPAdapter
|
||||
@@ -11,13 +11,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
import warnings
|
||||
|
||||
from cachetools import TTLCache
|
||||
from lancedb import connect_async
|
||||
from lancedb.remote import ClientConfig
|
||||
import pyarrow as pa
|
||||
from overrides import override
|
||||
|
||||
@@ -25,10 +28,8 @@ from ..common import DATA
|
||||
from ..db import DBConnection
|
||||
from ..embeddings import EmbeddingFunctionConfig
|
||||
from ..pydantic import LanceModel
|
||||
from ..table import Table, sanitize_create_table
|
||||
from ..table import Table
|
||||
from ..util import validate_table_name
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
|
||||
|
||||
class RemoteDBConnection(DBConnection):
|
||||
@@ -41,26 +42,70 @@ class RemoteDBConnection(DBConnection):
|
||||
region: str,
|
||||
host_override: Optional[str] = None,
|
||||
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
||||
connection_timeout: float = 120.0,
|
||||
read_timeout: float = 300.0,
|
||||
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
|
||||
connection_timeout: Optional[float] = None,
|
||||
read_timeout: Optional[float] = None,
|
||||
):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
|
||||
if isinstance(client_config, dict):
|
||||
client_config = ClientConfig(**client_config)
|
||||
elif client_config is None:
|
||||
client_config = ClientConfig()
|
||||
|
||||
# These are legacy options from the old Python-based client. We keep them
|
||||
# here for backwards compatibility, but will remove them in a future release.
|
||||
if request_thread_pool is not None:
|
||||
warnings.warn(
|
||||
"request_thread_pool is no longer used and will be removed in "
|
||||
"a future release.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if connection_timeout is not None:
|
||||
warnings.warn(
|
||||
"connection_timeout is deprecated and will be removed in a future "
|
||||
"release. Please use client_config.timeout_config.connect_timeout "
|
||||
"instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
client_config.timeout_config.connect_timeout = timedelta(
|
||||
seconds=connection_timeout
|
||||
)
|
||||
|
||||
if read_timeout is not None:
|
||||
warnings.warn(
|
||||
"read_timeout is deprecated and will be removed in a future release. "
|
||||
"Please use client_config.timeout_config.read_timeout instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
client_config.timeout_config.read_timeout = timedelta(seconds=read_timeout)
|
||||
|
||||
parsed = urlparse(db_url)
|
||||
if parsed.scheme != "db":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||
self._uri = str(db_url)
|
||||
self.db_name = parsed.netloc
|
||||
self.api_key = api_key
|
||||
self._client = RestfulLanceDBClient(
|
||||
self.db_name,
|
||||
region,
|
||||
api_key,
|
||||
host_override,
|
||||
connection_timeout=connection_timeout,
|
||||
read_timeout=read_timeout,
|
||||
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
try:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
|
||||
self.client_config = client_config
|
||||
|
||||
self._conn = self._loop.run_until_complete(
|
||||
connect_async(
|
||||
db_url,
|
||||
api_key=api_key,
|
||||
region=region,
|
||||
host_override=host_override,
|
||||
client_config=client_config,
|
||||
)
|
||||
)
|
||||
self._request_thread_pool = request_thread_pool
|
||||
self._table_cache = TTLCache(maxsize=10000, ttl=300)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteConnect(name={self.db_name})"
|
||||
@@ -82,16 +127,9 @@ class RemoteDBConnection(DBConnection):
|
||||
-------
|
||||
An iterator of table names.
|
||||
"""
|
||||
while True:
|
||||
result = self._client.list_tables(limit, page_token)
|
||||
|
||||
if len(result) > 0:
|
||||
page_token = result[len(result) - 1]
|
||||
else:
|
||||
break
|
||||
for item in result:
|
||||
self._table_cache[item] = True
|
||||
yield item
|
||||
return self._loop.run_until_complete(
|
||||
self._conn.table_names(start_after=page_token, limit=limit)
|
||||
)
|
||||
|
||||
@override
|
||||
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
|
||||
@@ -108,20 +146,14 @@ class RemoteDBConnection(DBConnection):
|
||||
"""
|
||||
from .table import RemoteTable
|
||||
|
||||
self._client.mount_retry_adapter_for_table(name)
|
||||
|
||||
if index_cache_size is not None:
|
||||
logging.info(
|
||||
"index_cache_size is ignored in LanceDb Cloud"
|
||||
" (there is no local cache to configure)"
|
||||
)
|
||||
|
||||
# check if table exists
|
||||
if self._table_cache.get(name) is None:
|
||||
self._client.post(f"/v1/table/{name}/describe/")
|
||||
self._table_cache[name] = True
|
||||
|
||||
return RemoteTable(self, name)
|
||||
table = self._loop.run_until_complete(self._conn.open_table(name))
|
||||
return RemoteTable(table, self.db_name, self._loop)
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
@@ -233,27 +265,20 @@ class RemoteDBConnection(DBConnection):
|
||||
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
||||
"for this feature."
|
||||
)
|
||||
if mode is not None:
|
||||
logging.warning("mode is not yet supported on LanceDB Cloud.")
|
||||
|
||||
data, schema = sanitize_create_table(
|
||||
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
|
||||
from .table import RemoteTable
|
||||
|
||||
data = to_ipc_binary(data)
|
||||
request_id = uuid.uuid4().hex
|
||||
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/create/",
|
||||
data=data,
|
||||
request_id=request_id,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
table = self._loop.run_until_complete(
|
||||
self._conn.create_table(
|
||||
name,
|
||||
data,
|
||||
mode=mode,
|
||||
schema=schema,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
)
|
||||
|
||||
self._table_cache[name] = True
|
||||
return RemoteTable(self, name)
|
||||
return RemoteTable(table, self.db_name, self._loop)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str):
|
||||
@@ -264,11 +289,7 @@ class RemoteDBConnection(DBConnection):
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/drop/",
|
||||
)
|
||||
self._table_cache.pop(name, default=None)
|
||||
self._loop.run_until_complete(self._conn.drop_table(name))
|
||||
|
||||
@override
|
||||
def rename_table(self, cur_name: str, new_name: str):
|
||||
@@ -281,12 +302,7 @@ class RemoteDBConnection(DBConnection):
|
||||
new_name: str
|
||||
The new name of the table.
|
||||
"""
|
||||
self._client.post(
|
||||
f"/v1/table/{cur_name}/rename/",
|
||||
data={"new_table_name": new_name},
|
||||
)
|
||||
self._table_cache.pop(cur_name, default=None)
|
||||
self._table_cache[new_name] = True
|
||||
self._loop.run_until_complete(self._conn.rename_table(cur_name, new_name))
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the database."""
|
||||
|
||||
@@ -11,53 +11,56 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from functools import cached_property
|
||||
from typing import Dict, Iterable, List, Optional, Union, Literal
|
||||
|
||||
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
|
||||
import pyarrow as pa
|
||||
from lance import json_to_schema
|
||||
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from lancedb.merge import LanceMergeInsertBuilder
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
|
||||
from ..table import Query, Table, _sanitize_data
|
||||
from ..util import value_to_sql, infer_vector_column_name
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE
|
||||
from .db import RemoteDBConnection
|
||||
from ..table import AsyncTable, Query, Table
|
||||
|
||||
|
||||
class RemoteTable(Table):
|
||||
def __init__(self, conn: RemoteDBConnection, name: str):
|
||||
self._conn = conn
|
||||
self.name = name
|
||||
def __init__(
|
||||
self,
|
||||
table: AsyncTable,
|
||||
db_name: str,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
):
|
||||
self._loop = loop
|
||||
self._table = table
|
||||
self.db_name = db_name
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the table"""
|
||||
return self._table.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteTable({self._conn.db_name}.{self.name})"
|
||||
return f"RemoteTable({self.db_name}.{self.name})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
self.count_rows(None)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||
of this Table
|
||||
|
||||
"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
|
||||
schema = json_to_schema(resp["schema"])
|
||||
return schema
|
||||
return self._loop.run_until_complete(self._table.schema())
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
"""Get the current version of the table"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self.name}/describe/")
|
||||
return resp["version"]
|
||||
return self._loop.run_until_complete(self._table.version())
|
||||
|
||||
@cached_property
|
||||
def embedding_functions(self) -> dict:
|
||||
@@ -84,20 +87,18 @@ class RemoteTable(Table):
|
||||
|
||||
def list_indices(self):
|
||||
"""List all the indices on the table"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/")
|
||||
return resp
|
||||
return self._loop.run_until_complete(self._table.list_indices())
|
||||
|
||||
def index_stats(self, index_uuid: str):
|
||||
"""List all the stats of a specified index"""
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self.name}/index/{index_uuid}/stats/"
|
||||
)
|
||||
return resp
|
||||
return self._loop.run_until_complete(self._table.index_stats(index_uuid))
|
||||
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
|
||||
*,
|
||||
replace: bool = False,
|
||||
):
|
||||
"""Creates a scalar index
|
||||
Parameters
|
||||
@@ -107,20 +108,23 @@ class RemoteTable(Table):
|
||||
or string column.
|
||||
index_type : str
|
||||
The index type of the scalar index. Must be "scalar" (BTREE),
|
||||
"BTREE", "BITMAP", or "LABEL_LIST"
|
||||
"BTREE", "BITMAP", or "LABEL_LIST",
|
||||
replace : bool
|
||||
If True, replace the existing index with the new one.
|
||||
"""
|
||||
if index_type == "scalar" or index_type == "BTREE":
|
||||
config = BTree()
|
||||
elif index_type == "BITMAP":
|
||||
config = Bitmap()
|
||||
elif index_type == "LABEL_LIST":
|
||||
config = LabelList()
|
||||
else:
|
||||
raise ValueError(f"Unknown index type: {index_type}")
|
||||
|
||||
data = {
|
||||
"column": column,
|
||||
"index_type": index_type,
|
||||
"replace": True,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self.name}/create_scalar_index/", data=data
|
||||
self._loop.run_until_complete(
|
||||
self._table.create_index(column, config=config, replace=replace)
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
def create_fts_index(
|
||||
self,
|
||||
column: str,
|
||||
@@ -128,15 +132,10 @@ class RemoteTable(Table):
|
||||
replace: bool = False,
|
||||
with_position: bool = True,
|
||||
):
|
||||
data = {
|
||||
"column": column,
|
||||
"index_type": "FTS",
|
||||
"replace": replace,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self.name}/create_index/", data=data
|
||||
config = FTS(with_position=with_position)
|
||||
self._loop.run_until_complete(
|
||||
self._table.create_index(column, config=config, replace=replace)
|
||||
)
|
||||
return resp
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
@@ -204,17 +203,22 @@ class RemoteTable(Table):
|
||||
"Existing indexes will always be replaced."
|
||||
)
|
||||
|
||||
data = {
|
||||
"column": vector_column_name,
|
||||
"index_type": index_type,
|
||||
"metric_type": metric,
|
||||
"index_cache_size": index_cache_size,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self.name}/create_index/", data=data
|
||||
)
|
||||
index_type = index_type.upper()
|
||||
if index_type == "VECTOR" or index_type == "IVF_PQ":
|
||||
config = IvfPq(distance_type=metric)
|
||||
elif index_type == "IVF_HNSW_PQ":
|
||||
config = HnswPq(distance_type=metric)
|
||||
elif index_type == "IVF_HNSW_SQ":
|
||||
config = HnswSq(distance_type=metric)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown vector index type: {index_type}. Valid options are"
|
||||
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||
)
|
||||
|
||||
return resp
|
||||
self._loop.run_until_complete(
|
||||
self._table.create_index(vector_column_name, config=config)
|
||||
)
|
||||
|
||||
def add(
|
||||
self,
|
||||
@@ -246,22 +250,10 @@ class RemoteTable(Table):
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
"""
|
||||
data, _ = _sanitize_data(
|
||||
data,
|
||||
self.schema,
|
||||
metadata=self.schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
payload = to_ipc_binary(data)
|
||||
|
||||
request_id = uuid.uuid4().hex
|
||||
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self.name}/insert/",
|
||||
data=payload,
|
||||
params={"request_id": request_id, "mode": mode},
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
self._loop.run_until_complete(
|
||||
self._table.add(
|
||||
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
)
|
||||
|
||||
def search(
|
||||
@@ -337,12 +329,6 @@ class RemoteTable(Table):
|
||||
# empty query builder is not supported in saas, raise error
|
||||
if query is None and query_type != "hybrid":
|
||||
raise ValueError("Empty query is not supported")
|
||||
vector_column_name = infer_vector_column_name(
|
||||
schema=self.schema,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
vector_column_name=vector_column_name,
|
||||
)
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
@@ -356,37 +342,9 @@ class RemoteTable(Table):
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
if (
|
||||
query.vector is not None
|
||||
and len(query.vector) > 0
|
||||
and not isinstance(query.vector[0], float)
|
||||
):
|
||||
if self._conn._request_thread_pool is None:
|
||||
|
||||
def submit(name, q):
|
||||
f = Future()
|
||||
f.set_result(self._conn._client.query(name, q))
|
||||
return f
|
||||
|
||||
else:
|
||||
|
||||
def submit(name, q):
|
||||
return self._conn._request_thread_pool.submit(
|
||||
self._conn._client.query, name, q
|
||||
)
|
||||
|
||||
results = []
|
||||
for v in query.vector:
|
||||
v = list(v)
|
||||
q = query.copy()
|
||||
q.vector = v
|
||||
results.append(submit(self.name, q))
|
||||
return pa.concat_tables(
|
||||
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
|
||||
).to_reader()
|
||||
else:
|
||||
result = self._conn._client.query(self.name, query)
|
||||
return result.to_arrow().to_reader()
|
||||
return self._loop.run_until_complete(
|
||||
self._table._execute_query(query, batch_size=batch_size)
|
||||
)
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
||||
@@ -403,42 +361,8 @@ class RemoteTable(Table):
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
data, _ = _sanitize_data(
|
||||
new_data,
|
||||
self.schema,
|
||||
metadata=None,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
payload = to_ipc_binary(data)
|
||||
|
||||
params = {}
|
||||
if len(merge._on) != 1:
|
||||
raise ValueError(
|
||||
"RemoteTable only supports a single on key in merge_insert"
|
||||
)
|
||||
params["on"] = merge._on[0]
|
||||
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
|
||||
if merge._when_matched_update_all_condition is not None:
|
||||
params["when_matched_update_all_filt"] = (
|
||||
merge._when_matched_update_all_condition
|
||||
)
|
||||
params["when_not_matched_insert_all"] = str(
|
||||
merge._when_not_matched_insert_all
|
||||
).lower()
|
||||
params["when_not_matched_by_source_delete"] = str(
|
||||
merge._when_not_matched_by_source_delete
|
||||
).lower()
|
||||
if merge._when_not_matched_by_source_condition is not None:
|
||||
params["when_not_matched_by_source_delete_filt"] = (
|
||||
merge._when_not_matched_by_source_condition
|
||||
)
|
||||
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self.name}/merge_insert/",
|
||||
data=payload,
|
||||
params=params,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
self._loop.run_until_complete(
|
||||
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
|
||||
)
|
||||
|
||||
def delete(self, predicate: str):
|
||||
@@ -488,8 +412,7 @@ class RemoteTable(Table):
|
||||
x vector _distance # doctest: +SKIP
|
||||
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||
"""
|
||||
payload = {"predicate": predicate}
|
||||
self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload)
|
||||
self._loop.run_until_complete(self._table.delete(predicate))
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -539,18 +462,9 @@ class RemoteTable(Table):
|
||||
2 2 [10.0, 10.0] # doctest: +SKIP
|
||||
|
||||
"""
|
||||
if values is not None and values_sql is not None:
|
||||
raise ValueError("Only one of values or values_sql can be provided")
|
||||
if values is None and values_sql is None:
|
||||
raise ValueError("Either values or values_sql must be provided")
|
||||
|
||||
if values is not None:
|
||||
updates = [[k, value_to_sql(v)] for k, v in values.items()]
|
||||
else:
|
||||
updates = [[k, v] for k, v in values_sql.items()]
|
||||
|
||||
payload = {"predicate": where, "updates": updates}
|
||||
self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload)
|
||||
self._loop.run_until_complete(
|
||||
self._table.update(where=where, updates=values, updates_sql=values_sql)
|
||||
)
|
||||
|
||||
def cleanup_old_versions(self, *_):
|
||||
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
|
||||
@@ -565,11 +479,7 @@ class RemoteTable(Table):
|
||||
)
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
payload = {"predicate": filter}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self.name}/count_rows/", data=payload
|
||||
)
|
||||
return resp
|
||||
return self._loop.run_until_complete(self._table.count_rows(filter))
|
||||
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import requests
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
@@ -57,6 +56,8 @@ class JinaReranker(Reranker):
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
import requests
|
||||
|
||||
if os.environ.get("JINA_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"JINA_API_KEY not set. Either set it in your environment or \
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from numpy import NaN
|
||||
from numpy import nan
|
||||
import pyarrow as pa
|
||||
|
||||
from .base import Reranker
|
||||
@@ -71,7 +71,7 @@ class LinearCombinationReranker(Reranker):
|
||||
elif self.score == "all":
|
||||
results = results.append_column(
|
||||
"_distance",
|
||||
pa.array([NaN] * len(fts_results), type=pa.float32()),
|
||||
pa.array([nan] * len(fts_results), type=pa.float32()),
|
||||
)
|
||||
return results
|
||||
|
||||
@@ -92,7 +92,7 @@ class LinearCombinationReranker(Reranker):
|
||||
elif self.score == "all":
|
||||
results = results.append_column(
|
||||
"_score",
|
||||
pa.array([NaN] * len(vector_results), type=pa.float32()),
|
||||
pa.array([nan] * len(vector_results), type=pa.float32()),
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ if TYPE_CHECKING:
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
from ._lancedb import Table as LanceDBTable, OptimizeStats
|
||||
from .db import LanceDBConnection
|
||||
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS
|
||||
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS, HnswPq, HnswSq
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
@@ -948,7 +948,9 @@ class Table(ABC):
|
||||
return _table_uri(self._conn.uri, self.name)
|
||||
|
||||
def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]:
|
||||
if get_uri_scheme(self._dataset_uri) != "file":
|
||||
from .remote.table import RemoteTable
|
||||
|
||||
if isinstance(self, RemoteTable) or get_uri_scheme(self._dataset_uri) != "file":
|
||||
return ("", None, False)
|
||||
path = join_uri(self._dataset_uri, "_indices", "fts")
|
||||
fs, path = fs_from_uri(path)
|
||||
@@ -2382,7 +2384,9 @@ class AsyncTable:
|
||||
column: str,
|
||||
*,
|
||||
replace: Optional[bool] = None,
|
||||
config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None,
|
||||
config: Optional[
|
||||
Union[IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
|
||||
] = None,
|
||||
):
|
||||
"""Create an index to speed up queries
|
||||
|
||||
@@ -2535,7 +2539,44 @@ class AsyncTable:
|
||||
async def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
pass
|
||||
# The sync remote table calls into this method, so we need to map the
|
||||
# query to the async version of the query and run that here. This is only
|
||||
# used for that code path right now.
|
||||
async_query = self.query().limit(query.k)
|
||||
if query.offset > 0:
|
||||
async_query = async_query.offset(query.offset)
|
||||
if query.columns:
|
||||
async_query = async_query.select(query.columns)
|
||||
if query.filter:
|
||||
async_query = async_query.where(query.filter)
|
||||
if query.fast_search:
|
||||
async_query = async_query.fast_search()
|
||||
if query.with_row_id:
|
||||
async_query = async_query.with_row_id()
|
||||
|
||||
if query.vector:
|
||||
async_query = (
|
||||
async_query.nearest_to(query.vector)
|
||||
.distance_type(query.metric)
|
||||
.nprobes(query.nprobes)
|
||||
)
|
||||
if query.refine_factor:
|
||||
async_query = async_query.refine_factor(query.refine_factor)
|
||||
if query.vector_column:
|
||||
async_query = async_query.column(query.vector_column)
|
||||
|
||||
if not query.prefilter:
|
||||
async_query = async_query.postfilter()
|
||||
|
||||
if isinstance(query.full_text_query, str):
|
||||
async_query = async_query.nearest_to_text(query.full_text_query)
|
||||
elif isinstance(query.full_text_query, dict):
|
||||
fts_query = query.full_text_query["query"]
|
||||
fts_columns = query.full_text_query.get("columns", []) or []
|
||||
async_query = async_query.nearest_to_text(fts_query, columns=fts_columns)
|
||||
|
||||
table = await async_query.to_arrow()
|
||||
return table.to_reader()
|
||||
|
||||
async def _do_merge(
|
||||
self,
|
||||
@@ -2781,7 +2822,7 @@ class AsyncTable:
|
||||
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
|
||||
return await self._inner.optimize(cleanup_older_than, delete_unverified)
|
||||
|
||||
async def list_indices(self) -> IndexConfig:
|
||||
async def list_indices(self) -> Iterable[IndexConfig]:
|
||||
"""
|
||||
List all indices that have been created with Self::create_index
|
||||
"""
|
||||
@@ -2865,3 +2906,8 @@ class IndexStatistics:
|
||||
]
|
||||
distance_type: Optional[Literal["l2", "cosine", "dot"]] = None
|
||||
num_indices: Optional[int] = None
|
||||
|
||||
# This exists for backwards compatibility with an older API, which returned
|
||||
# a dictionary instead of a class.
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
@@ -18,7 +18,6 @@ import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import requests
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
@@ -108,6 +107,7 @@ def test_basic_text_embeddings(alias, tmp_path):
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_openclip(tmp_path):
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
@@ -235,6 +235,29 @@ async def test_search_fts_async(async_table):
|
||||
results = await async_table.query().nearest_to_text("puppy").limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
|
||||
expected_count = await async_table.count_rows(
|
||||
"count > 5000 and contains(text, 'puppy')"
|
||||
)
|
||||
expected_count = min(expected_count, 10)
|
||||
|
||||
limited_results_pre_filter = await (
|
||||
async_table.query()
|
||||
.nearest_to_text("puppy")
|
||||
.where("count > 5000")
|
||||
.limit(10)
|
||||
.to_list()
|
||||
)
|
||||
assert len(limited_results_pre_filter) == expected_count
|
||||
limited_results_post_filter = await (
|
||||
async_table.query()
|
||||
.nearest_to_text("puppy")
|
||||
.where("count > 5000")
|
||||
.limit(10)
|
||||
.postfilter()
|
||||
.to_list()
|
||||
)
|
||||
assert len(limited_results_post_filter) <= expected_count
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fts_specify_column_async(async_table):
|
||||
|
||||
@@ -49,7 +49,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
|
||||
# Can recreate if replace=True
|
||||
await some_table.create_index("id", replace=True)
|
||||
indices = await some_table.list_indices()
|
||||
assert str(indices) == '[Index(BTree, columns=["id"])]'
|
||||
assert str(indices) == '[Index(BTree, columns=["id"], name="id_idx")]'
|
||||
assert len(indices) == 1
|
||||
assert indices[0].index_type == "BTree"
|
||||
assert indices[0].columns == ["id"]
|
||||
@@ -64,7 +64,7 @@ async def test_create_scalar_index(some_table: AsyncTable):
|
||||
async def test_create_bitmap_index(some_table: AsyncTable):
|
||||
await some_table.create_index("id", config=Bitmap())
|
||||
indices = await some_table.list_indices()
|
||||
assert str(indices) == '[Index(Bitmap, columns=["id"])]'
|
||||
assert str(indices) == '[Index(Bitmap, columns=["id"], name="id_idx")]'
|
||||
indices = await some_table.list_indices()
|
||||
assert len(indices) == 1
|
||||
index_name = indices[0].name
|
||||
@@ -80,7 +80,7 @@ async def test_create_bitmap_index(some_table: AsyncTable):
|
||||
async def test_create_label_list_index(some_table: AsyncTable):
|
||||
await some_table.create_index("tags", config=LabelList())
|
||||
indices = await some_table.list_indices()
|
||||
assert str(indices) == '[Index(LabelList, columns=["tags"])]'
|
||||
assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import Optional
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
from lancedb.index import IvfPq
|
||||
import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
@@ -330,6 +331,12 @@ async def test_query_async(table_async: AsyncTable):
|
||||
# Also check an empty query
|
||||
await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
|
||||
|
||||
# with row id
|
||||
await check_query(
|
||||
table_async.query().select(["id", "vector"]).with_row_id(),
|
||||
expected_columns=["id", "vector", "_rowid"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_arrow_async(table_async: AsyncTable):
|
||||
@@ -358,6 +365,25 @@ async def test_query_to_pandas_async(table_async: AsyncTable):
|
||||
assert df.shape == (0, 4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_search_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
vectors = pa.FixedShapeTensorArray.from_numpy_ndarray(
|
||||
np.random.rand(256, 32)
|
||||
).storage
|
||||
table = await db.create_table("test", pa.table({"vector": vectors}))
|
||||
await table.create_index(
|
||||
"vector", config=IvfPq(num_partitions=1, num_sub_vectors=1)
|
||||
)
|
||||
await table.add(pa.table({"vector": vectors}))
|
||||
|
||||
q = [1.0] * 32
|
||||
plan = await table.query().nearest_to(q).explain_plan(True)
|
||||
assert "LanceScan" in plan
|
||||
plan = await table.query().nearest_to(q).fast_search().explain_plan(True)
|
||||
assert "LanceScan" not in plan
|
||||
|
||||
|
||||
def test_explain_plan(table):
|
||||
q = LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
plan = q.explain_plan(verbose=True)
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import attrs
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
|
||||
|
||||
|
||||
@attrs.define
|
||||
class MockLanceDBServer:
|
||||
runner: web.AppRunner = attrs.field(init=False)
|
||||
site: web.TCPSite = attrs.field(init=False)
|
||||
|
||||
async def query_handler(self, request: web.Request) -> web.Response:
|
||||
table_name = request.match_info["table_name"]
|
||||
assert table_name == "test_table"
|
||||
|
||||
await request.json()
|
||||
# TODO: do some matching
|
||||
|
||||
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
|
||||
ids = pd.Series(range(10), name="id")
|
||||
df = pd.DataFrame([vecs, ids]).T
|
||||
|
||||
batch = pa.RecordBatch.from_pandas(
|
||||
df,
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 128)),
|
||||
pa.field("id", pa.int64()),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
sink = pa.BufferOutputStream()
|
||||
with pa.ipc.new_file(sink, batch.schema) as writer:
|
||||
writer.write_batch(batch)
|
||||
|
||||
return web.Response(body=sink.getvalue().to_pybytes())
|
||||
|
||||
async def setup(self):
|
||||
app = web.Application()
|
||||
app.add_routes([web.post("/table/{table_name}", self.query_handler)])
|
||||
self.runner = web.AppRunner(app)
|
||||
await self.runner.setup()
|
||||
self.site = web.TCPSite(self.runner, "localhost", 8111)
|
||||
|
||||
async def start(self):
|
||||
await self.site.start()
|
||||
|
||||
async def stop(self):
|
||||
await self.runner.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky somehow, fix later")
|
||||
@pytest.mark.asyncio
|
||||
async def test_e2e_with_mock_server():
|
||||
mock_server = MockLanceDBServer()
|
||||
await mock_server.setup()
|
||||
await mock_server.start()
|
||||
|
||||
try:
|
||||
with RestfulLanceDBClient("lancedb+http://localhost:8111") as client:
|
||||
df = (
|
||||
await client.query(
|
||||
"test_table",
|
||||
VectorQuery(
|
||||
vector=np.random.rand(128).tolist(),
|
||||
k=10,
|
||||
_metric="L2",
|
||||
columns=["id", "vector"],
|
||||
),
|
||||
)
|
||||
).to_pandas()
|
||||
|
||||
assert "vector" in df.columns
|
||||
assert "id" in df.columns
|
||||
|
||||
assert client.closed
|
||||
finally:
|
||||
# make sure we don't leak resources
|
||||
await mock_server.stop()
|
||||
@@ -2,91 +2,19 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import contextlib
|
||||
from datetime import timedelta
|
||||
import http.server
|
||||
import json
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
import uuid
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.remote import ClientConfig
|
||||
from lancedb.remote.errors import HttpError, RetryError
|
||||
import pyarrow as pa
|
||||
from lancedb.remote.client import VectorQuery, VectorQueryResult
|
||||
import pytest
|
||||
|
||||
|
||||
class FakeLanceDBClient:
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
assert table_name == "test"
|
||||
t = pa.schema([]).empty_table()
|
||||
return VectorQueryResult(t)
|
||||
|
||||
def post(self, path: str):
|
||||
pass
|
||||
|
||||
def mount_retry_adapter_for_table(self, table_name: str):
|
||||
pass
|
||||
|
||||
|
||||
def test_remote_db():
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
setattr(conn, "_client", FakeLanceDBClient())
|
||||
|
||||
table = conn["test"]
|
||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||
table.search([1.0, 2.0]).to_pandas()
|
||||
|
||||
|
||||
def test_create_empty_table():
|
||||
client = MagicMock()
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
|
||||
conn._client = client
|
||||
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||
|
||||
client.post.return_value = {"status": "ok"}
|
||||
table = conn.create_table("test", schema=schema)
|
||||
assert table.name == "test"
|
||||
assert client.post.call_args[0][0] == "/v1/table/test/create/"
|
||||
|
||||
json_schema = {
|
||||
"fields": [
|
||||
{
|
||||
"name": "vector",
|
||||
"nullable": True,
|
||||
"type": {
|
||||
"type": "fixed_size_list",
|
||||
"fields": [
|
||||
{"name": "item", "nullable": True, "type": {"type": "float"}}
|
||||
],
|
||||
"length": 2,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
client.post.return_value = {"schema": json_schema}
|
||||
assert table.schema == schema
|
||||
assert client.post.call_args[0][0] == "/v1/table/test/describe/"
|
||||
|
||||
client.post.return_value = 0
|
||||
assert table.count_rows(None) == 0
|
||||
|
||||
|
||||
def test_create_table_with_recordbatches():
|
||||
client = MagicMock()
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
|
||||
conn._client = client
|
||||
|
||||
batch = pa.RecordBatch.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0]])], ["vector"])
|
||||
|
||||
client.post.return_value = {"status": "ok"}
|
||||
table = conn.create_table("test", [batch], schema=batch.schema)
|
||||
assert table.name == "test"
|
||||
assert client.post.call_args[0][0] == "/v1/table/test/create/"
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def make_mock_http_handler(handler):
|
||||
@@ -100,8 +28,35 @@ def make_mock_http_handler(handler):
|
||||
return MockLanceDBHandler
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_lancedb_connection(handler):
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 8080), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override="http://localhost:8080",
|
||||
client_config={
|
||||
"retry_config": {"retries": 2},
|
||||
"timeout_config": {
|
||||
"connect_timeout": 1,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
server.shutdown()
|
||||
handle.join()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def mock_lancedb_connection(handler):
|
||||
async def mock_lancedb_connection_async(handler):
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 8080), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
@@ -143,7 +98,7 @@ async def test_async_remote_db():
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
async with mock_lancedb_connection(handler) as db:
|
||||
async with mock_lancedb_connection_async(handler) as db:
|
||||
table_names = await db.table_names()
|
||||
assert table_names == []
|
||||
|
||||
@@ -159,12 +114,12 @@ async def test_http_error():
|
||||
request.end_headers()
|
||||
request.wfile.write(b"Internal Server Error")
|
||||
|
||||
async with mock_lancedb_connection(handler) as db:
|
||||
with pytest.raises(HttpError, match="Internal Server Error") as exc_info:
|
||||
async with mock_lancedb_connection_async(handler) as db:
|
||||
with pytest.raises(HttpError) as exc_info:
|
||||
await db.table_names()
|
||||
|
||||
assert exc_info.value.request_id == request_id_holder["request_id"]
|
||||
assert exc_info.value.status_code == 507
|
||||
assert "Internal Server Error" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -178,15 +133,225 @@ async def test_retry_error():
|
||||
request.end_headers()
|
||||
request.wfile.write(b"Try again later")
|
||||
|
||||
async with mock_lancedb_connection(handler) as db:
|
||||
with pytest.raises(RetryError, match="Hit retry limit") as exc_info:
|
||||
async with mock_lancedb_connection_async(handler) as db:
|
||||
with pytest.raises(RetryError) as exc_info:
|
||||
await db.table_names()
|
||||
|
||||
assert exc_info.value.request_id == request_id_holder["request_id"]
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
cause = exc_info.value.__cause__
|
||||
assert isinstance(cause, HttpError)
|
||||
assert "Try again later" in str(cause)
|
||||
assert cause.request_id == request_id_holder["request_id"]
|
||||
assert cause.status_code == 429
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def query_test_table(query_handler):
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"{}")
|
||||
elif request.path == "/v1/table/test/query/":
|
||||
content_len = int(request.headers.get("Content-Length"))
|
||||
body = request.rfile.read(content_len)
|
||||
body = json.loads(body)
|
||||
|
||||
data = query_handler(body)
|
||||
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
|
||||
request.end_headers()
|
||||
|
||||
with pa.ipc.new_file(request.wfile, schema=data.schema) as f:
|
||||
f.write_table(data)
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
assert repr(db) == "RemoteConnect(name=dev)"
|
||||
table = db.open_table("test")
|
||||
assert repr(table) == "RemoteTable(dev.test)"
|
||||
yield table
|
||||
|
||||
|
||||
def test_query_sync_minimal():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"distance_type": "l2",
|
||||
"k": 10,
|
||||
"prefilter": False,
|
||||
"refine_factor": None,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 20,
|
||||
}
|
||||
|
||||
return pa.table({"id": [1, 2, 3]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
data = table.search([1, 2, 3]).to_list()
|
||||
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
assert data == expected
|
||||
|
||||
|
||||
def test_query_sync_maximal():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"distance_type": "cosine",
|
||||
"k": 42,
|
||||
"prefilter": True,
|
||||
"refine_factor": 10,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 5,
|
||||
"filter": "id > 0",
|
||||
"columns": ["id", "name"],
|
||||
"vector_column": "vector2",
|
||||
"fast_search": True,
|
||||
"with_row_id": True,
|
||||
}
|
||||
|
||||
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
(
|
||||
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
|
||||
.metric("cosine")
|
||||
.limit(42)
|
||||
.refine_factor(10)
|
||||
.nprobes(5)
|
||||
.where("id > 0", prefilter=True)
|
||||
.with_row_id(True)
|
||||
.select(["id", "name"])
|
||||
.to_list()
|
||||
)
|
||||
|
||||
|
||||
def test_query_sync_fts():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"full_text_query": {
|
||||
"query": "puppy",
|
||||
"columns": [],
|
||||
},
|
||||
"k": 10,
|
||||
"vector": [],
|
||||
}
|
||||
|
||||
return pa.table({"id": [1, 2, 3]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
(table.search("puppy", query_type="fts").to_list())
|
||||
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"full_text_query": {
|
||||
"query": "puppy",
|
||||
"columns": ["name", "description"],
|
||||
},
|
||||
"k": 42,
|
||||
"vector": [],
|
||||
"with_row_id": True,
|
||||
}
|
||||
|
||||
return pa.table({"id": [1, 2, 3]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
(
|
||||
table.search("puppy", query_type="fts", fts_columns=["name", "description"])
|
||||
.with_row_id(True)
|
||||
.limit(42)
|
||||
.to_list()
|
||||
)
|
||||
|
||||
|
||||
def test_query_sync_hybrid():
|
||||
def handler(body):
|
||||
if "full_text_query" in body:
|
||||
# FTS query
|
||||
assert body == {
|
||||
"full_text_query": {
|
||||
"query": "puppy",
|
||||
"columns": [],
|
||||
},
|
||||
"k": 42,
|
||||
"vector": [],
|
||||
"with_row_id": True,
|
||||
}
|
||||
return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]})
|
||||
else:
|
||||
# Vector query
|
||||
assert body == {
|
||||
"distance_type": "l2",
|
||||
"k": 42,
|
||||
"prefilter": False,
|
||||
"refine_factor": None,
|
||||
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
"nprobes": 20,
|
||||
"with_row_id": True,
|
||||
}
|
||||
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
embedding_func = MockTextEmbeddingFunction()
|
||||
embedding_config = MagicMock()
|
||||
embedding_config.function = embedding_func
|
||||
|
||||
embedding_funcs = MagicMock()
|
||||
embedding_funcs.get = MagicMock(return_value=embedding_config)
|
||||
table.embedding_functions = embedding_funcs
|
||||
|
||||
(table.search("puppy", query_type="hybrid").limit(42).to_list())
|
||||
|
||||
|
||||
def test_create_client():
|
||||
mandatory_args = {
|
||||
"uri": "db://dev",
|
||||
"api_key": "fake-api-key",
|
||||
"region": "us-east-1",
|
||||
}
|
||||
|
||||
db = lancedb.connect(**mandatory_args)
|
||||
assert isinstance(db.client_config, ClientConfig)
|
||||
|
||||
db = lancedb.connect(**mandatory_args, client_config={})
|
||||
assert isinstance(db.client_config, ClientConfig)
|
||||
|
||||
db = lancedb.connect(
|
||||
**mandatory_args,
|
||||
client_config=ClientConfig(timeout_config={"connect_timeout": 42}),
|
||||
)
|
||||
assert isinstance(db.client_config, ClientConfig)
|
||||
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
||||
|
||||
db = lancedb.connect(
|
||||
**mandatory_args,
|
||||
client_config={"timeout_config": {"connect_timeout": timedelta(seconds=42)}},
|
||||
)
|
||||
assert isinstance(db.client_config, ClientConfig)
|
||||
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
||||
|
||||
db = lancedb.connect(
|
||||
**mandatory_args, client_config=ClientConfig(retry_config={"retries": 42})
|
||||
)
|
||||
assert isinstance(db.client_config, ClientConfig)
|
||||
assert db.client_config.retry_config.retries == 42
|
||||
|
||||
db = lancedb.connect(
|
||||
**mandatory_args, client_config={"retry_config": {"retries": 42}}
|
||||
)
|
||||
assert isinstance(db.client_config, ClientConfig)
|
||||
assert db.client_config.retry_config.retries == 42
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
db = lancedb.connect(**mandatory_args, connection_timeout=42)
|
||||
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
db = lancedb.connect(**mandatory_args, read_timeout=42)
|
||||
assert db.client_config.timeout_config.read_timeout == timedelta(seconds=42)
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
lancedb.connect(**mandatory_args, request_thread_pool=10)
|
||||
|
||||
@@ -170,6 +170,17 @@ impl Connection {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn rename_table(
|
||||
self_: PyRef<'_, Self>,
|
||||
old_name: String,
|
||||
new_name: String,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.rename_table(old_name, new_name).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn drop_table(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
|
||||
@@ -24,8 +24,8 @@ use lancedb::{
|
||||
DistanceType,
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods, PyResult,
|
||||
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods, IntoPy, PyObject, PyResult, Python,
|
||||
};
|
||||
|
||||
use crate::util::parse_distance_type;
|
||||
@@ -236,7 +236,21 @@ pub struct IndexConfig {
|
||||
#[pymethods]
|
||||
impl IndexConfig {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!("Index({}, columns={:?})", self.index_type, self.columns)
|
||||
format!(
|
||||
"Index({}, columns={:?}, name=\"{}\")",
|
||||
self.index_type, self.columns, self.name
|
||||
)
|
||||
}
|
||||
|
||||
// For backwards-compatibility with the old sync SDK, we also support getting
|
||||
// attributes via __getitem__.
|
||||
pub fn __getitem__(&self, key: String, py: Python<'_>) -> PyResult<PyObject> {
|
||||
match key.as_str() {
|
||||
"index_type" => Ok(self.index_type.clone().into_py(py)),
|
||||
"columns" => Ok(self.columns.clone().into_py(py)),
|
||||
"name" | "index_name" => Ok(self.name.clone().into_py(py)),
|
||||
_ => Err(PyKeyError::new_err(format!("Invalid key: {}", key))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -68,6 +68,18 @@ impl Query {
|
||||
self.inner = self.inner.clone().offset(offset as usize);
|
||||
}
|
||||
|
||||
pub fn fast_search(&mut self) {
|
||||
self.inner = self.inner.clone().fast_search();
|
||||
}
|
||||
|
||||
pub fn with_row_id(&mut self) {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
pub fn postfilter(&mut self) {
|
||||
self.inner = self.inner.clone().postfilter();
|
||||
}
|
||||
|
||||
pub fn nearest_to(&mut self, vector: Bound<'_, PyAny>) -> PyResult<VectorQuery> {
|
||||
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
|
||||
let array = make_array(data);
|
||||
@@ -146,6 +158,14 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().offset(offset as usize);
|
||||
}
|
||||
|
||||
pub fn fast_search(&mut self) {
|
||||
self.inner = self.inner.clone().fast_search();
|
||||
}
|
||||
|
||||
pub fn with_row_id(&mut self) {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
pub fn column(&mut self, column: String) {
|
||||
self.inner = self.inner.clone().column(&column);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-node"
|
||||
version = "0.11.1-beta.1"
|
||||
version = "0.13.0-beta.0"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.11.1-beta.1"
|
||||
version = "0.13.0-beta.0"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
|
||||
@@ -39,9 +39,6 @@ use crate::utils::validate_table_name;
|
||||
use crate::Table;
|
||||
pub use lance_encoding::version::LanceFileVersion;
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
use log::warn;
|
||||
|
||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
|
||||
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>;
|
||||
@@ -719,8 +716,7 @@ impl ConnectBuilder {
|
||||
let api_key = self.api_key.ok_or_else(|| Error::InvalidInput {
|
||||
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
|
||||
})?;
|
||||
// TODO: remove this warning when the remote client is ready
|
||||
warn!("The rust implementation of the remote client is not yet ready for use.");
|
||||
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
|
||||
&self.uri,
|
||||
&api_key,
|
||||
|
||||
@@ -119,6 +119,7 @@ pub enum IndexType {
|
||||
#[serde(alias = "LABEL_LIST")]
|
||||
LabelList,
|
||||
// FTS
|
||||
#[serde(alias = "INVERTED", alias = "Inverted")]
|
||||
FTS,
|
||||
}
|
||||
|
||||
|
||||
@@ -403,6 +403,26 @@ pub trait QueryBase {
|
||||
/// By default, it is false.
|
||||
fn fast_search(self) -> Self;
|
||||
|
||||
/// If this is called then filtering will happen after the vector search instead of
|
||||
/// before.
|
||||
///
|
||||
/// By default filtering will be performed before the vector search. This is how
|
||||
/// filtering is typically understood to work. This prefilter step does add some
|
||||
/// additional latency. Creating a scalar index on the filter column(s) can
|
||||
/// often improve this latency. However, sometimes a filter is too complex or scalar
|
||||
/// indices cannot be applied to the column. In these cases postfiltering can be
|
||||
/// used instead of prefiltering to improve latency.
|
||||
///
|
||||
/// Post filtering applies the filter to the results of the vector search. This means
|
||||
/// we only run the filter on a much smaller set of data. However, it can cause the
|
||||
/// query to return fewer than `limit` results (or even no results) if none of the nearest
|
||||
/// results match the filter.
|
||||
///
|
||||
/// Post filtering happens during the "refine stage" (described in more detail in
|
||||
/// [`Self::refine_factor`]). This means that setting a higher refine factor can often
|
||||
/// help restore some of the results lost by post filtering.
|
||||
fn postfilter(self) -> Self;
|
||||
|
||||
/// Return the `_rowid` meta column from the Table.
|
||||
fn with_row_id(self) -> Self;
|
||||
}
|
||||
@@ -442,6 +462,11 @@ impl<T: HasQuery> QueryBase for T {
|
||||
self
|
||||
}
|
||||
|
||||
fn postfilter(mut self) -> Self {
|
||||
self.mut_query().prefilter = false;
|
||||
self
|
||||
}
|
||||
|
||||
fn with_row_id(mut self) -> Self {
|
||||
self.mut_query().with_row_id = true;
|
||||
self
|
||||
@@ -561,6 +586,9 @@ pub struct Query {
|
||||
///
|
||||
/// By default, this is false.
|
||||
pub(crate) with_row_id: bool,
|
||||
|
||||
/// If set to false, the filter will be applied after the vector search.
|
||||
pub(crate) prefilter: bool,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
@@ -574,6 +602,7 @@ impl Query {
|
||||
select: Select::All,
|
||||
fast_search: false,
|
||||
with_row_id: false,
|
||||
prefilter: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -678,8 +707,6 @@ pub struct VectorQuery {
|
||||
pub(crate) distance_type: Option<DistanceType>,
|
||||
/// Default is true. Set to false to enforce a brute force search.
|
||||
pub(crate) use_index: bool,
|
||||
/// Apply filter before ANN search/
|
||||
pub(crate) prefilter: bool,
|
||||
}
|
||||
|
||||
impl VectorQuery {
|
||||
@@ -692,7 +719,6 @@ impl VectorQuery {
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
use_index: true,
|
||||
prefilter: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -782,29 +808,6 @@ impl VectorQuery {
|
||||
self
|
||||
}
|
||||
|
||||
/// If this is called then filtering will happen after the vector search instead of
|
||||
/// before.
|
||||
///
|
||||
/// By default filtering will be performed before the vector search. This is how
|
||||
/// filtering is typically understood to work. This prefilter step does add some
|
||||
/// additional latency. Creating a scalar index on the filter column(s) can
|
||||
/// often improve this latency. However, sometimes a filter is too complex or scalar
|
||||
/// indices cannot be applied to the column. In these cases postfiltering can be
|
||||
/// used instead of prefiltering to improve latency.
|
||||
///
|
||||
/// Post filtering applies the filter to the results of the vector search. This means
|
||||
/// we only run the filter on a much smaller set of data. However, it can cause the
|
||||
/// query to return fewer than `limit` results (or even no results) if none of the nearest
|
||||
/// results match the filter.
|
||||
///
|
||||
/// Post filtering happens during the "refine stage" (described in more detail in
|
||||
/// [`Self::refine_factor`]). This means that setting a higher refine factor can often
|
||||
/// help restore some of the results lost by post filtering.
|
||||
pub fn postfilter(mut self) -> Self {
|
||||
self.prefilter = false;
|
||||
self
|
||||
}
|
||||
|
||||
/// If this is called then any vector index is skipped
|
||||
///
|
||||
/// An exhaustive (flat) search will be performed. The query vector will
|
||||
|
||||
@@ -23,6 +23,8 @@ pub(crate) mod table;
|
||||
pub(crate) mod util;
|
||||
|
||||
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
||||
#[cfg(test)]
|
||||
const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
|
||||
const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, RetryConfig, TimeoutConfig};
|
||||
|
||||
@@ -341,7 +341,22 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
request_id
|
||||
};
|
||||
|
||||
debug!("Sending request_id={}: {:?}", request_id, &request);
|
||||
if log::log_enabled!(log::Level::Debug) {
|
||||
let content_type = request
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.map(|v| v.to_str().unwrap());
|
||||
if content_type == Some("application/json") {
|
||||
let body = request.body().as_ref().unwrap().as_bytes().unwrap();
|
||||
let body = String::from_utf8_lossy(body);
|
||||
debug!(
|
||||
"Sending request_id={}: {:?} with body {}",
|
||||
request_id, request, body
|
||||
);
|
||||
} else {
|
||||
debug!("Sending request_id={}: {:?}", request_id, request);
|
||||
}
|
||||
}
|
||||
|
||||
if with_retry {
|
||||
self.send_with_retry_impl(client, request, request_id).await
|
||||
|
||||
@@ -161,7 +161,7 @@ impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
|
||||
if self.table_cache.get(&options.name).is_none() {
|
||||
let req = self
|
||||
.client
|
||||
.get(&format!("/v1/table/{}/describe/", options.name));
|
||||
.post(&format!("/v1/table/{}/describe/", options.name));
|
||||
let (request_id, resp) = self.client.send(req, true).await?;
|
||||
if resp.status() == StatusCode::NOT_FOUND {
|
||||
return Err(crate::Error::TableNotFound { name: options.name });
|
||||
@@ -301,7 +301,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_open_table() {
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/table/table1/describe/");
|
||||
assert_eq!(request.url().query(), None);
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::io::Cursor;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::index::Index;
|
||||
@@ -7,10 +8,9 @@ use crate::table::AddDataMode;
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::Error;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_ipc::reader::StreamReader;
|
||||
use arrow_ipc::reader::FileReader;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use bytes::Buf;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
|
||||
@@ -115,39 +115,14 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
async fn read_arrow_stream(
|
||||
&self,
|
||||
request_id: &str,
|
||||
body: reqwest::Response,
|
||||
response: reqwest::Response,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
// Assert that the content type is correct
|
||||
let content_type = body
|
||||
.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.ok_or_else(|| Error::Http {
|
||||
source: "Missing content type".into(),
|
||||
request_id: request_id.to_string(),
|
||||
status_code: None,
|
||||
})?
|
||||
.to_str()
|
||||
.map_err(|e| Error::Http {
|
||||
source: format!("Failed to parse content type: {}", e).into(),
|
||||
request_id: request_id.to_string(),
|
||||
status_code: None,
|
||||
})?;
|
||||
if content_type != ARROW_STREAM_CONTENT_TYPE {
|
||||
return Err(Error::Http {
|
||||
source: format!(
|
||||
"Expected content type {}, got {}",
|
||||
ARROW_STREAM_CONTENT_TYPE, content_type
|
||||
)
|
||||
.into(),
|
||||
request_id: request_id.to_string(),
|
||||
status_code: None,
|
||||
});
|
||||
}
|
||||
let response = self.check_table_response(request_id, response).await?;
|
||||
|
||||
// There isn't a way to actually stream this data yet. I have an upstream issue:
|
||||
// https://github.com/apache/arrow-rs/issues/6420
|
||||
let body = body.bytes().await.err_to_http(request_id.into())?;
|
||||
let reader = StreamReader::try_new(body.reader(), None)?;
|
||||
let body = response.bytes().await.err_to_http(request_id.into())?;
|
||||
let reader = FileReader::try_new(Cursor::new(body), None)?;
|
||||
let schema = reader.schema();
|
||||
let stream = futures::stream::iter(reader).map_err(DataFusionError::from);
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
|
||||
@@ -192,6 +167,10 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
body["fast_search"] = serde_json::Value::Bool(true);
|
||||
}
|
||||
|
||||
if params.with_row_id {
|
||||
body["with_row_id"] = serde_json::Value::Bool(true);
|
||||
}
|
||||
|
||||
if let Some(full_text_search) = ¶ms.full_text_search {
|
||||
if full_text_search.wand_factor.is_some() {
|
||||
return Err(Error::NotSupported {
|
||||
@@ -277,7 +256,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
.post(&format!("/v1/table/{}/count_rows/", self.name));
|
||||
|
||||
if let Some(filter) = filter {
|
||||
request = request.json(&serde_json::json!({ "filter": filter }));
|
||||
request = request.json(&serde_json::json!({ "predicate": filter }));
|
||||
} else {
|
||||
request = request.json(&serde_json::json!({}));
|
||||
}
|
||||
@@ -330,13 +309,13 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
let mut body = serde_json::Value::Object(Default::default());
|
||||
Self::apply_query_params(&mut body, &query.base)?;
|
||||
|
||||
body["prefilter"] = query.prefilter.into();
|
||||
body["prefilter"] = query.base.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["refine_factor"] = query.refine_factor.into();
|
||||
|
||||
if let Some(vector) = query.query_vector.as_ref() {
|
||||
let vector: Vec<f32> = match vector.data_type() {
|
||||
let vector: Vec<f32> = if let Some(vector) = query.query_vector.as_ref() {
|
||||
match vector.data_type() {
|
||||
DataType::Float32 => vector
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::Float32Array>()
|
||||
@@ -350,9 +329,12 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
message: "VectorQuery vector must be of type Float32".into(),
|
||||
})
|
||||
}
|
||||
};
|
||||
body["vector"] = serde_json::json!(vector);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Server takes empty vector, not null or undefined.
|
||||
Vec::new()
|
||||
};
|
||||
body["vector"] = serde_json::json!(vector);
|
||||
|
||||
if let Some(vector_column) = query.column.as_ref() {
|
||||
body["vector_column"] = serde_json::Value::String(vector_column.clone());
|
||||
@@ -383,6 +365,8 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
let mut body = serde_json::Value::Object(Default::default());
|
||||
Self::apply_query_params(&mut body, query)?;
|
||||
// Empty vector can be passed if no vector search is performed.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
@@ -399,30 +383,19 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
let mut updates = Vec::new();
|
||||
for (column, expression) in update.columns {
|
||||
updates.push(column);
|
||||
updates.push(expression);
|
||||
updates.push(vec![column, expression]);
|
||||
}
|
||||
|
||||
let request = request.json(&serde_json::json!({
|
||||
"updates": updates,
|
||||
"only_if": update.filter,
|
||||
"predicate": update.filter,
|
||||
}));
|
||||
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
|
||||
let body = response.text().await.err_to_http(request_id.clone())?;
|
||||
|
||||
serde_json::from_str(&body).map_err(|e| Error::Http {
|
||||
source: format!(
|
||||
"Failed to parse updated rows result from response {}: {}",
|
||||
body, e
|
||||
)
|
||||
.into(),
|
||||
request_id,
|
||||
status_code: None,
|
||||
})
|
||||
Ok(0) // TODO: support returning number of modified rows once supported in SaaS.
|
||||
}
|
||||
async fn delete(&self, predicate: &str) -> Result<()> {
|
||||
let body = serde_json::json!({ "predicate": predicate });
|
||||
@@ -691,6 +664,7 @@ mod tests {
|
||||
use crate::{
|
||||
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
remote::ARROW_FILE_CONTENT_TYPE,
|
||||
DistanceType, Error, Table,
|
||||
};
|
||||
|
||||
@@ -804,7 +778,7 @@ mod tests {
|
||||
);
|
||||
assert_eq!(
|
||||
request.body().unwrap().as_bytes().unwrap(),
|
||||
br#"{"filter":"a > 10"}"#
|
||||
br#"{"predicate":"a > 10"}"#
|
||||
);
|
||||
|
||||
http::Response::builder().status(200).body("42").unwrap()
|
||||
@@ -839,6 +813,17 @@ mod tests {
|
||||
body
|
||||
}
|
||||
|
||||
fn write_ipc_file(data: &RecordBatch) -> Vec<u8> {
|
||||
let mut body = Vec::new();
|
||||
{
|
||||
let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut body, &data.schema())
|
||||
.expect("Failed to create writer");
|
||||
writer.write(data).expect("Failed to write data");
|
||||
writer.finish().expect("Failed to finish");
|
||||
}
|
||||
body
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_append() {
|
||||
let data = RecordBatch::try_new(
|
||||
@@ -947,21 +932,27 @@ mod tests {
|
||||
let updates = value.get("updates").unwrap().as_array().unwrap();
|
||||
assert!(updates.len() == 2);
|
||||
|
||||
let col_name = updates[0].as_str().unwrap();
|
||||
let expression = updates[1].as_str().unwrap();
|
||||
let col_name = updates[0][0].as_str().unwrap();
|
||||
let expression = updates[0][1].as_str().unwrap();
|
||||
assert_eq!(col_name, "a");
|
||||
assert_eq!(expression, "a + 1");
|
||||
|
||||
let only_if = value.get("only_if").unwrap().as_str().unwrap();
|
||||
let col_name = updates[1][0].as_str().unwrap();
|
||||
let expression = updates[1][1].as_str().unwrap();
|
||||
assert_eq!(col_name, "b");
|
||||
assert_eq!(expression, "b - 1");
|
||||
|
||||
let only_if = value.get("predicate").unwrap().as_str().unwrap();
|
||||
assert_eq!(only_if, "b > 10");
|
||||
}
|
||||
|
||||
http::Response::builder().status(200).body("1").unwrap()
|
||||
http::Response::builder().status(200).body("{}").unwrap()
|
||||
});
|
||||
|
||||
table
|
||||
.update()
|
||||
.column("a", "a + 1")
|
||||
.column("b", "b - 1")
|
||||
.only_if("b > 10")
|
||||
.execute()
|
||||
.await
|
||||
@@ -1092,10 +1083,10 @@ mod tests {
|
||||
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
let response_body = write_ipc_stream(&expected_data_ref);
|
||||
let response_body = write_ipc_file(&expected_data_ref);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
@@ -1142,10 +1133,10 @@ mod tests {
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_stream(&data);
|
||||
let response_body = write_ipc_file(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
@@ -1185,6 +1176,8 @@ mod tests {
|
||||
"query": "hello world",
|
||||
},
|
||||
"k": 10,
|
||||
"vector": [],
|
||||
"with_row_id": true,
|
||||
});
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
@@ -1193,10 +1186,10 @@ mod tests {
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_stream(&data);
|
||||
let response_body = write_ipc_file(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
@@ -1207,6 +1200,7 @@ mod tests {
|
||||
FullTextSearchQuery::new("hello world".into())
|
||||
.columns(Some(vec!["a".into(), "b".into()])),
|
||||
)
|
||||
.with_row_id()
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
|
||||
@@ -1842,7 +1842,7 @@ impl TableInternal for NativeTable {
|
||||
|
||||
scanner.nprobs(query.nprobes);
|
||||
scanner.use_index(query.use_index);
|
||||
scanner.prefilter(query.prefilter);
|
||||
scanner.prefilter(query.base.prefilter);
|
||||
match query.base.select {
|
||||
Select::Columns(ref columns) => {
|
||||
scanner.project(columns.as_slice())?;
|
||||
@@ -3123,6 +3123,12 @@ mod tests {
|
||||
assert_eq!(index.index_type, crate::index::IndexType::FTS);
|
||||
assert_eq!(index.columns, vec!["text".to_string()]);
|
||||
assert_eq!(index.name, "text_idx");
|
||||
|
||||
let stats = table.index_stats("text_idx").await.unwrap().unwrap();
|
||||
assert_eq!(stats.num_indexed_rows, num_rows);
|
||||
assert_eq!(stats.num_unindexed_rows, 0);
|
||||
assert_eq!(stats.index_type, crate::index::IndexType::FTS);
|
||||
assert_eq!(stats.distance_type, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Reference in New Issue
Block a user