mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 21:39:57 +00:00
Compare commits
22 Commits
qian@saas-
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0608044a1 | ||
|
|
2e4ea7d2bc | ||
|
|
57e5695a54 | ||
|
|
ce58ea7c38 | ||
|
|
57207eff4a | ||
|
|
2d78bff120 | ||
|
|
7c09b9b9a9 | ||
|
|
bd0034a157 | ||
|
|
144b3b5d83 | ||
|
|
b6f0a31686 | ||
|
|
9ec526f73f | ||
|
|
600bfd7237 | ||
|
|
d087e7891d | ||
|
|
098e397cf0 | ||
|
|
63ee8fa6a1 | ||
|
|
693091db29 | ||
|
|
dca4533dbe | ||
|
|
f6bbe199dc | ||
|
|
366e522c2b | ||
|
|
244b6919cc | ||
|
|
aca785ff98 | ||
|
|
bbdebf2c38 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.3.9
|
||||
current_version = 0.3.11
|
||||
commit = True
|
||||
message = Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
20
.github/workflows/npm-publish.yml
vendored
20
.github/workflows/npm-publish.yml
vendored
@@ -38,13 +38,17 @@ jobs:
|
||||
node/vectordb-*.tgz
|
||||
|
||||
node-macos:
|
||||
runs-on: macos-13
|
||||
strategy:
|
||||
matrix:
|
||||
config:
|
||||
- arch: x86_64-apple-darwin
|
||||
runner: macos-13
|
||||
- arch: aarch64-apple-darwin
|
||||
# xlarge is implicitly arm64.
|
||||
runner: macos-13-xlarge
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
# Only runs on tags that matches the make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
target: [x86_64-apple-darwin, aarch64-apple-darwin]
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
@@ -54,11 +58,8 @@ jobs:
|
||||
run: |
|
||||
cd node
|
||||
npm ci
|
||||
- name: Install rustup target
|
||||
if: ${{ matrix.target == 'aarch64-apple-darwin' }}
|
||||
run: rustup target add aarch64-apple-darwin
|
||||
- name: Build MacOS native node modules
|
||||
run: bash ci/build_macos_artifacts.sh ${{ matrix.target }}
|
||||
run: bash ci/build_macos_artifacts.sh ${{ matrix.config.arch }}
|
||||
- name: Upload Darwin Artifacts
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
@@ -66,6 +67,7 @@ jobs:
|
||||
path: |
|
||||
node/dist/lancedb-vectordb-darwin*.tgz
|
||||
|
||||
|
||||
node-linux:
|
||||
name: node-linux (${{ matrix.config.arch}}-unknown-linux-gnu
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
|
||||
6
.github/workflows/python.yml
vendored
6
.github/workflows/python.yml
vendored
@@ -91,11 +91,7 @@ jobs:
|
||||
pip install "pydantic<2"
|
||||
pip install -e .[tests]
|
||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||
pip install pytest pytest-mock black isort
|
||||
- name: Black
|
||||
run: black --check --diff --no-color --quiet .
|
||||
- name: isort
|
||||
run: isort --check --diff --quiet .
|
||||
pip install pytest pytest-mock
|
||||
- name: Run tests
|
||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||
- name: doctest
|
||||
|
||||
@@ -5,10 +5,10 @@ exclude = ["python"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.8.17", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.8.17" }
|
||||
lance-linalg = { "version" = "=0.8.17" }
|
||||
lance-testing = { "version" = "=0.8.17" }
|
||||
lance = { "version" = "=0.8.20", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.8.20" }
|
||||
lance-linalg = { "version" = "=0.8.20" }
|
||||
lance-testing = { "version" = "=0.8.20" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "47.0.0", optional = false }
|
||||
arrow-array = "47.0"
|
||||
|
||||
@@ -5,10 +5,11 @@
|
||||
|
||||
**Developer-friendly, serverless vector database for AI applications**
|
||||
|
||||
<a href="https://lancedb.github.io/lancedb/">Documentation</a> •
|
||||
<a href="https://blog.lancedb.com/">Blog</a> •
|
||||
<a href="https://discord.gg/zMM32dvNtd">Discord</a> •
|
||||
<a href="https://twitter.com/lancedb">Twitter</a>
|
||||
<a href='https://github.com/lancedb/vectordb-recipes/tree/main' target="_blank"><img alt='LanceDB' src='https://img.shields.io/badge/VectorDB_Recipes-100000?style=for-the-badge&logo=LanceDB&logoColor=white&labelColor=645cfb&color=645cfb'/></a>
|
||||
<a href='https://lancedb.github.io/lancedb/' target="_blank"><img alt='lancdb' src='https://img.shields.io/badge/DOCS-100000?style=for-the-badge&logo=lancdb&logoColor=white&labelColor=645cfb&color=645cfb'/></a>
|
||||
[](https://blog.lancedb.com/)
|
||||
[](https://discord.gg/zMM32dvNtd)
|
||||
[](https://twitter.com/lancedb)
|
||||
|
||||
</p>
|
||||
|
||||
|
||||
@@ -80,7 +80,6 @@ nav:
|
||||
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
||||
- Available Functions: embeddings/default_embedding_functions.md
|
||||
- Create Custom Embedding Functions: embeddings/api.md
|
||||
- Example - Calculate CLIP Embeddings with Roboflow Inference: examples/image_embeddings_roboflow.md
|
||||
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
|
||||
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- 🔍 Python full-text search: fts.md
|
||||
@@ -99,6 +98,7 @@ nav:
|
||||
- YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb
|
||||
- Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb
|
||||
- Multimodal search using CLIP: notebooks/multimodal_search.ipynb
|
||||
- Example - Calculate CLIP Embeddings with Roboflow Inference: examples/image_embeddings_roboflow.md
|
||||
- Serverless QA Bot with S3 and Lambda: examples/serverless_lancedb_with_s3_and_lambda.md
|
||||
- Serverless QA Bot with Modal: examples/serverless_qa_bot_with_modal_and_langchain.md
|
||||
- 🌐 Javascript examples:
|
||||
@@ -146,7 +146,8 @@ nav:
|
||||
- Serverless Chatbot from any website: examples/serverless_website_chatbot.md
|
||||
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
||||
- API references:
|
||||
- Python API: python/python.md
|
||||
- OSS Python API: python/python.md
|
||||
- SaaS Python API: python/saas-python.md
|
||||
- Javascript API: javascript/modules.md
|
||||
- LanceDB Cloud↗: https://noteforms.com/forms/lancedb-mailing-list-cloud-kty1o5?notionforms=1&utm_source=notionforms
|
||||
|
||||
|
||||
18
docs/src/python/saas-python.md
Normal file
18
docs/src/python/saas-python.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# LanceDB Python API Reference
|
||||
|
||||
## Installation
|
||||
|
||||
```shell
|
||||
pip install lancedb
|
||||
```
|
||||
|
||||
## Connection
|
||||
|
||||
::: lancedb.connect
|
||||
|
||||
::: lancedb.remote.db.RemoteDBConnection
|
||||
|
||||
## Table
|
||||
|
||||
::: lancedb.remote.table.RemoteTable
|
||||
|
||||
86
node/package-lock.json
generated
86
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.3.9",
|
||||
"version": "0.3.11",
|
||||
"lockfileVersion": 2,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.3.9",
|
||||
"version": "0.3.11",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -53,11 +53,11 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.9",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.9",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.9",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.9",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.9"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.11",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.11",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.11",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.11",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.11"
|
||||
}
|
||||
},
|
||||
"node_modules/@apache-arrow/ts": {
|
||||
@@ -316,54 +316,6 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.3.9",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.9.tgz",
|
||||
"integrity": "sha512-4xXQoPheyIl1P5kRoKmZtaAHFrYdL9pw5yq+r6ewIx0TCemN4LSvzSUTqM5nZl3QPU8FeL0CGD8Gt2gMU0HQ2A==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.3.9",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.9.tgz",
|
||||
"integrity": "sha512-WIxCZKnLeSlz0PGURtKSX6hJ4CYE2o5P+IFmmuWOWB1uNapQu6zOpea6rNxcRFHUA0IJdO02lVxVfn2hDX4SMg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.3.9",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.9.tgz",
|
||||
"integrity": "sha512-bQbcV9adKzYbJLNzDjk9OYsMnT2IjmieLfb4IQ1hj5IUoWfbg80Bd0+gZUnrmrhG6fe56TIriFZYQR9i7TSE9Q==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.3.9",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.9.tgz",
|
||||
"integrity": "sha512-7EXI7P1QvAfgJNPWWBMDOkoJ696gSBAClcyEJNYg0JV21jVFZRwJVI3bZXflesWduFi/mTuzPkFFA68us1u19A==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
]
|
||||
},
|
||||
"node_modules/@neon-rs/cli": {
|
||||
"version": "0.0.160",
|
||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||
@@ -4856,30 +4808,6 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||
}
|
||||
},
|
||||
"@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.3.9",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.9.tgz",
|
||||
"integrity": "sha512-4xXQoPheyIl1P5kRoKmZtaAHFrYdL9pw5yq+r6ewIx0TCemN4LSvzSUTqM5nZl3QPU8FeL0CGD8Gt2gMU0HQ2A==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.3.9",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.9.tgz",
|
||||
"integrity": "sha512-WIxCZKnLeSlz0PGURtKSX6hJ4CYE2o5P+IFmmuWOWB1uNapQu6zOpea6rNxcRFHUA0IJdO02lVxVfn2hDX4SMg==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.3.9",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.9.tgz",
|
||||
"integrity": "sha512-bQbcV9adKzYbJLNzDjk9OYsMnT2IjmieLfb4IQ1hj5IUoWfbg80Bd0+gZUnrmrhG6fe56TIriFZYQR9i7TSE9Q==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.3.9",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.9.tgz",
|
||||
"integrity": "sha512-7EXI7P1QvAfgJNPWWBMDOkoJ696gSBAClcyEJNYg0JV21jVFZRwJVI3bZXflesWduFi/mTuzPkFFA68us1u19A==",
|
||||
"optional": true
|
||||
},
|
||||
"@neon-rs/cli": {
|
||||
"version": "0.0.160",
|
||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.3.9",
|
||||
"version": "0.3.11",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -81,10 +81,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.9",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.9",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.9",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.9",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.9"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.11",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.11",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.11",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.11",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.11"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,9 +21,10 @@ import type { EmbeddingFunction } from './embedding/embedding_function'
|
||||
import { RemoteConnection } from './remote'
|
||||
import { Query } from './query'
|
||||
import { isEmbeddingFunction } from './embedding/embedding_function'
|
||||
import { type Literal, toSQL } from './util'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
|
||||
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableUpdate, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
|
||||
|
||||
export { Query }
|
||||
export type { EmbeddingFunction }
|
||||
@@ -261,6 +262,39 @@ export interface Table<T = number[]> {
|
||||
*/
|
||||
delete: (filter: string) => Promise<void>
|
||||
|
||||
/**
|
||||
* Update rows in this table.
|
||||
*
|
||||
* This can be used to update a single row, many rows, all rows, or
|
||||
* sometimes no rows (if your predicate matches nothing).
|
||||
*
|
||||
* @param args see {@link UpdateArgs} and {@link UpdateSqlArgs} for more details
|
||||
*
|
||||
* @examples
|
||||
*
|
||||
* ```ts
|
||||
* const con = await lancedb.connect("./.lancedb")
|
||||
* const data = [
|
||||
* {id: 1, vector: [3, 3], name: 'Ye'},
|
||||
* {id: 2, vector: [4, 4], name: 'Mike'},
|
||||
* ];
|
||||
* const tbl = await con.createTable("my_table", data)
|
||||
*
|
||||
* await tbl.update({
|
||||
* filter: "id = 2",
|
||||
* updates: { vector: [2, 2], name: "Michael" },
|
||||
* })
|
||||
*
|
||||
* let results = await tbl.search([1, 1]).execute();
|
||||
* // Returns [
|
||||
* // {id: 2, vector: [2, 2], name: 'Michael'}
|
||||
* // {id: 1, vector: [3, 3], name: 'Ye'}
|
||||
* // ]
|
||||
* ```
|
||||
*
|
||||
*/
|
||||
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
|
||||
|
||||
/**
|
||||
* List the indicies on this table.
|
||||
*/
|
||||
@@ -272,6 +306,34 @@ export interface Table<T = number[]> {
|
||||
indexStats: (indexUuid: string) => Promise<IndexStats>
|
||||
}
|
||||
|
||||
export interface UpdateArgs {
|
||||
/**
|
||||
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
|
||||
* in which case all rows will be updated.
|
||||
*/
|
||||
where?: string
|
||||
|
||||
/**
|
||||
* A key-value map of updates. The keys are the column names, and the values are the
|
||||
* new values to set
|
||||
*/
|
||||
values: Record<string, Literal>
|
||||
}
|
||||
|
||||
export interface UpdateSqlArgs {
|
||||
/**
|
||||
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
|
||||
* in which case all rows will be updated.
|
||||
*/
|
||||
where?: string
|
||||
|
||||
/**
|
||||
* A key-value map of updates. The keys are the column names, and the values are the
|
||||
* new values to set as SQL expressions.
|
||||
*/
|
||||
valuesSql: Record<string, string>
|
||||
}
|
||||
|
||||
export interface VectorIndex {
|
||||
columns: string[]
|
||||
name: string
|
||||
@@ -426,6 +488,16 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
return new Query(query, this._tbl, this._embeddings)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a filter query to find all rows matching the specified criteria
|
||||
* @param value The filter criteria (like SQL where clause syntax)
|
||||
*/
|
||||
filter (value: string): Query<T> {
|
||||
return new Query(undefined, this._tbl, this._embeddings).filter(value)
|
||||
}
|
||||
|
||||
where = this.filter
|
||||
|
||||
/**
|
||||
* Insert records into this Table.
|
||||
*
|
||||
@@ -481,6 +553,31 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
return tableDelete.call(this._tbl, filter).then((newTable: any) => { this._tbl = newTable })
|
||||
}
|
||||
|
||||
/**
|
||||
* Update rows in this table.
|
||||
*
|
||||
* @param args see {@link UpdateArgs} and {@link UpdateSqlArgs} for more details
|
||||
*
|
||||
* @returns
|
||||
*/
|
||||
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
|
||||
let filter: string | null
|
||||
let updates: Record<string, string>
|
||||
|
||||
if ('valuesSql' in args) {
|
||||
filter = args.where ?? null
|
||||
updates = args.valuesSql
|
||||
} else {
|
||||
filter = args.where ?? null
|
||||
updates = {}
|
||||
for (const [key, value] of Object.entries(args.values)) {
|
||||
updates[key] = toSQL(value)
|
||||
}
|
||||
}
|
||||
|
||||
return tableUpdate.call(this._tbl, filter, updates).then((newTable: any) => { this._tbl = newTable })
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up old versions of the table, freeing disk space.
|
||||
*
|
||||
|
||||
@@ -23,10 +23,10 @@ const { tableSearch } = require('../native.js')
|
||||
* A builder for nearest neighbor queries for LanceDB.
|
||||
*/
|
||||
export class Query<T = number[]> {
|
||||
private readonly _query: T
|
||||
private readonly _query?: T
|
||||
private readonly _tbl?: any
|
||||
private _queryVector?: number[]
|
||||
private _limit: number
|
||||
private _limit?: number
|
||||
private _refineFactor?: number
|
||||
private _nprobes: number
|
||||
private _select?: string[]
|
||||
@@ -35,10 +35,10 @@ export class Query<T = number[]> {
|
||||
private _prefilter: boolean
|
||||
protected readonly _embeddings?: EmbeddingFunction<T>
|
||||
|
||||
constructor (query: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
|
||||
constructor (query?: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
|
||||
this._tbl = tbl
|
||||
this._query = query
|
||||
this._limit = 10
|
||||
this._limit = undefined
|
||||
this._nprobes = 20
|
||||
this._refineFactor = undefined
|
||||
this._select = undefined
|
||||
@@ -113,11 +113,13 @@ export class Query<T = number[]> {
|
||||
* Execute the query and return the results as an Array of Objects
|
||||
*/
|
||||
async execute<T = Record<string, unknown>> (): Promise<T[]> {
|
||||
if (this._query !== undefined) {
|
||||
if (this._embeddings !== undefined) {
|
||||
this._queryVector = (await this._embeddings.embed([this._query]))[0]
|
||||
} else {
|
||||
this._queryVector = this._query as number[]
|
||||
}
|
||||
}
|
||||
|
||||
const isElectron = this.isElectron()
|
||||
const buffer = await tableSearch.call(this._tbl, this, isElectron)
|
||||
|
||||
@@ -16,7 +16,8 @@ import {
|
||||
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
|
||||
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
|
||||
type WriteOptions,
|
||||
type IndexStats
|
||||
type IndexStats,
|
||||
type UpdateArgs, type UpdateSqlArgs
|
||||
} from '../index'
|
||||
import { Query } from '../query'
|
||||
|
||||
@@ -24,6 +25,7 @@ import { Vector, Table as ArrowTable } from 'apache-arrow'
|
||||
import { HttpLancedbClient } from './client'
|
||||
import { isEmbeddingFunction } from '../embedding/embedding_function'
|
||||
import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow'
|
||||
import { toSQL } from '../util'
|
||||
|
||||
/**
|
||||
* Remote connection.
|
||||
@@ -246,6 +248,26 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
|
||||
}
|
||||
|
||||
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
|
||||
let filter: string | null
|
||||
let updates: Record<string, string>
|
||||
|
||||
if ('valuesSql' in args) {
|
||||
filter = args.where ?? null
|
||||
updates = args.valuesSql
|
||||
} else {
|
||||
filter = args.where ?? null
|
||||
updates = {}
|
||||
for (const [key, value] of Object.entries(args.values)) {
|
||||
updates[key] = toSQL(value)
|
||||
}
|
||||
}
|
||||
await this._client.post(`/v1/table/${this._name}/update/`, {
|
||||
predicate: filter,
|
||||
updates: Object.entries(updates).map(([key, value]) => [key, value])
|
||||
})
|
||||
}
|
||||
|
||||
async listIndices (): Promise<VectorIndex[]> {
|
||||
const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
|
||||
return results.data.indexes?.map((index: any) => ({
|
||||
|
||||
@@ -78,12 +78,31 @@ describe('LanceDB client', function () {
|
||||
})
|
||||
|
||||
it('limits # of results', async function () {
|
||||
const uri = await createTestDB()
|
||||
const uri = await createTestDB(2, 100)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
const results = await table.search([0.1, 0.3]).limit(1).execute()
|
||||
let results = await table.search([0.1, 0.3]).limit(1).execute()
|
||||
assert.equal(results.length, 1)
|
||||
assert.equal(results[0].id, 1)
|
||||
|
||||
// there is a default limit if unspecified
|
||||
results = await table.search([0.1, 0.3]).execute()
|
||||
assert.equal(results.length, 10)
|
||||
})
|
||||
|
||||
it('uses a filter / where clause without vector search', async function () {
|
||||
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type
|
||||
const assertResults = (results: Array<Record<string, unknown>>) => {
|
||||
assert.equal(results.length, 50)
|
||||
}
|
||||
|
||||
const uri = await createTestDB(2, 100)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = (await con.openTable('vectors')) as LocalTable
|
||||
let results = await table.filter('id % 2 = 0').execute()
|
||||
assertResults(results)
|
||||
results = await table.where('id % 2 = 0').execute()
|
||||
assertResults(results)
|
||||
})
|
||||
|
||||
it('uses a filter / where clause', async function () {
|
||||
@@ -260,6 +279,46 @@ describe('LanceDB client', function () {
|
||||
assert.equal(await table.countRows(), 2)
|
||||
})
|
||||
|
||||
it('can update records in the table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(await table.countRows(), 2)
|
||||
|
||||
await table.update({ where: 'price = 10', valuesSql: { price: '100' } })
|
||||
const results = await table.search([0.1, 0.2]).execute()
|
||||
assert.equal(results[0].price, 100)
|
||||
assert.equal(results[1].price, 11)
|
||||
})
|
||||
|
||||
it('can update the records using a literal value', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(await table.countRows(), 2)
|
||||
|
||||
await table.update({ where: 'price = 10', values: { price: 100 } })
|
||||
const results = await table.search([0.1, 0.2]).execute()
|
||||
assert.equal(results[0].price, 100)
|
||||
assert.equal(results[1].price, 11)
|
||||
})
|
||||
|
||||
it('can update every record in the table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(await table.countRows(), 2)
|
||||
|
||||
await table.update({ valuesSql: { price: '100' } })
|
||||
const results = await table.search([0.1, 0.2]).execute()
|
||||
|
||||
assert.equal(results[0].price, 100)
|
||||
assert.equal(results[1].price, 100)
|
||||
})
|
||||
|
||||
it('can delete records from a table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
@@ -542,7 +601,7 @@ describe('Compact and cleanup', function () {
|
||||
|
||||
// should have no effect, but this validates the arguments are parsed.
|
||||
await table.compactFiles({
|
||||
targetRowsPerFragment: 1024 * 10,
|
||||
targetRowsPerFragment: 102410,
|
||||
maxRowsPerGroup: 1024,
|
||||
materializeDeletions: true,
|
||||
materializeDeletionsThreshold: 0.5,
|
||||
|
||||
45
node/src/test/util.ts
Normal file
45
node/src/test/util.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { toSQL } from '../util'
|
||||
import * as chai from 'chai'
|
||||
|
||||
const expect = chai.expect
|
||||
|
||||
describe('toSQL', function () {
|
||||
it('should turn string to SQL expression', function () {
|
||||
expect(toSQL('foo')).to.equal("'foo'")
|
||||
})
|
||||
|
||||
it('should turn number to SQL expression', function () {
|
||||
expect(toSQL(123)).to.equal('123')
|
||||
})
|
||||
|
||||
it('should turn boolean to SQL expression', function () {
|
||||
expect(toSQL(true)).to.equal('TRUE')
|
||||
})
|
||||
|
||||
it('should turn null to SQL expression', function () {
|
||||
expect(toSQL(null)).to.equal('NULL')
|
||||
})
|
||||
|
||||
it('should turn Date to SQL expression', function () {
|
||||
const date = new Date('05 October 2011 14:48 UTC')
|
||||
expect(toSQL(date)).to.equal("'2011-10-05T14:48:00.000Z'")
|
||||
})
|
||||
|
||||
it('should turn array to SQL expression', function () {
|
||||
expect(toSQL(['foo', 'bar', true, 1])).to.equal("['foo', 'bar', TRUE, 1]")
|
||||
})
|
||||
})
|
||||
44
node/src/util.ts
Normal file
44
node/src/util.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
export type Literal = string | number | boolean | null | Date | Literal[]
|
||||
|
||||
export function toSQL (value: Literal): string {
|
||||
if (typeof value === 'string') {
|
||||
return `'${value}'`
|
||||
}
|
||||
|
||||
if (typeof value === 'number') {
|
||||
return value.toString()
|
||||
}
|
||||
|
||||
if (typeof value === 'boolean') {
|
||||
return value ? 'TRUE' : 'FALSE'
|
||||
}
|
||||
|
||||
if (value === null) {
|
||||
return 'NULL'
|
||||
}
|
||||
|
||||
if (value instanceof Date) {
|
||||
return `'${value.toISOString()}'`
|
||||
}
|
||||
|
||||
if (Array.isArray(value)) {
|
||||
return `[${value.map(toSQL).join(', ')}]`
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
throw new Error(`Unsupported value type: ${typeof value} value: (${value})`)
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.3.4
|
||||
current_version = 0.3.6
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
@@ -27,7 +27,7 @@ def connect(
|
||||
uri: URI,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
region: str = "us-west-2",
|
||||
region: str = "us-east-1",
|
||||
host_override: Optional[str] = None,
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
@@ -39,7 +39,7 @@ def connect(
|
||||
api_key: str, optional
|
||||
If presented, connect to LanceDB cloud.
|
||||
Otherwise, connect to a database on file system or cloud storage.
|
||||
region: str, default "us-west-2"
|
||||
region: str, default "us-east-1"
|
||||
The region to use for LanceDB Cloud.
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
|
||||
@@ -348,3 +348,20 @@ def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
|
||||
if PYDANTIC_VERSION.major >= 2:
|
||||
return (field_info.json_schema_extra or {}).get(key)
|
||||
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary.
|
||||
"""
|
||||
return model.dict()
|
||||
|
||||
else:
|
||||
|
||||
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary.
|
||||
"""
|
||||
return model.model_dump()
|
||||
|
||||
@@ -18,6 +18,8 @@ import attrs
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lancedb.common import VECTOR_COLUMN_NAME
|
||||
|
||||
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
|
||||
|
||||
|
||||
@@ -43,6 +45,8 @@ class VectorQuery(BaseModel):
|
||||
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
vector_column: str = VECTOR_COLUMN_NAME
|
||||
|
||||
|
||||
@attrs.define
|
||||
class VectorQueryResult:
|
||||
|
||||
@@ -56,16 +56,20 @@ class RemoteDBConnection(DBConnection):
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoveConnect(name={self.db_name})"
|
||||
return f"RemoteConnect(name={self.db_name})"
|
||||
|
||||
@override
|
||||
def table_names(self, page_token: Optional[str] = None, limit=10) -> Iterable[str]:
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""List the names of all tables in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
page_token: str
|
||||
The last token to start the new page.
|
||||
limit: int, default 10
|
||||
The maximum number of tables to return for each page.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -120,6 +124,97 @@ class RemoteDBConnection(DBConnection):
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
data: The data to initialize the table, *optional*
|
||||
User must provide at least one of `data` or `schema`.
|
||||
Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
- pyarrow.Table or pyarrow.RecordBatch
|
||||
schema: The schema of the table, *optional*
|
||||
Acceptable types are:
|
||||
|
||||
- pyarrow.Schema
|
||||
|
||||
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceTable
|
||||
A reference to the newly created table.
|
||||
|
||||
!!! note
|
||||
|
||||
The vector index won't be created by default.
|
||||
To create the index, call the `create_index` method on the table.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Can create with list of tuples or dictionaries:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data) # doctest: +SKIP
|
||||
LanceTable(my_table)
|
||||
|
||||
You can also pass a pandas DataFrame:
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({
|
||||
... "vector": [[1.1, 1.2], [0.2, 1.8]],
|
||||
... "lat": [45.5, 40.1],
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data) # doctest: +SKIP
|
||||
LanceTable(table2)
|
||||
|
||||
>>> custom_schema = pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("lat", pa.float32()),
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema) # doctest: +SKIP
|
||||
LanceTable(table3)
|
||||
|
||||
It is also possible to create an table from `[Iterable[pa.RecordBatch]]`:
|
||||
|
||||
>>> import pyarrow as pa
|
||||
>>> def make_batches():
|
||||
... for i in range(5):
|
||||
... yield pa.RecordBatch.from_arrays(
|
||||
... [
|
||||
... pa.array([[3.1, 4.1], [5.9, 26.5]],
|
||||
... pa.list_(pa.float32(), 2)),
|
||||
... pa.array(["foo", "bar"]),
|
||||
... pa.array([10.0, 20.0]),
|
||||
... ],
|
||||
... ["vector", "item", "price"],
|
||||
... )
|
||||
>>> schema=pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("item", pa.utf8()),
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema) # doctest: +SKIP
|
||||
LanceTable(table4)
|
||||
|
||||
"""
|
||||
if data is None and schema is None:
|
||||
raise ValueError("Either data or schema must be provided.")
|
||||
if embedding_functions is not None:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
import uuid
|
||||
from functools import cached_property
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import pyarrow as pa
|
||||
from lance import json_to_schema
|
||||
@@ -22,6 +22,7 @@ from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
|
||||
from ..query import LanceVectorQueryBuilder
|
||||
from ..table import Query, Table, _sanitize_data
|
||||
from ..util import value_to_sql
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE
|
||||
from .db import RemoteDBConnection
|
||||
@@ -37,7 +38,10 @@ class RemoteTable(Table):
|
||||
|
||||
@cached_property
|
||||
def schema(self) -> pa.Schema:
|
||||
"""Return the schema of the table."""
|
||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||
of this Table
|
||||
|
||||
"""
|
||||
resp = self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
)
|
||||
@@ -53,24 +57,17 @@ class RemoteTable(Table):
|
||||
return resp["version"]
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""Return the table as an Arrow table."""
|
||||
"""to_arrow() is not supported on the LanceDB cloud"""
|
||||
raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
|
||||
|
||||
def to_pandas(self):
|
||||
"""Return the table as a Pandas DataFrame.
|
||||
|
||||
Intercept `to_arrow()` for better error message.
|
||||
"""
|
||||
"""to_pandas() is not supported on the LanceDB cloud"""
|
||||
return NotImplementedError("to_pandas() is not supported on the LanceDB cloud")
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
accelerator: Optional[str] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
@@ -81,39 +78,28 @@ class RemoteTable(Table):
|
||||
----------
|
||||
metric : str
|
||||
The metric to use for the index. Default is "L2".
|
||||
num_partitions : int
|
||||
The number of partitions to use for the index. Default is 256.
|
||||
num_sub_vectors : int
|
||||
The number of sub-vectors to use for the index. Default is 96.
|
||||
vector_column_name : str
|
||||
The name of the vector column. Default is "vector".
|
||||
replace : bool
|
||||
Whether to replace the existing index. Default is True.
|
||||
accelerator : str, optional
|
||||
If set, use the given accelerator to create the index.
|
||||
Default is None. Currently not supported.
|
||||
index_cache_size : int, optional
|
||||
The size of the index cache in number of entries. Default value is 256.
|
||||
|
||||
Examples
|
||||
--------
|
||||
import lancedb
|
||||
import uuid
|
||||
from lancedb.schema import vector
|
||||
conn = lancedb.connect("db://...", api_key="...", region="...")
|
||||
table_name = uuid.uuid4().hex
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.uint32(), False),
|
||||
pa.field("vector", vector(128), False),
|
||||
pa.field("s", pa.string(), False),
|
||||
]
|
||||
)
|
||||
table = conn.create_table(
|
||||
table_name,
|
||||
schema=schema,
|
||||
)
|
||||
table.create_index()
|
||||
>>> import lancedb
|
||||
>>> import uuid
|
||||
>>> from lancedb.schema import vector
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> table_name = uuid.uuid4().hex
|
||||
>>> schema = pa.schema(
|
||||
... [
|
||||
... pa.field("id", pa.uint32(), False),
|
||||
... pa.field("vector", vector(128), False),
|
||||
... pa.field("s", pa.string(), False),
|
||||
... ]
|
||||
... )
|
||||
>>> table = db.create_table( # doctest: +SKIP
|
||||
... table_name, # doctest: +SKIP
|
||||
... schema=schema, # doctest: +SKIP
|
||||
... )
|
||||
>>> table.create_index("L2", "vector") # doctest: +SKIP
|
||||
"""
|
||||
index_type = "vector"
|
||||
|
||||
@@ -135,6 +121,28 @@ class RemoteTable(Table):
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> int:
|
||||
"""Add more data to the [Table](Table). It has the same API signature as the OSS version.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: DATA
|
||||
The data to insert into the table. Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
- pyarrow.Table or pyarrow.RecordBatch
|
||||
mode: str
|
||||
The mode to use when writing the data. Valid values are
|
||||
"append" and "overwrite".
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
"""
|
||||
data = _sanitize_data(
|
||||
data,
|
||||
self.schema,
|
||||
@@ -158,6 +166,58 @@ class RemoteTable(Table):
|
||||
def search(
|
||||
self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
|
||||
All query options are defined in [Query][lancedb.query.Query].
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> data = [
|
||||
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
||||
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
||||
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
|
||||
... ]
|
||||
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query, vector_column_name="vector") # doctest: +SKIP
|
||||
... .where("original_width > 1000", prefilter=True) # doctest: +SKIP
|
||||
... .select(["caption", "original_width"]) # doctest: +SKIP
|
||||
... .limit(2) # doctest: +SKIP
|
||||
... .to_pandas()) # doctest: +SKIP
|
||||
caption original_width vector _distance # doctest: +SKIP
|
||||
0 foo 2000 [0.5, 3.4, 1.3] 5.220000 # doctest: +SKIP
|
||||
1 test 3000 [0.3, 6.2, 2.6] 23.089996 # doctest: +SKIP
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||
The targetted vector to search for.
|
||||
|
||||
- *default None*.
|
||||
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||
|
||||
- If None then the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str
|
||||
The name of the vector column to search.
|
||||
*default "vector"*
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
A query builder object representing the query.
|
||||
Once executed, the query returns
|
||||
|
||||
- selected columns
|
||||
|
||||
- the vector
|
||||
|
||||
- and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
@@ -165,8 +225,114 @@ class RemoteTable(Table):
|
||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||
|
||||
def delete(self, predicate: str):
|
||||
"""Delete rows from the table."""
|
||||
"""Delete rows from the table.
|
||||
|
||||
This can be used to delete a single row, many rows, all rows, or
|
||||
sometimes no rows (if your predicate matches nothing).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predicate: str
|
||||
The SQL where clause to use when deleting rows.
|
||||
|
||||
- For example, 'x = 2' or 'x IN (1, 2, 3)'.
|
||||
|
||||
The filter must not be empty, or it will error.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||
>>> table.search([10,10]).to_pandas() # doctest: +SKIP
|
||||
x vector _distance # doctest: +SKIP
|
||||
0 3 [5.0, 6.0] 41.0 # doctest: +SKIP
|
||||
1 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||
2 1 [1.0, 2.0] 145.0 # doctest: +SKIP
|
||||
>>> table.delete("x = 2") # doctest: +SKIP
|
||||
>>> table.search([10,10]).to_pandas() # doctest: +SKIP
|
||||
x vector _distance # doctest: +SKIP
|
||||
0 3 [5.0, 6.0] 41.0 # doctest: +SKIP
|
||||
1 1 [1.0, 2.0] 145.0 # doctest: +SKIP
|
||||
|
||||
If you have a list of values to delete, you can combine them into a
|
||||
stringified list and use the `IN` operator:
|
||||
|
||||
>>> to_remove = [1, 3] # doctest: +SKIP
|
||||
>>> to_remove = ", ".join([str(v) for v in to_remove]) # doctest: +SKIP
|
||||
>>> table.delete(f"x IN ({to_remove})") # doctest: +SKIP
|
||||
>>> table.search([10,10]).to_pandas() # doctest: +SKIP
|
||||
x vector _distance # doctest: +SKIP
|
||||
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||
"""
|
||||
payload = {"predicate": predicate}
|
||||
self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
where: Optional[str] = None,
|
||||
values: Optional[dict] = None,
|
||||
*,
|
||||
values_sql: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
This can be used to update zero to all rows depending on how many
|
||||
rows match the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str, optional
|
||||
The SQL where clause to use when updating rows. For example, 'x = 2'
|
||||
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
||||
values: dict, optional
|
||||
The values to update. The keys are the column names and the values
|
||||
are the values to set.
|
||||
values_sql: dict, optional
|
||||
The values to update, expressed as SQL expression strings. These can
|
||||
reference existing columns. For example, {"x": "x + 1"} will increment
|
||||
the x column by 1.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||
>>> table.to_pandas() # doctest: +SKIP
|
||||
x vector # doctest: +SKIP
|
||||
0 1 [1.0, 2.0] # doctest: +SKIP
|
||||
1 2 [3.0, 4.0] # doctest: +SKIP
|
||||
2 3 [5.0, 6.0] # doctest: +SKIP
|
||||
>>> table.update(where="x = 2", values={"vector": [10, 10]}) # doctest: +SKIP
|
||||
>>> table.to_pandas() # doctest: +SKIP
|
||||
x vector # doctest: +SKIP
|
||||
0 1 [1.0, 2.0] # doctest: +SKIP
|
||||
1 3 [5.0, 6.0] # doctest: +SKIP
|
||||
2 2 [10.0, 10.0] # doctest: +SKIP
|
||||
|
||||
"""
|
||||
if values is not None and values_sql is not None:
|
||||
raise ValueError("Only one of values or values_sql can be provided")
|
||||
if values is None and values_sql is None:
|
||||
raise ValueError("Either values or values_sql must be provided")
|
||||
|
||||
if values is not None:
|
||||
updates = [[k, value_to_sql(v)] for k, v in values.items()]
|
||||
else:
|
||||
updates = [[k, v] for k, v in values_sql.items()]
|
||||
|
||||
payload = {"predicate": where, "updates": updates}
|
||||
self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ import inspect
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
@@ -28,9 +28,9 @@ from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from .pydantic import LanceModel
|
||||
from .pydantic import LanceModel, model_to_dict
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
from .util import fs_from_uri, safe_import_pandas, value_to_sql
|
||||
from .utils.events import register_event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -53,7 +53,9 @@ def _sanitize_data(
|
||||
# convert to list of dict if data is a bunch of LanceModels
|
||||
if isinstance(data[0], LanceModel):
|
||||
schema = data[0].__class__.to_arrow_schema()
|
||||
data = [dict(d) for d in data]
|
||||
data = [model_to_dict(d) for d in data]
|
||||
data = pa.Table.from_pylist(data, schema=schema)
|
||||
else:
|
||||
data = pa.Table.from_pylist(data)
|
||||
elif isinstance(data, dict):
|
||||
data = vec_to_table(data)
|
||||
@@ -785,7 +787,7 @@ class LanceTable(Table):
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
register_event("search")
|
||||
register_event("search_table")
|
||||
return LanceQueryBuilder.create(
|
||||
self, query, query_type, vector_column_name=vector_column_name
|
||||
)
|
||||
@@ -906,35 +908,42 @@ class LanceTable(Table):
|
||||
f"Table {name} does not exist."
|
||||
f"Please first call db.create_table({name}, data)"
|
||||
)
|
||||
register_event("open_table")
|
||||
|
||||
return tbl
|
||||
|
||||
def delete(self, where: str):
|
||||
self._dataset.delete(where)
|
||||
|
||||
def update(self, where: str, values: dict):
|
||||
def update(
|
||||
self,
|
||||
where: Optional[str] = None,
|
||||
values: Optional[dict] = None,
|
||||
*,
|
||||
values_sql: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
EXPERIMENTAL: Update rows in the table (not threadsafe).
|
||||
|
||||
This can be used to update zero to all rows depending on how many
|
||||
rows match the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
where: str, optional
|
||||
The SQL where clause to use when updating rows. For example, 'x = 2'
|
||||
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
||||
values: dict
|
||||
values: dict, optional
|
||||
The values to update. The keys are the column names and the values
|
||||
are the values to set.
|
||||
values_sql: dict, optional
|
||||
The values to update, expressed as SQL expression strings. These can
|
||||
reference existing columns. For example, {"x": "x + 1"} will increment
|
||||
the x column by 1.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> table.to_pandas()
|
||||
@@ -950,18 +959,15 @@ class LanceTable(Table):
|
||||
2 2 [10.0, 10.0]
|
||||
|
||||
"""
|
||||
orig_data = self._dataset.to_table(filter=where).combine_chunks()
|
||||
if len(orig_data) == 0:
|
||||
return
|
||||
for col, val in values.items():
|
||||
i = orig_data.column_names.index(col)
|
||||
if i < 0:
|
||||
raise ValueError(f"Column {col} does not exist")
|
||||
orig_data = orig_data.set_column(
|
||||
i, col, pa.array([val] * len(orig_data), type=orig_data[col].type)
|
||||
)
|
||||
self.delete(where)
|
||||
self.add(orig_data, mode="append")
|
||||
if values is not None and values_sql is not None:
|
||||
raise ValueError("Only one of values or values_sql can be provided")
|
||||
if values is None and values_sql is None:
|
||||
raise ValueError("Either values or values_sql must be provided")
|
||||
|
||||
if values is not None:
|
||||
values_sql = {k: value_to_sql(v) for k, v in values.items()}
|
||||
|
||||
self.to_lance().update(values_sql, where)
|
||||
self._reset_dataset()
|
||||
register_event("update")
|
||||
|
||||
|
||||
@@ -12,9 +12,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from datetime import date, datetime
|
||||
from functools import singledispatch
|
||||
from typing import Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import pyarrow.fs as pa_fs
|
||||
|
||||
|
||||
@@ -88,3 +91,53 @@ def safe_import_pandas():
|
||||
return pd
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
@singledispatch
|
||||
def value_to_sql(value):
|
||||
raise NotImplementedError("SQL conversion is not implemented for this type")
|
||||
|
||||
|
||||
@value_to_sql.register(str)
|
||||
def _(value: str):
|
||||
return f"'{value}'"
|
||||
|
||||
|
||||
@value_to_sql.register(int)
|
||||
def _(value: int):
|
||||
return str(value)
|
||||
|
||||
|
||||
@value_to_sql.register(float)
|
||||
def _(value: float):
|
||||
return str(value)
|
||||
|
||||
|
||||
@value_to_sql.register(bool)
|
||||
def _(value: bool):
|
||||
return str(value).upper()
|
||||
|
||||
|
||||
@value_to_sql.register(type(None))
|
||||
def _(value: type(None)):
|
||||
return "NULL"
|
||||
|
||||
|
||||
@value_to_sql.register(datetime)
|
||||
def _(value: datetime):
|
||||
return f"'{value.isoformat()}'"
|
||||
|
||||
|
||||
@value_to_sql.register(date)
|
||||
def _(value: date):
|
||||
return f"'{value.isoformat()}'"
|
||||
|
||||
|
||||
@value_to_sql.register(list)
|
||||
def _(value: list):
|
||||
return "[" + ", ".join(map(value_to_sql, value)) + "]"
|
||||
|
||||
|
||||
@value_to_sql.register(np.ndarray)
|
||||
def _(value: np.ndarray):
|
||||
return value_to_sql(value.tolist())
|
||||
|
||||
@@ -64,8 +64,10 @@ class _Events:
|
||||
Initializes the Events object with default values for events, rate_limit, and metadata.
|
||||
"""
|
||||
self.events = [] # events list
|
||||
self.max_events = 25 # max events to store in memory
|
||||
self.rate_limit = 60.0 # rate limit (seconds)
|
||||
self.throttled_event_names = ["search_table"]
|
||||
self.throttled_events = set()
|
||||
self.max_events = 5 # max events to store in memory
|
||||
self.rate_limit = 60.0 * 5 # rate limit (seconds)
|
||||
self.time = 0.0
|
||||
|
||||
if is_git_dir():
|
||||
@@ -112,10 +114,9 @@ class _Events:
|
||||
return
|
||||
if (
|
||||
len(self.events) < self.max_events
|
||||
): # Events list limited to 25 events (drop any events past this)
|
||||
): # Events list limited to self.max_events (drop any events past this)
|
||||
params.update(self.metadata)
|
||||
self.events.append(
|
||||
{
|
||||
event = {
|
||||
"event": event_name,
|
||||
"properties": params,
|
||||
"timestamp": datetime.datetime.now(
|
||||
@@ -123,7 +124,11 @@ class _Events:
|
||||
).isoformat(),
|
||||
"distinct_id": CONFIG["uuid"],
|
||||
}
|
||||
)
|
||||
if event_name not in self.throttled_event_names:
|
||||
self.events.append(event)
|
||||
elif event_name not in self.throttled_events:
|
||||
self.throttled_events.add(event_name)
|
||||
self.events.append(event)
|
||||
|
||||
# Check rate limit
|
||||
t = time.time()
|
||||
@@ -135,7 +140,6 @@ class _Events:
|
||||
"distinct_id": CONFIG["uuid"], # posthog needs this to accepts the event
|
||||
"batch": self.events,
|
||||
}
|
||||
|
||||
# POST equivalent to requests.post(self.url, json=data).
|
||||
# threaded request is used to avoid blocking, retries are disabled, and verbose is disabled
|
||||
# to avoid any possible disruption in the console.
|
||||
@@ -150,6 +154,7 @@ class _Events:
|
||||
|
||||
# Flush & Reset
|
||||
self.events = []
|
||||
self.throttled_events = set()
|
||||
self.time = t
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.3.4"
|
||||
version = "0.3.6"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.8.17",
|
||||
"pylance==0.8.21",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.1.0",
|
||||
"tqdm>=4.27.0",
|
||||
"aiohttp",
|
||||
"pydantic>=1.10",
|
||||
"attrs>=21.3.0",
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from datetime import timedelta
|
||||
from datetime import date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from unittest.mock import PropertyMock, patch
|
||||
@@ -21,6 +21,7 @@ import lance
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
@@ -141,14 +142,32 @@ def test_add(db):
|
||||
|
||||
|
||||
def test_add_pydantic_model(db):
|
||||
class TestModel(LanceModel):
|
||||
vector: Vector(16)
|
||||
li: List[int]
|
||||
# https://github.com/lancedb/lancedb/issues/562
|
||||
|
||||
data = TestModel(vector=list(range(16)), li=[1, 2, 3])
|
||||
table = LanceTable.create(db, "test", data=[data])
|
||||
assert len(table) == 1
|
||||
assert table.schema == TestModel.to_arrow_schema()
|
||||
class Document(BaseModel):
|
||||
content: str
|
||||
source: str
|
||||
|
||||
class LanceSchema(LanceModel):
|
||||
id: str
|
||||
vector: Vector(2)
|
||||
li: List[int]
|
||||
payload: Document
|
||||
|
||||
tbl = LanceTable.create(db, "mytable", schema=LanceSchema, mode="overwrite")
|
||||
assert tbl.schema == LanceSchema.to_arrow_schema()
|
||||
|
||||
# add works
|
||||
expected = LanceSchema(
|
||||
id="id",
|
||||
vector=[0.0, 0.0],
|
||||
li=[1, 2, 3],
|
||||
payload=Document(content="foo", source="bar"),
|
||||
)
|
||||
tbl.add([expected])
|
||||
|
||||
result = tbl.search([0.0, 0.0]).limit(1).to_pydantic(LanceSchema)[0]
|
||||
assert result == expected
|
||||
|
||||
|
||||
def _add(table, schema):
|
||||
@@ -348,14 +367,79 @@ def test_update(db):
|
||||
assert len(table) == 2
|
||||
assert len(table.list_versions()) == 2
|
||||
table.update(where="id=0", values={"vector": [1.1, 1.1]})
|
||||
assert len(table.list_versions()) == 4
|
||||
assert table.version == 4
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table) == 2
|
||||
v = table.to_arrow()["vector"].combine_chunks()
|
||||
v = v.values.to_numpy().reshape(2, 2)
|
||||
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
|
||||
|
||||
|
||||
def test_update_types(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[
|
||||
{
|
||||
"id": 0,
|
||||
"str": "foo",
|
||||
"float": 1.1,
|
||||
"timestamp": datetime(2021, 1, 1),
|
||||
"date": date(2021, 1, 1),
|
||||
"vector1": [1.0, 0.0],
|
||||
"vector2": [1.0, 1.0],
|
||||
}
|
||||
],
|
||||
)
|
||||
# Update with SQL
|
||||
table.update(
|
||||
values_sql=dict(
|
||||
id="1",
|
||||
str="'bar'",
|
||||
float="2.2",
|
||||
timestamp="TIMESTAMP '2021-01-02 00:00:00'",
|
||||
date="DATE '2021-01-02'",
|
||||
vector1="[2.0, 2.0]",
|
||||
vector2="[3.0, 3.0]",
|
||||
)
|
||||
)
|
||||
actual = table.to_arrow().to_pylist()[0]
|
||||
expected = dict(
|
||||
id=1,
|
||||
str="bar",
|
||||
float=2.2,
|
||||
timestamp=datetime(2021, 1, 2),
|
||||
date=date(2021, 1, 2),
|
||||
vector1=[2.0, 2.0],
|
||||
vector2=[3.0, 3.0],
|
||||
)
|
||||
assert actual == expected
|
||||
|
||||
# Update with values
|
||||
table.update(
|
||||
values=dict(
|
||||
id=2,
|
||||
str="baz",
|
||||
float=3.3,
|
||||
timestamp=datetime(2021, 1, 3),
|
||||
date=date(2021, 1, 3),
|
||||
vector1=[3.0, 3.0],
|
||||
vector2=np.array([4.0, 4.0]),
|
||||
)
|
||||
)
|
||||
actual = table.to_arrow().to_pylist()[0]
|
||||
expected = dict(
|
||||
id=2,
|
||||
str="baz",
|
||||
float=3.3,
|
||||
timestamp=datetime(2021, 1, 3),
|
||||
date=date(2021, 1, 3),
|
||||
vector1=[3.0, 3.0],
|
||||
vector2=[4.0, 4.0],
|
||||
)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_create_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb-node"
|
||||
version = "0.3.9"
|
||||
version = "0.3.11"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
edition = "2018"
|
||||
|
||||
@@ -237,6 +237,7 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
||||
cx.export_function("tableAdd", JsTable::js_add)?;
|
||||
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
|
||||
cx.export_function("tableDelete", JsTable::js_delete)?;
|
||||
cx.export_function("tableUpdate", JsTable::js_update)?;
|
||||
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
|
||||
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
|
||||
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
|
||||
|
||||
@@ -23,8 +23,14 @@ impl JsQuery {
|
||||
let query_obj = cx.argument::<JsObject>(0)?;
|
||||
|
||||
let limit = query_obj
|
||||
.get::<JsNumber, _, _>(&mut cx, "_limit")?
|
||||
.value(&mut cx);
|
||||
.get_opt::<JsNumber, _, _>(&mut cx, "_limit")?
|
||||
.map(|value| {
|
||||
let limit = value.value(&mut cx) as u64;
|
||||
if limit <= 0 {
|
||||
panic!("Limit must be a positive integer");
|
||||
}
|
||||
limit
|
||||
});
|
||||
let select = query_obj
|
||||
.get_opt::<JsArray, _, _>(&mut cx, "_select")?
|
||||
.map(|arr| {
|
||||
@@ -48,7 +54,9 @@ impl JsQuery {
|
||||
.map(|s| s.value(&mut cx))
|
||||
.map(|s| MetricType::try_from(s.as_str()).unwrap());
|
||||
|
||||
let prefilter = query_obj.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?.value(&mut cx);
|
||||
let prefilter = query_obj
|
||||
.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?
|
||||
.value(&mut cx);
|
||||
|
||||
let is_electron = cx
|
||||
.argument::<JsBoolean>(1)
|
||||
@@ -59,20 +67,23 @@ impl JsQuery {
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
let query_vector = query_obj.get::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
||||
let query = convert::js_array_to_vec(query_vector.deref(), &mut cx);
|
||||
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
||||
let table = js_table.table.clone();
|
||||
let query = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx));
|
||||
|
||||
rt.spawn(async move {
|
||||
let builder = table
|
||||
.search(Float32Array::from(query))
|
||||
.limit(limit as usize)
|
||||
let mut builder = table
|
||||
.search(query.map(|q| Float32Array::from(q)))
|
||||
.refine_factor(refine_factor)
|
||||
.nprobes(nprobes)
|
||||
.filter(filter)
|
||||
.metric_type(metric_type)
|
||||
.select(select)
|
||||
.prefilter(prefilter);
|
||||
if let Some(limit) = limit {
|
||||
builder = builder.limit(limit as usize);
|
||||
};
|
||||
|
||||
let record_batch_stream = builder.execute();
|
||||
let results = record_batch_stream
|
||||
.and_then(|stream| {
|
||||
|
||||
@@ -165,6 +165,69 @@ impl JsTable {
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let mut table = js_table.table.clone();
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
|
||||
// create a vector of updates from the passed map
|
||||
let updates_arg = cx.argument::<JsObject>(1)?;
|
||||
let properties = updates_arg.get_own_property_names(&mut cx)?;
|
||||
let mut updates: Vec<(String, String)> =
|
||||
Vec::with_capacity(properties.len(&mut cx) as usize);
|
||||
|
||||
let len_properties = properties.len(&mut cx);
|
||||
for i in 0..len_properties {
|
||||
let property = properties
|
||||
.get_value(&mut cx, i)?
|
||||
.downcast_or_throw::<JsString, _>(&mut cx)?;
|
||||
|
||||
let value = updates_arg
|
||||
.get_value(&mut cx, property.clone())?
|
||||
.downcast_or_throw::<JsString, _>(&mut cx)?;
|
||||
|
||||
let property = property.value(&mut cx);
|
||||
let value = value.value(&mut cx);
|
||||
updates.push((property, value));
|
||||
}
|
||||
|
||||
// get the filter/predicate if the user passed one
|
||||
let predicate = cx.argument_opt(0);
|
||||
let predicate = predicate.unwrap().downcast::<JsString, _>(&mut cx);
|
||||
let predicate = match predicate {
|
||||
Ok(_) => {
|
||||
let val = predicate.map(|s| s.value(&mut cx)).unwrap();
|
||||
Some(val)
|
||||
}
|
||||
Err(_) => {
|
||||
// if the predicate is not string, check it's null otherwise an invalid
|
||||
// type was passed
|
||||
cx.argument::<JsNull>(0)?;
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
rt.spawn(async move {
|
||||
let updates_arg = updates
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v.as_str()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let predicate = predicate.as_ref().map(|s| s.as_str());
|
||||
|
||||
let update_result = table.update(predicate, updates_arg).await;
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
update_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
})
|
||||
});
|
||||
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb"
|
||||
version = "0.3.9"
|
||||
version = "0.3.11"
|
||||
edition = "2021"
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
|
||||
@@ -359,7 +359,9 @@ mod test {
|
||||
assert_eq!(t.count_rows().await.unwrap(), 100);
|
||||
|
||||
let q = t
|
||||
.search(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1]))
|
||||
.search(Some(PrimitiveArray::from_iter_values(vec![
|
||||
0.1, 0.1, 0.1, 0.1,
|
||||
])))
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
|
||||
@@ -24,8 +24,9 @@ use crate::error::Result;
|
||||
/// A builder for nearest neighbor queries for LanceDB.
|
||||
pub struct Query {
|
||||
pub dataset: Arc<Dataset>,
|
||||
pub query_vector: Float32Array,
|
||||
pub limit: usize,
|
||||
pub query_vector: Option<Float32Array>,
|
||||
pub column: String,
|
||||
pub limit: Option<usize>,
|
||||
pub filter: Option<String>,
|
||||
pub select: Option<Vec<String>>,
|
||||
pub nprobes: usize,
|
||||
@@ -46,11 +47,12 @@ impl Query {
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Query] object.
|
||||
pub(crate) fn new(dataset: Arc<Dataset>, vector: Float32Array) -> Self {
|
||||
pub(crate) fn new(dataset: Arc<Dataset>, vector: Option<Float32Array>) -> Self {
|
||||
Query {
|
||||
dataset,
|
||||
query_vector: vector,
|
||||
limit: 10,
|
||||
column: crate::table::VECTOR_COLUMN_NAME.to_string(),
|
||||
limit: None,
|
||||
nprobes: 20,
|
||||
refine_factor: None,
|
||||
metric_type: None,
|
||||
@@ -69,11 +71,13 @@ impl Query {
|
||||
pub async fn execute(&self) -> Result<DatasetRecordBatchStream> {
|
||||
let mut scanner: Scanner = self.dataset.scan();
|
||||
|
||||
scanner.nearest(
|
||||
crate::table::VECTOR_COLUMN_NAME,
|
||||
&self.query_vector,
|
||||
self.limit,
|
||||
)?;
|
||||
if let Some(query) = self.query_vector.as_ref() {
|
||||
// If there is a vector query, default to limit=10 if unspecified
|
||||
scanner.nearest(&self.column, query, self.limit.unwrap_or(10))?;
|
||||
} else {
|
||||
// If there is no vector query, it's ok to not have a limit
|
||||
scanner.limit(self.limit.map(|limit| limit as i64), None)?;
|
||||
}
|
||||
scanner.nprobs(self.nprobes);
|
||||
scanner.use_index(self.use_index);
|
||||
scanner.prefilter(self.prefilter);
|
||||
@@ -85,13 +89,23 @@ impl Query {
|
||||
Ok(scanner.try_into_stream().await?)
|
||||
}
|
||||
|
||||
/// Set the column to query
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `column` - The column name
|
||||
pub fn column(mut self, column: &str) -> Query {
|
||||
self.column = column.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum number of results to return.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `limit` - The maximum number of results to return.
|
||||
pub fn limit(mut self, limit: usize) -> Query {
|
||||
self.limit = limit;
|
||||
self.limit = Some(limit);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -101,7 +115,7 @@ impl Query {
|
||||
///
|
||||
/// * `vector` - The vector that will be used for search.
|
||||
pub fn query_vector(mut self, query_vector: Float32Array) -> Query {
|
||||
self.query_vector = query_vector;
|
||||
self.query_vector = Some(query_vector);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -174,7 +188,10 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_array::{
|
||||
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader,
|
||||
};
|
||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||
use futures::StreamExt;
|
||||
use lance::dataset::Dataset;
|
||||
@@ -187,7 +204,7 @@ mod tests {
|
||||
let batches = make_test_batches();
|
||||
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||
let vector = Some(Float32Array::from_iter_values([0.1, 0.2]));
|
||||
let query = Query::new(Arc::new(ds), vector.clone());
|
||||
assert_eq!(query.query_vector, vector);
|
||||
|
||||
@@ -201,8 +218,8 @@ mod tests {
|
||||
.metric_type(Some(MetricType::Cosine))
|
||||
.refine_factor(Some(999));
|
||||
|
||||
assert_eq!(query.query_vector, new_vector);
|
||||
assert_eq!(query.limit, 100);
|
||||
assert_eq!(query.query_vector.unwrap(), new_vector);
|
||||
assert_eq!(query.limit.unwrap(), 100);
|
||||
assert_eq!(query.nprobes, 1000);
|
||||
assert_eq!(query.use_index, true);
|
||||
assert_eq!(query.metric_type, Some(MetricType::Cosine));
|
||||
@@ -214,7 +231,7 @@ mod tests {
|
||||
let batches = make_non_empty_batches();
|
||||
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1; 4]);
|
||||
let vector = Some(Float32Array::from_iter_values([0.1; 4]));
|
||||
|
||||
let query = Query::new(ds.clone(), vector.clone());
|
||||
let result = query
|
||||
@@ -244,6 +261,27 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_no_vector() {
|
||||
// test that it's ok to not specify a query vector (just filter / limit)
|
||||
let batches = make_non_empty_batches();
|
||||
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
|
||||
|
||||
let query = Query::new(ds.clone(), None);
|
||||
let result = query
|
||||
.filter(Some("id % 2 == 0".to_string()))
|
||||
.execute()
|
||||
.await;
|
||||
let mut stream = result.expect("should have result");
|
||||
// should only have one batch
|
||||
while let Some(batch) = stream.next().await {
|
||||
let b = batch.expect("should be Ok");
|
||||
// cast arr into Int32Array
|
||||
let arr: &Int32Array = b["id"].as_primitive();
|
||||
assert!(arr.iter().all(|x| x.unwrap() % 2 == 0));
|
||||
}
|
||||
}
|
||||
|
||||
fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static {
|
||||
let vec = Box::new(RandomVector::new().named("vector".to_string()));
|
||||
let id = Box::new(IncrementingInt32::new().named("id".to_string()));
|
||||
|
||||
@@ -23,7 +23,7 @@ use lance::dataset::cleanup::RemovalStats;
|
||||
use lance::dataset::optimize::{
|
||||
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
||||
};
|
||||
use lance::dataset::{Dataset, WriteParams};
|
||||
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
|
||||
use lance::index::DatasetIndexExt;
|
||||
use lance::io::object_store::WrappingObjectStore;
|
||||
use std::path::Path;
|
||||
@@ -308,10 +308,14 @@ impl Table {
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Query] object.
|
||||
pub fn search(&self, query_vector: Float32Array) -> Query {
|
||||
pub fn search(&self, query_vector: Option<Float32Array>) -> Query {
|
||||
Query::new(self.dataset.clone(), query_vector)
|
||||
}
|
||||
|
||||
pub fn filter(&self, expr: String) -> Query {
|
||||
Query::new(self.dataset.clone(), None).filter(Some(expr))
|
||||
}
|
||||
|
||||
/// Returns the number of rows in this Table
|
||||
pub async fn count_rows(&self) -> Result<usize> {
|
||||
Ok(self.dataset.count_rows().await?)
|
||||
@@ -338,6 +342,27 @@ impl Table {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update(
|
||||
&mut self,
|
||||
predicate: Option<&str>,
|
||||
updates: Vec<(&str, &str)>,
|
||||
) -> Result<()> {
|
||||
let mut builder = UpdateBuilder::new(self.dataset.clone());
|
||||
if let Some(predicate) = predicate {
|
||||
builder = builder.update_where(predicate)?;
|
||||
}
|
||||
|
||||
for (column, value) in updates {
|
||||
builder = builder.set(column, value)?;
|
||||
}
|
||||
|
||||
let operation = builder.build()?;
|
||||
let new_ds = operation.execute().await?;
|
||||
self.dataset = new_ds;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove old versions of the dataset from disk.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -413,11 +438,14 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{
|
||||
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader,
|
||||
Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, Float64Array,
|
||||
Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray,
|
||||
UInt32Array,
|
||||
};
|
||||
use arrow_data::ArrayDataBuilder;
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use arrow_schema::{DataType, Field, Schema, TimeUnit};
|
||||
use futures::TryStreamExt;
|
||||
use lance::dataset::{Dataset, WriteMode};
|
||||
use lance::index::vector::pq::PQBuildParams;
|
||||
use lance::io::object_store::{ObjectStoreParams, WrappingObjectStore};
|
||||
@@ -540,6 +568,272 @@ mod tests {
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_with_predicate() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("name", DataType::Utf8, false),
|
||||
]));
|
||||
|
||||
let record_batch_iter = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
Arc::new(StringArray::from_iter_values(vec![
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||
])),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
);
|
||||
|
||||
Dataset::write(record_batch_iter, uri, None).await.unwrap();
|
||||
let mut table = Table::open(uri).await.unwrap();
|
||||
|
||||
table
|
||||
.update(Some("id > 5"), vec![("name", "'foo'")])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ds_after = Dataset::open(uri).await.unwrap();
|
||||
let mut batches = ds_after
|
||||
.scan()
|
||||
.project(&["id", "name"])
|
||||
.unwrap()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
while let Some(batch) = batches.pop() {
|
||||
let ids = batch
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
let names = batch
|
||||
.column(1)
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for (i, name) in names.iter().enumerate() {
|
||||
let id = ids[i].unwrap();
|
||||
let name = name.unwrap();
|
||||
if id > 5 {
|
||||
assert_eq!(name, "foo");
|
||||
} else {
|
||||
assert_eq!(name, &format!("{}", (b'a' + id as u8) as char));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_all_types() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("int32", DataType::Int32, false),
|
||||
Field::new("int64", DataType::Int64, false),
|
||||
Field::new("uint32", DataType::UInt32, false),
|
||||
Field::new("string", DataType::Utf8, false),
|
||||
Field::new("large_string", DataType::LargeUtf8, false),
|
||||
Field::new("float32", DataType::Float32, false),
|
||||
Field::new("float64", DataType::Float64, false),
|
||||
Field::new("bool", DataType::Boolean, false),
|
||||
Field::new("date32", DataType::Date32, false),
|
||||
Field::new(
|
||||
"timestamp_ns",
|
||||
DataType::Timestamp(TimeUnit::Nanosecond, None),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"timestamp_ms",
|
||||
DataType::Timestamp(TimeUnit::Millisecond, None),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"vec_f32",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"vec_f64",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let record_batch_iter = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
Arc::new(Int64Array::from_iter_values(0..10)),
|
||||
Arc::new(UInt32Array::from_iter_values(0..10)),
|
||||
Arc::new(StringArray::from_iter_values(vec![
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||
])),
|
||||
Arc::new(LargeStringArray::from_iter_values(vec![
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||
])),
|
||||
Arc::new(Float32Array::from_iter_values(
|
||||
(0..10).into_iter().map(|i| i as f32),
|
||||
)),
|
||||
Arc::new(Float64Array::from_iter_values(
|
||||
(0..10).into_iter().map(|i| i as f64),
|
||||
)),
|
||||
Arc::new(Into::<BooleanArray>::into(vec![
|
||||
true, false, true, false, true, false, true, false, true, false,
|
||||
])),
|
||||
Arc::new(Date32Array::from_iter_values(0..10)),
|
||||
Arc::new(TimestampNanosecondArray::from_iter_values(0..10)),
|
||||
Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
|
||||
Arc::new(
|
||||
create_fixed_size_list(
|
||||
Float32Array::from_iter_values((0..20).into_iter().map(|i| i as f32)),
|
||||
2,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
Arc::new(
|
||||
create_fixed_size_list(
|
||||
Float64Array::from_iter_values((0..20).into_iter().map(|i| i as f64)),
|
||||
2,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
);
|
||||
|
||||
Dataset::write(record_batch_iter, uri, None).await.unwrap();
|
||||
let mut table = Table::open(uri).await.unwrap();
|
||||
|
||||
// check it can do update for each type
|
||||
let updates: Vec<(&str, &str)> = vec![
|
||||
("string", "'foo'"),
|
||||
("large_string", "'large_foo'"),
|
||||
("int32", "1"),
|
||||
("int64", "1"),
|
||||
("uint32", "1"),
|
||||
("float32", "1.0"),
|
||||
("float64", "1.0"),
|
||||
("bool", "true"),
|
||||
("date32", "1"),
|
||||
("timestamp_ns", "1"),
|
||||
("timestamp_ms", "1"),
|
||||
("vec_f32", "[1.0, 1.0]"),
|
||||
("vec_f64", "[1.0, 1.0]"),
|
||||
];
|
||||
|
||||
// for (column, value) in test_cases {
|
||||
table.update(None, updates).await.unwrap();
|
||||
|
||||
let ds_after = Dataset::open(uri).await.unwrap();
|
||||
let mut batches = ds_after
|
||||
.scan()
|
||||
.project(&[
|
||||
"string",
|
||||
"large_string",
|
||||
"int32",
|
||||
"int64",
|
||||
"uint32",
|
||||
"float32",
|
||||
"float64",
|
||||
"bool",
|
||||
"date32",
|
||||
"timestamp_ns",
|
||||
"timestamp_ms",
|
||||
"vec_f32",
|
||||
"vec_f64",
|
||||
])
|
||||
.unwrap()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = batches.pop().unwrap();
|
||||
|
||||
macro_rules! assert_column {
|
||||
($column:expr, $array_type:ty, $expected:expr) => {
|
||||
let array = $column
|
||||
.as_any()
|
||||
.downcast_ref::<$array_type>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for v in array {
|
||||
assert_eq!(v, Some($expected));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
assert_column!(batch.column(0), StringArray, "foo");
|
||||
assert_column!(batch.column(1), LargeStringArray, "large_foo");
|
||||
assert_column!(batch.column(2), Int32Array, 1);
|
||||
assert_column!(batch.column(3), Int64Array, 1);
|
||||
assert_column!(batch.column(4), UInt32Array, 1);
|
||||
assert_column!(batch.column(5), Float32Array, 1.0);
|
||||
assert_column!(batch.column(6), Float64Array, 1.0);
|
||||
assert_column!(batch.column(7), BooleanArray, true);
|
||||
assert_column!(batch.column(8), Date32Array, 1);
|
||||
assert_column!(batch.column(9), TimestampNanosecondArray, 1);
|
||||
assert_column!(batch.column(10), TimestampMillisecondArray, 1);
|
||||
|
||||
let array = batch
|
||||
.column(11)
|
||||
.as_any()
|
||||
.downcast_ref::<FixedSizeListArray>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for v in array {
|
||||
let v = v.unwrap();
|
||||
let f32array = v.as_any().downcast_ref::<Float32Array>().unwrap();
|
||||
for v in f32array {
|
||||
assert_eq!(v, Some(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
let array = batch
|
||||
.column(12)
|
||||
.as_any()
|
||||
.downcast_ref::<FixedSizeListArray>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for v in array {
|
||||
let v = v.unwrap();
|
||||
let f64array = v.as_any().downcast_ref::<Float64Array>().unwrap();
|
||||
for v in f64array {
|
||||
assert_eq!(v, Some(1.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
@@ -554,8 +848,8 @@ mod tests {
|
||||
let table = Table::open(uri).await.unwrap();
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||
let query = table.search(vector.clone());
|
||||
assert_eq!(vector, query.query_vector);
|
||||
let query = table.search(Some(vector.clone()));
|
||||
assert_eq!(vector, query.query_vector.unwrap());
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
|
||||
Reference in New Issue
Block a user